In [1]:
import autograd 

import autograd.numpy as np
import autograd.scipy as sp
from numpy.polynomial.hermite import hermgauss

import structure_model_lib 
import dp_modeling_lib
import data_utils

import paragami

import optimization_lib as opt_lib

# Draw data

In [2]:
n_obs = 10
n_loci = 5
n_pop = 4

# population allele frequencies
p1 = 0.99
p0 = 0.01
pop_allele_freq = np.maximum(np.eye(n_loci, n_pop) * p1, p0)

# individual admixtures
scale = 3
ind_admix_propn = np.random.choice(n_pop, n_obs)
ind_admix_propn = scale * data_utils.get_one_hot(ind_admix_propn, nb_classes = n_pop)
ind_admix_propn = np.exp(ind_admix_propn) / np.exp(ind_admix_propn).sum(axis = 1, keepdims=True)

In [3]:
g_obs = data_utils.draw_data(pop_allele_freq, ind_admix_propn)

# Get prior

In [4]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_paragami)

OrderedDict:
	[dp_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_beta] = NumericArrayPattern (1,) (lb=0.0, ub=inf)


# Get VB params 

In [5]:
k_approx = n_pop
vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (5, 4, 2) (lb=0.0, ub=inf)
	[ind_mix_stick_propn_mean] = NumericArrayPattern (10, 3) (lb=-inf, ub=inf)
	[ind_mix_stick_propn_info] = NumericArrayPattern (10, 3) (lb=0.0001, ub=inf)


In [6]:
ind_mix_stick_propn_mean = vb_params_dict['ind_mix_stick_propn_mean']
ind_mix_stick_propn_info = vb_params_dict['ind_mix_stick_propn_info']
pop_freq_beta_params = vb_params_dict['pop_freq_beta_params']

In [7]:
dp_modeling_lib.get_e_log_beta(pop_freq_beta_params)[0].shape

(5, 4)

# Set up model

In [8]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [9]:
# get loss as a function of vb parameters
get_free_vb_params_loss = paragami.FlattenFunctionInput(
                                original_fun=structure_model_lib.get_kl, 
                                patterns = vb_params_paragami,
                                free = True,
                                argnums = 1)

get_free_vb_params_loss_cached = \
    lambda x : get_free_vb_params_loss(g_obs, x, prior_params_dict, gh_loc, gh_weights, 
                                      true_pop_allele_freq = None)

In [10]:
init_vb_free_params = vb_params_paragami.flatten(vb_params_dict, free = True)

In [11]:
structure_model_lib.get_kl(g_obs, vb_params_dict,
                   prior_params_dict, gh_loc, gh_weights)

184.34139383125708

In [12]:
get_free_vb_params_loss_cached(init_vb_free_params)

184.34139383125708

In [13]:
vb_opt_free_params = opt_lib.optimize_full(get_free_vb_params_loss_cached, init_vb_free_params,
                    bfgs_max_iter = 50, netwon_max_iter = 50,
                    max_precondition_iter = 10,
                    gtol=1e-8, ftol=1e-8, xtol=1e-8)

running bfgs ... 
         Current function value: 110.730018
         Iterations: 39
         Function evaluations: 102
         Gradient evaluations: 91

 running preconditioned newton; iter =  0
computing preconditioner 
running newton steps
         Current function value: 110.730018
         Iterations: 0
         Function evaluations: 2
         Gradient evaluations: 1
         Hessian evaluations: 0
Iter 0: x_diff = 3.5181232926895234e-13, f_diff = 5.684341886080802e-14, grad_l1 = 3.6412431609845773e-07
done. 


In [14]:
vb_opt_dict = vb_params_paragami.fold(vb_opt_free_params, free=True)

In [15]:
ind_mix_stick_propn_mean = vb_opt_dict['ind_mix_stick_propn_mean']
ind_mix_stick_propn_info = vb_opt_dict['ind_mix_stick_propn_info']
pop_freq_beta_params = vb_opt_dict['pop_freq_beta_params']

In [16]:
e_log_cluster_probs = dp_modeling_lib.get_e_log_cluster_probabilities(
                        ind_mix_stick_propn_mean, ind_mix_stick_propn_info,
                        gh_loc, gh_weights)

In [17]:
e_log_cluster_probs.argmax(axis = 1)

array([3, 0, 0, 3, 3, 0, 3, 3, 3, 0])

In [18]:
ind_admix_propn.argmax(axis=1)

array([3, 2, 2, 0, 0, 2, 0, 0, 1, 3])