In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils, cavi_lib
import vb_lib.structure_optimization_lib as s_optim_lib
from vb_lib.preconditioner_lib import get_mfvb_cov_matmul

import paragami
import vittles

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib
from bnpmodeling_runjingdev.sensitivity_lib import get_jac_hvp_fun



In [2]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [3]:
data_file = '../../../../../fastStructure/hgdp_data/huang2011_plink_files/' + \
                'phased_HGDP+India+Africa_2810SNPs-regions1to36.npz'

data = np.load(data_file)
g_obs = np.array(data['g_obs'], dtype = int)

In [4]:
n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

print(n_obs)
print(n_loci)

1107
2810


# Get prior

In [5]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_dict)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

{'dp_prior_alpha': DeviceArray([3.], dtype=float64), 'allele_prior_alpha': DeviceArray([1.], dtype=float64), 'allele_prior_beta': DeviceArray([1.], dtype=float64)}


# Get VB params 

In [6]:
k_approx = 20

In [7]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [8]:
use_logitnormal_sticks = True

vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (2810, 20, 2) (lb=0.0, ub=inf)
	[ind_admix_params] = OrderedDict:
	[stick_means] = NumericArrayPattern (1107, 19) (lb=-inf, ub=inf)
	[stick_infos] = NumericArrayPattern (1107, 19) (lb=0.0001, ub=inf)


## Initialize 

In [9]:
vb_params_dict = \
        s_optim_lib.set_nmf_init_vb_params(g_obs, k_approx, vb_params_dict,
                                                seed = 143241)

In [10]:
structure_model_lib.get_kl(g_obs, vb_params_dict, prior_params_dict,
                            gh_loc, gh_weights)

DeviceArray(37288025.69132804, dtype=float64)

In [11]:
vb_params_free = vb_params_paragami.flatten(vb_params_dict, free = True)

# Define objective

In [12]:
stru_objective = s_optim_lib.StructureObjective(g_obs, 
                                                 vb_params_paragami,
                                                 prior_params_dict, 
                                                 gh_loc, gh_weights)

compiling objective ... 
done. Elasped: 105.922


# Derivative times

In [14]:
for i in range(1): 
    t0 = time.time() 
    _ = stru_objective.f(vb_params_free).block_until_ready()
    print('elapsed: {:.3f}sec'.format(time.time() - t0))

elapsed: 10.589sec


In [15]:
for i in range(1): 
    t0 = time.time() 
    _ = stru_objective.grad(vb_params_free).block_until_ready()
    print('elapsed: {:.3f}sec'.format(time.time() - t0))

elapsed: 12.363sec


In [16]:
v = np.array(onp.random.randn(len(vb_params_free)))

In [17]:
for i in range(1): 
    t0 = time.time() 
    _ = stru_objective.hvp(vb_params_free, v).block_until_ready()
    print('elapsed: {:.3f}sec'.format(time.time() - t0))

elapsed: 29.280sec


# Preconditioned objective

In [18]:
precond_objective = s_optim_lib.StructurePrecondObjective(g_obs, 
                                                           vb_params_paragami, 
                                                           prior_params_dict, 
                                                           gh_loc, gh_weights)

compiling preconditioned objective ... 
done. Elasped: 120.225


In [19]:
t0 = time.time() 
_ = np.array(precond_objective.f_precond(vb_params_free, vb_params_free))
print('elapsed: {:.3f}sec'.format(time.time() - t0))

elapsed: 11.363sec


In [20]:
t0 = time.time() 
_ = np.array(precond_objective.grad_precond(vb_params_free, vb_params_free))
print('elapsed: {:.3f}sec'.format(time.time() - t0))

elapsed: 13.458sec


In [21]:
t0 = time.time() 
_ = np.array(precond_objective.hvp_precond(vb_params_free, vb_params_free, v))
print('elapsed: {:.3f}sec'.format(time.time() - t0))

elapsed: 31.294sec


In [22]:
for i in range(5): 
    t0 = time.time() 
    _ = precond_objective.unprecondition(vb_params_free, v).block_until_ready()
    print(time.time() - t0)

0.45952844619750977
0.4561750888824463
0.453643798828125
0.4516425132751465
0.4517691135406494
