# Bayesian Model

This notebook illustrate how to use a Bayesian Normal density model with the [beer framework](https://github.com/beer-asr/beer). The Normal distribution is a fairly basic models but it is used extenslively in other model as a basic building block.

In [1]:
# Add "beer" to the PYTHONPATH
import random
from collections import defaultdict
import sys
sys.path.insert(0, '../')


import beer
import numpy as np
import torch

# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

## Data

Generate some normally distributed data:

In [2]:
mean = np.zeros(2) 
cov = np.array([
    [2, .95],
    [.95, .9]
])
data = np.random.multivariate_normal(mean, cov, size=100)

fig = figure(
    title='Data',
    width=400,
    height=400,
    x_range=(mean[0] - 5, mean[0] + 5),
    y_range=(mean[1] - 5, mean[1] + 5)
)
fig.circle(data[:, 0], data[:, 1])
plotting.plot_normal(fig, mean, cov, line_color='black', fill_alpha=.3)

show(fig)

## Model Creation

We create two types of Normal distribution: one diagonal covariance matrix and another one with full covariance matrix.

In [3]:
data_mean = torch.zeros(2) 
data_var = torch.ones(2)

normal_iso = beer.Normal.create(data_mean, data_var, 1., cov_type='isotropic')
normal_diag = beer.Normal.create(data_mean, data_var, 1., cov_type='diagonal')
normal_full = beer.Normal.create(data_mean, data_var, 1., cov_type='full')

models = {
    'normal_full': normal_full,
    'normal_diag': normal_diag,
    'normal_iso': normal_iso
}
models['normal_diag']

Normal(
  (mean_precision): ConjugateBayesianParameter(prior=NormalGamma, posterior=NormalGamma)
)

In [4]:
model = models['normal_diag']
model.mean_precision.prior

NormalGamma(
  (params): NormalGammaStdParams(mean=tensor([0., 0.]), scale=tensor(1.), shape=tensor(1.), rates=tensor([1., 1.]))
)

In [5]:
model.mean_precision.posterior.params.mean.shape

torch.Size([2])

## Variational Bayes Training 

In [6]:
nbatches = 1
X = torch.from_numpy(data).float()
batches = X.view(nbatches, -1, 2)
batches.shape

torch.Size([1, 100, 2])

In [7]:
epochs = 10
lrate = 1.


optims = {
    model_name: beer.VariationalBayesOptimizer(
        model.mean_field_factorization(), lrate)
    for model_name, model in models.items()
}

elbos = {
    model_name: [] 
    for model_name in models
}  


for epoch in range(epochs):
    for name, model in models.items():
        batch_ids = list(range(len(batches)))
        random.shuffle(batch_ids)
        for batch_id in batch_ids:
            batch = batches[batch_id]
            
            optim = optims[name]

            optim.init_step()
            elbo = beer.evidence_lower_bound(model, batch, datasize=len(X))
            elbo.backward()
            optim.step()
            
            elbo = beer.evidence_lower_bound(model, X)
            elbos[name].append(float(elbo) / len(X))  

In [8]:
colors = {
    'normal_iso': 'green',
    'normal_diag': 'blue',
    'normal_full': 'red',
    
}
# Plot the ELBO.
fig = figure(title='ELBO', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
for model_name, elbo in elbos.items():
    fig.line(range(len(elbo)), elbo, legend=model_name, color=colors[model_name])
fig.legend.location = 'bottom_right'

show(fig)

In [9]:
fig = figure(
    width=400,
    height=400,
    x_range=(-5, 5),
    y_range=(-5, 5)
)
fig.circle(data[:, 0], data[:, 1])

mean, precision = normal_iso.mean_precision.value()
cov = torch.eye(2) / precision
plotting.plot_normal(fig, mean.numpy(), cov.numpy(), 
                     line_color='black', fill_alpha=.3, color='green')

mean, precision = normal_diag.mean_precision.value()
cov = (1/precision).diag()
plotting.plot_normal(fig, mean.numpy(), cov.numpy(), 
                     line_color='black', fill_alpha=.3, color='#98AFC7')

mean, precision = normal_full.mean_precision.value()
cov = precision.inverse()
plotting.plot_normal(fig, mean.numpy(), cov.numpy(), 
                     line_color='black', fill_alpha=.3, color='#C7B097')

show(fig)