# Hierarchical Generalized Subspace Model

In [1]:
# Add "beer" to the PYTHONPATH
import os
import sys
sys.path.insert(0, '../')

import math
import copy

import beer
import numpy as np
import torch

# For plotting.
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d
from bokeh.palettes import Category10 as palette
output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

In [2]:
from beer.models.gsm import HierarchicalAffineTransform, HierarchicalGSM, AffineTransform
# torch.set_num_threads(4)

## Synthetic Data

In [3]:
def generate_data(global_mean, angle, size, weight):
    rotation = np.array([
        [math.cos(angle), -math.sin(angle)],
        [math.sin(angle), math.cos(angle)]
    ])
    scale = np.array([.5, 2])
    init_cov = np.diag(scale**2)
    cov1 = rotation.T @ init_cov @ rotation
    cov2 = rotation @ init_cov @ rotation.T
    mean1 = global_mean + np.array([0, 3])
    mean2 = global_mean - np.array([0, 3])
    data1 = (scale * np.random.randn(int(size * weight), 2)) @ rotation + mean1 
    data2 = (scale * np.random.randn(int(size * (1 - weight)), 2)) @ rotation.T + mean2 
    data = np.vstack([data1, data2])
    np.random.shuffle(data)
    return data, (mean1, mean2), (cov1, cov2), (weight, 1-weight)


def generate_dataset(num_sets=1, nmodels=10, size_data=200):
    angle_shifts = np.linspace(0, np.pi, num_sets)
    mean_shifts = [0.1*n for n in range(num_sets)]
    boundary_shifts = [5*n for n in range(num_sets)]
    full_dataset = []
    for ni in range(num_sets):
        datasets = []
        means = []
        covs = []
        weights = []
        start_angle = -.5 * math.pi
        boundary = 50 + boundary_shifts[ni]
        for h in np.linspace(-boundary, boundary, nmodels):
            mean = np.array([1., 0]) * h + mean_shifts[ni]
            ratio = (h + boundary) / (2 * boundary)
            angle = start_angle + ratio * ((math.pi) + angle_shifts[ni])
            w_ratio = .1 + .8 * ratio
            data, m_means, m_covs, m_weights = generate_data(mean, angle, size=size_data, weight=w_ratio)
            datasets.append(data)
            means.append(m_means)
            covs.append(m_covs)
            weights.append(m_weights)
        data = np.vstack(datasets)

        # Convert the data to pytorch tensor to work with beer.  
        datasets = [torch.from_numpy(data) for data in datasets]
        full_dataset.append([means, covs, weights, datasets])
    return full_dataset

num_langs = 4
full_data = generate_dataset(num_langs)
train_data = [d[-1] for d in full_data]
colors = palette[10] * 2

fig = figure()
for means, covs, weights, datasets in full_data:
    
    for color, dataset, m_means, m_covs, m_weights in zip(colors, datasets, means, covs, weights):
        dataset = dataset.numpy()
        plotting.plot_normal(fig, m_means[0], m_covs[0], alpha=.5 * m_weights[0], color=color)
        plotting.plot_normal(fig, m_means[1], m_covs[1], alpha=.5 * m_weights[1], color=color)
        fig.cross(m_means[0][0], m_means[0][1], color=color, size=7, line_width=2)
        fig.cross(m_means[1][0], m_means[1][1], color=color, size=7, line_width=2)
show(fig)

## Hierarchical Generalized Subspace Model

### Creating the GSM

The GSM is composed of a latent prior, an affine transformation, a generic subspace model which indicates how to transform the projections of the embedding into a concrete model and the instances of the generic subspace model (paired with latent posterior distributions, one for each subspace model instance).

In the HGSM, the affine transformation of the GSM is itself generated by another GSM with its own parameters which is shared across (potentially) multiple child GSMs

In [4]:
obs_dim = 2 # Dimension of the observations
lang_latent_dim = 2 # Dimension of the latent space of the child GSMs
latent_dim = 2 # Dimension of the latent space of the parent GSM

# Type of covariance for the Subspace GMMs.
cov_type = 'full' # full/diagonal/isotropic
 
# Prior over the latent space.
latent_prior = beer.Normal.create(
    torch.zeros(latent_dim), 
    torch.ones(latent_dim),
    prior_strength=1e-3
).double()

language_priors = [beer.Normal.create(
    torch.zeros(lang_latent_dim), 
    torch.ones(lang_latent_dim),
    prior_strength=1e-3
).double() for _ in range(num_langs + 1)]

# Data model (SGMM).
modelsets = [beer.NormalSet.create(
    mean=torch.zeros(obs_dim), cov=torch.ones(obs_dim),
    size=2,
    cov_type=cov_type
) for _ in range(num_langs + 1)]
sgmm_list = [beer.Mixture.create(modelsets[i]).double() for i in range(num_langs)]

# We specify which parameters will be handled by the
# subspace in the GMM.
for i, sgmm in enumerate(sgmm_list):
    newparams = {
        param: beer.SubspaceBayesianParameter.from_parameter(param, language_priors[i])
        for param in sgmm.bayesian_parameters()
    }
    sgmm.replace_parameters(newparams)


# Create the Generalized Subspace Models
lang_gsms = [beer.GSM.create(sg, lang_latent_dim, lang_p, prior_strength=1e-3).double()
             for sg, lang_p in zip(sgmm_list, language_priors)]

# Create the parent GSM
univ_affine_transform = AffineTransform.create(latent_dim, lang_gsms[0].transform.out_dim * (lang_gsms[0].transform.in_dim + 1),
                                               prior_strength=1e-3)

# Create each child GSM's transform from the parent GSM
pseudo_transforms = [HierarchicalAffineTransform.create(latent_prior, lang_latent_dim,
                                                       lang_gsms[0].transform.out_dim,
                                                       univ_affine_transform, cov_type='diagonal').double()
                     for gsm in lang_gsms]
# Create the root GSM object which will be used to link all GSMs together in training
root_gsm = HierarchicalGSM(univ_affine_transform, latent_prior)

# Replace the child GSM transforms with the generated transforms
for pseudo_transform, lang_gsm in zip(pseudo_transforms, lang_gsms):
    lang_gsm.transform = pseudo_transform

# Create the instance of SGMM for each dataset
lang_sgmms = [gsm.new_models(len(train_data_single), cov_type='diagonal')
              for gsm, train_data_single in zip(lang_gsms, train_data)]

lang_latent_posts = [l[1] for l in lang_sgmms]
lang_sgmms = [l[0] for l in lang_sgmms]

print('Latent prior')
print('============')
print(latent_prior)    
print()

print('Child GSM latent prior')
print('============')
print(language_priors[0])    
print()

print('Subspace GMM (generic model)')
print('============================')
print(sgmm_list[0])    
print()

print('Generalized Subspace Model')
print('==========================')
print(lang_gsms[0]) 
print()

print('Subspace GMMs (concrete instances)')
print('==================================')
print('(1) -', lang_sgmms[0][0])   
print()
print('...')
print()
print(f'({len(datasets)}) -', lang_sgmms[0][-1])
print()

Latent prior
Normal(
  (mean_precision): ConjugateBayesianParameter(prior=NormalWishart, posterior=NormalWishart)
)

Child GSM latent prior
Normal(
  (mean_precision): ConjugateBayesianParameter(prior=NormalWishart, posterior=NormalWishart)
)

Subspace GMM (generic model)
Mixture(
  (modelset): NormalSet(
    (means_precisions): SubspaceBayesianParameter(prior=Normal, posterior=<unspecified>)
  )
  (categorical): Categorical(
    (weights): SubspaceBayesianParameter(prior=Normal, posterior=<unspecified>)
  )
)

Generalized Subspace Model
GSM(
  (model): Mixture(
    (modelset): NormalSet(
      (means_precisions): SubspaceBayesianParameter(prior=Normal, posterior=<unspecified>)
    )
    (categorical): Categorical(
      (weights): SubspaceBayesianParameter(prior=Normal, posterior=<unspecified>)
    )
  )
  (transform): HierarchicalAffineTransform(
    (root_transform): AffineTransform(
      (weights): BayesianParameter(prior=NormalDiagonalCovariance, posterior=NormalDiagonalCovarianc

### Pre-training

Before starting the training, we need to inialize the subspace. To do so, we first train a Normal distribution for each dataset and we'll use its statistics as initial statistics for all the Normal distributions of the SGMMs. 

In [5]:
def create_normal(dataset, cov_type):
    data_mean = dataset.mean(dim=0)
    data_var = dataset.var(dim=0)
    return beer.Normal.create(data_mean, data_var, cov_type=cov_type).double()

def fit_normal(normal, dataset, epochs=1):
    optim = beer.VBConjugateOptimizer(normal.mean_field_factorization(), lrate=1.)
    for epoch in range(epochs):
        optim.init_step()
        elbo = beer.evidence_lower_bound(normal, dataset)
        elbo.backward()
        optim.step()

lang_normals = []
for i in range(num_langs):
    normals = [create_normal(dataset, cov_type=cov_type) for dataset in train_data[i]]
    for normal, dataset in zip(normals, train_data[i]):
        fit_normal(normal, dataset)
    lang_normals.append(normals)

figs = []
ind = 0
for normals, train_data_single in zip(lang_normals, train_data):
    ind += 1
    fig = figure(width=400, height=400, title=f'Lang{ind}')
    for normal, dataset, color in zip(normals, train_data_single, colors):
        dataset = dataset.numpy()
        mean = normal.mean.numpy()
        cov = normal.cov.numpy()
        plotting.plot_normal(fig, mean, cov, alpha=.5, color=color)
        fig.circle(dataset[:, 0], dataset[:, 1], color=color, alpha=.1)
    figs.append(fig)
fig = figure(width=400, height=400, title='All languages')
for normals, train_data_single in zip(lang_normals, train_data):
    for normal, dataset, color in zip(normals, train_data_single, colors):
        dataset = dataset.numpy()
        mean = normal.mean.numpy()
        cov = normal.cov.numpy()
        plotting.plot_normal(fig, mean, cov, alpha=.5, color=color)
        fig.circle(dataset[:, 0], dataset[:, 1], color=color, alpha=.1)

figs.append(fig)
figs_per_line = 2
figs = [figs[figs_per_line*i:figs_per_line*i+figs_per_line] for i in range(1 + len(figs)//figs_per_line)]
show(gridplot(figs))

In [6]:
# Prepare the initial weights sufficient statistics.


for sgmms, train_data_single, normals in zip(lang_sgmms, train_data, lang_normals):
    ncomp = len(sgmms[0].modelset)
    weights_stats = torch.zeros(len(train_data_single), ncomp).double()
    counts = torch.cat([torch.tensor(float(len(dataset))).view(1) for dataset in train_data_single]).double()
    weights_stats[:] = counts[:, None] / ncomp
    weights_stats[:, -1] = counts
    weights_stats

    # Prepare the initial sufficient statistics for the 
    # components of the GMM.
    normals_stats =  [normal.mean_precision.stats.repeat(ncomp, 1)
                      for normal in normals]
    for i, gmm in enumerate(sgmms):
        gmm.categorical.weights.stats = weights_stats[i]
        gmm.modelset.means_precisions.stats = normals_stats[i]
    
# NOTE: we initialize the stats of all the parameters
# whether they are included in the subspace or not.
# For parameters that are not included in the subspace,
# this initialization will be discarded during
# the training ("optim.init_step()" clear the stats).

In [7]:
epochs = 20_000

params = sum([list(gsm.conjugate_bayesian_parameters(keepgroups=True)) for gsm in lang_gsms], [])
cjg_optim = beer.VBConjugateOptimizer(params, lrate=1.)
params = sum([list(latent_posts.parameters()) + list(gsm.parameters())
              for latent_posts, gsm in zip(lang_latent_posts, lang_gsms)], [])
std_optim = torch.optim.Adam(params, lr=5e-2)
optim = beer.VBOptimizer(cjg_optim, std_optim)
elbos = []

In [8]:
for i in range(1, epochs + 1):
    optim.init_step()
    elbo = beer.evidence_lower_bound(root_gsm,
                                     [(gsm, sgmm) for gsm, sgmm in zip(lang_gsms, lang_sgmms)],
                                     univ_latent_nsamples=5,
                                     latent_posts=lang_latent_posts, 
                                     latent_nsamples=5, params_nsamples=5)
    elbo.backward()
    optim.step()
    elbos.append(float(elbo))

In [9]:
figs_per_line = 2
figs = []
fig = figure(title='ELBO')
fig.line(range(len(elbos)), elbos)
figs.append(fig)

figs = [figs[figs_per_line*i:figs_per_line*i+figs_per_line] for i in range(1 + len(figs)//figs_per_line)]
show(gridplot(figs))

In [10]:
fig1 = figure(title='True model', x_range=(-100, 100), y_range=(-10, 10))
for means, covs, weights, datasets in full_data:
    for color, dataset, m_means, m_covs, m_weights in zip(colors, datasets, means, covs, weights):
        dataset = dataset.numpy()
        plotting.plot_normal(fig1, m_means[0], m_covs[0], alpha=.5 * m_weights[0], color=color)
        plotting.plot_normal(fig1, m_means[1], m_covs[1], alpha=.5 * m_weights[1], color=color)
        fig1.circle(dataset[:, 0], dataset[:, 1], alpha=.2, color=color)

fig2 = figure(title='Subspace GMM', x_range=fig1.x_range, y_range=fig1.y_range)
for sgmms, dataset in zip(lang_sgmms, train_data):
    for gmm, dataset, color in zip(sgmms, datasets, colors):
        dataset = dataset.numpy()
        plotting.plot_gmm(fig2, gmm, alpha=.7, color=color)
        fig2.circle(dataset[:, 0], dataset[:, 1], color=color, alpha=.5)

fig3 = figure(title='Unit latent space')
mean, cov = language_priors[0].mean.numpy(), language_priors[0].cov.numpy()
plotting.plot_normal(fig3, mean, cov, alpha=.5, color='pink')
for post, color in zip(lang_latent_posts, colors):
    for mean, cov in zip(post.params.mean, post.params.diag_cov):
        mean = mean.detach().numpy()
        cov = (cov.diag().detach().numpy())
        plotting.plot_normal(fig3, mean, cov, alpha=0.5, color=color)

fig4 = figure(title='Latent space')
mean, cov = pseudo_transform.latent_prior.mean.numpy(), pseudo_transform.latent_prior.cov.numpy()
plotting.plot_normal(fig4, mean, cov, alpha=.5, color='pink')

for gsm, color in zip(lang_gsms, colors):
    mean, cov = gsm.transform.latent_posterior.params.mean, gsm.transform.latent_posterior.params.diag_cov
    mean = mean.detach().numpy()
    cov = (cov.squeeze().diag().detach().numpy())
    plotting.plot_normal(fig4, mean, cov, alpha=0.5, color=color)

show(gridplot([[fig1, fig2], [fig3, fig4]]))

### Training

Now the HGSM is initialized so we start the "actual" training by updating the statistics of the parameters whenever the HGSM is updated

In [11]:
epochs = 20_000
stats_update_rate = 100

# This function accumulate the statistics for the parameters
# of the subspace and update the parameters that are not
# part of the subspace.
def accumulate_stats(models, datasets, optims):
    for model, X, optim in zip(models, datasets, optims):
        optim.init_step()
        elbo = beer.evidence_lower_bound(model, X) 
        elbo.backward(std_params=False)
        optim.step()
        
# Prepare an optimzer for each SGMM. The optimizer
# will handle all parameters that are note included
# in the subspace.
all_sgmm_optims = []
for sgmms in lang_sgmms:
    sgmms_optims = []
    for gmm in sgmms:
        pfilter = lambda param: not isinstance(param, beer.SubspaceBayesianParameter)
        params = gmm.bayesian_parameters(
            paramtype=beer.ConjugateBayesianParameter,
            paramfilter=pfilter,
            keepgroups=True
        )
        soptim = beer.VBConjugateOptimizer(params, lrate=1.)
        sgmms_optims.append(soptim)
    all_sgmm_optims.append(sgmms_optims)
    

elbos_f = []

In [12]:
for epoch in range(1, epochs + 1):
    if (epoch - 1) % stats_update_rate == 0:
        for sgmms, train_data_single, sgmm_optims in zip(lang_sgmms, train_data, all_sgmm_optims):
            accumulate_stats(sgmms, train_data_single, sgmms_optims)
    optim.init_step()
    elbo = beer.evidence_lower_bound(root_gsm,
                                     [(gsm, sgmm) for gsm, sgmm in zip(lang_gsms, lang_sgmms)],
                                     univ_latent_nsamples=5,
                                     latent_posts=lang_latent_posts, 
                                     latent_nsamples=5, params_nsamples=5)
    elbo.backward()
    optim.step()
    elbos_f.append(float(elbo))

In [13]:
figs_per_line = 3
figs = []
fig = figure(title='Elbos')
fig.line(range(len(elbos_f)), elbos_f)
figs.append(fig)

figs = [figs[figs_per_line*i:figs_per_line*i+figs_per_line] for i in range(1 + len(figs)//figs_per_line)]
show(gridplot(figs))

In [14]:
fig1 = figure(title='True model', x_range=(-100, 100), y_range=(-10, 10))
for means, covs, weights, datasets in full_data:
    for color, dataset, m_means, m_covs, m_weights in zip(colors, datasets, means, covs, weights):
        dataset = dataset.numpy()
        plotting.plot_normal(fig1, m_means[0], m_covs[0], alpha=.5 * m_weights[0], color=color)
        plotting.plot_normal(fig1, m_means[1], m_covs[1], alpha=.5 * m_weights[1], color=color)
        fig1.circle(dataset[:, 0], dataset[:, 1], alpha=.05, color=color)

fig2 = figure(title='Subspace GMM', x_range=fig1.x_range, y_range=fig1.y_range)
for sgmms, datasets in zip(lang_sgmms, train_data):
    for gmm, dataset, color in zip(sgmms, datasets, colors):
        dataset = dataset.numpy()
        fig2.circle(dataset[:, 0], dataset[:, 1], color=color, alpha=.05)
        plotting.plot_gmm(fig2, gmm, alpha=.7, color=color)

fig3 = figure(title='Unit latent space')
mean, cov = language_priors[0].mean.numpy(), language_priors[0].cov.numpy()
plotting.plot_normal(fig3, mean, cov, alpha=.5, color='pink')
for post, color in zip(lang_latent_posts, colors):
    for mean, cov in zip(post.params.mean, post.params.diag_cov):
        mean = mean.detach().numpy()
        cov = (cov.diag().detach().numpy())
        plotting.plot_normal(fig3, mean, cov, alpha=0.5, color=color)

fig4 = figure(title='Language latent space')
mean, cov = pseudo_transform.latent_prior.mean.numpy(), pseudo_transform.latent_prior.cov.numpy()
plotting.plot_normal(fig4, mean, cov, alpha=.5, color='pink')

for gsm, color in zip(lang_gsms, colors):
    mean, cov = gsm.transform.latent_posterior.params.mean, gsm.transform.latent_posterior.params.diag_cov
    mean = mean.detach().numpy()
    cov = (cov.squeeze().diag().detach().numpy())
    plotting.plot_normal(fig4, mean, cov, alpha=0.5, color=color)

show(gridplot([[fig1, fig2], [fig3, fig4]]))