In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

import numpy as onp
from numpy.polynomial.hermite import hermgauss

import structure_vb_lib.structure_model_lib as structure_model_lib
import structure_vb_lib.cavi_lib as cavi_lib
import structure_vb_lib.structure_optimization_lib as s_optim_lib

import paragami

import argparse
import distutils.util

import os

import time




In [2]:
######################
# Load Data
######################
data = np.load('../../data/huang2011_sub_nobs25_nloci75.npz')
g_obs = np.array(data['g_obs'], dtype = int)

n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

print('g_obs.shape', g_obs.shape)


g_obs.shape (25, 75, 3)


In [3]:
######################
# GET PRIOR
######################
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print('prior params: ')
print(prior_params_dict)


prior params: 
{'dp_prior_alpha': DeviceArray([3.], dtype=float64), 'allele_prior_alpha': DeviceArray([1.], dtype=float64), 'allele_prior_beta': DeviceArray([1.], dtype=float64)}


In [4]:
######################
# GET VB PARAMS AND INITIALIZE
######################
k_approx = 20
gh_deg = 8
gh_loc, gh_weights = hermgauss(gh_deg)

init_optim_time = time.time() 

cavi_init_time = 0.

_, vb_params_paragami = \
    structure_model_lib.\
        get_vb_params_paragami_object(n_obs, 
                                      n_loci,
                                      k_approx,
                                      use_logitnormal_sticks = True, 
                                      seed = 3453)
vb_params_dict, cavi_init_time = \
    s_optim_lib.initialize_with_cavi(g_obs, 
                         vb_params_paragami, 
                         prior_params_dict, 
                         gh_loc, gh_weights, 
                         print_every = 20, 
                         max_iter = 200, 
                         seed = 1232)


Compiling cavi functions ...
CAVI compile time: 5.32sec

 running CAVI ...
iteration [20]; kl:2819.874336; elapsed: 0.2569secs
iteration [40]; kl:2802.364975; elapsed: 0.2358secs
iteration [60]; kl:2797.128753; elapsed: 0.1448secs
iteration [80]; kl:2794.480526; elapsed: 0.1442secs
iteration [100]; kl:2792.529562; elapsed: 0.1442secs
iteration [120]; kl:2790.915119; elapsed: 0.1457secs
iteration [140]; kl:2790.511899; elapsed: 0.1448secs
iteration [160]; kl:2790.431946; elapsed: 0.1446secs
iteration [180]; kl:2790.392637; elapsed: 0.1476secs
final KL: 2790.366780
Elapsed: 199 steps in 1.64 seconds
Stick conversion time: 1.913secs


In [5]:
vb_cavi_free = vb_params_paragami.flatten(vb_params_dict, free = True)

In [6]:
vb_cavi_free - np.load('./testing.npz')['vb_cavi_free']

DeviceArray([0., 0., 0., ..., 0., 0., 0.], dtype=float64)

In [7]:
import inspect
lines = inspect.getsource(structure_model_lib.get_e_loglik_gene_nk)
print(lines)

def get_e_loglik_gene_nk(g_obs_l, e_log_pop_freq_l, e_log_1m_pop_freq_l):

    g_obs_l0 = g_obs_l[:, 0]
    g_obs_l1 = g_obs_l[:, 1]
    g_obs_l2 = g_obs_l[:, 2]

    loglik_a = \
        np.outer(g_obs_l0, e_log_1m_pop_freq_l) + \
            np.outer(g_obs_l1 + g_obs_l2, e_log_pop_freq_l)

    loglik_b = \
        np.outer(g_obs_l0 + g_obs_l1, e_log_1m_pop_freq_l) + \
            np.outer(g_obs_l2, e_log_pop_freq_l)


    return np.stack((loglik_a, loglik_b), axis = -1)



In [8]:
vb_opt_dict, vb_opt, out, precond_objective, lbfgs_time = \
    s_optim_lib.run_preconditioned_lbfgs(g_obs, 
                        vb_params_dict, 
                        vb_params_paragami,
                        prior_params_dict,
                        gh_loc, gh_weights)


compiling preconditioned objective ... 
done. Elasped: 33.6212
init kl: 2801.014344
iteration [20]; kl:2799.961902; elapsed: 1.075secs
iteration [37]; kl:2799.841273; elapsed: 1.025secs
lbfgs converged successfully
done. Elapsed 2.6482


In [9]:
np.abs(vb_opt - np.load('./testing.npz')['vb_opt']).max()

DeviceArray(4.85051999e-11, dtype=float64)

In [10]:
import bnpmodeling_runjingdev.cluster_quantities_lib as cluster_quantities_lib
cluster_quantities_lib._cumprod_through_log

<function bnpmodeling_runjingdev.cluster_quantities_lib._cumprod_through_log(x, axis=None)>

In [None]:
import inspect
lines = inspect.getsource(cluster_quantities_lib._cumprod_through_log)
print(lines)