# Preface

In this notebook, we show how we can train variational autoencoders. Things to take note of that may be new
  * using the `add_loss` method for fully custom loss functions
  * writing noise generation layers

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from pathlib import Path
sns.set(font_scale=1.5, style='darkgrid')

# Import Data

We will use the very familiar MNIST dataset to illustrate generative models.

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train / 255.0
x_test = x_test / 255.0

# Generating Samples using Autoencoders

First, let us use a simple fully connected autoencoder to attempt to sample the latent space.

## Train a simple FC Autoencoder

In [None]:
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tqdm.keras import TqdmCallback

In [None]:
encoding_dim = 2  # dimension of the latent space

encoder = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(128, activation='relu'),
    Dense(encoding_dim, activation='relu'),
])

decoder = Sequential([
    Dense(128, activation='relu', input_shape=(encoding_dim, )),
    Dense(784, activation='sigmoid'),
    Reshape((28, 28)),
])

autoencoder = Sequential([encoder, decoder])

In [None]:
autoencoder.compile(
    loss='binary_crossentropy',
    optimizer='adam',
)

In [None]:
filename = 'mnist_ae_gen.h5'
try:
    autoencoder.load_weights(filename)
except:
    autoencoder.fit(
        x=x_train,
        y=x_train,
        batch_size=128,
        epochs=50,
        validation_data=(x_test, x_test),
        callbacks=[TqdmCallback()],
        verbose=0,
    )
    autoencoder.save_weights(filename)

## Reconstruction Quality

Let us check the reconstruction quality of the simple AE

In [None]:
def plot_images(images, n_plots=5):
    """Plot images of the digits
    """
    n_plots = min(n_plots, len(images))
    with sns.axes_style("dark"):
        fig, ax = plt.subplots(1, n_plots, figsize=(5*n_plots, 4))

        for i, a in zip(images, ax):
            a.imshow(i, cmap='Greys_r')
            a.axis('off')

In [None]:
x_test_pred = autoencoder.predict(x_test)

In [None]:
plot_images(x_test)
plot_images(x_test_pred)

## Sampling the Latent Space

Let us now explore the latent space, i.e.
$$
    z = \mathrm{Encoder}(x)
$$

What we are going to do is the following:
  * Given two sample images $x^{(1)}$ and $x^{(2)}$, we obtain their latent states
  $$
      z^{(i)} = \mathrm{Encoder}(x^{(i)})
  $$
  * Consider their convex combination in latent space
  $$
      z(r) = (1-r) z^{(1)} + r z^{(2)}
  $$
  where $r\in [0,1]$. As $r$ varies, this interpolates between the two latent representations
  * We then explore the decoded image
  $$
      x'(r) = \mathrm{Decoder}(z(r))
  $$
  as $r$ varies.

In [None]:
image_1 = x_test[0]
image_2 = x_test[1]

Here are the two images we picked

In [None]:
plot_images([image_1, image_2])

We now compute and plot the interpolation through the latent space.

In [None]:
z_1 = encoder.predict(image_1[None, :, :])
z_2 = encoder.predict(image_2[None, :, :])

In [None]:
interpolated_images = [image_1]
for r in np.linspace(0, 1, 10):
    z = z_2 * r + z_1 * (1-r)
    interpolated_images.append(np.squeeze(decoder.predict(z)))
interpolated_images.append(image_2)

In [None]:
plot_images(interpolated_images, n_plots=12)

As seen, the images are not varying continuously. To see this more clearly, we can plot the entire latent space and the images that are generated.

In [None]:
def plot_latents(models, data, n=15):
    """Plots labels and MNIST digits as a function of the 2D latent vector
    """
    encoder, decoder = models
    x_test, y_test = data

    z_mean = encoder.predict(x_test)
    if type(z_mean) == list:  # For compatibility later
        z_mean = z_mean[0]
    
    with sns.axes_style("dark"):
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
        im = ax1.scatter(z_mean[:, 0], z_mean[:, 1], c=y_test)
        plt.colorbar(im, ax=ax1)
        ax1.set_xlabel("$z_0$")
        ax1.set_ylabel("$z_1$")
        
        digit_size = 28
        figure = np.zeros((digit_size * n, digit_size * n))
        x_min, x_max = z_mean[:, 0].min(), z_mean[:, 0].max()
        y_min, y_max = z_mean[:, 1].min(), z_mean[:, 1].max()
        grid_x = np.linspace(x_min, x_max, n)
        grid_y = np.linspace(y_min, y_max, n)[::-1]

        for i, yi in enumerate(grid_y):
            for j, xi in enumerate(grid_x):
                z_sample = np.array([[xi, yi]])
                x_decoded = decoder.predict(z_sample)
                digit = x_decoded[0].reshape(digit_size, digit_size)
                figure[i * digit_size:(i + 1) * digit_size,
                       j * digit_size:(j + 1) * digit_size] = digit

        start_range = digit_size // 2
        end_range = (n - 1) * digit_size + start_range + 1
        pixel_range = np.arange(start_range, end_range, digit_size)
        sample_range_x = np.round(grid_x, 1)
        sample_range_y = np.round(grid_y, 1)
        ax2.set_xlabel("$z_0$")
        ax2.set_ylabel("$z_1$")
        im = ax2.imshow(figure, cmap='Greys_r')
        ax2.set_xticklabels([])
        ax2.set_yticklabels([])

In [None]:
plot_latents(
    models=(encoder, decoder),
    data=(x_test, y_test),
)

Observe that the features are not well disentangled, in particular, the transition between numbers is very non-smooth. 

# Variational Autoencoder

Let us now apply the VAE ideas developed in class to generating handwritten numbers.

## Building a VAE Network

In [None]:
from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Lambda

### Encoder architecture

Let's first build the encoder network. Unlike the usual AE, for the encoder we want 3 outputs
  * $y_1$ (mean of the latent Gaussian distribution)
  * $y_2$ (log(std) of the latent Gaussian distribution)
  * $z$ (sample of the random Gaussian with mean and std from above

We can build $y_1,y_2$ easily by the outputs of two different `Dense` layers.

In [None]:
x = Input(shape=(28, 28))
h = Flatten()(x)
h = Dense(128, activation='relu')(h)
y1 = Dense(encoding_dim)(h)
y2 = Dense(encoding_dim)(h)

We need to build a custom layer in order to obtain $z$ via a sampling procedure, using `tf.random.normal`.

In [None]:
def sample_gaussian(mean_and_logstd):
    """Reparameterization trick by sampling from an isotropic unit Gaussian.
    """
    mean, logstd = mean_and_logstd
    u = tf.random.normal(tf.shape(mean))
    return mean + tf.math.exp(logstd) * u

In [None]:
z_sample = Lambda(sample_gaussian)([y1, y2])

We can now build the encoder model, noting that there are a total of 3 outputs.

In [None]:
encoder = Model(x, [y1, y2, z_sample], name='encoder')

### Decoder architecture

The decoder model is the factorized Bernoulli model, thus we only need to build a network that outputs a value between [0,1] ($s$) per pixel of the output. This can be done using the sigmoid activation.

In [None]:
z_in = Input(shape=(encoding_dim, ))
h = Dense(128, activation='relu')(z_in)
h = Dense(784, activation='sigmoid')(h)
s = Reshape((28, 28))(h)

In [None]:
decoder = Model(z_in, s, name='decoder')

### Combining into a VAE

To combine into a VAE, we simply take the generated latent state $z$ from the encoder network and feed it into the decoder network.

In [None]:
xp = decoder(encoder(x)[-1])
vae = Model(x, xp, name='vae')

## Building the ELBO Loss

As derived in class, the (negative) ELBO loss consisting of two parts:
  * Reconstruction loss
  $$
      \mathrm{BinaryCrossEntropy}(\theta,\phi)
  $$
  * KL-divergence loss
  $$
      \mathrm{KL(\theta,\phi)} 
      =
      \underbrace{
          \frac{1}{2} \| y_1 \|^2      
      }_{\mathrm{KL}_1}
      +
      \underbrace{
          \frac{1}{2} \| e^{y_2} \|^2
      }_{\mathrm{KL}_2}
      -
      \underbrace{
          \sum_j y_{2,j}
      }_{\mathrm{KL}_3}
  $$

In [None]:
from tensorflow.keras.losses import binary_crossentropy

In [None]:
reconstruction_loss = binary_crossentropy(
    tf.reshape(x, [-1, 784]),
    tf.reshape(xp, [-1, 784]),
)

In [None]:
kl_loss_1 = 0.5 * tf.reduce_sum(y1**2, axis=-1)
kl_loss_2 = 0.5 * tf.reduce_sum(tf.math.exp(y2)**2, axis=-1)
kl_loss_3 = - tf.reduce_sum(y2, axis=-1)

In [None]:
kl_loss = kl_loss_1 + kl_loss_2 + kl_loss_3

We combine the losses together. Note that the BCE scales by the output dimension (784) by default, so to get the right scaling we will scale all of the KL loss by the same amount

In [None]:
vae_loss = reconstruction_loss + (1.0 / 784) * kl_loss
vae_loss = tf.reduce_mean(vae_loss)

Finally, we add the loss to the model via
```python
    vae.add_loss(vae_loss)
```
This is the most general way to use custom loss functions. However, I suggest to use as much as possible the API for `tensorflow.keras.losses` and subclassing from there, if possible. For the case of VAE, the current method is the most convenient. 

In [None]:
vae.add_loss(vae_loss)

In [None]:
vae.compile(optimizer='adam')

## Compile and Train

In [None]:
enc_filename = 'mnist_vae_encoder.h5'
dec_filename = 'mnist_vae_decoder.h5'
try:
    encoder.load_weights(enc_filename)
    decoder.load_weights(dec_filename)
except:
    vae.fit(
        x=x_train,
        y=None,
        epochs=50,
        batch_size=128,
        validation_data=(x_test, None),
        callbacks=[TqdmCallback()],
        verbose=0,
    )
    encoder.save_weights('mnist_vae_encoder.h5')
    decoder.save_weights('mnist_vae_decoder.h5')

## Exploring the Latent Space

Now, let us explore the latent space trained by VAE and try interpolation and generation of new samples.

### Interpolation

We start with the same interpolation idea.

In [None]:
_, _, z_1 = encoder.predict(image_1[None, :, :])
_, _, z_2 = encoder.predict(image_2[None, :, :])

In [None]:
interpolated_images_vae = [image_1]
for r in np.linspace(0, 1, 10):
    z = z_2 * r + z_1 * (1-r)
    interpolated_images_vae.append(np.squeeze(decoder.predict(z)))
interpolated_images_vae.append(image_2)

In [None]:
plot_images(interpolated_images, n_plots=12)
plot_images(interpolated_images_vae, n_plots=12)

### Latent Distribution and Image Generation

In [None]:
plot_latents(
    models=(encoder, decoder),
    data=(x_test, y_test),
)

# Exercise

1. Try VAEs on other image tasks, e.g. generation and interpolation between human faces
2. Try convolution layers in VAEs
3. Try using non-diagonal latent distribution model (you need to derive the corresponding loss functions!)