In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils, cavi_lib, structure_optimization_lib
from vb_lib.preconditioner_lib import get_mfvb_cov_matmul

import paragami
import vittles

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib
from bnpmodeling_runjingdev.sensitivity_lib import get_jac_hvp_fun



In [2]:
import numpy as onp
onp.random.seed(53453)

# Draw data

In [3]:
n_obs = 200
n_loci = 500
n_pop = 4

In [4]:
g_obs, true_pop_allele_freq, true_ind_admix_propn = \
    data_utils.draw_data(n_obs, n_loci, n_pop)

Generating datapoints  0  to  200


In [5]:
g_obs = np.array(g_obs)

In [6]:
(g_obs.mean(1)**2).mean()

DeviceArray(0.1143092, dtype=float64)

In [7]:
g_obs.shape

(200, 500, 3)

# Get prior

In [8]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_dict)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

{'dp_prior_alpha': DeviceArray([3.], dtype=float64), 'allele_prior_alpha': DeviceArray([1.], dtype=float64), 'allele_prior_beta': DeviceArray([1.], dtype=float64)}


# Get VB params 

In [9]:
k_approx = 15

In [10]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [11]:
use_logitnormal_sticks = False

vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (500, 15, 2) (lb=0.0, ub=inf)
	[ind_mix_stick_beta_params] = NumericArrayPattern (200, 14, 2) (lb=0.0, ub=inf)


## Initialize 

In [12]:
vb_params_dict = \
        structure_model_lib.set_init_vb_params(g_obs, k_approx, vb_params_dict,
                                                seed = 143241)

In [13]:
structure_model_lib.get_kl(g_obs, vb_params_dict, prior_params_dict,
                            gh_loc, gh_weights)

DeviceArray(482925.84029194, dtype=float64)

# Define objective

In [14]:
optim_objective, vb_params_free = \
    structure_optimization_lib.define_structure_objective(g_obs,
                                                        vb_params_dict,
                                                        vb_params_paragami,
                                                        prior_params_dict, 
                                                        compile_hvp = True)

Compiling objective ...
Iter 0: f = 482925.84029194
Compiling grad ...
Compiling hvp ...
Compile time: 46.8943secs


In [15]:
optim_objective.set_print_every(1000)

In [16]:
for i in range(10): 
    t0 = time.time()
    kl = optim_objective.f_np(vb_params_free)
    print(time.time() - t0)

Iter 0: f = 482925.84029194
0.2911083698272705
0.2938210964202881
0.28874850273132324
0.29052257537841797
0.28848767280578613
0.28859543800354004
0.28805994987487793
0.29024744033813477
0.28844332695007324
0.29029107093811035


In [17]:
for i in range(10): 
    t0 = time.time()
    kl_grad = optim_objective.grad_np(vb_params_free)
    print(time.time() - t0)

0.41802096366882324
0.4140028953552246
0.4097018241882324
0.41280031204223633
0.4125094413757324
0.4142568111419678
0.41301798820495605
0.41332364082336426
0.41864633560180664
0.4147317409515381


In [18]:
for i in range(10): 
    t0 = time.time()
    kl_hvp = optim_objective.hvp_np(vb_params_free, vb_params_free)
    print(time.time() - t0)

0.6436164379119873
0.6356630325317383
0.6405189037322998
0.6439821720123291
0.6356256008148193
0.6437315940856934
0.6406965255737305
0.6421606540679932
0.6400279998779297
0.6372144222259521


In [20]:
import inspect
lines = inspect.getsource(structure_model_lib.get_e_loglik)
print(lines)

def get_e_loglik(g_obs,
                    e_log_pop_freq, e_log_1m_pop_freq, \
                    e_log_sticks, e_log_1m_sticks,
                    detach_ez): 

    n_obs = g_obs.shape[0]
    n_loci = g_obs.shape[1]
    
    e_log_cluster_probs = \
        modeling_lib.get_e_log_cluster_probabilities_from_e_log_stick(
                            e_log_sticks, e_log_1m_sticks)
    def body_fun(val, i): 
        n = i % n_obs 
        l = i // n_obs
        return get_e_loglik_nl(g_obs[n, l], e_log_pop_freq[l], e_log_1m_pop_freq[l],
                        e_log_cluster_probs[n], detach_ez) + val

    scan_fun = lambda val, x : (body_fun(val, x), None)
    
    init_val = np.array([0., 0.])
    out = jax.lax.scan(scan_fun, init_val,
                        xs = np.arange(n_obs * n_loci))[0]

    e_loglik = out[0]
    z_entropy = out[1]
    
    return e_loglik, z_entropy 

