In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp
from numpy.polynomial.hermite import hermgauss

from vb_lib import structure_model_lib, data_utils, cavi_lib

import paragami

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  




In [2]:
import numpy as onp
onp.random.seed(53453)

# Draw data

In [3]:
n_obs = 40
n_loci = 50
n_pop = 4

In [4]:
g_obs, true_pop_allele_freq, true_ind_admix_propn = \
    data_utils.draw_data(n_obs, n_loci, n_pop)

Generating datapoints  0  to  40


In [5]:
g_obs.shape

(40, 50, 3)

# Get prior

In [6]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_dict)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

{'dp_prior_alpha': DeviceArray([3.], dtype=float64), 'allele_prior_alpha': DeviceArray([1.], dtype=float64), 'allele_prior_beta': DeviceArray([1.], dtype=float64)}


# Get VB params 

In [7]:
k_approx = 8

In [8]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [9]:
use_logitnormal_sticks = False

_, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)

vb_params_dict = vb_params_paragami.random()

In [10]:
vb_params_free = vb_params_paragami.flatten(vb_params_dict, free = True)

# Check preconditioners

In [11]:
import numpy as onp
import vb_lib.preconditioner_lib as preconditioner_lib
import vb_lib.preconditioner_lib_autograd as preconditioner_lib_autograd

In [12]:
# autograd results
cov_ag = preconditioner_lib_autograd.get_mfvb_cov(vb_params_dict, vb_params_paragami,
                    use_logitnormal_sticks,
                    return_info = False)

info_ag = preconditioner_lib_autograd.get_mfvb_cov(vb_params_dict, vb_params_paragami,
                    use_logitnormal_sticks,
                    return_info = True)

In [13]:
# jax results
cov_jax = jax.jit(lambda v : preconditioner_lib.get_mfvb_cov_matmul(v, 
                                    vb_params_dict,
                                    vb_params_paragami,
                                    return_info = False))

info_jax = jax.jit(lambda v : preconditioner_lib.get_mfvb_cov_matmul(v, 
                                    vb_params_dict,
                                    vb_params_paragami,
                                    return_info = True))

# Check covariances / Infos

In [15]:
# check covariances

for i in range(len(vb_params_free)): 
    
    e_i = onp.zeros(len(vb_params_free))
    e_i[i] = 1.
    
    out1 = cov_jax(e_i)
    out2 = cov_ag.dot(e_i)

    assert np.abs(out2 - out1).max() < 1e-8

In [16]:
# check infos

for i in range(len(vb_params_free)): 
    
    e_i = onp.zeros(len(vb_params_free))
    e_i[i] = 1.
    
    out1 = info_jax(e_i)
    out2 = info_ag.dot(e_i)

    assert np.abs(out2 - out1).max() < 1e-8

# Check square roots

In [17]:
from paragami.optimization_lib import _get_sym_matrix_inv_sqrt_funcs

In [18]:
cov_sqrt_ag = _get_sym_matrix_inv_sqrt_funcs(np.array(cov_ag.todense()))[0]
info_sqrt_ag = _get_sym_matrix_inv_sqrt_funcs(np.array(info_ag.todense()))[0]

In [19]:
# jax results
cov_sqrt_jax = jax.jit(lambda v : preconditioner_lib.get_mfvb_cov_matmul(v, 
                                    vb_params_dict,
                                    vb_params_paragami,
                                    return_info = False, 
                                    return_sqrt = True))

info_sqrt_jax = jax.jit(lambda v : preconditioner_lib.get_mfvb_cov_matmul(v, 
                                    vb_params_dict,
                                    vb_params_paragami,
                                    return_info = True, 
                                    return_sqrt = True))

In [20]:
# check covariances

for i in range(len(vb_params_free)): 
    
    e_i = onp.zeros(len(vb_params_free))
    e_i[i] = 1.
    
    out1 = cov_sqrt_ag(e_i)
    out2 = cov_sqrt_jax(e_i)

    assert np.abs(out2 - out1).max() < 1e-8


In [21]:
# check infos

for i in range(len(vb_params_free)): 
    
    e_i = onp.zeros(len(vb_params_free))
    e_i[i] = 1.
    
    out1 = info_sqrt_ag(e_i)
    out2 = info_sqrt_jax(e_i)

    assert np.abs(out2 - out1).max() < 1e-8