In [1]:
git_repo = '../../../../'

import sys
import os

sys.path.insert(0, os.path.join(git_repo, 'BNP_sensitivity/GMM_clustering/'))
sys.path.insert(0, '../../../../LinearResponseVariationalBayes.py/')

# Linear response libraries
import LinearResponseVariationalBayes as vb
import LinearResponseVariationalBayes.SparseObjectives as obj_lib
import LinearResponseVariationalBayes.OptimizationUtils as opt_lib

import LinearResponseVariationalBayes.ModelSensitivity as model_sens

# Local libraries
# import gmm_clustering_lib as gmm_utils
# import modeling_lib 
import functional_sensitivity_lib as fun_sens_lib 
import utils_lib
import simulation_lib

import matplotlib.pyplot as plt
%matplotlib inline

from copy import deepcopy

import autograd
from autograd import numpy as np
from autograd import scipy as sp

from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

np.random.seed(453453)

import json 

sys.path.append('../../../../paragami/')
import paragami

import gmm_clustering_paragami_lib as gmm_parag_lib
import modeling_lib_paragami 

from numpy.polynomial.hermite import hermgauss

from scipy import optimize

SyntaxError: invalid syntax (gmm_clustering_paragami_lib.py, line 494)

In [None]:
# colors for plotting
import matplotlib.cm as cm

cmap = cm.get_cmap(name='gist_rainbow')
colors1 = [cmap(k * 50) for k in range(12)]
colors2 = [cmap(k * 25) for k in range(12)]

# Draw data

In [None]:
# load (or simulate) data
simulate = True
if not simulate:
    # load iris data
    dataset_name = 'iris'
    features, iris_species = utils_lib.load_data()
    dim = features.shape[1]
    n_obs = len(iris_species)    
    
else:
    # Simulate data
    dataset_name = 'simulation'

    n_obs = 1000
    dim = 2
    true_k = 5
    features, true_z, true_components, true_centroids, true_covs, true_probs = \
        simulation_lib.simulate_data(n_obs, dim, true_k, separation=0.6)

    for k in range(true_k):
        plt.plot(features[true_components == k, 0], features[true_components == k, 1], '.')
        
    iris_species = np.argmax(true_z, axis = 1)

In [None]:
# run PCA
pca_fit = PCA()
pca_fit.fit(features)
pc_features = pca_fit.transform(features)

# plot
fig1 = plt.figure(1)
fig = fig1.add_subplot(111)
utils_lib.plot_clusters(pc_features[:, 0], pc_features[:, 1], iris_species, colors1, fig)
fig.set_xlabel('PC1')
fig.set_ylabel('PC2')


In [None]:
def plot_results_from_vb_params_dict(vb_params_dict, e_z, fig): 
    # we plot in PCA space
    bnp_centroids_pc, bnp_cluster_covs_pc = \
        utils_lib.transform_params_to_pc_space(pca_fit, vb_params_dict['centroids'], 
                                               np.linalg.inv(vb_params_dict['gamma']))
    
    bnp_clusters = np.argmax(e_z, axis = 1)
    
    utils_lib.plot_clusters(pc_features[:, 0], pc_features[:, 1], bnp_clusters, \
                  colors1, fig, \
                  centroids = bnp_centroids_pc[:, np.unique(bnp_clusters)], 
                cov = bnp_cluster_covs_pc[np.unique(bnp_clusters)])


In [None]:
if simulate: 
    true_vb_params_dict = dict() 
    true_vb_params_dict['centroids'] = true_centroids.T
    true_vb_params_dict['v_stick_mean'] = np.ones(true_k - 1)
    true_vb_params_dict['v_stick_info'] = np.ones(true_k - 1)
    
    true_vb_params_dict['gamma'] = np.linalg.inv(true_covs)
    

# Set up model

### Get priors

In [None]:
prior_params_dict, prior_params_paragami = gmm_parag_lib.get_default_prior_params(dim)

In [None]:
print(prior_params_paragami)

In [None]:
# these are constrained parameters
print(prior_params_dict)

In [None]:
# these are free parameters
prior_params_paragami.flatten(prior_params_dict, free = True)

### the variational inference objective 

In [None]:
if simulate: 
    k_approx = true_k
else: 
    k_approx = 12

In [None]:
# Gauss-Hermite points
gh_deg = 8
gh_loc, gh_weights = hermgauss(gh_deg)

In [None]:
# get vb parameters
vb_params_dict, vb_params_paragami = gmm_parag_lib.get_vb_params_paragami_object(dim, k_approx, n_obs)

In [None]:
gmm_parag_lib.get_kl(features, vb_params_dict, prior_params_dict, gh_loc, gh_weights)

# Optimization 

### run k-means init

In [None]:
n_kmeans_init = 10
init_vb_free_params, init_vb_params_dict, init_ez = \
    gmm_parag_lib.cluster_and_get_k_means_inits(features, vb_params_paragami, 
                                                n_kmeans_init = n_kmeans_init)

In [None]:
gmm_parag_lib.get_kl(features, init_vb_params_dict, prior_params_dict, gh_loc, gh_weights)

In [None]:
init_vb_params_dict

In [None]:
init_vb_free_params

In [None]:
def plot_results_from_vb_params_dict(vb_params_dict, e_z, fig): 
    # we plot in PCA space
    bnp_centroids_pc, bnp_cluster_covs_pc = \
        utils_lib.transform_params_to_pc_space(pca_fit, vb_params_dict['centroids'], 
                                               np.linalg.inv(vb_params_dict['gamma']))
    
    bnp_clusters = np.argmax(e_z, axis = 1)

    utils_lib.plot_clusters(pc_features[:, 0], pc_features[:, 1], bnp_clusters, \
                  colors1, fig, \
                  centroids = bnp_centroids_pc[:, np.unique(bnp_clusters)], 
                cov = bnp_cluster_covs_pc[np.unique(bnp_clusters)])


In [None]:
fig1 = plt.figure(1)
fig = fig1.add_subplot(111)

plot_results_from_vb_params_dict(init_vb_params_dict, init_ez, fig)

### Set up losses

In [None]:
# Get loss as a function of the  vb_params_dict
get_vb_params_loss = paragami.Functor(original_fun=gmm_parag_lib.get_kl, argnums=1)
get_vb_params_loss.cache_args(features, None, prior_params_dict, gh_loc, gh_weights)

# Get loss as a function vb_free_params
get_vb_free_params_loss = paragami.FlattenedFunction(
                                            original_fun=get_vb_params_loss,
                                            patterns=vb_params_paragami,
                                            free=True)


In [None]:
# get gradient 
get_vb_free_params_loss_grad = autograd.grad(get_vb_free_params_loss)
get_vb_free_params_loss_hess = autograd.hessian(get_vb_free_params_loss)

In [None]:
bfgs_opt = gmm_parag_lib.run_bfgs(get_vb_free_params_loss, init_vb_free_params, get_vb_free_params_loss_grad, 
                                 maxiter = 50)

In [None]:
bfgs_opt.success



In [None]:
bfgs_vb_free_pars = bfgs_opt.x
bfgs_vb_params_dict = vb_params_paragami.fold(bfgs_vb_free_pars, free=True)
bfgs_vb_params_dict

In [None]:
true_vb_params_dict

In [None]:
bfgs_ez = gmm_parag_lib.get_optimal_z_from_vb_params_dict(features, bfgs_vb_params_dict, gh_loc, gh_weights)

In [None]:
fig1 = plt.figure(1)
fig = fig1.add_subplot(111)

plot_results_from_vb_params_dict(bfgs_vb_params_dict, bfgs_ez, fig)

### Get preconditioner and run Newton trust-region

In [None]:
trust_ncg_vb_free_pars = gmm_parag_lib.precondition_and_optimize(get_vb_free_params_loss, bfgs_opt.x)

In [None]:
vb_opt_dict = \
    vb_params_paragami.fold(trust_ncg_vb_free_pars, free = True)
    
vb_opt_dict

In [None]:
opt_ez = gmm_parag_lib.get_optimal_z_from_vb_params_dict(features, vb_opt_dict, gh_loc, gh_weights)

In [None]:
fig1 = plt.figure(1)
fig = fig1.add_subplot(111)

plot_results_from_vb_params_dict(vb_opt_dict, opt_ez, fig)