In [None]:
import genjax
import jax
import jax.numpy as jnp
from genjax import gen, normal

key = jax.random.PRNGKey(0)

Let's write a printing function to be able to visualize traces a bit more compactly.

In [None]:
class HiddenIndex:
    def __repr__(self):
        return "#"


class Addr:
    def __init__(self, addr, show_indices):
        if not show_indices:
            new_addr = []
            for a in addr:
                if isinstance(a, str):
                    new_addr.append(a)
                else:
                    new_addr.append(HiddenIndex())
            addr = new_addr
        self.addr = addr

    def __repr__(self):
        return f"<{self.addr}>"

    def __lt__(self, other):
        return self.addr < other.addr


def cm_kv(t, show_indices=False):
    ret = {}

    def cm_kv_inner(t, addr_path=None):
        if addr_path is None:
            addr_path = []
        else:
            addr_path = addr_path.copy()
        match type(t):
            case (
                genjax._src.core.generative.choice_map.XorChm
                | genjax._src.core.generative.choice_map.OrChm
            ):
                cm_kv_inner(t.c1, addr_path)
                cm_kv_inner(t.c2, addr_path)
            case genjax._src.core.generative.choice_map.StaticChm:
                addr_path.append(t.addr)
                cm_kv_inner(t.c, addr_path)
            case genjax._src.core.generative.choice_map.IdxChm:
                addr_path.append(t.addr)
                cm_kv_inner(t.c, addr_path)
            case genjax._src.core.generative.choice_map.ValueChm:
                ret[Addr(addr_path, show_indices)] = t.v
            case _:
                raise NotImplementedError(str(type(t)))

    cm_kv_inner(t)
    return ret

One classic trick is to encode all the options as an array and pick the desired value from the array with a dynamic one.

Here's a first example:

In [None]:
@gen
def model(
    i, means, vars
):  # provide all the possible values and the dynamic index to pick from them
    x = normal(means[i], vars[i]) @ "x"
    return x


model.simulate(key, (13, jnp.arange(10, dtype=jnp.float32), jnp.ones(13)))

Now, what if there's a value we may or may not want to get depending on a dynamic value?

In this case, we can use masking. Let's look at an example in JAX.

In [None]:
non_masked = jnp.arange(9).reshape(3, 3)

non_masked

In [None]:
# mask the upper triangular part of the matrix
mask = jnp.mask_indices(3, jnp.triu)

non_masked[mask]

We can use similar logic for generative functions in GenJAX. 

Let's create an HMM using the scan combinator.

In [None]:
N = 300
n_repeats = 100
variance = jnp.eye(N)
initial_state = jax.random.normal(jax.random.PRNGKey(0), (N,))


@genjax.gen
def hmm_step(x, _):
    new_x = genjax.mv_normal(x, variance) @ "new_x"
    return new_x, None


hmm = hmm_step.scan(n=100)

When we run it, we get a full trace. 

In [None]:
jitted = jax.jit(hmm.repeat(n=n_repeats).simulate)
trace = jitted(key, (initial_state, None))
cm_kv(trace.get_choices())

To get partial results instead, we can use masking as follows: