In [1]:
import autograd 

import autograd.numpy as np
import autograd.scipy as sp
from numpy.polynomial.hermite import hermgauss

from vb_lib import structure_model_lib, data_utils, \
                    preconditioner_lib, structure_optimization_lib

import paragami

import time

In [2]:
from BNP_modeling import cluster_quantities_lib, modeling_lib
import BNP_modeling.optimization_lib as opt_lib

from itertools import permutations

import matplotlib.pyplot as plt
%matplotlib inline  

import vittles

from copy import deepcopy

import os


In [3]:
np.random.seed(534534)

# Draw data

In [4]:
n_obs = 160
n_loci = 5
n_pop = 4

# population allele frequencies
p1 = 0.99
p0 = 0.01
true_pop_allele_freq = np.maximum(np.eye(n_loci, n_pop) * p1, p0)

# individual admixtures
scale = 10
true_ind_admix_propn = np.random.choice(n_pop, n_obs)
true_ind_admix_propn = scale * data_utils.get_one_hot(true_ind_admix_propn, nb_classes = n_pop)
true_ind_admix_propn = np.exp(true_ind_admix_propn) / np.exp(true_ind_admix_propn).sum(axis = 1, keepdims=True)

In [5]:
print(true_pop_allele_freq)

[[0.99 0.01 0.01 0.01]
 [0.01 0.99 0.01 0.01]
 [0.01 0.01 0.99 0.01]
 [0.01 0.01 0.01 0.99]
 [0.01 0.01 0.01 0.01]]


In [6]:
true_ind_admix_propn[0:5, :]

array([[9.99863819e-01, 4.53937471e-05, 4.53937471e-05, 4.53937471e-05],
       [9.99863819e-01, 4.53937471e-05, 4.53937471e-05, 4.53937471e-05],
       [4.53937471e-05, 4.53937471e-05, 9.99863819e-01, 4.53937471e-05],
       [9.99863819e-01, 4.53937471e-05, 4.53937471e-05, 4.53937471e-05],
       [4.53937471e-05, 9.99863819e-01, 4.53937471e-05, 4.53937471e-05]])

In [7]:
g_obs = data_utils.draw_data_from_popfreq_and_admix(true_pop_allele_freq, true_ind_admix_propn)

# Get prior

In [8]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

print(prior_params_paragami)

OrderedDict:
	[dp_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_beta] = NumericArrayPattern (1,) (lb=0.0, ub=inf)


# Get VB params 

In [9]:
use_logitnormal_sticks = True

In [10]:
k_approx = n_pop
vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx, use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (5, 4, 2) (lb=0.0, ub=inf)
	[ind_mix_stick_propn_mean] = NumericArrayPattern (160, 3) (lb=-inf, ub=inf)
	[ind_mix_stick_propn_info] = NumericArrayPattern (160, 3) (lb=0.0001, ub=inf)


In [11]:
ind_mix_stick_propn_mean = vb_params_dict['ind_mix_stick_propn_mean']
ind_mix_stick_propn_info = vb_params_dict['ind_mix_stick_propn_info']
pop_freq_beta_params = vb_params_dict['pop_freq_beta_params']

In [12]:
modeling_lib.get_e_log_beta(pop_freq_beta_params)[0].shape

(5, 4)

# Set up model

In [13]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [14]:
# get loss as a function of vb parameters
get_free_vb_params_loss = paragami.FlattenFunctionInput(
                                original_fun=structure_model_lib.get_kl, 
                                patterns = vb_params_paragami,
                                free = True,
                                argnums = 1)

get_free_vb_params_loss_cached = \
    lambda x : get_free_vb_params_loss(g_obs, x, prior_params_dict, 
                                       use_logitnormal_sticks, gh_loc, gh_weights)

In [15]:
init_vb_free_params = vb_params_paragami.flatten(vb_params_dict, free = True)

In [16]:
structure_model_lib.get_kl(g_obs, vb_params_dict,
                   prior_params_dict, 
                    use_logitnormal_sticks, 
                    gh_loc, gh_weights)

2695.7424932015747

In [17]:
get_free_vb_params_loss_cached(init_vb_free_params)

2695.7424932015747

In [18]:
vb_opt_free_params = opt_lib.optimize_full(get_free_vb_params_loss_cached, init_vb_free_params,
                    bfgs_max_iter = 50, netwon_max_iter = 50,
                    max_precondition_iter = 10,
                    gtol=1e-8, ftol=1e-8, xtol=1e-8)

running bfgs ... 
Iter 0: f = 2695.74249320
Iter 1: f = 2485.41702413
Iter 2: f = 2400.93040842
Iter 3: f = 2333.73846688
Iter 4: f = 2252.37495055
Iter 5: f = 2101.94356879
Iter 6: f = 1908.68223269
Iter 7: f = 1678.32081880
Iter 8: f = 1627.00351029
Iter 9: f = 1579.44051065
Iter 10: f = 1542.14648002
Iter 11: f = 1542.86847625
Iter 12: f = 1529.59678591
Iter 13: f = 1511.14681564
Iter 14: f = 1494.38397198
Iter 15: f = 1490.13495097
Iter 16: f = 1483.33497014
Iter 17: f = 1478.19091800
Iter 18: f = 1467.17033975
Iter 19: f = 1520.14083955
Iter 20: f = 1459.72185696
Iter 21: f = 1505.67372309
Iter 22: f = 1458.18081525
Iter 23: f = 1455.29442850
Iter 24: f = 1450.05524976
Iter 25: f = 1441.68068839
Iter 26: f = 1435.47876897
Iter 27: f = 1433.64117713
Iter 28: f = 1431.57390526
Iter 29: f = 1431.46669191
Iter 30: f = 1430.61112485
Iter 31: f = 1429.22492394
Iter 32: f = 1426.94396495
Iter 33: f = 1421.14192567
Iter 34: f = 1437.74020062
Iter 35: f = 1411.87319507
Iter 36: f = 3600.80

In [19]:
vb_opt_dict = vb_params_paragami.fold(vb_opt_free_params, free=True)

In [20]:
ind_mix_stick_propn_mean = vb_opt_dict['ind_mix_stick_propn_mean']
ind_mix_stick_propn_info = vb_opt_dict['ind_mix_stick_propn_info']
pop_freq_beta_params = vb_opt_dict['pop_freq_beta_params']

In [21]:
e_pop_allele_freq = pop_freq_beta_params[:, :, 0] / pop_freq_beta_params.sum(axis=2)

In [22]:
def find_min_perm(x, y, axis = 0):
    # perumutes array x along axis to find closest 
    # match to y
        
    perms = list(permutations(np.arange(x.shape[axis])))

    i = 0
    diff_best = np.Inf
    for perm in perms: 

        x_perm = x.take(perm, axis)
        
        diff = np.sum((x_perm - y)**2)
        
        if diff < diff_best: 
            diff_best = diff
            i_best = i
        
        i += 1

    return perms[i_best]

In [23]:
perm_best = find_min_perm(e_pop_allele_freq, true_pop_allele_freq, axis = 1)

In [24]:
e_pop_allele_freq[:, perm_best]

array([[0.92553933, 0.0110584 , 0.01560233, 0.01646585],
       [0.00963405, 0.98609818, 0.01836387, 0.01945906],
       [0.01089783, 0.01387308, 0.98260488, 0.02029114],
       [0.01630353, 0.0152515 , 0.02121923, 0.98607813],
       [0.0091834 , 0.02231597, 0.01778837, 0.03408285]])

In [25]:
true_pop_allele_freq

array([[0.99, 0.01, 0.01, 0.01],
       [0.01, 0.99, 0.01, 0.01],
       [0.01, 0.01, 0.99, 0.01],
       [0.01, 0.01, 0.01, 0.99],
       [0.01, 0.01, 0.01, 0.01]])

In [26]:
e_ind_admix = cluster_quantities_lib.get_e_cluster_probabilities(
                        ind_mix_stick_propn_mean, ind_mix_stick_propn_info,
                        gh_loc, gh_weights)[:, perm_best]

In [27]:
e_ind_admix.argmax(axis=1)

array([0, 0, 2, 0, 1, 1, 0, 2, 2, 3, 1, 0, 2, 2, 3, 3, 3, 0, 1, 1, 2, 1,
       0, 1, 3, 2, 0, 3, 1, 3, 2, 1, 1, 1, 3, 3, 1, 0, 0, 1, 3, 2, 3, 3,
       1, 3, 3, 2, 2, 2, 3, 1, 3, 3, 0, 3, 1, 0, 3, 0, 0, 0, 1, 0, 1, 1,
       2, 0, 1, 3, 3, 1, 3, 0, 2, 0, 0, 3, 0, 3, 3, 0, 0, 1, 2, 0, 2, 1,
       3, 1, 1, 1, 3, 1, 0, 0, 3, 3, 3, 2, 2, 3, 0, 1, 2, 0, 2, 3, 1, 0,
       2, 1, 2, 2, 0, 0, 0, 1, 2, 0, 2, 3, 1, 3, 0, 0, 0, 1, 3, 0, 0, 0,
       0, 1, 0, 2, 1, 3, 2, 1, 2, 3, 2, 3, 1, 0, 2, 0, 2, 1, 1, 1, 2, 1,
       2, 1, 0, 1, 1, 2])

In [28]:
true_ind_admix_propn.argmax(axis=1)

array([0, 0, 2, 0, 1, 1, 0, 2, 2, 3, 1, 0, 2, 2, 3, 3, 3, 0, 2, 1, 2, 1,
       0, 1, 3, 2, 0, 3, 1, 3, 2, 1, 1, 1, 3, 3, 1, 0, 0, 1, 3, 2, 3, 3,
       1, 3, 3, 2, 2, 2, 3, 1, 3, 3, 3, 3, 1, 0, 3, 0, 0, 0, 1, 0, 1, 1,
       2, 0, 1, 3, 3, 1, 3, 0, 2, 0, 0, 3, 0, 3, 3, 0, 0, 1, 2, 0, 2, 1,
       3, 1, 1, 1, 3, 1, 0, 0, 3, 3, 3, 2, 2, 3, 0, 1, 2, 0, 2, 3, 1, 0,
       2, 1, 2, 2, 0, 3, 0, 1, 2, 0, 2, 3, 1, 3, 0, 0, 0, 1, 3, 0, 0, 0,
       0, 1, 0, 2, 1, 3, 2, 1, 2, 3, 2, 3, 1, 0, 2, 0, 2, 1, 1, 1, 2, 1,
       2, 1, 0, 1, 1, 2])

In [29]:
np.mean(e_ind_admix.argmax(axis=1) == \
       true_ind_admix_propn.argmax(axis=1))

0.98125