In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import jax
jax.config.update("jax_enable_x64", True)

from jax import vmap, grad, random, jit, tree_map
import jax.numpy as jnp
from inference.MP_Inference import lds_inference, lds_inference_sequential, lds_inference_homog, hmm_inference, lds_to_hmm_mf, hmm_to_lds_mf

import pickle
import time

from inference import SLDS_Inference

This notebook assumes `slds.pkl` has been constructed by the previous notebook

In [2]:
with open("saved_models/slds.pkl", 'rb') as f:
    recog_potentials, pgm_potentials, _, _ = pickle.load(f)

all_params = (recog_potentials, *pgm_potentials)

initializer = random.PRNGKey(47)

In this notebook, we will set the convergence threshold used to prematurely end iterative inference to 0, thus guaranteeing we take the specified number of inference steps

In [3]:
SLDS_Inference.CONV_THRESH = 0.
_, cat_expected_stats1 = SLDS_Inference.slds_inference_unrolled(*all_params, initializer)

Once a function gets compiled by jit, changing a global variable won't change the inner workings of the function. If we want to try running it with different global variable setting, we have to reload (ideally this parameter would be an argument, but this made less sense for model training than it does here for debugging).

In [4]:
import importlib
importlib.reload(SLDS_Inference)
SLDS_Inference.MAX_ITER = 20
SLDS_Inference.CONV_THRESH = 0.

_, cat_expected_stats2 = SLDS_Inference.slds_inference_unrolled(*all_params, initializer)

The endpoint of optimization is slightly different after 10 vs 20 steps

In [5]:
jnp.mean(jnp.abs(cat_expected_stats2 - cat_expected_stats1))

Array(0.00150787, dtype=float64)

We can check how converged these two values are by how much they change after an additional block update:

In [6]:
def forward_iter_block(cat_expected_stats):
    recog_potentials, E_mniw_params, init, E_init_normalizer, E_init_lps, E_trans_lps = all_params
    
    gaus_natparam, E_prior_logZ = hmm_to_lds_mf(cat_expected_stats, E_mniw_params, E_init_normalizer)
    gaus_expected_stats, gaus_logZ, _ = lds_inference(recog_potentials, init, gaus_natparam)

    cat_natparam = lds_to_hmm_mf(gaus_expected_stats, E_mniw_params)
    cat_es, hmm_logZ, _ = hmm_inference(E_init_lps, E_trans_lps, cat_natparam)
    return cat_es

jnp.mean(jnp.abs(forward_iter_block(cat_expected_stats1) - cat_expected_stats1))

Array(0.00027077, dtype=float64)

In [7]:
jnp.mean(jnp.abs(forward_iter_block(cat_expected_stats2) - cat_expected_stats2))

Array(9.92897613e-05, dtype=float64)

In [8]:
import importlib
importlib.reload(SLDS_Inference)
SLDS_Inference.MAX_ITER = 100
SLDS_Inference.CONV_THRESH = 0.

_, cat_expected_stats3 = SLDS_Inference.slds_inference_unrolled(*all_params, initializer)

jnp.mean(jnp.abs(forward_iter_block(cat_expected_stats3) - cat_expected_stats3))

Array(8.96105988e-08, dtype=float64)

Due to numerical imprecisions, it may be impossible to perfectly reach a fixed point of this block update function. However, we can get extremely close.

Now let's check to make sure the implicit function gives the same (forward pass) output

In [9]:
gaus_expected_stats, cat_expected_stats = SLDS_Inference.slds_inference_implicit(*all_params, initializer)

In [10]:
assert jnp.isclose(cat_expected_stats, cat_expected_stats3).all()

#### Gradient check

Now, let's start comparing the gradients computed by the two methods. 

This example is so small (a single sequence with length $T=100$ and dimension $d=5$ that unrolled gradients work fine

In [11]:
# start by defining some arbitrary scalar-valued function of the output, so we can backprop
def composition_func(params):
    a,b,_ = SLDS_Inference.slds_inference_unrolled_baseline(*params, initializer)
    
    # computes KL(q(z,k) || p(z,k))
    c = SLDS_Inference.slds_kl(*params, a, b, 0.) 
    return sum(tree_map(lambda x: x.sum(), a)) + c

grad_func_unrolled = (grad(composition_func))
unrolled_grads = grad_func_unrolled(all_params)
print(composition_func(all_params))

1721.819184951477


In [12]:
def composition_func(params):
    # one of the implicit implementations, inherits the forward pass from slds_inference_implicit
    a,b = SLDS_Inference.slds_inference_cgsolve(*params, initializer) 
    
    # computes KL(q(z,k) || p(z,k))
    c = SLDS_Inference.slds_kl(*params, a, b, 0.) 
    return sum(tree_map(lambda x: x.sum(), a)) + c

grad_func_implicit = jit(grad(composition_func))
implicit_grads = grad_func_implicit(all_params)
print(composition_func(all_params))

1721.8191849515183


Numerical impressions accrue, but we see below the gradients are very similar!

In [13]:
# for each parameter, the mean absolute different in gradient values
tree_map(lambda x,y: jnp.abs(jnp.mean(x-y)), unrolled_grads, implicit_grads)

((Array(0.00029938, dtype=float32), Array(0.00022077, dtype=float32)),
 (Array(0.00552423, dtype=float64),
  Array(0.00748995, dtype=float64),
  Array(0.00061389, dtype=float64),
  Array(0.00748981, dtype=float64),
  Array(0.00552422, dtype=float64),
  Array(7.2011147e-13, dtype=float64)),
 (Array(4.28741647e-08, dtype=float64), Array(3.31068861e-07, dtype=float64)),
 Array(0., dtype=float64),
 Array(7.99967731e-15, dtype=float64),
 Array(1.78052018e-13, dtype=float64))