# Structured Variational AutoEncoder

This notebook illustrate how to build and train a Structured Variational AutoEncoder (SVAE) with the [beer framework](https://github.com/beer-asr/beer).

In [1]:
%load_ext autoreload
%autoreload 2

# Add the path of the beer source code ot the PYTHONPATH.
import sys
sys.path.insert(0, '../')

import numpy as np
import torch

# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d

# Beer framework
import beer

# Convenience functions for plotting.
import plotting

output_notebook(verbose=False)

## Data 

As a simple example we consider the following synthetic data: 

$$ 
\begin{split}
    z &\sim \mathcal{N}(m, \Sigma) \\
    x &= 
        \begin{pmatrix}
        z_1 \\
        z_2 + (z_1 - m_1)^2
        \end{pmatrix} 
\end{split}
$$

In [2]:
# Generate some Normal distributed samples.
mean = np.array([3., 2.])
cov = np.array([[.75, 0.], [0., .075]])
Z = np.random.multivariate_normal(mean, cov, size=100)

# Apply the non-linear transformation.
X = np.zeros_like(Z)
X[:, 0] = Z[:, 0]
X[:, 1] = Z[:, 1] + (Z[:, 0]-mean[0])** 2

fig = figure(title='Synthetic data', width=400, height=400)
fig.circle(X[:, 0], X[:, 1])
show(fig)

## Features normalization

Since the VAE model is built upon neural network components, it is a good practice to mean-variance normalize the features to ease up the training.

In [3]:
data_mean = X.mean(axis=0)
data_cov = np.cov(X.T)
X -= data_mean
X /= np.sqrt(np.diag(data_cov))

## Model Creation

We first create the SVAE.

In [8]:
# Dimension of the observed space.
obs_dim = X.shape[1]

# Dimension of the latent space. It can be bigger or smaller
# than the dimension of the observed space.
latent_dim = 2

# Number of samples for the "reparameterization-trick".
nb_samples = 10

# Number of units per hidden-layer.
n_units = 10

# beer uses pytorch as a backend for the neural-network part
# of the model.
from torch import nn

# Neural network structure of the encoder of the model.
enc_struct = nn.Sequential(
    nn.Linear(obs_dim, n_units),
    nn.Tanh(),
    nn.Linear(n_units, n_units),
    nn.Tanh()
)
encoder = beer.models.MLPNormalIso(enc_struct, latent_dim, residual=True)

# Neural network structure of the decoder of the model.
dec_struct = nn.Sequential(
    nn.Linear(latent_dim, n_units),
    nn.Tanh(),
    nn.Linear(n_units, n_units),
    nn.Tanh()
)
decoder = beer.models.MLPNormalDiag(dec_struct, obs_dim)

# Model of the latent space (uncomment the one you want to try).
# It can be changed at any-time. 
# ----------------------------------------------------------------------

#latent_model = beer.models.NormalDiagonalCovariance.create(
#    torch.zeros(latent_dim), torch.ones(latent_dim)
#)

latent_model = beer.models.NormalFullCovariance.create(
    torch.zeros(latent_dim), torch.eye(latent_dim)
)

#args = {
#    'prior_mean': torch.zeros(latent_dim), 
#    'prior_cov': torch.eye(latent_dim), 
#    'prior_count': 1, 'random_init': True
#}
#latent_model = beer.Mixture.create(torch.ones(10), beer.NormalDiagonalCovariance.create, args)
#latent_model = beer.Mixture.create(torch.ones(10), beer.NormalFullCovariance.create, args)

# ----------------------------------------------------------------------

# Putting everything together to build the SVAE.
svae = beer.models.VAE(encoder, decoder, latent_model, nb_samples)

Variational Bayes Inference is sensitive to the initialization of the posterior. Our initialization scheme is fairly basic but seems to provide good results on this toy example:
  1. Keep the prior fixed and optimize the expected value of the  log-likelihood of the VAE (i.e. loss function without the KL divergence term).
  2. Freeze the parameters of the encoder/decoder and update the latent model so it fits the current distribution of the latent space.

In [9]:
beer.train_vae(svae, torch.from_numpy(X).float(), max_epochs=1000, lrate=1e-3, latent_model_lrate=0, kl_weight=0.0, callback=None)
beer.train_vae(svae, torch.from_numpy(X).float(), max_epochs=1000, lrate=0, latent_model_lrate=1e-1, kl_weight=0.0, callback=None)

## Variational Bayes Training

In [11]:
# Callback to monitor the training progress.
elbos, llhs, klds = [], [], []
def callback(elbo, llh, kld):
    elbos.append(elbo)
    llhs.append(llh)
    klds.append(kld)

# This is the training.
beer.train_vae(svae, torch.from_numpy(X).float(), max_epochs=10000, lrate=1e-3, latent_model_lrate=1e-2, callback=callback)

# Plot the ELBO.
fig1 = figure(title='ELBO', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig1.line(np.arange(len(elbos)), elbos)

# Plot the LLH and the KLD separately.
fig2 = figure(title='LLH + KLD', width=400, height=400,
              y_range=(min(llhs) - 1, max(llhs) + 1),
              x_axis_label='step', y_axis_label='ln p(x|...)')
fig2.line(np.arange(len(llhs)), llhs)
fig2.extra_y_ranges['KLD'] = Range1d(0, max(klds) + 1)
fig2.add_layout(LinearAxis(y_range_name="KLD", axis_label='KLD'), 'right')
fig2.line(np.arange(len(klds)), klds, y_range_name='KLD', color='green')

show(gridplot([[fig1, fig2]]))

Let's see what the VAE has learnt.

In [12]:
d = 100
x_range = (-3, 3)
y_range = (-3, 3)
xy = np.mgrid[x_range[0]:x_range[1]:100j, y_range[0]:y_range[1]:100j].reshape(2,-1).T
elbo, llh, kld, mean, var = svae.evaluate(xy, sampling=False)
elbo, llh, kld, mean, var = elbo.data.numpy(), llh.data.numpy(), \
    kld.data.numpy(), mean.data.numpy(), var.data.numpy()
fig1 = figure(x_range=x_range,  y_range=y_range, width=400, height=400)

# must give a vector of image data for image parameter
fig1.image(image=[np.exp((elbo).reshape(d, d).T)], 
           x=x_range[0], y=y_range[0], 
           dw=(x_range[1] - x_range[0]), dh=(y_range[1] - y_range[0]),
           palette="Inferno256", alpha=.01)
fig1.circle(X[:, 0], X[:, 1], alpha=1)
fig2 = figure(x_range=(-10, 10), y_range=(-10, 10), width=400, height=400)

elbo, llh, kld, mean, var = svae.evaluate(X[:100], sampling=False)
elbo, llh, kld, mean, var = elbo.data.numpy(), llh.data.numpy(), \
    kld.data.numpy(), mean.data.numpy(), var.data.numpy()
fig2.cross(mean[:, 0], mean[:, 1], color='black')
for m, v in zip(mean, var):
    fig2.ellipse(x=m[0], y=m[1], 
                 width=2 * np.sqrt(v[0]), 
                 height=2 * np.sqrt(v[1]), 
                 fill_alpha=0, color='black') 
plotting.plot_latent_model(fig2, svae.latent_model, alpha=.5, color='salmon')

grid = gridplot([[fig1, fig2]])
show(grid)  

  elif np.issubdtype(type(obj), np.float):
