# Bayesian Nested Mixture Model

This notebook illustrate how to build and train a Bayesian Nested Mixture Model with the [beer framework](https://github.com/beer-asr/beer).

In [2]:
# Add "beer" to the PYTHONPATH
import sys
sys.path.insert(0, '../')

import copy

import beer
import numpy as np
import torch

# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d
output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

## Data

As an illustration, we generate a synthetic data set composed of two Normal distributed cluster. One has a diagonal covariance matrix whereas the other has a dense covariance matrix.

In [3]:
# First cluster.
mean = np.array([-5, 5]) 
cov = .5 *np.array([[.75, .5], [.5, 2.]])
data1 = np.random.multivariate_normal(mean, cov, size=200)

# Second cluster.
mean = np.array([5, 5]) 
cov = 2 * np.array([[2, -.5], [-.5, .75]])
data2 = np.random.multivariate_normal(mean, cov, size=200)

# Merge everything to get the finale data set.
data = np.vstack([data1, data2])
np.random.shuffle(data)

# We use the global mean/cov. matrix of the data to initialize the mixture.
data_mean = torch.from_numpy(data.mean(axis=0)).float()
data_var = torch.from_numpy(np.var(data, axis=0)).float()

In [4]:
# Mean, variance of the data to scale the figure.
mean = data.mean(axis=0)
var = data.var(axis=0)
std_dev = np.sqrt(max(var))
x_range = (mean[0] - 2 * std_dev, mean[0] + 2 * std_dev)
y_range = (mean[1] - 2 * std_dev, mean[1] + 2 * std_dev)
global_range = (min(x_range[0], y_range[0]), max(x_range[1], y_range[1]))

fig = figure(title='Data', width=400, height=400,
             x_range=global_range, y_range=global_range)
fig.circle(data[:, 0], data[:, 1])

show(fig)

## Model Creation

We create two types of mixture model: one whose (Normal) components have full covariance matrix and the other whose (Normal) components have diagonal covariance matrix.

In [5]:
nmixtures = 4
ncomp_per_mixture = 3
total_components = nmixtures * ncomp_per_mixture

# We use the global mean/cov. matrix of the data to initialize the mixture.
data_mean = torch.from_numpy(data.mean(axis=0)).float()
data_var = torch.from_numpy(np.var(data, axis=0)).float()

# Isotropic covariance.
modelset = beer.NormalSet.create(
    data_mean, data_var, 
    size=total_components,
    prior_strength=1., 
    noise_std=1., 
    cov_type='isotropic'
)
mixtureset = beer.MixtureSet.create(nmixtures, modelset)
m_gmm_iso = beer.Mixture.create(mixtureset)

# Diagonal covariance.
modelset = beer.NormalSet.create(
    data_mean, data_var, 
    size=total_components,
    prior_strength=1., 
    noise_std=1., 
    cov_type='diagonal'
)
mixtureset = beer.MixtureSet.create(nmixtures, modelset)
m_gmm_diag = beer.Mixture.create(mixtureset)

# Full covariance.
modelset = beer.NormalSet.create(
    data_mean, data_var,
    size=total_components,
    prior_strength=1.,
    noise_std=1., 
    cov_type='full'
)
mixtureset = beer.MixtureSet.create(nmixtures, modelset)
m_gmm_full = beer.Mixture.create(mixtureset)


models = {
    'm_gmm_iso': m_gmm_iso,
    'm_gmm_diag': m_gmm_diag,
    'm_gmm_full': m_gmm_full
}

In [6]:
print(m_gmm_iso)

Mixture(
  (modelset): MixtureSet(
    (categoricalset): CategoricalSet(
      (weights): ConjugateBayesianParameter(prior=Dirichlet, posterior=Dirichlet)
    )
    (modelset): NormalSet(
      (means_precisions): ConjugateBayesianParameter(prior=IsotropicNormalGamma, posterior=IsotropicNormalGamma)
    )
  )
  (categorical): Categorical(
    (weights): ConjugateBayesianParameter(prior=Dirichlet, posterior=Dirichlet)
  )
)


## Variational Bayes Training 

In [7]:
epochs = 200
lrate = 1.
X = torch.from_numpy(data).float()

optims = {
    model_name: beer.VBConjugateOptimizer(
        model.mean_field_factorization(), 
        lrate
    )
    for model_name, model in models.items()
}

elbos = {
    model_name: [] 
    for model_name in models
}  
    
for epoch in range(epochs):
    for name, model in models.items():
        optim = optims[name]
        optim.init_step()
        elbo = beer.evidence_lower_bound(model, X, datasize=len(X))
        elbo.backward()
        elbos[name].append(float(elbo) / len(X))
        optim.step()

In [8]:
colors = {
    'm_gmm_iso': 'green',
    'm_gmm_diag': 'blue',
    'm_gmm_full': 'red',
    'm_gmm_iso_shared': 'grey',
    'm_gmm_diag_shared': 'brown',
    'm_gmm_full_shared': 'black'
}
# Plot the ELBO.
fig = figure(title='ELBO', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
for model_name, elbo in elbos.items():
    fig.line(range(len(elbo)), elbo, legend=model_name, color=colors[model_name])
fig.legend.location = 'bottom_right'

show(fig)

In [9]:
figs = []
for i, model_name in enumerate(models):
    fig = figure(title=model_name, x_range=global_range, y_range=global_range,
                  width=250, height=250)
    model = models[model_name]
    weights = model.categorical.mean
    for j, gmm in enumerate(model.modelset):
        fig.circle(data[:, 0], data[:, 1], alpha=.1)
        plotting.plot_gmm(fig, gmm, alpha=weights[j].numpy())
    if i % 3 == 0:
        figs.append([])
    figs[-1].append(fig)
grid = gridplot(figs)
show(grid)

## Hierarchical Dirichlet Process Mixture Model



In [18]:
truncation_1 = 20  # root level truncation
truncation_2 = 20  # bottom level truncation
n_components = 5

# We use the global mean/cov. matrix of the data to initialize the mixture.
data_mean = torch.from_numpy(data.mean(axis=0)).float()
data_var = torch.from_numpy(np.var(data, axis=0)).float()

modelset = beer.NormalSet.create(
    data_mean, data_var,
    size=truncation_1,
    prior_strength=1.,
    noise_std=1., 
    cov_type='full'
)
mixtureset = beer.MixtureSet.create(nmixtures, modelset)
hdp_gmm = beer.Mixture.create(mixtureset)

ssb = beer.SBCategoricalSet.create(n_components, truncation_2)
ssb.mean

tensor([[3.6788e-01, 1.3534e-01, 4.9787e-02, 1.8316e-02, 6.7379e-03, 2.4788e-03,
         9.1188e-04, 3.3546e-04, 1.2341e-04, 4.5400e-05, 1.6702e-05, 6.1442e-06,
         2.2603e-06, 8.3153e-07, 3.0590e-07, 1.1253e-07, 4.1399e-08, 1.5230e-08,
         5.6028e-09, 2.0611e-09],
        [3.6788e-01, 1.3534e-01, 4.9787e-02, 1.8316e-02, 6.7379e-03, 2.4788e-03,
         9.1188e-04, 3.3546e-04, 1.2341e-04, 4.5400e-05, 1.6702e-05, 6.1442e-06,
         2.2603e-06, 8.3153e-07, 3.0590e-07, 1.1253e-07, 4.1399e-08, 1.5230e-08,
         5.6028e-09, 2.0611e-09],
        [3.6788e-01, 1.3534e-01, 4.9787e-02, 1.8316e-02, 6.7379e-03, 2.4788e-03,
         9.1188e-04, 3.3546e-04, 1.2341e-04, 4.5400e-05, 1.6702e-05, 6.1442e-06,
         2.2603e-06, 8.3153e-07, 3.0590e-07, 1.1253e-07, 4.1399e-08, 1.5230e-08,
         5.6028e-09, 2.0611e-09],
        [3.6788e-01, 1.3534e-01, 4.9787e-02, 1.8316e-02, 6.7379e-03, 2.4788e-03,
         9.1188e-04, 3.3546e-04, 1.2341e-04, 4.5400e-05, 1.6702e-05, 6.1442e-06,
       

In [58]:
import pickle 
with open('/mnt/matylda5/iondel/workspace/2019/asru/beer/recipes/aud/exp_ch1_v2/timit/aud_mfcc_4g_dirichlet_process/final.mdl', 'rb') as f:
    ploop = pickle.load(f)
sb_categorical = ploop.categorical

In [77]:
m = sb_categorical.mean
m, 1 - m.cumsum(dim=0)

(tensor([8.4794e-02, 6.8245e-02, 5.9572e-02, 5.3810e-02, 4.9531e-02, 4.5954e-02,
         4.2675e-02, 3.9621e-02, 3.6804e-02, 3.4056e-02, 3.1768e-02, 2.9965e-02,
         2.8356e-02, 2.7020e-02, 2.5932e-02, 2.4905e-02, 2.4000e-02, 2.3210e-02,
         2.2438e-02, 2.1641e-02, 2.0904e-02, 2.0179e-02, 1.9428e-02, 1.8530e-02,
         1.7518e-02, 1.6477e-02, 1.5243e-02, 1.3951e-02, 1.2607e-02, 1.1147e-02,
         9.8266e-03, 8.5809e-03, 7.4084e-03, 6.3655e-03, 5.4483e-03, 4.6224e-03,
         3.8511e-03, 3.1464e-03, 2.5248e-03, 1.9762e-03, 1.5367e-03, 1.1640e-03,
         8.9236e-04, 6.6523e-04, 4.6442e-04, 3.0076e-04, 2.0410e-04, 1.2422e-04,
         8.2237e-05, 4.6845e-05, 2.8373e-05, 1.8312e-05, 9.2043e-06, 5.7885e-06,
         4.8481e-06, 3.9383e-06, 3.3966e-06, 2.9706e-06, 2.8055e-06, 2.7001e-06,
         2.6145e-06, 2.5215e-06, 2.4651e-06, 2.4145e-06, 2.3644e-06, 2.3170e-06,
         2.2709e-06, 2.2260e-06, 2.1822e-06, 2.1393e-06, 2.0973e-06, 2.0562e-06,
         2.0159e-06, 1.9763e

In [79]:
concentration = 2.

from beer.dists import Dirichlet 
from beer.models import ConjugateBayesianParameter

def _default_set_sb_param(n_components, root_sb_categorical, prior_strength):
    mean = sb_categorical.mean
    params = torch.ones(n_components, len(root_sb_categorical.stickbreaking), 2)
    params[:, :, 0] = prior_strength * mean 
    params[:, :, 1] = prior_strength * (1 - mean.cumsum(dim=0))
    params = params.reshape(-1, 2)
    prior = Dirichlet.from_std_parameters(params)
    posterior = Dirichlet.from_std_parameters(params.clone())
    return ConjugateBayesianParameter(prior, posterior)

_default_set_sb_param(2, sb_categorical, concentration)

ConjugateBayesianParameter(prior=Dirichlet, posterior=Dirichlet)

In [92]:
ssb = beer.SBCategoricalSet.create(n_components, sb_categorical, prior_strength=100.)

In [102]:
data = torch.eye(101)
stats = ssb.sufficient_statistics(data)
ssb.expected_log_likelihood(stats

torch.Size([101, 101]) torch.Size([5, 101])


tensor([[-2.5227e+00, -2.7547e+00, -2.9018e+00, -3.0131e+00, -3.1045e+00,
         -3.1878e+00, -3.2709e+00, -3.3549e+00, -3.4391e+00, -3.5287e+00,
         -3.6099e+00, -3.6787e+00, -3.7445e+00, -3.8025e+00, -3.8523e+00,
         -3.9017e+00, -3.9472e+00, -3.9888e+00, -4.0311e+00, -4.0766e+00,
         -4.1207e+00, -4.1659e+00, -4.2149e+00, -4.2768e+00, -4.3513e+00,
         -4.4339e+00, -4.5412e+00, -4.6666e+00, -4.8149e+00, -5.0031e+00,
         -5.2063e+00, -5.4386e+00, -5.7095e+00, -6.0142e+00, -6.3580e+00,
         -6.7629e+00, -7.2736e+00, -7.9311e+00, -8.7854e+00, -9.9525e+00,
         -1.1457e+01, -1.3592e+01, -1.6246e+01, -2.0105e+01, -2.6636e+01,
         -3.8378e+01, -5.4140e+01, -8.5657e+01, -1.2676e+02, -2.1864e+02,
         -3.5762e+02, -5.5127e+02, -1.0916e+03, -1.7327e+03, -2.0678e+03,
         -2.5443e+03, -2.9493e+03, -3.3715e+03, -3.5696e+03, -3.7088e+03,
         -3.8301e+03, -3.9711e+03, -4.0618e+03, -4.1469e+03, -4.2345e+03,
         -4.3211e+03, -4.4088e+03, -4.

In [129]:
acc_stats = torch.exp(torch.randn(7, 5, 101))
acc_stats /= acc_stats.sum(dim=(1,2))[:, None, None]
acc_stats

ssb.accumulate_from_jointresps(acc_stats)[ssb.stickbreaking].shape, ssb.stickbreaking.posterior.natural_parameters().shape

(torch.Size([505, 2]), torch.Size([505, 2]))