In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils, cavi_lib
from vb_lib.preconditioner_lib import get_mfvb_cov_matmul
from bnpmodeling_runjingdev.sensitivity_lib import HyperparameterSensitivityLinearApproximation, get_jac_hvp_fun

import paragami

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib




In [2]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [3]:
# data = np.load('../simulated_data/simulated_structure_data_nobs20_nloci50_npop4.npz')
# g_obs = np.array(data['g_obs'])

In [4]:
data_dir = '../../../../fastStructure/hgdp_data/huang2011_plink_files/'
filenamebase = 'phased_HGDP+India+Africa_2810SNPs-regions1to36'
filename = data_dir + filenamebase + '.npz'
data = np.load(filename)

g_obs = np.array(data['g_obs'])
g_obs_raw = np.array(data['g_obs_raw'])

# just checking ... 
which_missing = (g_obs_raw == 3)
(g_obs.argmax(-1) == g_obs_raw)[~which_missing].all()
(g_obs[which_missing] == 0).all()

DeviceArray(True, dtype=bool)

In [5]:
n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

In [6]:
print(n_obs)
print(n_loci)

1107
2810


# Get prior

In [7]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_paragami)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

OrderedDict:
	[dp_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_beta] = NumericArrayPattern (1,) (lb=0.0, ub=inf)


# Get VB params 

In [8]:
k_approx = 15

In [9]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [10]:
use_logitnormal_sticks = True

vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (2810, 15, 2) (lb=0.0, ub=inf)
	[ind_admix_params] = OrderedDict:
	[stick_means] = NumericArrayPattern (1107, 14) (lb=-inf, ub=inf)
	[stick_infos] = NumericArrayPattern (1107, 14) (lb=0.0001, ub=inf)


In [11]:
vb_params_dict = vb_params_paragami.random(key=jax.random.PRNGKey(41))

In [12]:
structure_model_lib.get_kl(g_obs, vb_params_dict, prior_params_dict,
                                    gh_loc, gh_weights)

DeviceArray(5866337.80064644, dtype=float64)

# Optimize

In [13]:
dp_prior_alpha = prior_params_dict['dp_prior_alpha']
allele_prior_alpha = prior_params_dict['allele_prior_alpha']
allele_prior_beta = prior_params_dict['allele_prior_beta']

# get initial moments from vb_params
e_log_sticks, e_log_1m_sticks, \
    e_log_pop_freq, e_log_1m_pop_freq = \
        structure_model_lib.get_moments_from_vb_params_dict(
            vb_params_dict, gh_loc, gh_weights)

In [16]:
get_pop_beta_update1_ad = jax.jit(cavi_lib.get_pop_beta_update1_ad)

In [17]:
t0 = time.time()
foo = get_pop_beta_update1_ad(g_obs,
                    e_log_pop_freq, e_log_1m_pop_freq,
                    e_log_sticks, e_log_1m_sticks,
                    dp_prior_alpha, allele_prior_alpha,
                    allele_prior_beta)
foo[0].block_until_ready()
print(time.time() - t0)

10.463480234146118


In [18]:
t0 = time.time() 
out = get_pop_beta_update1_ad(g_obs,
                    e_log_pop_freq, e_log_1m_pop_freq,
                    e_log_sticks, e_log_1m_sticks,
                    dp_prior_alpha, allele_prior_alpha,
                    allele_prior_beta)
out[0].block_until_ready()

print(time.time() - t0)

8.657811880111694


In [20]:
g_obs.shape

(1107, 2810, 3)

In [22]:
e_log_cluster_probs = \
    modeling_lib.get_e_log_cluster_probabilities_from_e_log_stick(
                        e_log_sticks, e_log_1m_sticks)

In [27]:
l = 0

In [30]:
g_obs_l = g_obs[:, l, :]

g_obs_l0 = g_obs_l[:, 0]
g_obs_l1 = g_obs_l[:, 1]
g_obs_l2 = g_obs_l[:, 2]


In [32]:
(g_obs_l0 + g_obs_1)

(1107,)

In [None]:
def update_pop_beta():
    

In [17]:
class StickObjective():
    def __init__(self, g_obs, vb_params_paragami, prior_params_dict, 
                            gh_loc, gh_weights, log_phi, epsilon): 
        self.g_obs = g_obs
        vb_params_paragami = vb_params_paragami
        
        self.prior_params_dict = prior_params_dict
        self.gh_loc = gh_loc
        self.gh_weights = gh_weights
        
        self.log_phi = log_phi
        self.epsilon = epsilon
        
        self.stick_objective_fun = \
            paragami.FlattenFunctionInput(
                original_fun =self._get_ind_admix_sticks_loss,
                patterns = vb_params_paragami['ind_admix_params'],
                free = True,
                argnums = 0)
        
        # objective and gradients
        self.f = jax.jit(self.stick_objective_fun)
        self.grad = jax.jit(jax.grad(self.stick_objective_fun, argnums = 0))  
        
        # compile 
        print('compiling stick objective and gradients ...')
        t0 = time.time()
        param_dict = vb_params_paragami.random()
        stick_free_params = vb_params_paragami['ind_admix_params'].flatten(
                                param_dict['ind_admix_params'], free = True)
        _ = self.f(stick_free_params, param_dict['pop_freq_beta_params'])
        _ = self.grad(stick_free_params, param_dict['pop_freq_beta_params'])
        print('compile time: {0:.3g}sec'.format(time.time() - t0))
        
    def _get_ind_admix_sticks_loss(self, 
                                ind_admix_params,
                                pop_freq_beta_params):
        
        vb_params_dict = dict({'pop_freq_beta_params':pop_freq_beta_params,
                          'ind_admix_params': ind_admix_params})

        return structure_model_lib.get_kl(self.g_obs,
                                            vb_params_dict,
                                            self.prior_params_dict,
                                            self.gh_loc, self.gh_weights,
                                            self.log_phi,
                                            self.epsilon,
                                            detach_ez = False)
    

In [18]:
stick_objective = StickObjective(g_obs, vb_params_paragami, prior_params_dict, 
                                 gh_loc, gh_weights, log_phi = None, epsilon = 0.)

compiling stick objective and gradients ...
compile time: 23sec


In [19]:
pop_freq_beta_params = vb_params_dict['pop_freq_beta_params']
x = vb_params_paragami['ind_admix_params'].flatten(vb_params_dict['ind_admix_params'], free = True)

In [22]:
t0 = time.time()
_ = stick_objective.f(x, pop_freq_beta_params).block_until_ready()
print(time.time() - t0)

8.488821506500244


In [28]:
def _get_ind_admix_sticks_loss(g_obs,
                                ind_admix_params,
                                pop_freq_beta_params,
                                prior_params_dict,
                                gh_loc, gh_weights,
                                log_phi, epsilon):

    vb_params_dict = dict({'pop_freq_beta_params':pop_freq_beta_params,
                          'ind_admix_params': ind_admix_params})

    return structure_model_lib.get_kl(g_obs, vb_params_dict,
                                        prior_params_dict,
                                        gh_loc, gh_weights,
                                        log_phi,
                                        epsilon,
                                        detach_ez = False)

In [15]:
from vb_lib.structure_optimization_lib import define_structure_objective
from bnpmodeling_runjingdev.optimization_lib import run_lbfgs

In [84]:
from scipy import optimize

In [16]:
optim_objective, init_vb_free = \
    define_structure_objective(g_obs, vb_params_dict,
                        vb_params_paragami,
                        prior_params_dict,
                        gh_loc = gh_loc,
                        gh_weights = gh_weights, 
                        compile_hvp=True)

Compiling objective ...
Iter 0: f = 30594962.37881606
Compiling grad ...
Compiling hvp ...
Compile time: 76.768secs


In [19]:
t0 = time.time()
_ = optim_objective.hvp_np(init_vb_free, init_vb_free)
time.time() - t0

28.205904006958008

# Check out the fit

In [None]:
# fitted
e_ind_admix = get_vb_expectations(vb_opt_dict, use_logitnormal_sticks)[0]
plt.matshow(e_ind_admix.T)