In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp
from numpy.polynomial.hermite import hermgauss

from vb_lib import structure_model_lib, data_utils, cavi_lib
from bnpmodeling_runjingdev.sensitivity_lib import HyperparameterSensitivityLinearApproximation

import paragami

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib




In [2]:
import numpy as onp 
onp.random.seed(53453)

# Draw data

In [3]:
n_obs = 50
n_loci = 1000
n_pop = 4

In [4]:
g_obs, true_pop_allele_freq, true_ind_admix_propn = \
    data_utils.draw_data(n_obs, n_loci, n_pop)

# Get prior

In [5]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()


dp_prior_alpha = prior_params_dict['dp_prior_alpha']
allele_prior_alpha = prior_params_dict['allele_prior_alpha']
allele_prior_beta = prior_params_dict['allele_prior_beta']

# Get VB params 

In [6]:
k_approx = 12

In [7]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [8]:
use_logitnormal_sticks = False

_, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)

vb_params_dict = vb_params_paragami.random()

# Cavi updates

In [9]:
e_log_sticks, e_log_1m_sticks, \
        e_log_pop_freq, e_log_1m_pop_freq = \
            structure_model_lib.get_moments_from_vb_params_dict(
                vb_params_dict, gh_loc, gh_weights)

### Check population updates

In [10]:
_, _, beta_params = \
    cavi_lib.update_pop_beta(g_obs,
                    e_log_pop_freq, e_log_1m_pop_freq,
                    e_log_sticks, e_log_1m_sticks,
                    dp_prior_alpha, allele_prior_alpha,
                    allele_prior_beta)

In [11]:
import vb_lib.cavi_lib_full_ez as cavi_lib_full_ez

In [12]:
e_z = cavi_lib_full_ez.update_z(g_obs, e_log_sticks, e_log_1m_sticks, e_log_pop_freq,
                                e_log_1m_pop_freq)

In [13]:
_, _, beta_params2 = \
    cavi_lib_full_ez.update_pop_beta(g_obs, e_z,
                    e_log_sticks, e_log_1m_sticks,
                    dp_prior_alpha, allele_prior_alpha,
                    allele_prior_beta)

In [14]:
np.abs(beta_params - beta_params2).max()

DeviceArray(7.10542736e-15, dtype=float64)

### Check stick updates

In [15]:
_, _, stick_beta = cavi_lib.update_stick_beta(g_obs,
                                            e_log_pop_freq, e_log_1m_pop_freq,
                                            e_log_sticks, e_log_1m_sticks,
                                            dp_prior_alpha, allele_prior_alpha,
                                            allele_prior_beta)

In [16]:
e_z = cavi_lib_full_ez.update_z(g_obs, e_log_sticks, e_log_1m_sticks, e_log_pop_freq,
                                e_log_1m_pop_freq)

_, _, stick_beta2 = \
    cavi_lib_full_ez.update_stick_beta(g_obs, e_z,
                    e_log_pop_freq, e_log_1m_pop_freq,
                    dp_prior_alpha, allele_prior_alpha,
                    allele_prior_beta)

In [17]:
np.abs(stick_beta - stick_beta2).max()

DeviceArray(2.27373675e-12, dtype=float64)