In [None]:
import genjax
import jax
import jax.numpy as jnp
from genjax import ChoiceMapBuilder as C
from jax import jit

So far, we have mostly shown how to use GenJAX to run simulations in parallel. Whether it was a generative function on several random keys, on different arguments, they all had a similar flavor of "simply duplicating particles for inference."

Here we will show a different kind of example where parallelism can be used for better inference which has a very different flavor: we will do a type of MCMC update to a trace where the move itself benefits from parallel acceleration thanks to the structure of the generative function to which the update is performed.

Let's first create a simple HMM and run it.

In [None]:
length_chain = 50
state_size = 100
number_runs = 1000
# for numerical stability of the HMM, ensuring that the eigenvalues of the transition matrices are around 1.
magic_number = jnp.exp(1)
normalizer = 1.0 / jnp.sqrt(state_size / magic_number)
transition_matrix = (
    jax.random.normal(jax.random.PRNGKey(0), (state_size, state_size)) * normalizer
)
observation_matrix = (
    jax.random.normal(jax.random.PRNGKey(42), (state_size, state_size)) * normalizer
)
latent_variance = jnp.eye(state_size)
obs_variance = jnp.eye(state_size)
initial_state = jax.random.normal(jax.random.PRNGKey(0), (state_size,))


@genjax.gen
def hmm_step(x, _):
    new_x = (
        genjax.mv_normal(jnp.matmul(transition_matrix, x), latent_variance) @ "new_x"
    )
    _ = genjax.mv_normal(jnp.matmul(observation_matrix, new_x), obs_variance) @ "obs"
    return new_x, None


hmm = hmm_step.scan(n=length_chain)

key = jax.random.PRNGKey(0)
jitted = jit(hmm.repeat(n=number_runs).simulate)
trace = jitted(key, (initial_state, None))
trace.get_choices()
%timeit jitted(key, (initial_state, None))

Let's add observervations and run the default importance sampling.

In [None]:
chm = jax.vmap(
    lambda idx: C[idx, "obs"].set(
        idx.astype(float) * jnp.arange(state_size) / state_size
    )
)(jnp.arange(length_chain))


jitted = jit(lambda key: hmm.importance(key, chm, (initial_state, None)))
jitted(key)

In [None]:
def test_parallel_update_logic():
    simpled_jitted = jit(lambda key: hmm.simulate(key, (initial_state, None)))
    simple_tr = simpled_jitted(key)
    vars_to_update = simple_tr.get_choices()[..., "new_x"]

    magic_c_matrix = jnp.matmul(transition_matrix.T, jnp.linalg.inv(latent_variance))

    # single update
    vars_to_update = vars_to_update.at[2].set(
        jnp.matmul(magic_c_matrix, vars_to_update[2])
    )

    # multiple updates
    parity = 0
    idx_to_update = jnp.arange(length_chain)[jnp.arange(length_chain) % 2 == parity]
    # maybe something like that
    updating_vals = (magic_c_matrix.T @ vars_to_update[idx_to_update].T).T
    vars_to_update = vars_to_update.at[idx_to_update].set(updating_vals)

    # simple test
    vars_to_update = jnp.zeros((length_chain, state_size))
    updating_vals = (magic_c_matrix.T @ (vars_to_update[idx_to_update] + 1).T).T
    print(updating_vals.shape)
    vars_to_update = vars_to_update.at[idx_to_update].set(updating_vals)
    return vars_to_update


test_parallel_update_logic()

In [None]:
def gibbs_update(args, key):
    trace, parity = args
    vars_to_update = trace.get_choices()[..., "new_x"]
    idx_to_update = jnp.arange(length_chain)[jnp.arange(length_chain) % 2 == parity]

    # TODO: actual gibbs update, and need to use conditional rule from GaussPPL
    magic_c_matrix = jnp.matmul(transition_matrix.T, jnp.linalg.inv(latent_variance))
    updating_vals = (magic_c_matrix.T @ vars_to_update[idx_to_update].T).T
    vars_to_update = vars_to_update.at[idx_to_update].set(updating_vals)

    # TODO: need to actually return a proper trace in order to be able to iterate the gibbs update, and need to return the vars_to_update
    return (
        genjax._src.generative_functions.static.StaticTrace(
            trace.get_gen_fn(),
            trace.get_args(),
            trace.get_retval(),
            genjax.Pytree.field(
                default_factory=genjax._src.generative_functions.static.AddressVisitor
            ),
            genjax.Pytree.field(default_factory=list),
            trace.get_score(),
        ),
        (parity + 1) % 2,
    )


# gibbs_update((simple_tr, 0), key)

We can now test inference using the parallel Gibbs update.

In [None]:
# TODO: use vmap to update all chains in parallel, and use scan to keep updating until convergence

number_gibbs_sweep = 1000
keys = jax.random.split(key, number_gibbs_sweep)
# args = (simple_tr, 0)
# TODO: need to fix a ShapedArray(bool[50]) and the returned trace
# iterated_gibbs = lambda keys: jax.lax.scan(gibbs_update, args, keys)
# jitted_gibbs = jax.jit(iterated_gibbs)

# TODO: now need a parallel version for all the different initial traces.

We can check the quality of the inference in this case as it's a rare instance when one can do exact inference.

In [None]:
# TODO: exact inference using GaussPPL and comparison with Gibbs sampling

Note the time difference between the exact and approximate method.

In [None]:
# TODO: