# Bayesian Mixture Model

This notebook illustrate how to build and train a Bayesian Hidden Markov Model with the [beer framework](https://github.com/beer-asr/beer).

In [None]:
# Add "beer" to the PYTHONPATH
import sys
sys.path.insert(0, '../')

import copy

import beer
import numpy as np
import torch

# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d
output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

## Data
Generate data following HMM generative process
#### Probability of initial states
$$
p(s^0 = s_1) = 1 \\
p(s^0 = s_2) = 0 \\
p(s^0 = s_3) = 0
$$

#### Probability of transitions
$$
p(s^t = s_1 \vert s^{t-1} = s_1) = 0.5 \quad p(s^t = s_2 \vert s^{t-1} = s_1) = 0.5 \quad p(s^t = s_3 \vert s^{t-1} = s_1) = 0 \\
p(s^t = s_1 \vert s^{t-1} = s_2) = 0 \quad p(s^t = s_2 \vert s^{t-1} = s_2) = 0.5 \quad p(s^t = s_3 \vert s^{t-1} = s_2) = 0.5 \\
p(s^t = s_1 \vert s^{t-1} = s_3) = 0.5 \quad p(s^t = s_2 \vert s^{t-1} = s_3) = 0 \quad p(s^t = s_3 \vert s^{t-1} = s_3) = 0.5 \\  
$$

#### Emission
$$
p(x^t \vert s^t = s_1) = \mathcal{N}(x^t \vert \mu_1, \Sigma_1) \\
p(x^t \vert s^t = s_2) = \mathcal{N}(x^t \vert \mu_2, \Sigma_2) \\
p(x^t \vert s^t = s_3) = \mathcal{N}(x^t \vert \mu_3, \Sigma_3)
$$


In [None]:
nsamples = 100
ndim = 2
nstates = 3
trans_mat = np.array([[.5, .5, 0], [0, .5, .5], [.5, 0, .5]])

means = [np.array([-1.5, 4]),np.array([5, 5]), np.array([1, -2])]
covs = [np.array([[.75, -.5], [-.5, 2.]]), np.array([[2, 1], [1, .75]]), np.array([[1, 0], [0, 1]]) ]
normal_sets = [[means[0], covs[0]], [means[1], covs[1]], [means[2], covs[2]]] 

states = np.zeros(nsamples, dtype=np.int16)
data = np.zeros((nsamples, ndim))
states[0] = 0
data[0] = np.random.multivariate_normal(means[states[0]], covs[states[0]], size=1)

colors = ['blue', 'red', 'green']
fig1 = figure(title='Samples', width=400, height=400)
fig1.circle(data[0, 0], data[0, 1], color=colors[states[0]])


for n in range(1, nsamples):
    states[n] = np.random.choice(np.arange(nstates), p=trans_mat[states[n-1]])
    data[n] = np.random.multivariate_normal(means[states[n]], covs[states[n]], size=1)
    fig1.circle(data[n, 0], data[n, 1], color=colors[states[n]], line_width=1)
    fig1.line(data[n-1:n+1, 0], data[n-1:n+1, 1], color='black', line_width=.5, alpha=.5)

fig2 = figure(title='Emissions',  width=400, height=400)
colors = ['blue', 'red', 'green']
for i, n in enumerate(normal_sets):
    plotting.plot_normal(fig2, n[0], n[1], alpha=.3, color=colors[i])
grid = gridplot([[fig1, fig2]])
show(grid)

## Model Creation

We create several types of HMMs, each of them has the same transition matrix and initial / final state probability, and a specific type of emission density: 
  * one Normal density per state with full covariance matrix
  * one Normal density per state with diagonal covariance matrix
  * one Normal density per state with full covariance matrix shared across states
  * one Normal density per state with diagonal covariance matrix shared across states.

In [None]:
# We use the global mean/cov. matrix of the data to initialize the mixture.
p_mean = torch.from_numpy(data.mean(axis=0)).float()
p_cov = torch.from_numpy(np.cov(data.T)).float()

init_states = torch.from_numpy(np.arange(nstates))
final_states = torch.from_numpy(np.arange(nstates))
trans_mat = torch.from_numpy(trans_mat).float()

# HMM (diag cov).
normalset = beer.NormalDiagonalCovarianceSet.create(p_mean, torch.diag(p_cov), 
                                                    nstates, noise_std=0.5)
hmm_diag = beer.HMM.create(init_states, final_states, trans_mat, normalset)

# HMM (full cov).
normalset = beer.NormalFullCovarianceSet.create(p_mean, p_cov, nstates, 
                                                noise_std=0.5)
hmm_full = beer.HMM.create(init_states, final_states, trans_mat, normalset)

# HMM shared (full) cov.
normalset = beer.NormalSetSharedDiagonalCovariance.create(p_mean, 
                                                        torch.diag(p_cov), 
                                                        nstates,
                                                        noise_std=0.5)
hmm_sharedcov_diag = beer.HMM.create(init_states, final_states, trans_mat, normalset)

# HMM shared (full) cov.
normalset = beer.NormalSetSharedFullCovariance.create(p_mean, p_cov, nstates,
                                                      noise_std=0.5)
hmm_sharedcov_full = beer.HMM.create(init_states, final_states, trans_mat, normalset)

models = [
    hmm_diag, 
    hmm_full,
    hmm_sharedcov_diag,
    hmm_sharedcov_full
]

## Variational Bayes Training 

In [None]:
epochs = 30
lrate = 1.
labels = states
X = torch.from_numpy(data).float()
#Z = torch.from_numpy(labels).long()
Z = None
elbo_fn = beer.EvidenceLowerBound(len(X))
params = []
for model in models:
    params += model.parameters
optimizer = beer.BayesianModelOptimizer(params, lrate)
    
elbos = [[], [], [], []]
for epoch in range(epochs):
    optimizer.zero_grad()
    for i, model in enumerate(models):
        elbo = elbo_fn(model, X, Z)
        elbo.natural_backward()
        elbos[i].append(float(elbo) / len(X))
    optimizer.step()

# Plot the ELBO.
fig = figure(title='ELBO', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig.line(range(epochs), elbos[0], legend='HMM (diag)', color='blue')
fig.line(range(epochs), elbos[1], legend='HMM (full)', color='red')
fig.line(range(epochs), elbos[2], legend='HMM (shared cov. diag)', color='green')
fig.line(range(epochs), elbos[3], legend='HMM (shared cov. full)', color='black')
fig.legend.location = 'bottom_right'

show(fig)

### Plotting

In [None]:
mean = data.mean(axis=0)
var = data.var(axis=0)
std_dev = np.sqrt(max(var))
x_range = (mean[0] - 2 * std_dev, mean[0] + 2 * std_dev)
y_range = (mean[1] - 2 * std_dev, mean[1] + 2 * std_dev)
global_range = (min(x_range[0], y_range[0]), max(x_range[1], y_range[1]))

fig1 = figure(title='HMM (diag)', x_range=global_range, y_range=global_range,
              width=400, height=400)
fig1.circle(data[:, 0], data[:, 1], alpha=.1)
plotting.plot_hmm(fig1, hmm_diag, alpha=.1, color='blue')

fig2 = figure(title='HMM (full)', x_range=global_range, y_range=global_range,
              width=400, height=400)
fig2.circle(data[:, 0], data[:, 1], alpha=.1)
plotting.plot_hmm(fig2, hmm_full, alpha=.1, color='red')

fig3 = figure(title='HMM (shared cov. diag)', x_range=global_range, y_range=global_range,
              width=400, height=400)
fig3.circle(data[:, 0], data[:, 1], alpha=.1)
plotting.plot_hmm(fig3, hmm_sharedcov_diag, alpha=.1, color='green')

fig4 = figure(title='HMM (shared cov. full)', x_range=global_range, y_range=global_range,
              width=400, height=400)
fig4.circle(data[:, 0], data[:, 1], alpha=.1)
plotting.plot_hmm(fig4, hmm_sharedcov_full, alpha=.1, color='black')

grid = gridplot([[fig1, fig2], [fig3, fig4]])
show(grid)