## Explain model comparison toy example

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import scipy

import sys 
sys.path.append('../../')
from model_comparison.utils import *
from model_comparison.mdns import ClassificationSingleLayerMDN, Trainer, UnivariateMogMDN
from model_comparison.models import PoissonModel, NegativeBinomialModel

## Set up Poisson and Negative Binomial model

In [None]:
seed = None
rng = np.random.RandomState(seed=seed)
sample_size = 10
ntrain = 1000

k2 = 5.
theta2 = 2.0

k3 = 2.
theta3 = .8

# then the scale of the Gamma prior for the Poisson is given by
k1 = 2.0
theta1 = (k2 * theta2 * k3 * theta3) / k1
print(theta1)

model_poisson = PoissonModel(k1, theta1, sample_size=sample_size, seed=seed)
model_nb = NegativeBinomialModel(k2, theta2, k3, theta3, sample_size=sample_size, seed=seed)

In [None]:
params_poi = rng.gamma(shape=k1, scale=theta1, size=int(ntrain / 2))

params_nb = np.vstack((rng.gamma(shape=k2, scale=theta2, size=int(ntrain / 2)), 
                       rng.gamma(shape=k3, scale=theta3, size=int(ntrain / 2)))).T

## Generate data from models and calculate summary stats

In [None]:
# data_poi = model_poisson.gen(int(ntrain / 2))
# data_nb = model_nb.gen(int(ntrain / 2))

data_poi = model_poisson.gen(params_poi)
data_nb = model_nb.gen(params_nb)

In [None]:
# shuffle and set up model index target vector 
x = np.vstack((data_poi, data_nb))

# define model indices
m = np.hstack((np.zeros(data_poi.shape[0]), np.ones(data_nb.shape[0]))).squeeze().astype(int)

# shuffle data
shuffle_indices = np.arange(ntrain)
np.random.shuffle(shuffle_indices)
x = x[shuffle_indices,]
m = m[shuffle_indices].tolist()

# calculate summary stats
sx = calculate_stats_toy_examples(x)
sx, training_norm = normalize(sx)

## Set up the NN and train it 

In [None]:
model = ClassificationSingleLayerMDN(ndim_input=2, n_hidden=10)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
trainer = Trainer(model, optimizer, verbose=True, classification=True)

n_epochs = 200 
n_minibatch = 10

# train with training data
loss_trace = trainer.train(sx, m, n_epochs=n_epochs, n_minibatch=n_minibatch)
plt.plot(loss_trace)
plt.ylabel('loss')
plt.xlabel('iterations');

## Observe Poisson data and predict underlying model 

In [None]:
x_obs = rng.poisson(lam=k1*theta1, size=sample_size)
sx_obs = calculate_stats_toy_examples(x_obs).squeeze()
sx_obs, training_norm = normalize(sx_obs, training_norm)

softmax = nn.Softmax(dim=0)
out_act = model(Variable(torch.Tensor(sx_obs)))
p_vec = softmax(out_act).data.numpy()
print('P(poisson | data) = {:.2f}'.format(p_vec[0]))

## Given the predicted underlying model we can learn the posterior of its parameters

In [None]:
# define a network to approximate the posterior with a MoG 
model = UnivariateMogMDN(ndim_input=2, n_hidden=10, n_components=3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
trainer = Trainer(model, optimizer, verbose=True)

In [None]:
sx_poi = calculate_stats_toy_examples(data_poi)
sx_poi, training_norm = normalize(sx_poi)

In [None]:
loss_trace = trainer.train(sx_poi, params_poi, n_epochs=100, n_minibatch=10)

In [None]:
n_thetas = 1000
thetas_poisson = np.linspace(0, 30, n_thetas)
post_poisson = get_mog_posterior(model, sx_obs, thetas_poisson)

In [None]:
# get true posteriors 
# get analytical gamma posterior 
k_post = k1 + np.sum(x_obs)

# use the posterior given the summary stats, not the data vector 
scale_post = 1. / (sample_size + theta1**-1)

# somehow we have to scale with N again, why? because the scale is changed due to s(x) by 1/N  
true_post_poisson = gamma.pdf(x=thetas_poisson, a=k_post, scale=scale_post)

In [None]:
plt.figure(figsize=(15, 10))

plt.subplot(211)
plt.title(r'Gamma posterior fit for Poisson $\lambda$')
# plt.plot(thetas_poisson, gamma_prior.pdf(thetas_poisson), label='prior')
plt.plot(thetas_poisson, true_post_poisson, label='analytical posterior')
plt.plot(thetas_poisson, post_poisson.data.numpy(), label='predicted posterior')
plt.axvline(x=k1 * theta1, color="C2", label='true lambda')
plt.legend()
plt.xlabel(r'$\lambda$');
