<a href="https://colab.research.google.com/github/jarrydmartinx/generative-models/blob/master/variational-autoencoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Variational Autoencoder with Sonnet and TensorFlow Datasets

In [0]:
#@title Install
! pip install -q plotnine

In [0]:
#@title Imports

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import plotnine as gg
import sonnet as snt
import tensorflow as tf
import tensorflow_datasets as tfds
from collections import namedtuple

from typing import Text, List, Tuple

## Data

In [0]:
#@title Getting information about our data

def get_metadata(dataset_name: Text):
  dataset = tfds.load(dataset_name, split=tfds.Split.TRAIN)

  # Extract the shape of the input images.
  print('-----------Dataset metadata---------------')
  image_shape = dataset.output_shapes['image']
  print('Image shape: {}'.format(image_shape))  


In [0]:
#@title Loading and preparing our data

def get_data(dataset_name: Text,
             batch_size: int) -> tf.Tensor:

  # Download the whole dataset, since we're doing unsupervised learning.
  dataset = tfds.load(dataset_name, split=tfds.Split.ALL)

  # Pipelining operations: repeat the dataset infinitely, then batch it.
  dataset = dataset.repeat().batch(batch_size, drop_remainder=True)

  # Make an iterator that we can use to draw batches of data.
  iterator = dataset.make_one_shot_iterator()

  # Make a op that returns the next batch of data when we run it.
  data_batch = iterator.get_next()

  # Extract the images from the data batch; cast to the dtype that TF expects
  input_data = tf.cast(data_batch['image'], dtype=tf.float32)

  return input_data

Run the cell below twice and you won't have all this annoying output clogging up your screen.

In [13]:
#@title Visualising our data

def show_image(dataset_name: Text) -> None:
  batch_size = 25
  data_op = get_data(dataset_name, batch_size)

  with tf.train.MonitoredSession() as sess:
    data_batch = sess.run(data_op)

  images = data_batch.squeeze(-1)
  index = np.random.randint(batch_size - 1)
  plt.imshow(images[index])

show_image('mnist')
get_metadata('mnist')

ValueError: ignored

## Model

### The variational autoencoder model

![alt text](https://jaan.io/images/encoder-decoder.png)

### Making Variational Autoencoder Models with Sonnet

In [0]:
#HYPERPARAMETERSn_latent = 16
batch_size = 100
dataset_name = 'mnist'
model_name = 'mlp_vae'
learning_rate = 1e-3
num_training_steps = 10000
log_every = 20
max_val = 10.


In [0]:

#@title Encoder/Decoder Models

def encode(input_data: tf.Tensor, 
           model_name: Text = 'mlp_vae',
           n_latent: int = n_latent) -> tf.Tensor:
# Encoder
  
  encoders = {
      'mlp_vae': snt.Sequential([
          snt.BatchFlatten(),
          snt.nets.MLP([64, 32], activate_final=True)])
  }
  
  embedding = encoders[model_name](input_data)
  latent_mu = snt.Linear(output_size=n_latent)(embedding)
  latent_log_var = snt.Linear(output_size=n_latent)(embedding)
  
  max_log_val = tf.cast(max_val, tf.float32)

  
  latent_log_var = tf.clip_by_value(latent_log_var,
                                      clip_value_max=max_log_val,
                                      clip_value_min=1e-5)

  latent_var = tf.exp(latent_log_var)
  
  return latent_mu, latent_var

# Decoder
# latent values is a sample from the latent space defined by mu and var above
def decode(latent_values: tf.Tensor, 
           model_name: Text = 'mlp_vae',
           batch_size: int = batch_size):
  
  mlp_decoder = snt.Sequential([
      snt.nets.MLP([32, 64, 784]),
      lambda x: tf.reshape(x, shape=[batch_size, 28, 28, 1])])

  reconstruction = mlp_decoder(latent_values)
  
  return reconstruction


## Loss

### The variational lower bound  $\mathcal{L}(\theta, \phi; \mathbf{x}^{(i)})$ on the marginal log-likelihood $\log p_\text{model}(\mathbf{\mathbf{x}^{(i)}};\theta)$


The marginal log-likelihood under our model for a single training sample $\mathbf{x}^{(i)}$ can be rewritten as:
\begin{align}
\log p(\mathbf{x}^{(i)}; \mathbf{\theta}) \quad
&= \quad \mathbb{E}_{q(z|x)}\left[\log p_\theta(\mathbf{x}^{(i)})\right] \\
&= \quad \mathbb{E}_{q(z|x)}\left[\log \frac{p_\theta(\mathbf{x}^{(i)},\mathbf{z})}{p_\theta(\mathbf{z}\mid \mathbf{x}^{(i)})}\right] \\
&= \quad \mathbb{E}_{q(z|\mathbf{x})}\left[\log \frac{p_\theta(\mathbf{x}^{(i)}, \mathbf{z})\ q_\phi(\mathbf{z} \mid \mathbf{x}^{(i)})}{q_\phi(\mathbf{z} \mid \mathbf{x}^{(i)})\ p_\theta(\mathbf{z}\mid \mathbf{x}^{(i)})} \right]\\
&= \quad \mathbb{E}_{q(z|\mathbf{x})}\left[\log \frac{p_\theta(\mathbf{x}^{(i)},\mathbf{z})}{q_\phi(\mathbf{z} \mid \mathbf{x}^{(i)})} + \log \frac{q_\phi(\mathbf{z} \mid \mathbf{x}^{(i)})}{p_\theta(\mathbf{z}\mid \mathbf{x}^{(i)})}\right]\\
&= \quad \mathbb{E}_{q(\mathbf{z}|\mathbf{x})}\left[\log \frac{p_\theta(\mathbf{x}^{(i)}, \mathbf{z})}{q_\phi(\mathbf{z} \mid \mathbf{x}^{(i)})}\right] +  \mathbb{E}_{q(z|\mathbf{x})}\left[\log \frac{q_\phi(\mathbf{z} \mid \mathbf{x}^{(i)})}{ p_\theta(\mathbf{z}\mid \mathbf{x}^{(i)})}\right] \\
&= \underbrace{\mathcal{L}(\theta, \phi ; \mathbf{x}^{(i)})}_\text{Var. lower bound on marg. log-likelihood of $\mathbf{x}^{(i)}$} + \qquad   \underbrace{D_{KL} (q_\phi(z \mid \mathbf{x}^{(i)}) \Vert p_\theta(\mathbf{z}\mid \mathbf{x}^{(i)}))}_\text{KL bw approximate and 'true' model posterior ($\theta$-dependent)} 
\end{align}

We can write the evidence lower bound (ELBO) as:
\begin{align}
\log p_\theta(\mathbf{x}^{(i)}) 
&\geq \quad \mathcal{L}(\theta, \phi; \mathbf{x}^{(i)}) \\
&= \quad \mathbb{E}_{q(z\mid\mathbf{x}^{(i)})}\left[\log \frac{p(\mathbf{x}^{(i)}, z)}{q(z|\mathbf{x}^{(i)})}\right] \\
&= \quad \mathbb{E}_{q_\phi (z\mid \mathbf{x}^{(i)})}\big\{\log p_\theta (\mathbf{x}^{(i)}, z) -\log q_\phi (z \mid \mathbf{x}^{(i)}) \big\} \\
&= \quad \underbrace{\mathbb{E}_{q_\phi (z\mid \mathbf{x}^{(i)})}\big\{\log p_\theta (\mathbf{x}^{(i)}, z) \big\}}_\text{Exp. complete data log-likelihood} \qquad +  \underbrace{H(q_\phi (z\mid \mathbf{x}^{(i)}))}_\text{Entropy of approx. posterior $q$}\\
\end{align}


For our purposes, we will rewrite this equation in a form that best reflects the way VAEs compute an estimate of the (negative) ELBO:
\begin{align}
\boxed{- \mathcal{L}(\theta, \phi; x^{(i)}) = \qquad \underbrace{- \mathbb{E}_{q_\phi(z \mid x^{(i)})} \big\{\log p_\theta (x^{(i)}\mid z) \big\}}_\text{exp'd recontruction error}  \qquad + \underbrace{D_{KL}(q_\phi (z\mid x^{(i)}) \Vert p_\theta(z))}_{\text{regularizes $\phi$ so approx. posterior $q_\phi$ is close to prior $p_\theta$}}}
\end{align}

(it's nice because it's equivalent to maximising the expectation pictured above, but with respect to q(z|x) instead of p(z), with a KL correction term which penalises how far q(z|x) is from p(z))




### Estimating the ELBO $\mathcal{L}(\theta, \phi; x^{(i)})$ and its gradients
* VAEs train by maximizing the evidence lower bound (ELBO) on the marginal log-likelihood, which is an expectation taken against $q_\phi (z'\mid x)$. In practice we often only compute the single sample Monte Carlo estimate of this expectation:
\begin{align}
&\mathbb{E}_{q_\phi (\mathbf{z}^{(i)}\mid \mathbf{x}^{(i)})}\big\{\log p_\theta (\mathbf{x}^{(i)}, \mathbf{z}^{(i)}) -\log q_\phi (\mathbf{z}^{(i)} \mid \mathbf{x}^{(i)}) \big\} \\
\approx \quad &\frac{1}{L}\sum_{l = 1}^{L} \log p_\theta (\mathbf{z}^{(i,l)}) + \log p_\theta (\mathbf{x}^{(i)} \mid \mathbf{z}^{(i,l)}) - \log q_\phi (\mathbf{z}^{(i,l)} \mid \mathbf{x}^{(i)}) \qquad \mathbf{z}^{(i,l)}  \sim q_\phi (\mathbf{z}\mid \mathbf{x}^{(i)}) \\
\approx \quad &\log p_\theta (\mathbf{z}^{(i,l)}) + \log p_\theta (\mathbf{x}^{(i)} \mid \mathbf{\mathbf{z}}^{(i,l)}) - \log q_\phi (\mathbf{z}^{(i,l)} \mid \mathbf{x}^{(i)}) \qquad \mathbf{z}^{(i,l)}  \sim q_\phi (\mathbf{z} \mid \mathbf{x}^{(i)})
\end{align}

If the KL term can be integrated analytically, we need only compute an estimate of the expected reconstruction error.
  * The single sample Monte Carlo estimate of the expected reconstruction error is:
$$\log p(\mathbf{x}^{(i)}| z') \qquad z'\sim q_{\phi}(z|\mathbf{x}^{(i)})$$

## Training

### Training our model with the ELBO Loss

In [0]:
# Make a container for the images that our trained models will reconstruct
ImageReconstructions = namedtuple('ImageReconstructions', 
                                  ['true_images', 
                                   'recon_images', 
                                   'latent_values'])

In [0]:
#@title Our train function

tf.reset_default_graph()

# Get our image data
input_data = get_data(dataset_name, batch_size)

latent_mu, latent_var = encode(input_data, model_name)

latent_sigma = tf.sqrt(latent_var)

noise = tf.random_normal(shape=latent_mu.shape)

print('Latent_sigma.shape: {}, Latent_mu.shape: {}'.format(latent_sigma.shape,
                                                          latent_mu.shape))
#@TODO(jarryd@google.com): CHECK THIS HAS A BATCH SIZE!!, you should be 
# sampling for each, not trying to get the whole batch to fit to a sample

latent_value = noise * (latent_sigma * 0.5) + latent_mu

output = decode(latent_value)

# The reconst. loss term in the ELBO loss function of our VAE
print('Output Shape: {}'.format(output.shape))
reconstruction_loss = tf.reduce_mean((output - input_data) ** 2)

# The KL term in the ELBO loss function of our VAE
# kl_loss = tf.constant(0., dtype=tf.float32)
kl_loss = 0.5 * (latent_var + latent_mu ** 2
                 - tf.log(latent_var) - 1)  # [B, L]
kl_loss = tf.reduce_sum(kl_loss, axis=-1)  # [B]
kl_loss = tf.reduce_mean(kl_loss)  # []

# This is the ELBO loss for which we will compute gradients
loss_op = reconstruction_loss + kl_loss

# Pick an optimizer; Adam is a common choice that often works well
optimizer = tf.train.AdamOptimizer(learning_rate)

# This op computes and applies our gradients for the batch (takes one sgd step)
sgd_op = optimizer.minimize(loss_op)


num_samples = batch_size

# Sampling stuff, generative stuff
sample_latents = tf.placeholder(tf.float32, shape=(num_samples, n_latent))
sample_images = decode(sample_latents, batch_size=num_samples)

#@TODO(jarryd@google.com): You need to get the real latent for EVERY image that 
# you're outputting 

# We need an open session to run our sgd and loss ops
sess = tf.Session()
sess.run(tf.global_variables_initializer()) # Initialize our variables

results = []
images = {}
print('----------------------------------------------')
print('Training {} model on the {} dataset...'.format(model_name, dataset_name))
print_every = num_training_steps/10

# Training Loop
for step in range(0, num_training_steps + 1):
  sess.run(sgd_op)

  if step % log_every == 0:
    loss, kl, recon_loss = sess.run([loss_op, kl_loss, reconstruction_loss])
    result = {
        'step': step,
        'loss': loss,
    }
    results.append(result)

  if step % print_every == 0:
    print('Iteration: {}, Loss: {}, KL: {}, Recon: {}'.format(step,
                                                              loss,
                                                              kl, 
                                                              recon_loss))

true_images, latent_values, recons = sess.run([input_data, latent_value, output])
print(latents.shape)
images = ImageReconstructions(true_images=true_images.squeeze(-1), 
                              recon_images=recons.squeeze(-1),
                              latent_values=latent_values)

ims = sess.run([sample_images], 
               feed_dict={sample_latents: images.latent_values}).squeeze()


In [0]:
ims = sess.run(sample_images, feed_dict={sample_latents: images.latent_values}).squeeze()
# ims = sample(feed_dict={sample_latents: images.latent_values[:10]}).squeeze()


ims.shape
fig, axarray = plt.subplots(2, 4)
i = 0
for row in range(2):
  for col in range(4):
    axarray[row, col].imshow(ims[i])
    i += 1
    
    
    


In [0]:
#@title Collecting our results

datasets = ['mnist', 'fashion_mnist']
models = ['mlp_vae']
batch_size = 25
num_steps = 2000

results = pd.DataFrame()
images = {}

for dataset in datasets:
  images[dataset] = {}
  
  for model in models:
    result, recon_images = train_vae(dataset, 
                                     model, 
                                     num_training_steps=num_steps,
                                     batch_size = batch_size)
    result['dataset'] = dataset
    result['model'] = model
    
    results = pd.concat([results, result])
    images[dataset][model] = recon_images

## Plotting the training loss and viewing the reconstructions

In [0]:
#@title Plotting the training loss for each model/dataset

p = (gg.ggplot(results)
     + gg.aes(x='step', y='loss', color='model')
     + gg.theme(figure_size=(8, 4))
     + gg.facet_wrap('dataset')
     + gg.lims(x=(0, 2000))
     + gg.geom_line()
    )

p

In [0]:
#@title Viewing the Generated Images

num_combinations = len(datasets) * len(models)
fig, axarray = plt.subplots(1, num_combinations)
fig.set_figwidth(14)
fig.set_figheight(7)

column = 0
rand_idx = np.random.randint(batch_size - 1)

for dataset in datasets:
  for model in models:
    axarray[column].set_title('Dataset: {}, Model: {}'.format(dataset, model))
    axarray[column].imshow(images[dataset][model].recon_images[rand_idx])
    column += 1


In [0]:
idx = np.random.randint(99)

plt.imshow(images.true_images[idx])
plt.show()
plt.imshow(images.recon_images[idx])

## Playground


In [0]:
kept_images = images

In [0]:
# except tf.errors.InvalidArgumentError:
#   var = tf.clip_by_value(tf.exp(log_var), 
#                          clip_value_min=0., 
#                          clip_value_max=tf.float32.max)