# Bayesian Mixture Model

This notebook illustrate how to build and train a Bayesian Mixture Model with the [beer framework](https://github.com/beer-asr/beer).

In [None]:
# Add "beer" to the PYTHONPATH
import sys
sys.path.insert(0, '../')

import copy

import beer
import numpy as np
import torch

# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d
output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

## Data

As an illustration, we generate a synthetic data set composed of two Normal distributed cluster. One has a diagonal covariance matrix whereas the other has a dense covariance matrix.

In [None]:
# First cluster.
mean = np.array([-1.5, 4]) 
cov = np.array([[.75, 0], [0, 2.]])
data1 = np.random.multivariate_normal(mean, cov, size=200)

# Second cluster.
mean = np.array([5, 5]) 
cov = np.array([[2, 1], [1, .75]])
data2 = np.random.multivariate_normal(mean, cov, size=200)

# Merge everything to get the finale data set.
data = np.vstack([data1, data2])
np.random.shuffle(data)

In [None]:
# Mean, variance of the data to scale the figure.
mean = data.mean(axis=0)
var = data.var(axis=0)
std_dev = np.sqrt(max(var))
x_range = (mean[0] - 2 * std_dev, mean[0] + 2 * std_dev)
y_range = (mean[1] - 2 * std_dev, mean[1] + 2 * std_dev)
global_range = (min(x_range[0], y_range[0]), max(x_range[1], y_range[1]))

fig = figure(title='Data', width=400, height=400,
             x_range=global_range, y_range=global_range)
fig.circle(data[:, 0], data[:, 1])

show(fig)

## Model Creation

We create two types of mixture model: one whose (Normal) components have full covariance matrix and the other whose (Normal) components have diagonal covariance matrix.

In [None]:
ncomp = 10 

# We use the global mean/cov. matrix of the data to initialize the mixture.
#p_mean = torch.from_numpy(data.mean(axis=0)).float()
#p_cov = torch.from_numpy(np.cov(data.T)).float()
p_mean = torch.zeros(2)
p_cov = torch.eye(2)

print(p_mean, p_cov)

# Mean of the weights' prior.
weights = torch.ones(ncomp) / ncomp

# GMM (diag cov).
normalset = beer.NormalDiagonalCovarianceSet.create(p_mean, torch.diag(p_cov), 
                                                    ncomp, noise_std=0.1)
gmm_diag = beer.Mixture.create(weights, normalset)

# GMM (full cov).
normalset = beer.NormalFullCovarianceSet.create(p_mean, p_cov, ncomp, 
                                                noise_std=0.1)
gmm_full = beer.Mixture.create(weights, normalset)

# GMM shared (full) cov.
normalset = beer.NormalSetSharedDiagonalCovariance.create(p_mean, 
                                                        torch.diag(p_cov), 
                                                        ncomp,
                                                        noise_std=0.1)
gmm_sharedcov_diag = beer.Mixture.create(weights, normalset)

# GMM shared (full) cov.
normalset = beer.NormalSetSharedFullCovariance.create(p_mean, p_cov, ncomp,
                                                      noise_std=0.1)
gmm_sharedcov_full = beer.Mixture.create(weights, normalset)

models = [
    gmm_diag, 
    gmm_full,
    gmm_sharedcov_diag,
    gmm_sharedcov_full
]

## Variational Bayes Training 

In [None]:
epochs = 100
lrate = 1.
X = torch.from_numpy(data).float()
elbo_fn = beer.EvidenceLowerBound(len(X))
params = []
for model in models:
    params += model.parameters
optimizer = beer.BayesianModelOptimizer(params, lrate)
    
elbos = [[], [], [], []]
for epoch in range(epochs):
    optimizer.zero_grad()
    for i, model in enumerate(models):
        elbo = elbo_fn(model, X)
        elbo.natural_backward()
        if epoch > 0:
            elbos[i].append(float(elbo) / len(X))
    optimizer.step()

# Plot the ELBO.
fig = figure(title='ELBO', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig.line(range(1, epochs), elbos[0], legend='GMM (diag)', color='blue')
fig.line(range(1, epochs), elbos[1], legend='GMM (full)', color='red')
fig.line(range(1, epochs), elbos[2], legend='GMM (shared cov. diag)', color='green')
fig.line(range(1, epochs), elbos[3], legend='GMM (shared cov. full)', color='black')
fig.legend.location = 'bottom_right'

show(fig)

In [None]:
fig1 = figure(title='GMM (diag)', x_range=global_range, y_range=global_range,
              width=400, height=400)
fig1.circle(data[:, 0], data[:, 1], alpha=.1)
plotting.plot_gmm(fig1, gmm_diag, alpha=.5, color='blue')

fig2 = figure(title='GMM (full)', x_range=global_range, y_range=global_range,
              width=400, height=400)
fig2.circle(data[:, 0], data[:, 1], alpha=.1)
plotting.plot_gmm(fig2, gmm_full, alpha=.5, color='red')

fig3 = figure(title='GMM (shared cov. full)', x_range=global_range, y_range=global_range,
              width=400, height=400)
fig3.circle(data[:, 0], data[:, 1], alpha=.1)
plotting.plot_gmm(fig3, gmm_sharedcov_diag, alpha=.5, color='green')

fig4 = figure(title='GMM (shared cov. full)', x_range=global_range, y_range=global_range,
              width=400, height=400)
fig4.circle(data[:, 0], data[:, 1], alpha=.1)
plotting.plot_gmm(fig4, gmm_sharedcov_full, alpha=.5, color='black')

grid = gridplot([[fig1, fig2], [fig3, fig4]])
show(grid)