In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils
import vb_lib.cavi_logitnormal_sticks_lib as cavi_logit_lib 
from vb_lib.preconditioner_lib import get_mfvb_cov_matmul
import vb_lib.structure_optimization_lib as s_optim_lib

from bnpmodeling_runjingdev.sensitivity_lib import HyperparameterSensitivityLinearApproximation, get_jac_hvp_fun

import paragami

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib




In [2]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [3]:
# data = np.load('../simulated_data/simulated_structure_data_nobs20_nloci50_npop4.npz')
# g_obs = np.array(data['g_obs'], dtype = int)

In [4]:
data_dir = '../../../../fastStructure/hgdp_data/huang2011_plink_files/'
filenamebase = 'phased_HGDP+India+Africa_2810SNPs-regions1to36'
filename = data_dir + filenamebase + '.npz'
data = np.load(filename)

g_obs = np.array(data['g_obs'], dtype = int)
g_obs_raw = np.array(data['g_obs_raw'])

# just checking ... 
which_missing = (g_obs_raw == 3)
(g_obs.argmax(-1) == g_obs_raw)[~which_missing].all()
(g_obs[which_missing] == 0).all()

DeviceArray(True, dtype=bool)

In [5]:
n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

In [6]:
print(n_obs)
print(n_loci)

1107
2810


# Get prior

In [7]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_paragami)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

OrderedDict:
	[dp_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_beta] = NumericArrayPattern (1,) (lb=0.0, ub=inf)


# Get VB params 

In [8]:
k_approx = 15

In [9]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [10]:
use_logitnormal_sticks = True

vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (2810, 15, 2) (lb=0.0, ub=inf)
	[ind_admix_params] = OrderedDict:
	[stick_means] = NumericArrayPattern (1107, 14) (lb=-inf, ub=inf)
	[stick_infos] = NumericArrayPattern (1107, 14) (lb=0.0001, ub=inf)


In [11]:
vb_params_dict = vb_params_paragami.random(key=jax.random.PRNGKey(41))

In [12]:
structure_model_lib.get_kl(g_obs, vb_params_dict, prior_params_dict,
                                    gh_loc, gh_weights)

DeviceArray(5866337.80064644, dtype=float64)

# Timing results

## set init

In [13]:
vb_params_dict = s_optim_lib.set_init_vb_params(g_obs, k_approx, vb_params_dict,
                                                      prior_params_dict,
                                                      gh_loc, gh_weights,
                                                      seed = 3421)

running NMF ...
running a few cavi steps for pop beta ...
done. Elapsed: 32.1097


In [14]:
structure_model_lib.get_kl(g_obs, vb_params_dict, prior_params_dict,
                                    gh_loc, gh_weights)

DeviceArray(5198875.61385809, dtype=float64)

In [16]:
cavi_logit_lib.run_stoch_cavi(g_obs, vb_params_dict,
                    vb_params_paragami,
                    prior_params_dict,
                    gh_loc, gh_weights, 
                    batchsize = 100,
                    debug = False)

compiling stick objective and gradients ...
sticks compile time: 7.41sec
done compiling. 
Compiling ... first iteration of cavi might be slow
iteration [1]; kl:4894634.279518; elapsed: 16.4269secs
iteration [2]; kl:4855327.468211; elapsed: 17.165secs
iteration [3]; kl:4818247.0225; elapsed: 16.9571secs
iteration [4]; kl:4771253.490793; elapsed: 16.6934secs
iteration [5]; kl:4724350.643236; elapsed: 16.6961secs
iteration [6]; kl:4678835.52233; elapsed: 16.8344secs
iteration [7]; kl:4636742.952323; elapsed: 17.361secs
iteration [8]; kl:4596125.394588; elapsed: 16.9968secs
iteration [9]; kl:4552878.245967; elapsed: 16.6874secs
iteration [10]; kl:4515099.018962; elapsed: 16.7473secs
iteration [11]; kl:4475538.888854; elapsed: 16.818secs
iteration [12]; kl:4438165.990076; elapsed: 16.7119secs
iteration [13]; kl:4399818.359445; elapsed: 16.7037secs
iteration [14]; kl:4363602.5968; elapsed: 17.2026secs
iteration [15]; kl:4329722.276155; elapsed: 17.143secs


KeyboardInterrupt: 

In [None]:
?random.choice