Here, we examine parametric sensitivity of the structure model on a small simulated dataset

In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils, cavi_lib
from vb_lib.preconditioner_lib import get_mfvb_cov_matmul
from bnpmodeling_runjingdev.sensitivity_lib import HyperparameterSensitivityLinearApproximation, get_jac_hvp_fun

import paragami

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib




In [2]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [3]:
# data = np.load('../simulated_data/simulated_structure_data_nobs20_nloci50_npop4.npz')
# g_obs = np.array(data['g_obs'])

In [4]:
data_dir = '../../../../fastStructure/hgdp_data/huang2011_plink_files/'
filenamebase = 'phased_HGDP+India+Africa_2810SNPs-regions1to36'
filename = data_dir + filenamebase + '.npz'
data = np.load(filename)

g_obs = np.array(data['g_obs'])
g_obs_raw = np.array(data['g_obs_raw'])

# just checking ... 
which_missing = (g_obs_raw == 3)
(g_obs.argmax(-1) == g_obs_raw)[~which_missing].all()
(g_obs[which_missing] == 0).all()

DeviceArray(True, dtype=bool)

In [5]:
n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

# Get prior

In [6]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_paragami)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

OrderedDict:
	[dp_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_beta] = NumericArrayPattern (1,) (lb=0.0, ub=inf)


# Get VB params 

In [7]:
k_approx = 15

In [8]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [33]:
use_logitnormal_sticks = True

vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (2810, 15, 2) (lb=0.0, ub=inf)
	[ind_mix_stick_propn_mean] = NumericArrayPattern (1107, 14) (lb=-inf, ub=inf)
	[ind_mix_stick_propn_info] = NumericArrayPattern (1107, 14) (lb=0.0001, ub=inf)


In [46]:
vb_params_dict = vb_params_paragami.random(key=jax.random.PRNGKey(41))

In [47]:
structure_model_lib.get_kl(g_obs, vb_params_dict, prior_params_dict,
                                    gh_loc, gh_weights)

DeviceArray(5866337.80064644, dtype=float64)

## Initialize 

In [48]:
vb_params_dict = \
        structure_model_lib.set_init_vb_params(g_obs, k_approx, vb_params_dict,
                                                seed = 34221)

In [50]:
vb_params_dict['ind_mix_stick_propn_info'] = vb_params_dict['ind_mix_stick_propn_info'] * 2

In [51]:
structure_model_lib.get_kl(g_obs, vb_params_dict, prior_params_dict,
                                    gh_loc, gh_weights)

DeviceArray(30594962.37881607, dtype=float64)

In [39]:
# get initial moments from vb_params
e_log_sticks, e_log_1m_sticks, \
    e_log_pop_freq, e_log_1m_pop_freq = \
        structure_model_lib.get_moments_from_vb_params_dict(
            vb_params_dict, gh_loc, gh_weights)

In [40]:
vb_params_dict['pop_freq_beta_params'] = \
    cavi_lib.update_pop_beta(g_obs,
                    e_log_pop_freq, e_log_1m_pop_freq,
                    e_log_sticks, e_log_1m_sticks,
                    prior_params_dict['dp_prior_alpha'],
                    prior_params_dict['allele_prior_alpha'],
                    prior_params_dict['allele_prior_beta'])[2]

In [41]:
structure_model_lib.get_kl(g_obs, vb_params_dict, prior_params_dict,
                                    gh_loc, gh_weights)

DeviceArray(8772160.85035737, dtype=float64)

# Optimize

In [80]:
from vb_lib.structure_optimization_lib import define_structure_objective
from bnpmodeling_runjingdev.optimization_lib import run_lbfgs

In [84]:
from scipy import optimize

In [74]:
params = sub_vb_params_paragami.flatten(sub_vb_params_dict, free = True)

In [77]:
get_kl_np(g_obs_subsampled, params)

array(534238.29629317)

In [21]:
optim_objective, init_vb_free = \
    define_structure_objective(g_obs, vb_params_dict,
                        vb_params_paragami,
                        prior_params_dict,
                        gh_loc = gh_loc,
                        gh_weights = gh_weights, 
                        compile_hvp=True)

Compiling objective ...
Iter 0: f = 5841761.27007069
Compiling grad ...
Compiling hvp ...
Compile time: 78.052secs


In [34]:
t0 = time.time()
_ = optim_objective.hessian_vector_product(init_vb_free, init_vb_free).block_until_ready()
time.time() - t0

27.336929321289062

In [13]:
kl1 = optim_objective.f_np(init_vb_free)
kl2 = structure_model_lib.get_kl(g_obs, vb_params_dict, prior_params_dict,
                                    gh_loc, gh_weights)

print(kl1 - kl2)

optim_objective.reset()

Iter 0: f = 5841761.27007069
9.313225746154785e-10


In [14]:
out = run_lbfgs(optim_objective, init_vb_free)


Running L-BFGS-B ... 
Iter 0: f = 5841761.27007069
Iter 1: f = 5796107.78372170
Iter 2: f = 5619834.29162614
Iter 3: f = 5023949.80218303
Iter 4: f = 3896302.37849859
Iter 5: f = 3651757.32707991
Iter 6: f = 3399301.58622861
Iter 7: f = 3270805.62961326
Iter 8: f = 3209808.54410049
Iter 9: f = 3160924.92712628
Iter 10: f = 3134901.41007459
Iter 11: f = 3110473.11365059
Iter 12: f = 3096498.20527866
Iter 13: f = 3087583.01671213
Iter 14: f = 3075292.14434937
Iter 15: f = 3067292.52546304
Iter 16: f = 3058673.63748094
Iter 17: f = 3052149.68444698
Iter 18: f = 3043506.80625176
Iter 19: f = 3037895.90628880
Iter 20: f = 3032464.63095157
Iter 21: f = 3028222.76437575
Iter 22: f = 3024704.57737861
Iter 23: f = 3021182.01036253
Iter 24: f = 3019238.16608886
Iter 25: f = 3016539.89794494
Iter 26: f = 3014768.26037245
Iter 27: f = 3012826.41680009
Iter 28: f = 3011375.18561922
Iter 29: f = 3009783.97638679
Iter 30: f = 3008476.71657809
Iter 31: f = 3006524.66913905
Iter 32: f = 3005481.352514

KeyboardInterrupt: 

In [82]:
vb_opt = out.x

In [83]:
optim_objective.f_np(vb_opt)

Iter 55: f = -39613911251944964744892383232.00000000


array(-3.96139113e+28)

In [84]:
vb_opt_dict = vb_params_paragami.fold(vb_opt, free = True)

In [110]:
structure_model_lib.get_kl(g_obs, vb_opt_dict, prior_params_dict,
                            gh_loc, gh_weights, 
                            log_phi = None, epsilon = 0.)

DeviceArray(2.13204905e+08, dtype=float64)

In [111]:
stick_entropy = \
            modeling_lib.get_stick_breaking_entropy(
                                    vb_opt_dict['ind_mix_stick_propn_mean'],
                                    vb_opt_dict['ind_mix_stick_propn_info'],
                                    gh_loc, gh_weights)

In [112]:
stick_entropy

DeviceArray(-106793.63149807, dtype=float64)

In [113]:
structure_model_lib.get_entropy(vb_opt_dict, gh_loc, gh_weights)

DeviceArray(-2.08260724e+08, dtype=float64)

In [115]:
from bnpmodeling_runjingdev import exponential_families as ef

In [116]:
# beta entropy term
pop_freq_beta_params = vb_opt_dict['pop_freq_beta_params']
lk = pop_freq_beta_params.shape[0] * pop_freq_beta_params.shape[1]
beta_entropy = ef.beta_entropy(tau = pop_freq_beta_params.reshape((lk, 2)))


In [117]:
beta_entropy

DeviceArray(-2.0815393e+08, dtype=float64)

In [118]:
pop_freq_beta_params

DeviceArray([[[1.87047580e+00, 6.79665324e-01],
              [1.86103167e+00, 6.63560490e-01],
              [5.34263030e+02, 9.63100759e-01],
              ...,
              [1.85587377e+00, 6.63984916e-01],
              [1.35427347e+00, 7.18684571e-01],
              [1.61582538e+02, 3.29681448e-02]],

             [[1.88271333e+00, 7.12465697e-01],
              [1.85667785e+00, 6.96534507e-01],
              [1.93067564e-01, 2.11307462e+01],
              ...,
              [1.84591506e+00, 6.92819824e-01],
              [1.84784200e+00, 7.65969630e-01],
              [1.12279194e+02, 2.24290522e-01]],

             [[1.62126944e+00, 1.42571931e+00],
              [1.72837243e+00, 1.27728199e+00],
              [3.38823901e+02, 3.49384973e-01],
              ...,
              [1.57199561e+00, 1.16508668e+00],
              [1.21305538e+00, 1.79800999e+00],
              [9.94449703e-02, 1.63357778e+02]],

             ...,

             [[1.98137725e+00, 6.45915993e-01],
      

# Check out the fit

In [None]:
# fitted
e_ind_admix = get_vb_expectations(vb_opt_dict, use_logitnormal_sticks)[0]
plt.matshow(e_ind_admix.T)