# Variational autoencoders

__Objective:__ define a variational autoencoder for images.

**Idea:** in an autoencoder, the encoder maps samples to points in latent space. In a variational autoencoder, it maps samples to **multivariate Gaussian distributions** on latent space. This helps reconstructing similar samples from nearby points in latent space, because the decoder now needs to minimize the reconstruction error for all the points sampled from the distribution corresponding to the same input sample.

### Ingredients

#### Encoder

the encoder part of the model is modified to output the parameters for a multivariate Gaussian on latent space with diagonal covariance matrix. In practice, given the input sample $x$, the encoder outputs two vectors $\mu(x), \sigma(x) \in \mathbb{R}^d$, where $d$ is the dimension of latent space, parametrizing a distribution $\mathcal{N}(\mu(x), \Sigma(x))$, where $\Sigma(x) = \mathrm{diag}(\sigma^2_1(x), \ldots, \sigma^2_d(x))$.

#### Decoder

latent vectors $z\in \mathbb{R}^d$ are obtained by sampling the distributions on latent space, and given a latent vector the decoder produces a realistic sample, as similar as possible to the one correspnding to the Gaussian distribution that generated $z$. The architecture of the decoder indeed remains the same as in regular autoencoders.

#### Loss function

The loss function to minimize has an additional term w.r.t. the usual MSE or categorical cross-entropy consisting in the KL divergence of the Gaussian distribution on the latent space corresponding to each sample and a (multivariate) standard normal distribution,
$$
\mathrm{KL}\left[ \mathcal{N}(\mu(x), \Sigma(x)) || \mathcal{N}(0, \mathbf{1}) \right]\,.
$$
This comes from assuming a multivariate standard normal prior on latent space, a Gaussian likelihood and an approximate variational posterior given by the multivariate Gaussian outputted by the encoder. With the reparametrization trick, the loss function is then given by the KL-divergence of the variational posterior and the true posterior (product of likelihood and prior).

The KL divergence above can be computed analytically, so given $\mu(x)$ and $\sigma(x)$ it's easy to compute the exact contribution to the total loss:
$$
\begin{array}{lll}
\mathrm{KL}\left[ \mathcal{N}(\mu(x), \Sigma(x)) || \mathcal{N}(0, \mathbf{1}) \right] &\equiv& -\int \mathrm{d}^d z\, \mathcal{N}(z | \mu(x), \sigma(x))\,\log\left( \frac{\mathcal{N}(z | 0, \mathbf{1})}{\mathcal{N}(z | \mu(x), \sigma(x))} \right) \\
&=& -\frac{1}{2} \sum_{j=1}^d \left( 1 + \log(\sigma^2_j) - \mu_j^2 - \sigma_j^2 \right)\,.
\end{array}
$$

In the $\beta$-VAE variant of the model it's possible to tune the relative weight of the reconstruction and KL terms in the loss functions via a coefficient $\beta$,
$$
\mathcal{L} = \mathrm{MSE} + \beta\,\mathrm{KL}\,.
$$
$\beta$ is an hyperparameter controlling the balance between the minimization of either term in the loss: if $\beta$ is too small the KL term will have little effect (latent vectors more spread in latent space, farther away from the origin and with discontinuoous clusters), while if $\beta$ is too big the KL term will prevail and the model will have a poor recontruction power (essentially the Gaussians will end up fitting the unit ones).

#### Reparametrization trick

Given an input sample, the prediction has a random component corresponding to the sampling of the Gaussian distribution obtained from the input sample via the encoder. Backpropagation would require to "differentiate the sampling" w.r.t. the parameters of the Gaussian distribution, which is not possible: one drawn, a sample is a numerical value and all the information about the distribution from which it was generated is lost. Nonetheless, it's possible use a reparametrization of the Gaussian distribution that allows for explicit differentiation w.r.t. to the $\mu(x)$ and $\sigma(x)$ parameters, the **reparametrization trick**.

Given the input sample $x$, the encoder outputs the parameters $\mu(x)$ and $\sigma(x)$ of the multivariate Gaussian $\mathcal{N}(\mu(x), \sigma(x))$, from which the latent vector $z$ is sampled,
$$
z \sim \mathcal{N}(\mu(x), \sigma(x))\,.
$$
The reparametrization trick consists in sampling $z$ in the equivalent way
$$
z = \mu(x) + \sigma(x)\,\epsilon\,,
$$
where $\epsilon \sim \mathcal{N}(0, 1)$. This way the generated values for $z$ are exactly equivalent as before, but the parameters $\mu$ and $\sigma$ appear exlicitly and differentiation w.r.t. them is possible.

In [None]:
import sys
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import seaborn as sns

sys.path.append('../../modules/')

from variational_autoencoders import VariationalEncoder, SampleLayer

tfd = tfp.distributions

sns.set_theme()

%load_ext autoreload
%autoreload 2

## Model definition and training

In [None]:
variational_encoder = VariationalEncoder()
sample_layer = SampleLayer()

n_samples = 5000

random_inputs = tf.random.normal(shape=(n_samples, 32, 32, 1))

z_mean, z_log_var, z_samples = variational_encoder(random_inputs)

In [None]:
class VAE(tf.keras.Model):
    """
    """
    def __init__(self, variational_encoder, decoder):
        """
        """
        super().__init__()

        self.variational_encoder = variational_encoder
        self.decoder = decoder

        self.total_loss_tracker = tf.keras.metrics.Mean(name='total_loss')
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
            name='reconstruction_loss'
        )
        self.kl_loss_tracker = tf.keras.Mean(name='kl_loss')

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker
        ]

    def call(self, x):
        z_mean, z_log_var, z_samples = self.variational_encoder(x)

        reconstructed_samples = self.decoder(z_samples)

        return z_mean, z_log_var, reconstructed_samples

    def train_step(self, x):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, reconstructed_samples = self(x)

            reconstruction_loss = tf.reduce_mean(
                tf.keras.losses.binary_crossentropy(
                    x,
                    reconstructed_samples,
                    axis=(1, 2, 3)
                )
            )

            kl_loss = tf.reduce_mean(tf.reduce_sum(
                
            ))