# Generalized Subspace Model


In [1]:
# Add "beer" to the PYTHONPATH
import sys
sys.path.insert(0, '../')

import math
import copy

import beer
import numpy as np
import torch

# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d
from bokeh.palettes import Category10 as palette
output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

## Synthetic Data

As an illustration, we generate a synthetic data set composed of two Normal distributed cluster. One has a diagonal covariance matrix whereas the other has a dense covariance matrix.

In [55]:
def generate_data(global_mean, angle, size):
    rotation1 = np.array([
        [math.cos(angle), -math.sin(angle)],
        [math.sin(angle), math.cos(angle)]
    ])
    rotation2 = np.array([
        [math.cos(-angle), -math.sin(-angle)],
        [math.sin(-angle), math.cos(-angle)]
    ])
    scale = np.array([.5, 2.])
    mean1 = global_mean + np.array([0, 2])
    mean2 = global_mean - np.array([0, 2])
    data1 = (scale * np.random.randn(size // 2, 2)) @ rotation1 + mean1
    data2 = (scale * np.random.randn(size // 2, 2)) @ rotation2 + mean2
    data = np.vstack([data1, data2])
    np.random.shuffle(data)
    return data

datasets2 = []
start_angle = -.5 * math.pi
boundary = 50
nmodels = 10
for h in np.linspace(-boundary, boundary, nmodels):
    mean = np.array([1, 0]) * h
    angle = start_angle + (h + boundary) * (math.pi) / (2 * boundary) 
    data = generate_data(mean, angle, size=200)
    datasets2.append(data)
data = np.vstack(datasets2)

# Colors 
colors = palette[10]

fig = figure(width=400, height=400)
for color, dataset in zip(colors, datasets2):
    fig.circle(dataset[:, 0], dataset[:, 1], color=color)
show(fig)

In [58]:
def generate_data(global_mean, angle, size):
    rotation1 = np.array([
        [math.cos(angle), -math.sin(angle)],
        [math.sin(angle), math.cos(angle)]
    ])
    rotation2 = np.array([
        [math.cos(-angle), -math.sin(-angle)],
        [math.sin(-angle), math.cos(-angle)]
    ])
    scale = np.array([.5, 2.])
    mean1 = global_mean + np.array([0, 2])
    mean2 = global_mean - np.array([0, 2])
    data1 = (scale * np.random.randn(size // 2, 2)) @ rotation1 + mean1
    data2 = (scale * np.random.randn(size // 2, 2)) @ rotation2 + mean2
    data = np.vstack([data1, data2])
    np.random.shuffle(data)
    return data

datasets = []
start_angle = -.5 * math.pi
boundary = 50
nmodels = 10
for h in np.linspace(-boundary, boundary, nmodels):
    mean = np.array([1, 0]) * h
    angle = start_angle + (h + boundary) * (math.pi) / (2 * boundary) 
    data = generate_data(mean, angle, size=100)
    datasets.append(data)
data = np.vstack(datasets)

# Convert the data to pytorch tensor to work with beer.  
datasets = [torch.from_numpy(data) for data in datasets]

# Colors 
colors = palette[10]

fig = figure(width=400, height=400)
for color, dataset in zip(colors, datasets):
    dataset = dataset.numpy()
    fig.circle(dataset[:, 0], dataset[:, 1], color=color)
show(fig)

## Pre-training

First we train a GMM for each dataset. This GMM will served as starting point to build the GSM at the next steps.

In [38]:
def create_gmm(dataset, size, cov_type):
    data_mean = dataset.mean(dim=0)
    data_var = dataset.var(dim=0)
    modelset = beer.NormalSet.create(data_mean, data_var / 10, size=size,
                                     noise_std=1, cov_type=cov_type)
    return beer.Mixture.create(modelset).double()  

gmms = [create_gmm(dataset, size=2, cov_type='full')
        for dataset in datasets]

print('Standard GMM:')
print('=============')
print(gmms[0])

Standard GMM:
Mixture(
  (modelset): NormalSet(
    (means_precisions): ConjugateBayesianParameter(prior=NormalWishart, posterior=NormalWishart)
  )
  (weights): ConjugateBayesianParameter(prior=Dirichlet, posterior=Dirichlet)
)


In [39]:
def fit_gmm(gmm, dataset, epochs=100):
    optim = beer.VariationalBayesOptimizer(gmm.mean_field_factorization(), lrate=1.)
    for epoch in range(epochs):
        optim.init_step()
        elbo = beer.evidence_lower_bound(gmm, dataset)
        elbo.backward()
        optim.step()

for gmm, dataset in zip(gmms, datasets):
    fit_gmm(gmm, dataset)

In [40]:
fig = figure(width=400, height=400)
for gmm, dataset, color in zip(gmms, datasets, colors):
    dataset = dataset.numpy()
    plotting.plot_gmm(fig, gmm, alpha=.5, color=color)
    fig.circle(dataset[:, 0], dataset[:, 1], color=color, alpha=.1)
    
show(fig)

## Generalized Subspace Model

In [41]:
# Prior over the latent space.
latent_dim = 2
latent_prior = beer.Normal.create(
    torch.zeros(latent_dim), 
    torch.ones(2)
).double()
    
# Create a new set of GMMs (initialized from the original GMMs)
# whose parameters modeled by the subspace.
subspace_gmms = copy.deepcopy(gmms)
for gmm in subspace_gmms:
    newparams = {
        param: beer.SubspaceBayesianParameter.from_parameter(param, latent_prior)
        for param in gmm.bayesian_parameters()
    }
    gmm.replace_parameters(newparams)

print('Subspace GMM')
print('============')
print(subspace_gmms[0])    
print()

# We keep a GMM which will serve as a "template model"
# for the GSM.
template_gmm = copy.deepcopy(subspace_gmms[0])

# Create the final Generalized Subspace Model
gsm = beer.GSM.create(template_gmm, latent_dim, latent_prior, 
                      latent_nsamples=10, params_nsamples=10).double()
print('Generalized Subspace Model')
print('==========================')
print(gsm)    
print()

latent_posts = gsm.new_latent_posteriors(len(gmms))
print('Latent Posteriors')
print('=================')
print(latent_posts)    

Subspace GMM
Mixture(
  (modelset): NormalSet(
    (means_precisions): SubspaceBayesianParameter(prior=Normal, posterior=<unspecified>)
  )
  (weights): SubspaceBayesianParameter(prior=Normal, posterior=<unspecified>)
)

Generalized Subspace Model
GSM(
  (model): Mixture(
    (modelset): NormalSet(
      (means_precisions): SubspaceBayesianParameter(prior=Normal, posterior=<unspecified>)
    )
    (weights): SubspaceBayesianParameter(prior=Normal, posterior=<unspecified>)
  )
  (affine_transform): AffineTransform(
    (weights): ConjugateBayesianParameter(prior=NormalDiagonalCovariance, posterior=NormalDiagonalCovariance)
    (bias): ConjugateBayesianParameter(prior=NormalDiagonalCovariance, posterior=NormalDiagonalCovariance)
  )
  (latent_prior): Normal(
    (mean_precision): ConjugateBayesianParameter(prior=NormalWishart, posterior=NormalWishart)
  )
)

Latent Posteriors
NormalDiagonalCovariance(
  (params): _MeanLogDiagCov(mean=Parameter containing:
  tensor([[0., 0.],
          [0

In [42]:
gsm.update_models(subspace_gmms, latent_posts, latent_nsamples=10, params_nsamples=10)

fig = figure(width=400, height=400)
for gmm, dataset, color in zip(subspace_gmms, datasets, colors):
    dataset = dataset.numpy()
    plotting.plot_gmm(fig, gmm, alpha=.5, color=color)
    fig.circle(dataset[:, 0], dataset[:, 1], color=color, alpha=.1)
    
show(fig)

In [43]:
def accumulate_stats(models, datasets):
    for gmm, X in zip(models, datasets):
        for param in gmm.bayesian_parameters():
            param.zero_stats()
        elbo = beer.evidence_lower_bound(gmm, X) 
        elbo.backward(std_params=False)

In [44]:
epochs = 15000
lrate_cjg = 1e-1
lrate_std = 1e-1
params = list(latent_posts.parameters()) + list(gsm.parameters())
std_optim = torch.optim.Adam(params, lr=lrate_std)
optim = beer.VariationalBayesOptimizer([[]], lrate=lrate_cjg, std_optim=std_optim)


elbos = []

for epoch in range(1, epochs + 1): 
    #accumulate_stats(models, datasets)
    optim.init_step()
    elbo = beer.evidence_lower_bound(gsm, subspace_gmms, latent_posts=latent_posts, 
                                     latent_nsamples=5, params_nsamples=5)
    elbo.backward()
    optim.step()
    elbos.append(float(elbo))

In [45]:
fig = figure()
fig.line(range(len(elbos)), elbos)
show(fig)

In [46]:
gsm.update_models(subspace_gmms, latent_posts, latent_nsamples=10, params_nsamples=10)

fig1 = figure(title='Standard GMM')
for gmm, dataset, color in zip(gmms, datasets, colors):
    dataset = dataset.numpy()
    plotting.plot_gmm(fig1, gmm, alpha=.5, color=color)
    fig1.circle(dataset[:, 0], dataset[:, 1], color=color, alpha=.1)
    
fig2 = figure(title='Subspace GMM', x_range=fig1.x_range, y_range=fig1.y_range)
for gmm, dataset, color in zip(subspace_gmms, datasets, colors):
    dataset = dataset.numpy()
    plotting.plot_gmm(fig2, gmm, alpha=.5, color=color)
    fig2.circle(dataset[:, 0], dataset[:, 1], color=color, alpha=.1)
    
show(gridplot([[fig1, fig2]]))

In [76]:
%%prun -s tottime 
epochs = 1_000
lrate_cjg = 1e-1
lrate_std = 1e-1
params = list(latent_posts.parameters()) + list(gsm.parameters())
std_optim = torch.optim.Adam(params, lr=lrate_std)
optim = beer.VariationalBayesOptimizer(gsm.mean_field_factorization(), lrate=lrate_cjg, std_optim=std_optim)


elbos = []
for epoch in range(1, epochs + 1): 
    gsm.update_models(subspace_gmms, latent_posts, latent_nsamples=10, params_nsamples=10)
    accumulate_stats(subspace_gmms, datasets)
    optim.init_step()
    elbo = beer.evidence_lower_bound(gsm, subspace_gmms, latent_posts=latent_posts, 
                                     latent_nsamples=5, params_nsamples=5)
    elbo.backward()
    optim.step()
    elbos.append(float(elbo))

 

         7635540 function calls (7512422 primitive calls) in 33.539 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     1000    7.117    0.007    7.117    0.007 {method 'run_backward' of 'torch._C._EngineBase' objects}
    84000    2.785    0.000    2.785    0.000 {built-in method matmul}
    20000    2.624    0.000    7.362    0.000 normalwishart.py:51(pdfvectors_from_rvectors)
   563000    1.873    0.000    2.256    0.000 module.py:537(__setattr__)
   239000    1.545    0.000    1.545    0.000 {method 'reshape' of 'torch._C._TensorBase' objects}
    30000    1.339    0.000    1.339    0.000 {built-in method logsumexp}
   128000    0.938    0.000    0.938    0.000 {method 'sum' of 'torch._C._TensorBase' objects}
    60000    0.937    0.000    0.937    0.000 {built-in method cat}
    43000    0.835    0.000    0.835    0.000 {method 'mean' of 'torch._C._TensorBase' objects}
    20000    0.767    0.000   12.492    0.001 gsm

In [77]:
fig = figure()
fig.line(range(len(elbos)), elbos)
show(fig)

In [79]:
gsm.update_models(subspace_gmms, latent_posts, latent_nsamples=10, params_nsamples=10)

fig1 = figure(title='Standard GMM')
for gmm, dataset, color in zip(gmms, datasets, colors):
    dataset = dataset.numpy()
    plotting.plot_gmm(fig1, gmm, alpha=.5, color=color)
    fig1.circle(dataset[:, 0], dataset[:, 1], color=color, alpha=.5)
    
fig2 = figure(title='Subspace GMM', x_range=fig1.x_range, y_range=fig1.y_range)
for gmm, dataset, color in zip(subspace_gmms, datasets, colors):
    dataset = dataset.numpy()
    plotting.plot_gmm(fig2, gmm, alpha=.5, color=color)
    fig2.circle(dataset[:, 0], dataset[:, 1], color=color, alpha=.5)
    
fig3 = figure(title='Latent space')
mean, cov = gsm.latent_prior.mean.numpy(), gsm.latent_prior.cov.numpy()
plotting.plot_normal(fig3, mean, cov, alpha=.5, color='pink')
for mean, diag_cov, color in zip(latent_posts.params.mean, latent_posts.params.diag_cov, colors):
    mean = mean.detach().numpy()
    cov = diag_cov.detach().diag().numpy()
    plotting.plot_normal(fig3, mean, cov, alpha=.5, color=color)
    
show(gridplot([[fig1, fig2], [None, fig3]]))