# Structured Variational AutoEncoder

This notebook illustrate how to build and train a Structured Variational AutoEncoder (SVAE) with the [beer framework](https://github.com/beer-asr/beer).

In [1]:
%load_ext autoreload
%autoreload 2

# Add the path of the beer source code ot the PYTHONPATH.
import sys
sys.path.append('../')

import numpy as np

# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d

# Beer framework
import beer

output_notebook(verbose=False)

## Data 

As a simple example we consider the following synthetic data: 

$$ 
\begin{split}
    z &\sim \mathcal{N}(m, \Sigma) \\
    x &= 
        \begin{pmatrix}
        z_1 \\
        z_2 + (z_1 - m_1)^2
        \end{pmatrix} 
\end{split}
$$

In [2]:
# Generate some Normal distributed samples.
mean = np.array([3., 2.])
cov = np.array([[2., 1.], [1., .75]])
Z = np.random.multivariate_normal(mean, cov, size=100)

# Apply the non-linear transformation.
X = np.zeros_like(Z)
X[:, 0] = Z[:, 0]
X[:, 1] = Z[:, 1] + (Z[:, 0]-mean[0])**2
#X = Z

# We will use the statistics of the data to 
# set the range of the plots.
data_mean = X.mean(axis=0)
data_cov = np.cov(X.T)
X -= mean
X /= np.sqrt(np.diag(data_cov))
data_mean = X.mean(axis=0)
data_cov = np.cov(X.T)
x_range = (data_mean[0] - 3 * np.sqrt(data_cov[0,0]), data_mean[0] + 3 * np.sqrt(data_cov[0,0]))
y_range = (data_mean[1] - 3 * np.sqrt(data_cov[1,1]), data_mean[1] + 3 * np.sqrt(data_cov[1,1]))

fig = figure(
    title='Non-Linear subspace',
    width=400,
    height=400,
    x_range=x_range,
    y_range=y_range
)
fig.circle(X[:, 0], X[:, 1])
show(fig)

In [3]:
def plot_normal(fig, mean, cov, alpha=1., color='blue'):
    'Plot a Normal density'
    # Eigenvalue decomposition of the covariance matrix.
    evals, evecs = np.linalg.eigh(cov)
    
    sign = 1 if cov[1, 0] == 0 else np.sign(cov[1, 0])
    # Angle of the rotation.
    angle =  - np.arccos(sign * abs(evecs[0, 0]))
   
    fig.ellipse(x=mean[0], y=mean[1], 
                width=4 * np.sqrt(evals[0]), 
                height=4 * np.sqrt(evals[1]), 
                angle=angle, alpha=.5 * alpha, color=color)
    fig.cross(mean[0], mean[1], color=color, alpha=alpha)
    fig.ellipse(x=mean[0], y=mean[1], 
                width=2 * np.sqrt(evals[0]), 
                height=2 * np.sqrt(evals[1]), 
                angle=angle, alpha=alpha, color=color)
    
def plot_gmm(fig, gmm, alpha=1., color='blue'):
    'Plot a Normal density'
    for weight, comp in zip(gmm.weights, gmm.components):
        plot_normal(fig, comp.mean, comp.cov, alpha * weight, color)
        
def plot_latent_model(fig, latent_model, alpha=1., color='blue'):
    if 'Mixture' in str(type(latent_model)):
        plot_gmm(fig, latent_model, alpha, color)
    elif 'Normal' in str(type(latent_model)):
        plot_normal(fig, latent_model.mean, latent_model.cov, alpha, color)
    else:
        raise ValueError

## Model Creation

We first create the SVAE.

In [11]:
# Dimension of the observed space.
obs_dim = X.shape[1]

# Dimension of the latent space. It can be bigger or smaller
# than the dimension of the observed space.
latent_dim = 2

# Number of samples for the "reparameterization-trick".
nb_samples = 10

# Number of units per hidden-layer.
n_units = 10

# beer uses pytorch as a backend for the neural-network part
# of the model.
from torch import nn

# Neural network structure of the encoder of the model.
enc_struct = nn.Sequential(
    nn.Linear(obs_dim, n_units),
    nn.Tanh(),
    nn.Linear(n_units, n_units),
    nn.Tanh()
)
encoder = beer.models.MLPNormalIso(enc_struct, latent_dim, residual=True)

# Neural network structure of the decoder of the model.
dec_struct = nn.Sequential(
    nn.Linear(latent_dim, n_units),
    nn.Tanh(),
    nn.Linear(n_units, n_units),
    nn.Tanh()
)
decoder = beer.models.MLPNormalDiag(dec_struct, obs_dim)

# Model of the latent space. It can be changed at any-time.
#latent_model = beer.models.NormalDiagonalCovariance.create(latent_dim)
#latent_model = beer.models.NormalFullCovariance.create(latent_dim, prior_count=1)
#args = {'dim':2, 'prior_count':1, 'mean': data_mean, 'cov': data_cov, 'random_init':True}
#latent_model = beer.Mixture.create(10, beer.NormalDiagonalCovariance.create, args, prior_count=1e-6)
args = {'dim':2, 'prior_count':1, 'mean': data_mean, 'cov': data_cov, 'random_init':True}
latent_model = beer.Mixture.create(10, beer.NormalFullCovariance.create, args, prior_count=1)

# Putting everything together to build the SVAE.
svae = beer.models.VAE(encoder, decoder, latent_model, nb_samples)

In [12]:
svae.fit(X, max_epochs=500, lrate=1e-3, latent_model_lrate=0, kl_weight=0.0, callback=None)
svae.fit(X, max_epochs=500, lrate=0, latent_model_lrate=1e-1, kl_weight=0.0, callback=None)
#elbo, llh, kld, mean, var, dec_mean, dec_var = svae.evaluate(X, sampling=False)
#svae.latent_model.fit(mean, max_epochs=100)
#svae.fit(X, max_epochs=500, lrate=1e-3, latent_model_lrate=0, kl_weight=1.0, callback=None)

In [None]:
svae.latent_model.weights

In [None]:
c, elbos, llhs, klds = 0, [], [], []

def callback(elbo, llh, kld):
    elbos.append(elbo)
    llhs.append(llh)
    klds.append(kld)
    
    #if c % 50 == 0:
    #    print('ln p(X) >=', elbo)

svae.fit(X, max_epochs=2000, lrate=1e-3, latent_model_lrate=1e-2, callback=callback)

fig1 = figure(
    title='ELBO',
    width=400,
    height=400,
    x_axis_label='step',
    y_axis_label='ln p(X)'
)
fig1.line(np.arange(len(elbos)), elbos)

fig2 = figure(
    title='LLH + KLD',
    width=400,
    height=400,
    y_range=(min(llhs) - 1, max(llhs) + 1),
    x_axis_label='step',
    y_axis_label='ln p(x|...)'
)
fig2.line(np.arange(len(llhs)), llhs)

fig2.extra_y_ranges['KLD'] = Range1d(0, max(klds) + 1)
fig2.add_layout(LinearAxis(y_range_name="KLD", axis_label='KLD'), 'right')
fig2.line(np.arange(len(klds)), klds, y_range_name='KLD', color='green')

#show(fig2)
show(gridplot([[fig1, fig2]]))

In [14]:
d = 100
xy = np.mgrid[x_range[0]:x_range[1]:100j, y_range[0]:y_range[1]:100j].reshape(2,-1).T
elbo, llh, kld, mean, var, _, _ = svae.evaluate(xy, sampling=False)
fig1 = figure(
    x_range=x_range, 
    y_range=y_range,
    width=400,
    height=400
)

# must give a vector of image data for image parameter
fig1.image(
    image=[np.exp((elbo).reshape(d, d).T)], 
    x=x_range[0], 
    y=y_range[0], 
    dw=(x_range[1] - x_range[0]), 
    dh=(y_range[1] - y_range[0]),
    palette="Inferno256", alpha=.01
)
fig1.circle(X[:, 0], X[:, 1], alpha=1)

fig2 = figure(
    width=400,
    height=400,
    
    x_range=(mean.min() - .1, mean.max() + .1),
    y_range=(mean.min() - .1, mean.max() + .1)
)

elbo, llh, kld, mean, var, _, _ = svae.evaluate(X[:100], sampling=False)
fig2.circle(mean[:, 0], mean[:, 1])
for m, v in zip(mean, var):
    fig2.ellipse(x=m[0], y=m[1], 
                 width=2 * np.sqrt(v[0]), 
                 height=2 * np.sqrt(v[1]), 
                 fill_alpha=0, color='black') 
plot_latent_model(fig2, svae.latent_model, alpha=.5, color='salmon')

grid = gridplot([[fig1, fig2]])
show(grid)  

In [15]:
svae.latent_model.weights

array([ 0.00533224,  0.45035742,  0.00533224,  0.00533224,  0.00533224,
        0.00533224,  0.02588363,  0.00533224,  0.00533224,  0.48643328])

In [None]:
args = {'dim':2, 'prior_count':1, 'random_init':True}
gmm_diag = beer.Mixture.create(10, beer.NormalDiagonalCovariance.create, args, prior_count=1e-6)

In [None]:
c, elbos, llhs, klds = 0, [], [], []

def callback(elbo, llh, kld):
    elbos.append(elbo)
    llhs.append(llh)
    klds.append(kld)
    
    #if c % 50 == 0:
    #    print('ln p(X) >=', elbo)

gmm_diag.fit(X, max_epochs=200, lrate=1e-1, callback=callback)

fig1 = figure(
    title='ELBO',
    width=400,
    height=400,
    x_axis_label='step',
    y_axis_label='ln p(X)'
)
fig1.line(np.arange(len(elbos)), elbos)

fig2 = figure(
    title='LLH + KLD',
    width=400,
    height=400,
    y_range=(min(llhs) - 1, max(llhs) + 1),
    x_axis_label='step',
    y_axis_label='ln p(x|...)'
)
fig2.line(np.arange(len(llhs)), llhs)

fig2.extra_y_ranges['KLD'] = Range1d(0, max(klds) + 1)
fig2.add_layout(LinearAxis(y_range_name="KLD", axis_label='KLD'), 'right')
fig2.line(np.arange(len(klds)), klds, y_range_name='KLD', color='green')

#show(fig2)
show(gridplot([[fig1, fig2]]))

In [None]:
fig = figure(width=400, height=400)
fig.circle(X[:, 0], X[:, 1])
plot_gmm(fig, gmm_diag)
show(fig)  

In [None]:
elbo, llh, kld, mean, var, dec_mean, dec_var = svae.evaluate(X, sampling=True)
dec_var

In [None]:
import torch
from torch.autograd import Variable

# Forward the data through the networks.
state = svae(Variable(torch.from_numpy(X)).float())
means = state['encoder_state']['mean'].data.numpy()
variances = state['encoder_state']['std_dev'].data.numpy()**2

# Estimate m^* and Sigma^*
m = np.mean(means, axis=0) 
m_cov = np.cov(means.T, bias=True)
S = np.diag(np.mean(variances, axis=0)) + m_cov

print(m)
print(m_cov)
print(S)

In [None]:
new_model = svae


In [None]:
#import copy
#new_model = copy.deepcopy(svae)
new_model = svae

# Forward the data through the networks.
state = svae(Variable(torch.from_numpy(X)).float())
means = state['encoder_state']['mean'].data.numpy()
variances = state['encoder_state']['std_dev'].data.numpy()**2

# Estimate m^* and Sigma^*
m = np.mean(means, axis=0) 
m_cov = np.cov(means.T, bias=True)
S = np.diag(np.mean(variances, axis=0)) + m_cov

balance = .5

L = np.linalg.cholesky(m_cov)

inv_L = np.linalg.inv(L)

#evals, evecs = np.linalg.eigh(m_cov)
#inv_L = evecs.T

W, b = svae.encoder.hid_to_mu.weight.data.numpy(), \
    svae.encoder.hid_to_mu.bias.data.numpy()
new_W = torch.from_numpy(inv_L @ W).float()
new_b = torch.from_numpy(inv_L @ (b - m)).float()
new_model.encoder.hid_to_mu.weight = nn.Parameter(new_W)
new_model.encoder.hid_to_mu.bias = nn.Parameter(new_b)

#b = svae.encoder.hid_to_logprec.bias.data.numpy()
#new_b = torch.from_numpy(b + np.log(10 * variances.mean(axis=0)))
#new_model.encoder.hid_to_logprec.bias = nn.Parameter(new_b)

In [None]:
# Forward the data through the networks.
state = svae(Variable(torch.from_numpy(X)).float(), sampling=False)
means = state['decoder_state']['mean'].data.numpy()
variances = state['decoder_state']['std_dev'].data.numpy()**2

# Estimate m^* and Sigma^*
m = np.mean(means, axis=0) 
m_cov = np.cov(means.T, bias=True)
S = np.diag(np.mean(variances, axis=0)) + m_cov
data_L = np.linalg.cholesky(data_cov)

balance = .5

L = np.linalg.cholesky(m_cov)
inv_L = np.linalg.inv(L)
L_S = np.linalg.cholesky(S)
inv_L_S = np.linalg.inv(L_S)


#evals, evecs = np.linalg.eigh(m_cov)
#inv_L = evecs.T

W, b = svae.decoder.hid_to_mu.weight.data.numpy(), \
    svae.decoder.hid_to_mu.bias.data.numpy()
#new_W = torch.from_numpy(data_L @ inv_L @ W).float()
new_b = torch.from_numpy((b - m)).float()
#new_model.decoder.hid_to_mu.weight = nn.Parameter(new_W)
#new_model.decoder.hid_to_mu.bias = nn.Parameter(new_b)

#b = svae.decoder.hid_to_logprec.bias.data.numpy()
#new_b = torch.from_numpy(b + np.log(10 * variances.mean(axis=0)))
#new_model.decoder.hid_to_logprec.bias = nn.Parameter(new_b)

In [None]:
d = 200
xy = np.mgrid[x_range[0]:x_range[1]:200j, y_range[0]:y_range[1]:200j].reshape(2,-1).T
elbo, llh, kld, mean, var = new_model.evaluate(xy, sampling=False)

fig1 = figure(
    x_range=x_range, 
    y_range=y_range,
    width=400,
    height=400
)

# must give a vector of image data for image parameter
fig1.image(
    image=[elbo.reshape(d, d).T], 
    x=x_range[0], 
    y=y_range[0], 
    dw=(x_range[1] - x_range[0]), 
    dh=(y_range[1] - y_range[0]),
    palette="Greys9"
)
fig1.circle(X[:, 0], X[:, 1])

fig2 = figure(
    width=400,
    height=400
)

elbo, llh, kld, mean, var = new_model.evaluate(X, sampling=False)
for i in range(len(X)):
    plot_normal(fig2, mean[i], np.diag(var[i]), color='black')

grid = gridplot([[fig1, fig2]])
show(grid)  

In [None]:
new_model.fit(X, max_epochs=250, lrate=1e-3, latent_model_lrate=0)

In [None]:
model.sample = False
beer.inference.run_training(data, model, optimizer, 1000, history, batch_size=20, lrate_latent_model=1.0, kl_weight=0.0)

In [None]:
plot_model_outputs(model, data[:100])

In [None]:
history.plot()

In [None]:
exp = model.latent_model.posterior.grad_lognorm()
print(1/ (exp[0] * -2))