In [None]:
import jax
import jax.numpy as jnp
from jax import make_jaxpr

import genjax
from genjax import normal, pjax, pretty

pretty()

In [None]:
def prog():
    x = normal.sample(0.0, 1.0)
    return x

In [None]:
make_jaxpr(pjax.vmap(prog, axis_size=10))()

In [None]:
def prog(b):
    x = normal.sample(0.0, 1.0)
    v = jax.lax.cond(
        b, lambda v: normal.sample(v, 1.0), lambda v: normal.sample(v, 2.0), x
    )
    return v

In [None]:
make_jaxpr(pjax.vmap(prog, in_axes=(None,), axis_size=5))(False)

In [None]:
import jax.random as jrand

from genjax import ChoiceMapBuilder as C


@genjax.vmap(in_axes=(0,))
@genjax.gen
def kernel(x):
    z = genjax.normal(x, 1.0) @ "z"
    y = genjax.normal(z, 2.0) @ "y"
    return z


map_over = jnp.arange(0, 3, dtype=float)
chm = pjax.vmap(lambda v: C["z"].set(v))(jnp.array([3.0, 2.0, 3.0]))

good_jitted = jax.jit(pjax.seed(kernel.importance))
good_jitted(jrand.key(2), chm, (map_over,))

In [None]:
make_jaxpr(kernel.importance)(chm, (map_over,))

In [None]:
bad_chm = pjax.vmap(lambda idx, v: C[idx, "z"].set(v))(
    jnp.arange(3), jnp.array([3.0, 2.0, 3.0])
)
bad_jitted = jax.jit(pjax.seed(kernel.bad_importance))
bad_jitted(jrand.key(2), bad_chm, (map_over,))

In [None]:
key = jrand.key(1)

In [None]:
%%timeit 
bad_jitted.lower(key, bad_chm, (map_over,))

In [None]:
%%timeit
good_jitted.lower(key, chm, (map_over,))