In [1]:
import jax
from jax import numpy as np
from jax import scipy as sp

from numpy.polynomial.hermite import hermgauss

import paragami

# BNP sensitivity libraries
import bnpgmm_runjingdev.gmm_clustering_lib as gmm_lib
import bnpgmm_runjingdev.gmm_cavi_lib as cavi_lib
from bnpgmm_runjingdev import utils_lib

from bnpmodeling_runjingdev.sensitivity_lib import HyperparameterSensitivityLinearApproximation
import bnpmodeling_runjingdev.functional_sensitivity_lib as func_sens_lib 
from bnpmodeling_runjingdev import cluster_quantities_lib as cluster_lib


import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.decomposition import PCA

from copy import deepcopy

import vittles
from tqdm import tqdm

import json 



# Load data

In [2]:
# load iris data
dataset_name = 'iris'
features, iris_species = utils_lib.load_data()
dim = features.shape[1]
n_obs = len(iris_species)    

# Load model

### Get prior 

In [3]:
prior_params_dict, prior_params_paragami = gmm_lib.get_default_prior_params(dim)

In [4]:
print(prior_params_dict)

{'alpha': DeviceArray([3.], dtype=float64), 'prior_centroid_mean': DeviceArray([0.], dtype=float64), 'prior_lambda': DeviceArray([1.], dtype=float64), 'prior_wishart_df': DeviceArray([10.], dtype=float64), 'prior_wishart_rate': DeviceArray([[1., 0., 0., 0.],
             [0., 1., 0., 0.],
             [0., 0., 1., 0.],
             [0., 0., 0., 1.]], dtype=float64)}


In [5]:
# DP prior parameter 
prior_params_dict['alpha'] = 12

### Variational parameters

In [6]:
k_approx = 30

In [7]:
# Gauss-Hermite points
gh_deg = 20
gh_loc, gh_weights = hermgauss(gh_deg)

In [8]:
# get vb parameters
vb_params_dict, vb_params_paragami = gmm_lib.get_vb_params_paragami_object(dim, k_approx)

# Optimize

### Kmeans

In [9]:
n_kmeans_init = 50
init_vb_free_params, init_vb_params_dict, init_ez = \
    utils_lib.cluster_and_get_k_means_inits(features, vb_params_paragami, 
                                                n_kmeans_init = n_kmeans_init, 
                                                seed = 32423)

In [10]:
(init_vb_free_params**2).mean()

DeviceArray(20.18105018, dtype=float64)

In [11]:
# initial KL 
gmm_lib.get_kl(features, init_vb_params_dict, prior_params_dict, gh_loc, gh_weights)

DeviceArray(12643.5064535, dtype=float64)

### Run CAVI

In [12]:
vb_params_dict = deepcopy(init_vb_params_dict)

In [13]:
vb_opt_dict, e_z_opt = cavi_lib.run_cavi(features, vb_params_dict,
                                            vb_params_paragami, prior_params_dict,
                                            gh_loc, gh_weights,
                                            debug = False, 
                                            x_tol = 1e-3)

Compiling CAVI update functions ... 
CAVI compile time: 3.87sec

Running CAVI ... 
done. num iterations = 140
stick_time: 1.24sec
cluster_time: 0.591sec
e_z_time: 0.0279sec
**CAVI time: 2.01sec**


In [14]:
vb_opt = vb_params_paragami.flatten(vb_opt_dict, free = True)

In [15]:
(vb_opt**2).mean()

DeviceArray(0.79410882, dtype=float64)

# Define a perturbation in the logit v space.

In [16]:
def log_phi(logit_v):
    return(sp.special.expit(logit_v))

def phi(logit_v):
    return np.exp(log_phi(logit_v))


In [17]:
logit_v_grid = np.linspace(-8, 8, 200)

v_grid = np.exp(logit_v_grid) / (1 + np.exp(logit_v_grid))
log_phi_max = np.max(np.abs(np.log(phi(logit_v_grid))))

def rescaled_log_phi(logit_v):
    return 10 * log_phi(logit_v) / log_phi_max


In [18]:
prior_perturbation = func_sens_lib.PriorPerturbation(
                                alpha0 = prior_params_dict['alpha'],
                                log_phi = rescaled_log_phi, 
                                logit_v_ub=8, logit_v_lb = -8)


In [19]:
log_prior_pert = lambda logit_v : rescaled_log_phi(logit_v) - prior_perturbation.log_norm_pc_logit

In [20]:
e_log_prior_pert = lambda means, info, epsilon : \
                        func_sens_lib.get_e_log_perturbation(log_prior_pert,
                            means, info, epsilon, 
                            gh_loc, gh_weights, sum_vector=True)

# Define prior perturbation

In [21]:
# the KL objective
def _get_epsilon_vb_loss(params, epsilon): 
    kl = gmm_lib.get_kl(features, 
                    params,
                    prior_params_dict, 
                    gh_loc, gh_weights) 
    
    perturbation = e_log_prior_pert(params['stick_params']['stick_propn_mean'],
                                    params['stick_params']['stick_propn_info'], 
                                    epsilon)
    
    return kl - perturbation

get_epsilon_vb_loss = paragami.FlattenFunctionInput(_get_epsilon_vb_loss, 
                                                    argnums=0, 
                                                    patterns=vb_params_paragami,
                                                    free=True)


In [22]:
# perturbed KL
get_epsilon_vb_loss(vb_opt, epsilon = 1.)

DeviceArray(-323.89585349, dtype=float64)

In [23]:
# the terms of the KL objective that depend on epsilon 
def _hyper_par_objective_fun(params, epsilon): 
    return -e_log_prior_pert(params['stick_params']['stick_propn_mean'],
                            params['stick_params']['stick_propn_info'], 
                            epsilon)

hyper_par_objective_fun = paragami.FlattenFunctionInput(_hyper_par_objective_fun,
                                                        argnums=0, 
                                                        patterns=vb_params_paragami,
                                                        free=True)

In [24]:
epsilon0 = np.array([0.])

In [25]:
epsilon_sens = \
    HyperparameterSensitivityLinearApproximation(
        objective_fun = get_epsilon_vb_loss, 
        opt_par_value = vb_opt, 
        hyper_par_value0 = epsilon0, 
        hyper_par_objective_fun = hyper_par_objective_fun)

Compiling hessian solver ...
Hessian solver compile time: 46.8478sec

Compiling cross hessian...
Cross-hessian compile time: 0.741282sec

LR sensitivity time: 0.0983145sec



## Fit with perturbation

In [26]:
epsilon = 1.0 
print('Epsilon: ', epsilon)

vb_pert_pred = epsilon_sens.predict_opt_par_from_hyper_par(epsilon)

print('Predicted differences: ', np.linalg.norm(vb_pert_pred - vb_opt))

Epsilon:  1.0
Predicted differences:  4.287180475877754


In [27]:
vb_pert_pred_dict = vb_params_paragami.fold(vb_pert_pred, free = True)
e_z_pert_pred = gmm_lib.get_optimal_z_from_vb_params_dict(
    features, vb_pert_pred_dict, gh_loc, gh_weights)

In [28]:
new_init_dict = deepcopy(vb_opt_dict)
new_init_free = vb_params_paragami.flatten(new_init_dict, free = True)

In [29]:
vb_pert_dict, e_z_pert = cavi_lib.run_cavi(features, deepcopy(new_init_dict),
                                        vb_params_paragami, prior_params_dict,
                                        gh_loc, gh_weights,
                                        e_log_phi = lambda means, info : \
                                                       e_log_prior_pert(means, info, epsilon), 
                                        debug = False)

vb_pert_opt = vb_params_paragami.flatten(vb_pert_dict, free = True)

Compiling CAVI update functions ... 
CAVI compile time: 2.96sec

Running CAVI ... 
done. num iterations = 132
stick_time: 1.02sec
cluster_time: 0.0455sec
e_z_time: 0.0271sec
**CAVI time: 1.16sec**


In [30]:
(vb_pert_pred**2).mean()

DeviceArray(0.56207331, dtype=float64)

In [31]:
(vb_pert_opt**2).mean()

DeviceArray(0.49917918, dtype=float64)

In [32]:
(epsilon_sens.dinput_dhyper**2).mean()

DeviceArray(0.03845171, dtype=float64)

# Check against saved autograd results

In [33]:
autograd_results = np.load('.tmp.npz')

In [34]:
(vb_pert_pred**2).mean() - autograd_results['vb_pert_pred']

DeviceArray(-2.09724346e-08, dtype=float64)

In [35]:
(vb_pert_opt**2).mean() - autograd_results['vb_pert_opt']

DeviceArray(-2.22044605e-16, dtype=float64)

In [36]:
(epsilon_sens.dinput_dhyper**2).mean() - autograd_results['dopt_dhyper']

DeviceArray(-3.51604995e-11, dtype=float64)

In [37]:
get_epsilon_vb_loss(vb_opt, epsilon = 0) - autograd_results['kl']

DeviceArray(1.13686838e-13, dtype=float64)

In [38]:
get_epsilon_vb_loss(vb_opt, epsilon = 1) - autograd_results['kl_pert']

DeviceArray(5.68434189e-14, dtype=float64)