In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils, cavi_lib
import vb_lib.structure_optimization_lib as s_optim_lib
from vb_lib.preconditioner_lib import get_mfvb_cov_matmul

import paragami
import vittles

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib
from bnpmodeling_runjingdev.sensitivity_lib import get_jac_hvp_fun



In [2]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [3]:
data_file = '../../../../../fastStructure/hgdp_data/huang2011_plink_files/' + \
                'phased_HGDP+India+Africa_2810SNPs-regions1to36.npz'

data = np.load(data_file)
g_obs = np.array(data['g_obs'], dtype = int)

In [4]:
n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

print(n_obs)
print(n_loci)

1107
2810


# Get prior

In [5]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_dict)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

{'dp_prior_alpha': DeviceArray([3.], dtype=float64), 'allele_prior_alpha': DeviceArray([1.], dtype=float64), 'allele_prior_beta': DeviceArray([1.], dtype=float64)}


# Get VB params 

In [6]:
k_approx = 25

In [7]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [8]:
use_logitnormal_sticks = True

vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (2810, 25, 2) (lb=0.0, ub=inf)
	[ind_admix_params] = OrderedDict:
	[stick_means] = NumericArrayPattern (1107, 24) (lb=-inf, ub=inf)
	[stick_infos] = NumericArrayPattern (1107, 24) (lb=0.0001, ub=inf)


## Initialize 

In [9]:
vb_params_dict = \
        s_optim_lib.set_nmf_init_vb_params(g_obs, k_approx, vb_params_dict,
                                                seed = 143241)

In [10]:
structure_model_lib.get_kl(g_obs, vb_params_dict, prior_params_dict,
                            gh_loc, gh_weights)

DeviceArray(41777567.83342966, dtype=float64)

In [11]:
vb_params_free = vb_params_paragami.flatten(vb_params_dict, free = True)

# Define objective

In [12]:
import inspect
lines = inspect.getsource(structure_model_lib.get_e_loglik)
print(lines)

def get_e_loglik(g_obs, e_log_pop_freq, e_log_1m_pop_freq, \
                    e_log_sticks, e_log_1m_sticks,
                    detach_ez, detach_vb_params):


    e_log_cluster_probs = \
        modeling_lib.get_e_log_cluster_probabilities_from_e_log_stick(
                            e_log_sticks, e_log_1m_sticks)
    
    body_fun = lambda val, x : get_e_loglik_l(x[0], x[1], x[2],
                                             e_log_cluster_probs,
                                             detach_ez, 
                                             detach_vb_params) + val
    
    scan_fun = lambda val, x : (body_fun(val, x), None)
    
    return jax.lax.scan(scan_fun,
                        init = 0.,
                        xs = (g_obs.transpose((1, 0, 2)),
                              e_log_pop_freq, 
                              e_log_1m_pop_freq))[0]



In [13]:
_kl_fun_free = paragami.FlattenFunctionInput(
                                original_fun=structure_model_lib.get_kl,
                                patterns = vb_params_paragami,
                                free = True,
                                argnums = 1)

kl_fun_free = lambda x : _kl_fun_free(g_obs, x, prior_params_dict,
                                      gh_loc, gh_weights, 
                                      detach_ez = True)

get_kl = jax.jit(kl_fun_free)
get_grad = jax.jit(jax.grad(kl_fun_free))
get_hvp_theta = get_jac_hvp_fun(kl_fun_free)

# compile times

In [14]:
t0 = time.time() 
_ = get_kl(vb_params_free).block_until_ready()
print('elapsed: {:.3f}sec'.format(time.time() - t0))

detaching ez
elapsed: 19.336sec


In [15]:
t0 = time.time() 
_ = get_grad(vb_params_free).block_until_ready()
print('elapsed: {:.3f}sec'.format(time.time() - t0))

detaching ez
elapsed: 29.538sec


In [16]:
v = np.array(onp.random.randn(len(vb_params_free)))

In [17]:
t0 = time.time() 
_ = get_hvp_theta(vb_params_free, v).block_until_ready()
print('elapsed: {:.3f}sec'.format(time.time() - t0))

detaching ez
elapsed: 24.997sec


# Derivative times

In [18]:
for i in range(1): 
    t0 = time.time() 
    _ = get_kl(vb_params_free).block_until_ready()
    print('elapsed: {:.3f}sec'.format(time.time() - t0))

elapsed: 13.267sec


In [19]:
for i in range(1): 
    t0 = time.time() 
    _ = get_grad(vb_params_free).block_until_ready()
    print('elapsed: {:.3f}sec'.format(time.time() - t0))

elapsed: 15.121sec


In [20]:
for i in range(1): 
    t0 = time.time() 
    hvp_theta = get_hvp_theta(vb_params_free, v).block_until_ready()
    print('elapsed: {:.3f}sec'.format(time.time() - t0))

detaching ez
elapsed: 21.059sec


In [22]:
from vb_lib.structure_model_lib import *

In [23]:
def ps_loss_zl(vb_params_dict, g_obs, indx_l): 
    
    # cluster probabilitites
    e_log_sticks, e_log_1m_sticks = \
        ef.get_e_log_logitnormal(
            lognorm_means = vb_params_dict['ind_admix_params']['stick_means'],
            lognorm_infos = vb_params_dict['ind_admix_params']['stick_infos'],
            gh_loc = gh_loc,
            gh_weights = gh_weights)

    e_log_cluster_probs = \
        modeling_lib.get_e_log_cluster_probabilities_from_e_log_stick(
                            e_log_sticks, e_log_1m_sticks)
    
    # stick parameters
    pop_freq_beta_params = vb_params_dict['pop_freq_beta_params'][indx_l]
    e_log_pop_freq_l, e_log_1m_pop_freq_l = \
        modeling_lib.get_e_log_beta(pop_freq_beta_params)
    
    return np.sqrt(get_optimal_ezl(g_obs[:, indx_l], 
                                   np.expand_dims(e_log_pop_freq_l, 0),
                                   np.expand_dims(e_log_1m_pop_freq_l, 0),
                                   e_log_cluster_probs,
                                   detach_ez = False)[1]).flatten()

In [24]:
ps_loss_zl_free = paragami.FlattenFunctionInput(
                                original_fun=ps_loss_zl,
                                patterns = vb_params_paragami,
                                free = True,
                                argnums = 0)

In [33]:
@jax.jit
def get_hvp_zz(vb_params_free, v): 
    
    def body_fun(val, l): 
        fun = lambda x : ps_loss_zl_free(x, g_obs, l)
        
        return jax.vjp(fun, vb_params_free)[1](\
                jax.jvp(fun, (vb_params_free, ), (v, ))[1])[0] + val
    
    scan_fun = lambda val, l:  (body_fun(val, l), None)
    
    return jax.lax.scan(scan_fun,
             init = np.zeros(len(vb_params_free)),
             xs = np.arange(g_obs.shape[1]))[0]

In [35]:
t0 = time.time() 
hvp_zz = get_hvp_zz(vb_params_free, v).block_until_ready()
print('elapsed: {:.3f}sec'.format(time.time() - t0))

elapsed: 54.717sec


In [30]:
hvp_zz

DeviceArray([ 0.71833449, -0.07177968,  2.5052839 , ...,  7.58099324,
              8.8273031 ,  9.19396516], dtype=float64)

In [None]:
term2 = 

# Preconditioned objective

In [28]:
precond_objective = s_optim_lib.StructurePrecondObjective(g_obs, 
                                                           vb_params_paragami, 
                                                           prior_params_dict, 
                                                           gh_loc, gh_weights)

compiling preconditioned objective ... 
done. Elasped: 205.473


In [29]:
t0 = time.time() 
_ = np.array(precond_objective.f_precond(vb_params_free, vb_params_free))
print('elapsed: {:.3f}sec'.format(time.time() - t0))

elapsed: 5.284sec


In [30]:
t0 = time.time() 
_ = np.array(precond_objective.grad_precond(vb_params_free, vb_params_free))
print('elapsed: {:.3f}sec'.format(time.time() - t0))

elapsed: 9.165sec


In [31]:
t0 = time.time() 
_ = np.array(precond_objective.hvp_precond(vb_params_free, vb_params_free, vb_params_free))
print('elapsed: {:.3f}sec'.format(time.time() - t0))

elapsed: 16.065sec
