In [None]:
import genjax
import jax
import jax.numpy as jnp
from genjax import ChoiceMapBuilder as C
from jax import jit

So far, we have mostly shown how to use GenJAX to run simulations in parallel. Whether it was a generative function on several random keys, on different arguments, they all had a similar flavor of "simply duplicating particles for inference."

Here we will show a different kind of example where parallelism can be used for better inference which has a very different flavor: we will do a type of MCMC update to a trace where the move itself benefits from parallel acceleration thanks to the structure of the generative function to which the update is performed.

Let's first create a simple HMM and run it.

In [None]:
length_chain = 100
state_size = 300
number_runs = 100
transition_matrix = jax.random.normal(jax.random.PRNGKey(0), (state_size, state_size))
observation_matrix = jax.random.normal(jax.random.PRNGKey(0), (state_size, state_size))
latent_variance = jnp.eye(state_size)
obs_variance = jnp.eye(state_size)
initial_state = jax.random.normal(jax.random.PRNGKey(0), (state_size,))


@genjax.gen
def hmm_step(x, _):
    new_x = (
        genjax.mv_normal(jnp.matmul(transition_matrix, x), latent_variance) @ "new_x"
    )
    _ = genjax.mv_normal(jnp.matmul(observation_matrix, new_x), obs_variance) @ "obs"
    return new_x, None


hmm = hmm_step.scan(n=length_chain)

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
jitted = jit(hmm.repeat(n=number_runs).simulate)
trace = jitted(key, (initial_state, None))
# It takes ~1.7s to run 100 runs of the HMM of length 100 where each step has a state size of 300, on M2 CPU.
# Strangely enough, all the matrix-vector multiply only take ~0.2s while they perform about 900M operations.
# %timeit jitted(subkey, (initial_state, None))
trace.get_choices()

Let's add observervations and run the default importance sampling.

In [None]:
chm = jax.vmap(
    lambda idx: C[idx, "obs"].set(idx.astype(float) * jnp.arange(state_size))
)(jnp.arange(length_chain))


jitted = jit(lambda key: hmm.importance(key, chm, (initial_state, None)))
jitted(key)

In [None]:
def gibbs_update(trace, parity):
    vars_to_update = trace.get_choices()["new_x"]
    idx_to_update = jnp.arange(length_chain)[jnp.arange(length_chain) % 2 == parity]
    new_vars = jax.ops.index_update(
        vars_to_update,
        idx_to_update,
        jax.random.normal(key, vars_to_update[idx_to_update].shape),
    )
    return trace.set_choices(new_x=new_vars)

In [None]:
parity = False
jnp.arange(length_chain)[jnp.arange(length_chain) % 2 == parity]