In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils, cavi_lib
from vb_lib.preconditioner_lib import get_mfvb_cov_matmul
from bnpmodeling_runjingdev.sensitivity_lib import HyperparameterSensitivityLinearApproximation, get_jac_hvp_fun

import paragami

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib




In [2]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [3]:
use_simulated = False

if use_simulated: 
    data = np.load('../simulated_data/simulated_structure_data_nobs20_nloci50_npop4.npz')
    g_obs = np.array(data['g_obs'])
else: 
    data_dir = '../../../../fastStructure/hgdp_data/huang2011_plink_files/'
    filenamebase = 'phased_HGDP+India+Africa_2810SNPs-regions1to36'
    filename = data_dir + filenamebase + '.npz'
    data = np.load(filename)

    g_obs = np.array(data['g_obs'])
    g_obs_raw = np.array(data['g_obs_raw'])

    # just checking ... 
    which_missing = (g_obs_raw == 3)
    (g_obs.argmax(-1) == g_obs_raw)[~which_missing].all()
    (g_obs[which_missing] == 0).all()

In [4]:
n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

In [5]:
print(n_obs)
print(n_loci)

1107
2810


# Get prior

In [6]:
_, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

prior_params_dict = prior_params_paragami.random(key=jax.random.PRNGKey(41))
prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

print(prior_params_dict)

OrderedDict([('dp_prior_alpha', DeviceArray([2.59776136], dtype=float64)), ('allele_prior_alpha', DeviceArray([1.15564526], dtype=float64)), ('allele_prior_beta', DeviceArray([1.0515046], dtype=float64))])


# Get VB params 

In [7]:
k_approx = 15

In [8]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [9]:
use_logitnormal_sticks = False

_, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (2810, 15, 2) (lb=0.0, ub=inf)
	[ind_admix_params] = OrderedDict:
	[stick_beta] = NumericArrayPattern (1107, 14, 2) (lb=0.0, ub=inf)


In [10]:
vb_params_dict = vb_params_paragami.random(key=jax.random.PRNGKey(41))

# Get moments from vb parameters

In [11]:
dp_prior_alpha = prior_params_dict['dp_prior_alpha']
allele_prior_alpha = prior_params_dict['allele_prior_alpha']
allele_prior_beta = prior_params_dict['allele_prior_beta']

# get initial moments from vb_params
e_log_sticks, e_log_1m_sticks, \
    e_log_pop_freq, e_log_1m_pop_freq = \
        structure_model_lib.get_moments_from_vb_params_dict(
            vb_params_dict, gh_loc, gh_weights)

e_log_cluster_probs = \
        modeling_lib.get_e_log_cluster_probabilities_from_e_log_stick(
            e_log_sticks, e_log_1m_sticks)

# Population beta update

In [12]:
update_pop_beta = jax.jit(cavi_lib.update_pop_beta)

In [13]:
beta_param_update = update_pop_beta(g_obs, e_log_pop_freq, e_log_1m_pop_freq, 
                                        e_log_cluster_probs, prior_params_dict)

### check against autograd

In [14]:
get_pop_beta_update1_ag = jax.jit(cavi_lib.get_pop_beta_update1_ag)
get_pop_beta_update2_ag = jax.jit(cavi_lib.get_pop_beta_update2_ag)

In [15]:
beta_update1 = get_pop_beta_update1_ag(g_obs,
                                    e_log_pop_freq, e_log_1m_pop_freq,
                                    e_log_sticks, e_log_1m_sticks,
                                    dp_prior_alpha, allele_prior_alpha,
                                    allele_prior_beta).block_until_ready()

In [16]:
beta_update2 = get_pop_beta_update2_ag(g_obs,
                                    e_log_pop_freq, e_log_1m_pop_freq,
                                    e_log_sticks, e_log_1m_sticks,
                                    dp_prior_alpha, allele_prior_alpha,
                                    allele_prior_beta).block_until_ready()

In [17]:
np.abs(beta_param_update[0][:, :, 0] - 1 - beta_update1).max()

DeviceArray(3.41060513e-12, dtype=float64)

In [18]:
np.abs(beta_param_update[0][:, :, 1] - 1 - beta_update2).max()

DeviceArray(1.36424205e-12, dtype=float64)

### Timing

It should be faster than the autograd updates ... otherwise why not just use autograd updates

In [19]:
t0 = time.time() 

_ = get_pop_beta_update1_ag(g_obs,
                            e_log_pop_freq, e_log_1m_pop_freq,
                            e_log_sticks, e_log_1m_sticks,
                            dp_prior_alpha, allele_prior_alpha,
                            allele_prior_beta).block_until_ready()

_ = get_pop_beta_update2_ag(g_obs,
                            e_log_pop_freq, e_log_1m_pop_freq,
                            e_log_sticks, e_log_1m_sticks,
                            dp_prior_alpha, allele_prior_alpha,
                            allele_prior_beta).block_until_ready()

print(time.time() - t0)

17.36200714111328


In [20]:
t0 = time.time() 

out = update_pop_beta(g_obs, e_log_pop_freq, e_log_1m_pop_freq, 
                        e_log_cluster_probs, prior_params_dict)

out[0].block_until_ready()
print(time.time() - t0)

4.0859293937683105


# Admixture stick updates

In [21]:
update_ind_admix_beta = jax.jit(cavi_lib.update_ind_admix_beta)

In [22]:
ind_admix_beta_update = update_ind_admix_beta(g_obs, e_log_pop_freq, e_log_1m_pop_freq, 
                            e_log_cluster_probs, prior_params_dict)

In [23]:
get_stick_update1_ag = jax.jit(cavi_lib.get_stick_update1_ag)
get_stick_update2_ag = jax.jit(cavi_lib.get_stick_update2_ag)

In [24]:
stick_update1 = get_stick_update1_ag(g_obs,
                        e_log_pop_freq, e_log_1m_pop_freq,
                        e_log_sticks, e_log_1m_sticks,
                        dp_prior_alpha, allele_prior_alpha,
                        allele_prior_beta)

stick_update2 = get_stick_update2_ag(g_obs,
                        e_log_pop_freq, e_log_1m_pop_freq,
                        e_log_sticks, e_log_1m_sticks,
                        dp_prior_alpha, allele_prior_alpha,
                        allele_prior_beta)

In [25]:
np.abs(stick_update1 + 1 -  ind_admix_beta_update[0][:, :, 0]).max()

DeviceArray(1.09139364e-10, dtype=float64)

In [26]:
np.abs(stick_update2 + 1 - ind_admix_beta_update[0][:, :, 1]).max()

DeviceArray(7.23048288e-11, dtype=float64)

### Timing

It should be faster than the autograd updates ... otherwise why not just use autograd updates

In [27]:
t0 = time.time() 

_ = get_stick_update1_ag(g_obs,
                        e_log_pop_freq, e_log_1m_pop_freq,
                        e_log_sticks, e_log_1m_sticks,
                        dp_prior_alpha, allele_prior_alpha,
                        allele_prior_beta).block_until_ready()

_ = get_stick_update2_ag(g_obs,
                        e_log_pop_freq, e_log_1m_pop_freq,
                        e_log_sticks, e_log_1m_sticks,
                        dp_prior_alpha, allele_prior_alpha,
                        allele_prior_beta).block_until_ready()
print(time.time() - t0)

17.013890981674194


In [28]:
t0 = time.time() 

out = update_ind_admix_beta(g_obs, e_log_pop_freq, e_log_1m_pop_freq, 
                            e_log_cluster_probs, prior_params_dict)
out[0].block_until_ready()

print(time.time() - t0)

4.136917352676392


In [None]:
1

In [None]:
_ = cavi_lib.run_cavi(g_obs, vb_params_dict,
                vb_params_paragami,
                prior_params_dict, 
                debug = True)

# check data weighting for pop beta

In [46]:
beta_update1, e_log1, e_log2 = cavi_lib.update_pop_beta(g_obs, e_log_pop_freq, e_log_1m_pop_freq, 
                       e_log_cluster_probs, prior_params_dict, 
                       data_weight = 2)

In [47]:
beta_update2, e_log12, e_log22 = cavi_lib.update_pop_beta(np.vstack((g_obs, g_obs)), 
                            e_log_pop_freq, e_log_1m_pop_freq, 
                            np.vstack((e_log_cluster_probs, e_log_cluster_probs)), 
                            prior_params_dict, 
                            data_weight = 1)

In [48]:
np.abs(beta_update1 - beta_update2).max()

DeviceArray(2.13162821e-14, dtype=float64)

In [49]:
np.abs(e_log1 - e_log12).max()

DeviceArray(2.66453526e-15, dtype=float64)

In [50]:
np.abs(e_log2 - e_log22).max()

DeviceArray(2.44249065e-15, dtype=float64)