# Variational autoencoders with Tensorflow Probability

__Objective:__ build a VAE using TFP.

**Idea:** in variational autoencoders (VAE), an encoder model maps samples $x$ to latent vectors $z$ and a decoder model maps latent vectors $z$ back to samples $x$, just as for regular autoencoders. This time though the model is generative, as assumptions and optimizations are made on the probability distributions of the various objects.

Let $x$ denote a sample and $z\in\mathbb{R}^d$ be a latent vector ($d$ is the dimension of the latent space). VAE look for a probabilistic relation between $z$ and $x$, i.e. give $x$ we assume that the corresponding $z$ is not deterministic, but rather given by a probability distribution
$$
p(z|x),
$$
in which the sample is a variable we condition over.

On the other hand, given a latent vector $z$, we also assume that the corresponding sample $x$ is also given by a distribution,
$$
p(x|z).
$$

These two distributions are unknown: the model will try to approximate them. In particular, the **encoder** network is reponsible for approximating $p(z|x)$ (as it's the distribution via which latent vectors are encoded, given a sample), while the **decoder** network is reponsible for approximating $p(x|z)$ (as it's the distribution via which samples are reconstructed given a latent vector).

#### Encoder

The encoder network approximates $p(z|x)$ using **variational inference**: we use a **variational distribution**
$$
q_\phi(z|x),
$$
where $\phi$ is the set of parameters the distibution depends upon, and approximate $p(z|x)$ by minimizing the difference between $q_\pi(z|x)$ and $p(z|x)$ w.r.t. the parameters.

The **Kullback-Leibler** divergence between $q_\pi(z|x)$ and $p(z|x)$ is taken as a measure of the difference between them,
$$
\begin{array}{lll}
D_\mathrm{KL} \left[ q_\phi(z|x) || p(z|x) \right] &\equiv& - \int \mathrm{d}^d z \, q_\phi(z|x) \log \left( \frac{p(z|x)}{q_\phi(z|x)} \right) \\
&=& \mathbb{E}_{z\sim q_\phi} \left[ \log(q_\phi(z|x)) - \log(p(z|x)) \right] 
\end{array}
$$

**Note:** the KL divergence is **not** a symmetric quantity (it's not a metric on the space of distributions), so the order in which distributions are taken changes the result.

Because $p(z|x)$ is not known, the above quantity can't be computed as it is. Something else is needed.

#### Decoder

We assume the decoder network gives the distribution of samples given latent vectors, depending on a set of parameters (the parameters of the network) $\theta$,
$$
p_\theta(x|z)\,.
$$

#### Training

If we followed the logic of normal autoencoders, the model should be trained to give the best possible reconstruction of samples. Schematically, as in usual autoencoders, given a sample, a corresponding latent vector is produced and then an output sample is reconstructed by the decoder: we want the output sample to be as similar as possible to the one we started with. VAE add sampling as an additional ingredient: given the input sample $x$ we generate the corresponding latent vector $z$ by sampling $q_\phi(z|x)$ (the true conditional distribution is unaccessible, so we can only sample the variational one!), while once we have $z$, the output sample is generated by sampling $p_\theta(x|z)$. Optimization then happens by minimizing a "reconstruction loss" w.r.t. the parameters involved in the generation of the output samples, i.e. those of both the encoder and the decoder, $\phi$ and $\theta$.

In fact, VAE follow a slightly different logic that ends up including the above one. Instead of a reconstruction loss, the starting point is the minimization of $D_\mathrm{KL} \left[ q_\phi(z|x) || p(z|x) \right]$. Using Bayes' theorem, we can rewrite $p(z|x)$ precisely in terms of the distribution given by the decoder,
$$
p(z|x) = \frac{p_\theta(x|z)\,p(z)}{p(x)}\,,
$$
where $p(z)$ is a prior distribution on latent vectors and $p(x)$ is the evidence, the true generating distribution of the data. Substituting this in the expression for the KL-divergence gives
$$
\begin{array}{lll}
D_\mathrm{KL} \left[ q_\phi(z|x) || p(z|x) \right] &=& \mathbb{E}_{z\sim\phi} \left[ \log(q_\phi(z|x)) - \log(p_\theta(x|z)) - \log(p(z)) + \log(p(x)) \right] \\
&=& D_\mathrm{KL} \left[ q_\phi(z|x) || p(z) \right] - \mathbb{E}_{z\sim\phi} \left[ \log(p_\theta(x|z)) \right] + \log(p(x))\,,
\end{array}
$$
where we recognized that the terms with $q_phi(z|x)$ and $p(z)$ reconstruct a KL-divergence and that $p(x)$ doesn't depend on $z$ and therefore the expectation value acts trivially on it.

The KL-divergence between any two distributions is always non-negative and in particular is zero if and only if the two distributions are the same, therefore the final expression is always non-negative and we can rearrange it as
$$
\log(p(x)) \geq \mathbb{E}_{z\sim\phi} \left[ \log(p_\theta(x|z)) \right] - D_\mathrm{KL} \left[ q_\phi(z|x) || p(z) \right].
$$
The RHS is a lower bound for the (log) evidence and is therefore known as **evidence lower bound (ELBO)**. We wan to find values for the parameters that give maximal evidence, threfore our training task translates into the maximization of ELBO w.r.t. $\phi$ and $\theta$.

This maximization problem can be restated as the minimization of a loss function given by the negative ELBO,
$$
\mathcal{L} = - \mathbb{E}_{z\sim\phi} \left[ \log(p_\theta(x|z)) \right] + D_\mathrm{KL} \left[ q_\phi(z|x) || p(z) \right] \equiv \mathcal{L}_R + \mathcal{L}_\mathrm{KL},
$$
where
- $\mathcal{L}_R \equiv - \mathbb{E}_{z\sim\phi} \left[ \log(p_\theta(x|z)) \right]$ is nothing but the expected negative log likelihood, which in this case is also the usual **reconstruction loss** (the loss term expected by analogy with the usual autoencoders). This is the case because $z$ itself depends on the input samples $x$ because the expectation value is taken with $z$ drawn from $q_\phi(z|x)$, so effectively this term compares (a function of) the input samples with the output ones.
- $\mathcal{L}_\mathrm{KL} \equiv D_\mathrm{KL} \left[ q_\phi(z|x) || p(z) \right]$ is a new term specific for VAE and quantifies the difference between the encoding distribution and a prior distribution (to be chosen) on latent space. This acts as a regularization term for the encoding distribution.

Both terms are in fact expectation values over the distribution $q_\phi(z|x)$, a distribution conditioned on $x$, for which we only have a finite number of (training) samples.

#### Modeling assumptions

We need to make assumptions on the distributions appearing in the loss function, namely $q_\phi(z|x)$ (encoder), $p_\theta(x|z)$ (decoder) and $p(z)$ (prior on latent space). The VAE algorithm uses the following assumptions:
- $q_\phi(z|x)$ is a multivariate Gaussian on latent space with diagonal covariance matrix and parameters (mean vector $\mu(x)\in\mathbb{R}^d$ and diagonal elements $\sigma(x)\in\mathbb{R}^d$ of the covariance matrix) parametrized by the encoder network as function of the input samples $x$:
$$
q_\phi(z|x) = \mathcal{N}\left( z | \mu(x), \mathrm{diag}\left( \sigma_1^2(x), \ldots, \sigma_d^2(x) \right) \right).
$$
- $p_\theta(x|z)$ is another multivariate Gaussian with diagonal covariance matrix and parameters (mean vector $\tilde{\mu}\in\mathbb{R}^D$ and diagonal elements $\tilde{\sigma}\in\mathbb{R}^D$ of the covariance matrix, where $D$ is the dimension of the space of samples) parametrized by the decoder network as functions of the latent vector $z$:
$$
p_\theta(x|z) = \mathcal{N}\left( x | \tilde{\mu}(z), \mathrm{diag}\left(  \tilde{\sigma}_1^2(z), \ldots,  \tilde{\sigma}_D^2(z) \right) \right).
$$
- $p(z)$ is a multivariate Gaussian with zero mean and unit variance on latent space:
$$
p(z) = \mathcal{N}\left( z | 0, I \right).
$$

With these assumption the KL loss term can be computed analytically,
$$
\mathcal{L}_\mathrm{KL} = D_\mathrm{KL} \left[ q_\phi(z|x) || p(z) \right] = -\frac{1}{2} \sum_{j=1}^d \left[ 1 + \log(\sigma_j^2(x)) - \mu_j^2(x) - \sigma_j^2(x) \right].
$$
Notice that this is a value for a **single input sample** $x$.

#### Reparametrization trick

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Flatten, Dense, Conv2DTranspose, Reshape
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import seaborn as sns

tfd = tfp.distributions

sns.set_theme()

## Get data

In [None]:
def preprocess_images(img):
    """
    """
    # Normalize pixel values.
    img = img.astype('float32') / 255.

    # Add padding.
    img = np.pad(img, ((0, 0), (2, 2), (2, 2)), constant_values=0.)
    
    # The images come in grayscale without an explicit
    # channels dimensions. Here we add it.
    img = np.expand_dims(img, -1)

    return img

In [None]:
# Note: we don't really care about the labels in the y arrays.
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

x_train = preprocess_images(x_train)
x_test = preprocess_images(x_test)

## Build and train the model

In [None]:
latent_space_dim = 16

encoder_tfp = tf.keras.Sequential([
    Conv2D(filters=4, kernel_size=(3, 3), strides=2, activation='relu'),
    Conv2D(filters=16, kernel_size=(3, 3), strides=2, activation='relu'),
    Flatten(),
    Dense(units=2 * latent_space_dim),
    tfp.layers.DistributionLambda(
        lambda t: tfd.MultivariateNormalDiag(
            loc=t[..., :latent_space_dim],
            scale_diag=t[..., latent_space_dim:]
        )
    )
])

decoder_tfp = tf.keras.Sequential([
    Dense(units=64, activation='relu'),
    Reshape(target_shape=(8, 8, 1)),
    Conv2DTranspose(filters=16, kernel_size=(3, 3), strides=2, output_padding=0, activation='relu'),
    Conv2DTranspose(filters=8, kernel_size=(3, 3), strides=2, output_padding=0, activation='relu'),
    Conv2D(filters=1, kernel_size=(4, 4), strides=1, padding='valid'),
    tfp.layers.DistributionLambda(
        lambda t: tfd.Independent(
            tfd.MultivariateNormalDiag(loc=t, scale_diag=tf.ones_like(t)),
            reinterpreted_batch_ndims=2
        )
    )
])

In [None]:
n_samples = 6

input_samples = x_train[:n_samples, ...]

# Intermediate between encoder and decoder: sample the variational distribution
# (outputted by the encoder), map the samples to the corresponding reconstruction
# distribution (outputted by the decoder) and then sample the latter.
reconstructed_samples = decoder_tfp(
    encoder_tfp(input_samples).sample()
).sample()

fig, axs = plt.subplots(nrows=2, ncols=n_samples, figsize=(14, 6))

for i in range(2):
    for j in range(n_samples):
        axs[i][j].imshow(
            input_samples[j, ...] if i == 0 else reconstructed_samples[j, ...].numpy(),
            cmap='gray'
        )
    
        axs[i][j].grid(False)

In [None]:
class SampleWithReparametrizationLayer(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()

    def call(self, gaussians):        
        means = gaussians.mean()
        stddevs = gaussians.stddev()

        batch_shape = tf.shape(means)[0]
        event_shape = tf.shape(means)[1:]
        
        reparametrized_samples = (
            means
            + stddevs * tfd.MultivariateNormalDiag(
                loc=tf.zeros(shape=event_shape),
                scale_diag=tf.ones(shape=event_shape)
            ).sample(batch_shape)
        )
        
        return reparametrized_samples

In [None]:
sample_layer = SampleWithReparametrizationLayer()

sample_layer(encoder_tfp(input_samples))

In [None]:
vae_model = tf.keras.Sequential([
    encoder_tfp,
    sample_layer,
    decoder_tfp
])

vae_model(input_samples).sample().shape

In [None]:
def reconstruction_loss(distr, input_samples, n_intermediate_samples=1):
    """
    In fact, it's the negative log likelihood.
    """
    if n_intermediate_samples == 1:
        loss = - distr.log_prob(input_samples)
    elif n_intermediate_samples > 1:
        # Unfortunately if we sampled the distribution on latent space
        # (outputted by the encoder) multiple times, we'd get a higher-rank
        # batch size, which the `Reshape` layer in the decoder cannot
        # deal with out of the box.
        raise NotImplementedError()
    else:
        raise Exception('The number of intermediate samples should be positive')
    
    return loss


def kl_loss(variational_distr):
    # KL-divergence between the variational distribution and
    # a multivariate standard Normal distribution on the latent
    # space.
    # Q: how to treat input samples?
    loss = tfd.kl_divergence(
        # Output of the encoder (given some input samples).
        variational_distr,
        # Multivariate Gaussian on latent space.
        tfd.MultivariateNormalDiag(
            loc=tf.zeros(shape=(variational_distr.batch_shape[0], latent_space_dim)),
            scale_diag=tf.ones(shape=(variational_distr.batch_shape[0], latent_space_dim))
        )
    )

    return loss

In [None]:
learning_rate = 1e-2

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

epochs_counter = 0
training_history = {
    'total_loss': []
}

In [None]:
epochs = 10
batch_size = 128

for i in range(epochs):
    epochs_counter += 1

    # Fetch training batch.
    indices = tf.random.shuffle(
        tf.range(tf.shape(x_train)[0])
    )[:batch_size]
    
    batch = tf.gather(x_train, indices)

    with tf.GradientTape() as tape:
        total_loss = tf.reduce_mean(
            reconstruction_loss(vae_model(batch), batch)
            + kl_loss(vae_model.layers[0](batch))
        )

    training_history['total_loss'].append(total_loss)
    
    gradient = tape.gradient(total_loss, vae_model.trainable_variables)

    optimizer.apply_gradients(zip(gradient, vae_model.trainable_variables))

    print(f'Epoch: {epochs_counter} | Loss: {total_loss}')

In [None]:
n_samples = 6

input_samples = x_train[:n_samples, ...]

# Intermediate between encoder and decoder: sample the variational distribution
# (outputted by the encoder), map the samples to the corresponding reconstruction
# distribution (outputted by the decoder) and then sample the latter.
reconstructed_samples = decoder_tfp(
    encoder_tfp(input_samples).sample()
).sample()

fig, axs = plt.subplots(nrows=2, ncols=n_samples, figsize=(14, 6))

for i in range(2):
    for j in range(n_samples):
        axs[i][j].imshow(
            input_samples[j, ...] if i == 0 else reconstructed_samples[j, ...].numpy(),
            cmap='gray'
        )
    
        axs[i][j].grid(False)