# Variational autoencoders

__Objective:__ define a variational autoencoder for images.

**Idea:** in an autoencoder, the encoder maps samples to points in latent space. In a variational autoencoder, it maps samples to **multivariate Gaussian distributions** on latent space. This helps reconstructing similar samples from nearby points in latent space, because the decoder now needs to minimize the reconstruction error for all the points sampled from the distribution corresponding to the same input sample.

### Ingredients

#### Encoder

the encoder part of the model is modified to output the parameters for a multivariate Gaussian on latent space with diagonal covariance matrix. In practice, given the input sample $x$, the encoder outputs two vectors $\mu(x), \sigma(x) \in \mathbb{R}^d$, where $d$ is the dimension of latent space, parametrizing a distribution $\mathcal{N}(\mu(x), \Sigma(x))$, where $\Sigma(x) = \mathrm{diag}(\sigma^2_1(x), \ldots, \sigma^2_d(x))$.

#### Decoder

latent vectors $z\in \mathbb{R}^d$ are obtained by sampling the distributions on latent space, and given a latent vector the decoder produces a realistic sample, as similar as possible to the one correspnding to the Gaussian distribution that generated $z$. The architecture of the decoder indeed remains the same as in regular autoencoders.

#### Loss function

The loss function to minimize has an additional term w.r.t. the usual MSE or categorical cross-entropy consisting in the KL divergence of the Gaussian distribution on the latent space corresponding to each sample and a (multivariate) standard normal distribution,
$$
\mathrm{KL}\left[ \mathcal{N}(\mu(x), \Sigma(x)) || \mathcal{N}(0, \mathbf{1}) \right]\,.
$$
This comes from assuming a multivariate standard normal prior on latent space, a Gaussian likelihood and an approximate variational posterior given by the multivariate Gaussian outputted by the encoder. With the reparametrization trick, the loss function is then given by the KL-divergence of the variational posterior and the true posterior (product of likelihood and prior).

The KL divergence above can be computed analytically, so given $\mu(x)$ and $\sigma(x)$ it's easy to compute the exact contribution to the total loss:
$$
\begin{array}{lll}
\mathrm{KL}\left[ \mathcal{N}(\mu(x), \Sigma(x)) || \mathcal{N}(0, \mathbf{1}) \right] &\equiv& -\int \mathrm{d}^d z\, \mathcal{N}(z | \mu(x), \sigma(x))\,\log\left( \frac{\mathcal{N}(z | 0, \mathbf{1})}{\mathcal{N}(z | \mu(x), \sigma(x))} \right) \\
&=& -\frac{1}{2} \sum_{j=1}^d \left( 1 + \log(\sigma^2_j) - \mu_j^2 - \sigma_j^2 \right)\,.
\end{array}
$$

In the $\beta$-VAE variant of the model it's possible to tune the relative weight of the reconstruction and KL terms in the loss functions via a coefficient $\beta$,
$$
\mathcal{L} = \mathrm{MSE} + \beta\,\mathrm{KL}\,.
$$
$\beta$ is an hyperparameter controlling the balance between the minimization of either term in the loss: if $\beta$ is too small the KL term will have little effect (latent vectors more spread in latent space, farther away from the origin and with discontinuoous clusters), while if $\beta$ is too big the KL term will prevail and the model will have a poor recontruction power (essentially the Gaussians will end up fitting the unit ones).

#### Reparametrization trick

Given an input sample, the prediction has a random component corresponding to the sampling of the Gaussian distribution obtained from the input sample via the encoder. Backpropagation would require to "differentiate the sampling" w.r.t. the parameters of the Gaussian distribution, which is not possible: one drawn, a sample is a numerical value and all the information about the distribution from which it was generated is lost. Nonetheless, it's possible use a reparametrization of the Gaussian distribution that allows for explicit differentiation w.r.t. to the $\mu(x)$ and $\sigma(x)$ parameters, the **reparametrization trick**.

Given the input sample $x$, the encoder outputs the parameters $\mu(x)$ and $\sigma(x)$ of the multivariate Gaussian $\mathcal{N}(\mu(x), \sigma(x))$, from which the latent vector $z$ is sampled,
$$
z \sim \mathcal{N}(\mu(x), \sigma(x))\,.
$$
The reparametrization trick consists in sampling $z$ in the equivalent way
$$
z = \mu(x) + \sigma(x)\,\epsilon\,,
$$
where $\epsilon \sim \mathcal{N}(0, 1)$. This way the generated values for $z$ are exactly equivalent as before, but the parameters $\mu$ and $\sigma$ appear exlicitly and differentiation w.r.t. them is possible.

In [None]:
import sys
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import seaborn as sns

sys.path.append('../../modules/')

from variational_autoencoders import VariationalEncoder, VAE
from autoencoders import Decoder
from keras_utilities import get_intermediate_output, append_to_full_history, plot_history

tfd = tfp.distributions

sns.set_theme()

%load_ext autoreload
%autoreload 2

## Get data

In [None]:
def preprocess_images(img):
    """
    """
    # Normalize pixel values.
    img = img.astype('float32') / 255.

    # Add padding.
    img = np.pad(img, ((0, 0), (2, 2), (2, 2)), constant_values=0.)
    
    # The images come in grayscale without an explicit
    # channels dimensions. Here we add it.
    img = np.expand_dims(img, -1)

    return img

In [None]:
# Note: we don't really care about the labels in the y arrays.
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

In [None]:
x_train = preprocess_images(x_train)
x_test = preprocess_images(x_test)

## Model definition and training

In [None]:
variational_encoder = VariationalEncoder()

n_samples = 5000

random_inputs = tf.random.normal(shape=(n_samples, 32, 32, 1))

z_mean, z_log_var, z_samples = variational_encoder(random_inputs)

In [None]:
variational_encoder = VariationalEncoder()

image_reshaping_size = get_intermediate_output(
    tf.random.normal(shape=(15, 32, 32, 1)),
    variational_encoder,
    3
).shape[1:]

decoder = Decoder(image_reshaping_size)

vae_model = VAE(
    variational_encoder=variational_encoder,
    decoder=decoder
)

vae_model(tf.random.normal(shape=(21, 32, 32, 1)))

vae_model.summary()

In [None]:
vae_model.compile(
    optimizer='adam'
)

full_history = dict([])

In [None]:
epochs = 20
batch_size = 100

history = vae_model.fit(
    x_train,
    epochs=epochs,
    batch_size=batch_size,
    validation_data=(x_test, x_test)
)

append_to_full_history(history, full_history)

plot_history(full_history)

In [None]:
saved_model_path = '../../models/variational_autoencoders/vae_model.keras'

vae_model.save(saved_model_path)

# loaded_model = tf.keras.models.load_model(saved_model_path)

## Image reconstruction after training

In [None]:
nrows = 2
ncols = 6

reconstructed_images = tf.concat(
    [
        x_test[:ncols, ...][None, ...],
        vae_model(x_test[:ncols, ...])[2][None, ...]
    ],
    axis=0
)

fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(14, 4))

for i in range(nrows):
    for j in range(ncols):
        ax = axs[i][j]
        
        ax.imshow(
            reconstructed_images[i, j, ...],
            cmap='gray'
        )

        ax.grid(False)

## Exploration of the latent space

In [None]:
n_samples = 5000

z_means, _, z_samples = variational_encoder(x_test[:n_samples, ...])

fig = plt.figure(figsize=(14, 6))

sns.scatterplot(
    x=z_means[:, 0],
    y=z_means[:, 1],
    hue=y_test[:n_samples],
    palette=sns.color_palette()
)

## Generating new images

Thanks to the KL divergence term in the loss function, the distribution in which samples are encoded should not be too far away from a standard normal distribution. This implies that if we want to generate random samples from latent space, we can just use a standard normal distribution and be to find realistic recostructed samples.

In [None]:
n_images = 6

# Random 2-dimensional vectors in latent space.
random_latent_vectors = tf.concat(
    [
        tf.random.normal(shape=(n_images, 1)),
        tf.random.normal(shape=(n_images, 1))
    ],
    axis=-1
)

# Decode the randomly-generated latent vectors into
# images via the decoder.
random_images = vae_model.decoder(random_latent_vectors)


# Plot the position of the random latent vectors over
# existing samples.
fig = plt.figure(figsize=(14, 6))

sns.scatterplot(
    x=z_means[:, 0],
    y=z_means[:, 1],
    color=sns.color_palette()[0],
    alpha=.3
)

sns.scatterplot(
    x=random_latent_vectors[:, 0],
    y=random_latent_vectors[:, 1],
    color=sns.color_palette()[3],
)


# Show the decoded images corresponding to the random
# latent vectors.
fig, axs = plt.subplots(ncols=n_images, figsize=(14, 4))

for i in range(n_images):
    axs[i].imshow(
        random_images[i, ...],
        cmap='gray'
    )

    axs[i].grid(False)

    plt.sca(axs[i])
    plt.title(f'{random_latent_vectors[i, ...].numpy().round(2)}')

Build a path in latent space and observe the morphing of the corresponding reconstructed images.

In [None]:
starting_point = tf.constant([.8, 0.5])
endpoint = tf.constant([2.5, 1.5])

n_points = 20

path = (endpoint - starting_point) * tf.linspace(0., 1., n_points)[..., None] + starting_point

reconstructed_images_path = vae_model.decoder(path)

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.scatterplot(
    x=z_means[:, 0],
    y=z_means[:, 1],
    hue=y_test[:n_samples],
    palette=sns.color_palette()
)

sns.lineplot(
    x=path[:, 0],
    y=path[:, 1],
    color=sns.color_palette()[3],
    linestyle='dashdot',
    linewidth=3
)

In [None]:
# Generate images correponding to points along
# the path.
images_along_path = decoder(path)

# Plot generated images.
ncols = 10
nrows = n_points // ncols

fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(14, 4))

for i in range(nrows):
    for j in range(ncols):
        ax = axs[i][j]
        
        ax.imshow(
            images_along_path[i * ncols + j, ...],
            cmap='gray'
        )

        ax.grid(False)

        plt.sca(ax)
        plt.xticks([])
        plt.yticks([])
        plt.title(f'{path[i * ncols + j, ...].numpy().round(2)}')

## Morphing towards a particular item

Let's say that we want to start from a point in latent space and move along a straight line to morph the reconstructed object into a pair of trousers (class label 1 in the dataset).

In [None]:
trousers_samples = x_train[y_train == 1]

In [None]:
ncols = 6

random_trousers_samples = tf.gather(
    trousers_samples,
    indices=np.random.choice(range(trousers_samples.shape[0]), ncols)
)

fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(14, 4))

for i in range(ncols):
    ax = axs[i]
    
    ax.imshow(
        x_test[y_test == 1][i, ...],
        cmap='gray'
    )

    ax.grid(False)

In order to move in the direction of "trousers", we do the following:
1. Compute the average latent vector (i.e. mean of the corresponding Gaussians) for all (training) samples belonging to the class.
2. Compute the average latent vector for all the (training) samples belonging to any other class.
3. Subtract the second from the first and normalize it. This is the general "trousers" direction.

In [None]:
average_latent_trousers = tf.reduce_mean(
    vae_model.variational_encoder(trousers_samples)[0],
    axis=0
)
average_latent_other_classes = tf.reduce_mean(
    vae_model.variational_encoder(x_train[y_train != 1, ...])[0],
    axis=0
)

latent_direction_trousers = average_latent_trousers - average_latent_other_classes
latent_direction_trousers = latent_direction_trousers / tf.norm(latent_direction_trousers)

In [None]:
latent_direction_trousers

Let's start from a random point in latent space and let's make the corresponding reconstructed image more "trousers-y".

In [None]:
random_latent_vectors = tf.random.normal(shape=(1, 2))

t_path = random_latent_vectors + latent_direction_trousers * tf.linspace(0., 3., 20)[..., None]

fig = plt.figure(figsize=(14, 6))

sns.scatterplot(
    x=z_means[:, 0],
    y=z_means[:, 1],
    hue=y_test[:n_samples],
    palette=sns.color_palette()
)

sns.lineplot(
    x=t_path[:, 0],
    y=t_path[:, 1],
    color=sns.color_palette()[3],
    linestyle='dashdot',
    linewidth=3
)

In [None]:
# Generate images correponding to points along
# the path.
t_morphing_images = vae_model.decoder(t_path)

# Plot generated images.
ncols = 10
nrows = n_points // ncols

fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(14, 4))

for i in range(nrows):
    for j in range(ncols):
        ax = axs[i][j]
        
        ax.imshow(
            t_morphing_images[i * ncols + j, ...],
            cmap='gray'
        )

        ax.grid(False)

        plt.sca(ax)
        plt.xticks([])
        plt.yticks([])
        plt.title(f'{t_path[i * ncols + j, ...].numpy().round(2)}')