In [1]:
from autograd import numpy as np
from autograd import scipy as sp

from numpy.polynomial.hermite import hermgauss

np.random.seed(453453)

import paragami

# BNP sensitivity libraries
import bnpgmm_runjingdev.gmm_clustering_lib as gmm_lib
from bnpgmm_runjingdev import hessian_lib
import bnpgmm_runjingdev.gmm_cavi_lib as cavi_lib
import bnpgmm_runjingdev.gmm_preconditioner_lib as preconditioner_lib
from bnpgmm_runjingdev import utils_lib

import bnpmodeling_runjingdev.functional_sensitivity_lib as func_sens_lib 
from bnpmodeling_runjingdev import modeling_lib 
from bnpmodeling_runjingdev import optimization_lib 
from bnpmodeling_runjingdev import cluster_quantities_lib as cluster_lib


import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.decomposition import PCA

from copy import deepcopy

import vittles
from tqdm import tqdm

import json 

The defvjp method is deprecated. See the update guide and tutorial:
https://github.com/HIPS/autograd/blob/master/docs/updateguide.md
https://github.com/HIPS/autograd/blob/master/docs/tutorial.md


# Load data

In [2]:
# load iris data
dataset_name = 'iris'
features, iris_species = utils_lib.load_data()
dim = features.shape[1]
n_obs = len(iris_species)    

# Load model

### Get prior 

In [3]:
prior_params_dict, prior_params_paragami = gmm_lib.get_default_prior_params(dim)

In [4]:
print(prior_params_dict)

{'alpha': array([3.]), 'prior_centroid_mean': array([0.]), 'prior_lambda': array([1.]), 'prior_wishart_df': array([10.]), 'prior_wishart_rate': array([[1., 0., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.]])}


In [5]:
# DP prior parameter 
prior_params_dict['alpha'] = 12

### Variational parameters

In [6]:
k_approx = 30

In [7]:
# Gauss-Hermite points
gh_deg = 20
gh_loc, gh_weights = hermgauss(gh_deg)

In [8]:
# get vb parameters
vb_params_dict, vb_params_paragami = gmm_lib.get_vb_params_paragami_object(dim, k_approx)

# Optimize

### Kmeans

In [9]:
n_kmeans_init = 50
init_vb_free_params, init_vb_params_dict, init_ez = \
    utils_lib.cluster_and_get_k_means_inits(features, vb_params_paragami, 
                                                n_kmeans_init = n_kmeans_init, 
                                                seed = 32423)

In [10]:
(init_vb_free_params**2).mean()

20.181050178389253

In [11]:
# initial KL
gmm_lib.get_kl(features, vb_params_dict, prior_params_dict, gh_loc, gh_weights)

2009.3998815516832

### Run CAVI

In [12]:
vb_params_dict = deepcopy(init_vb_params_dict)

In [13]:
vb_opt_dict, e_z_opt = cavi_lib.run_cavi(features, vb_params_dict,
                                            vb_params_paragami, prior_params_dict,
                                            gh_loc, gh_weights,
                                            debug = False, 
                                            x_tol = 1e-3)

done. num iterations = 140
stick_time: 0.691sec
cluster_time: 0.069sec
e_z_time: 0.175sec
**TOTAL time: 2.022sec**


In [14]:
vb_opt = vb_params_paragami.flatten(vb_opt_dict, free = True)

In [15]:
(vb_opt**2).mean()

0.7941088233684164

# Define a perturbation in the logit v space.

In [16]:
def log_phi(logit_v):
    return(sp.special.expit(logit_v))

def phi(logit_v):
    return np.exp(log_phi(logit_v))


In [17]:
logit_v_grid = np.linspace(-8, 8, 200)

v_grid = np.exp(logit_v_grid) / (1 + np.exp(logit_v_grid))
log_phi_max = np.max(np.abs(np.log(phi(logit_v_grid))))

def rescaled_log_phi(logit_v):
    return 10 * log_phi(logit_v) / log_phi_max


In [18]:
prior_perturbation = func_sens_lib.PriorPerturbation(
                                alpha0 = prior_params_dict['alpha'],
                                log_phi = rescaled_log_phi, 
                                logit_v_ub=8, logit_v_lb = -8)


In [19]:
log_prior_pert = lambda logit_v : rescaled_log_phi(logit_v) - prior_perturbation.log_norm_pc_logit

# Define prior perturbation

In [20]:
# we will use this for the sensitivity class
get_epsilon_vb_loss = paragami.FlattenFunctionInput(
    lambda params, epsilon: gmm_lib.get_perturbed_kl(
                                                        features, 
                                                        params,
                                                        epsilon,
                                                        log_prior_pert,
                                                        prior_params_dict, 
                                                        gh_loc, gh_weights),
    argnums=0, patterns=vb_params_paragami, free=True)

def get_e_log_perturbation(log_phi, vb_params_dict, epsilon, gh_loc, gh_weights): 
    return func_sens_lib.get_e_log_perturbation(log_phi,
                            vb_params_dict['stick_params']['stick_propn_mean'],
                            vb_params_dict['stick_params']['stick_propn_info'],
                            epsilon, gh_loc, gh_weights, sum_vector=True)


hyper_par_objective_fun = paragami.FlattenFunctionInput(
    lambda params, epsilon: get_e_log_perturbation(log_prior_pert, 
                                                    params,
                                                    epsilon,
                                                    gh_loc, gh_weights),
    argnums=0, patterns=vb_params_paragami, free=True)

In [21]:
# perturbed KL
get_epsilon_vb_loss(vb_opt, epsilon = 1.)

-323.8958534860422

In [22]:
epsilon0 = np.array([0.])

In [23]:
import time

In [24]:
t0 = time.time()
epsilon_sens = \
    vittles.HyperparameterSensitivityLinearApproximation(
        objective_fun = get_epsilon_vb_loss, 
        opt_par_value = vb_opt, 
        hyper_par_value = epsilon0, 
        hyper_par_objective_fun = hyper_par_objective_fun)

print('derivative time: {}sec'.format(np.round(time.time() - t0, 3)))

cross hess time:  0.10590982437133789
solver time:  0.0006532669067382812
derivative time: 11.329sec


## Fit with perturbation

In [25]:
epsilon = 1.0 
print('Epsilon: ', epsilon)

vb_pert_pred = epsilon_sens.predict_opt_par_from_hyper_par(epsilon)

print('Predicted differences: ', np.linalg.norm(vb_pert_pred - vb_opt))

Epsilon:  1.0
Predicted differences:  4.287180477837869


In [26]:
vb_pert_pred_dict = vb_params_paragami.fold(vb_pert_pred, free = True)
e_z_pert_pred = gmm_lib.get_optimal_z_from_vb_params_dict(
    features, vb_pert_pred_dict, gh_loc, gh_weights)

In [27]:
new_init_dict = deepcopy(vb_opt_dict)
new_init_free = vb_params_paragami.flatten(new_init_dict, free = True)

In [28]:
vb_pert_dict, e_z_pert = cavi_lib.run_cavi(features, deepcopy(new_init_dict),
                                        vb_params_paragami, prior_params_dict,
                                        gh_loc, gh_weights,
                                        log_phi = rescaled_log_phi,
                                    epsilon = epsilon, 
                                    debug = False)

vb_pert_opt = vb_params_paragami.flatten(vb_pert_dict, free = True)

done. num iterations = 132
stick_time: 0.654sec
cluster_time: 0.058sec
e_z_time: 0.156sec
**TOTAL time: 1.82sec**


In [29]:
(vb_pert_pred**2).mean()

0.5620733320868303

In [30]:
(vb_pert_opt**2).mean()

0.4991791832418399

In [31]:
(epsilon_sens.get_dopt_dhyper()**2).mean()

0.03845170805346053

In [32]:
np.savez('.tmp', 
         kl = get_epsilon_vb_loss(vb_opt, epsilon = 0), 
         kl_pert = get_epsilon_vb_loss(vb_opt, epsilon = 1), 
         vb_pert_pred = (vb_pert_pred**2).mean(), 
         vb_pert_opt = (vb_pert_opt**2).mean(), 
         dopt_dhyper = (epsilon_sens.get_dopt_dhyper()**2).mean())