In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils, cavi_lib
import vb_lib.structure_optimization_lib as s_optim_lib
from vb_lib.preconditioner_lib import get_mfvb_cov_matmul

import paragami
import vittles

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib
from bnpmodeling_runjingdev.sensitivity_lib import get_jac_hvp_fun



In [2]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [3]:
# data_file = '../../../../../fastStructure/hgdp_data/huang2011_plink_files/' + \
#                 'phased_HGDP+India+Africa_2810SNPs-regions1to36.npz'
data_file = '../../simulated_data/simulated_structure_data_nobs20_nloci50_npop4.npz'

data = np.load(data_file)
g_obs = np.array(data['g_obs'], dtype = int)

In [4]:
n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

print(n_obs)
print(n_loci)

20
50


# Get prior

In [5]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_dict)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

{'dp_prior_alpha': DeviceArray([3.], dtype=float64), 'allele_prior_alpha': DeviceArray([1.], dtype=float64), 'allele_prior_beta': DeviceArray([1.], dtype=float64)}


# Get VB params 

In [6]:
k_approx = 8

In [7]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [8]:
use_logitnormal_sticks = True

vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (50, 8, 2) (lb=0.0, ub=inf)
	[ind_admix_params] = OrderedDict:
	[stick_means] = NumericArrayPattern (20, 7) (lb=-inf, ub=inf)
	[stick_infos] = NumericArrayPattern (20, 7) (lb=0.0001, ub=inf)


## Initialize 

In [9]:
vb_params_dict = vb_params_paragami.random()

In [10]:
structure_model_lib.get_kl(g_obs, vb_params_dict, prior_params_dict,
                            gh_loc, gh_weights)

DeviceArray(2375.583033, dtype=float64)

In [11]:
vb_params_free = vb_params_paragami.flatten(vb_params_dict, free = True)

# Define objective

### This is the actual objective
Here, detach_ez is False

In [12]:
_kl_fun_free = paragami.FlattenFunctionInput(
                                original_fun=structure_model_lib.get_kl,
                                patterns = vb_params_paragami,
                                free = True,
                                argnums = 1)

kl_fun_free = lambda x : \
                _kl_fun_free(g_obs, x, prior_params_dict,
                                gh_loc, gh_weights, 
                                detach_ez = False)



get_kl_grad = jax.grad(kl_fun_free)
get_kl_hessian = jax.hessian(kl_fun_free)

In [13]:
kl_grad = get_kl_grad(vb_params_free)

In [14]:
kl_hess = get_kl_hessian(vb_params_free)

### My custom objective
Here, we took some short-cuts in evaluating the gradient and we fiddled with the HVP

In [15]:
stru_objective = s_optim_lib.StructureObjective(g_obs, 
                                                 vb_params_paragami,
                                                 prior_params_dict, 
                                                 gh_loc, gh_weights)

compiling objective ... 
done. Elasped: 33.6822


# Check objective derivatives

In [21]:
assert np.abs(stru_objective.f(vb_params_free) - kl_fun_free(vb_params_free)) < 1e-12

In [22]:
assert np.abs(stru_objective.grad(vb_params_free) - kl_grad).max() < 1e-12

### The HVP in particular needs testing ...

In [23]:
for i in range(len(vb_params_free)): 
    
    if (i % 50) == 0: 
        print(i)
    
    v = onp.zeros(len(vb_params_free))
    v[i] = 1.
    v = np.array(v)
    
    hvp1 = stru_objective.hvp(vb_params_free, v)
    hvp2 = np.dot(kl_hess, v)
    
    diff = np.abs(hvp1 - hvp2).max()
    assert diff < 1e-12, diff
print('done. ')

0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
done. 


# Custom hvp

In [17]:
kl_fun_free_customized = lambda x, detach_ez, detach_vb_params: \
                _kl_fun_free(g_obs, x, prior_params_dict,
                                gh_loc, gh_weights, 
                                detach_ez = detach_ez, 
                                detach_vb_params = detach_vb_params)

In [18]:
term1 = jax.hessian(lambda x : kl_fun_free_customized(x, True, False))(vb_params_free)

detaching ez


In [19]:
from vb_lib.structure_model_lib import *

In [20]:
def ps_loss_z(vb_params_dict): 
    
    e_log_sticks, e_log_1m_sticks, \
        e_log_pop_freq, e_log_1m_pop_freq = \
            get_moments_from_vb_params_dict(vb_params_dict,
                                    gh_loc = gh_loc,
                                    gh_weights = gh_weights)
    
    e_log_cluster_probs = \
        modeling_lib.get_e_log_cluster_probabilities_from_e_log_stick(
                            e_log_sticks, e_log_1m_sticks)
    
    body_fun = lambda x : np.sqrt(get_optimal_ezl(x[0], x[1], x[2],
                                             e_log_cluster_probs,
                                             detach_ez = False)[1]).flatten()
        
    return jax.lax.map(body_fun,
                        xs = (g_obs.transpose((1, 0, 2)),
                              e_log_pop_freq, 
                              e_log_1m_pop_freq)).flatten()


In [21]:
ps_loss_z_free = paragami.FlattenFunctionInput(
                                original_fun=ps_loss_z,
                                patterns = vb_params_paragami,
                                free = True,
                                argnums = 0)

In [22]:
v = np.array(onp.random.randn(len(vb_params_free)))

In [23]:
get_term2 = lambda v : jax.vjp(ps_loss_z_free, vb_params_free)[1](\
                        jax.jvp(ps_loss_z_free, (vb_params_free, ), (v, ))[1])[0]

In [24]:
np.dot(term1, v) - 4 * get_term2(v)

DeviceArray([ 1.50482607e+00, -1.41679872e+00, -1.96451443e-03, ...,
              1.49461175e+00, -4.59348122e+00, -6.41059686e-01],            dtype=float64)

In [25]:
np.dot(kl_hess, v)

DeviceArray([ 1.50482607e+00, -1.41679872e+00, -1.96451443e-03, ...,
              1.49461175e+00, -4.59348122e+00, -6.41059686e-01],            dtype=float64)

In [34]:
def ps_loss_zl(vb_params_dict, g_obs, indx_l): 
    
    # cluster probabilitites
    e_log_sticks, e_log_1m_sticks = \
        ef.get_e_log_logitnormal(
            lognorm_means = vb_params_dict['ind_admix_params']['stick_means'],
            lognorm_infos = vb_params_dict['ind_admix_params']['stick_infos'],
            gh_loc = gh_loc,
            gh_weights = gh_weights)

    e_log_cluster_probs = \
        modeling_lib.get_e_log_cluster_probabilities_from_e_log_stick(
                            e_log_sticks, e_log_1m_sticks)
    
    # stick parameters
    pop_freq_beta_params = vb_params_dict['pop_freq_beta_params'][indx_l]
    e_log_pop_freq, e_log_1m_pop_freq = \
        modeling_lib.get_e_log_beta(pop_freq_beta_params)
    
    return np.sqrt(get_optimal_ezl(g_obs[:, indx_l], 
                                   np.expand_dims(e_log_pop_freq, 0),
                                   np.expand_dims(e_log_1m_pop_freq, 0),
                                   e_log_cluster_probs,
                                   detach_ez = False)[1]).flatten()

In [37]:
ps_loss_zl_free = paragami.FlattenFunctionInput(
                                original_fun=ps_loss_zl,
                                patterns = vb_params_paragami,
                                free = True,
                                argnums = 0)

In [53]:
def body_fun(val, l): 
    fun = lambda x : ps_loss_zl_free(x, g_obs, l)
    return jax.vjp(fun, vb_params_free)[1](\
            jax.jvp(fun, (vb_params_free, ), (v, ))[1])[0] + val

In [54]:
scan_fun = lambda val, l:  (body_fun(val, l), None)

In [55]:
np.arange(g_obs.shape[1])

DeviceArray([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14,
             15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
             30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
             45, 46, 47, 48, 49], dtype=int64)

In [60]:
term2 = jax.lax.scan(scan_fun,
             init = np.zeros(len(vb_params_free)),
             xs = np.arange(g_obs.shape[1]))[0]

In [61]:
np.dot(term1, v) - 4 * term2

DeviceArray([ 1.50482607e+00, -1.41679872e+00, -1.96451443e-03, ...,
              1.49461175e+00, -4.59348122e+00, -6.41059686e-01],            dtype=float64)

In [62]:
np.dot(kl_hess, v)

DeviceArray([ 1.50482607e+00, -1.41679872e+00, -1.96451443e-03, ...,
              1.49461175e+00, -4.59348122e+00, -6.41059686e-01],            dtype=float64)

In [36]:
ps_loss_zl(vb_params_dict, g_obs, indx_l = 0).shape

(320,)