In [None]:
import jax
import jax.numpy as jnp
from genjax import ChoiceMapBuilder as C
from genjax import gen, normal

key = jax.random.PRNGKey(0)

Let's first create a simple GenJAX model, some observations, use the default importance sampling for inference and see what we get.

In [None]:
length = 100


@gen
def model(x, _):
    x = normal(x, 1.0) @ "x"
    _ = normal(x, 1.0) @ "y"
    return x, None


scanned = model.scan(n=length)
scanned.simulate(key, (0.0, None))

In [None]:
obs = jax.vmap(lambda idx: C[idx, "y"].set(idx.astype(float)))(jnp.arange(length))

scanned.importance(key, obs, (0.0, None))

Instead of doing it in one go, we will separate the problem in two: we first sample the first half, see how good the particles are.

In [None]:
num_particles = 1000

half_scanned = model.scan(n=length // 2).repeat(n=num_particles)
half_obs = jax.vmap(lambda idx: C[idx, "y"].set(idx.astype(float)))(
    jnp.arange(length // 2)
)
trs, w = half_scanned.importance(key, half_obs, (0.0, None))
trs.get_choices()

We can now resample the particles and continue the simulation: