# Variational AutoEncoder

This notebook illustrate how to build and train a Variation AutoEncoder with the [beer framework](https://github.com/beer-asr/beer).

In [1]:
# Add "beer" to the PYTHONPATH
import sys
sys.path.insert(0, '../')

import copy

import beer
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d
output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

## Data
As an illustration, we generate a synthetic data set composed of two Normal distributed cluster.
One has a diagonal covariance matrix whereas the other has a dense covariance matrix.
Those two clusters overlap so it is reasonable to map all the data to a single Gaussian in the latent space.

In [2]:
# First cluster.
mean = np.array([-3, 3]) 
cov = np.array([[.25, -1], [-1, 2.]])
data1 = np.random.multivariate_normal(mean, cov, size=100)

# Second cluster.
mean = np.array([3, 2.5]) 
cov = np.array([[2, 1], [1, .75]])
data2 = np.random.multivariate_normal(mean, cov, size=100)

data = np.vstack([data1, data2])

np.random.shuffle(data)

  after removing the cwd from sys.path.


In [3]:
# Mean, variance of the data to scale the figure.
mean = data.mean(axis=0)
var = data.var(axis=0)
std_dev = np.sqrt(max(var))
x_range = (mean[0] - 3 * std_dev, mean[0] + 3 * std_dev)
y_range = (mean[1] - 3 * std_dev, mean[1] + 3 * std_dev)
global_range = (min(x_range[0], y_range[0]), max(x_range[1], y_range[1]))

fig = figure(title='Data', width=400, height=400,
             x_range=global_range, y_range=global_range)
fig.circle(data[:, 0], data[:, 1])

show(fig)

In [4]:
from torch.autograd import Variable
X = Variable(torch.from_numpy(data).float())

In [5]:
observed_dim = 2
latent_dim = 2

In [6]:
hidden_dim = 20

class GaussianMLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim):
        super().__init__()
        self.i2h = torch.nn.Linear(in_dim, hidden_dim)
        self.h2mean = nn.Linear(hidden_dim, out_dim)
        self.h2logvar = nn.Linear(hidden_dim, out_dim)
        
    def forward(self, X):
        h = F.leaky_relu(self.i2h(X))
        mean = self.h2mean(h)
        logvar = self.h2logvar(h)
        return beer.NormalDiagonalCovarianceMLP(mean, logvar.exp())

In [7]:
latent_normal = beer.NormalDiagonalCovariance(
    prior=beer.NormalGammaPrior(torch.zeros(2), torch.ones(2), 1.),
    posterior=beer.NormalGammaPrior(torch.zeros(2), torch.ones(2), 1.)
)
vae = beer.VAE(
    GaussianMLP(observed_dim, latent_dim, hidden_dim), 
    GaussianMLP(latent_dim, observed_dim, hidden_dim), 
    latent_normal, 
    nsamples=5
)
mean_elbos = []

In [8]:
def train():
    optim.zero_grad()
    loss = loss_fn(vae, X)
    loss.backward()
    mean_elbos.append(float(loss) / len(X))
    optim.step()
        
nnet_parameters = list(vae.encoder.parameters()) + list(vae.decoder.parameters())
nnet_optim = torch.optim.Adam(nnet_parameters, lr=1e-3)
params = vae.latent_model.parameters
optim = beer.BayesianModelOptimizer([], lrate=0., std_optim=nnet_optim)
loss_fn = beer.StochasticVariationalBayesLoss(len(X))

# reasonable procedure for "GMM" data
for i in range(10_000):
    train()

fig = figure(title='ELBO', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig.line(np.arange(len(mean_elbos)), mean_elbos, legend='ELBO', color='blue')
fig.legend.location = 'bottom_right'

show(fig)

In [9]:
resolution=0.1
xy = np.mgrid[global_range[0]:global_range[1]:resolution, global_range[0]:global_range[1]:resolution].reshape(2,-1).T
xy = Variable(torch.from_numpy(xy).float())
import math
single_dim_nb_points = int(math.sqrt(xy.shape[0]))

# For a smooth plot increase the number of samples.
vae.nsamples = 50

neg_elbos = -vae(xy)

p_x_mtx = (-neg_elbos).view(single_dim_nb_points,single_dim_nb_points).t().exp()
p_x_mtx = p_x_mtx.data.numpy()
fig = figure(title='p(X)', width=400, height=400,
             x_range=global_range, y_range=global_range)

plane_size = global_range[1] - global_range[0]
fig.image(image=[p_x_mtx], x=global_range[0], y=global_range[0], dw=plane_size, dh=plane_size)
fig.circle(data[:, 0], data[:, 1], alpha=.1)

show(fig)