# Variational AutoEncoder

This notebook illustrate how to build and train a Variation AutoEncoder with the [beer framework](https://github.com/beer-asr/beer).

In [1]:
# Add "beer" to the PYTHONPATH
import sys
sys.path.insert(0, '../')

import copy

import beer
import numpy as np
import torch

import torchvision
import torchvision.transforms as transforms

# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d
output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

## Data

As an illustration, we generate a synthetic data set composed of two Normal distributed cluster.
One has a diagonal covariance matrix whereas the other has a dense covariance matrix.
Those two clusters overlap so it is reasonable to map all the data to a single Gaussian in the latent space.

In [20]:
root = './data'
download = False  # set to True if the line "train_set = ..." complains

trans = transforms.Compose([
    transforms.RandomVerticalFlip(p=1.0),
    transforms.ToTensor(), 
#    transforms.Normalize((0.5,), (1.0,)),
])
train_set = torchvision.datasets.MNIST(root=root, train=True, transform=trans, download=download)
test_set = torchvision.datasets.MNIST(root=root, train=False, transform=trans)

batch_size = 100

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)

print('==>>> total trainning batch number: {}'.format(len(train_loader)))
print('==>>> total testing batch number: {}'.format(len(test_loader)))

==>>> total trainning batch number: 600
==>>> total testing batch number: 100


In [21]:
X, t = next(iter(train_loader))

In [22]:
sqrt_nb = 5
fig = figure(x_range=[0, sqrt_nb*28], y_range=[0, sqrt_nb*28])
for i in range(sqrt_nb):
    for j in range(sqrt_nb):
        fig.image(image=[X[i*sqrt_nb + j][0].numpy()], x=j*28, y=(sqrt_nb-i-1)*28, dw=27, dh=27)
show(fig)

print(t[:sqrt_nb**2].view(sqrt_nb,sqrt_nb))



 6  4  3  0  9
 2  7  5  5  2
 9  8  5  9  9
 8  7  9  6  7
 7  1  3  9  9
[torch.LongTensor of size (5,5)]



##Model

We build a VAE with Gaussian distribution in the latent space and Bernouli distribution on individual pixel in the observed space.

In [23]:
observed_dim = 28*28
latent_dim = 2

In [40]:
hidden_dim = 500

enc_nn = torch.nn.Sequential(
    torch.nn.Linear(observed_dim, hidden_dim),
    torch.nn.Tanh(),
)
enc_proto = beer.models.MLPNormalDiag(enc_nn, latent_dim)

dec_nn = torch.nn.Sequential(    
    torch.nn.Linear(latent_dim, hidden_dim),
    torch.nn.Tanh(),
)
dec_proto = beer.models.MLPBernoulli(dec_nn, observed_dim)

In [45]:
import copy
latent_normal = beer.models.FixedIsotropicGaussian(latent_dim)
vae = beer.models.VAE(copy.deepcopy(enc_proto), copy.deepcopy(dec_proto), latent_normal, nsamples=5)
mean_elbos = []
mean_klds = []
mean_llhs = []

In [None]:
def train(nb_epochs):
    for i in range(nb_epochs):
        for X, _ in train_loader:
            X = torch.autograd.Variable(X.view(-1, 28**2))
            sth = vae.forward(X)
            neg_elbo, llh, kld = vae.loss(X, sth)
            obj = neg_elbo.mean()
            mean_elbos.append(-obj.item())
            mean_klds.append(kld.mean().item())
            mean_llhs.append(llh.mean().item())
            optim.zero_grad()
            obj.backward()
            optim.step()
        print("epoch {} done, last ELBO: {}".format(i, mean_elbos[-1]))

# a reasonable training procedure
optim = torch.optim.SGD(vae.parameters(), lr=1e-4)
train(1)

fig = figure(title='ELBO', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig.line(np.arange(len(mean_elbos)), mean_elbos, legend='ELBO', color='blue')
fig.line(np.arange(len(mean_klds)), mean_klds, legend='KLD', color='red')
fig.line(np.arange(len(mean_llhs)), mean_llhs, legend='LLH', color='black')
fig.legend.location = 'bottom_right'

show(fig)

In [None]:
sqrt_nb = 20
latent_range = 5
scale = latent_range/(sqrt_nb/2.0)
fig = figure(x_range=[-sqrt_nb/2 * 28, sqrt_nb/2*28], y_range=[-sqrt_nb/2 * 28, sqrt_nb/2*28])
for i in range(sqrt_nb):
    for j in range(sqrt_nb):
        latent_repre = torch.Tensor([(i - sqrt_nb/2)*scale, (j-sqrt_nb/2)*scale])
        image = vae.decoder(torch.autograd.Variable(latent_repre)).mu
        image = image.view(28,28).data
        fig.image(image=[image.numpy()], x=(j-sqrt_nb/2)*28, y=(i-sqrt_nb/2)*28, dw=27, dh=27)
show(fig)


In [None]:
for X, t in train_loader:
    X = torch.autograd.Variable(X.view(-1, 28**2))
    latent_images = vae.encoder(X).mean
    break
latent_images = latent_images.data
print(latent_images.shape)
print(t.data.numpy())

fig = figure(title='p(X)', width=400, height=400)

for i in range(10):
    mask = (t == i).nonzero().view(-1)
    selection = latent_images[mask]
    colors = ['red', 'blue', 'green', 'purple', 'black', 'cyan', 'yellow', 'brown', 'violet', 'olive']
    fig.circle(selection[:,0].numpy(), selection[:,1].numpy(), color=colors[i])
show(fig)