In [1]:
import jax
from jax import numpy as np
from jax import scipy as sp

from numpy.polynomial.hermite import hermgauss

import paragami

# BNP sensitivity libraries
import bnpgmm_runjingdev.gmm_clustering_lib as gmm_lib
import bnpgmm_runjingdev.gmm_cavi_lib as cavi_lib
from bnpgmm_runjingdev import utils_lib

from bnpmodeling_runjingdev.sensitivity_lib import HyperparameterSensitivityLinearApproximation
import bnpmodeling_runjingdev.functional_sensitivity_lib as func_sens_lib 
from bnpmodeling_runjingdev import cluster_quantities_lib as cluster_lib


import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.decomposition import PCA

from copy import deepcopy

import vittles
from tqdm import tqdm

import json 



# Load data

In [2]:
# load iris data
dataset_name = 'iris'
features, iris_species = utils_lib.load_data()
dim = features.shape[1]
n_obs = len(iris_species)    

# Load model

### Get prior 

In [3]:
prior_params_dict, prior_params_paragami = gmm_lib.get_default_prior_params(dim)

In [4]:
print(prior_params_dict)

{'alpha': DeviceArray([3.], dtype=float64), 'prior_centroid_mean': DeviceArray([0.], dtype=float64), 'prior_lambda': DeviceArray([1.], dtype=float64), 'prior_wishart_df': DeviceArray([10.], dtype=float64), 'prior_wishart_rate': DeviceArray([[1., 0., 0., 0.],
             [0., 1., 0., 0.],
             [0., 0., 1., 0.],
             [0., 0., 0., 1.]], dtype=float64)}


In [5]:
# DP prior parameter 
prior_params_dict['alpha'] = 12

### Variational parameters

In [6]:
k_approx = 30

In [7]:
# Gauss-Hermite points
gh_deg = 20
gh_loc, gh_weights = hermgauss(gh_deg)

In [8]:
# get vb parameters
vb_params_dict, vb_params_paragami = gmm_lib.get_vb_params_paragami_object(dim, k_approx)

In [9]:
# the objective 
gmm_lib.get_kl(features, vb_params_dict, prior_params_dict, gh_loc, gh_weights)

DeviceArray(1889.86954163, dtype=float64)

In [10]:
stick_propn_mean = vb_params_dict['stick_params']['stick_propn_mean']
stick_propn_info = vb_params_dict['stick_params']['stick_propn_info']
centroids = vb_params_dict['cluster_params']['centroids']
cluster_info = vb_params_dict['cluster_params']['cluster_info']

In [11]:
e_z, loglik_obs_by_nk = \
            gmm_lib.get_optimal_z(features, stick_propn_mean, stick_propn_info, centroids, cluster_info,
                            gh_loc, gh_weights, use_bnp_prior = True)

In [12]:
gmm_lib.get_entropy(stick_propn_mean, stick_propn_info, e_z,
                                        gh_loc, gh_weights)

DeviceArray(107.22703455, dtype=float64)

In [13]:
gmm_lib.get_e_log_prior(stick_propn_mean, stick_propn_info,
                            centroids, cluster_info,
                            prior_params_dict,
                            gh_loc, gh_weights)

DeviceArray(-391.68462054, dtype=float64)

In [14]:
gmm_lib.modeling_lib.loglik_ind(stick_propn_mean, stick_propn_info, e_z,
                            gh_loc, gh_weights)

DeviceArray(-569.29713507, dtype=float64)

# Optimize

### Kmeans

In [15]:
n_kmeans_init = 50
init_vb_free_params, init_vb_params_dict, init_ez = \
    utils_lib.cluster_and_get_k_means_inits(features, vb_params_paragami, 
                                                n_kmeans_init = n_kmeans_init, 
                                                seed = 32423)

In [16]:
(init_vb_free_params**2).mean()

DeviceArray(20.18105018, dtype=float64)

### Run CAVI

In [17]:
vb_params_dict = deepcopy(init_vb_params_dict)

In [18]:
vb_opt_dict, e_z_opt = cavi_lib.run_cavi(features, vb_params_dict,
                                            vb_params_paragami, prior_params_dict,
                                            gh_loc, gh_weights,
                                            debug = False, 
                                            x_tol = 1e-3)

Compiling CAVI update functions ... 
CAVI compile time: 4.26sec

Running CAVI ... 
done. num iterations = 140
stick_time: 1.61sec
cluster_time: 1.02sec
e_z_time: 0.0384sec
**CAVI time: 2.84sec**


In [19]:
vb_opt = vb_params_paragami.flatten(vb_opt_dict, free = True)

In [20]:
(vb_opt**2).mean()

DeviceArray(0.79410882, dtype=float64)

# Define a perturbation in the logit v space.

In [21]:
def log_phi(logit_v):
    return(sp.special.expit(logit_v))

def phi(logit_v):
    return np.exp(log_phi(logit_v))


In [22]:
logit_v_grid = np.linspace(-8, 8, 200)

v_grid = np.exp(logit_v_grid) / (1 + np.exp(logit_v_grid))
log_phi_max = np.max(np.abs(np.log(phi(logit_v_grid))))

def rescaled_log_phi(logit_v):
    return 10 * log_phi(logit_v) / log_phi_max


In [23]:
prior_perturbation = func_sens_lib.PriorPerturbation(
                                alpha0 = prior_params_dict['alpha'],
                                log_phi = rescaled_log_phi, 
                                logit_v_ub=8, logit_v_lb = -8)


In [24]:
log_prior_pert = lambda logit_v : rescaled_log_phi(logit_v) - prior_perturbation.log_norm_pc_logit

# Define prior perturbation

In [25]:
# we will use this for the sensitivity class
get_epsilon_vb_loss = paragami.FlattenFunctionInput(
    lambda params, epsilon: gmm_lib.get_perturbed_kl(
                                                        features, 
                                                        params,
                                                        epsilon,
                                                        log_prior_pert,
                                                        prior_params_dict, 
                                                        gh_loc, gh_weights),
    argnums=0, patterns=vb_params_paragami, free=True)

def get_e_log_perturbation(log_phi, vb_params_dict, epsilon, gh_loc, gh_weights): 
    return func_sens_lib.get_e_log_perturbation(log_phi,
                            vb_params_dict['stick_params']['stick_propn_mean'],
                            vb_params_dict['stick_params']['stick_propn_info'],
                            epsilon, gh_loc, gh_weights, sum_vector=True)


hyper_par_objective_fun = paragami.FlattenFunctionInput(
    lambda params, epsilon: get_e_log_perturbation(log_prior_pert, 
                                                    params,
                                                    epsilon,
                                                    gh_loc, gh_weights),
    argnums=0, patterns=vb_params_paragami, free=True)

In [26]:
epsilon0 = np.array([0.])

In [27]:
epsilon_sens = \
    HyperparameterSensitivityLinearApproximation(
        objective_fun = get_epsilon_vb_loss, 
        opt_par_value = vb_opt, 
        hyper_par_value0 = epsilon0, 
        hyper_par_objective_fun = hyper_par_objective_fun)

Compiling objective function derivatives ... 
Compile time: 21sec

Objective function derivative time: 0.000858sec

Linear system compile time: 0.832sec
Linear system time: 0.000388sec


In [28]:
epsilon_sens.hessian.shape

(478, 478)

## Fit with perturbation

In [29]:
epsilon = 1.0 
print('Epsilon: ', epsilon)

vb_pert_pred = epsilon_sens.predict_opt_par_from_hyper_par(epsilon)

print('Predicted differences: ', np.linalg.norm(vb_pert_pred - vb_opt))

Epsilon:  1.0
Predicted differences:  4.287180477837872


In [30]:
vb_pert_pred_dict = vb_params_paragami.fold(vb_pert_pred, free = True)
e_z_pert_pred = gmm_lib.get_optimal_z_from_vb_params_dict(
    features, vb_pert_pred_dict, gh_loc, gh_weights)

In [31]:
new_init_dict = deepcopy(vb_opt_dict)
new_init_free = vb_params_paragami.flatten(new_init_dict, free = True)

In [32]:
vb_pert_dict, e_z_pert = cavi_lib.run_cavi(features, deepcopy(new_init_dict),
                                        vb_params_paragami, prior_params_dict,
                                        gh_loc, gh_weights,
                                        log_phi = rescaled_log_phi,
                                    epsilon = epsilon, 
                                    debug = False)

vb_pert_opt = vb_params_paragami.flatten(vb_pert_dict, free = True)

Compiling CAVI update functions ... 
CAVI compile time: 4.4sec

Running CAVI ... 
done. num iterations = 132
stick_time: 1.35sec
cluster_time: 0.0654sec
e_z_time: 0.0353sec
**CAVI time: 1.54sec**


In [33]:
(vb_pert_pred**2).mean()

DeviceArray(0.56207333, dtype=float64)

In [34]:
(vb_pert_opt**2).mean()

DeviceArray(0.49917918, dtype=float64)