In [1]:
import autograd

import autograd.numpy as np
import autograd.scipy as sp
from numpy.polynomial.hermite import hermgauss

import scipy as osp

import sys
sys.path.insert(0, '../')

import structure_model_lib
import structure_optimization_lib as str_opt_lib
import preconditioner_lib 

import paragami
import vittles

from copy import deepcopy

import argparse
import distutils.util

import os

import time

import data_utils

import matplotlib.pyplot as plt
%matplotlib inline  

from BNP_modeling import cluster_quantities_lib, modeling_lib
import BNP_modeling.optimization_lib as opt_lib


In [2]:
np.random.seed(53453)

# Draw data

In [4]:
n_obs = 10
n_loci = 6
n_pop = 4

In [7]:
g_obs = data_utils.draw_data(n_obs, n_loci, n_pop)[0]

# Get prior

In [8]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_paragami)

OrderedDict:
	[dp_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_beta] = NumericArrayPattern (1,) (lb=0.0, ub=inf)


# Get VB params 

In [9]:
k_approx = 4

In [10]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [11]:
use_logitnormal_sticks = False

vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)

## Initialize 

In [12]:
vb_params_dict = \
        structure_model_lib.set_init_vb_params(g_obs, k_approx, vb_params_dict,
                                                use_logitnormal_sticks)

# Optimize

In [13]:
vb_opt_free_params = \
    str_opt_lib.optimize_structure(g_obs, vb_params_dict, vb_params_paragami,
                                prior_params_dict,
                                gh_loc, gh_weights,
                                use_logitnormal_sticks = use_logitnormal_sticks,
                                run_cavi = True,
                                cavi_max_iter = 100,
                                cavi_tol = 1e-2,
                                netwon_max_iter = 20,
                                max_precondition_iter = 25,
                                gtol=1e-8, ftol=1e-8, xtol=1e-8,
                                approximate_hessian = True)

vb_opt_dict = vb_params_paragami.fold(vb_opt_free_params, free=True)

iteration [0]; kl:175.811206
iteration [1]; kl:149.944569
iteration [2]; kl:144.741878
iteration [3]; kl:142.551463
iteration [4]; kl:141.172401
iteration [5]; kl:140.196596
iteration [6]; kl:139.491752
iteration [7]; kl:138.981946
iteration [8]; kl:138.608283
iteration [9]; kl:138.326908
iteration [10]; kl:138.108797
iteration [11]; kl:137.935623
iteration [12]; kl:137.795538
iteration [13]; kl:137.680534
iteration [14]; kl:137.584977
iteration [15]; kl:137.50478
iteration [16]; kl:137.436881
iteration [17]; kl:137.378912
iteration [18]; kl:137.328991
iteration [19]; kl:137.285588
iteration [20]; kl:137.247453
iteration [21]; kl:137.21356
iteration [22]; kl:137.183076
iteration [23]; kl:137.155332
iteration [24]; kl:137.129796
iteration [25]; kl:137.106052
iteration [26]; kl:137.083777
iteration [27]; kl:137.062725
iteration [28]; kl:137.042712
iteration [29]; kl:137.023602
iteration [30]; kl:137.005304
iteration [31]; kl:136.987757
iteration [32]; kl:136.97093
iteration [33]; kl:136.

In [14]:
vb_params_dict = deepcopy(vb_opt_dict)

# Sensitivities

### Get sensitivity class

In [17]:
get_kl_from_vb_free_prior_free = paragami.FlattenFunctionInput(original_fun=structure_model_lib.get_kl, 
                                    patterns = [vb_params_paragami, prior_params_paragami],
                                    free = True,
                                    argnums = [1, 2])

In [18]:
objective_fun = lambda x, y: get_kl_from_vb_free_prior_free(g_obs, x, y, use_logitnormal_sticks, 
                                                                    gh_loc, gh_weights)

In [19]:
t0 = time.time()

vb_sens = \
    vittles.HyperparameterSensitivityLinearApproximation(
        objective_fun = objective_fun,
        opt_par_value = vb_opt_free_params, 
        hyper_par_value = prior_params_paragami.flatten(prior_params_dict, free=True),
        validate_optimum=False,
        factorize_hessian=True,
        hyper_par_objective_fun=None,
        grad_tol=1e-8)

print('time: {:.08}sec'.format(time.time() - t0))

computing hessian ... 
computing cross hessian ... 
solving sensitivity matrix ... 
time: 0.9278667sec


## Preconditioned CG?

In [20]:
mfvb_cov, mfvb_info = preconditioner_lib.get_mfvb_cov_preconditioner(vb_params_dict, vb_params_paragami,
                        use_logitnormal_sticks)

In [21]:
hvp = autograd.hessian_vector_product(objective_fun, argnum=0)

In [22]:
opt0 = deepcopy(vb_opt_free_params)
hyper0 = deepcopy(prior_params_paragami.flatten(prior_params_dict, free=True))

In [23]:
system_solver = preconditioner_lib.SystemSolverFromHVP(hvp, opt0, hyper0)

In [24]:
_hyper_obj_fun_grad = \
            autograd.grad(objective_fun, argnum=0)
_hyper_obj_cross_hess = autograd.jacobian(
    _hyper_obj_fun_grad, argnum=1)

In [25]:
cross_hess = _hyper_obj_cross_hess(opt0, hyper0)

In [26]:
sens_mat = - system_solver.solve(cross_hess)

In [27]:
# check against our old sensitivity matrix
np.max(np.abs(vb_sens._sens_mat - sens_mat))

7.134802772573723e-05

In [28]:
# check the new arguments to HyperparameterSensitivityLinearApproximation
# with my new solver

In [29]:
t0 = time.time()

vb_sens2 = \
    vittles.HyperparameterSensitivityLinearApproximation(
        objective_fun = objective_fun,
        opt_par_value = vb_opt_free_params, 
        hyper_par_value = prior_params_paragami.flatten(prior_params_dict, free=True),
        validate_optimum=False,
        factorize_hessian=True,
        hyper_par_objective_fun=None,
        grad_tol=1e-8, 
        system_solver=system_solver,
        compute_hess=False)

print('time: {:.08}sec'.format(time.time() - t0))

computing cross hessian ... 
solving sensitivity matrix ... 
time: 1.5493872sec


In [30]:
np.max(np.abs(vb_sens._sens_mat - vb_sens2._sens_mat))

7.134802772573723e-05

In [31]:
# Now check with preconditioner
system_solver2 = preconditioner_lib.SystemSolverFromHVP(hvp, opt0, hyper0, 
                                    cg_opts = {'M': mfvb_info})

In [32]:
t0 = time.time()

vb_sens3 = \
    vittles.HyperparameterSensitivityLinearApproximation(
        objective_fun = objective_fun,
        opt_par_value = vb_opt_free_params, 
        hyper_par_value = prior_params_paragami.flatten(prior_params_dict, free=True),
        validate_optimum=False,
        factorize_hessian=True,
        hyper_par_objective_fun=None,
        grad_tol=1e-8, 
        system_solver=system_solver2,
        compute_hess=False)

print('time: {:.08}sec'.format(time.time() - t0))

computing cross hessian ... 
solving sensitivity matrix ... 
time: 1.2601016sec


In [33]:
np.max(np.abs(vb_sens._sens_mat - vb_sens3._sens_mat))

1.837315988245658e-05