# Variational AutoEncoder

This notebook illustrate how to build and train a Variation AutoEncoder with the [beer framework](https://github.com/beer-asr/beer).

In [1]:
# Add "beer" to the PYTHONPATH
import sys
sys.path.insert(0, '../')

import copy

import beer
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d
output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

## Data
As an illustration, we generate a synthetic data set composed of two Normal distributed cluster.
One has a diagonal covariance matrix whereas the other has a dense covariance matrix.
Those two clusters overlap so it is reasonable to map all the data to a single Gaussian in the latent space.

In [6]:
# First cluster.
mean = np.array([-3, 3]) 
cov = np.array([[1, -1], [-1, 2.]])
data1 = np.random.multivariate_normal(mean, cov, size=100)

# Second cluster.
mean = np.array([3, 2.5]) 
cov = np.array([[2, 1], [1, .75]])
data2 = np.random.multivariate_normal(mean, cov, size=100)

data = np.vstack([data1, data2]) 

np.random.shuffle(data)

In [7]:
# Mean, variance of the data to scale the figure.
mean = data.mean(axis=0)
var = data.var(axis=0)
std_dev = np.sqrt(max(var))
x_range = (mean[0] - 3 * std_dev, mean[0] + 3 * std_dev)
y_range = (mean[1] - 3 * std_dev, mean[1] + 3 * std_dev)
global_range = (min(x_range[0], y_range[0]), max(x_range[1], y_range[1]))

fig = figure(title='Data', width=400, height=400,
             x_range=global_range, y_range=global_range)
fig.circle(data[:, 0], data[:, 1])

show(fig)

In [21]:
# Generate some test data

# First cluster.
mean = np.array([-3, 3]) 
cov = np.array([[1, -1], [-1, 2.]])
test_data1 = np.random.multivariate_normal(mean, cov, size=100)

# Second cluster.
mean = np.array([3, 2.5]) 
cov = np.array([[2, 1], [1, .75]])
test_data2 = np.random.multivariate_normal(mean, cov, size=100)

test_data = np.vstack([test_data1, test_data2])

fig = figure(title='Data', width=400, height=400,
             x_range=global_range, y_range=global_range)
fig.circle(test_data[:, 0], test_data[:, 1])

show(fig)

In [9]:
X = torch.from_numpy(data).float()
test_X = torch.from_numpy(test_data).float()

In [10]:
observed_dim = 2
latent_dim = 2

In [11]:
hidden_dim = 50

class GaussianMLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim):
        super().__init__()
        self.i2h = torch.nn.Linear(in_dim, hidden_dim)
        self.h2mean = nn.Linear(hidden_dim, out_dim)
        self.h2logvar = nn.Linear(hidden_dim, out_dim)
        
    def forward(self, X):
        h = F.tanh(self.i2h(X))
        mean = self.h2mean(h)
        logvar = self.h2logvar(h)
        return mean, logvar.exp()
    
class GaussianMLP2(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim):
        super().__init__()
        self.i2h = torch.nn.Linear(in_dim, hidden_dim)
        self.h2mean = nn.Linear(hidden_dim, out_dim)
        
    def forward(self, X):
        h = F.tanh(self.i2h(X))
        mean = self.h2mean(h)
        return mean

In [12]:
latent_normal1 = beer.NormalDiagonalCovariance.create(
    torch.zeros(latent_dim), torch.ones(latent_dim)
)
vae1 = beer.VAE(
    GaussianMLP(observed_dim, latent_dim, hidden_dim), 
    GaussianMLP(latent_dim, observed_dim, hidden_dim), 
    latent_normal1, 
    nsamples=5
)

latent_normal2 = beer.NormalDiagonalCovariance.create(
    torch.zeros(latent_dim), torch.ones(latent_dim)
)
vae2 = beer.SubspaceVAE.create(
    torch.from_numpy(data.mean(axis=0)).float(),
    torch.from_numpy(data.var(axis=0)).float(),
    GaussianMLP(observed_dim, latent_dim, hidden_dim), 
    GaussianMLP2(latent_dim, observed_dim, hidden_dim), 
    latent_normal2, 
    nsamples=5,
    pseudo_counts=1
)

models = [vae1, vae2]

In [13]:
epochs = 20_000
lrate = 1.
lrate_nnet = 1e-3

nnet_parameters = list(vae1.encoder.parameters()) + list(vae1.decoder.parameters())
nnet_parameters += list(vae2.encoder.parameters()) + list(vae2.decoder.parameters())
nnet_optim = torch.optim.Adam(nnet_parameters, lr=1e-3)
params = vae2.parameters + vae1.parameters
optim = beer.BayesianModelOptimizer(params, lrate=lrate, std_optim=nnet_optim)
elbo_fn = beer.EvidenceLowerBound(len(X))

elbos = [[], []]
for i in range(epochs):
    for i, model in enumerate(models):
        optim.zero_grad()
        elbo = elbo_fn(model, X)
        elbo.backward()
        elbo.natural_backward()
        optim.step()
        elbos[i].append(float(elbo) / len(X))
        

fig = figure(title='ELBO', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig.line(np.arange(len(elbos[0])), elbos[0], legend='ELBO (VAE 1)', color='blue')
fig.line(np.arange(len(elbos[1])), elbos[1], legend='ELBO (VAE 2)', color='red')
fig.legend.location = 'bottom_right'

show(fig)

In [14]:
resolution=0.1
xy = np.mgrid[global_range[0]:global_range[1]:resolution, global_range[0]:global_range[1]:resolution].reshape(2,-1).T
xy = torch.from_numpy(xy).float()
import math
single_dim_nb_points = int(math.sqrt(xy.shape[0]))

# For a smooth plot increase the number of samples.
vae1.nsamples = 200

elbos = vae1(xy) - vae1.local_kl_div_posterior_prior() - vae1.kl_div_posterior_prior()

p_x_mtx = elbos.view(single_dim_nb_points,single_dim_nb_points).t().exp()
p_x_mtx = p_x_mtx.data.numpy()
fig = figure(title='p(X)', width=400, height=400,
             x_range=global_range, y_range=global_range)

plane_size = global_range[1] - global_range[0]
fig.image(image=[p_x_mtx], x=global_range[0], y=global_range[0], dw=plane_size, dh=plane_size)
fig.circle(data[:, 0], data[:, 1], alpha=.1)

show(fig)

In [15]:
resolution=0.1
xy = np.mgrid[global_range[0]:global_range[1]:resolution, global_range[0]:global_range[1]:resolution].reshape(2,-1).T
xy = torch.from_numpy(xy).float()
import math
single_dim_nb_points = int(math.sqrt(xy.shape[0]))

# For a smooth plot increase the number of samples.
vae2.nsamples = 200

elbos = vae2(xy) - vae2.local_kl_div_posterior_prior() - vae2.kl_div_posterior_prior()

p_x_mtx = elbos.view(single_dim_nb_points,single_dim_nb_points).t().exp()
p_x_mtx = p_x_mtx.data.numpy()
fig = figure(title='p(X)', width=400, height=400,
             x_range=global_range, y_range=global_range)

plane_size = global_range[1] - global_range[0]
fig.image(image=[p_x_mtx], x=global_range[0], y=global_range[0], dw=plane_size, dh=plane_size)
fig.circle(data[:, 0], data[:, 1], alpha=.1)

show(fig)

In [25]:
Y = X
elbo1 = float(elbo_fn(vae1, Y)) / len(Y)
elbo2 = float(elbo_fn(vae2, Y)) / len(Y)
print(elbo1, elbo2)

Y = test_X
elbo1 = float(elbo_fn(vae1, Y)) / len(Y)
elbo2 = float(elbo_fn(vae2, Y)) / len(Y)
print(elbo1, elbo2)

-3.440498046875 -3.6354058837890624
-8582.01 -1189.866484375


In [27]:
vae2(test_X), vae2.local_kl_div_posterior_prior(), vae2.kl_div_posterior_prior()

(tensor([ -834.3262, -1317.6401, -1426.7592, -1091.1256, -1193.1471,
         -1146.8143, -1051.6378, -1472.5848, -1440.8253, -1058.5867,
         -1063.1042, -1012.1676, -1076.9736, -1172.8970, -1160.8318,
         -1239.2485, -1096.8867, -1279.8191, -1074.1561, -1197.0640,
         -1078.8087, -1245.4554, -1110.7629, -1080.7571, -1148.0337,
         -1443.4719, -1543.0397,  -842.8319, -1439.9722, -1044.9510,
         -1331.5238, -1249.8940, -1251.4414,  -952.2350,  -922.5039,
         -1393.3263, -1421.5881, -1226.1929, -1106.0793, -1537.0229,
          -963.1926, -1013.3795, -1162.7195, -1101.4579, -1349.0364,
         -1148.4849, -1126.9088, -1034.5486, -1309.4576, -1131.8815,
         -1180.2729, -1008.1373,  -968.0485, -1036.5656, -1296.3148,
         -1181.5266, -1407.0381, -1301.5968, -1114.5181, -1177.3335,
         -1161.0660, -1357.5637, -1270.5541, -1253.8571,  -911.0944,
         -1336.3091,  -943.0217, -1437.4319, -1314.9695, -1409.0460,
         -1149.9050, -1175.9561, -

In [28]:
test_X.mean()

tensor(21.4987)

In [24]:
vae1.kl_div_posterior_prior(),  vae2.kl_div_posterior_prior()

(tensor([ 9.1794]), tensor([ 36.2939]))

In [23]:
vae2.normal.mean

tensor([ 0.1112,  2.7058])