In [None]:
import genjax
import jax
import jax.numpy as jnp
from genjax import Mask, bernoulli, categorical, gen, normal, or_else, pretty
from genjax._src.core.generative.choice_map import MaskChm

pretty()
key = jax.random.PRNGKey(0)

One classic trick is to encode all the options as an array and pick the desired value from the array with a dynamic one.

Here's a first example:

In [None]:
@gen
def model(
    i, means, vars
):  # provide all the possible values and the dynamic index to pick from them
    x = normal(means[i], vars[i]) @ "x"
    return x


model.simulate(key, (7, jnp.arange(10, dtype=jnp.float32), jnp.ones(10)))

Now, what if there's a value we may or may not want to get depending on a dynamic value?

In this case, we can use masking. Let's look at an example in JAX.

In [None]:
non_masked = jnp.arange(9).reshape(3, 3)

non_masked

In [None]:
# mask the upper triangular part of the matrix
mask = jnp.mask_indices(3, jnp.triu)

non_masked[mask]

We can use similar logic for generative functions in GenJAX. 

Let's create an HMM using the scan combinator.

In [None]:
state_size = 10
length = 10
variance = jnp.eye(state_size)
initial_state = jax.random.normal(jax.random.PRNGKey(0), (state_size,))


@genjax.gen
def hmm_step(x, _):
    new_x = genjax.mv_normal(x, variance) @ "new_x"
    return new_x, None


hmm = hmm_step.scan(n=length)

When we run it, we get a full trace. 

In [None]:
jitted = jax.jit(hmm.simulate)
trace = jitted(key, (initial_state, None))
trace.get_choices()

Let us now define a masked scan_combinator so that we can return intermediate results.

In [None]:
# TODO: this should be part of the standard library, hopefully soon.
def masked_scan_combinator(step, **scan_kwargs):
    """
    Given a generative function `step` so that `step.scan(n=N)` is valid,
    return a generative function accepting an input
    `(initial_state, masked_input_values_array)` and returning a pair
    `(masked_final_state, masked_returnvalue_sequence)`.
    This operates similarly to `step.scan`, but the input values can be masked.
    """
    mstep = step.mask().dimap(
        pre=lambda masked_state, masked_inval: (
            jnp.logical_and(masked_state.flag, masked_inval.flag),
            masked_state.value,
            masked_inval.value,
        ),
        post=lambda args, masked_retval: (
            Mask(masked_retval.flag, masked_retval.value[0]),
            Mask(masked_retval.flag, masked_retval.value[1]),
        ),
    )

    # This should be given a pair (
    #     Mask(True, initial_state),
    #     Mask(bools_indicating_active, input_vals)
    # ).
    # It will output a pair (masked_final_state, masked_returnvalue_sequence).
    scanned = mstep.scan(**scan_kwargs)

    scanned_nice = scanned.dimap(
        pre=lambda initial_state, masked_input_values: (
            Mask(True, initial_state),
            Mask(masked_input_values.flag, masked_input_values.value),
        ),
        post=lambda args, retval: retval,
    )

    return scanned_nice

To get the partial results in the HMM instead, we can use the masking as follows:

In [None]:
stop_at_index = 5
mask = Mask(jnp.arange(state_size) < stop_at_index, None)
masked_hmm = masked_scan_combinator(hmm_step, n=length)
choices = masked_hmm.simulate(key, (initial_state, mask)).get_choices()
choices

Let's now use it in a bigger computation where the masking index is dynamic and comes from a sampled value.

In [None]:
@gen
def larger_model(init, probs):
    i = categorical(probs) @ "i"
    mask = Mask(jnp.arange(10) < i, None)
    x = masked_hmm(init, mask) @ "x"
    return x


key, subkey = jax.random.split(key, 2)
init = jax.random.normal(key, (state_size,))
probs = jnp.arange(state_size) / sum(jnp.arange(state_size))
choices = larger_model.simulate(subkey, (init, probs)).get_choices()
choices

We have already seen how to use conditionals in GenJAX models in the `conditionals` notebook. Behind the scene, it's using the same logic with masks.

In [None]:
@gen
def cond_model(p):
    pred = p > 0
    arg_1 = (p,)
    arg_2 = (p,)
    v = (
        or_else(
            gen(lambda p: bernoulli(p) @ "v1"), gen(lambda p: bernoulli(-p) @ "v1")
        )(pred, arg_1, arg_2)
        @ "cond"
    )
    return v


choices = cond_model.simulate(key, (0.5,)).get_choices()
isinstance(choices.c.c1, MaskChm)