# Bayesian Nested Mixture Model

This notebook illustrate how to build and train a Bayesian Nested Mixture Model with the [beer framework](https://github.com/beer-asr/beer).

In [2]:
# Add "beer" to the PYTHONPATH
import sys
sys.path.insert(0, '../')

import copy

import beer
import numpy as np
import torch

# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d
output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Data

As an illustration, we generate a synthetic data set composed of two Normal distributed cluster. One has a diagonal covariance matrix whereas the other has a dense covariance matrix.

In [3]:
# First cluster.
mean = np.array([-5, 5]) 
cov = .5 *np.array([[.75, .5], [.5, 2.]])
data1 = np.random.multivariate_normal(mean, cov, size=200)

# Second cluster.
mean = np.array([5, 5]) 
cov = 2 * np.array([[2, -.5], [-.5, .75]])
data2 = np.random.multivariate_normal(mean, cov, size=200)

# Merge everything to get the finale data set.
data = np.vstack([data1, data2])
np.random.shuffle(data)

# We use the global mean/cov. matrix of the data to initialize the mixture.
data_mean = torch.from_numpy(data.mean(axis=0)).float()
data_var = torch.from_numpy(np.var(data, axis=0)).float()

In [4]:
# Mean, variance of the data to scale the figure.
mean = data.mean(axis=0)
var = data.var(axis=0)
std_dev = np.sqrt(max(var))
x_range = (mean[0] - 2 * std_dev, mean[0] + 2 * std_dev)
y_range = (mean[1] - 2 * std_dev, mean[1] + 2 * std_dev)
global_range = (min(x_range[0], y_range[0]), max(x_range[1], y_range[1]))

fig = figure(title='Data', width=400, height=400,
             x_range=global_range, y_range=global_range)
fig.circle(data[:, 0], data[:, 1])

show(fig)

## Model Creation

We create two types of mixture model: one whose (Normal) components have full covariance matrix and the other whose (Normal) components have diagonal covariance matrix.

In [5]:
nmixtures = 4
ncomp_per_mixture = 3
total_components = nmixtures * ncomp_per_mixture

# We use the global mean/cov. matrix of the data to initialize the mixture.
data_mean = torch.from_numpy(data.mean(axis=0)).float()
data_var = torch.from_numpy(np.var(data, axis=0)).float()

# Isotropic covariance.
modelset = beer.NormalSet.create(
    data_mean, data_var, 
    size=total_components,
    prior_strength=1., 
    noise_std=1., 
    cov_type='isotropic'
)
mixtureset = beer.MixtureSet.create(nmixtures, modelset)
m_gmm_iso = beer.Mixture.create(mixtureset)

# Diagonal covariance.
modelset = beer.NormalSet.create(
    data_mean, data_var, 
    size=total_components,
    prior_strength=1., 
    noise_std=1., 
    cov_type='diagonal'
)
mixtureset = beer.MixtureSet.create(nmixtures, modelset)
m_gmm_diag = beer.Mixture.create(mixtureset)

# Full covariance.
modelset = beer.NormalSet.create(
    data_mean, data_var,
    size=total_components,
    prior_strength=1.,
    noise_std=1., 
    cov_type='full'
)
mixtureset = beer.MixtureSet.create(nmixtures, modelset)
m_gmm_full = beer.Mixture.create(mixtureset)


models = {
    'm_gmm_iso': m_gmm_iso,
    'm_gmm_diag': m_gmm_diag,
    'm_gmm_full': m_gmm_full
}

In [6]:
print(m_gmm_iso)

Mixture(
  (modelset): MixtureSet(
    (categoricalset): CategoricalSet(
      (weights): ConjugateBayesianParameter(prior=Dirichlet, posterior=Dirichlet)
    )
    (modelset): NormalSet(
      (means_precisions): ConjugateBayesianParameter(prior=IsotropicNormalGamma, posterior=IsotropicNormalGamma)
    )
  )
  (categorical): Categorical(
    (weights): ConjugateBayesianParameter(prior=Dirichlet, posterior=Dirichlet)
  )
)


## Variational Bayes Training 

In [7]:
epochs = 200
lrate = 1.
X = torch.from_numpy(data).float()

optims = {
    model_name: beer.VBConjugateOptimizer(
        model.mean_field_factorization(), 
        lrate
    )
    for model_name, model in models.items()
}

elbos = {
    model_name: [] 
    for model_name in models
}  
    
for epoch in range(epochs):
    for name, model in models.items():
        optim = optims[name]
        optim.init_step()
        elbo = beer.evidence_lower_bound(model, X, datasize=len(X))
        elbo.backward()
        elbos[name].append(float(elbo) / len(X))
        optim.step()

In [8]:
colors = {
    'm_gmm_iso': 'green',
    'm_gmm_diag': 'blue',
    'm_gmm_full': 'red',
    'm_gmm_iso_shared': 'grey',
    'm_gmm_diag_shared': 'brown',
    'm_gmm_full_shared': 'black'
}
# Plot the ELBO.
fig = figure(title='ELBO', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
for model_name, elbo in elbos.items():
    fig.line(range(len(elbo)), elbo, legend=model_name, color=colors[model_name])
fig.legend.location = 'bottom_right'

show(fig)

In [9]:
figs = []
for i, model_name in enumerate(models):
    fig = figure(title=model_name, x_range=global_range, y_range=global_range,
                  width=250, height=250)
    model = models[model_name]
    weights = model.categorical.mean
    for j, gmm in enumerate(model.modelset):
        fig.circle(data[:, 0], data[:, 1], alpha=.1)
        plotting.plot_gmm(fig, gmm, alpha=weights[j].numpy())
    if i % 3 == 0:
        figs.append([])
    figs[-1].append(fig)
grid = gridplot(figs)
show(grid)

## Hierarchical Dirichlet Process Mixture Model



In [10]:
import pickle 
with open('/home/lucas/Desktop/test_reorder.mdl', 'rb') as f:
    ploop = pickle.load(f)
    
sb_categoricalset = beer.SBCategoricalSet.create(len(ploop.start_pdf), ploop.categorical, prior_strength=1)
sb_categoricalset.mean_field_factorization()
sb_categoricalset.stickbreaking

ConjugateBayesianParameter(prior=Dirichlet, posterior=Dirichlet)

In [11]:
ploop.categorical.ordering

tensor([  0,  22,  28,  79,  90,  38,   4,  40,  60,  12,   9,  66,  43,   6,
         71,  81,  37,  16,  25,  63,  49,  11,  57,  23,  35,  44,  80,  82,
         48,  30,  18,  21,  85,  31,  97,  33,  84,  77,  32,  39,   2,  59,
         89,  13,   7,  56,  15,  78,  54,  95,  68,  52,  47,  65,  20,  69,
         29,  87,  41,  42,   3,  72,  19,   8,  58,  24,  17,  75,  61,  91,
          1,  94,  74,  36,  53,  26,  10,  27,  55,  64,  45,  51,  14,  50,
         46,  99,  76,  70,  62,   5,  67,  34,  83,  73, 100,  93,  98,  88,
         96,  92,  86])

In [12]:
data = torch.eye(len(ploop.start_pdf))
stats = ploop.categorical.sufficient_statistics(data)
log_weights = ploop.categorical.expected_log_likelihood(stats)
log_weights

tensor([ -2.8977,  -4.8074,  -4.4731,  -4.6469,  -4.1472, -12.7396,  -4.2089,
         -4.5137,  -4.6595,  -4.1685,  -4.8238,  -4.2644,  -4.2041,  -4.5405,
         -5.1539,  -4.5429,  -4.2365,  -4.7592,  -4.3686,  -4.6578,  -4.6209,
         -4.3934,  -3.9633,  -4.3094,  -4.7372,  -4.2841,  -4.9348,  -5.1012,
         -3.9816,  -4.6313,  -4.3646,  -4.4146,  -4.5209,  -4.5215, -12.7792,
         -4.3122,  -4.8477,  -4.2550,  -4.1250,  -4.4696,  -4.1186,  -4.6737,
         -4.6896,  -4.1682,  -4.3554,  -5.0000,  -5.4830,  -4.6407,  -4.3627,
         -4.2766,  -5.3377,  -5.0815,  -4.5643,  -4.8729,  -4.5636,  -4.9744,
         -4.5199,  -4.3090,  -4.8281,  -4.5632,  -4.1601,  -4.7726, -11.7298,
         -4.2697,  -4.9771,  -4.6150,  -4.1701, -12.7594,  -4.6234,  -4.6272,
        -11.7148,  -4.2120,  -4.7531, -12.8188,  -4.8203,  -4.7648, -11.2086,
         -4.4661,  -4.5616,  -4.0325,  -4.3350,  -4.2451,  -4.3838, -12.7990,
         -4.4803,  -4.4581, -12.9574,  -4.6455, -12.8980,  -4.47

In [13]:
data = torch.eye(len(ploop.start_pdf))
stats = sb_categoricalset.sufficient_statistics(data)
log_weights = sb_categoricalset.expected_log_likelihood(stats)
log_weights

tensor([[ -2.8977,  -4.8074,  -4.4731,  ..., -12.8782,  -5.8312, -12.8386],
        [ -2.8977,  -4.8074,  -4.4731,  ..., -12.8782,  -5.8312, -12.8386],
        [ -2.8977,  -4.8074,  -4.4731,  ..., -12.8782,  -5.8312, -12.8386],
        ...,
        [ -2.8977,  -4.8074,  -4.4731,  ..., -12.8782,  -5.8312, -12.8386],
        [ -2.8977,  -4.8074,  -4.4731,  ..., -12.8782,  -5.8312, -12.8386],
        [ -2.8977,  -4.8074,  -4.4731,  ..., -12.8782,  -5.8312, -12.8386]])

In [14]:
from scipy.special import gamma

def beta(x, a, b):
    norm = gamma(a + b) / (gamma(a) * gamma(b))
    return (x ** (a - 1) * (1 - x)**(b-1)) * norm

x = np.linspace(1e-3, 0.999, 1000)
mean = ploop.categorical.mean.numpy()
cmean = mean.cumsum()
concentration = 1000

fig = figure()

for i in range(90, 100):
    p_x = beta(x, concentration * mean[i], concentration * (1 - cmean[i]))
    if i == 0:
        fig.line(x, p_x, color='red')
    else:
        fig.line(x, p_x)
show(fig)


In [15]:
def hdp_sb(mean, concentration):
    cmean = np.cumsum(mean)
    v = np.array([np.random.beta(concentration * mean[i], concentration * (1 - cmean[i]))
                  for i in range(101)])
    residual = np.cumprod(1 - v)
    pi = v
    pi[1:] *= residual[:-1]
    return pi

samples = np.c_[[hdp_sb(mean, concentration=100) for i in range(50)]]

fig = figure()


for sample in samples:
    fig.line(range(101), sample, alpha=.3)
fig.line(range(101), mean, color='red')
fig.line(range(101), samples.mean(axis=0), color='green')

show(fig)

In [25]:
sb_categoricalset = beer.SBCategoricalSet.create(len(ploop.start_pdf), ploop.categorical, prior_strength=50)
vbinit = sb_categoricalset.mean.numpy()

fig = figure()
fig.line(range(101), mean, color='red')
fig.line(range(101), samples.mean(axis=0), color='green')
fig.line(range(101), vbinit.mean(axis=0), color='blue')

show(fig)

RuntimeError: Columns need to be 1D (y is not)

In [27]:
import pickle 
with open('/home/lucas/Desktop/test_reorder.mdl', 'rb') as f:
    ploop = pickle.load(f)
sb_categoricalset = beer.SBCategoricalSet.create(len(ploop.start_pdf), ploop.categorical, prior_strength=1)
bploop = beer.BigramPhoneLoop.create(ploop.graph, ploop.start_pdf, ploop.end_pdf,
                                     ploop.modelset, sb_categoricalset).double()

X = torch.from_numpy(np.load('/home/lucas/Desktop/mzmb0_sx176.npy'))
model = bploop
epochs = 10
optim =  beer.VBConjugateOptimizer(model.mean_field_factorization(), lrate=1)
elbos = []
    
for epoch in range(epochs):
    optim.init_step()
    elbo = beer.evidence_lower_bound(model, X)
    #elbo += beer.evidence_lower_bound(model, X)
    elbo.backward()
    elbos.append(float(elbo) / len(X))
    optim.step()
    print(elbos[-1])

-1065.3952062208352
-66.76383924880133
-64.51263460563408
-63.85281116045825
-63.589213531125836
-63.32554173223835
-63.210672161560296
-63.16201029741327
-63.13208599028739
-63.100564791618396


In [23]:
mean = bploop.categoricalset.mean.numpy()
print(mean.sum())

fig = figure(x_range=(0, 101), y_range=(0, 101))
fig.image(image=[mean], x=0, y=0, dh=101, dw=101)
show(fig)

16.164053065235425


In [47]:
with open('/home/lucas/Desktop/test_bigram_reorder.mdl', 'rb') as f:
    bploop = pickle.load(f)
    
mean = bploop.categoricalset.mean[:, bploop.categoricalset.ordering].numpy()
print(mean.sum())

fig = figure()
fig.line(range(101), mean[9])
show(fig)

fig = figure(x_range=(0, 101), y_range=(0, 101))
fig.image(image=[mean], x=0, y=0, dh=101, dw=101)
show(fig)

86.152115


In [29]:
with open('/home/lucas/Desktop/test_reorder.mdl', 'rb') as f:
    ploop = pickle.load(f)
    
print(ploop.categorical.ordering)
mean = ploop.categorical.mean[ploop.categorical.ordering].numpy()

fig = figure()
fig.line(range(101), mean, color='green')

with open('/home/lucas/Desktop/test_reorder.mdl', 'rb') as f:
    ploop = pickle.load(f)
    
print(ploop.categorical.ordering)
mean = ploop.categorical.mean[ploop.categorical.ordering].numpy()
fig.line(range(101), mean)

show(fig)

mean = ploop.categorical.mean[ploop.categorical.ordering].repeat(101, 1).numpy()
print(mean.sum())

fig = figure(x_range=(0, 101), y_range=(0, 101))
fig.image(image=[mean], x=0, y=0, dh=101, dw=101)
show(fig)

tensor([  0,  22,  28,  79,  90,  38,   4,  40,  60,  12,   9,  66,  43,   6,
         71,  81,  37,  16,  25,  63,  49,  11,  57,  23,  35,  44,  80,  82,
         48,  30,  18,  21,  85,  31,  97,  33,  84,  77,  32,  39,   2,  59,
         89,  13,   7,  56,  15,  78,  54,  95,  68,  52,  47,  65,  20,  69,
         29,  87,  41,  42,   3,  72,  19,   8,  58,  24,  17,  75,  61,  91,
          1,  94,  74,  36,  53,  26,  10,  27,  55,  64,  45,  51,  14,  50,
         46,  99,  76,  70,  62,   5,  67,  34,  83,  73, 100,  93,  98,  88,
         96,  92,  86])
tensor([  0,  22,  28,  79,  90,  38,   4,  40,  60,  12,   9,  66,  43,   6,
         71,  81,  37,  16,  25,  63,  49,  11,  57,  23,  35,  44,  80,  82,
         48,  30,  18,  21,  85,  31,  97,  33,  84,  77,  32,  39,   2,  59,
         89,  13,   7,  56,  15,  78,  54,  95,  68,  52,  47,  65,  20,  69,
         29,  87,  41,  42,   3,  72,  19,   8,  58,  24,  17,  75,  61,  91,
          1,  94,  74,  36,  53,  26,  1

100.951035


In [None]:
mean = ploop.categorical.mean.repeat(101, 1).numpy()
print(mean.sum())

fig = figure(x_range=(0, 101), y_range=(0, 101))
fig.image(image=[np.log(mean)], x=0, y=0, dh=101, dw=101, palette='Viridis256')
show(fig)

In [None]:
x = torch.tensor([1, 2, 3])
x.repeat(4, 2)