In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils, cavi_lib
import vb_lib.structure_optimization_lib as s_optim_lib
from vb_lib.preconditioner_lib import get_mfvb_cov_matmul

import paragami
import vittles

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib
from bnpmodeling_runjingdev.sensitivity_lib import get_jac_hvp_fun



In [2]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [3]:
# data_file = '../../../../../fastStructure/hgdp_data/huang2011_plink_files/' + \
#                 'phased_HGDP+India+Africa_2810SNPs-regions1to36.npz'
data_file = '../../simulated_data/simulated_structure_data_nobs20_nloci50_npop4.npz'

data = np.load(data_file)
g_obs = np.array(data['g_obs'], dtype = int)

In [4]:
n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

print(n_obs)
print(n_loci)

20
50


# Get prior

In [5]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_dict)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

{'dp_prior_alpha': DeviceArray([3.], dtype=float64), 'allele_prior_alpha': DeviceArray([1.], dtype=float64), 'allele_prior_beta': DeviceArray([1.], dtype=float64)}


# Get VB params 

In [6]:
k_approx = 8

In [7]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [8]:
use_logitnormal_sticks = True

vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (50, 8, 2) (lb=0.0, ub=inf)
	[ind_admix_params] = OrderedDict:
	[stick_means] = NumericArrayPattern (20, 7) (lb=-inf, ub=inf)
	[stick_infos] = NumericArrayPattern (20, 7) (lb=0.0001, ub=inf)


## Initialize 

In [9]:
vb_params_dict = vb_params_paragami.random()

In [10]:
structure_model_lib.get_kl(g_obs, vb_params_dict, prior_params_dict,
                            gh_loc, gh_weights)

DeviceArray(2375.583033, dtype=float64)

In [11]:
vb_free_params = vb_params_paragami.flatten(vb_params_dict, free = True)

### My custom objective
Here, we took some short-cuts in evaluating the gradient and we fiddled with the HVP

In [63]:
stru_objective = s_optim_lib.StructureObjective(g_obs, 
                                                 vb_params_paragami,
                                                 prior_params_dict, 
                                                 gh_loc, gh_weights, 
                                                 jit_functions = False)

In [64]:
v = onp.zeros(len(vb_free_params))
v[0] = 1
v = np.array(v)

In [65]:
stru_objective._kl_zz(vb_free_params, v)

DeviceArray([ 35.6275913 , -33.67329717,   0.        , ...,   0.        ,
               0.        ,   0.        ], dtype=float64)

In [32]:
moments_tuple = stru_objective._get_moments_from_vb_free_params(vb_free_params)

In [33]:
moments_jvp = jax.jvp(stru_objective._get_moments_from_vb_free_params, \
                                      (vb_free_params, ), (v, ))[1]

In [34]:
moments_vjp = jax.vjp(stru_objective._get_moments_from_vb_free_params, 
                             vb_free_params)[1]

In [35]:
l = 0

In [55]:
def scan_fun(val, x): 
    # x[0] is g_obs[:, l]
    # x[1] is e_log_pop
    # x[2] is e_log_pop jvp
    
    fun = lambda clust_probs, pop_freq : \
            stru_objective._ps_loss_zl(x[0], clust_probs, pop_freq)
    
    jvp1 = jax.jvp(fun, 
                    (moments_tuple[0], x[1]), 
                    (moments_jvp[0], x[2]))[1]
    
    return jvp1 
    # vjp1 = jax.vjp(fun, *(moments_tuple[0], x[1]))[1](jvp1)
    
    # return vjp1

In [56]:
e_z_l.shape

(20, 8, 2)

In [61]:
moments_jvp

(DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float64),
 DeviceArray([[[ 0.7963

In [62]:
scan_fun(0, (g_obs[:, l], moments_tuple[1][l], moments_jvp[1][l]))

DeviceArray([-0.74569034, -0.74569034,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
             -0.80380262, -0.80380262,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
              1.06367703, -0.80379569,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
              0.99389352, -0.77510784,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
              0.        ,

In [58]:
vjp1 = jax.lax.scan(scan_fun,
             init = np.zeros(moments_tuple[0].shape), 
             xs = (g_obs.transpose((1, 0, 2)), 
                   moments_tuple[1], 
                   moments_jvp[1]))

TypeError: scan body output must be a pair, got ShapedArray(float64[320]).

In [29]:
vjp1[0].shape

(20, 8)

In [30]:
vjp1[1].shape

(50, 8, 2)

In [32]:
moments_vjp(vjp1)[0]

DeviceArray([-1.04380913e+00,  2.06411121e+00,  1.48310283e+00, ...,
             -1.96914989e-03, -1.17091790e-03, -6.72791142e-04],            dtype=float64)

In [None]:
jax.jvp(structure_model_lib._ps_loss_zl, 
        (moments_tuple[0, l], moments_tuple[]

In [30]:
moments_tuple[0]

DeviceArray([[[-0.82741892, -1.03241251],
              [-0.48572532, -1.38910329],
              [-0.91153441, -1.05561225],
              [-1.10927797, -0.68664253],
              [-0.9792916 , -0.77146759],
              [-0.45600275, -1.49331377],
              [-1.0996217 , -0.61163253],
              [-0.91436016, -0.84017733]],

             [[-0.81688144, -1.00946629],
              [-1.22457811, -0.62709289],
              [-0.98789019, -0.73206267],
              [-0.78909338, -1.13599018],
              [-0.73767236, -0.90320678],
              [-1.4744108 , -0.48331476],
              [-0.56750962, -1.26715667],
              [-0.59739746, -1.36187634]],

             [[-1.12063888, -0.61642108],
              [-0.64450045, -1.09954715],
              [-0.99022447, -0.79967659],
              [-1.0559823 , -0.62219574],
              [-0.81530391, -0.82565833],
              [-0.76211209, -1.15504168],
              [-0.78566752, -0.90779666],
              [-0.79493075, -1

In [38]:
fun = lambda x, y, z: stru_objective._ps_loss_zl(x, y, z, l)
                
jvp1 = jax.jvp(fun, moments_tuple, moments_jvp)[1]
vjp1 = jax.vjp(fun, *moments_tuple)[1](jvp1)

In [44]:
moments_jvp

(DeviceArray([[ 0.14917525,  0.33544684,  0.08425488, -0.12411104,
               -0.00131144,  0.3498302 , -0.13852111,  0.05999757],
              [ 0.15363427, -0.24086096, -0.01748675,  0.18520526,
                0.17154381, -0.51796494,  0.30535658,  0.30772497],
              [-0.15214363,  0.25459988, -0.00433194, -0.10436821,
                0.11315939,  0.20473179,  0.15241164,  0.17887636],
              [-0.17032621, -0.35407862,  0.08035337,  0.15743843,
               -0.27103981, -0.10137679, -0.4339894 , -0.31136587],
              [ 0.09297595, -0.54988979,  0.29957498,  0.29787629,
                0.19336748, -0.15372246, -0.10417238, -0.00797628],
              [ 0.10909973,  0.35621418,  0.35711571,  0.31658662,
                0.22711777,  0.29272999,  0.3403026 , -0.53762502],
              [ 0.19721481,  0.0122966 ,  0.25583434, -0.15210551,
                0.14306055, -0.18733439,  0.19087443, -0.34484987],
              [ 0.15523415,  0.08868858,  0.2757443 , -

In [42]:
vjp1[0].shape

(50, 8)

In [28]:
jvp1

DeviceArray([ 0.19888119,  0.19888119, -0.32993966, -0.32993966,
             -0.15612458, -0.15612458, -0.04668224, -0.04668224,
             -0.10698422, -0.10698422, -0.08535478, -0.08535478,
             -0.05630256, -0.05630256, -0.05839513, -0.05839513,
              0.15584981,  0.15584981, -0.13873141, -0.13873141,
             -0.08428724, -0.08428724, -0.12248975, -0.12248975,
             -0.08649573, -0.08649573, -0.07856049, -0.07856049,
             -0.04328068, -0.04328068, -0.04540428, -0.04540428,
              0.03319317,  0.14988376,  0.08959133, -0.12910482,
             -0.18117827, -0.15575321, -0.12841917, -0.03627768,
             -0.09517473, -0.0618535 , -0.06909015, -0.07837284,
             -0.05071287, -0.03812505, -0.05064085, -0.04622517,
              0.13127232,  0.17600339, -0.08387132, -0.25550081,
             -0.09305379, -0.07237535, -0.13618611, -0.0706363 ,
             -0.10596484, -0.08988662, -0.08148679, -0.08373148,
             -0.05615018,

In [26]:
vjp1

(DeviceArray([[ 0.8821253 ,  0.74594663, -0.8326863 , -0.43687188,
               -0.19017812, -0.10683612, -0.03291352, -0.028586  ],
              [ 0.        ,  0.        ,  0.        ,  0.        ,
                0.        ,  0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ,  0.        ,
                0.        ,  0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ,  0.        ,
                0.        ,  0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ,  0.        ,
                0.        ,  0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ,  0.        ,
                0.        ,  0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ,  0.        ,
                0.        ,  0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ,  

In [25]:
moments_vjp(vjp1)[0].shape

(1080,)

In [13]:
moments_tuple = stru_objective._get_moments_from_vb_free_params(vb_free_params)

In [25]:
moments_jvp = jax.jvp(stru_objective._get_moments_from_vb_free_params, \
                        (vb_free_params, ), (v, ))[1]

In [94]:
moments_vjp = jax.vjp(stru_objective._get_moments_from_vb_free_params, 
                     vb_free_params)[1]

In [98]:
l = 0

In [99]:
fun = lambda x, y, z: stru_objective._ps_loss_zl(x, y, z, l)

In [100]:
jvp1 = jax.jvp(fun, moments_tuple, moments_jvp)[1]
vjp1 = jax.vjp(fun, *moments_tuple)[1](jvp1)

In [102]:
moments_vjp(vjp1)[0]

(1080,)

In [89]:
jax.vjp?

In [80]:
foo.shape

(320,)

In [79]:
jax.vjp(fun, *moments_tuple)[1](foo)

(DeviceArray([[ 0.        ,  0.        ,  0.        ,  0.        ,
                0.        ,  0.        ,  0.        ,  0.        ],
              [ 1.81030992, -0.8563201 , -0.49874139, -0.23451346,
               -0.12096671, -0.0478392 , -0.02823992, -0.02368915],
              [ 0.        ,  0.        ,  0.        ,  0.        ,
                0.        ,  0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ,  0.        ,
                0.        ,  0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ,  0.        ,
                0.        ,  0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ,  0.        ,
                0.        ,  0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ,  0.        ,
                0.        ,  0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ,  

In [38]:
jax.vjp(fun, moments_tuple)[1](jax.jvp(fun, moments_tuple, moments_jvp)[1])[0] + val

TypeError: <lambda>() takes 0 positional arguments but 1 was given

# Check objective derivatives

In [16]:
assert np.abs(stru_objective.f(vb_params_free) - kl_fun_free(vb_params_free)) < 1e-12

In [17]:
assert np.abs(stru_objective.grad(vb_params_free) - kl_grad).max() < 1e-12

### The HVP in particular needs testing ...

In [18]:
for i in range(len(vb_params_free)): 
    
    if (i % 50) == 0: 
        print(i)
    
    v = onp.zeros(len(vb_params_free))
    v[i] = 1.
    v = np.array(v)
    
    hvp1 = stru_objective.hvp(vb_params_free, v)
    hvp2 = np.dot(kl_hess, v)
    
    diff = np.abs(hvp1 - hvp2).max()
    assert diff < 1e-12, diff
print('done. ')

0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
done. 


In [38]:
def foo(x, y): 
    return x**2, y**2

In [39]:
x = 5.
y = 6.

In [40]:
foo(x, y)

(25.0, 36.0)

In [41]:
jax.jvp(foo, (x, y), (x, y))[1]

(DeviceArray(50., dtype=float64), DeviceArray(72., dtype=float64))

In [None]:
#         vb_params_dict = self.vb_params_paragami.fold(vb_free_params, free = True)

#         # cluster probabilitites
#         e_log_sticks, e_log_1m_sticks = \
#             ef.get_e_log_logitnormal(
#                 lognorm_means = vb_params_dict['ind_admix_params']['stick_means'],
#                 lognorm_infos = vb_params_dict['ind_admix_params']['stick_infos'],
#                 gh_loc = self.gh_loc,
#                 gh_weights = self.gh_weights)

#         e_log_cluster_probs = \
#             modeling_lib.get_e_log_cluster_probabilities_from_e_log_stick(
#                                 e_log_sticks, e_log_1m_sticks)
