In [1]:
import jax

import jax.numpy as np
import jax.scipy as sp

import scipy as osp

from vb_lib import structure_model_lib, preconditioner_lib, testutils

from bnpmodeling_runjingdev.sensitivity_lib import HyperparameterSensitivityLinearApproximation

import unittest 




In [2]:
# get model at optimum
vb_opt, vb_params_paragami, precond_objective = \
    testutils.construct_model_and_optimize()

Compiling cavi functions ...
CAVI compile time: 3.66sec

 running CAVI ...
iteration [1]; kl:151.878389; elapsed: 0.0235secs
iteration [2]; kl:151.276609; elapsed: 0.1796secs
iteration [3]; kl:151.07998; elapsed: 0.0303secs
iteration [4]; kl:150.965841; elapsed: 0.0054secs
iteration [5]; kl:150.884505; elapsed: 0.0057secs
iteration [6]; kl:150.82073; elapsed: 0.0057secs
iteration [7]; kl:150.767847; elapsed: 0.0056secs
iteration [8]; kl:150.722519; elapsed: 0.0054secs
iteration [9]; kl:150.683014; elapsed: 0.0054secs
iteration [10]; kl:150.648422; elapsed: 0.0054secs
iteration [11]; kl:150.618249; elapsed: 0.0053secs
iteration [12]; kl:150.592179; elapsed: 0.0053secs
iteration [13]; kl:150.569948; elapsed: 0.0052secs
iteration [14]; kl:150.551268; elapsed: 0.0053secs
iteration [15]; kl:150.535806; elapsed: 0.0052secs
iteration [16]; kl:150.523189; elapsed: 0.0053secs
iteration [17]; kl:150.51302; elapsed: 0.0053secs
iteration [18]; kl:150.504902; elapsed: 0.0053secs
iteration [19]; kl:

In [19]:
# the hessian without preconditioning
kl_hess = jax.lax.map(lambda v : precond_objective.hvp(vb_opt, v), 
                      np.eye(len(vb_opt)))

# the preconditioned hessian 
x_c = precond_objective.precondition(vb_opt, vb_opt)
kl_hess_precond = jax.lax.map(lambda v : \
                              precond_objective.hvp_precond(x_c, vb_opt, v), 
                              np.eye(len(vb_opt)))

In [20]:
def evaluate_condition_number(kl_hess): 
    # get eigenvalues
    kl_hess_evals = np.linalg.eigvals(kl_hess)

    # all real
    assert np.all(np.imag(kl_hess_evals) == 0.)
    kl_hess_evals = np.real(kl_hess_evals)

    # all positive 
    assert np.all(kl_hess_evals) > 0

    print('Hessian eigenvalues: ')
    print((kl_hess_evals.max(), 
           kl_hess_evals.min()))

    cn_hess = kl_hess_evals.max() / \
                kl_hess_evals.min()

    return cn_hess

In [21]:
print('Hessian eigenvalues: ')
cn_hess = evaluate_condition_number(kl_hess)
print('Precond Hessian eigenvalues: ')
cn_hess_precond = evaluate_condition_number(kl_hess_precond)
assert cn_hess_precond < cn_hess

Hessian eigenvalues: 
Hessian eigenvalues: 
(DeviceArray(5.18349375, dtype=float64), DeviceArray(0.06618864, dtype=float64))
Precond Hessian eigenvalues: 
Hessian eigenvalues: 
(DeviceArray(1.16835871, dtype=float64), DeviceArray(0.02412579, dtype=float64))


In [39]:
# now check solver
# the preconditoined function
vb_opt_dict = vb_params_paragami.fold(vb_opt, free = True)

cg_precond = lambda v : preconditioner_lib.get_mfvb_cov_matmul(v, vb_opt_dict,
                                    vb_params_paragami,
                                    return_sqrt = False, 
                                    return_info = True)

# define sensitivity class 
kwargs = dict(objective_fun = lambda x, y : 0., 
              opt_par_value = vb_opt, 
              hyper_par_value0 = np.array([0.]), 
              obj_fun_hvp = precond_objective.hvp, 
              # a null perturbation ... will set later
              hyper_par_objective_fun = lambda x, y : 0.,
              # this lets ust track the progress of the solver
              use_scipy_cgsolve = True, 
              cg_precond = None)


In [40]:
vb_sens = HyperparameterSensitivityLinearApproximation(**kwargs)

NOTE: using custom hvp
Compiling hvp ...
hvp compile time: 5.43286sec

Compiling cross hessian...
Cross-hessian compile time: 0.00439644sec

LR sensitivity time: 0.00159764sec



In [41]:
kwargs.update({'cg_precond': cg_precond})
vb_sens_precond = HyperparameterSensitivityLinearApproximation(**kwargs)

NOTE: using custom hvp
Compiling hvp ...
hvp compile time: 5.38759sec

Compiling preconditioner ...
preconditioner compile time: 1.29937sec

Compiling cross hessian...
Cross-hessian compile time: 0.0043776sec

LR sensitivity time: 0.00134945sec



In [42]:
b = jax.random.normal(key = jax.random.PRNGKey(443), shape = (len(vb_opt), ))

In [43]:
out1 = vb_sens.hessian_solver(b)

Iter [1]; elapsed 0.001sec; diff: 9.618408597646873
Iter [2]; elapsed 0.001sec; diff: 4.4078535876488525
Iter [3]; elapsed 0.001sec; diff: 2.771605525355097
Iter [4]; elapsed 0.001sec; diff: 1.9565825818799134
Iter [5]; elapsed 0.001sec; diff: 1.2979479019174516
Iter [6]; elapsed 0.0sec; diff: 1.3516364601385733
Iter [7]; elapsed 0.001sec; diff: 1.0287958713081684
Iter [8]; elapsed 0.001sec; diff: 0.8546277342587844
Iter [9]; elapsed 0.001sec; diff: 0.8049913666440518
Iter [10]; elapsed 0.001sec; diff: 0.49648247938920964
Iter [11]; elapsed 0.001sec; diff: 0.3521193870599186
Iter [12]; elapsed 0.001sec; diff: 0.2284653207940656
Iter [13]; elapsed 0.001sec; diff: 0.149397026952766
Iter [14]; elapsed 0.001sec; diff: 0.15323158864962438
Iter [15]; elapsed 0.001sec; diff: 0.08236512278335736
Iter [16]; elapsed 0.001sec; diff: 0.06547674414610535
Iter [17]; elapsed 0.001sec; diff: 0.06377983164335647
Iter [18]; elapsed 0.001sec; diff: 0.04085591060768166
Iter [19]; elapsed 0.001sec; diff: 0

In [44]:
out2 = vb_sens_precond.hessian_solver(b)

Iter [1]; elapsed 0.002sec; diff: 2.332343027521112
Iter [2]; elapsed 0.001sec; diff: 1.5666595443002211
Iter [3]; elapsed 0.001sec; diff: 1.3444100080155763
Iter [4]; elapsed 0.001sec; diff: 1.130718475172583
Iter [5]; elapsed 0.001sec; diff: 0.7869283138418265
Iter [6]; elapsed 0.001sec; diff: 0.6711221969261818
Iter [7]; elapsed 0.001sec; diff: 0.36922920675220683
Iter [8]; elapsed 0.001sec; diff: 0.20947676025084536
Iter [9]; elapsed 0.001sec; diff: 0.15179745241993142
Iter [10]; elapsed 0.001sec; diff: 0.12207459936224838
Iter [11]; elapsed 0.001sec; diff: 0.13862285141683076
Iter [12]; elapsed 0.001sec; diff: 0.10790235871605819
Iter [13]; elapsed 0.001sec; diff: 0.0576105423722407
Iter [14]; elapsed 0.001sec; diff: 0.035972536780313924
Iter [15]; elapsed 0.001sec; diff: 0.021525549533695957
Iter [16]; elapsed 0.001sec; diff: 0.0159552181215984
Iter [17]; elapsed 0.001sec; diff: 0.004853000009076802
Iter [18]; elapsed 0.001sec; diff: 0.002108486011389165
Iter [19]; elapsed 0.001s

In [47]:
np.abs(out1 - out2).max() < 1e-3

DeviceArray(True, dtype=bool)

In [51]:
assert vb_sens.cg_solver.iter > vb_sens_precond.cg_solver.iter