In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils, cavi_lib
from vb_lib.preconditioner_lib import get_mfvb_cov_matmul
from bnpmodeling_runjingdev.sensitivity_lib import HyperparameterSensitivityLinearApproximation, get_jac_hvp_fun

import paragami

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib




In [2]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [42]:
# data = np.load('../simulated_data/simulated_structure_data_nobs20_nloci50_npop4.npz')
data = np.load('/scratch/users/genomic_times_series_bnp/structure/simulated_data/' + 
               'simulated_structure_data_nobs200_nloci500_npop4.npz')
g_obs = np.array(data['g_obs'])

In [43]:
# data_dir = '../../../../fastStructure/hgdp_data/huang2011_plink_files/'
# filenamebase = 'phased_HGDP+India+Africa_2810SNPs-regions1to36'
# filename = data_dir + filenamebase + '.npz'
# data = np.load(filename)

# g_obs = np.array(data['g_obs'])
# g_obs_raw = np.array(data['g_obs_raw'])

# # just checking ... 
# which_missing = (g_obs_raw == 3)
# (g_obs.argmax(-1) == g_obs_raw)[~which_missing].all()
# (g_obs[which_missing] == 0).all()

In [44]:
n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

# Get prior

In [45]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_paragami)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

OrderedDict:
	[dp_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_beta] = NumericArrayPattern (1,) (lb=0.0, ub=inf)


# Get VB params 

In [46]:
k_approx = 15

In [47]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [48]:
use_logitnormal_sticks = True

_, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (500, 15, 2) (lb=0.0, ub=inf)
	[ind_mix_stick_propn_mean] = NumericArrayPattern (200, 14) (lb=-inf, ub=inf)
	[ind_mix_stick_propn_info] = NumericArrayPattern (200, 14) (lb=0.0001, ub=inf)


In [49]:
vb_free_params = vb_params_paragami.flatten(vb_params_paragami.random(), free = True)

# Benchmarking

In [50]:
from bnpmodeling_runjingdev.optimization_lib import OptimizationObjective, OptimizationObjectiveJaxtoNumpy

In [51]:
_kl_fun_free = paragami.FlattenFunctionInput(
                                original_fun=structure_model_lib.get_kl,
                                patterns = vb_params_paragami,
                                free = True,
                                argnums = 1)

kl_fun_free = lambda x : _kl_fun_free(g_obs, x, prior_params_dict,
                                                 gh_loc, gh_weights,
                                                 log_phi = None,
                                                  epsilon = 0.)

In [52]:
x = vb_params_paragami.flatten(vb_params_paragami.random(), free = True)

In [53]:
kl_fun_free(x)

DeviceArray(206710.48885037, dtype=float64)

In [54]:
optim_objective1 = OptimizationObjective(kl_fun_free, print_every = 1e16)
optim_objective2 = OptimizationObjectiveJaxtoNumpy(kl_fun_free, vb_free_params, print_every = 1e16)

Compiling objective ...
Iter 0: f = 206710.48885037
Compiling grad ...
Compile time: 11.8376secs


In [58]:
t0 = time.time()
_ = optim_objective1.f(vb_free_params).block_until_ready()
time.time() - t0

0.7237265110015869

In [63]:
t0 = time.time()
_ = optim_objective2.f_np(vb_free_params)
time.time() - t0

0.2738654613494873

In [32]:
%timeit optim_objective1.f(vb_free_params).block_until_ready()

Iter 0: f = 5841761.27007069


KeyboardInterrupt: 

In [None]:
# this should be jitted and hence faster
%timeit optim_objective2.f(vb_free_params).block_until_ready()

In [None]:
# this returns a numpy array
%timeit optim_objective2.f_np(vb_free_params)