In [1]:
import autograd 

import autograd.numpy as np
import autograd.scipy as sp
from numpy.polynomial.hermite import hermgauss

import structure_model_lib 
import dp_modeling_lib
import data_utils

import paragami

import optimization_lib as opt_lib

# Draw data

In [2]:
n_obs = 10
n_loci = 5

# this is just done randomly at the moment
# a matrix of integers {0, 1, 2}
g_obs_int = np.random.choice(3, size = (n_obs, n_loci))

# one hot encoding
g_obs = data_utils.get_one_hot(g_obs_int, 3)
assert g_obs.shape == (n_obs, n_loci, 3)

In [3]:
np.random.seed(453453)

# Get prior

In [4]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_paragami)

OrderedDict:
	[dp_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_beta] = NumericArrayPattern (1,) (lb=0.0, ub=inf)


# Get VB params 

In [5]:
k_approx = 12
vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (5, 12, 2) (lb=0.0, ub=inf)
	[ind_mix_stick_propn_mean] = NumericArrayPattern (10, 11) (lb=-inf, ub=inf)
	[ind_mix_stick_propn_info] = NumericArrayPattern (10, 11) (lb=0.0001, ub=inf)


# Set up model

In [6]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [7]:
# get loss as a function of vb parameters
get_free_vb_params_loss = paragami.FlattenFunctionInput(
                                original_fun=structure_model_lib.get_kl, 
                                patterns = vb_params_paragami,
                                free = True,
                                argnums = 1)

get_free_vb_params_loss_cached = \
    lambda x : get_free_vb_params_loss(g_obs, x, prior_params_dict, gh_loc, gh_weights)

In [8]:
init_vb_free_params = vb_params_paragami.flatten(vb_params_dict, free = True)

In [9]:
structure_model_lib.get_kl(g_obs, vb_params_dict,
                   prior_params_dict, gh_loc, gh_weights)

397.3183609901973

In [10]:
get_free_vb_params_loss_cached(init_vb_free_params)

397.3183609901973

In [None]:
vb_opt_params = opt_lib.optimize_full(get_free_vb_params_loss_cached, init_vb_free_params,
                    bfgs_max_iter = 50, netwon_max_iter = 50,
                    max_precondition_iter = 10,
                    gtol=1e-8, ftol=1e-8, xtol=1e-8)

running bfgs ... 
