In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

from numpy.polynomial.hermite import hermgauss
import scipy as osp

from vb_lib import structure_model_lib, data_utils
import vb_lib.structure_optimization_lib as s_optim_lib
from vb_lib.preconditioner_lib import get_mfvb_cov_matmul

import paragami

from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib inline  

from bnpmodeling_runjingdev import cluster_quantities_lib, modeling_lib
from bnpmodeling_runjingdev.optimization_lib import run_lbfgs




In [2]:
import numpy as onp
onp.random.seed(53453)

# Load data

In [3]:
data = np.load('../../simulated_data/simulated_structure_data_nobs20_nloci50_npop4.npz')
g_obs = np.array(data['g_obs'], dtype = int)

In [4]:
# data_dir = '../../../../fastStructure/hgdp_data/huang2011_plink_files/'
# filenamebase = 'phased_HGDP+India+Africa_2810SNPs-regions1to36'
# filename = data_dir + filenamebase + '.npz'
# data = np.load(filename)

# g_obs = np.array(data['g_obs'], dtype = int)
# g_obs_raw = np.array(data['g_obs_raw'])

# # just checking ... 
# which_missing = (g_obs_raw == 3)
# (g_obs.argmax(-1) == g_obs_raw)[~which_missing].all()
# (g_obs[which_missing] == 0).all()

In [5]:
n_obs = g_obs.shape[0]
n_loci = g_obs.shape[1]

In [6]:
print(n_obs)
print(n_loci)

20
50


# Get prior

In [7]:
_, prior_params_paragami = \
    structure_model_lib.get_default_prior_params()

prior_params_dict = prior_params_paragami.random(key=jax.random.PRNGKey(41))

print(prior_params_paragami)

prior_params_free = prior_params_paragami.flatten(prior_params_dict, free = True)

OrderedDict:
	[dp_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_alpha] = NumericArrayPattern (1,) (lb=0.0, ub=inf)
	[allele_prior_beta] = NumericArrayPattern (1,) (lb=0.0, ub=inf)


# Get VB params 

In [8]:
k_approx = 8

In [9]:
gh_deg = 8
gh_loc, gh_weights = hermgauss(8)

In [10]:
use_logitnormal_sticks = True

vb_params_dict, vb_params_paragami = \
    structure_model_lib.get_vb_params_paragami_object(n_obs, n_loci, k_approx,
                                    use_logitnormal_sticks = use_logitnormal_sticks)
    
print(vb_params_paragami)

OrderedDict:
	[pop_freq_beta_params] = NumericArrayPattern (50, 8, 2) (lb=0.0, ub=inf)
	[ind_admix_params] = OrderedDict:
	[stick_means] = NumericArrayPattern (20, 7) (lb=-inf, ub=inf)
	[stick_infos] = NumericArrayPattern (20, 7) (lb=0.0001, ub=inf)


In [11]:
vb_params_dict = vb_params_paragami.random(key=jax.random.PRNGKey(41))

# Optimize

In [12]:
vb_init_dict = deepcopy(vb_params_dict)

In [13]:
# get optimization objective 
optim_objective, init_vb_free = \
    s_optim_lib.define_structure_objective(g_obs, vb_init_dict,
                        vb_params_paragami,
                        prior_params_dict,
                        gh_loc = gh_loc,
                        gh_weights = gh_weights)

out = run_lbfgs(optim_objective, init_vb_free)

vb_opt = out.x
vb_opt_dict = vb_params_paragami.fold(vb_opt, free = True)


Compiling objective ...
Iter 0: f = 2416.47883971
Compiling grad ...
Compile time: 12.2373secs

Running L-BFGS-B ... 
Iter 0: f = 2416.47883971
Iter 1: f = 2339.59615087
Iter 2: f = 2042.15460109
Iter 3: f = 1964.48538994
Iter 4: f = 1842.40209008
Iter 5: f = 1829.80247881
Iter 6: f = 1793.34241428
Iter 7: f = 1786.48851044
Iter 8: f = 1781.15586307
Iter 9: f = 1773.07413016
Iter 10: f = 1767.90596295
Iter 11: f = 1764.01683626
Iter 12: f = 1759.00108553
Iter 13: f = 1752.62770974
Iter 14: f = 1749.18441085
Iter 15: f = 1745.41599705
Iter 16: f = 1743.92272235
Iter 17: f = 1742.78820299
Iter 18: f = 1741.51612383
Iter 19: f = 1740.86902367
Iter 20: f = 1740.29480466
Iter 21: f = 1739.70705386
Iter 22: f = 1738.66142237
Iter 23: f = 1737.90358634
Iter 24: f = 1737.11893275
Iter 25: f = 1736.22924315
Iter 26: f = 1735.90258736
Iter 27: f = 1735.58752357
Iter 28: f = 1735.30538673
Iter 29: f = 1735.03437414
Iter 30: f = 1734.74637446
Iter 31: f = 1734.18640722
Iter 32: f = 1733.46884095
I

# Compute Hessian

In [14]:
get_kl_hess = jax.hessian(optim_objective._objective_fun)

In [15]:
kl_hess = get_kl_hess(vb_opt)

In [16]:
kl_hess = 0.5 * (kl_hess + kl_hess.transpose())

In [17]:
# check condition number 

kl_hess_evals = np.linalg.eigvals(kl_hess)

print((kl_hess_evals.max(), 
       kl_hess_evals.min()))

print('CN: ', kl_hess_evals.max() / kl_hess_evals.min())

(DeviceArray(10.98057129+0.j, dtype=complex128), DeviceArray(0.1086444+0.j, dtype=complex128))
CN:  (101.0689081069877+0j)


# Get preconditioner

In [18]:
# this approximates the inverse hessian square root
# this is "a" in PreconditionedFunction
_info_inv_sqrt_op = lambda v : \
                        get_mfvb_cov_matmul(v, vb_params_dict,
                                                vb_params_paragami,
                                                return_info = True, 
                                                return_sqrt = True)
    
info_inv_sqrt_op = jax.jit(_info_inv_sqrt_op)

print('compiling ... ')
_ = info_inv_sqrt_op(vb_opt).block_until_ready()
print('done')

compiling ... 
done


In [19]:
# this approximates the hessian square root
# this is "a_inv" in PreconditionedFunction

_cov_inv_sqrt_op = lambda v : \
                        get_mfvb_cov_matmul(v, vb_params_dict,
                                                vb_params_paragami,
                                                return_info = False,
                                                return_sqrt = True)
    
cov_inv_sqrt_op = jax.jit(_cov_inv_sqrt_op)

print('compiling ... ')
_ = cov_inv_sqrt_op(vb_opt).block_until_ready()
print('done')

compiling ... 
done


# Check condition number after preconditioning

In [20]:
basis_vecs = np.eye(kl_hess.shape[0])

In [21]:
info_inv_sqrt = jax.lax.map(info_inv_sqrt_op, basis_vecs)

# symmetrize
info_inv_sqrt = 0.5 * (info_inv_sqrt + info_inv_sqrt.transpose())

In [22]:
precond_hess = np.dot(np.dot(info_inv_sqrt, kl_hess), info_inv_sqrt)

In [23]:
# check condition after pre-conditioning

kl_precond_hess_evals = np.linalg.eigvals(precond_hess)

print((kl_precond_hess_evals.max(), 
       kl_precond_hess_evals.min()))

print('CN: ', kl_precond_hess_evals.max() / kl_precond_hess_evals.min())

(DeviceArray(5.41147702+0.j, dtype=complex128), DeviceArray(0.10539556+0.j, dtype=complex128))
CN:  (51.34444892804844+0j)
