In [1]:
import jax

import jax.numpy as np

from numpy.polynomial.hermite import hermgauss

from vb_lib import structure_model_lib, data_utils
import vb_lib.structure_optimization_lib as s_optim_lib

import paragami

import time



In [2]:
import numpy as onp
onp.random.seed(53453)

# Simulate data

In [3]:
n_obs = 1107 # number of individuals
n_loci = 2810 # number of loci
n_pop = 3 # number of true populations

g_obs = data_utils.draw_data(n_obs, n_loci, n_pop)[0]

In [4]:
# what is g_obs ...
g_obs.shape

(1107, 2810, 3)

# Get prior

In [5]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_dict)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

{'dp_prior_alpha': DeviceArray([3.], dtype=float64), 'allele_prior_alpha': DeviceArray([1.], dtype=float64), 'allele_prior_beta': DeviceArray([1.], dtype=float64)}


# Get VB params 

In [6]:
# number of components in vb approximation
k_approx = 20

In [7]:
# gauss hermite ....
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [8]:
# define vb model, and paragami
use_logitnormal_sticks = True

# randomly initialized vb parameters along with corresponding pattern
vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (2810, 20, 2) (lb=0.0, ub=inf)
	[ind_admix_params] = OrderedDict:
	[stick_means] = NumericArrayPattern (1107, 19) (lb=-inf, ub=inf)
	[stick_infos] = NumericArrayPattern (1107, 19) (lb=0.0, ub=inf)


In [9]:
vb_params_free = vb_params_paragami.flatten(vb_params_dict, free = True)

# Define objective

In [10]:
# what is this objective ... 

In [11]:
stru_objective = s_optim_lib.StructureObjective(g_obs, 
                                                 vb_params_paragami,
                                                 prior_params_dict, 
                                                 gh_loc, gh_weights)

compiling objective ... 
done. Elasped: 78.494


# Benchmark

### Function time

In [12]:
for i in range(5): 
    t0 = time.time() 
    _ = stru_objective.f(vb_params_free).block_until_ready()
    print('elapsed: {:.3f}sec'.format(time.time() - t0))

elapsed: 5.406sec
elapsed: 5.407sec
elapsed: 5.424sec
elapsed: 5.403sec
elapsed: 5.406sec


### Gradient time

In [13]:
for i in range(5): 
    t0 = time.time() 
    _ = stru_objective.grad(vb_params_free).block_until_ready()
    print('elapsed: {:.3f}sec'.format(time.time() - t0))

elapsed: 6.887sec
elapsed: 6.933sec
elapsed: 6.867sec
elapsed: 6.889sec
elapsed: 6.861sec


### My custom hessian vector product

In [14]:
v = np.array(onp.random.randn(len(vb_params_free)))

In [None]:
for i in range(5): 
    t0 = time.time() 
    my_hvp = stru_objective.hvp(vb_params_free, v).block_until_ready()
    print('elapsed: {:.3f}sec'.format(time.time() - t0))

elapsed: 14.228sec


# HVP using fully automatic differentiation