In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils, cavi_lib
from vb_lib.preconditioner_lib import get_mfvb_cov_matmul
import vb_lib.structure_optimization_lib as s_optim_lib

from bnpmodeling_runjingdev.sensitivity_lib import HyperparameterSensitivityLinearApproximation, get_jac_hvp_fun

import paragami

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib




In [2]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [3]:
# data = np.load('../simulated_data/simulated_structure_data_nobs20_nloci50_npop4.npz')
# g_obs = np.array(data['g_obs'], dtype = int)

In [4]:
data_dir = '../../../../fastStructure/hgdp_data/huang2011_plink_files/'
filenamebase = 'phased_HGDP+India+Africa_2810SNPs-regions1to36'
filename = data_dir + filenamebase + '.npz'
data = np.load(filename)

g_obs = np.array(data['g_obs'], dtype = int)
g_obs_raw = np.array(data['g_obs_raw'])

# just checking ... 
which_missing = (g_obs_raw == 3)
(g_obs.argmax(-1) == g_obs_raw)[~which_missing].all()
(g_obs[which_missing] == 0).all()

DeviceArray(True, dtype=bool)

In [5]:
n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

In [6]:
print(n_obs)
print(n_loci)

1107
2810


# Get prior

In [7]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_paragami)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

OrderedDict:
	[dp_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_beta] = NumericArrayPattern (1,) (lb=0.0, ub=inf)


# Get VB params 

In [8]:
k_approx = 15

In [9]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [10]:
use_logitnormal_sticks = True

vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (2810, 15, 2) (lb=0.0, ub=inf)
	[ind_admix_params] = OrderedDict:
	[stick_means] = NumericArrayPattern (1107, 14) (lb=-inf, ub=inf)
	[stick_infos] = NumericArrayPattern (1107, 14) (lb=0.0001, ub=inf)


In [11]:
vb_params_dict = vb_params_paragami.random(key=jax.random.PRNGKey(41))

In [12]:
structure_model_lib.get_kl(g_obs, vb_params_dict, prior_params_dict,
                                    gh_loc, gh_weights)

DeviceArray(5866337.80064644, dtype=float64)

# Timing results

## set init

In [13]:
vb_params_dict = \
        s_optim_lib.set_init_vb_params(g_obs, k_approx, vb_params_dict,
                                                      prior_params_dict,
                                                      gh_loc, gh_weights,
                                                      seed = 3421)

running NMF ...
running a few cavi steps for pop beta ...
done. Elapsed: 32.3909


In [14]:
structure_model_lib.get_kl(g_obs, vb_params_dict, prior_params_dict,
                                    gh_loc, gh_weights)

DeviceArray(5198875.61385809, dtype=float64)

## update population betas

In [15]:
dp_prior_alpha = prior_params_dict['dp_prior_alpha']
allele_prior_alpha = prior_params_dict['allele_prior_alpha']
allele_prior_beta = prior_params_dict['allele_prior_beta']

# get initial moments from vb_params
e_log_sticks, e_log_1m_sticks, \
    e_log_pop_freq, e_log_1m_pop_freq = \
        structure_model_lib.get_moments_from_vb_params_dict(
            vb_params_dict, gh_loc, gh_weights)

e_log_cluster_probs = \
    modeling_lib.get_e_log_cluster_probabilities_from_e_log_stick(
                        e_log_sticks, e_log_1m_sticks)

In [16]:
update_pop_beta = jax.jit(cavi_lib.update_pop_beta)

In [17]:
# compile update
out = update_pop_beta(g_obs, e_log_pop_freq, e_log_1m_pop_freq, 
                       e_log_cluster_probs, prior_params_dict)
_ = out[0].block_until_ready()

In [18]:
# timeit!
t0 = time.time() 

_ = update_pop_beta(g_obs, e_log_pop_freq, e_log_1m_pop_freq, \
                       e_log_cluster_probs, prior_params_dict)[0].block_until_ready()

print(time.time() - t0)

4.152426242828369


## gradients of admixture stick ps-loss

In [19]:
stick_objective = s_optim_lib.StickObjective(g_obs, 
                                             vb_params_paragami,
                                             prior_params_dict,
                                             gh_loc, gh_weights)

compiling stick objective and gradients ...
compile time: 21.2sec


In [20]:
ind_admix_params_free = vb_params_paragami['ind_admix_params'].flatten(
                            vb_params_dict['ind_admix_params'], free = True)

In [21]:
t0 = time.time() 
_ = stick_objective.f(g_obs, ind_admix_params_free, e_log_pop_freq, e_log_1m_pop_freq).block_until_ready()
print(time.time() - t0)

4.1950438022613525


In [22]:
t0 = time.time() 
_ = stick_objective.grad(g_obs, ind_admix_params_free, e_log_pop_freq, e_log_1m_pop_freq).block_until_ready()
print(time.time() - t0)

4.7713096141815186


In [23]:
t0 = time.time() 
_ = stick_objective.hvp(g_obs, ind_admix_params_free, e_log_pop_freq, e_log_1m_pop_freq, 
                        ind_admix_params_free).block_until_ready()
print(time.time() - t0)

6.440979719161987


# Subsample?

In [56]:
batchsize = 1

vb_params_dict_sub, vb_params_paragami_sub = \
    structure_model_lib.get_vb_params_paragami_object(batchsize,
                                                      n_loci,
                                                      k_approx,
                                                      use_logitnormal_sticks = use_logitnormal_sticks)

x = vb_params_paragami_sub['ind_admix_params'].flatten(
                            vb_params_dict_sub['ind_admix_params'], free = True)

g_obs_sub = g_obs[0:batchsize]

In [57]:
stick_objective_sub = s_optim_lib.StickObjective(g_obs_sub, vb_params_paragami_sub, prior_params_dict,
                                                 gh_loc, gh_weights, 
                                                 compute_hess = True)

compiling stick objective and gradients ...
compile time: 7.71sec


In [58]:
indx = onp.random.choice(n_obs - batchsize)

In [59]:
t0 = time.time() 
_ = stick_objective_sub.f(g_obs[indx:(indx+batchsize)], 
                          x,
                          e_log_pop_freq, e_log_1m_pop_freq).block_until_ready()
print(time.time() - t0)

0.00915670394897461


In [60]:
t0 = time.time() 
_ = stick_objective_sub.grad(g_obs[indx:(indx+batchsize)], x, 
                             e_log_pop_freq, e_log_1m_pop_freq).block_until_ready()
print(time.time() - t0)

0.00901341438293457


In [61]:
t0 = time.time() 
_ = stick_objective_sub.hvp(g_obs[indx:(indx+batchsize)], x, e_log_pop_freq, e_log_1m_pop_freq, 
                            x).block_until_ready()
print(time.time() - t0)

0.010412454605102539


In [62]:
t0 = time.time() 
hess = stick_objective_sub.hess(g_obs[indx:(indx+batchsize)], x,
                             e_log_pop_freq, e_log_1m_pop_freq).block_until_ready()
print(time.time() - t0)

0.030073165893554688


# Optimize one stick parameter

In [63]:
t0 = time.time()
stick_updates, out = stick_objective_sub.optimize_sticks(g_obs[indx:(indx+batchsize)], 
                                    vb_params_dict_sub['ind_admix_params'],
                                    e_log_pop_freq, e_log_1m_pop_freq)

print(time.time() - t0)

0.09354281425476074


In [64]:
stick_objective_sub.f(g_obs[indx:(indx+batchsize)], 
                        x,
                        e_log_pop_freq, e_log_1m_pop_freq)

DeviceArray(7517.81766931, dtype=float64)

In [65]:
stick_objective_sub.f(g_obs[indx:(indx+batchsize)], 
                        out.x,
                        e_log_pop_freq, e_log_1m_pop_freq)

DeviceArray(7261.2916878, dtype=float64)