In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils, cavi_lib
from vb_lib.preconditioner_lib import get_mfvb_cov_matmul
from bnpmodeling_runjingdev.sensitivity_lib import HyperparameterSensitivityLinearApproximation, get_jac_hvp_fun

import paragami

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib




In [2]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [3]:
use_simulated = True

if use_simulated: 
    data = np.load('../simulated_data/simulated_structure_data_nobs20_nloci50_npop4.npz')
    g_obs = np.array(data['g_obs'])
else: 
    data_dir = '../../../../fastStructure/hgdp_data/huang2011_plink_files/'
    filenamebase = 'phased_HGDP+India+Africa_2810SNPs-regions1to36'
    filename = data_dir + filenamebase + '.npz'
    data = np.load(filename)

    g_obs = np.array(data['g_obs'])
    g_obs_raw = np.array(data['g_obs_raw'])

    # just checking ... 
    which_missing = (g_obs_raw == 3)
    (g_obs.argmax(-1) == g_obs_raw)[~which_missing].all()
    (g_obs[which_missing] == 0).all()

In [4]:
n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

In [5]:
print(n_obs)
print(n_loci)

20
50


# Get prior

In [6]:
_, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

prior_params_dict = prior_params_paragami.random(key=jax.random.PRNGKey(41))
prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

print(prior_params_dict)

OrderedDict([('dp_prior_alpha', DeviceArray([2.59776136], dtype=float64)), ('allele_prior_alpha', DeviceArray([1.15564526], dtype=float64)), ('allele_prior_beta', DeviceArray([1.0515046], dtype=float64))])


# Get VB params 

In [7]:
k_approx = 15

In [8]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [9]:
use_logitnormal_sticks = False

_, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (50, 15, 2) (lb=0.0, ub=inf)
	[ind_admix_params] = OrderedDict:
	[stick_beta] = NumericArrayPattern (20, 14, 2) (lb=0.0, ub=inf)


In [10]:
vb_params_dict = vb_params_paragami.random(key=jax.random.PRNGKey(41))

# Get moments from vb parameters

In [11]:
dp_prior_alpha = prior_params_dict['dp_prior_alpha']
allele_prior_alpha = prior_params_dict['allele_prior_alpha']
allele_prior_beta = prior_params_dict['allele_prior_beta']

# get initial moments from vb_params
e_log_sticks, e_log_1m_sticks, \
    e_log_pop_freq, e_log_1m_pop_freq = \
        structure_model_lib.get_moments_from_vb_params_dict(
            vb_params_dict, gh_loc, gh_weights)

e_log_cluster_probs = \
        modeling_lib.get_e_log_cluster_probabilities_from_e_log_stick(
            e_log_sticks, e_log_1m_sticks)

# Population beta update

In [12]:
update_pop_beta = jax.jit(cavi_lib.update_pop_beta)

In [13]:
beta_param_update = update_pop_beta(g_obs, e_log_pop_freq, e_log_1m_pop_freq, 
                                        e_log_cluster_probs, prior_params_dict)

### check against autograd

In [14]:
get_pop_beta_update1_ag = jax.jit(cavi_lib.get_pop_beta_update1_ag)
get_pop_beta_update2_ag = jax.jit(cavi_lib.get_pop_beta_update2_ag)

In [15]:
beta_update1 = get_pop_beta_update1_ag(g_obs,
                                    e_log_pop_freq, e_log_1m_pop_freq,
                                    e_log_sticks, e_log_1m_sticks,
                                    dp_prior_alpha, allele_prior_alpha,
                                    allele_prior_beta).block_until_ready()

In [16]:
beta_update2 = get_pop_beta_update2_ag(g_obs,
                                    e_log_pop_freq, e_log_1m_pop_freq,
                                    e_log_sticks, e_log_1m_sticks,
                                    dp_prior_alpha, allele_prior_alpha,
                                    allele_prior_beta).block_until_ready()

In [17]:
np.abs(beta_param_update[0][:, :, 0] - 1 - beta_update1).max()

DeviceArray(7.10542736e-15, dtype=float64)

In [18]:
np.abs(beta_param_update[0][:, :, 1] - 1 - beta_update2).max()

DeviceArray(7.10542736e-15, dtype=float64)

### Timing

It should be faster than the autograd updates ... otherwise why not just use autograd updates

In [19]:
t0 = time.time() 

_ = get_pop_beta_update1_ag(g_obs,
                            e_log_pop_freq, e_log_1m_pop_freq,
                            e_log_sticks, e_log_1m_sticks,
                            dp_prior_alpha, allele_prior_alpha,
                            allele_prior_beta).block_until_ready()

_ = get_pop_beta_update2_ag(g_obs,
                            e_log_pop_freq, e_log_1m_pop_freq,
                            e_log_sticks, e_log_1m_sticks,
                            dp_prior_alpha, allele_prior_alpha,
                            allele_prior_beta).block_until_ready()

print(time.time() - t0)

0.010199785232543945


In [20]:
t0 = time.time() 

out = update_pop_beta(g_obs, e_log_pop_freq, e_log_1m_pop_freq, 
                        e_log_cluster_probs, prior_params_dict)

out[0].block_until_ready()
print(time.time() - t0)

0.0029418468475341797


# Admixture stick updates

In [21]:
update_ind_admix_beta = jax.jit(cavi_lib.update_ind_admix_beta)

In [22]:
ind_admix_beta_update = update_ind_admix_beta(g_obs, e_log_pop_freq, e_log_1m_pop_freq, 
                            e_log_cluster_probs, prior_params_dict)

In [23]:
get_stick_update1_ag = jax.jit(cavi_lib.get_stick_update1_ag)
get_stick_update2_ag = jax.jit(cavi_lib.get_stick_update2_ag)

In [24]:
stick_update1 = get_stick_update1_ag(g_obs,
                        e_log_pop_freq, e_log_1m_pop_freq,
                        e_log_sticks, e_log_1m_sticks,
                        dp_prior_alpha, allele_prior_alpha,
                        allele_prior_beta)

stick_update2 = get_stick_update2_ag(g_obs,
                        e_log_pop_freq, e_log_1m_pop_freq,
                        e_log_sticks, e_log_1m_sticks,
                        dp_prior_alpha, allele_prior_alpha,
                        allele_prior_beta)

In [25]:
np.abs(stick_update1 + 1 -  ind_admix_beta_update[0][:, :, 0]).max()

DeviceArray(5.68434189e-14, dtype=float64)

In [26]:
np.abs(stick_update2 + 1 - ind_admix_beta_update[0][:, :, 1]).max()

DeviceArray(1.42108547e-14, dtype=float64)

### Timing

It should be faster than the autograd updates ... otherwise why not just use autograd updates

In [27]:
t0 = time.time() 

_ = get_stick_update1_ag(g_obs,
                        e_log_pop_freq, e_log_1m_pop_freq,
                        e_log_sticks, e_log_1m_sticks,
                        dp_prior_alpha, allele_prior_alpha,
                        allele_prior_beta).block_until_ready()

_ = get_stick_update2_ag(g_obs,
                        e_log_pop_freq, e_log_1m_pop_freq,
                        e_log_sticks, e_log_1m_sticks,
                        dp_prior_alpha, allele_prior_alpha,
                        allele_prior_beta).block_until_ready()
print(time.time() - t0)

0.010127544403076172


In [28]:
t0 = time.time() 

out = update_ind_admix_beta(g_obs, e_log_pop_freq, e_log_1m_pop_freq, 
                            e_log_cluster_probs, prior_params_dict)
out[0].block_until_ready()

print(time.time() - t0)

0.0032122135162353516


In [29]:
cavi_lib.run_cavi(g_obs, vb_params_dict,
                vb_params_paragami,
                prior_params_dict, 
                debug = True)

CAVI compile time: 2.69sec

 running CAVI ...
iteration [1]; kl:2036.442476; elapsed: 0.0343secs
iteration [2]; kl:2022.359559; elapsed: 0.192secs
iteration [3]; kl:2013.471353; elapsed: 0.039secs
iteration [4]; kl:2006.331928; elapsed: 0.0228secs
iteration [5]; kl:2000.120481; elapsed: 0.0195secs
iteration [6]; kl:1994.438231; elapsed: 0.0183secs
iteration [7]; kl:1989.037013; elapsed: 0.0183secs
iteration [8]; kl:1983.768687; elapsed: 0.0196secs
iteration [9]; kl:1978.57033; elapsed: 0.0183secs
iteration [10]; kl:1973.449046; elapsed: 0.0191secs
iteration [11]; kl:1968.459106; elapsed: 0.0187secs
iteration [12]; kl:1963.673786; elapsed: 0.0196secs
iteration [13]; kl:1959.159472; elapsed: 0.0197secs
iteration [14]; kl:1954.959799; elapsed: 0.0194secs
iteration [15]; kl:1951.092835; elapsed: 0.0187secs
iteration [16]; kl:1947.557926; elapsed: 0.0192secs
iteration [17]; kl:1944.345364; elapsed: 0.0197secs
iteration [18]; kl:1941.443241; elapsed: 0.0187secs
iteration [19]; kl:1938.839679

iteration [168]; kl:1914.613674; elapsed: 0.0173secs
iteration [169]; kl:1914.613196; elapsed: 0.0174secs
iteration [170]; kl:1914.612743; elapsed: 0.0173secs
iteration [171]; kl:1914.612314; elapsed: 0.0168secs
iteration [172]; kl:1914.611907; elapsed: 0.0174secs
iteration [173]; kl:1914.611521; elapsed: 0.0196secs
iteration [174]; kl:1914.611155; elapsed: 0.0193secs
iteration [175]; kl:1914.610809; elapsed: 0.0188secs
iteration [176]; kl:1914.61048; elapsed: 0.019secs
iteration [177]; kl:1914.610169; elapsed: 0.0193secs
iteration [178]; kl:1914.609874; elapsed: 0.0181secs
iteration [179]; kl:1914.609594; elapsed: 0.0194secs
iteration [180]; kl:1914.609329; elapsed: 0.0197secs
iteration [181]; kl:1914.609078; elapsed: 0.0191secs
iteration [182]; kl:1914.608839; elapsed: 0.0179secs
iteration [183]; kl:1914.608613; elapsed: 0.019secs
iteration [184]; kl:1914.608399; elapsed: 0.0191secs
iteration [185]; kl:1914.608196; elapsed: 0.0181secs
iteration [186]; kl:1914.608003; elapsed: 0.0193s

(OrderedDict([('pop_freq_beta_params',
               DeviceArray([[[18.02754992,  2.48862109],
                             [ 5.82430939, 17.22473896],
                             [ 1.33488308,  1.20352748],
                             ...,
                             [ 1.15801348,  1.05351108],
                             [ 1.15725565,  1.05286904],
                             [ 1.16173061,  1.05665971]],
               
                            [[ 5.60392194, 15.8088414 ],
                             [ 9.2665638 , 12.88052435],
                             [ 1.3260009 ,  1.21260054],
                             ...,
                             [ 1.15794576,  1.05364128],
                             [ 1.15720968,  1.05295753],
                             [ 1.16155539,  1.05699554]],
               
                            [[ 8.65526164, 12.09400826],
                             [ 5.21946547, 17.59198312],
                             [ 1.32477334,  1.21398128],
    