# Variational AutoEncoder

This notebook illustrate how to build and train a Variation AutoEncoder with the [beer framework](https://github.com/beer-asr/beer).

In [None]:
# Add "beer" to the PYTHONPATH
import sys
sys.path.insert(0, '../')

import copy

import beer
import numpy as np
import torch

import torchvision
import torchvision.transforms as transforms

# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d
output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

## Data

As an illustration, we generate a synthetic data set composed of two Normal distributed cluster.
One has a diagonal covariance matrix whereas the other has a dense covariance matrix.
Those two clusters overlap so it is reasonable to map all the data to a single Gaussian in the latent space.

In [None]:
root = './data'
download = False  # set to True if the line "train_set = ..." complains

trans = transforms.Compose([
    transforms.RandomVerticalFlip(p=1.0),
    transforms.ToTensor(), 
#    transforms.Normalize((0.5,), (1.0,)),
])
train_set = torchvision.datasets.MNIST(root=root, train=True, transform=trans, download=download)
test_set = torchvision.datasets.MNIST(root=root, train=False, transform=trans)

batch_size = 256

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)

print('==>>> total trainning batch number: {}'.format(len(train_loader)))
print('==>>> total testing batch number: {}'.format(len(test_loader)))

In [None]:
X, t = next(iter(train_loader))

In [None]:
sqrt_nb = 5
fig = figure(x_range=[0, sqrt_nb*28], y_range=[0, sqrt_nb*28])
for i in range(sqrt_nb):
    for j in range(sqrt_nb):
        fig.image(image=[X[i*sqrt_nb + j][0].numpy()], x=j*28, y=(sqrt_nb-i-1)*28, dw=27, dh=27)
show(fig)

print(t[:sqrt_nb**2].view(sqrt_nb,sqrt_nb))


##Model

We build a VAE with Gaussian distribution in the latent space and Bernouli distribution on individual pixel in the observed space.

In [None]:
observed_dim = 28*28
latent_dim = 2

In [None]:
hidden_dim = 400

enc_nn = torch.nn.Sequential(
    torch.nn.Linear(observed_dim, hidden_dim),
    torch.nn.Tanh(),
)
latent_dist_builder = beer.NormalDiagBuilder(latent_dim)
enc_proto = beer.MLPModel(enc_nn, hidden_dim, latent_dist_builder)

dec_nn = torch.nn.Sequential(    
    torch.nn.Linear(latent_dim, hidden_dim),
    torch.nn.Tanh(),
)
obs_dist_builder = beer.BernoulliBuilder(observed_dim)
dec_proto = beer.MLPModel(dec_nn, hidden_dim, obs_dist_builder)

In [None]:
import copy
latent_normal = beer.NormalDiagonalCovariance(
    prior=beer.NormalGammaPrior(torch.zeros(latent_dim), torch.ones(latent_dim), 1.),
    posterior=beer.NormalGammaPrior(torch.zeros(latent_dim), torch.ones(latent_dim), 1.)
)
vae = beer.models.VAE(copy.deepcopy(enc_proto), copy.deepcopy(dec_proto), latent_normal, nsamples=1)
mean_elbos = []

In [None]:
def train(nb_epochs):
    for i in range(nb_epochs):
        for X, _ in train_loader:
            X = torch.autograd.Variable(X.view(-1, 28**2))
            elbo = vae.forward(X)
            obj = -elbo.mean()
            mean_elbos.append(-obj.item())
            optim.zero_grad()
            obj.backward()
            optim.step()
        print("epoch {} done, last ELBO: {}".format(i, mean_elbos[-1]))

# a reasonable training procedure
optim = torch.optim.Adam(list(vae.encoder.parameters()) + list(vae.decoder.parameters()), lr=1e-3)
train(3)

fig = figure(title='ELBO', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig.line(np.arange(len(mean_elbos)), mean_elbos, legend='ELBO', color='blue')
fig.legend.location = 'bottom_right'

show(fig)

In [None]:
sqrt_nb = 15 # how many samples per axis

# what range to cover. Note that it is a half range (2 <=> -2 -- +2).
# Consult output of the next cell for optimal value
latent_range = 4

latent_step = 2*latent_range / (sqrt_nb-1) # -1 so that we can place the end ones on the ends
latent_positions = [-latent_range + i*latent_step for i in range(sqrt_nb)]

complete_range = [-latent_range-latent_step/2, latent_range+latent_step/2]
fig = figure(x_range=complete_range, y_range=complete_range)
for ly in latent_positions:
    for lx in latent_positions:
        latent_repre = torch.Tensor([lx, ly])
        image = vae.decoder(torch.autograd.Variable(latent_repre)).mu
        image = image.view(28,28).data
        fig.image(
            image=[image.numpy()], 
            x=lx-latent_step/2, y=ly-latent_step/2, 
            dw=latent_step, dh=latent_step
        )
show(fig)


In [None]:
latent_images = []
ts = []
for X, t in train_loader:
    X = torch.autograd.Variable(X.view(-1, 28**2))
    latent_images.append(vae.encoder(X).mean)
    ts.append(t)
    
    
latent_images = torch.cat(latent_images).data
ts = torch.cat(ts).data
print(latent_images.mean(dim=0))

#           0      1       2          3        4       5        6         7        8         9
colors = ['red', 'blue', 'green', 'purple', 'black', 'cyan', 'yellow', 'brown', 'violet', 'olive']
fig = figure(
    title='p(X)', width=400, height=400, 
    x_range=[-latent_range, latent_range], y_range=[-latent_range, latent_range]
)
for i in range(10): # plot each digit seperately
    mask = (ts == i).nonzero().view(-1)
    selection = latent_images[mask]
    fig.circle(selection[:,0].numpy(), selection[:,1].numpy(), color=colors[i], size=0.4)
show(fig)