# VAE - Gaussian Linear Classifier

This notebook illustrate how to combine a Variational AutoEncoder (VAE) and a Gaussian Linear Classifier (GLC) with the [beer framework](https://github.com/beer-asr/beer).

In [None]:
%load_ext autoreload
%autoreload 2

# Add the path of the beer source code ot the PYTHONPATH.
from collections import defaultdict
import random
import sys
sys.path.insert(0, '../')

import math
import yaml
import numpy as np
import torch
import torch.optim
from torch import nn



# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d

# Beer framework
import beer

# Convenience functions for plotting.
import plotting

output_notebook(verbose=False)

## Data 

As a simple example we consider the following synthetic data: 

In [None]:
def generate_data(ntargets, npoints_per_target=100):
    cov = np.array([[.75, 0.], [0., .075]])
    xs, labels = [], []
    for i in range(ntargets):
        mean = np.array([0, 2. - (i * 1.5)])
        z = np.random.multivariate_normal(mean, cov, size=npoints_per_target)
        x = np.zeros_like(z)
        x[:, 0] = z[:, 0]
        x[:, 1] = z[:, 1] + (z[:, 0]- mean[0]) ** 2
        labels.append(np.ones(len(x)) * i)
        xs.append(x)

    idxs = np.arange(0, ntargets * npoints_per_target)
    np.random.shuffle(idxs)
    data = np.vstack(xs)[idxs]
    labels = np.hstack(labels)[idxs]
    return data, labels

ntargets = 5
data, labels = generate_data(ntargets)
test_data, test_labels = generate_data(ntargets, npoints_per_target=50)

# Convert the data/labels to torch tensorl and prepare the 
# (mini-)batches for the training.
X = torch.from_numpy(data).float()
targets = torch.from_numpy(labels).long()
test_X = torch.from_numpy(test_data).float()
test_targets = torch.from_numpy(test_labels).long()

# Compute the global mean/variance to initialize the models.
global_mean, global_var = X.mean(dim=0), X.var(dim=0)

In [None]:
colors = ['salmon', 'blue', 'green', 'orange', 'black', 'red', 'cyan', 
          'purple', 'brown', 'pink']

x_range, y_range = (-3, 3), (-(ntargets + 1), ntargets + 2)

fig1 = figure(title='Training data', width=400, height=400, x_range=x_range,
              y_range=y_range)
for target, color in zip(range(ntargets), colors):
    idxs = np.where(labels == target)[0]
    x = data[idxs]
    fig1.circle(x[:, 0], x[:, 1], color=color)
    
fig2 = figure(title='Test data', width=400, height=400, x_range=x_range,
              y_range=y_range)
for target, color in zip(range(ntargets), colors):
    idxs = np.where(test_labels == target)[0]
    x = test_data[idxs]
    fig2.circle(x[:, 0], x[:, 1], color=color)
show(gridplot([[fig1, fig2]]))

## Model Creation

We first create the VAE-GLC.

#### NOTE:
To obtain a Gaussian Quadratic Classifier, us a GMM model with individual (diagonal) covariance matrix.

In [100]:
def create_vae(mean, var, nnet_width=50, nflow_width=20, nflow_depth=0, 
               nflow_block_depth=2, nflow_params_dim=10, latent_space_dim=2, 
               p_strength=1.):
    
    obs_space_dim = len(mean)
    # Normal prior.
    prior = beer.Normal.create(mean=torch.zeros(latent_space_dim), 
                               cov=torch.ones(latent_space_dim),
                               cov_type='full')

    # Encoder network.
    encoder = torch.nn.Sequential(
        torch.nn.Linear(obs_space_dim, nnet_width),
        torch.nn.ELU(),
        torch.nn.Linear(nnet_width, nnet_width),
        torch.nn.ELU()
    )
    
    # Normalizing flow (1): Initial distribution
    normal_layer = beer.nnet.NormalIsotropicCovarianceLayer(nnet_width, 
                                                            latent_space_dim)

    # Normalizing flow (2): sequence of autogressive network.
    nflow_steps = 0
    nnet_flow = []
    for i in range(nflow_steps):
        nnet_flow.append(beer.nnet.AutoRegressiveNetwork(
                dim_in=latent_space_dim, 
                flow_params_dim=nflow_params_dim, 
                depth=nflow_block_depth,
                width=nflow_width,
                activation=torch.nn.ELU()
            )
        )
    
    # Normalizing flow (3): Assemble the initial distribution and the 
    #                       autoregressive nnets.
    encoder_problayer = beer.nnet.InverseAutoRegressiveFlow(
        dim_in=nnet_width,
        flow_params_dim=nflow_params_dim,
        normal_layer=normal_layer,
        nnet_flow=nnet_flow
    )

    # Decoder network
    decoder = torch.nn.Sequential(
        torch.nn.Linear(latent_space_dim, nnet_width),
        torch.nn.Tanh(),
        torch.nn.Linear(nnet_width, nnet_width),
        torch.nn.Tanh(),
        torch.nn.Linear(nnet_width, obs_space_dim)
    )

    # Normal distribution embedding the auto-encoder:
    #   N(μ + f(z), σ²)
    # note that the variance does not depends on the latent space as 
    # it usually does with standard variational auto-encoder.
    normal_iso = beer.Normal.create(mean, var, 1., cov_type='isotropic')

    # Constructre the VAE from all the part.
    model = beer.VAEGlobalMeanVariance(encoder, encoder_problayer,decoder, 
                                       normal_iso, prior)
    return model

    
def create_gaussian_classifier(ntargets, mean, var, p_strength=1., 
                               shared_cov=True):
    mset = beer.NormalSet.create(mean, var, size=ntargets, cov_type='full', 
                          shared_cov=shared_cov, prior_strength=p_strength,
                                noise_std=0)
    return beer.Mixture.create(mset)

### 1. Pre-training

In [6]:
def train_cvb(model, X, labels=None, epochs=1, nbatches=1, lrate_nnet=1e-3,
              update_prior=True, update_nnet=True, kl_weight=1., state=None):
    
    batches = X.view(nbatches, -1, 2)
    batches_targets = targets.view(nbatches, -1)

    prior_parameters = model.bayesian_parameters() if update_prior else range(0)
    nnet_parameters = model.modules_parameters() if update_nnet else range(0)
    
    if state is None:
        std_optimizer = torch.optim.Adam(nnet_parameters, lr=lrate_nnet, 
                                         weight_decay=1e-2)
        optimizer = beer.CVBOptimizer(prior_parameters, std_optim=std_optimizer)
        batch_stats = defaultdict(lambda: defaultdict(lambda: None))
    else:
        optimizer, batch_stats = state

    elbos = []
    for epoch in range(epochs):
        # Randomized the order of the batches.
        batch_ids = list(range(len(batches)))
        random.shuffle(batch_ids)
        
        for batch_id in batch_ids:
            optimizer.init_step(batch_stats[batch_id])
            kwargs = {'kl_weight': kl_weight}
            kwargs.update({'labels': batches_targets[batch_id]} 
                           if labels is not None else {})
            elbo = beer.collapsed_evidence_lower_bound(model, batches[batch_id], 
                                                       **kwargs)
            batch_stats[batch_id] = elbo.backward()
            optimizer.step()
            
        # Monitor the evidence lower bound after each epoch.
        kwargs = {'labels': targets} if labels is not None else {}
        elbo = beer.evidence_lower_bound(model, X, **kwargs)
        elbos.append(float(elbo) / len(X))
    
    return elbos, (optimizer, batch_stats)


def train_svb(model, X, labels=None, epochs=1, nbatches=1, lrate_nnet=1e-3,
              lrate_prior=1e-1, update_prior=True, update_nnet=True, 
              kl_weight=1., state=None):
    
    batches = X.view(nbatches, -1, 2)
    batches_targets = targets.view(nbatches, -1)
    
    mf_groups = model.mean_field_groups if update_prior else [[]]
    nnet_parameters = model.modules_parameters() if update_nnet else range(0)

    if state is None:
        std_optimizer = torch.optim.Adam(nnet_parameters, lr=lrate_nnet, 
                                         weight_decay=1e-2)
        optimizer = beer.BayesianModelOptimizer(mf_groups, lrate=lrate_prior, 
                                                std_optim=std_optimizer)
    else:
        optimizer = state
    
    
    elbos = []
    for epoch in range(epochs):
        # Randomized the order of the batches.
        batch_ids = list(range(len(batches)))
        random.shuffle(batch_ids)
        for batch_id in batch_ids:
            optimizer.init_step()
            kwargs = {'kl_weight': kl_weight, 'datasize': len(X)}
            kwargs.update({'labels': batches_targets[batch_id]} 
                           if labels is not None else {})
            elbo = beer.evidence_lower_bound(model, batches[batch_id], 
                                             **kwargs)
            elbo.backward()
            optimizer.step()
            
        # Monitor the evidence lower bound after each epoch.
        kwargs = {'labels': targets} if labels is not None else {}
        elbo = beer.evidence_lower_bound(model, X, **kwargs)
        elbos.append(float(elbo) / len(X))
    
    return elbos, optimizer

In [83]:
def plot_latent_space(fig, model, X, labels, use_mean=True):
    enc_states = vae.encoder(X)
    post_params = vae.encoder_problayer(enc_states)
    samples, _ = vae.encoder_problayer.samples_and_llh(post_params, use_mean=use_mean)
    samples = samples.data.numpy()

    #mean, cov = vae.latent_model.mean.numpy(), vae.latent_model.cov.numpy()
    #plotting.plot_normal(fig, mean, cov, alpha=.3, color='blue')
    for target, color in zip(range(ntargets), colors):
        idxs = np.where(labels.numpy() == target)
        class_samples = samples[idxs]
        fig.circle(class_samples[:, 0], class_samples[:, 1], color=color)

In [275]:
vae = create_vae(global_mean, global_var)

# Set the Gaussian classifer as the prior of the VAE.
gmm = create_gaussian_classifier(ntargets, torch.zeros(2), torch.ones(2),
                                 p_strength=10.)
vae.latent_model = gmm

# Initialize the VAE.
elbos, state = train_svb(vae, X, epochs=50, nbatches=50, kl_weight=0)

# Set the Gaussian classifer as the prior of the VAE.
vae.latent_model = gmm

# training the vae.
svb_elbos, state = train_svb(vae, X, labels=targets, epochs=1000, nbatches=50, 
                         kl_weight=1, state=state)

# Plotting
fig1 = figure(width=400, height=400)
fig1.line(range(len(svb_elbos)), svb_elbos, legend='ELBO')
fig1.legend.location = 'bottom_right'

fig2 = figure(width=400, height=400)
plotting.plot_gmm(fig2, vae.latent_model, colors=colors, alpha=.5, color='blue')
plot_latent_space(fig2, vae, X, targets)
show(gridplot([[fig1, fig2]]))

KeyboardInterrupt: 

In [278]:
vae = create_vae(global_mean, global_var)

# Set the Gaussian classifer as the prior of the VAE.
gmm = create_gaussian_classifier(ntargets + 1, torch.zeros(2), torch.ones(2),
                                 p_strength=1)
vae.latent_model = gmm

# Initialize the VAE.
elbos, state = train_cvb(vae, X, labels=targets, epochs=50, nbatches=50, kl_weight=0)

# training the vae.
cvb_elbos, state = train_cvb(vae, X, labels=targets, epochs=100, nbatches=50, 
                         kl_weight=1, state=state)

# Plotting
fig1 = figure(width=400, height=400)
fig1.line(range(len(cvb_elbos)), cvb_elbos, legend='ELBO')
fig1.legend.location = 'bottom_right'

fig2 = figure(width=400, height=400)
plotting.plot_gmm(fig2, vae.latent_model, colors=colors, alpha=.5, color='blue')
plot_latent_space(fig2, vae, X, targets)
show(gridplot([[fig1, fig2]]))

In [85]:
# Plotting
fig1 = figure(width=400, height=400)
fig1.line(range(len(cvb_elbos)), cvb_elbos, legend='ELBO')
fig1.legend.location = 'bottom_right'

fig2 = figure(width=400, height=400)
plotting.plot_gmm(fig2, vae.latent_model, colors=colors, alpha=.5, color='blue')
plot_latent_space(fig2, vae, X, targets, use_mean=True)
show(gridplot([[fig1, fig2]]))

In [153]:
# Plotting

fig1 = figure(width=400, height=400)
fig1.line(range(len(svb_elbos)), svb_elbos, color='blue', legend='SVB ELBO')
fig1.line(range(len(cvb_elbos)), cvb_elbos, color='green', legend='CVB ELBO')
fig1.legend.location = 'bottom_right'

show(fig1)

NameError: name 'svb_elbos' is not defined

In [268]:
gmm = create_gaussian_classifier(5, torch.ones(2), 2 * torch.ones(2), p_strength=1)
prior = gmm.modelset.means_precision.prior
post = gmm.modelset.means_precision.posterior

stats = gmm.sufficient_statistics(X[:1])
resps = 1 + torch.randn(1, 5) ** 2
resps /= resps.sum(dim=1)[:, None]
resps = torch.ones_like(resps) / 5

acc_stats = gmm.modelset.accumulate(stats, resps)
#post.natural_parameters = prior.natural_parameters + acc_stats[gmm.modelset.means_precision]
print('stats:', stats)
print('acc_stats:', acc_stats[gmm.modelset.means_precision])

joint_nparams = post.natural_parameters
np1, np2 = gmm.modelset._split_natural_parameters(joint_nparams)
np1 = torch.ones(len(np2), 1, dtype=np1.dtype,
                 device=np1.device) * np1.view(1, -1)
nparams1 = torch.cat([
    np1[:, :-1],
    np2,
    np1[:, -1].view(-1, 1)
], dim=1)[None]
print('stats', stats)
stats = torch.cat([
    stats[:, :4],
    stats[:, 4:-1] / 5,
    stats[:, -1].view(-1, 1)
], dim=-1)
print('stats', stats)
nparams2 = stats + nparams1

stats: tensor([[-0.0365, -0.1330, -0.1330, -0.4846,  0.2702,  0.9844, -0.5000,  0.5000]])
acc_stats: tensor([-0.0365, -0.1330, -0.1330, -0.4846,  0.0540,  0.1969,  0.0540,  0.1969,
         0.0540,  0.1969,  0.0540,  0.1969,  0.0540,  0.1969, -0.1000, -0.1000,
        -0.1000, -0.1000, -0.1000,  0.5000])
stats tensor([[-0.0365, -0.1330, -0.1330, -0.4846,  0.2702,  0.9844, -0.5000,  0.5000]])
stats tensor([[-0.0365, -0.1330, -0.1330, -0.4846,  0.0540,  0.1969, -0.1000,  0.5000]])


In [269]:
#print(post.natural_parameters)
#print(nparams1)
post.natural_parameters = post.natural_parameters + acc_stats[gmm.modelset.means_precision]
print(stats)
print(acc_stats[gmm.modelset.means_precision])
print('nparams:', post.natural_parameters)
print(nparams2)
lnorm1 = post.log_norm()
lnorm2 = post.joint_log_norm(nparams2)
lnorm1, lnorm2

tensor([[-0.0365, -0.1330, -0.1330, -0.4846,  0.0540,  0.1969, -0.1000,  0.5000]])
tensor([-0.0365, -0.1330, -0.1330, -0.4846,  0.0540,  0.1969,  0.0540,  0.1969,
         0.0540,  0.1969,  0.0540,  0.1969,  0.0540,  0.1969, -0.1000, -0.1000,
        -0.1000, -0.1000, -0.1000,  0.5000])
nparams: tensor([-4.5365, -2.6330, -2.6330, -4.9846,  1.0540,  1.1969,  1.0540,  1.1969,
         1.0540,  1.1969,  1.0540,  1.1969,  1.0540,  1.1969, -0.6000, -0.6000,
        -0.6000, -0.6000, -0.6000,  2.5000])
tensor([[[-4.5365, -2.6330, -2.6330, -4.9846,  1.0540,  1.1969, -0.6000,
           2.5000],
         [-4.5365, -2.6330, -2.6330, -4.9846,  1.0540,  1.1969, -0.6000,
           2.5000],
         [-4.5365, -2.6330, -2.6330, -4.9846,  1.0540,  1.1969, -0.6000,
           2.5000],
         [-4.5365, -2.6330, -2.6330, -4.9846,  1.0540,  1.1969, -0.6000,
           2.5000],
         [-4.5365, -2.6330, -2.6330, -4.9846,  1.0540,  1.1969, -0.6000,
           2.5000]]])


(tensor([[-2.6974]]), tensor([[-1.9681, -1.9681, -1.9681, -1.9681, -1.9681]]))

In [254]:
normal = beer.Normal.create(torch.ones(2), 2 * torch.ones(2))
normal.mean_precision.posterior.log_norm()

tensor([[-0.2416]])

In [132]:
nparams1[:, 0], normal.mean_precision.posterior.natural_parameters

(tensor([[-4.5000, -2.5000, -2.5000, -4.5000,  1.0000,  1.0000, -0.5000,  2.0000]]),
 tensor([-2.5000, -0.5000, -0.5000, -2.5000,  1.0000,  1.0000, -0.5000,  0.0000]))

In [None]:
gmm = create_gaussian_classifier(5, torch.ones(2), torch.ones(2), p_strength=1)
prior = gmm.modelset.means_precision.prior
post = gmm.modelset.means_precision.posterior

stats = gmm.sufficient_statistics(X[:10])
resps = torch.randn(10, 5)
resps /= resps.sum(dim=1)[:, None]
acc_stats = gmm.modelset.accumulate(stats, resps)

lnorm1 = post.log_norm()
post.natural_parameters = prior.natural_parameters + acc_stats[gmm.modelset.means_precision]
lnorm2 = post.log_norm()
lnorm1, lnorm2

In [None]:
npoints = N * ntargets
epochs = 200
lrate_bayesmodel = 1e-1
lrate_encoder = 1e-3
targets = torch.from_numpy(labels[:npoints]).long()

nnet_parameters = list(model.encoder.parameters()) + \
    list(model.encoder_problayer.parameters()) + \
    list(model.decoder.parameters()) 
std_optimizer = torch.optim.Adam(nnet_parameters, lr=lrate_encoder, weight_decay=1e-2)

#optimizer = beer.BayesianModelOptimizer(model.mean_field_groups, 
#                                        lrate=lrate_bayesmodel, 
#                                        std_optim=std_optimizer)
optimizer = beer.CVBOptimizer(model.bayesian_parameters(), std_optim=std_optimizer)
batch_stats = defaultdict(lambda: defaultdict(lambda: None))

elbos = []
for epoch in range(epochs):
    batch_ids = list(range(len(batches)))
    random.shuffle(batch_ids)
    for batch_id in batch_ids:
        #optimizer.init_step()
        #elbo = beer.evidence_lower_bound(model, 
        #                                 batches[batch_id], 
        #                                 labels=batches_targets[batch_id], 
        #                                 datasize=len(X))
        #elbo.backward()
        #optimizer.step()
        #elbos.append(float(elbo) / len(X))

        optimizer.init_step(batch_stats[batch_id])
        elbo = beer.collapsed_evidence_lower_bound(model, batches[batch_id], 
                                                   labels=batches_targets[batch_id])
        batch_stats[batch_id] = elbo.backward()
        optimizer.step()
    elbo = beer.collapsed_evidence_lower_bound(model, X, labels=targets)
    elbos.append(float(elbo) / len(X))

# Plot the ELBO.
#fig = figure(title='ELBO', width=400, height=400, x_axis_label='step',
#              y_axis_label='ln p(X)')
#fig.line(np.arange(len(elbos)), elbos, color='blue')

#show(fig)

In [None]:
# Plot the ELBO.
fig = figure(title='ELBO', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig.line(np.arange(len(elbos)), elbos, color='blue')
show(fig)

In [None]:
fig = figure(width=400, height=400)
for class_X, color in zip(Xs, 
                          colors):
    class_X = torch.from_numpy(class_X).float()
    enc_states = model.encoder(class_X)
    post_params = model.encoder_problayer(enc_states)
    samples, _ = model.encoder_problayer.samples_and_llh(post_params, use_mean=True)
    samples = samples.data.numpy()
    fig.circle(samples[:, 0], samples[:, 1], color=color)
plotting.plot_gmm(fig, model.latent_model, colors=colors, alpha=.5, color='blue')
show(fig)

In [None]:
model.normal.mean_precision.posterior.strength