In [6]:
import sys
sys.path.insert(0, './')

from collections import namedtuple
import torch
import numpy as np

from gsm import dists
from gsm import gsm

# For plotting.
from bokeh.io import show, output_notebook, export_png
from bokeh.plotting import figure, gridplot
from bokeh.application import Application
from bokeh.models import ColumnDataSource
from bokeh.palettes import viridis, cividis, inferno
output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
x_range = y_range = (-5, 5)
nsamples = 20
mu = np.array([1, -.5])
weights = np.array([[1.2, 2.2]])
W = weights
tau = 10000
C = W.T @ W + (1./tau) * np.identity(2)
L = np.linalg.cholesky(C)
psis = mu + np.random.randn(nsamples, 2) @ L.T
emp_mu = psis.mean(axis=0)
pi = np.exp(psis) / (1 + np.exp(psis).sum(axis=1)[:, None])

precs = 1e-2 + 2 * np.exp(psis[:, -1])
means = psis[:, 0] / precs
n_var = 1/precs
n_mean = means

fig1 = figure(width=400, height=400, x_range=x_range, y_range=y_range)
fig1.circle(psis[:, 0], psis[:, 1], alpha=.5)

fig2 = figure(width=400, height=400, x_range=(0, 1), y_range=(0, 1))
fig2.line([1, 0], [0, 1], line_width=2, color='black')
fig2.circle(pi[:, 0], pi[:, 1], alpha=.5)


show(gridplot([[fig1, fig2]]))

In [8]:
class CategoricalLogLikelihood:
    
    def __call__(self, nparams, stats_counts):
        stats = stats_counts[:, :-1]
        counts = stats_counts[:, -1]
        lnorm = torch.log(1 + nparams.exp().sum(dim=-1))
        return (nparams * stats[:, None, :]).sum(dim=-1) - counts[:, None] * lnorm
    
def onehot(labels, max_label, dtype, device):
    retval = torch.zeros(len(labels), max_label, dtype=dtype, device=device)
    idxs = torch.arange(0, len(labels)).long()
    retval[idxs, labels] = 1
    return retval
    
def generate_data(pi, nsamples=100):
    pi = np.c_[pi, 1 - pi.sum(axis=-1)]
    K = len(pi)
    nsamples_pm = np.random.multinomial(nsamples, [1./ K] * K)
    data = []
    Ts = []
    for i in range(K):
        m_data = np.random.choice(pi.shape[-1], size=nsamples_pm[i], p=pi[i])
        m_data = torch.LongTensor(m_data)
        # Compute the accumulated sufficient statistics of the 
        # samples for each model.
        stats = onehot(m_data, max_label=3, dtype=m_data.dtype, 
                       device=m_data.device)[:, :-1].sum(dim=0)[None]
        Ts.append(stats)
    T = torch.cat(Ts, dim=0)
    counts = torch.from_numpy(nsamples_pm).long()
    T = torch.cat([T, counts[:, None]], dim=-1)
    return T

In [18]:
obs_dim = 2
latent_dim = 2
svectors_counts = generate_data(pi, nsamples=10000).float()
model = gsm.GSM(latent_dim, obs_dim, CategoricalLogLikelihood())
latents = gsm.create_latent_posteriors(len(svectors_counts), latent_dim)

In [19]:
epochs = 2000
params = []
for latent in latents:
    params += list(latent.parameters())
params += list(model.parameters())
optim = torch.optim.Adam(params, lr=1e-1)

# Objective function: -ELBO.
def neg_elbo_fn(model, data, latents, nsamples_latents, 
                nsamples_params):
    exp_llh, l_kl_div = model(data, latents, nsamples_latents,
                              nsamples_params)
    kl_div = l_kl_div.sum() + model.kl_div_posterior_prior()
    counts = data[:, -1].sum()
    return -(exp_llh.sum() - kl_div) / counts, kl_div /counts

elbos = []
kl_divs = []
for epoch in range(epochs):
    optim.zero_grad()
    neg_elbo, kl_div = neg_elbo_fn(model, svectors_counts, latents, 
                                   nsamples_latents=10,
                                   nsamples_params=10)
    neg_elbo.backward()
    elbos.append(float(-neg_elbo))
    kl_divs.append(float(kl_div))
    optim.step()
    
    # Minimum Divergence Step.
    model.update_prior(latents)

In [20]:
fig1 = figure(width=300, height=300)
fig1.line(range(epochs), np.array(elbos) + np.array(kl_divs))
fig2 = figure(width=300, height=300)
fig2.line(range(epochs), kl_divs)
fig3 = figure(width=300, height=300)
fig3.line(range(epochs), elbos)
show(gridplot([[fig1, fig2], [fig3]]))

In [21]:
fig = figure(x_range=(-3, 3), y_range=(-3, 3))

p_m = model.prior.prior.mean.numpy()
p_cov = (model.prior.posterior.dof * model.prior.posterior.scale_matrix).inverse()
p_cov = p_cov.numpy()
plotting.plot_normal(fig, p_m, p_cov, alpha=.3, color='salmon')

for latent in latents:
    l_m = latent.mean.detach().numpy()
    l_cov = latent.diag_cov.diag().detach().numpy()
    #plotting.plot_normal(fig, l_m, l_cov, n_std_dev=2, alpha=.2)
    fig.cross([l_m[0]], [l_m[1]], size=5)
show(fig)

In [22]:
model.update_prior(latents)

In [23]:
fig = figure(x_range=(-3, 3), y_range=(-3, 3))

p_m = model.prior.prior.mean.numpy()
p_cov = (model.prior.posterior.dof * model.prior.posterior.scale_matrix).inverse()
p_cov = p_cov.numpy()
plotting.plot_normal(fig, p_m, p_cov, alpha=.3, color='salmon')

for latent in latents:
    l_m = latent.mean.detach().numpy()
    l_cov = latent.diag_cov.diag().detach().numpy()
    plotting.plot_normal(fig, l_m, l_cov, n_std_dev=2, alpha=.2)
    fig.cross([l_m[0]], [l_m[1]], size=5)
show(fig)

In [24]:
s_params = model.sample_params(latents, nsamples_latents=100, nsamples_params=100).reshape(-1, 2).detach().numpy()
print(s_params.shape)
fig1 = figure(width=400, height=400, x_range=x_range, y_range=y_range)
fig1.circle(s_params[:, 0], s_params[:, 1], alpha=.1, color='red')
fig1.cross(psis[:, 0], psis[:, 1], alpha=1, size=10, line_width=2, color='black')

show(fig1)

(200000, 2)


In [16]:
model.trans.kl_div_posterior_prior(), model.prior.kl_div_posterior_prior()

(tensor(14.1702, grad_fn=<AddBackward0>), tensor(14.6824))

In [17]:
latents_exp_stats = torch.cat([
    latent.expected_sufficient_stats(store_full_cov=True)[None]
    for latent in latents])

new_nparams = model.prior.prior.natural_parameters() + latents_exp_stats.sum(dim=0).detach()
model.prior.posterior.params = dists.NormalWishartStdParams.from_natural_parameters(new_nparams)