In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils, cavi_lib
import vb_lib.structure_optimization_lib as s_optim_lib
from vb_lib.preconditioner_lib import get_mfvb_cov_matmul

import paragami
import vittles

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib
from bnpmodeling_runjingdev.sensitivity_lib import get_jac_hvp_fun



In [2]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [3]:
# data_file = '../../../../../fastStructure/hgdp_data/huang2011_plink_files/' + \
#                 'phased_HGDP+India+Africa_2810SNPs-regions1to36.npz'
data_file = '../../simulated_data/simulated_structure_data_nobs20_nloci50_npop4.npz'

data = np.load(data_file)
g_obs = np.array(data['g_obs'], dtype = int)

In [4]:
n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

print(n_obs)
print(n_loci)

20
50


# Get prior

In [5]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_dict)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

{'dp_prior_alpha': DeviceArray([3.], dtype=float64), 'allele_prior_alpha': DeviceArray([1.], dtype=float64), 'allele_prior_beta': DeviceArray([1.], dtype=float64)}


# Get VB params 

In [6]:
k_approx = 8

In [7]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [8]:
use_logitnormal_sticks = True

vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (50, 8, 2) (lb=0.0, ub=inf)
	[ind_admix_params] = OrderedDict:
	[stick_means] = NumericArrayPattern (20, 7) (lb=-inf, ub=inf)
	[stick_infos] = NumericArrayPattern (20, 7) (lb=0.0001, ub=inf)


## Initialize 

In [9]:
vb_params_dict = \
        s_optim_lib.set_nmf_init_vb_params(g_obs, k_approx, vb_params_dict,
                                                seed = 143241)

In [10]:
structure_model_lib.get_kl(g_obs, vb_params_dict, prior_params_dict,
                            gh_loc, gh_weights)

DeviceArray(5396.53199239, dtype=float64)

In [11]:
vb_params_free = vb_params_paragami.flatten(vb_params_dict, free = True)

# Define objective

In [12]:
import inspect
lines = inspect.getsource(structure_model_lib.get_e_loglik)
print(lines)

def get_e_loglik(g_obs, e_log_pop_freq, e_log_1m_pop_freq, \
                    e_log_sticks, e_log_1m_sticks,
                    detach_ez, detach_vb_params):


    e_log_cluster_probs = \
        modeling_lib.get_e_log_cluster_probabilities_from_e_log_stick(
                            e_log_sticks, e_log_1m_sticks)
    
    body_fun = lambda val, x : get_e_loglik_l(x[0], x[1], x[2],
                                             e_log_cluster_probs,
                                             detach_ez, 
                                             detach_vb_params) + val
    
    scan_fun = lambda val, x : (body_fun(val, x), None)
    
    return jax.lax.scan(scan_fun,
                        init = 0.,
                        xs = (g_obs.transpose((1, 0, 2)),
                              e_log_pop_freq, 
                              e_log_1m_pop_freq))[0]



In [13]:
_kl_fun_free = paragami.FlattenFunctionInput(
                                original_fun=structure_model_lib.get_kl,
                                patterns = vb_params_paragami,
                                free = True,
                                argnums = 1)

kl_fun_free = lambda x : \
                _kl_fun_free(g_obs, x, prior_params_dict,
                                gh_loc, gh_weights, 
                                detach_ez = False, 
                                detach_vb_params = False)



get_kl_grad = jax.grad(kl_fun_free)
get_kl_hessian = jax.hessian(kl_fun_free)

In [14]:
kl_fun_free(vb_params_free)

DeviceArray(5396.53199239, dtype=float64)

In [15]:
kl_grad = get_kl_grad(vb_params_free)

In [16]:
kl_hess = get_kl_hessian(vb_params_free)

# Custom hvp

In [17]:
kl_fun_free_customized = lambda x, detach_ez, detach_vb_params: \
                _kl_fun_free(g_obs, x, prior_params_dict,
                                gh_loc, gh_weights, 
                                detach_ez = detach_ez, 
                                detach_vb_params = detach_vb_params)

In [18]:
# get_dkl_dtheta2_v = get_jac_hvp_fun(lambda x : kl_fun_free_customized(x, True, False))

In [19]:
term1 = jax.hessian(lambda x : kl_fun_free_customized(x, True, False))(vb_params_free)

detaching ez


In [20]:
from vb_lib.structure_model_lib import *

In [21]:
def ps_loss_z(vb_params_dict): 
    
    e_log_sticks, e_log_1m_sticks, \
        e_log_pop_freq, e_log_1m_pop_freq = \
            get_moments_from_vb_params_dict(vb_params_dict,
                                    gh_loc = gh_loc,
                                    gh_weights = gh_weights)
    
    e_log_cluster_probs = \
        modeling_lib.get_e_log_cluster_probabilities_from_e_log_stick(
                            e_log_sticks, e_log_1m_sticks)
    
    body_fun = lambda x : np.sqrt(get_optimal_ezl(x[0], x[1], x[2],
                                             e_log_cluster_probs,
                                             detach_ez = False)[1]).flatten()
        
    return jax.lax.map(body_fun,
                        xs = (g_obs.transpose((1, 0, 2)),
                              e_log_pop_freq, 
                              e_log_1m_pop_freq)).flatten()

In [22]:
ps_loss_z_free = paragami.FlattenFunctionInput(
                                original_fun=ps_loss_z,
                                patterns = vb_params_paragami,
                                free = True,
                                argnums = 0)

In [24]:
jax.jvp?

In [28]:
grad_tmp = jax.jvp(ps_loss_z_free, (vb_params_free, ), (vb_params_free, ))[1]

In [43]:
grad_tmp

DeviceArray([-0.26249259, -0.26249259, -0.28121894, ...,  0.00209781,
              0.43190552,  0.43190552], dtype=float64)

In [31]:
y, vjp_fun = jax.vjp(ps_loss_z_free, vb_params_free)

DeviceArray([ 0.15906895, -0.37978205, -0.06907311, ...,  0.3872382 ,
              0.44535961,  0.47694177], dtype=float64)

In [40]:
np.dot(term1, vb_params_free) - 4 * vjp_fun(grad_tmp)[0]

DeviceArray([ 2.25740299, -0.09683445, -5.82278243, ...,  0.69953368,
             -0.01185289, -0.50171633], dtype=float64)

In [41]:
np.dot(kl_hess, vb_params_free)

DeviceArray([ 2.25740299, -0.09683445, -5.82278243, ...,  0.69953368,
             -0.01185289, -0.50171633], dtype=float64)

In [26]:
def ps_loss(x1, x2): 
    kl1 = kl_fun_free_customized(x1, True, False)
    kl2 = kl_fun_free_customized(x2, False, True)
    
    return kl1 * kl2 / jax.lax.stop_gradient(kl1)

In [27]:
get_dkl_dtheta = jax.grad(ps_loss, argnums = 0)

In [39]:
get_dkl_dtheta(vb_params_free, vb_params_free) - kl_grad

detaching ez
detaching vb params


DeviceArray([-8.88178420e-16,  8.88178420e-16, -5.32907052e-15, ...,
              0.00000000e+00,  4.44089210e-16,  0.00000000e+00],            dtype=float64)

In [28]:
term3 = jax.jacobian(lambda x : get_dkl_dtheta(vb_params_free, x))(vb_params_free)

# term2 = jax.hessian(lambda x : jax.grad(ps_loss, argnums = 1)(vb_params_free, x))(vb_params_free)

detaching ez
detaching vb params


In [46]:
np.abs(term3).max()

DeviceArray(5.48943093e-17, dtype=float64)

In [40]:
term1 - term2 + 2 * term3

DeviceArray([[ 4.84130351, -1.25762774,  0.74459675, ...,  1.00809741,
               0.7334582 ,  0.33664258],
             [-1.25762774,  0.96394136, -0.74131977, ..., -1.00366076,
              -0.73023024, -0.33516101],
             [ 0.74459675, -0.74131977,  9.84005946, ..., -0.22693199,
              -0.16510818, -0.07578134],
             ...,
             [ 1.00809741, -1.00366076, -0.22693199, ...,  1.73172179,
              -0.22353727, -0.10259912],
             [ 0.7334582 , -0.73023024, -0.16510818, ..., -0.22353727,
               1.62606697, -0.07464771],
             [ 0.33664258, -0.33516101, -0.07578134, ..., -0.10259912,
              -0.07464771,  1.60220803]], dtype=float64)

In [41]:
kl_hess

DeviceArray([[ 5.96604554e+00, -2.74168502e+00,  1.86555298e+00, ...,
               3.67664923e-06,  3.96854417e-06,  4.31185842e-06],
             [-2.74168502e+00,  2.57629049e+00, -1.36591512e+00, ...,
              -2.66603883e-06, -2.87769983e-06, -3.12664638e-06],
             [ 1.86555298e+00, -1.36591512e+00,  2.18200337e+00, ...,
               1.06604788e-06,  1.15068309e-06,  1.25022738e-06],
             ...,
             [ 3.67664923e-06, -2.66603883e-06,  1.06604788e-06, ...,
               1.98575217e+00, -4.72202325e-02, -4.25869908e-02],
             [ 3.96854417e-06, -2.87769983e-06,  1.15068309e-06, ...,
              -4.72202325e-02,  1.73085000e+00, -5.16537938e-02],
             [ 4.31185842e-06, -3.12664638e-06,  1.25022738e-06, ...,
              -4.25869908e-02, -5.16537938e-02,  1.57707094e+00]],            dtype=float64)

In [None]:
foo3 = jax.hessian(ps_loss, argnums = 1)(vb_params_free, vb_params_free)

In [141]:
hvp_true

DeviceArray([-5.22529225, -0.07056143, -4.81493041, ..., -0.12960552,
             -0.60288711, -1.12083729], dtype=float64)

In [151]:
foo + 2 * foo2 + foo3

DeviceArray([[ 7.98715958e+00, -5.28266157e-01,  4.62184693e-03, ...,
              -1.05419961e-04, -7.86721146e-05, -5.90573576e-05],
             [-5.28180217e-01,  1.87539284e-01, -7.62565287e-04, ...,
               1.73933935e-05,  1.29802273e-05,  9.74395984e-06],
             [ 4.49670019e-03, -8.30200420e-04,  8.16509627e+00, ...,
              -1.08294410e-04, -8.08172401e-05, -6.06676542e-05],
             ...,
             [ 1.11415401e-04, -2.05699977e-05,  1.17638653e-04, ...,
               6.13819331e-01, -2.00242062e-06, -1.50317138e-06],
             [ 1.13075309e-04, -2.08764572e-05,  1.19391278e-04, ...,
              -2.72320222e-06,  5.90889408e-01, -1.52556619e-06],
             [ 9.91589920e-05, -1.83071660e-05,  1.04697647e-04, ...,
              -2.38805437e-06, -1.78214151e-06,  5.55193807e-01]],            dtype=float64)

In [143]:
foo2 

DeviceArray([[ 2.18867232e-03, -4.04082239e-04,  2.31092346e-03, ...,
              -5.27099803e-05, -3.93360573e-05, -2.95286788e-05],
             [-3.61112249e-04,  6.66701199e-05, -3.81282643e-04, ...,
               8.69669677e-06,  6.49011365e-06,  4.87197992e-06],
             [ 2.24835009e-03, -4.15100210e-04,  2.37393462e-03, ...,
              -5.41472051e-05, -4.04086200e-05, -3.03338271e-05],
             ...,
             [ 5.57077004e-05, -1.02849989e-05,  5.88193265e-05, ...,
              -1.34161325e-06, -1.00121031e-06, -7.51585688e-07],
             [ 5.65376546e-05, -1.04382286e-05,  5.96956389e-05, ...,
              -1.36160111e-06, -1.01612672e-06, -7.62783093e-07],
             [ 4.95794960e-05, -9.15358300e-06,  5.23488233e-05, ...,
              -1.19402719e-06, -8.91070754e-07, -6.68906441e-07]],            dtype=float64)

In [145]:
my_hess = (foo + foo2)

In [148]:
np.abs(my_hess - my_hess.transpose()).max()

DeviceArray(0.04655508, dtype=float64)

In [133]:
jax.grad(ps_loss, argnums = 1)(vb_params_free, vb_params_free)

DeviceArray([-6.42836712,  1.1868332 , -6.78743194, ...,  0.15481491,
              0.11553425,  0.08672892], dtype=float64)

In [134]:
grad_true

DeviceArray([-6.79433951,  1.12100802, -6.97959842, ..., -0.17293453,
             -0.17551098, -0.15391063], dtype=float64)

In [66]:
len(foo2)

3460

In [67]:
len(hvp_true)

3460

In [24]:
?jax.jvp

In [23]:
get_dkl_dtheta_through_ez = jax.grad(ps_loss, argnums = 1)

In [None]:
get_dkl_dtheta_through_ez(vb_params_free, 
                          vb_params_free)

In [None]:
get_dkl_dtheta_through_ez_jvp = lambda x1, x2 : np.dot(get_dkl_dtheta_through_ez, x1, x2)