In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils, cavi_lib
import vb_lib.structure_optimization_lib as s_optim_lib
from vb_lib.preconditioner_lib import get_mfvb_cov_matmul
from bnpmodeling_runjingdev.sensitivity_lib import get_jac_hvp_fun

import paragami

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib




In [2]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [3]:
use_simulated = True 

if use_simulated: 
    data = np.load('../simulated_data/simulated_structure_data_nobs20_nloci50_npop4.npz')
    g_obs = np.array(data['g_obs'])
else: 
    data_dir = '../../../../fastStructure/hgdp_data/huang2011_plink_files/'
    filenamebase = 'phased_HGDP+India+Africa_2810SNPs-regions1to36'
    filename = data_dir + filenamebase + '.npz'
    data = np.load(filename)

    g_obs = np.array(data['g_obs'])
    g_obs_raw = np.array(data['g_obs_raw'])

    # just checking ... 
    which_missing = (g_obs_raw == 3)
    (g_obs.argmax(-1) == g_obs_raw)[~which_missing].all()
    (g_obs[which_missing] == 0).all()

In [4]:
n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

In [5]:
print(n_obs)
print(n_loci)

20
50


# Get prior

In [6]:
_, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

prior_params_dict = prior_params_paragami.random(key=jax.random.PRNGKey(41))
prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

# Get VB params 

In [7]:
k_approx = 15

In [8]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [9]:
use_logitnormal_sticks = True

vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (50, 15, 2) (lb=0.0, ub=inf)
	[ind_admix_params] = OrderedDict:
	[stick_means] = NumericArrayPattern (20, 14) (lb=-inf, ub=inf)
	[stick_infos] = NumericArrayPattern (20, 14) (lb=0.0001, ub=inf)


In [10]:
vb_params_dict = vb_params_paragami.random(key=jax.random.PRNGKey(41))

# Define ind admixture objective

In [11]:
# get initial moments from vb_params
e_log_sticks, e_log_1m_sticks, \
    e_log_pop_freq, e_log_1m_pop_freq = \
        structure_model_lib.get_moments_from_vb_params_dict(
            vb_params_dict, gh_loc, gh_weights)

In [12]:
# this uses the "psuedo-loss"
stick_objective = s_optim_lib.StickObjective(vb_params_paragami, prior_params_dict, 
                            gh_loc, gh_weights)

compiling stick objective and gradients ...
compile time: 6.68sec


In [13]:
ind_admix_params_free = vb_params_paragami['ind_admix_params'].flatten(
                            vb_params_dict['ind_admix_params'], free = True)

In [14]:
# the "correct" loss
_ind_admix_obj = paragami.FlattenFunctionInput(
                                original_fun=lambda *x :
                                                  s_optim_lib.get_ind_admix_params_loss(*x, 
                                                                                          detach_ez = True),
                                patterns = vb_params_paragami['ind_admix_params'],
                                free = True,
                                argnums = 1)

ind_admix_obj = jax.jit(_ind_admix_obj)


# Check gradients match

In [15]:
ind_admix_grad = jax.jit(jax.grad(_ind_admix_obj, argnums = 1))

In [16]:
grad1 = stick_objective.grad(g_obs, ind_admix_params_free, 
                              e_log_pop_freq, e_log_1m_pop_freq).block_until_ready()

grad2 = ind_admix_grad(g_obs, 
                    ind_admix_params_free, 
                    vb_params_dict['pop_freq_beta_params'], 
                    prior_params_dict,
                    gh_loc, gh_weights).block_until_ready()

In [17]:
np.abs(grad1 - grad2).max()

DeviceArray(2.13162821e-14, dtype=float64)

### Check timing

In [18]:
t0 = time.time()
grad1 = stick_objective.grad(g_obs, ind_admix_params_free, 
                              e_log_pop_freq, e_log_1m_pop_freq).block_until_ready()
print(time.time() - t0)

0.0036454200744628906


In [19]:
t0 = time.time()

grad2 = ind_admix_grad(g_obs, 
                    ind_admix_params_free, 
                    vb_params_dict['pop_freq_beta_params'], 
                    prior_params_dict,
                    gh_loc, gh_weights).block_until_ready()

print(time.time() - t0)


0.0071086883544921875


# Check hessian vector products

In [20]:
# our get hvp function requires a functon with only one input argument
f = lambda x : _ind_admix_obj(g_obs, 
                                x, 
                                vb_params_dict['pop_freq_beta_params'], 
                                prior_params_dict,
                                gh_loc, gh_weights)

ind_admix_hvp = jax.jit(get_jac_hvp_fun(f))

In [21]:
hvp1 = stick_objective.hvp(g_obs, ind_admix_params_free, 
                           e_log_pop_freq, e_log_1m_pop_freq, 
                           ind_admix_params_free)

hvp2 = ind_admix_hvp(ind_admix_params_free, ind_admix_params_free)

In [22]:
np.abs(hvp1 - hvp2).max()

DeviceArray(7.10542736e-15, dtype=float64)

In [23]:
t0 = time.time()
hvp1 = stick_objective.hvp(g_obs, ind_admix_params_free, 
                           e_log_pop_freq, e_log_1m_pop_freq, 
                           ind_admix_params_free).block_until_ready()
print(time.time() - t0)

0.0047512054443359375


In [24]:
t0 = time.time()
_ = ind_admix_hvp(ind_admix_params_free, ind_admix_params_free).block_until_ready()
print(time.time() - t0)

0.006402015686035156
