# Variational AutoEncoder

This notebook illustrate how to build and train a Variation AutoEncoder with the [beer framework](https://github.com/beer-asr/beer).

In [18]:
# Add "beer" to the PYTHONPATH
import sys
sys.path.insert(0, '../')

import copy

import beer
import numpy as np
import torch

# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d
output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Data

As an illustration, we generate a synthetic data set composed of two Normal distributed cluster.
One has a diagonal covariance matrix whereas the other has a dense covariance matrix.
Those two clusters overlap so it is reasonable to map all the data to a single Gaussian in the latent space.

In [20]:
# First cluster.
mean = np.array([-3, 3]) 
cov = np.array([[.25, -1], [-1, 2.]])
data1 = np.random.multivariate_normal(mean, cov, size=200)

# Second cluster.
mean = np.array([3, 2.5]) 
cov = np.array([[2, 1], [1, .75]])
data2 = np.random.multivariate_normal(mean, cov, size=200)

data_gmm_2 = np.vstack([data1, data2])

mean = np.array([1, 2]) 
cov = np.array([[2, 0], [0, 0.3]])
data_3 = np.random.multivariate_normal(mean, cov, size=100)
data_3[:, 1] = data_3[:, 1] + (data_3[:, 0]-mean[0])**2

data_quad = data_3

# Select a dataset
#data = data_quad
data = data_gmm_2

# data = data1
np.random.shuffle(data)

  after removing the cwd from sys.path.


In [21]:
# Mean, variance of the data to scale the figure.
mean = data.mean(axis=0)
var = data.var(axis=0)
std_dev = np.sqrt(max(var))
x_range = (mean[0] - 3 * std_dev, mean[0] + 3 * std_dev)
y_range = (mean[1] - 3 * std_dev, mean[1] + 3 * std_dev)
global_range = (min(x_range[0], y_range[0]), max(x_range[1], y_range[1]))

fig = figure(title='Data', width=400, height=400,
             x_range=global_range, y_range=global_range)
fig.circle(data[:, 0], data[:, 1])

show(fig)

In [22]:
from torch.autograd import Variable
X = Variable(torch.from_numpy(data).float())

In [23]:
observed_dim = 2
latent_dim = 2

In [24]:
hidden_dim = 100

enc_nn = torch.nn.Sequential(
    torch.nn.Linear(observed_dim, hidden_dim),
    torch.nn.LeakyReLU(),
)
enc_proto = beer.models.MLPNormalDiag(enc_nn, latent_dim)

dec_nn = torch.nn.Sequential(    
    torch.nn.Linear(latent_dim, hidden_dim),
    torch.nn.LeakyReLU(),
)
dec_proto = beer.models.MLPNormalDiag(dec_nn, observed_dim)

In [28]:
import copy
latent_normal = beer.models.FixedIsotropicGaussian(latent_dim)
vae = beer.models.VAE(copy.deepcopy(enc_proto), copy.deepcopy(dec_proto), latent_normal, nsamples=5)
mean_elbos = []
mean_klds = []
mean_llhs = []

In [29]:
def train(nb_epochs):
    for i in range(nb_epochs):
        sth = vae.forward(X)
        neg_elbo, llh, kld = vae.loss(X, sth)
        obj = neg_elbo.mean()
        mean_elbos.append(-obj.item())
        mean_klds.append(kld.mean().item())
        mean_llhs.append(llh.mean().item())
        optim.zero_grad()
        obj.backward()
        optim.step()

# reasonable procedure for "GMM" data
optim = torch.optim.SGD(vae.parameters(), lr=5e-3)        
train(10_000)
optim = torch.optim.SGD(vae.parameters(), lr=1e-3)        
train(5_000)

# an unreasonable procedure for "quadratic" data
# optim = torch.optim.SGD(vae.parameters(), lr=1e-5)        
# train(1_000)
# optim = torch.optim.SGD(vae.parameters(), lr=1e-3)
# train(2_000)
# optim = torch.optim.SGD(vae.parameters(), lr=2e-3) 
# train(2_000)
# optim = torch.optim.SGD(vae.parameters(), lr=1e-3)        
# train(2_500)
# optim = torch.optim.SGD(vae.parameters(), lr=3e-4)        
# train(2_500)


fig = figure(title='ELBO', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig.line(np.arange(len(mean_elbos)), mean_elbos, legend='ELBO', color='blue')
fig.line(np.arange(len(mean_klds)), mean_klds, legend='KLD', color='red')
fig.line(np.arange(len(mean_llhs)), mean_llhs, legend='LLH', color='black')
fig.legend.location = 'bottom_right'

show(fig)

In [30]:
resolution=0.1
xy = np.mgrid[global_range[0]:global_range[1]:resolution, global_range[0]:global_range[1]:resolution].reshape(2,-1).T
xy = Variable(torch.from_numpy(xy).float())
import math
single_dim_nb_points = int(math.sqrt(xy.shape[0]))

sth = vae.forward(xy, sampling=False)
neg_elbos, llhs, klds = vae.loss(xy, sth)

p_x_mtx = (-neg_elbos).view(single_dim_nb_points,single_dim_nb_points).t().exp()
p_x_mtx = p_x_mtx.data.numpy()
fig = figure(title='p(X)', width=400, height=400,
             x_range=global_range, y_range=global_range)

plane_size = global_range[1] - global_range[0]
fig.image(image=[p_x_mtx], x=global_range[0], y=global_range[0], dw=plane_size, dh=plane_size)
fig.circle(data[:, 0], data[:, 1])

show(fig)