# Variational AutoEncoder

This notebook illustrate how to build and train a Variation AutoEncoder with the [beer framework](https://github.com/beer-asr/beer).

In [12]:
# Add "beer" to the PYTHONPATH
import sys
sys.path.insert(0, '../')

import copy

import beer
import numpy as np
import torch
import torch.nn as nn

# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d
output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Data
As an illustration, we generate a synthetic data set composed of two Normal distributed cluster.
One has a diagonal covariance matrix whereas the other has a dense covariance matrix.
Those two clusters overlap so it is reasonable to map all the data to a single Gaussian in the latent space.

In [13]:
# First cluster.
mean = np.array([3, 3]) 
x1 = np.random.randn(100) * 1.5
x2 = np.power(.75 * x1, 2)
x = np.c_[x1, x2]
data = mean + x + .2 * np.random.randn(len(x), 2)

In [14]:
# Mean, variance of the data to scale the figure.
mean = data.mean(axis=0)
var = data.var(axis=0)
std_dev = np.sqrt(max(var))
x_range = (mean[0] - 5 * std_dev, mean[0] + 5 * std_dev)
y_range = (mean[1] - 5 * std_dev, mean[1] + 5 * std_dev)
global_range = (min(x_range[0], y_range[0]), max(x_range[1], y_range[1]))

fig = figure(title='Data', width=400, height=400,
             x_range=global_range, y_range=global_range)
fig.circle(data[:, 0], data[:, 1])

show(fig)

In [15]:
X = torch.from_numpy(data).float()

In [16]:
hidden_dim = 20

class GaussianMLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim):
        super().__init__()
        self.seq = torch.nn
        self.i2h = torch.nn.Linear(in_dim, hidden_dim)
        self.h2h = torch.nn.Linear(hidden_dim, hidden_dim)
        self.h2mean = nn.Linear(hidden_dim, out_dim)
        self.nonlinear_transform = torch.nn.Tanh()
        self.h2logvar = nn.Linear(hidden_dim, out_dim)
        
    def forward(self, X):
        h = self.nonlinear_transform(self.i2h(X))
        h = self.nonlinear_transform(self.h2h(h))
        mean = self.h2mean(h) + X
        logvar = self.h2logvar(h)
        return beer.NormalDiagonalCovarianceMLP(mean, logvar.exp())

In [22]:
torch.manual_seed(13)
observed_dim = 2
latent_dim1 = 2
latent_dim2 = 1
latent_normal = beer.PPCA.create(
    torch.zeros(latent_dim1), 1., torch.eye(latent_dim2, latent_dim1)
)
vae_ppca = beer.VAE(
    GaussianMLP(observed_dim, latent_dim1, hidden_dim), 
    GaussianMLP(latent_dim1, observed_dim, hidden_dim), 
    latent_normal, 
    nsamples=5
)

In [23]:
epochs = 15_000
lrate_bayesmodel = 1.
lrate_encoder = 1e-3
X = torch.from_numpy(data).float()
elbo_fn = beer.EvidenceLowerBound(len(X))

nnet_parameters = list(vae_ppca.encoder.parameters()) + list(vae_ppca.decoder.parameters())
std_optimizer = torch.optim.Adam(nnet_parameters, lr=lrate_encoder, weight_decay=1e-2)

params1 = [vae_ppca.latent_model.mean_param]
params2 = [vae_ppca.latent_model.subspace_param]
params3 = [vae_ppca.latent_model.precision_param]
#optim = beer.BayesianModelCoordinateAscentOptimizer(params1, params2, params3,
#                                                    lrate=lrate_bayesmodel, 
#                                                    std_optim=std_optimizer)

params = params1 + params2 + params3
optim = beer.BayesianModelCoordinateAscentOptimizer(params,
                                                    lrate=lrate_bayesmodel, 
                                                    std_optim=std_optimizer)
    
elbos, klds, llhs = [], [], []
for epoch in range(epochs):
    optim.zero_grad()
    elbo = elbo_fn(vae_ppca, X)
    elbo.backward()
    elbo.natural_backward()
    optim.step()
    
    if epoch > 0:
        elbos.append(float(elbo) / len(X))
        llhs.append(float(elbo.expected_llh) / len(X))
        klds.append(float(elbo.kl_div) / len(X))

fig1 = figure(title='ELBO', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig1.line(np.arange(len(elbos)), elbos, color='blue')

fig2 = figure(title='LLH', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig2.line(np.arange(len(elbos)), llhs, color='green')

fig3 = figure(title='KL', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig3.line(np.arange(len(elbos)), klds, color='red')

show(gridplot([[fig1], [fig2, fig3]]))

In [24]:
elbo_fn(vae_ppca, X)._local_kl_div

tensor([ 3.1855,  2.5780,  2.4840,  2.7410,  6.5716,  2.3982,  2.8314,
         3.3785,  2.6146,  3.5420,  3.0759,  2.7748,  2.3189,  2.7190,
         2.3128,  2.3091,  2.8296,  2.6097,  2.5934,  2.3386,  2.4298,
         2.3287,  2.7397,  2.2879,  2.4677,  2.5909,  2.2962,  2.3399,
         3.6051,  2.3251,  2.3245,  2.6538,  2.3261,  2.3218,  2.5575,
         2.3126,  3.7724,  2.3612,  6.1385,  3.2250,  2.9334,  2.7001,
         3.1910,  2.6570,  2.4180,  2.6491,  3.8918,  2.3260,  2.5413,
         2.4948,  2.3615,  2.6912,  2.5617,  2.4130,  2.3223,  2.4160,
         2.3739,  2.4054,  2.3211,  2.3018,  3.8094,  3.3546,  2.2983,
         2.4065,  2.5738,  3.2421,  2.3598,  2.5908,  2.9743,  2.3376,
         2.3899,  2.3872,  2.3185,  2.4441,  4.1210,  2.3038,  2.5714,
         3.1548,  3.3538,  2.3245,  2.6554,  2.3918,  3.1146,  2.5590,
         3.1102,  2.5097,  3.1923,  2.5466,  4.3300,  2.9861,  2.4893,
         2.5549,  2.3937,  2.7484,  2.6009,  2.3269,  2.5029,  2.9929,
      

In [25]:
resolution=0.1
xy = np.mgrid[global_range[0]:global_range[1]:resolution, global_range[0]:global_range[1]:resolution].reshape(2,-1).T
xy = torch.from_numpy(xy).float()
import math
single_dim_nb_points = int(math.sqrt(xy.shape[0]))

# For a smooth plot increase the number of samples.
vae_ppca.nsamples = 100
neg_elbos = -vae_ppca(xy) 

p_x_mtx = (-neg_elbos).view(single_dim_nb_points,single_dim_nb_points).t().exp()
p_x_mtx = p_x_mtx.data.numpy()
fig = figure(title='p(X)', width=400, height=400,
             x_range=global_range, y_range=global_range)

plane_size = global_range[1] - global_range[0]
fig.image(image=[p_x_mtx], x=global_range[0], y=global_range[0], dw=plane_size, dh=plane_size)
fig.circle(data[:, 0], data[:, 1], alpha=.1)

show(fig)

In [26]:
model = vae_ppca.latent_model
enc_state = vae_ppca.encoder(X)
means, dcovs = enc_state.mean.detach().numpy(), enc_state.var.detach().numpy()

x = np.linspace(-20, 20, 1000)

A, B = model.mean.numpy(), model.mean.numpy() + model.subspace.numpy()[0, :]
slope = (A[1] - B[1]) / (A[0] - B[0])
intercept = -slope * ((slope * A[0] - A[1]) / slope)
s_line = np.c_[x, slope * x + intercept]
p_h = np.sqrt(1 / (2 * np.pi)) * np.exp(-.5 * x ** 2)
angle = np.arctan(abs(B[1] - A[1]) / abs(B[0] - A[0]))
R = np.array([
    [np.cos(angle), -np.sin(angle)],
    [np.sin(angle), np.cos(angle)]
])

fig1 = figure(plot_width=350, plot_height=350, x_range=(-5, 5),
             y_range=(-5, 5))

for i in range(2):
    xy = np.c_[x, np.zeros_like(x)]
    rxy1 = xy @ R.T + model.mean.numpy()
    xy = np.c_[x, (i + 1) * np.sqrt(np.ones_like(x) / model.precision.numpy())]
    rxy2 = xy @ R.T + model.mean.numpy()
    band_x = np.append(rxy1[:,0], rxy2[:, 0][::-1])
    band_y = np.append(rxy1[:,1], rxy2[:, 1][::-1])
    fig1.patch(band_x, band_y, line_alpha=0., fill_alpha=0.5, fill_color='LightBlue')

    xy = np.c_[x, np.zeros_like(x)]
    rxy1 = xy @ R.T + model.mean.numpy()
    xy = np.c_[x, -(i + 1) * np.sqrt(np.ones_like(x) / model.precision.numpy())]
    rxy2 = xy @ R.T + model.mean.numpy()
    band_x = np.append(rxy1[:,0], rxy2[:, 0][::-1])
    band_y = np.append(rxy1[:,1], rxy2[:, 1][::-1])
    fig1.patch(band_x, band_y, line_alpha=0., fill_alpha=0.5, fill_color='LightBlue')


xy = np.c_[x, np.zeros_like(x)]
rxy1 = xy @ R.T + model.mean.numpy()
xy = np.c_[x, p_h]
rxy2 = xy @ R.T + model.mean.numpy()
band_x = np.append(rxy1[:,0], rxy2[:, 0][::-1])
band_y = np.append(rxy1[:,1], rxy2[:, 1][::-1])
fig1.patch(band_x, band_y, line_color='black', fill_color='LightGreen')

fig1.cross(means[:, 0], means[:, 1], color='red', alpha=.5)

show(fig1)