In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils
import vb_lib.structure_optimization_lib as s_optim_lib

import paragami

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib
from bnpmodeling_runjingdev.sensitivity_lib import get_jac_hvp_fun



In [2]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [3]:
data = np.load('../../simulated_data/simulated_structure_data_nobs20_nloci50_npop4.npz')
g_obs = np.array(data['g_obs'], dtype = int)

In [4]:
# data_dir = '../../../../../fastStructure/hgdp_data/huang2011_plink_files/'
# filenamebase = 'phased_HGDP+India+Africa_2810SNPs-regions1to36'
# filename = data_dir + filenamebase + '.npz'
# data = np.load(filename)

# g_obs = np.array(data['g_obs'], dtype = int)
# g_obs_raw = np.array(data['g_obs_raw'])

# # just checking ... 
# which_missing = (g_obs_raw == 3)
# (g_obs.argmax(-1) == g_obs_raw)[~which_missing].all()
# (g_obs[which_missing] == 0).all()

In [5]:
n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

In [6]:
print(n_obs)
print(n_loci)

20
50


# Get prior

In [7]:
_, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

prior_params_dict = prior_params_paragami.random(key=jax.random.PRNGKey(41))

print(prior_params_paragami)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

OrderedDict:
	[dp_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_beta] = NumericArrayPattern (1,) (lb=0.0, ub=inf)


# Get VB params 

In [8]:
k_approx = 8

In [9]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [10]:
use_logitnormal_sticks = True

vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (50, 8, 2) (lb=0.0, ub=inf)
	[ind_admix_params] = OrderedDict:
	[stick_means] = NumericArrayPattern (20, 7) (lb=-inf, ub=inf)
	[stick_infos] = NumericArrayPattern (20, 7) (lb=0.0001, ub=inf)


In [11]:
vb_params_dict = vb_params_paragami.random(key=jax.random.PRNGKey(41))

# The (not preconditioned) objective

In [12]:
def _f(x): 
    vb_params_dict = vb_params_paragami.fold(x, free = True)
        
    return structure_model_lib.get_kl(g_obs, vb_params_dict, 
                                      prior_params_dict, 
                                      gh_loc, gh_weights, 
                                      detach_ez = False)

get_kl = jax.jit(_f)
get_grad = jax.jit(jax.grad(_f))
get_hvp = jax.jit(get_jac_hvp_fun(_f))

In [13]:
x = vb_params_paragami.flatten(vb_params_dict, free = True)

In [14]:
kl = get_kl(x)
print(kl)

2416.47883971296


In [15]:
grad = get_grad(x).block_until_ready()

In [16]:
_  = get_hvp(x, x).block_until_ready()

# Define preconditioned objective

In [17]:
# not preconditioned
objective_precond = s_optim_lib.StructurePrecondObjective(g_obs, 
                            vb_params_paragami,
                            prior_params_dict,
                            gh_loc = gh_loc, gh_weights = gh_weights)

compiling preconditioned objective ... 
done. Elasped: 45.473


In [18]:
precond_params = x

x_c = objective_precond.precondition(x, x)

# Check function values

In [19]:
objective_precond.f_precond(x_c, precond_params) - kl

DeviceArray(-4.54747351e-13, dtype=float64)

### Check gradient

In [20]:
# preconditioned function
def f_precond(x_c): 
    x = objective_precond._unprecondition(x_c, precond_params)
    return _f(x)

print(f_precond(x_c) - kl)


-4.547473508864641e-13


In [21]:
# the gradient from the preconditioned objective
grad_precond = objective_precond.grad_precond(x_c, precond_params)

grad_precond2 = jax.grad(f_precond)(x_c)

diff = np.abs(grad_precond2 - grad_precond).max()
assert diff < 1e-12, diff


# # convert gradient precond to gradient 
# grad1 = objective_precond.precondition(grad_precond, precond_params)

# # convert gradient to precond gradient 
# grad2 = objective_precond.unprecondition(grad, precond_params)


# print(np.abs(grad1 - grad).max())
# print(np.abs(grad2 - grad_precond).max())

### Check HVP

In [22]:
# get hessian from jax
get_kl_hess_pc = jax.hessian(f_precond)
kl_hess_pc = get_kl_hess_pc(x_c)

In [23]:
for i in range(len(x_c)): 
    
    if (i % 50) == 0: 
        print(i)
    
    v = onp.zeros(len(x_c))
    v[i] = 1.
    v = np.array(v)
    
    hvp1 = objective_precond.hvp_precond(x_c, precond_params, v)
    hvp2 = np.dot(kl_hess_pc, v)
    
    diff = np.abs(hvp1 - hvp2).max()
    assert diff < 1e-12, diff
    
print('done. ')

0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
done. 


# Timing

### Objective function

In [24]:
for i in range(5): 
    t0 = time.time() 
    _ = objective_precond.f_precond(x_c, precond_params).block_until_ready()
    print(time.time() - t0)

0.006220340728759766
0.006279945373535156
0.006295680999755859
0.00620722770690918
0.006139039993286133


In [25]:
for i in range(5): 
    t0 = time.time() 
    _ = get_kl(x).block_until_ready()
    print(time.time() - t0)

0.0026390552520751953
0.002539396286010742
0.004297733306884766
0.004114866256713867
0.0038759708404541016


### Gradient

In [26]:
for i in range(5): 
    t0 = time.time() 
    _ = objective_precond.grad_precond(x_c, precond_params).block_until_ready()
    print(time.time() - t0)

0.014833450317382812
0.01473093032836914
0.014672994613647461
0.014631032943725586
0.014544010162353516


In [27]:
for i in range(5): 
    t0 = time.time() 
    _ = get_grad(x).block_until_ready()
    print(time.time() - t0)

0.008603096008300781
0.008400440216064453
0.008219003677368164
0.008244037628173828
0.008393049240112305


### HVP

In [28]:
v = np.array(onp.random.randn(len(x_c)))

In [29]:
for i in range(5): 
    t0 = time.time() 
    _ = objective_precond.hvp_precond(x_c, precond_params, v).block_until_ready()
    print(time.time() - t0)

0.04215669631958008
0.038106441497802734
0.037842750549316406
0.03804802894592285
0.03779172897338867


In [30]:
for i in range(5): 
    t0 = time.time() 
    _ = get_hvp(x, v).block_until_ready()
    print(time.time() - t0)

0.0198974609375
0.01949453353881836
0.019475698471069336
0.019238710403442383
0.01947498321533203
