In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils, cavi_lib
import vb_lib.structure_optimization_lib as s_optim_lib
from vb_lib.preconditioner_lib import get_mfvb_cov_matmul

import paragami
import vittles

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib
from bnpmodeling_runjingdev.sensitivity_lib import get_jac_hvp_fun, HyperparameterSensitivityLinearApproximation



In [2]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [3]:
data_file = '../../../../../fastStructure/hgdp_data/huang2011_plink_files/' + \
                'phased_HGDP+India+Africa_2810SNPs-regions1to36.npz'

# data_file = '../../simulated_data/simulated_structure_data_nobs20_nloci50_npop4.npz'
data = np.load(data_file)
g_obs = np.array(data['g_obs'], dtype = int)

In [4]:
n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

print(n_obs)
print(n_loci)

1107
2810


# Load fit

In [5]:
# fit_file = '../../fits/fits_20201122/simulated_fit_alpha7.5.npz'
fit_file = '../../fits/fits_20201122/huang2011_fit_alpha3.5.npz'

In [6]:
vb_opt_dict, vb_params_paragami, meta_data = \
    paragami.load_folded(fit_file)

vb_opt = vb_params_paragami.flatten(vb_opt_dict, free = True)

k_approx = vb_opt_dict['pop_freq_beta_params'].shape[1]

gh_deg = int(meta_data['gh_deg'])
gh_loc, gh_weights = hermgauss(gh_deg)

gh_loc = np.array(gh_loc)
gh_weights = np.array(gh_weights)


# Get prior

In [7]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

prior_params_dict['dp_prior_alpha'] = np.array(meta_data['dp_prior_alpha'])
prior_params_dict['allele_prior_alpha'] = np.array(meta_data['allele_prior_alpha'])
prior_params_dict['allele_prior_beta'] = np.array(meta_data['allele_prior_beta'])

print(prior_params_dict)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)


{'dp_prior_alpha': DeviceArray([3.5], dtype=float64), 'allele_prior_alpha': DeviceArray([1.], dtype=float64), 'allele_prior_beta': DeviceArray([1.], dtype=float64)}


# Define objective 

In [8]:
stru_objective = s_optim_lib.StructureObjective(g_obs, 
                                                vb_params_paragami,
                                                prior_params_dict, 
                                                gh_loc, gh_weights, 
                                                jit_functions = False)

# check KL's match
kl = stru_objective.f(vb_opt)
diff = np.abs(kl - meta_data['final_kl'])
assert diff < 1e-8, diff


# Define preconditioner

In [9]:
cg_precond = lambda v : get_mfvb_cov_matmul(v, vb_opt_dict,
                                            vb_params_paragami,
                                            return_sqrt = False, 
                                            return_info = True)


# Hyper parameter objective function

In [10]:
import bnpmodeling_runjingdev.exponential_families as ef

In [11]:
def _hyper_par_objective_fun(vb_params_dict, alpha): 
    
    means = vb_params_dict['ind_admix_params']['stick_means']
    infos = vb_params_dict['ind_admix_params']['stick_infos']

    e_log_1m_sticks = \
        ef.get_e_log_logitnormal(
            lognorm_means = means,
            lognorm_infos = infos,
            gh_loc = gh_loc,
            gh_weights = gh_weights)[1]

    return - (alpha - 1) * np.sum(e_log_1m_sticks)

hyper_par_objective_fun = paragami.FlattenFunctionInput(
                                original_fun=_hyper_par_objective_fun, 
                                patterns = [vb_params_paragami, prior_params_paragami['dp_prior_alpha']],
                                free = [True, True],
                                argnums = [0, 1])

In [12]:
alpha0 = prior_params_dict['dp_prior_alpha']
alpha_free = prior_params_paragami['dp_prior_alpha'].flatten(alpha0, 
                                                              free = True)

In [None]:
vb_sens = HyperparameterSensitivityLinearApproximation(stru_objective.f, 
                                             vb_opt, 
                                             alpha_free, 
                                             obj_fun_hvp=stru_objective.hvp, 
                                             hyper_par_objective_fun=hyper_par_objective_fun, 
                                             cg_precond=cg_precond)

NOTE custom hvp
Compiling hessian solver ...


In [None]:
vb_sens.lr_time

In [18]:
cross_hessian = dobj_dhyper_dinput(vb_opt, alpha_free).block_until_ready()

In [27]:
from jax.scipy.sparse.linalg import cg

In [28]:
@jax.jit
def invert_hessian(b): 
    return cg(A = stru_hvp, 
                b = b, 
                M = cg_precond)[0]

In [29]:
# compile time
t0 = time.time()
_ = invert_hessian(cross_hessian.squeeze() * 0.).block_until_ready()
print(time.time() - t0)

75.18531966209412


In [30]:
t0 = time.time()
_ = invert_hessian(cross_hessian.squeeze()).block_until_ready()
print(time.time() - t0)

0.44893884658813477


In [40]:
stru_hvp = jax.jit(lambda x : stru_objective.hvp(vb_opt, x))

print('compiling hessian vector products ...')
t0 = time.time()
_ = stru_hvp(vb_opt).block_until_ready()

compiling hessian vector products ...


In [46]:
# runtime ...
for i in range(5): 
    t0 = time.time()
    _ = stru_hvp(vb_opt).block_until_ready()
    print(time.time() - t0)

0.03542518615722656
0.029263734817504883
0.029251575469970703
0.029056310653686523
0.029242753982543945


# define preconditioner

In [47]:
# define preconditioner
cg_precond = jax.jit(lambda v : get_mfvb_cov_matmul(v, vb_opt_dict,
                                            vb_params_paragami,
                                            return_sqrt = False, 
                                            return_info = True))

print('compiling preconditioner ...')
_ = cg_precond(vb_opt).block_until_ready()


compiling preconditioner ...


# Get alpha sensitivity 

In [50]:
print('Computing cross-hessian ...  ')

# actual runtime: 
t0 = time.time() 
cross_hessian = dobj_dhyper_dinput(vb_opt, alpha_free)
print('Cross-hessian time: {:.3f}'.format(time.time() - t0))


Computing cross-hessian ...  
Cross-hessian time: 0.001


In [55]:
# from jax.scipy.sparse.linalg import cg
from scipy.sparse.linalg import cg

In [62]:
import scipy.sparse as sparse

In [79]:
A = sparse.linalg.LinearOperator(shape = (len(vb_opt), len(vb_opt)), 
                                  matvec = stru_hvp)
M = sparse.linalg.LinearOperator(shape = (len(vb_opt), len(vb_opt)), 
                                  matvec = cg_precond)

t0 = time.time()
cg_out = cg(A = A, 
            b = cross_hessian.squeeze(), 
            M = M, 
            callback = lambda x : print('foo'))

print(time.time() - t0)

foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
foo
1.256953239440918
