# Variational AutoEncoder

This notebook illustrate how to build and train a Variation AutoEncoder with the [beer framework](https://github.com/beer-asr/beer).

In [1]:
# Add "beer" to the PYTHONPATH
import sys
sys.path.insert(0, '../')

import copy

import beer
import numpy as np
import torch

# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d
output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

## Data

As an illustration, we generate a synthetic data set composed of two Normal distributed cluster.
One has a diagonal covariance matrix whereas the other has a dense covariance matrix.
Those two clusters overlap so it is reasonable to map all the data to a single Gaussian in the latent space.

In [2]:
# First cluster.
mean = np.array([-1.5, 4]) 
cov = np.array([[.75, 0], [0, 2.]])
data1 = np.random.multivariate_normal(mean, cov, size=200)

# Second cluster.
mean = np.array([3, 2.5]) 
cov = np.array([[2, 1], [1, .75]])
data2 = np.random.multivariate_normal(mean, cov, size=200)

# Merge everything to get the finale data set.
data = np.vstack([data1, data2])
np.random.shuffle(data)

In [3]:
# Mean, variance of the data to scale the figure.
mean = data.mean(axis=0)
var = data.var(axis=0)
std_dev = np.sqrt(max(var))
x_range = (mean[0] - 2 * std_dev, mean[0] + 2 * std_dev)
y_range = (mean[1] - 2 * std_dev, mean[1] + 2 * std_dev)
global_range = (min(x_range[0], y_range[0]), max(x_range[1], y_range[1]))

fig = figure(title='Data', width=400, height=400,
             x_range=global_range, y_range=global_range)
fig.circle(data[:, 0], data[:, 1])

show(fig)

In [4]:
observed_dim = 2
latent_dim = 2

In [29]:
hidden_dim = 10

enc_nn = torch.nn.Sequential(
    torch.nn.Linear(observed_dim, hidden_dim),
    torch.nn.Tanh(),
)
enc_proto = beer.models.MLPNormalDiag(enc_nn, latent_dim)

dec_nn = torch.nn.Sequential(    
    torch.nn.Linear(latent_dim, hidden_dim),
    torch.nn.Tanh(),
)
dec_proto = beer.models.MLPNormalDiag(enc_nn, observed_dim)

In [30]:
class FixedIdentityGaussian:
    def __init__(self, dim):
        mean = torch.ones(dim)
        var = torch.ones(dim)
        self._np = beer.models._normal_diag_natural_params(mean, var)
        
    def expected_natural_params(self, mean, var):
        return self._np, None
    
latent_normal = FixedIdentityGaussian(latent_dim)

In [44]:
import copy
vae = beer.models.VAE(copy.deepcopy(enc_proto), copy.deepcopy(dec_proto), latent_normal, nsamples=5)
mean_elbos = []

In [45]:
from torch.autograd import Variable
X = Variable(torch.from_numpy(data).float())
optim = torch.optim.SGD(vae.parameters(), lr=1e-3)

In [48]:
for i in range(5000):
    sth = vae.forward(X)
    neg_elbo, llh, kld = vae.loss(X, sth)
    obj = neg_elbo.mean()
    mean_elbos.append(-obj.item())
    optim.zero_grad()
    obj.backward()
    optim.step()
    
fig = figure(title='ELBO', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig.line(np.arange(len(mean_elbos)), mean_elbos, legend='GMM (diag)', color='blue')
fig.legend.location = 'bottom_right'

show(fig)