In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from structure_vb_lib import structure_model_lib, data_utils, cavi_lib
import structure_vb_lib.structure_optimization_lib as s_optim_lib
from structure_vb_lib.preconditioner_lib import get_mfvb_cov_matmul

import paragami

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib
from bnpmodeling_runjingdev.sensitivity_lib import get_jac_hvp_fun



In [2]:
autograd_results = np.load('./tmp.npz')

In [3]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [4]:
g_obs = np.array(autograd_results['g_obs'])

n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

# Get prior

In [5]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_dict)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

{'dp_prior_alpha': DeviceArray([3.], dtype=float64), 'allele_prior_alpha': DeviceArray([1.], dtype=float64), 'allele_prior_beta': DeviceArray([1.], dtype=float64)}


# Get VB params 

In [6]:
k_approx = 8

In [7]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [8]:
use_logitnormal_sticks = True

_, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (10, 8, 2) (lb=0.0, ub=inf)
	[ind_admix_params] = OrderedDict:
	[stick_means] = NumericArrayPattern (5, 7) (lb=-inf, ub=inf)
	[stick_infos] = NumericArrayPattern (5, 7) (lb=0.0, ub=inf)


In [9]:
vb_params_free = autograd_results['vb_params_free']
vb_params_dict = vb_params_paragami.fold(vb_params_free, free = True)

assert np.all(vb_params_paragami.flatten(vb_params_dict, free = False) == \
              autograd_results['vb_params_flattened'])

# Define objective

In [10]:
stru_objective = s_optim_lib.StructureObjective(g_obs, 
                                                 vb_params_paragami,
                                                 prior_params_dict, 
                                                 gh_loc, gh_weights)

compiling objective ... 
done. Elasped: 45.0574


# Check gradients are the same as autograd

In [11]:
assert np.abs(stru_objective.f(vb_params_free) - \
              autograd_results['kl']) < 1e-8

In [12]:
assert np.abs(stru_objective.grad(vb_params_free) - \
              autograd_results['kl_grad']).max() < 1e-8

In [13]:
kl_hess = autograd_results['kl_hess']

for i in range(kl_hess.shape[0]): 
    
    e_i = onp.zeros(kl_hess.shape[0])
    e_i[i] = 1
    
    hvp1 = stru_objective.hvp(vb_params_free, e_i)
    hvp2 = np.dot(kl_hess, e_i)
    
    assert np.abs(hvp1 - hvp2).max() < 1e-8

AssertionError: 

In [14]:
np.abs(hvp1 - hvp2).max()

DeviceArray(0.39678664, dtype=float64)

In [None]:
import inspect
lines = inspect.getsource(stru_objective._kl_z2)
print(lines)