In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils, cavi_lib
from vb_lib.preconditioner_lib import get_mfvb_cov_matmul
from bnpmodeling_runjingdev.sensitivity_lib import HyperparameterSensitivityLinearApproximation, get_jac_hvp_fun

import paragami

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib




In [2]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [3]:
# data = np.load('../simulated_data/simulated_structure_data_nobs20_nloci50_npop4.npz')
# data = np.load('/scratch/users/genomic_times_series_bnp/structure/simulated_data/' + 
#                'simulated_structure_data_nobs200_nloci500_npop4.npz')
# g_obs = np.array(data['g_obs'])

In [4]:
data_dir = '../../../../fastStructure/hgdp_data/huang2011_plink_files/'
filenamebase = 'phased_HGDP+India+Africa_2810SNPs-regions1to36'
filename = data_dir + filenamebase + '.npz'
data = np.load(filename)

g_obs = np.array(data['g_obs'])
g_obs_raw = np.array(data['g_obs_raw'])

# just checking ... 
which_missing = (g_obs_raw == 3)
(g_obs.argmax(-1) == g_obs_raw)[~which_missing].all()
(g_obs[which_missing] == 0).all()

DeviceArray(True, dtype=bool)

In [5]:
n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

# Get prior

In [6]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_paragami)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

OrderedDict:
	[dp_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_beta] = NumericArrayPattern (1,) (lb=0.0, ub=inf)


# Get VB params 

In [7]:
k_approx = 15

In [8]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [13]:
use_logitnormal_sticks = True

vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (2810, 15, 2) (lb=0.0, ub=inf)
	[ind_admix_params] = OrderedDict:
	[stick_means] = NumericArrayPattern (1107, 14) (lb=-inf, ub=inf)
	[stick_infos] = NumericArrayPattern (1107, 14) (lb=0.0001, ub=inf)


# Benchmarking

In [14]:
from vb_lib.structure_optimization_lib import define_structure_objective

In [16]:
optim_objective, init_vb_free = \
    define_structure_objective(g_obs, vb_params_dict,
                            vb_params_paragami,
                            prior_params_dict,
                            gh_loc, gh_weights, 
                            compile_hvp = True)

Compiling objective ...
Iter 0: f = 5841761.27007069
Compiling grad ...
Compiling hvp ...
Compile time: 76.7123secs


In [None]:
optim_objective.f

In [12]:
_kl_fun_free = paragami.FlattenFunctionInput(
                                original_fun=structure_model_lib.get_kl,
                                patterns = vb_params_paragami,
                                free = True,
                                argnums = 1)

kl_fun_free = lambda x : _kl_fun_free(g_obs, x, prior_params_dict,
                                                 gh_loc, gh_weights,
                                                 log_phi = None,
                                                  epsilon = 0.)

In [13]:
x = vb_params_paragami.flatten(vb_params_paragami.random(), free = True)

In [14]:
kl_fun_free(x)

DeviceArray(206710.48885037, dtype=float64)

In [15]:
optim_objective1 = OptimizationObjective(kl_fun_free, print_every = 1e16)
optim_objective2 = OptimizationObjectiveJaxtoNumpy(kl_fun_free, vb_free_params, print_every = 1e16)

Compiling objective ...
Iter 0: f = 206710.48885037
Compiling grad ...
Compile time: 13.0393secs


In [16]:
t0 = time.time()
_ = optim_objective1.f(vb_free_params).block_until_ready()
time.time() - t0

Iter 0: f = 206710.48885037


1.2726421356201172

In [17]:
t0 = time.time()
_ = optim_objective2.f_np(vb_free_params)
time.time() - t0

Iter 0: f = 206710.48885037


0.2721726894378662

In [18]:
%timeit optim_objective1.f(vb_free_params).block_until_ready()

716 ms ± 17.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [19]:
# this should be jitted and hence faster
%timeit optim_objective2.f(vb_free_params).block_until_ready()

271 ms ± 862 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [20]:
# this returns a numpy array
%timeit optim_objective2.f_np(vb_free_params)

271 ms ± 900 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
