In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils, cavi_lib, structure_optimization_lib
from vb_lib.preconditioner_lib import get_mfvb_cov_matmul

import paragami
import vittles

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib
from bnpmodeling_runjingdev.sensitivity_lib import get_jac_hvp_fun



In [2]:
import numpy as onp
onp.random.seed(53453)

# Draw data

In [3]:
n_obs = 200
n_loci = 500
n_pop = 4

In [4]:
g_obs, true_pop_allele_freq, true_ind_admix_propn = \
    data_utils.draw_data(n_obs, n_loci, n_pop)

Generating datapoints  0  to  200


In [5]:
g_obs = np.array(g_obs)

In [6]:
(g_obs.mean(1)**2).mean()

DeviceArray(0.1143092, dtype=float64)

In [7]:
g_obs.shape

(200, 500, 3)

# Get prior

In [8]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_dict)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

{'dp_prior_alpha': DeviceArray([3.], dtype=float64), 'allele_prior_alpha': DeviceArray([1.], dtype=float64), 'allele_prior_beta': DeviceArray([1.], dtype=float64)}


# Get VB params 

In [9]:
k_approx = 15

In [10]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [11]:
use_logitnormal_sticks = False

vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (500, 15, 2) (lb=0.0, ub=inf)
	[ind_mix_stick_beta_params] = NumericArrayPattern (200, 14, 2) (lb=0.0, ub=inf)


In [12]:
g_obs.transpose((1, 0, 2)).shape

(500, 200, 3)

## Initialize 

In [13]:
vb_params_dict = \
        structure_model_lib.set_init_vb_params(g_obs, k_approx, vb_params_dict,
                                                seed = 143241)

In [14]:
structure_model_lib.get_kl(g_obs, vb_params_dict, prior_params_dict,
                            gh_loc, gh_weights)

DeviceArray(488772.73509143, dtype=float64)

# Define objective

In [15]:
optim_objective, vb_params_free = \
    structure_optimization_lib.define_structure_objective(g_obs,
                                                        vb_params_dict,
                                                        vb_params_paragami,
                                                        prior_params_dict, 
                                                        compile_hvp = True)

Compiling objective ...
Iter 0: f = 488772.73509143
Compiling grad ...
Compiling hvp ...
Compile time: 45.2426secs


In [16]:
optim_objective.set_print_every(1000)

In [17]:
for i in range(10): 
    t0 = time.time()
    kl = optim_objective.f_np(vb_params_free)
    print(time.time() - t0)

Iter 0: f = 488772.73509143
1.594972848892212
0.28156137466430664
0.2793557643890381
0.28894925117492676
0.27843737602233887
0.27852511405944824
0.2793397903442383
0.27867984771728516
0.27886199951171875
0.2780947685241699


In [18]:
for i in range(10): 
    t0 = time.time()
    kl_grad = optim_objective.grad_np(vb_params_free)
    print(time.time() - t0)

0.658360481262207
0.6518728733062744
0.6499850749969482
0.6550426483154297
0.6555237770080566
0.6533718109130859
0.6517345905303955
0.6536388397216797
0.6495473384857178
0.6459848880767822


In [19]:
for i in range(10): 
    t0 = time.time()
    kl_hvp = optim_objective.hvp_np(vb_params_free, vb_params_free)
    print(time.time() - t0)

1.3368937969207764
1.3207335472106934
1.319925308227539


KeyboardInterrupt: 

In [None]:
import inspect
lines = inspect.getsource(structure_model_lib.get_e_loglik)
print(lines)
