In [None]:
import jax
import jax.numpy as jnp
from jax import jit

import genjax
from genjax import ChoiceMapBuilder as C
from genjax import pretty
from genjax._src.generative_functions.combinators.scan import ScanTrace
from genjax._src.generative_functions.static import StaticTrace

pretty()
key = jax.random.PRNGKey(0)

This notebook assumes you are familiar with the Gibbs update rule. See the dedicated cookbook entry.

So far, we have mostly shown how to use GenJAX to run simulations in parallel. Whether it was a generative function on several random keys, on different arguments, they all had a similar flavor of "simply duplicating particles for inference."

Here we will show a different kind of example where parallelism can be used for better inference which has a very different flavor: we will do a type of MCMC update to a trace where the move itself benefits from parallel acceleration thanks to the structure of the generative function to which the update is performed.

Let's first create a simple HMM and run it.

In [None]:
length_chain = 50
state_size = 100
number_runs = 1000
# for numerical stability of the HMM, ensuring that the eigenvalues of the transition matrices are around 1.
magic_number = jnp.exp(1)
normalizer = 1.0 / jnp.sqrt(state_size / magic_number)
key, subkey = jax.random.split(key)
transition_matrix = jax.random.normal(subkey, (state_size, state_size)) * normalizer

key, subkey = jax.random.split(key)
observation_matrix = jax.random.normal(subkey, (state_size, state_size)) * normalizer
latent_variance = jnp.eye(state_size)
obs_variance = jnp.eye(state_size)
key, subkey = jax.random.split(key)
initial_state = jax.random.normal(subkey, (state_size,))


@genjax.gen
def initial_state_model():
    return (
        genjax.mv_normal(
            jnp.zeros(state_size, dtype=float), jnp.identity(state_size, dtype=float)
        )
        @ "initial_state"
    )


@genjax.gen
def hmm_step(x, _):
    new_x = (
        genjax.mv_normal(jnp.matmul(transition_matrix, x), latent_variance) @ "new_x"
    )
    _ = genjax.mv_normal(jnp.matmul(observation_matrix, new_x), obs_variance) @ "obs"
    return new_x, None


@genjax.gen
def hmm():
    x = initial_state_model() @ "init"
    _ = hmm_step.scan(n=length_chain)(x, None) @ "steps"


# Testing that the model runs
jitted = jit(hmm.repeat(n=number_runs).simulate)
key, subkey = jax.random.split(key)
trace = jitted(subkey, ())
trace.get_sample()

Let's add observervations and run the default importance sampling.

In [None]:
chm = jax.vmap(
    lambda idx: C["steps", idx, "obs"].set(
        idx.astype(float) * jnp.arange(state_size) / state_size
    )
)(jnp.arange(length_chain))


jitted = jit(lambda key: hmm.importance(key, chm, ()))
key, subkey = jax.random.split(key)
jitted(key)

In [None]:
def test_parallel_update_logic(key):
    simpled_jitted = jit(lambda key: hmm.simulate(key, ()))
    key, subkey = jax.random.split(key)
    simple_tr = simpled_jitted(subkey)
    vars_to_update = simple_tr.get_choices()["steps", ..., "new_x"]

    magic_c_matrix = jnp.matmul(transition_matrix.T, jnp.linalg.inv(latent_variance))

    # single update
    vars_to_update = vars_to_update.at[2].set(
        jnp.matmul(magic_c_matrix, vars_to_update[2])
    )

    # multiple updates
    parity = 0
    idx_to_update = jnp.arange(length_chain)[jnp.arange(length_chain) % 2 == parity]
    idx_to_update = jnp.arange(start=parity, step=2, stop=length_chain)
    # TODO: maybe something like that. also maybe faster/simpler with vmap?
    updating_vals = (magic_c_matrix.T @ vars_to_update[idx_to_update].T).T
    vars_to_update = vars_to_update.at[idx_to_update].set(updating_vals)

    # simple test
    vars_to_update = jnp.zeros((length_chain, state_size))
    updating_vals = (magic_c_matrix.T @ (vars_to_update[idx_to_update] + 1).T).T
    print(updating_vals.shape)
    vars_to_update = vars_to_update.at[idx_to_update].set(updating_vals)
    return vars_to_update


key, subkey = jax.random.split(key)
test_parallel_update_logic(subkey)

In [None]:
def gibbs_update(key, args):
    trace, parity = args
    vars_to_update = trace.get_choices()["steps", ..., "new_x"]
    # idx_to_update = jnp.arange(length_chain)[jnp.arange(length_chain) % 2 == parity]
    idx_to_update = jnp.arange(start=1, step=2, stop=length_chain)

    # TODO: actual gibbs update, and need to use conditional rule from GaussPPL
    magic_c_matrix = jnp.matmul(transition_matrix.T, jnp.linalg.inv(latent_variance))
    updating_vals = (magic_c_matrix.T @ vars_to_update[idx_to_update].T).T
    vars_to_update = vars_to_update.at[idx_to_update].set(updating_vals)

    # TODO: need to actually return a proper trace in order to be able to iterate the gibbs update, and need to return the vars_to_update
    updated_one = trace.inner.subtraces[1]  # TODO:
    inner = StaticTrace(
        trace.inner.gen_fn,
        trace.inner.args,
        trace.inner.retval,
        trace.inner.addresses,
        [trace.inner.subtraces[0], updated_one],
        trace.inner.score,
    )
    return (
        ScanTrace(trace.get_gen_fn(), inner, trace.args, trace.retval, trace.score),
        (parity + 1) % 2,
    ), None

We can now test inference using the parallel Gibbs update.

In [None]:
key, subkey = jax.random.split(key)
simple_tr = hmm.simulate(subkey, ())

number_gibbs_sweep = 1000
keys = jax.random.split(key, number_gibbs_sweep)
args = (simple_tr, 0)
jitted_gibbs = jax.jit(lambda keys: jax.lax.scan(gibbs_update, keys, args))
jitted_gibbs(keys)

We can also use `vmap` to launch different MCMC chains in parallel.

In [None]:
keys = jax.random.split(key, number_runs)
initial_states = jax.vmap(lambda key: jax.random.normal(key, (state_size,)))(keys)

key, subkey = jax.random.split(key)
keys = jax.random.split(key, number_runs)
initial_traces = jax.jit(
    jax.vmap(lambda initial_state, key: hmm.simulate(key, (initial_state, None)))
)(initial_states, keys)

In [None]:
key, subkey = jax.random.split(key)
keys = jax.random.split(key, number_gibbs_sweep)
jitted_gibbs = jax.jit(lambda tr, keys: jax.lax.scan(gibbs_update, (tr, 0), keys))
# TODO: need to feed the traces and keys with the right format
# traces = jitted_gibbs(initial_traces, keys)

We can check the quality of the inference in this case as it's a rare instance when one can do exact inference.

In [None]:
# TODO: exact inference using GaussPPL and comparison with Gibbs sampling

Note the time difference between the exact and approximate method.

In [None]:
# TODO: