In [1]:
import jax

import jax.numpy as np

from numpy.polynomial.hermite import hermgauss

from vb_lib import structure_model_lib, data_utils
import vb_lib.structure_optimization_lib as s_optim_lib

import paragami

import time



In [2]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [3]:
# this is the size of the human genome data
n_obs = 1107
n_loci = 2810

# n_obs = 20
# n_loci = 50

g_obs, true_pop_allele_freq, true_ind_admix_propn = \
    data_utils.draw_data(n_obs, n_loci, n_pop = 3)

In [4]:
g_obs.shape

(1107, 2810, 3)

# Get prior

In [5]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_dict)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

{'dp_prior_alpha': DeviceArray([3.], dtype=float64), 'allele_prior_alpha': DeviceArray([1.], dtype=float64), 'allele_prior_beta': DeviceArray([1.], dtype=float64)}


# Get VB params 

In [6]:
k_approx = 20

In [7]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [8]:
use_logitnormal_sticks = True

vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (2810, 20, 2) (lb=0.0, ub=inf)
	[ind_admix_params] = OrderedDict:
	[stick_means] = NumericArrayPattern (1107, 19) (lb=-inf, ub=inf)
	[stick_infos] = NumericArrayPattern (1107, 19) (lb=0.0, ub=inf)


In [9]:
vb_params_free = vb_params_paragami.flatten(vb_params_dict, free = True)

# Define objective

In [10]:
stru_objective = s_optim_lib.StructureObjective(g_obs, 
                                                 vb_params_paragami,
                                                 prior_params_dict, 
                                                 gh_loc, gh_weights)

compiling objective ... 
done. Elasped: 78.7653


# Benchmark

### Function time

In [11]:
for i in range(1): 
    t0 = time.time() 
    _ = stru_objective.f(vb_params_free).block_until_ready()
    print('elapsed: {:.3f}sec'.format(time.time() - t0))

elapsed: 5.417sec


### Gradient time

In [12]:
for i in range(1): 
    t0 = time.time() 
    _ = stru_objective.grad(vb_params_free).block_until_ready()
    print('elapsed: {:.3f}sec'.format(time.time() - t0))

elapsed: 6.899sec


### My custom hessian vector product

In [13]:
v = np.array(onp.random.randn(len(vb_params_free)))

In [14]:
for i in range(1): 
    t0 = time.time() 
    my_hvp = stru_objective.hvp(vb_params_free, v).block_until_ready()
    print('elapsed: {:.3f}sec'.format(time.time() - t0))

elapsed: 14.265sec


# HVP using fully automatic differentiation

In [15]:
def get_kl(x): 
    # note that detach_ez is False here

    vb_params_dict = vb_params_paragami.fold(x, free = True)

    return structure_model_lib.get_kl(g_obs,
                                      vb_params_dict, 
                                      prior_params_dict, 
                                      gh_loc,
                                      gh_weights, 
                                      detach_ez = False)

_get_kl_hvp = lambda v : jax.jvp(jax.grad(get_kl), (vb_params_free, ), (v, ))[1]
get_kl_hvp = jax.jit(_get_kl_hvp)

### the lines below are super slow on full-sized data

In [16]:
# compile
true_hvp = get_kl_hvp(v).block_until_ready()
np.abs(true_hvp - my_hvp).max()

DeviceArray(8.18545232e-12, dtype=float64)

In [17]:
for i in range(1): 
    t0 = time.time() 
    my_hvp = get_kl_hvp(v).block_until_ready()
    print('elapsed: {:.3f}sec'.format(time.time() - t0))

elapsed: 485.335sec
