# Generalized Subspace Model: Normal distribution

This notebook illustrates how to build and train a Bayesian Generalized Subspace Model for a Normal distribution with the [beer framework](https://github.com/beer-asr/beer). 

In [None]:
# Add "beer" to the PYTHONPATH
import sys
sys.path.append('../')

import beer
import numpy as np
import torch
from torch.autograd import Variable

# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import Range1d, LinearAxis
output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

## Data

The "data" is actually a Normal distribution whose parameters lie in a (non-linear) subspace. For our synthetic example it is defined as:

$$
\begin{split}
    x & \sim \mathcal{N}(\frac{3}{2}, 10) \\
    \mu &= 
        \begin{pmatrix}
            x \\
            1 - \ln \big( 1 + e^{-x} \big)
        \end{pmatrix} \\
    \Sigma &= 
        \begin{pmatrix}
            1 & 0 \\
            0 & 1 + e^{-\frac{x}{2}}
        \end{pmatrix}
\end{split}
$$

In [None]:
def normal_from_subspace(x):
    mean = np.array([x, 1 - np.log(1 + np.exp(-x))])
    cov = np.array([
        [1., 0.],
        [0., (1 + np.exp(-.5 * x))]
    ])
    return mean, cov 

xs = np.sqrt(10) * np.random.randn(100) + 1.5
dists = [normal_from_subspace(x) for x in xs]

fig = figure(
    title='Data',
    width=400,
    height=400,
    x_range=(-10, 10),
    y_range=(-10, 10)
)

for dist in dists:
    plotting.plot_normal(fig, dist[0], dist[1], alpha=.1)
    
show(fig)

Natural parameters of the Normal distribution.

$$
\eta(\tau, \mu) = \begin{pmatrix}
    -\frac{\tau}{2} \\
    \tau \mu \\
    - \frac{\tau \mu^2}{2} \\
    \frac{1}{2} \ln \tau
\end{pmatrix}
$$

In [None]:
def natural_params(mean, diag_prec):
    return np.hstack([
        -.5 * diag_prec,
        diag_prec * mean,
        -.5 * diag_prec * (mean ** 2),
        .5 * np.log(diag_prec)
    ])

In [None]:
data = np.vstack([natural_params(dist[0], 1 / np.diag(dist[1])) for dist in dists])
data.shape

## Model Creation

We create two types of Normal distribution: one diagonal covariance matrix and another one with full covariance matrix.

In [None]:
# Dimension of the observed spac.
obs_dim = data.shape[1]

# Dimension of the latent space. It can be bigger or smaller
# than the dimension of the observed space.
latent_dim = 2

# Number of samples for the "reparameterization-trick".
nb_samples = 10

# Number of units per hidden-layer.
n_units = 10

# beer uses pytorch as a backend for the neural-network part
# of the model.
from torch import nn

# Neural network structure of the encoder of the model.
enc_struct = nn.Sequential(
    nn.Linear(obs_dim, n_units),
    nn.Tanh(),
    nn.Linear(n_units, n_units),
    nn.Tanh()
)
encoder = beer.models.MLPNormalDiag(enc_struct, latent_dim, residual=True)

# Neural network structure of the decoder of the model.
dec_struct = nn.Sequential(
    nn.Linear(latent_dim, n_units),
    nn.Tanh(),
    nn.Linear(n_units, n_units),
    nn.Tanh()
)
decoder = beer.models.MLPNormalGamma(dec_struct, obs_dim, prior_count=1)

# Model of the latent space (uncomment the one you want to try).
# It can be changed at any-time. 
# ----------------------------------------------------------------------
latent_model = beer.models.NormalDiagonalCovariance.create(latent_dim)

#latent_model = beer.models.NormalFullCovariance.create(latent_dim, prior_count=1e-3)

#args = {'dim':2, 'prior_count':1, 'mean': data_mean, 'cov': data_cov, 'random_init':True}
#latent_model = beer.Mixture.create(10, beer.NormalDiagonalCovariance.create, args, prior_count=1e-6)

#args = {'dim':2, 'prior_count':1, 'mean': data_mean, 'cov': data_cov, 'random_init':True}
#latent_model = beer.Mixture.create(10, beer.NormalFullCovariance.create, args, prior_count=1)

# ----------------------------------------------------------------------

# Putting everything together to build the SVAE.
svae = beer.models.VAE(encoder, decoder, latent_model, nb_samples)

Initialization of the model.

In [None]:
beer.training.train_vae(svae, data, max_epochs=500, lrate=1e-3, latent_model_lrate=0, kl_weight=0.0, callback=None)
beer.training.train_vae(svae, data, max_epochs=500, lrate=0, latent_model_lrate=1e-1, kl_weight=0.0, callback=None)

## Variational Bayes Training 

In [None]:
# Callback to monitor the training progress.
elbos, llhs, klds = [], [], []
def callback(elbo, llh, kld):
    elbos.append(elbo)
    llhs.append(llh)
    klds.append(kld)

# This is the training.
beer.training.train_vae(svae, data, max_epochs=5000, lrate=1e-3, latent_model_lrate=1e-2, callback=callback)

# Plot the ELBO.
fig1 = figure(title='ELBO', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig1.line(np.arange(len(elbos)), elbos)

# Plot the LLH and the KLD separately.
fig2 = figure(title='LLH + KLD', width=400, height=400,
              y_range=(min(llhs) - 1, max(llhs) + 1),
              x_axis_label='step', y_axis_label='ln p(x|...)')
fig2.line(np.arange(len(llhs)), llhs)
fig2.extra_y_ranges['KLD'] = Range1d(0, max(klds) + 1)
fig2.add_layout(LinearAxis(y_range_name="KLD", axis_label='KLD'), 'right')
fig2.line(np.arange(len(klds)), klds, y_range_name='KLD', color='green')

show(gridplot([[fig1, fig2]]))

To see what the model has learned, we sample "embeddings" from the prior distribution and transform them with decoder of the VAE.

In [None]:
# Observed space
fig1 = figure(x_range=(-10, 10), y_range=(-10, 10), width=400, height=400)
eps = Variable(torch.from_numpy(np.random.multivariate_normal(svae.latent_model.mean, svae.latent_model.cov, 100)).float())
decoder_state = svae.decoder(eps)
for prior in decoder_state.as_priors():
    normal = beer.NormalDiagonalCovariance(prior, prior)
    plotting.plot_normal(fig1, normal.mean, normal.cov, alpha=.1)

# Latent Space.
fig2 = figure(x_range=(-10, 10), y_range=(-10, 10), width=400, height=400)
embeddings = eps.data.numpy()
plotting.plot_latent_model(fig2, svae.latent_model, color='salmon')
fig2.cross(embeddings[:, 0], embeddings[:, 1], color='black')

grid = gridplot([[fig1, fig2]])
show(grid)  