In [3]:
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import genjax
from genjax.incremental import diff, UnknownChange, NoChange
from genjax.inference.translator import extending_trace_translator
from genjax.inference import smc

# Model
@genjax.gen(genjax.Unfold, max_length=10)
@genjax.gen(genjax.Static)
def chain(z_prev):
    z = genjax.normal(z_prev, 1.0) @ "z"
    x = genjax.normal(z, 1.0) @ "x"
    return z

# Proposal and step-by-step translator.
# This can be freely used inside scan'd code.
def get_translator(t, obs):
    @genjax.gen(genjax.Static)
    def proposal(obs_chm, _):
        masked_x = obs_chm[t, "x"] # from idx choice map
        x = masked_x.unmask() # will throw if invalid under checkify.
        z = genjax.normal(x, 1.0) @ "z"
        return z

    def choice_map_forward(proposal_choices):
        return genjax.indexed_choice_map(t, proposal_choices)

    def choice_map_inverse(transformed_choices):
        return transformed_choices[t].unmask()

    translator = extending_trace_translator(
        (diff(t, UnknownChange), diff(0.0, NoChange)),
        proposal,
        (),
        obs,
        choice_map_forward,
        choice_map_inverse,
        # Dynamically checks the bijection, and registers
        # the error with `jax.checkify` if it fails.
        check_bijection=True,
    ) 
    return translator

# Obs.
obs = genjax.indexed_choice_map(
    [0, 1, 2, 3],
    genjax.choice_map({"x": jnp.array([1.0, 2.0, 3.0, 4.0])}),
)

# SMC with custom proposal.
def extending_smc(key, obs, init_state):
    index_sel = genjax.indexed_select(0)
    obs_slice = obs.slice(0)
    key, sub_key = jax.random.split(key)
    smc_state = smc.smc_initialize(chain, 5).apply(
        sub_key, obs_slice, (0, init_state)
    )
    obs = jtu.tree_map(lambda v: v[1:], obs)

    def _inner(carry, xs):
        key, smc_state, t = carry
        obs_slice = xs
        t = t + 1
        translator = get_translator(t, obs_slice)
        key, sub_key = jax.random.split(key)
        smc_state = smc.smc_update(translator).apply(
            sub_key,
            smc_state,
        )
        return (key, smc_state, t), (smc_state,)

    (_, final_state, _), (stacked,) = jax.lax.scan(
        _inner,
        (key, smc_state, 0),
        obs,
    )
    return final_state, stacked

In [None]:
from jax.experimental.checkify import checkify
key = jax.random.PRNGKey(314159)

# Use the console to enforce checkify -- this enables
# runtime checks throughout GenJAX.
with genjax.console(enforce_checkify=True):
    err, (final_state, stacked) = jax.jit(checkify(extending_smc))(key, obs, 0.0)
    err.throw()