# Generalized Subspace Model

In [25]:
# Add "beer" to the PYTHONPATH
import sys
sys.path.insert(0, '../')

import math
import copy

import beer
import numpy as np
import torch

# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d
from bokeh.palettes import Category10 as palette
output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Synthetic Data

In [2]:
def generate_data(global_mean, angle, size, weight):
    rotation = np.array([
        [math.cos(angle), -math.sin(angle)],
        [math.sin(angle), math.cos(angle)]
    ])
    scale = np.array([.5, 2])
    init_cov = np.diag(scale**2)
    cov1 = rotation.T @ init_cov @ rotation
    cov2 = rotation @ init_cov @ rotation.T
    mean1 = global_mean + np.array([0, 3])
    mean2 = global_mean - np.array([0, 3])
    data1 = (scale * np.random.randn(int(size * weight), 2)) @ rotation + mean1 
    data2 = (scale * np.random.randn(int(size * (1 - weight)), 2)) @ rotation.T + mean2 
    data = np.vstack([data1, data2])
    np.random.shuffle(data)
    return data, (mean1, mean2), (cov1, cov2), (weight, 1-weight)

datasets = []
means = []
covs = []
weights = []
start_angle = -.5 * math.pi
boundary = 50
nmodels = 10
for h in np.linspace(-boundary, boundary, nmodels):
    mean = np.array([1., 0]) * h
    ratio = (h + boundary) / (2 * boundary)
    angle = start_angle + ratio * (math.pi)
    w_ratio = .1 + .8 * ratio
    data, m_means, m_covs, m_weights = generate_data(mean, angle, size=200, weight=w_ratio)
    datasets.append(data)
    means.append(m_means)
    covs.append(m_covs)
    weights.append(m_weights)
data = np.vstack(datasets)

# Convert the data to pytorch tensor to work with beer.  
datasets = [torch.from_numpy(data) for data in datasets]

# Colors 
colors = palette[10] * 2

fig = figure()
for color, dataset, m_means, m_covs, m_weights in zip(colors, datasets, means, covs, weights):
    dataset = dataset.numpy()
    plotting.plot_normal(fig, m_means[0], m_covs[0], alpha=.5 * m_weights[0], color=color)
    plotting.plot_normal(fig, m_means[1], m_covs[1], alpha=.5 * m_weights[1], color=color)
    fig.cross(m_means[0][0], m_means[0][1], color=color, size=7, line_width=2)
    fig.cross(m_means[1][0], m_means[1][1], color=color, size=7, line_width=2)
show(fig)

## Generalized Subspace Model

### Creating the GSM

The GSM is composed of a latent prior, an affine transformation, a generic subspace model which indicates how to transform the projections of the embedding into a concrete model and the instances of the generic subspace model (paired with latent posterior distributions, one for each subspace model instance).

In [3]:
obs_dim = 2
latent_dim = 2

# Type of covariance for the Subspace GMMs.
cov_type = 'full' # full/diagonal/isotropic
 
# Prior over the latent space.
latent_prior = beer.Normal.create(
    torch.zeros(latent_dim), 
    torch.ones(latent_dim),
    prior_strength=1e-3
).double()

# Data model (SGMM).
modelset = beer.NormalSet.create(
    mean=torch.zeros(obs_dim), cov=torch.ones(obs_dim),
    size=2,
    cov_type=cov_type
)
sgmm = beer.Mixture.create(modelset).double()

# We specify which parameters will be handled by the
# subspace in the GMM. 
newparams = {
    param: beer.SubspaceBayesianParameter.from_parameter(param, latent_prior)
    for param in sgmm.bayesian_parameters()
}
sgmm.replace_parameters(newparams)

# Create the Generalized Subspace Model
gsm = beer.GSM.create(sgmm, latent_dim, latent_prior, prior_strength=1e-3).double()

# Create the instance of SGMM for each dataset
sgmms, latent_posts = gsm.new_models(len(datasets), cov_type='diagonal')

print('Latent prior')
print('============')
print(latent_prior)    
print()

print('Subspace GMM (generic model)')
print('============================')
print(sgmm)    
print()

print('Generalized Subspace Model')
print('==========================')
print(gsm) 
print()

print('Subspace GMMs (concrete instances)')
print('==================================')
print('(1) -', sgmms[0])   
print()
print('...')
print()
print(f'({len(datasets)}) -', sgmms[-1])
print()

Latent prior
Normal(
  (mean_precision): ConjugateBayesianParameter(prior=NormalWishart, posterior=NormalWishart)
)

Subspace GMM (generic model)
Mixture(
  (modelset): NormalSet(
    (means_precisions): SubspaceBayesianParameter(prior=Normal, posterior=<unspecified>)
  )
  (weights): SubspaceBayesianParameter(prior=Normal, posterior=<unspecified>)
)

Generalized Subspace Model
GSM(
  (model): Mixture(
    (modelset): NormalSet(
      (means_precisions): SubspaceBayesianParameter(prior=Normal, posterior=<unspecified>)
    )
    (weights): SubspaceBayesianParameter(prior=Normal, posterior=<unspecified>)
  )
  (transform): AffineTransform(
    (weights): BayesianParameter(prior=NormalDiagonalCovariance, posterior=NormalDiagonalCovariance)
    (bias): BayesianParameter(prior=NormalDiagonalCovariance, posterior=NormalDiagonalCovariance)
  )
  (latent_prior): Normal(
    (mean_precision): ConjugateBayesianParameter(prior=NormalWishart, posterior=NormalWishart)
  )
)

Subspace GMMs (concrete

### Pre-training

Before to start the training we need to inialize the subspace. To do so, we first train a Normal distribution for each dataset and we'll use its statistics as initial statistics for all the Normal distributions of the SGMMs. 

In [4]:
def create_normal(dataset, cov_type):
    data_mean = dataset.mean(dim=0)
    data_var = dataset.var(dim=0)
    return beer.Normal.create(data_mean, data_var, cov_type=cov_type).double()

def fit_normal(normal, dataset, epochs=100):
    optim = beer.VBConjugateOptimizer(normal.mean_field_factorization(), lrate=1.)
    for epoch in range(epochs):
        optim.init_step()
        elbo = beer.evidence_lower_bound(normal, dataset)
        elbo.backward()
        optim.step()
        
normals = [create_normal(dataset, cov_type=cov_type) for dataset in datasets]
for normal, dataset in zip(normals, datasets):
    fit_normal(normal, dataset)
    
fig = figure(width=400, height=400)
for normal, dataset, color in zip(normals, datasets, colors):
    dataset = dataset.numpy()
    mean = normal.mean.numpy()
    cov = normal.cov.numpy()
    plotting.plot_normal(fig, mean, cov, alpha=.5, color=color)
    fig.circle(dataset[:, 0], dataset[:, 1], color=color, alpha=.1)
    
show(fig)

In [5]:
# Prepare the initial weights sufficient statistics.
ncomp = len(sgmm.modelset)
weights_stats = torch.zeros(len(datasets), ncomp).double()
counts = torch.cat([torch.tensor(float(len(dataset))).view(1) for dataset in datasets]).double()
weights_stats[:] = counts[:, None] / ncomp
weights_stats[:, -1] = counts
weights_stats

# Prepare the initial sufficient statistics for the 
# components of the GMM.
normals_stats =  [normal.mean_precision.stats.repeat(ncomp, 1)
                  for normal in normals]
for i, gmm in enumerate(sgmms):
    gmm.weights.stats = weights_stats[i]
    gmm.modelset.means_precisions.stats = normals_stats[i]
    
# NOTE: we initialize the stats of all the parameters
# whether they are included in the subspace or not.
# For parameters that are not included in the subspace,
# this initialization will be discarded during
# the training ("optim.init_step()" clear the stats).

In [6]:
epochs = 15_000
params = gsm.conjugate_bayesian_parameters(keepgroups=True)
cjg_optim = beer.VBConjugateOptimizer(params, lrate=1.)
params = list(latent_posts.parameters()) + list(gsm.parameters())
std_optim = torch.optim.Adam(params, lr=1e-1)
optim = beer.VBOptimizer(cjg_optim, std_optim)

elbos = []
for epoch in range(1, epochs + 1): 
    optim.init_step()
    elbo = beer.evidence_lower_bound(gsm, sgmms, latent_posts=latent_posts, 
                                     latent_nsamples=5, params_nsamples=5)
    elbo.backward()
    optim.step()
    elbos.append(float(elbo))

In [7]:
fig = figure()
fig.line(range(len(elbos)), elbos)
show(fig)

In [8]:
fig1 = figure(title='True model')
for color, dataset, m_means, m_covs, m_weights in zip(colors, datasets, means, covs, weights):
    dataset = dataset.numpy()
    plotting.plot_normal(fig1, m_means[0], m_covs[0], alpha=.7 * m_weights[0], color=color)
    plotting.plot_normal(fig1, m_means[1], m_covs[1], alpha=.7 * m_weights[1], color=color)
    fig1.circle(dataset[:, 0], dataset[:, 1], alpha=.5, color=color)
    
fig2 = figure(title='Subspace GMM', x_range=fig1.x_range, y_range=fig1.y_range)
for gmm, dataset, color in zip(sgmms, datasets, colors):
    dataset = dataset.numpy()
    plotting.plot_gmm(fig2, gmm, alpha=.7, color=color)
    fig2.circle(dataset[:, 0], dataset[:, 1], color=color, alpha=.5)
    
fig3 = figure(title='Latent space')
mean, cov = gsm.latent_prior.mean.numpy(), gsm.latent_prior.cov.numpy()
plotting.plot_normal(fig3, mean, cov, alpha=.5, color='pink')
for mean, cov, color in zip(latent_posts.params.mean, latent_posts.params.diag_cov, colors):
    mean = mean.detach().numpy()
    cov = cov.diag().detach().numpy()
    plotting.plot_normal(fig3, mean, cov, alpha=.5, color=color)


show(gridplot([[fig1, fig2], [None, fig3]]))

### Training

Now the GSM is initialized so we start the "actual" training by updating the statistics of the parameters whenever the GSM is updated

In [9]:
epochs = 5_000
stats_update_rate = 100

# This function accumulate the statistics for the parameters
# of the subspace and update the parameters that are not
# part of the subspace.
def accumulate_stats(models, datasets, optims):
    for model, X, optim in zip(models, datasets, optims):
        optim.init_step()
        elbo = beer.evidence_lower_bound(model, X) 
        elbo.backward(std_params=False)
        optim.step()
        
# Prepare an optimzer for each SGMM. The optimizer
# will handle all parameters that are note included
# in the subspace.
sgmms_optims = []
for gmm in sgmms:
    pfilter = lambda param: not isinstance(param, beer.SubspaceBayesianParameter)
    params = gmm.bayesian_parameters(
        paramtype=beer.ConjugateBayesianParameter,
        paramfilter=pfilter,
        keepgroups=True
    )
    soptim = beer.VBConjugateOptimizer(params, lrate=1.)
    sgmms_optims.append(soptim)
    

elbos = []
for epoch in range(1, epochs + 1): 
    if (epoch - 1) % stats_update_rate == 0:
        accumulate_stats(sgmms, datasets, sgmms_optims)
    optim.init_step()
    elbo = beer.evidence_lower_bound(gsm, sgmms, latent_posts=latent_posts, 
                                     latent_nsamples=10, params_nsamples=10)
    elbo.backward()
    optim.step()
    elbos.append(float(elbo))

In [10]:
fig = figure()
fig.line(range(len(elbos)), elbos)
show(fig)

In [24]:
fig1 = figure(title='True model', width=600, height=600)
for color, dataset, m_means, m_covs, m_weights in zip(colors, datasets, means, covs, weights):
    dataset = dataset.numpy()
    plotting.plot_normal(fig1, m_means[0], m_covs[0], alpha=.8 * m_weights[0], color=color)
    plotting.plot_normal(fig1, m_means[1], m_covs[1], alpha=.8 * m_weights[1], color=color)
    fig1.circle(dataset[:, 0], dataset[:, 1], alpha=.5, color=color)
    
fig2 = figure(title='Subspace GMM', x_range=fig1.x_range, y_range=fig1.y_range, width=600, height=600)
for gmm, dataset, color in zip(sgmms, datasets, colors):
    dataset = dataset.numpy()
    plotting.plot_gmm(fig2, gmm, alpha=.8, color=color)
    fig2.circle(dataset[:, 0], dataset[:, 1], color=color, alpha=.5)
    
fig3 = figure(title='Latent space', width=600, height=600)
mean, cov = gsm.latent_prior.mean.numpy(), gsm.latent_prior.cov.numpy()
plotting.plot_normal(fig3, mean, cov, alpha=.5, color='pink')
#fig3.x_range = fig3.y_range
for mean, cov, color in zip(latent_posts.params.mean, latent_posts.params.diag_cov, colors):
    mean = mean.detach().numpy()
    cov = cov.diag().detach().numpy()
    plotting.plot_normal(fig3, mean, cov, alpha=.5, color=color)


#show(gridplot([[fig1, fig2], [None, fig3]]))
show(gridplot([[fig2, fig3]]))