Here, we examine the sensitivity to a functional perturbation on a small simulated dataset

In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp
from numpy.polynomial.hermite import hermgauss

from vb_lib import structure_model_lib, data_utils, cavi_lib, plotting_utils
import vb_lib.structure_optimization_lib as s_optim_lib
from vb_lib.preconditioner_lib import get_mfvb_cov_matmul

import paragami

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import influence_lib, modeling_lib, log_phi_lib
from bnpmodeling_runjingdev.sensitivity_lib import HyperparameterSensitivityLinearApproximation
import bnpmodeling_runjingdev.functional_sensitivity_lib as func_sens_lib

ModuleNotFoundError: No module named 'bnpmodeling_runjingdev'

In [None]:
import numpy as onp 
onp.random.seed(53453)

# Draw data

In [None]:
n_obs = 50
n_loci = 200
n_pop = 3
g_obs, true_pop_allele_freq, true_ind_admix_propn = \
    data_utils.draw_data(n_obs, n_loci, n_pop)

# Get prior

In [None]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_dict)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

# Get VB params 

In [None]:
k_approx = 10

In [None]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [None]:
_, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                                      use_logitnormal_sticks = True)
    
print(vb_params_paragami)

# Optimize

### Initialize with CAVI

In [None]:
vb_params_dict, cavi_init_time = \
            s_optim_lib.initialize_with_cavi(g_obs, 
                                 vb_params_paragami, 
                                 prior_params_dict, 
                                 gh_loc, gh_weights, 
                                 print_every = 20, 
                                 max_iter = 200, 
                                 seed = 1232)

# Optimize with preconditioned LBFGS

In [None]:
vb_opt_dict, vb_opt, out, precond_objective, lbfgs_time = \
    s_optim_lib.run_preconditioned_lbfgs(g_obs, 
                        vb_params_dict, 
                        vb_params_paragami,
                        prior_params_dict,
                        gh_loc, gh_weights)

# Check out the fit

In [None]:
fig, axarr = plt.subplots(1, 2, figsize = (12, 4))

###############
# estimated 
###############
e_ind_admix, e_pop_freq = plotting_utils.get_vb_expectations(vb_opt_dict, gh_loc, gh_weights)

_, sorted_indx = \
    plotting_utils.plot_top_clusters(e_ind_admix, axarr[0], n_top_clusters = n_pop)
axarr[0].set_title('estimated')

e_pop_freq = e_pop_freq[:, sorted_indx]

###############
# truth 
###############
# permute so that colors match (as well as possible)
perm = data_utils.find_min_perm(true_pop_allele_freq, e_pop_freq, axis = 1)

plotting_utils.plot_top_clusters(true_ind_admix_propn[:, perm], axarr[1], 
                                 n_top_clusters = n_pop);
axarr[1].set_title('truth')

# Set up linear response derivatives

In [None]:
# this is just a place holder for a null perturbation
# will set this properly later
hyper_par_objective_fun = lambda params, epsilon: 0.

In [None]:
# set up linear approximation class
epsilon0 = np.array([0.])

epsilon_sens = \
    HyperparameterSensitivityLinearApproximation(
        # doesnt matter bc we give it the hvp
        # and we will give it the hyper_par objective later
        objective_fun = None, 
        opt_par_value = vb_opt, 
        hyper_par_value0 = epsilon0, 
        obj_fun_hvp = precond_objective.hvp, 
        hyper_par_objective_fun = hyper_par_objective_fun)


# Compute influence function

### Define posterior quantity of interest

In [None]:
def g(vb_free_params, vb_params_paragami): 
    
    # key for random sampling. 
    # this is fixed! so all standard normal 
    # samples used in computing the posterior quantity 
    key = jax.random.PRNGKey(0)
    
    vb_params_dict = vb_params_paragami.fold(vb_free_params, free = True)
    
    stick_means = vb_params_dict['ind_admix_params']['stick_means']
    stick_infos = vb_params_dict['ind_admix_params']['stick_infos']
    
    return structure_model_lib.get_e_num_pred_clusters(stick_means, stick_infos, gh_loc, gh_weights, 
                                                       key = key,
                                                       n_samples = 100)

In [None]:
get_grad_g = jax.jacobian(g, argnums = 0)
grad_g = get_grad_g(vb_opt, vb_params_paragami)

### the influence function

In [None]:
influence_operator = influence_lib.InfluenceOperator(vb_opt, 
                           vb_params_paragami, 
                           epsilon_sens.hessian_solver,
                           prior_params_dict['dp_prior_alpha'], 
                           stick_key = 'ind_admix_params')

### worst-case perturbation

In [None]:
logit_v_grid = np.linspace(-10, 10, 200)
influence_grid = influence_operator.get_influence(logit_v_grid, grad_g)

In [None]:
worst_case_pert = influence_lib.WorstCasePerturbation(influence_fun = None, 
                                                      logit_v_grid = logit_v_grid, 
                                                      cached_influence_grid = influence_grid)

### Plot influence function

In [None]:
fig, ax = plt.subplots(1, 3, figsize = (15, 4)) 

ax[0].plot(worst_case_pert.logit_v_grid, np.sign(worst_case_pert.influence_grid))
ax[0].set_xlabel('logit v')
ax[0].set_ylabel('sign(Influence)')

ax[1].plot(worst_case_pert.logit_v_grid, worst_case_pert.influence_grid)
ax[1].set_xlabel('logit v')
ax[1].set_ylabel('Influence')

ax[2].plot(worst_case_pert.v_grid, worst_case_pert.influence_grid)
ax[2].set_xlabel('v')
ax[2].set_ylabel('Influence')


fig.tight_layout()

# Define prior perturbation

In [None]:
# this contains a suite of perturbations
f_obj_all = log_phi_lib.LogPhiPerturbations(vb_params_paragami, 
                                                 prior_params_dict['dp_prior_alpha'],
                                                 gh_loc, 
                                                 gh_weights,
                                                 logit_v_grid = logit_v_grid, 
                                                 influence_grid = influence_grid, 
                                                 stick_key = 'ind_admix_params')

# name of the perturbation 
perturbation = 'worst_case'

# get class containing the necessary methods
f_obj = getattr(f_obj_all, 'f_obj_' + perturbation)

In [None]:
prior_perturbation = func_sens_lib.PriorPerturbation(
                                alpha0 = prior_params_dict['dp_prior_alpha'],
                                log_phi = f_obj.log_phi, 
                                logit_v_ub = 10, 
                                logit_v_lb = -10)

In [None]:
prior_perturbation.set_epsilon(1.0)

v_grid = sp.special.expit(logit_v_grid)

plt.figure(1, figsize=(18, 5))

plt.subplot(141)
plt.plot(logit_v_grid, prior_perturbation.get_log_p0_logit(logit_v_grid))
plt.plot(logit_v_grid, prior_perturbation.get_log_pc_logit(logit_v_grid))
plt.title('Log priors in logit space')

plt.subplot(142)
plt.plot(logit_v_grid, prior_perturbation.log_phi(logit_v_grid))
plt.title('log phi in logit space')

plt.subplot(143)
plt.plot(v_grid, np.exp(prior_perturbation.get_log_p0(v_grid)))
plt.plot(v_grid, np.exp(prior_perturbation.get_log_pc(v_grid)))
plt.title('Priors in stick space')

plt.subplot(144)
plt.plot(logit_v_grid, np.exp(prior_perturbation.get_log_p0_logit(logit_v_grid)),
            label = 'p0')
plt.plot(logit_v_grid, np.exp(prior_perturbation.get_log_pc_logit(logit_v_grid)),
            label = 'p1')
plt.title('Priors in logit space')
plt.legend()


# Get derivative for prior perturbation

In [None]:
epsilon_sens._set_cross_hess_and_solve(f_obj.hyper_par_objective_fun)

In [None]:
epsilon = 0.1
print('Epsilon: ', epsilon)

lr_vb_free_params = epsilon_sens.predict_opt_par_from_hyper_par(epsilon)

print('init number of cluster: ', g(vb_opt, vb_params_paragami))
print('lr number of cluster: ', g(lr_vb_free_params, vb_params_paragami))

In [None]:
vb_pert_pred_dict = vb_params_paragami.fold(lr_vb_free_params, free = True)

# Re-optimize

In [None]:
new_init_dict = deepcopy(vb_opt_dict)

In [None]:
vb_pert_dict, vb_opt_pert = \
    s_optim_lib.run_preconditioned_lbfgs(g_obs, 
                                            new_init_dict,
                                            vb_params_paragami,
                                            prior_params_dict,
                                            gh_loc = gh_loc,
                                            gh_weights = gh_weights,
                                            e_log_phi = lambda means, infos : \
                                                           f_obj.e_log_phi_epsilon(means,
                                                                                       infos,
                                                                                       epsilon))[0:2]


In [None]:
np.linalg.norm(vb_opt_pert - vb_opt)

### compare

In [None]:
def print_diff_plot(lr_vb_free_params, vb_opt_pert, vb_opt): 
    plt.plot((lr_vb_free_params - vb_opt), 
             vb_opt_pert - vb_opt, 
             '+', color = 'red')

    plt.plot(lr_vb_free_params - vb_opt, 
            lr_vb_free_params - vb_opt, 
             '-', color = 'blue')

    plt.xlabel('lr')
    plt.ylabel('re-optimized')


In [None]:
# compare free parameters
print_diff_plot(lr_vb_free_params, vb_opt_pert, vb_opt)

In [None]:
print('init number of cluster: ', g(vb_opt, vb_params_paragami))
print('pert number of cluster: ', g(vb_opt_pert, vb_params_paragami))
print('lr number of cluster: ', g(lr_vb_free_params, vb_params_paragami))

# Fit for a range of epsilon

In [None]:
epsilon_list = np.linspace(0, 1, 8) ** 2 # Square to get more points close to 0
print(epsilon_list)

In [None]:
def refit_with_epsilon(epsilon, new_init_dict):
    # sets new epsilon, returns new vb optimum
    
    vb_opt = s_optim_lib.run_preconditioned_lbfgs(g_obs, 
                                         new_init_dict,
                                        vb_params_paragami,
                                        prior_params_dict,
                                        gh_loc = gh_loc,
                                        gh_weights = gh_weights,
                                        e_log_phi = lambda means, infos : \
                                                           f_obj.e_log_phi_epsilon(means, infos, epsilon))[1]
        
    return vb_opt


In [None]:
print('epsilons: ', epsilon_list)

In [None]:
vb_pert_list = []
for epsilon in epsilon_list: 
    print('\n re-optimzing with epsilon = ', epsilon)
    
    vb_pert_list.append(refit_with_epsilon(epsilon, new_init_dict))


# Check free parameters

In [None]:
lr_list = []

for epsilon in epsilon_list: 
    
    # get linear response
    lr_list.append(epsilon_sens.predict_opt_par_from_hyper_par(epsilon))


In [None]:
for i in range(len(epsilon_list)): 
    plt.figure()
    print_diff_plot(lr_list[i], vb_pert_list[i], vb_opt)
    
    plt.title('epsilon = {}'.format(epsilon_list[i]))


# Number of clusters

In [None]:
lr_e_num_clusters_vec = onp.zeros(len(epsilon_list))
refit_e_num_clusters_vec = onp.zeros(len(epsilon_list))

for i in range(len(epsilon_list)): 
        
    # get number of clusters
    refit_e_num_clusters_vec[i] = g(vb_pert_list[i], vb_params_paragami)
    lr_e_num_clusters_vec[i] = g(lr_list[i], vb_params_paragami)


In [None]:
plt.plot(epsilon_list, lr_e_num_clusters_vec, '+--')
plt.plot(epsilon_list, refit_e_num_clusters_vec, '+-')

plt.xlabel('epsilon')
plt.ylabel('num posterior clusters')
plt.legend(('lr', 'refit'))
