In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

from jax.config import config
config.update("jax_enable_x64", True)

from jax import vmap, grad, random, jit, tree_map
import jax.numpy as jnp
from inference.MP_Inference import lds_inference, lds_inference_sequential, lds_inference_homog, hmm_inference, lds_to_hmm_mf, hmm_to_lds_mf

import pickle
import time

from inference import SLDS_Inference

This notebook assumes `slds.pkl` has been constructed by the previous notebook

In [2]:
with open("saved_models/slds.pkl", 'rb') as f:
    recog_potentials, pgm_potentials, _, _ = pickle.load(f)

all_params = (recog_potentials, *pgm_potentials)

initializer = random.PRNGKey(47)

In this notebook, we will set the convergence threshold used to prematurely end iterative inference to 0, thus guaranteeing we take the specified number of inference steps

In [3]:
SLDS_Inference.CONV_THRESH = 0.
_, cat_expected_stats1 = SLDS_Inference.slds_inference_unrolled(*all_params, initializer)

Once a function gets compiled by jit, changing a global variable won't change the inner workings of the function. If we want to try running it with different global variable setting, we have to reload (obviously in an ideal world, this parameter would be an argument, but this made less sense for model training than it does here for debugging).

In [4]:
import importlib
importlib.reload(SLDS_Inference)
SLDS_Inference.MAX_ITER = 20
SLDS_Inference.CONV_THRESH = 0.

_, cat_expected_stats2 = SLDS_Inference.slds_inference_unrolled(*all_params, initializer)

The endpoint of optimization is slightly different after 10 vs 20 steps

In [5]:
jnp.mean(jnp.abs(cat_expected_stats2 - cat_expected_stats1))

Array(0.00154468, dtype=float64)

In [6]:
import importlib
importlib.reload(SLDS_Inference)
SLDS_Inference.MAX_ITER = 50
SLDS_Inference.CONV_THRESH = 0.

_, cat_expected_stats3 = SLDS_Inference.slds_inference_unrolled(*all_params, initializer)

After 20 iterations, it's a lot closer to converged

In [7]:
jnp.mean(jnp.abs(cat_expected_stats3 - cat_expected_stats2))

Array(0.00010522, dtype=float64)

Now let's check to make sure the implicit function gives the same (forward pass) output

In [8]:
gaus_expected_stats, cat_expected_stats = SLDS_Inference.slds_inference_implicit(*all_params, initializer)

In [9]:
jnp.mean(jnp.abs(cat_expected_stats - cat_expected_stats3))

Array(7.3085262e-15, dtype=float64)

Almost exactly the same! 

#### Gradient check

Now, let's start comparing the gradients computed by the two methods. 

This example is so small (a single sequence with length $T=100$ and dimension $d=5$ that unrolled gradients work fine

In [10]:
# start by defining some arbitrary scalar-valued function of the output, so we can backprop
def composition_func(params):
    a,b,_ = SLDS_Inference.slds_inference_unrolled_baseline(*params, initializer)
    
    # computes KL(q(z,k) || p(z,k))
    c = SLDS_Inference.slds_kl(*params, a, b, 0.) 
    return sum(tree_map(lambda x: x.sum(), a)) + c

grad_func_unrolled = (grad(composition_func))
unrolled_grads = grad_func_unrolled(all_params)
print(composition_func(all_params))

2483.4785798474977


In [11]:
def composition_func(params):
    # one of the implicit implementations, inherits the forward pass from slds_inference_implicit
    a,b = SLDS_Inference.slds_inference_itersolve_uncapped(*params, initializer) 
    
    # computes KL(q(z,k) || p(z,k))
    c = SLDS_Inference.slds_kl(*params, a, b, 0.) 
    return sum(tree_map(lambda x: x.sum(), a)) + c

grad_func_implicit = jit(grad(composition_func))
implicit_grads = grad_func_implicit(all_params)
print(composition_func(all_params))

2483.4785986185616


Numerical impressions accrue, but we see below the gradients are very similar!

In [12]:
# for each parameter, the mean absolute different in gradient values
tree_map(lambda x,y: jnp.abs(jnp.mean(x-y)), unrolled_grads, implicit_grads)

((Array(0.00013658, dtype=float32), Array(0.00029361, dtype=float32)),
 (Array(0.00734411, dtype=float64),
  Array(0.0034201, dtype=float64),
  Array(0.00395355, dtype=float64),
  Array(0.00342012, dtype=float64),
  Array(0.00734407, dtype=float64),
  Array(7.05213665e-13, dtype=float64)),
 (Array(3.1836066e-08, dtype=float64), Array(1.60530376e-07, dtype=float64)),
 Array(0., dtype=float64),
 Array(5.48519563e-15, dtype=float64),
 Array(1.68606448e-13, dtype=float64))