# VAE - Gaussian Linear Classifier

This notebook illustrate how to combine a Variational AutoEncoder (VAE) and a Gaussian Linear Classifier (GLC) with the [beer framework](https://github.com/beer-asr/beer).

In [1]:
%load_ext autoreload
%autoreload 2

# Add the path of the beer source code ot the PYTHONPATH.
import sys
sys.path.insert(0, '../')

import math
import yaml
import numpy as np
import torch
import torch.optim
from torch import nn



# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d

# Beer framework
import beer

# Convenience functions for plotting.
import plotting

output_notebook(verbose=False)

## Data 

As a simple example we consider the following synthetic data: 

In [41]:
ntargets = 4
N = 100
Xs = []
labels = []

x = np.linspace(0, 20, ntargets)
means = np.c_[x, (.1 * x)**2] 
cov = np.array([[.75, 0.], [0., .075]])

points = np.array([
    [1, 1],
    [1, -1],
    [-1, 1],
    [-1, -1.]
])
#points = np.random.randn(4, 2)

for i in range(ntargets):
    mean = points[i]
    cov = .05 * np.array([[1, 0], [0, 1]])
    #X = np.random.multivariate_normal(mean, cov, N)
    #X[:, 1] += X[:, 0] ** 2 
    X = np.ones((N, points.shape[1])) * mean[None]
    labels.append(np.ones(len(X)) * i)
    Xs.append(X)

data = np.vstack(Xs)
labels = np.hstack(labels)

fig = figure(title='Synthetic data', width=400, height=400)
colors = ['salmon', 'blue', 'green', 'black']
for sX, color in zip(Xs, colors):
    fig.circle(sX[:, 0], sX[:, 1], color=color)
show(fig)

## Model Creation

We first create the VAE-GLC.

#### NOTE:
To obtain a Gaussian Quadratic Classifier, us a GMM model with individual (diagonal) covariance matrix.

In [47]:
vae_iaf_conf_str = '''
type: VAENormalizingFlow
llh_type: normal
normalizing_flow:
  type: InverseAutoRegressive
  depth: 10
  iaf_block:
    activation: Tanh
    context_dim: 10
    data_dim: 2
    depth: 2
    width: 20
encoder:
  nnet_structure:
  - residual: false
    block_structure:
    - Linear:in_features=<feadim>;out_features=50
    - ReLU
    - Linear:in_features=50;out_features=50
    - ReLU
  prob_layer:
    type: NormalizingFlowLayer
    covariance: isotropic
    flow_params_dim: 10
    dim_in: 50
    dim_out: 2
decoder:
  nnet_structure:
  - residual: false
    block_structure:
    - Linear:in_features=2;out_features=50
    - ReLU
    - Linear:in_features=50;out_features=50
    - ReLU
  prob_layer:
    type: NormalLayer
    covariance: isotropic
    dim_in: 50
    dim_out: <feadim>
latent_model:
  type: Normal
  covariance: isotropic
  prior_strength: 1.
  noise_std: 0.
'''

vae_conf_str = '''
type: VAE
llh_type: normal
encoder:
  nnet_structure:
  - residual: false
    block_structure:
    - Linear:in_features=<feadim>;out_features=50
    - ReLU
    - Linear:in_features=50;out_features=50
    - ReLU
  prob_layer:
    type: NormalLayer
    covariance: isotropic
    dim_in: 50
    dim_out: 2
decoder:
  nnet_structure:
  - residual: false
    block_structure:
    - Linear:in_features=2;out_features=50
    - ReLU
    - Linear:in_features=50;out_features=50
    - ReLU
  prob_layer:
    type: NormalLayer
    covariance: diagonal
    dim_in: 50
    dim_out: <feadim>
latent_model:
  type: Normal
  covariance: isotropic
  prior_strength: 1
  noise_std: 0.
'''

tmp = '''latent_model:
  type: Mixture
  prior_strength: 1.
  components:
    type: PLDASet
    size: 5
    dim_noise_subspace: 1
    dim_class_subspace: 1
    prior_strength: 1.
    noise_std: 1.
'''

In [49]:
data_mean = torch.from_numpy(data.mean(axis=0)).float()
data_var = torch.from_numpy(np.var(data, axis=0)).float()

print('feadim:', len(data_mean))
conf_data = vae_conf_str.replace('<feadim>', str(len(data_mean)))
conf = yaml.load(conf_data)
vae = beer.create_model(conf, data_mean, data_var).double()

conf_data = vae_iaf_conf_str.replace('<feadim>', str(len(data_mean)))
conf = yaml.load(conf_data)
vae_iaf = beer.create_model(conf, data_mean, data_var).double()

feadim: 2


## Variational Bayes Training

In [54]:
npoints = N * ntargets
epochs = 5_000
lrate_bayesmodel = 0.
lrate_encoder = 1e-3
X = torch.from_numpy(data[:npoints]).double()
vae = vae.double()
targets = torch.from_numpy(labels[:npoints]).long()

nnet_parameters = list(vae.encoder.parameters()) + list(vae.decoder.parameters())
std_optimizer = torch.optim.Adam(nnet_parameters, lr=lrate_encoder, weight_decay=1e-2)
optimizer = beer.BayesianModelCoordinateAscentOptimizer(
    *vae.grouped_parameters, 
    lrate=lrate_bayesmodel, 
    std_optim=std_optimizer)
    
elbos = []
for epoch in range(epochs):
    optimizer.zero_grad()
    elbo = beer.evidence_lower_bound(vae, X, datasize=len(X), nsamples=1, kl_weight=1.)
    elbo.backward()
    elbo.natural_backward()
    optimizer.step()
    
    if epoch > 0:
        elbos.append(float(elbo) / len(X))

# Plot the ELBO.
fig = figure(title='ELBO', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig.line(np.arange(len(elbos)), elbos, color='blue')

show(fig)

In [55]:
fig1 = figure(title='Observed space', width=400, height=400)
fig2 = figure(title='Latent space', width=400, height=400, x_range=(-5, 5), y_range=(-5, 5))
plotting.plot_normal(fig2, vae.latent_model.mean.numpy(), vae.latent_model.cov.numpy(), alpha=.2, color='blue')


for class_X, color in zip(Xs, colors):
    class_X = torch.from_numpy(class_X).double()
    mean, variance = vae.encoder(class_X)
    samples = beer.utils.sample_from_normals(mean, variance, 100).view(-1, 2).detach()
    r_class_X = vae.decoder(samples)[0]
    samples = samples.data.numpy()
    class_X, r_class_X = class_X.detach().numpy(), r_class_X.detach().numpy()
    fig1.circle(class_X[:, 0], class_X[:, 1], alpha=.5, color=color)
    fig1.cross(r_class_X[:, 0], r_class_X[:, 1], color=color)
    fig2.circle(samples[:, 0], samples[:, 1], color=color)
    
show(gridplot([[fig1, fig2]]))

In [52]:
npoints = N * ntargets
epochs = 5_000
lrate_bayesmodel = 0.
lrate_encoder = 1e-3
X = torch.from_numpy(data[:npoints]).double()
vae_iaf = vae_iaf.double()
targets = torch.from_numpy(labels[:npoints]).long()

nnet_parameters = list(vae_iaf.encoder.parameters()) + \
    list(vae_iaf.nflow.parameters()) + \
    list(vae_iaf.decoder.parameters())
std_optimizer = torch.optim.Adam(nnet_parameters, lr=lrate_encoder, weight_decay=1e-2)
optimizer = beer.BayesianModelCoordinateAscentOptimizer(
    *vae_iaf.grouped_parameters, 
    lrate=lrate_bayesmodel, 
    std_optim=std_optimizer)
    
elbos = []
for epoch in range(epochs):
    optimizer.zero_grad()
    elbo = beer.evidence_lower_bound(vae_iaf, X, datasize=len(X), nsamples=1, kl_weight=1.)
    elbo.backward()
    elbo.natural_backward()
    optimizer.step()
    
    if epoch > 0:
        elbos.append(float(elbo) / len(X))

# Plot the ELBO.
fig = figure(title='ELBO', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig.line(np.arange(len(elbos)), elbos, color='blue')

show(fig)

In [53]:
fig1 = figure(title='Observed space', width=400, height=400)
fig2 = figure(title='Latent space', width=400, height=400, x_range=(-5, 5), y_range=(-5, 5))
#plotting.plot_normal(fig2, vae_iaf.latent_model.mean.numpy(), vae_iaf.latent_model.cov.numpy(), alpha=.2, color='blue')

for class_X, color in zip(Xs, colors):
    class_X = torch.from_numpy(class_X).double()
    mean, variance, flow_params = vae_iaf.encoder(class_X)
    _, samples = vae_iaf.nflow(mean, variance, flow_params, nsamples=100, stop_level=-1)
    samples = samples.view(-1, 2).detach()
    r_class_X = vae_iaf.decoder(samples)[0]
    samples = samples.data.numpy()
    class_X, r_class_X = class_X.detach().numpy(), r_class_X.detach().numpy()
    fig1.circle(class_X[:, 0], class_X[:, 1], alpha=.5, color=color)
    fig1.cross(r_class_X[:, 0], r_class_X[:, 1], color=color)
    fig2.circle(samples[:, 0], samples[:, 1], color=color)
    
show(gridplot([[fig1, fig2]]))

In [None]:
fig1 = figure(title='Observed space', width=400, height=400)
fig2 = figure(title='Latent space', width=400, height=400, x_range=(-5, 5), y_range=(-5, 5))
#plotting.plot_normal(fig2, vae_iaf.latent_model.mean.numpy(), vae_iaf.latent_model.cov.numpy(), alpha=.2, color='blue')


for class_X, color in zip(Xs, colors):
    class_X = torch.from_numpy(class_X).double()
    mean, variance, flow_params = vae_iaf.encoder(class_X)
    _, samples = vae_iaf.nflow(mean, variance, flow_params, nsamples=100, stop_level=5)
    samples = samples.view(-1, 2).detach()
    r_class_X = vae_iaf.decoder(samples)[0]
    samples = samples.data.numpy()
    class_X, r_class_X = class_X.detach().numpy(), r_class_X.detach().numpy()
    fig1.circle(class_X[:, 0], class_X[:, 1], alpha=.5, color=color)
    fig1.cross(r_class_X[:, 0], r_class_X[:, 1], color=color)
    fig2.circle(samples[:, 0], samples[:, 1], color=color)
    
show(gridplot([[fig1, fig2]]))

In [109]:
vae_iaf(X).sum(), vae(X).sum() 

(tensor(-1330.7556, dtype=torch.float64),
 tensor(1001.7643, dtype=torch.float64))

In [110]:
vae_iaf.local_kl_div_posterior_prior().sum(), vae.local_kl_div_posterior_prior().sum() 

(tensor(7270.2387, dtype=torch.float64), tensor(815.0322, dtype=torch.float64))

In [41]:
vae_iaf.nflow.nnet_flow[0]

SequentialMultipleInput(
  (0): MergeTransform(
    (transforms): Sequential(
      (0): MaskedLinear(
        (_linear_transform): Linear(in_features=2, out_features=20, bias=True)
      )
      (1): Linear(in_features=10, out_features=20, bias=True)
    )
  )
  (1): ReLU()
  (2): MaskedLinear(
    (_linear_transform): Linear(in_features=20, out_features=20, bias=True)
  )
  (3): ReLU()
  (4): ARNetNormalDiagonalCovarianceLayer(
    (h2mean): MaskedLinear(
      (_linear_transform): Linear(in_features=20, out_features=2, bias=True)
    )
    (h2logvar): MaskedLinear(
      (_linear_transform): Linear(in_features=20, out_features=2, bias=True)
    )
  )
)

In [116]:
M1 = vae_iaf.nflow.nnet_flow[0][0].transforms[0]._mask
M2 = vae_iaf.nflow.nnet_flow[0][2]._mask
M3 = vae_iaf.nflow.nnet_flow[0][-1].h2mean._mask
M1.shape, M2.shape, M3.shape

(torch.Size([20, 2]), torch.Size([20, 20]), torch.Size([2, 20]))

In [118]:
M1

Parameter containing:
tensor([[ 1.,  0.],
        [ 1.,  0.],
        [ 1.,  0.],
        [ 1.,  0.],
        [ 1.,  0.],
        [ 1.,  0.],
        [ 1.,  0.],
        [ 1.,  0.],
        [ 1.,  0.],
        [ 1.,  0.],
        [ 1.,  0.],
        [ 1.,  0.],
        [ 1.,  0.],
        [ 1.,  0.],
        [ 1.,  0.],
        [ 1.,  0.],
        [ 1.,  0.],
        [ 1.,  0.],
        [ 1.,  0.],
        [ 1.,  0.]], dtype=torch.float64)