In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils
import vb_lib.structure_optimization_lib as s_optim_lib
import vb_lib.structure_preconditioned_optimization_lib as s_poptim_lib

import paragami

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib
from bnpmodeling_runjingdev.optimization_lib import run_lbfgs




In [2]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [3]:
data = np.load('../../simulated_data/simulated_structure_data_nobs20_nloci50_npop4.npz')
g_obs = np.array(data['g_obs'], dtype = int)

In [4]:
# data_dir = '../../../../../fastStructure/hgdp_data/huang2011_plink_files/'
# filenamebase = 'phased_HGDP+India+Africa_2810SNPs-regions1to36'
# filename = data_dir + filenamebase + '.npz'
# data = np.load(filename)

# g_obs = np.array(data['g_obs'], dtype = int)
# g_obs_raw = np.array(data['g_obs_raw'])

# # just checking ... 
# which_missing = (g_obs_raw == 3)
# (g_obs.argmax(-1) == g_obs_raw)[~which_missing].all()
# (g_obs[which_missing] == 0).all()

In [5]:
n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

In [6]:
print(n_obs)
print(n_loci)

20
50


# Get prior

In [7]:
_, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

prior_params_dict = prior_params_paragami.random(key=jax.random.PRNGKey(41))

print(prior_params_paragami)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

OrderedDict:
	[dp_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_beta] = NumericArrayPattern (1,) (lb=0.0, ub=inf)


# Get VB params 

In [8]:
k_approx = 8

In [9]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [10]:
use_logitnormal_sticks = True

vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (50, 8, 2) (lb=0.0, ub=inf)
	[ind_admix_params] = OrderedDict:
	[stick_means] = NumericArrayPattern (20, 7) (lb=-inf, ub=inf)
	[stick_infos] = NumericArrayPattern (20, 7) (lb=0.0001, ub=inf)


In [11]:
vb_params_dict = vb_params_paragami.random(key=jax.random.PRNGKey(41))

In [12]:
kl = structure_model_lib.get_kl(g_obs, 
                    vb_params_dict,
                    prior_params_dict, 
                    gh_loc, gh_weights)

print(kl)

2416.4788397129596


# Define preconditioned objective

In [13]:
# not preconditioned
objective = s_poptim_lib.StructurePrecondObjective(g_obs, 
                            vb_params_paragami,
                            prior_params_dict,
                            gh_loc = gh_loc, gh_weights = gh_weights, 
                            use_preconditioning = False)

compiling preconditioned objective ... 
done. Elasped: 14.1929


In [14]:
# preconditioned 
objective_precond = s_poptim_lib.StructurePrecondObjective(g_obs, 
                            vb_params_paragami,
                            prior_params_dict,
                            gh_loc = gh_loc, gh_weights = gh_weights)

compiling preconditioned objective ... 
done. Elasped: 18.4967


In [15]:
x = vb_params_paragami.flatten(vb_params_dict, free = True)
precond_params = x

x_c = objective_precond.precondition(x, x)

In [16]:
objective_precond.f_precond(x_c, precond_params) - kl

DeviceArray(1.36424205e-12, dtype=float64)

In [17]:
objective.f_precond(x, None) - kl

DeviceArray(0., dtype=float64)

### Timing

In [18]:
for i in range(5): 
    t0 = time.time() 
    _ = objective_precond.f_precond(x_c, precond_params).block_until_ready()
    print(time.time() - t0)

0.00655364990234375
0.006070375442504883
0.006063222885131836
0.006056785583496094
0.006375789642333984


In [19]:
for i in range(5): 
    t0 = time.time() 
    _ = objective.f_precond(x, None).block_until_ready()
    print(time.time() - t0)

0.0027022361755371094
0.0040569305419921875
0.0040433406829833984
0.004189252853393555
0.0027561187744140625


In [20]:
for i in range(5): 
    t0 = time.time() 
    _ = objective_precond.grad(x_c, precond_params).block_until_ready()
    print(time.time() - t0)

0.015915393829345703
0.01566314697265625
0.015961408615112305
0.016867637634277344
0.01677107810974121


In [21]:
for i in range(5): 
    t0 = time.time() 
    _ = objective.grad(x, None).block_until_ready()
    print(time.time() - t0)

9.329540014266968
0.008954286575317383
0.008554220199584961
0.008484363555908203
0.008470535278320312


In [24]:
grad1_precond = objective_precond.grad(x_c, precond_params)
grad1 = objective_precond.precondition(grad1_precond, precond_params)

In [25]:
grad2 = objective.grad(x, None)

In [27]:
np.abs(grad1 - grad2).max()

DeviceArray(1.42108547e-14, dtype=float64)