In [1]:
import jax

import jax.numpy as np

from numpy.polynomial.hermite import hermgauss

from vb_lib import structure_model_lib, data_utils
import vb_lib.structure_optimization_lib as s_optim_lib

import paragami

import time

import pprint



In [2]:
import numpy as onp
onp.random.seed(53453)

# Simulate data

In [3]:
n_obs = 1107 # number of individuals
n_loci = 2810 # number of loci
n_pop = 5 # number of true populations

# array of size (n_obs x n_loci x 3)
g_obs, true_pop_allele_freq, true_ind_admix_propn = \
    data_utils.draw_data(n_obs, n_loci, n_pop)

Each individual is genotyped at `n_loci` loci. 

At each loci, the genotype is one of three classes: aa, Aa, or AA. 

For individual $n$ at loci $l$, `g_obs[n, l]` contains a one-hot-encoding of its genotype: 

In [4]:
n = 0
l = 0

# either (1, 0, 0); (0, 1, 0); or (0, 0, 1)
print(g_obs[n, l])

[0. 0. 1.]


Each individual is generated as a mixture of `n_pop` ancentral populations. 

The mixture proportions (as well as the number of ancenstral populations) need to be inferred.

In [18]:
# the true mixture proportions for individual n
# (we know this because we simulated the data): 
true_ind_admix_propn[n]

array([6.36557155e-01, 2.18859797e-01, 2.50443253e-02, 1.19538717e-01,
       6.24815061e-09])

Now at each loci $l$, there are two chromosomes, call them chromosome $1$ and chromosome $2$. 

Each chromosome belongs to a particular population. 

That is, the population indicators $z_{nl1}$ and $z_{nl2}$ are drawn from a categorical with mixture weights 
`true_ind_admix_propn[n]` for each $l$. 

Each population has a particular frequncy of the dominant allele $A$. This frequency also needs to inferred. 

Suppose loci $l$ of  individual $n$ comes from population $k = 0$. Then the probability that chromosome 1 is dominant (A) is: 

In [6]:
# the probability of observing the dominant allele A 
# at locus l, if it comes from population k
k = 0
true_pop_allele_freq[l, k]

0.9195577707266201

**In summary**, we need to infer: 
* A (`n_obs, n_pop`) array of individual mixture proportions. Rows sum to 1. 
* A (`n_loci, n_pop`) array of population allele frequncies. Entries are between 0 and 1. 
* A (`n_obs, n_loci, 2, n_pop`) array of cluster probabilities. Last dimension sums to 1, and give the probability of individual $n$, locus $l$, chromosome $i$ belonging to population $k$. 

Below, we will take truncated BNP approach, where we do not know `n_pop`. We will replace `n_pop` with a large positive integer, `k_approx`. 

# Get prior

The individual mixture proportions are drawn from a Dirichlet stick-breaking process, with `dp_prior_alpha` the concentration parameter. 

The population allele frequencies at each $l$, $k$ are drawn iid from a Beta distribution, with parameters `allale_prior_alpha` and `allele_prior_beta`. 

In [7]:
prior_params_dict, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

pprint.pprint(prior_params_dict)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

{'allele_prior_alpha': DeviceArray([1.], dtype=float64),
 'allele_prior_beta': DeviceArray([1.], dtype=float64),
 'dp_prior_alpha': DeviceArray([3.], dtype=float64)}


# Get VB params 

In [8]:
# number of components in truncated BNP approximation
k_approx = 20

The variational parameters: 

* `pop_freq_beta_params`: array of size (`n_loci`, `k_approx`, 2). Describes Beta distributed population frequncies. 

* `ind_admix_params`, which contains `stick_means` and `stick_infos`, each of size (`n_obs`, `k_approx - 1`). These are parameters for the logitnormal distribution on the stick-breaking proportions. (This differs from the usual approach of using a Beta distributed variational distribution on sticks. ). 

In [9]:
# randomly initialized vb parameters along with corresponding pattern
vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs,
                                                      n_loci,
                                                      k_approx,
                                                      use_logitnormal_sticks = True)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (2810, 20, 2) (lb=0.0, ub=inf)
	[ind_admix_params] = OrderedDict:
	[stick_means] = NumericArrayPattern (1107, 19) (lb=-inf, ub=inf)
	[stick_infos] = NumericArrayPattern (1107, 19) (lb=0.0, ub=inf)


The paragami object provides an easy way to convert parameter dictionaries to a flattened vector of real-valued, unconstrained parameters (`paragami.flatten`), or vice-versa (`paragami.fold`)

In [10]:
# this is a real-valued vector of variational paramters
vb_params = vb_params_paragami.flatten(vb_params_dict, free = True)
print(vb_params.shape)

(154466,)


Since we use a logitnormal approximation to te stick-breaking proportions (instead of the usual approach of making the variational distributions also Beta distributed), we need to evaluate expectations under a logitnormal distribution. 

We do so using Gauss-Hermite quadrature. Compute locations and weights here ...

In [11]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

# The KL objective

**Notice** that `vb_params_dict` does not include the `(n_obs, n_loci, 2, k_approx)` array of cluster belonging probabilities! An array of size `(n_obs, n_loci, 2, k_approx)` would be too large to instantiate in memory. 

Inside our kl objective, we implicitly set the optimal cluster belongings as a function of `vb_params` and never instantiate the full array of cluster belongings in memory. 

(We only ever instantiate a `(n_obs, 2, k_approx)` array by writing a for-loop through the loci). 

Let $\theta$ be `vb_params`, and let $\zeta$ be the array of *unconstrained* cluster belonging probabilities. 
Let $z$ be the *constrained* probabilities: in pseudo-code, `z = softmax(zeta, axis = -1)`.  

Let $\zeta^*(\theta)$ be the optimal unconstrained cluster belonging probabilties. We need to work with unconstrained probabilities so that this optimality condition holds: 

\begin{align}
f_\zeta(\theta, \zeta^*(\theta)) = 0 \quad \forall\; \theta
\end{align}
we will need this condition later. 

The function `get_kl` below is the function 
\begin{align}
    F(\theta) := f(\theta, \zeta^*(\theta))
\end{align}

In [12]:
def get_kl(vb_params): 
    
    # below, the `detach_ez = False` argument 
    # allows gradients to backpropagate through the $z$'s. 
    # In other words, derivatives wrt to theta are
    # the **total** derivative of the KL. 
    
    # when `detach_ez = True`, derivatives will return 
    # the **partial** derivative of the KL. 
    
    vb_params_dict = vb_params_paragami.fold(vb_params, free = True)

    return structure_model_lib.get_kl(g_obs,
                                      vb_params_dict, 
                                      prior_params_dict, 
                                      gh_loc,
                                      gh_weights, 
                                      detach_ez = False)

_get_kl_hvp = lambda v : jax.jvp(jax.grad(get_kl), (vb_params, ), (v, ))[1]
get_kl_hvp = jax.jit(_get_kl_hvp)

Computing hessian vector products is either super slow ... or just crashes the notebook entirely

In [13]:
v = jax.random.normal(key = jax.random.PRNGKey(0), 
                     shape = (len(vb_params), ))

In [14]:
run_kl_hvp = False
if run_kl_hvp: 
    print('compiling ...')
    t0 = time.time()
    _ = get_kl_hvp(v).block_until_ready()
    print('elapsed: {:.3f}sec'.format(time.time() - t0))
    

    t0 = time.time() 
    true_hvp = get_kl_hvp(v).block_until_ready()
    print('Evaluation time: {:.3f}sec'.format(time.time() - t0))

# My custom objective

This class contains a custom implementation of the HVP using the Schur complement, which I detail below. 

In [15]:
stru_objective = s_optim_lib.StructureObjective(g_obs, 
                                                 vb_params_paragami,
                                                 prior_params_dict, 
                                                 gh_loc, gh_weights)

compiling objective ... 
done. Elasped: 78.5147


Using the above optimality condition, we can derive the Schur complement decomposition of the Hessian: 
\begin{align}
    \nabla^2 F(\theta) = f_{\theta\theta} + f_{\theta\zeta}^Tf^{-1}_{\zeta\zeta}f_{\theta\zeta},
\end{align}
where the RHS is evaluated at $(\theta, \zeta^*(\theta))$. 

We can decompose the second term. Let $s$ be the mapping from $\zeta$ to $z$ (this is just the softmax function). Then 
\begin{align}
f_{\theta\zeta}^Tf^{-1}_{\zeta\zeta}f_{\theta\zeta} = 
    [\nabla_\theta \zeta^*]^T [\nabla_\zeta s]^T
    f_{zz}
    [\nabla_\zeta s]
    [\nabla_\theta \zeta^*]
\end{align}

where $f_zz$ is the Hessian of the KL with respect to the clustering belonging probabilities.

Note that
* $f_{zz}$ is diagonal, with entries $\{1 / z_{nlki}\}$: all terms in the KL are linear in z except the entropy term 
(given by $z \log z$). 
* $\nabla_\zeta s$ is the Jacobian of the softmax function, which has a closed form. 
* $\nabla_\theta \zeta^*$ is sparse. 


In [16]:
# the above decomposition is implemented in the hvp method of `stru_objective`. 

for i in range(5): 
    t0 = time.time() 
    my_hvp = stru_objective.hvp(vb_params, v).block_until_ready()
    print('elapsed: {:.3f}sec'.format(time.time() - t0))

elapsed: 14.257sec
elapsed: 14.314sec
elapsed: 14.222sec
elapsed: 14.211sec
elapsed: 14.283sec


In [17]:
if run_kl_hvp: 
    assert np.abs(my_hvp - true_hvp).max() < 1e-8