# Variational AutoEncoder

This notebook illustrate how to build and train a Variation AutoEncoder with the [beer framework](https://github.com/beer-asr/beer).

In [None]:
# Add "beer" to the PYTHONPATH
import sys
sys.path.insert(0, '../')

import copy

import beer
import numpy as np
import torch
import torch.nn as nn

# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d
output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

## Data
As an illustration, we generate a synthetic data set composed of two Normal distributed cluster.
One has a diagonal covariance matrix whereas the other has a dense covariance matrix.
Those two clusters overlap so it is reasonable to map all the data to a single Gaussian in the latent space.

In [None]:
# First cluster.
mean = np.array([3, 3]) 
x1 = np.random.randn(100) * 1.5
x2 = np.power(.75 * x1, 2)
x = np.c_[x1, x2]
data = mean + x + .2 * np.random.randn(len(x), 2)

In [None]:
# Mean, variance of the data to scale the figure.
mean = data.mean(axis=0)
var = data.var(axis=0)
std_dev = np.sqrt(max(var))
x_range = (mean[0] - 5 * std_dev, mean[0] + 5 * std_dev)
y_range = (mean[1] - 5 * std_dev, mean[1] + 5 * std_dev)
global_range = (min(x_range[0], y_range[0]), max(x_range[1], y_range[1]))

fig = figure(title='Data', width=400, height=400,
             x_range=global_range, y_range=global_range)
fig.circle(data[:, 0], data[:, 1])

show(fig)

In [None]:
torch.manual_seed(9)
observed_dim = 2
latent_dim1 = 2
latent_dim2 = 1
latent_normal = beer.PPCA.create(
    torch.zeros(latent_dim1), 1., torch.eye(latent_dim2, latent_dim1)
)

# Number of units per hidden-layer.
n_units = 20

class GaussianMLP(nn.Module):
    def __init__(self, structure, space_dim, residual=False):
        super().__init__()
        self.residual = residual
        self.nn = structure
        self.h2mean = nn.Linear(n_units, space_dim)
        self.h2logvar = nn.Linear(n_units, space_dim)
    
        self.h2logvar.bias.data += -1.0 # init with small (log)variance
            
    def forward(self, X):
        h = self.nn(X)
        mean = self.h2mean(h)
        if self.residual:
            mean += X
        logvar = self.h2logvar(h)
        
        return beer.NormalDiagonalCovarianceMLP(mean, logvar.exp())

    
# Neural network structure of the encoder/decoder of the model.
enc_struct = nn.Sequential(
    nn.Linear(observed_dim, n_units),
    nn.Tanh(),
    nn.Linear(n_units, n_units),
    nn.Tanh(),
)

dec_struct = nn.Sequential(
    nn.Linear(observed_dim, n_units),
    nn.Tanh(),
    nn.Linear(n_units, n_units),
    nn.Tanh(),
)

#latent_normal = beer.NormalDiagonalCovariance.create(
#    torch.zeros(latent_dim1), torch.ones(latent_dim1)
#)
vae_ppca = beer.VAE(
    GaussianMLP(enc_struct, observed_dim, residual=True), 
    GaussianMLP(dec_struct, observed_dim, residual=True), 
    latent_normal, 
    nsamples=5
)

In [None]:
epochs = 5_000
lrate_bayesmodel = 1.
lrate_encoder = 1e-3
X = torch.from_numpy(data).float()
elbo_fn = beer.EvidenceLowerBound(len(X))

nnet_parameters = list(vae_ppca.encoder.parameters()) + list(vae_ppca.decoder.parameters())
std_optimizer = torch.optim.Adam(nnet_parameters, lr=lrate_encoder, weight_decay=1e-2)
params = vae_ppca.latent_model.parameters
optim = beer.BayesianModelOptimizer(params,lrate=lrate_bayesmodel, 
                                    std_optim=std_optimizer)
    
elbos, klds, llhs = [], [], []
for epoch in range(epochs):
    optim.zero_grad()
    elbo = elbo_fn(vae_ppca, X)
    elbo.backward()
    elbo.natural_backward()
    optim.step()
    
    if epoch > 0:
        elbos.append(float(elbo) / len(X))
        llhs.append(float(elbo.expected_llh) / len(X))
        klds.append(float(elbo.kl_div) / len(X))

fig1 = figure(title='ELBO', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig1.line(np.arange(len(elbos)), elbos, color='blue')

fig2 = figure(title='LLH', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig2.line(np.arange(len(elbos)), llhs, color='green')

fig3 = figure(title='KL', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig3.line(np.arange(len(elbos)), klds, color='red')

show(gridplot([[fig1], [fig2, fig3]]))

In [None]:
model = vae_ppca.latent_model
model_mean = model.mean.numpy()
S = model.subspace.numpy()
model_cov = S.T @ S + np.identity(2) / model.precision.numpy()

def transform(xy):
    dec_state = vae_ppca.decoder(model_mean + torch.from_numpy(xy).float())
    return dec_state.mean.detach().numpy()

def transform2(xy):
    dec_state = vae_ppca.decoder(model_mean + torch.from_numpy(xy).float())
    return dec_state.mean.detach().numpy(), \
        dec_state.var.detach().numpy()

A, B = model.mean.numpy(), model.mean.numpy() + model.subspace.numpy()[0, :]
slope = (A[1] - B[1]) / (A[0] - B[0])
intercept = -slope * ((slope * A[0] - A[1]) / slope)
x = np.linspace(-20, 20, 1000)
o_s_line = np.c_[x, slope * x + intercept]
xy = transform(o_s_line)
s_line = np.c_[x, slope * x + intercept + np.sqrt(1/model.precision.numpy())]
confidence1 = transform(s_line)
s_line = np.c_[x, slope * x + intercept - np.sqrt(1/model.precision.numpy())]
confidence2 = transform(s_line)
s_line = np.c_[x, slope * x + intercept + 2 * np.sqrt(1/model.precision.numpy())]
confidence3 = transform(s_line)
s_line = np.c_[x, slope * x + intercept - 2 * np.sqrt(1/model.precision.numpy())]
confidence4 = transform(s_line)
    
fig1 = figure(plot_width=400, plot_height=400, x_range=(-10, 10),
             y_range=(0, 15))
fig1.line(xy[:, 0], xy[:, 1])
fig1.circle(data[:, 0], data[:, 1], color='red', alpha=.5)
means, variances = transform2(o_s_line)
print(means.shape, variances.shape)
for i in range(len(means)):
    mean = means[i]
    cov = np.diag(variances[i])
    plotting.plot_normal(fig1, mean, cov, n_std_dev=2, alpha=.1)
show(fig1)

In [None]:
model = vae_ppca.latent_model
enc_state = vae_ppca.encoder(X)
means, dcovs = enc_state.mean.detach().numpy(), enc_state.var.detach().numpy()

x = np.linspace(-20, 20, 1000)

A, B = model.mean.numpy(), model.mean.numpy() + model.subspace.numpy()[0, :]
slope = (A[1] - B[1]) / (A[0] - B[0])
intercept = -slope * ((slope * A[0] - A[1]) / slope)
s_line = np.c_[x, slope * x + intercept]
p_h = np.sqrt(1 / (2 * np.pi)) * np.exp(-.5 * x ** 2)
angle = np.arctan(abs(B[1] - A[1]) / abs(B[0] - A[0]))
R = np.array([
    [np.cos(angle), -np.sin(angle)],
    [np.sin(angle), np.cos(angle)]
])

fig1 = figure(plot_width=400, plot_height=400, x_range=(-5, 5),
             y_range=(-5, 5))

for i in range(2):
    xy = np.c_[x, np.zeros_like(x)]
    rxy1 = xy @ R.T + model.mean.numpy()
    xy = np.c_[x, (i + 1) * np.sqrt(np.ones_like(x) / model.precision.numpy())]
    rxy2 = xy @ R.T + model.mean.numpy()
    band_x = np.append(rxy1[:,0], rxy2[:, 0][::-1])
    band_y = np.append(rxy1[:,1], rxy2[:, 1][::-1])
    fig1.patch(band_x, band_y, line_alpha=0., fill_alpha=0.3, fill_color='LightBlue')

    xy = np.c_[x, np.zeros_like(x)]
    rxy1 = xy @ R.T + model.mean.numpy()
    xy = np.c_[x, -(i + 1) * np.sqrt(np.ones_like(x) / model.precision.numpy())]
    rxy2 = xy @ R.T + model.mean.numpy()
    band_x = np.append(rxy1[:,0], rxy2[:, 0][::-1])
    band_y = np.append(rxy1[:,1], rxy2[:, 1][::-1])
    fig1.patch(band_x, band_y, line_alpha=0., fill_alpha=0.3, fill_color='LightBlue')

fig1.cross(means[:30, 0], means[:30, 1], color='red', alpha=.5)
for i in range(30):
    mean = means[i]
    cov = np.diag(dcovs[i])
    plotting.plot_normal(fig1, mean, cov, n_std_dev=2, line_alpha=.0, 
                         line_color='black', fill_alpha=.3, fill_color='red')
    
xy = np.c_[x, np.zeros_like(x)]
rxy1 = xy @ R.T + model.mean.numpy() 
xy = np.c_[x, p_h]
rxy2 = xy @ R.T + model.mean.numpy()
band_x = np.append(rxy1[:,0], rxy2[:, 0][::-1])
band_y = np.append(rxy1[:,1], rxy2[:, 1][::-1])
fig1.patch(band_x, band_y, line_color='black', fill_color='LightGreen', alpha=.5)

show(fig1)