# VAE-GMM

This notebook illustrate how to combine a Variational AutoEncoder (VAE) and a Gaussian Mixture Model (GMM) with the [beer framework](https://github.com/beer-asr/beer).

In [3]:
%load_ext autoreload
%autoreload 2

# Add the path of the beer source code ot the PYTHONPATH.
from collections import defaultdict
import random
import sys
sys.path.insert(0, '../')

import math
import yaml
import numpy as np
import torch
import torch.optim
from torch import nn



# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d

# Beer framework
import beer

# Convenience functions for plotting.
import plotting

output_notebook(verbose=False)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Data 

As a simple example we consider the following synthetic data: 

In [320]:
def generate_cluster(npoints, mean, angle):
    x = np.random.randn(npoints) 
    data = np.c_[x, np.cos(x) + np.random.randn(npoints) * 1e-1] 
    R = np.array([
        [np.cos(angle), np.sin(angle)],
        [-np.sin(angle), np.cos(angle)]
    ])
    return ((data - np.array([0, 1])) @ R)  + mean

data1 = generate_cluster(npoints=50, mean=np.array([0., -1.]), angle=0)
data2 = generate_cluster(npoints=50, mean=np.array([1., 0.]), angle=np.pi/2)
data3 = generate_cluster(npoints=50, mean=np.array([0., 1.]), angle=np.pi)
data4 = generate_cluster(npoints=50, mean=np.array([-1., 0.]), angle=3 * np.pi/2)


X = torch.from_numpy(np.r_[data1, data2, data3, data4])

fig = figure()
fig.circle(data1[:, 0], data1[:, 1], color='blue')
fig.circle(data2[:, 0], data2[:, 1], color='red')
fig.circle(data3[:, 0], data3[:, 1], color='green')
fig.circle(data4[:, 0], data4[:, 1], color='orange')


show(fig)

## Model Creation


In [356]:
data_mean = torch.from_numpy(data.mean(axis=0)).double()
data_var = torch.from_numpy(np.var(data, axis=0)).double()

gaussians = beer.NormalSet.create(
    data_mean, data_var,      # use to set the mean/variance of the prior
    size=50,                  # total number of components in the mixture
    prior_strength=1e-3,        # how much the prior affect the training ("pseudo-counts")
    noise_std=1.,             # standard deviation of the noise to initialize the mean of the posterior
    cov_type='full',          # type of the covariance matrix  ('full', 'diagonal' or 'isotropic')
)

gmm = beer.Mixture.create(
    gaussians, 
    prior_strength=1.         # how much the prior over the weights will affect the training ("pseudo-counts")
)

gmm = gmm.double()            # set all the parameters in double precision
#gmm = gmm.cuda()             # move the model on a GPU. If you do so, you'll have
                              # to move the data as well.
    
    
# Fit the GMM to the data as initialization.
# Note that this initialization is valid as 
# we assume a residual network as encoder (see cells below)
optim = beer.VBConjugateOptimizer(gmm.mean_field_factorization(), lrate=1.)
for e in range(100):
    optim.init_step()
    elbo = beer.evidence_lower_bound(gmm, X)
    elbo.backward()
    optim.step()
    
print(gmm)

Mixture(
  (modelset): NormalSet(
    (means_precisions): ConjugateBayesianParameter(prior=NormalWishart, posterior=NormalWishart)
  )
  (categorical): Categorical(
    (weights): ConjugateBayesianParameter(prior=Dirichlet, posterior=Dirichlet)
  )
)


In [357]:
class ResidualFeedFowardBlock(torch.nn.Module):
    '''Block of two feed-forward layer with a reisdual connection:
      
            f(W1^T x + b1)         f(W2^T h1 + b2 )         h2 + x 
        x ------------------> h1 --------------------> h2 ----------> y
        |                                              ^
        |               Residual connection            | 
        +----------------------------------------------+
        
    '''
    
    def __init__(self, dim_in, width, activation_fn=torch.nn.Tanh):
        super().__init__()
        self.layer1 = torch.nn.Linear(dim_in, width)
        self.layer2 = torch.nn.Linear(width, dim_in)
        self.activation_fn = activation_fn()
    
    def forward(self, x):
        h1 = self.activation_fn(self.layer1(x))
        h2 = self.activation_fn(self.layer2(h1))
        return h2 + x
    

class ResidualFeedForwardNet(torch.nn.Module):
    
    def __init__(self, dim_in, nblocks=1, block_width=10):
        super().__init__()
        self._dim_in = dim_in
        self.blocks = torch.nn.Sequential(*[
            ResidualFeedFowardBlock(dim_in, block_width)
            for i in range(nblocks)
        ])
    
    @property
    def dim_in(self):
        return self._dim_in

    @property
    def dim_out(self):
        # input and output dimension are the same in our residual network.
        return self._dim_in
    
    def forward(self, X):
        return self.blocks(X)
    
    
# Parameterization of the Normal using the
# log diagonal covariance matrix.
class MeanLogDiagCov(torch.nn.Module):

    def __init__(self, mean, log_diag_cov):
        super().__init__()
        self.mean = mean
        self.log_diag_cov = log_diag_cov

    @property
    def diag_cov(self):
        # Make sure the variance is never 0.
        return 1e-5 + self.log_diag_cov.exp()
        
        
class VAE(beer.Model):
    
    def __init__(self, prior, encoder, decoder):
        super().__init__()
        self.prior = prior
        self.encoder = encoder
        self.decoder = decoder
        self.enc_mean_layer = torch.nn.Linear(encoder.dim_out, decoder.dim_in)
        self.enc_var_layer = torch.nn.Linear(encoder.dim_out, decoder.dim_in)
        self.dec_mean_layer = torch.nn.Linear(decoder.dim_out, encoder.dim_in)
        self.dec_var_layer = torch.nn.Linear(decoder.dim_out, encoder.dim_in)
         
    def posteriors(self, X):
        'Forward the data to the encoder to get the variational posteriors.'
        H = self.encoder(X)
        return beer.dists.NormalDiagonalCovariance(
            MeanLogDiagCov(self.enc_mean_layer(H), self.enc_var_layer(H))
        )
    
    def pdfs(self, Z):
        'Return the normal densities given the latent variable Z'
        return beer.dists.NormalDiagonalCovariance(
            MeanLogDiagCov(self.dec_mean_layer(Z), self.dec_var_layer(Z))
        )
    
    ####################################################################
    # Model interface.

    def mean_field_factorization(self):
        return self.prior.mean_field_factorization()

    def sufficient_statistics(self, data):
        return data

    def expected_log_likelihood(self, data, nsamples=1, llh_weight=1.,
                                kl_weight=1., **kwargs):
        posts = self.posteriors(data)
        
        # Local KL-divergence. There is a close for solution
        # for this term but we use sampling as it allows to
        # change the prior (GMM, HMM, ...) easily.
        samples = posts.sample(nsamples)
        s_samples = posts.sufficient_statistics(samples).mean(dim=1)
        ent = -posts(s_samples, pdfwise=True)
        s_samples = self.prior.sufficient_statistics(samples.view(-1, samples.shape[-1]))
        s_samples = s_samples.reshape(len(samples), -1, s_samples.shape[-1]).mean(dim=1)
        self.cache['prior_stats'] = s_samples
        xent = -self.prior.expected_log_likelihood(s_samples)
        local_kl_div = xent - ent
        
        # Approximate the expected log-likelihood with the
        # reparameterization trick.
        pdfs = self.pdfs(samples.view(-1, samples.shape[-1]))
        r_data = data[:, None, :].repeat(1, nsamples, 1).view(-1, data.shape[-1])
        llh = pdfs(pdfs.sufficient_statistics(r_data), pdfwise=True)
        llh = llh.reshape(len(data), nsamples, -1).mean(dim=1)
        
        return llh_weight * llh - kl_weight * local_kl_div

    def accumulate(self, stats, parent_msg=None):
        return self.prior.accumulate(self.cache['prior_stats'])

    ####################################################################
        
encoder = ResidualFeedForwardNet(dim_in=2, nblocks=2, block_width=20)
decoder = ResidualFeedForwardNet(dim_in=2, nblocks=2, block_width=20)
vae = VAE(gmm, encoder, decoder).double()

In [358]:
epochs = 30_000
update_prior_after_epoch = 20_000
prior_lrate = 1.
cjg_optim = beer.VBConjugateOptimizer(vae.mean_field_factorization(), lrate=0)
std_optim = torch.optim.Adam(vae.parameters(), lr=1e-3)
optim = beer.VBOptimizer(cjg_optim, std_optim)

elbos = []
for e in range(epochs):
    optim.init_step()
    elbo = beer.evidence_lower_bound(vae, X, nsamples=5)
    elbo.backward()
    optim.step()
    
    if e >= update_prior_after_epoch:
        cjg_optim.lrate = prior_lrate
    elbos.append(float(elbo) / len(X))
    
fig = figure()
fig.line(range(len(elbos)), elbos)
show(fig)

In [359]:
posts = vae.posteriors(X)
rX = vae.dec_mean_layer(vae.decoder(posts.params.mean)).detach().numpy()
weights = gmm.categorical.mean.numpy()
print(weights)

fig = figure(title='Components', width=400, height=400)
for i in range(len(posts)):
    normal = posts[i]
    mean = normal.params.mean.detach().numpy()
    cov = normal.params.diag_cov.detach().diag().numpy()
    plotting.plot_normal(fig, mean, cov, alpha=.5, n_std_dev=2, color='salmon')
#fig.circle(data[:, 0], data[:, 1], alpha=.5)
for weight, normal in zip(weights, gmm.modelset):
    mean = normal.mean.numpy()
    cov = normal.cov.numpy()
    plotting.plot_normal(fig, mean, cov, alpha=.5 * weight, color='green')
    
fig2 = figure(width=400, height=400, x_range=fig.x_range)
fig2.circle(data[:, 0], data[:, 1])
fig2.circle(rX[:, 0], rX[:, 1], color='red')
#fig2 = figure(width=400, height=400, y_range=(-0.1, 1.1), title='Mixing weights')
#fig2.vbar(range(len(weights)), width=.5, top=weights)
#fig2.xaxis.ticker = list(range(len(weights)))
#fig2.xgrid.visible = False
show(gridplot([[fig, fig2]]))


[9.95024876e-05 9.95024876e-05 9.95024876e-05 9.95024876e-05
 9.95024876e-05 9.95024876e-05 9.95024876e-05 9.95024876e-05
 9.95024876e-05 9.95024876e-05 9.95024876e-05 9.95024876e-05
 9.95024876e-05 9.95024876e-05 9.95024876e-05 9.95024876e-05
 9.95024876e-05 9.95024876e-05 6.83044527e-02 9.95024876e-05
 9.95024876e-05 9.95024876e-05 9.95024876e-05 1.60012816e-01
 9.95024876e-05 9.95024876e-05 2.07797602e-02 9.95024876e-05
 9.95024876e-05 9.95024876e-05 2.31196178e-01 2.04946893e-01
 9.95024876e-05 9.95024876e-05 9.95024876e-05 9.95024876e-05
 1.42687671e-01 9.95024876e-05 9.95024876e-05 9.95024876e-05
 1.42689935e-01 9.95024876e-05 9.95024876e-05 9.95024876e-05
 2.52031896e-02 9.95024876e-05 9.95024876e-05 9.95024876e-05
 9.95024876e-05 9.95024876e-05]


In [219]:
posts[0].params.diag_cov.diag()

tensor([[0.8643, 0.0000],
        [0.0000, 2.9946]], dtype=torch.float64, grad_fn=<DiagBackward>)

In [204]:
epochs = 1_000
cjg_optim = beer.VBConjugateOptimizer(vae.mean_field_factorization(), lrate=1)
std_optim = torch.optim.Adam(vae.parameters(), lr=1e-3)
optim = beer.VBOptimizer(cjg_optim, std_optim)

elbos = []
for e in range(epochs):
    optim.init_step()
    elbo = beer.evidence_lower_bound(vae, X, nsamples=5)
    elbo.backward()
    optim.step()
    elbos.append(float(elbo) / len(X))
    
fig = figure()
fig.line(range(len(elbos)), elbos)
show(fig)

In [206]:
weights = gmm.categorical.mean.numpy()
fig = figure(title='Components', width=400, height=400,
             x_range=(-10, 10), y_range=(-5, 15))
fig.circle(data[:, 0], data[:, 1], alpha=.5)
for weight, normal in zip(weights, gmm.modelset):
    mean = normal.mean.numpy()
    cov = normal.cov.numpy()
    plotting.plot_normal(fig, mean, cov, alpha=.5 * weight, color='green')
    
fig2 = figure(width=400, height=400, y_range=(-0.1, 1.1), title='Mixing weights')
fig2.vbar(range(len(weights)), width=.5, top=weights)
fig2.xaxis.ticker = list(range(len(weights)))
fig2.xgrid.visible = False
show(gridplot([[fig, fig2]]))

### 1. Pre-training

In [14]:
def train_cvb(model, X, epochs=1, nbatches=1, lrate_nnet=1e-3, 
              update_prior=True, update_nnet=True, kl_weight=1., state=None,
              nsamples=1, callback=None):
    
    batches = X.view(nbatches, -1, 2)

    prior_parameters = model.bayesian_parameters() if update_prior else model.normal.bayesian_parameters()
    
    if state is None:
        if update_nnet:
            std_optimizer = torch.optim.Adam(model.modules_parameters(), lr=lrate_nnet, 
                                             weight_decay=1e-2)
        else:
            std_optimizer = None
        optimizer = beer.CVBOptimizer(prior_parameters, std_optim=std_optimizer)
        batch_stats = defaultdict(lambda: defaultdict(lambda: None))
    else:
        optimizer, batch_stats = state

    for epoch in range(epochs):
        # Randomized the order of the batches.
        batch_ids = list(range(len(batches)))
        random.shuffle(batch_ids)
        
        for batch_id in batch_ids:
            optimizer.init_step(batch_stats[batch_id])
            kwargs = {'kl_weight': kl_weight}
            elbo = beer.collapsed_evidence_lower_bound(model, batches[batch_id], 
                                                       nsamples=nsamples,
                                                       **kwargs)
            batch_stats[batch_id] = elbo.backward()
            optimizer.step()
            
        if callback is not None:
            callback()
            
    return (optimizer, batch_stats)


def train_svb(model, X, epochs=1, nbatches=1, lrate_nnet=1e-3,
              lrate_prior=1e-1, update_prior=True, update_nnet=True, 
              kl_weight=1., state=None, nsamples=1, callback=None):
    
    batches = X.view(nbatches, -1, 2)
    
    mf_groups = model.mean_field_groups if update_prior else model.normal.mean_field_groups
    nnet_parameters = model.modules_parameters() if update_nnet else range(0)

    if state is None:
        std_optimizer = torch.optim.Adam(nnet_parameters, lr=lrate_nnet, 
                                         weight_decay=1e-2)
        optimizer = beer.BayesianModelOptimizer(mf_groups, lrate=lrate_prior, 
                                                std_optim=std_optimizer)
    else:
        optimizer = state
    
    for epoch in range(epochs):
        # Randomized the order of the batches.
        batch_ids = list(range(len(batches)))
        random.shuffle(batch_ids)
        for batch_id in batch_ids:
            optimizer.init_step()
            kwargs = {'kl_weight': kl_weight, 'datasize': len(X)}
            elbo = beer.evidence_lower_bound(model, batches[batch_id], 
                                             **kwargs)
            elbo.backward()
            optimizer.step()
        
        if callback is not None:
            callback()
            
        # Monitor the evidence lower bound after each epoch.
        #elbo = beer.evidence_lower_bound(model, X, **kwargs)
        #elbos.append(float(elbo) / len(X))
    
    return optimizer

In [15]:
def plot_latent_space(fig, model, X, use_mean=True):
    enc_states = vae.encoder(X)
    post_params = vae.encoder_problayer(enc_states)
    samples, _ = vae.encoder_problayer.samples_and_llh(post_params, use_mean=use_mean)
    samples = samples.data.numpy()
    fig.circle(samples[:, 0], samples[:, 1])
    
def plot_density(fig, model, x_range, y_range, nsamples=10, marginal=False):
    xy = np.mgrid[x_range[0]:x_range[1]:100j, y_range[0]:y_range[1]:100j].reshape(2,-1).T
    xy = torch.from_numpy(xy).float()
    
    mllhs = []
    for i in range(nsamples):
        if marginal:
            mllhs.append(model.marginal_log_likelihood(xy).view(-1, 1))
        else:
            mllhs.append(model.expected_log_likelihood(xy).view(-1, 1))
    mllhs = torch.cat(mllhs, dim=-1).mean(dim=-1)
    mllhs = mllhs.detach().numpy().reshape(100, 100)
    mlhs = np.exp(mllhs)
    width, height = x_range[1] - x_range[0] / 100, y_range[1] - y_range[0] / 100
    fig.image(image=[mlhs.T], x=x_range[0], y=y_range[0], dw=2 * width, dh=2 * height)

In [16]:
vae = create_vae(global_mean, global_var)

svb_elbos = []
svb_elbos2 = []
svb_elbos_test = []
def log_pred():
    elbo = beer.evidence_lower_bound(vae, X)
    svb_elbos.append(float(elbo) / len(X))
    elbo = beer.collapsed_evidence_lower_bound(vae, X, kl_weight=1.)
    svb_elbos2.append(float(elbo) / len(X))
    elbo = beer.evidence_lower_bound(vae, test_X, datasize=len(test_X), use_mean=False)
    svb_elbos_test.append(float(elbo) / len(test_X))
    
# training the vae.
state = train_svb(vae, X, epochs=5000, nbatches=10, nsamples=1, callback=log_pred,
                  update_prior=True)

# Plotting
fig1 = figure(width=300, height=300)
fig1.line(range(len(svb_elbos)), svb_elbos, legend='ELBO')
fig1.legend.location = 'bottom_right'

fig2 = figure(width=300, height=300, x_range=(-7, 7), y_range=(-7, 7))
mean, cov = vae.latent_model.mean, vae.latent_model.cov
plotting.plot_normal(fig2, mean.numpy(), cov.numpy(),alpha=.1)
plot_latent_space(fig2, vae, X, use_mean=False)

fig3 = figure(width=300, height=300, x_range=x_range, y_range=y_range)
plot_density(fig3, vae, x_range, y_range, nsamples=100)

show(gridplot([[fig1, fig2, fig3]]))

AttributeError: 'VAEGlobalMeanVariance' object has no attribute 'mean_field_groups'

In [64]:
vae = create_vae(global_mean, global_var)

cvb_elbos_test = []
cvb_elbos = []
cvb_elbos2 = []
def log_pred():
    elbo = beer.evidence_lower_bound(vae, X)
    cvb_elbos.append(float(elbo) / len(X))
    elbo = beer.collapsed_evidence_lower_bound(vae, X, kl_weight=1.)
    cvb_elbos2.append(float(elbo) / len(X))
    elbo = beer.collapsed_evidence_lower_bound(vae, test_X, kl_weight=1., use_mean=False)
    cvb_elbos_test.append(float(elbo) / len(test_X))

# training the vae.
#state = train_cvb(vae, X, epochs=10, nbatches=10, callback=log_pred,  kl_weight=0., update_prior=True)
state = train_cvb(vae, X, epochs=5000, nbatches=10, callback=log_pred, 
                  nsamples=1, update_prior=True)

# Plotting
fig1 = figure(width=300, height=300)
fig1.line(range(len(cvb_elbos)), cvb_elbos, legend='ELBO')
fig1.legend.location = 'bottom_right'

fig2 = figure(width=300, height=300, x_range=(-7, 7), y_range=(-7, 7))
mean, cov = vae.latent_model.mean, vae.latent_model.cov
plotting.plot_normal(fig2, mean.numpy(), cov.numpy(),alpha=.1)
plot_latent_space(fig2, vae, X, use_mean=False)

fig3 = figure(width=300, height=300, x_range=x_range, y_range=y_range)
plot_density(fig3, vae, x_range, y_range, nsamples=100, marginal=True)

show(gridplot([[fig1, fig2, fig3]]))

In [65]:
# Plotting

fig1 = figure(title='ELBO (train set)', width=300, height=300, y_range=(-5, 1))
fig1.line(range(len(svb_elbos2)), svb_elbos, color='blue', legend='SVB')
fig1.line(range(len(cvb_elbos2)), cvb_elbos, color='green', legend='CVB')
fig1.legend.location = 'bottom_right'

fig2 = figure(title='Col. ELBO (train set)', width=300, height=300, y_range=(-5, 1))
fig2.line(range(len(svb_elbos2)), svb_elbos2, color='blue', legend='SVB')
fig2.line(range(len(cvb_elbos2)), cvb_elbos2, color='green', legend='CVB')
fig2.legend.location = 'bottom_right'

fig3 = figure(title='log pred. (test set)', width=300, height=300, y_range=(-5, 1))
fig3.line(range(len(svb_elbos_test)), svb_elbos_test, color='blue', legend='SVB')
fig3.line(range(len(cvb_elbos_test)), cvb_elbos_test, color='green', legend='CVB')
fig3.legend.location = 'bottom_right'

show(gridplot([[fig1, fig2, fig3]]))