In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils
import vb_lib.structure_optimization_lib as s_optim_lib
import vb_lib.structure_preconditioned_optimization_lib as s_poptim_lib

import paragami

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib
from bnpmodeling_runjingdev.optimization_lib import run_lbfgs
from bnpmodeling_runjingdev.sensitivity_lib import get_jac_hvp_fun



In [2]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [3]:
data = np.load('../../simulated_data/simulated_structure_data_nobs20_nloci50_npop4.npz')
g_obs = np.array(data['g_obs'], dtype = int)

In [4]:
# data_dir = '../../../../../fastStructure/hgdp_data/huang2011_plink_files/'
# filenamebase = 'phased_HGDP+India+Africa_2810SNPs-regions1to36'
# filename = data_dir + filenamebase + '.npz'
# data = np.load(filename)

# g_obs = np.array(data['g_obs'], dtype = int)
# g_obs_raw = np.array(data['g_obs_raw'])

# # just checking ... 
# which_missing = (g_obs_raw == 3)
# (g_obs.argmax(-1) == g_obs_raw)[~which_missing].all()
# (g_obs[which_missing] == 0).all()

In [5]:
n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

In [6]:
print(n_obs)
print(n_loci)

20
50


# Get prior

In [7]:
_, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

prior_params_dict = prior_params_paragami.random(key=jax.random.PRNGKey(41))

print(prior_params_paragami)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

OrderedDict:
	[dp_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_beta] = NumericArrayPattern (1,) (lb=0.0, ub=inf)


# Get VB params 

In [8]:
k_approx = 8

In [9]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [10]:
use_logitnormal_sticks = True

vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (50, 8, 2) (lb=0.0, ub=1000000.0)
	[ind_admix_params] = OrderedDict:
	[stick_means] = NumericArrayPattern (20, 7) (lb=-inf, ub=inf)
	[stick_infos] = NumericArrayPattern (20, 7) (lb=0.0001, ub=inf)


In [11]:
vb_params_dict = vb_params_paragami.random(key=jax.random.PRNGKey(41))

# The (not preconditioned) objective

In [12]:
def _f(x): 
    vb_params_dict = vb_params_paragami.fold(x, free = True)
        
    return structure_model_lib.get_kl(g_obs, vb_params_dict, 
                                      prior_params_dict, 
                                      gh_loc, gh_weights)

f = jax.jit(_f)
get_grad = jax.jit(jax.grad(_f))
get_hvp = jax.jit(get_jac_hvp_fun(_f))

In [13]:
x = vb_params_paragami.flatten(vb_params_dict, free = True)

In [14]:
kl = f(x)
print(kl)

4479.870303612657


In [15]:
grad = get_grad(x).block_until_ready()

In [16]:
v = x
hvp = get_hvp(x, v).block_until_ready()

# Define preconditioned objective

In [17]:
# not preconditioned
objective_precond = s_poptim_lib.StructurePrecondObjective(g_obs, 
                            vb_params_paragami,
                            prior_params_dict,
                            gh_loc = gh_loc, gh_weights = gh_weights)

compiling preconditioned objective ... 
done. Elasped: 49.3403


In [18]:
precond_params = x

x_c = objective_precond.precondition(x, x)

# Check function values

In [19]:
objective_precond.f_precond(x_c, precond_params) - kl

DeviceArray(6.70588634e-08, dtype=float64)

### Check gradient

In [20]:
grad_precond = objective_precond.grad_precond(x_c, precond_params)

grad1 = objective_precond.precondition(grad_precond, precond_params)
grad2 = objective_precond.unprecondition(grad, precond_params)


print(np.abs(grad1 - grad).max())
print(np.abs(grad2 - grad_precond).max())

1.3280080091160329e-09
1.7556965881126985e-09


### Check HVP

In [21]:
hvp_precond = objective_precond.hvp_precond(x_c, precond_params, v)

In [22]:
v1 = objective_precond.unprecondition(v, precond_params)
hvp = get_hvp(x, v1)
hvp = objective_precond.unprecondition(hvp, precond_params)
np.abs(hvp_precond - hvp).max()

DeviceArray(3.48810544e-09, dtype=float64)

# Timing

### Objective function

In [23]:
for i in range(5): 
    t0 = time.time() 
    _ = objective_precond.f_precond(x_c, precond_params).block_until_ready()
    print(time.time() - t0)

0.008582353591918945
0.006486177444458008
0.006678104400634766
0.006247043609619141
0.006412982940673828


In [24]:
for i in range(5): 
    t0 = time.time() 
    _ = f(x).block_until_ready()
    print(time.time() - t0)

0.003061532974243164
0.0030117034912109375
0.0027008056640625
0.0033049583435058594
0.0028066635131835938


### Gradient

In [25]:
for i in range(5): 
    t0 = time.time() 
    _ = objective_precond.grad_precond(x_c, precond_params).block_until_ready()
    print(time.time() - t0)

0.016669034957885742
0.016231536865234375
0.015273809432983398
0.015501976013183594
0.015541553497314453


In [26]:
for i in range(5): 
    t0 = time.time() 
    _ = get_grad(x).block_until_ready()
    print(time.time() - t0)

0.012613534927368164
0.008977413177490234
0.008396148681640625
0.008582115173339844
0.008532285690307617


### HVP

In [27]:
for i in range(5): 
    t0 = time.time() 
    _ = objective_precond.hvp_precond(x_c, precond_params, v).block_until_ready()
    print(time.time() - t0)

0.039403676986694336
0.03425335884094238
0.03356218338012695
0.033171892166137695
0.033469200134277344


In [28]:
for i in range(5): 
    t0 = time.time() 
    _ = get_hvp(x, v).block_until_ready()
    print(time.time() - t0)

0.0214388370513916
0.0210416316986084
0.020429372787475586
0.02018451690673828
0.02038407325744629
