In [None]:
import genjax
import jax
import jax.numpy as jnp
from genjax import Mask, gen, normal

key = jax.random.PRNGKey(0)

Let's write a printing function to be able to visualize traces a bit more compactly.

In [None]:
class HiddenIndex:
    def __repr__(self):
        return "#"


class Addr:
    def __init__(self, addr, show_indices):
        if not show_indices:
            new_addr = []
            for a in addr:
                if isinstance(a, str):
                    new_addr.append(a)
                else:
                    new_addr.append(HiddenIndex())
            addr = new_addr
        self.addr = addr

    def __repr__(self):
        return f"<{self.addr}>"

    def __lt__(self, other):
        return self.addr < other.addr


def cm_kv(t, show_indices=False):
    def cm_kv_inner(t, addr_path=None, flag=None):
        if addr_path is None:
            addr_path = []
        else:
            addr_path = addr_path.copy()
        match type(t):
            case genjax._src.core.generative.choice_map.XorChm:
                ret1 = cm_kv_inner(t.c1, addr_path, flag)
                ret2 = cm_kv_inner(t.c2, addr_path, flag)
                # Check for empty intersection
                set1 = set(ret1)
                set2 = set(ret2)
                in_common = set1.intersection(set2)
                if not in_common:
                    return ret1.update(ret2)
                else:
                    raise ValueError("Common keys found in XorChm")
            case genjax._src.core.generative.choice_map.OrChm:
                ret1 = cm_kv_inner(t.c1, addr_path, flag)
                ret2 = cm_kv_inner(t.c2, addr_path, flag)
                return ret1.update(ret2)
            case genjax._src.core.generative.choice_map.StaticChm:
                addr_path.append(t.addr)
                return cm_kv_inner(t.c, addr_path, flag)
            case genjax._src.core.generative.choice_map.IdxChm:
                addr_path.append(t.addr)
                return cm_kv_inner(t.c, addr_path, flag)
            case genjax._src.core.generative.choice_map.ValueChm:
                # TODO: needs to replace 0 with  special symbol for masked
                if flag is None:
                    return {Addr(addr_path, show_indices): t.v}
                else:
                    return {Addr(addr_path, show_indices): flag * t.v}
            case genjax._src.core.generative.choice_map.MaskChm:
                if flag is None:
                    flag = t.flag
                else:
                    flag = flag * t.flag
                return cm_kv_inner(t.c, addr_path, flag)
            case genjax._src.core.generative.choice_map.EmptyChm:
                pass
            case genjax._src.core.generative.choice_map.FilteredChm:
                # TODO: needs testing
                ret = cm_kv_inner(t.c, addr_path)
                sel = t.selection
                keys = list(ret.keys())
                filtered__out_keys = [
                    tup for tup in keys if all(x == y for x, y in zip(tup, sel))
                ]
                for k in filtered__out_keys:
                    del ret[k]
                return ret
            case _:
                raise NotImplementedError(str(type(t)))

    return cm_kv_inner(t)

One classic trick is to encode all the options as an array and pick the desired value from the array with a dynamic one.

Here's a first example:

In [None]:
@gen
def model(
    i, means, vars
):  # provide all the possible values and the dynamic index to pick from them
    x = normal(means[i], vars[i]) @ "x"
    return x


model.simulate(key, (7, jnp.arange(10, dtype=jnp.float32), jnp.ones(10)))

Now, what if there's a value we may or may not want to get depending on a dynamic value?

In this case, we can use masking. Let's look at an example in JAX.

In [None]:
non_masked = jnp.arange(9).reshape(3, 3)

non_masked

In [None]:
# mask the upper triangular part of the matrix
mask = jnp.mask_indices(3, jnp.triu)

non_masked[mask]

We can use similar logic for generative functions in GenJAX. 

Let's create an HMM using the scan combinator.

In [None]:
state_size = 10
length = 10
variance = jnp.eye(state_size)
initial_state = jax.random.normal(jax.random.PRNGKey(0), (state_size,))


@genjax.gen
def hmm_step(x, _):
    new_x = genjax.mv_normal(x, variance) @ "new_x"
    return new_x, None


hmm = hmm_step.scan(n=length)

When we run it, we get a full trace. 

In [None]:
jitted = jax.jit(hmm.simulate)
trace = jitted(key, (initial_state, None))
cm_kv(trace.get_choices())

Let us now define a masked scan_combinator so that we can return intermediate results.

In [None]:
# TODO: this should be part of the standard library, hopefully soon.
def masked_scan_combinator(step, **scan_kwargs):
    """
    Given a generative function `step` so that `step.scan(n=N)` is valid,
    return a generative function accepting an input
    `(initial_state, masked_input_values_array)` and returning a pair
    `(masked_final_state, masked_returnvalue_sequence)`.
    This operates similarly to `step.scan`, but the input values can be masked.
    """
    mstep = step.mask().dimap(
        pre=lambda masked_state, masked_inval: (
            jnp.logical_and(masked_state.flag, masked_inval.flag),
            masked_state.value,
            masked_inval.value,
        ),
        post=lambda args, masked_retval: (
            Mask(masked_retval.flag, masked_retval.value[0]),
            Mask(masked_retval.flag, masked_retval.value[1]),
        ),
    )

    # This should be given a pair (
    #     Mask(True, initial_state),
    #     Mask(bools_indicating_active, input_vals)
    # ).
    # It will output a pair (masked_final_state, masked_returnvalue_sequence).
    scanned = mstep.scan(**scan_kwargs)

    scanned_nice = scanned.dimap(
        pre=lambda initial_state, masked_input_values: (
            Mask(True, initial_state),
            Mask(masked_input_values.flag, masked_input_values.value),
        ),
        post=lambda args, retval: retval,
    )

    return scanned_nice

To get the partial results in the HMM instead, we can use the masking as follows:

In [None]:
stop_at_index = 5
mask = Mask(jnp.arange(state_size) < stop_at_index, None)

masked_hmm = masked_scan_combinator(hmm_step, n=state_size)

choices = masked_hmm.simulate(key, (initial_state, mask)).get_choices()

cm_kv(choices)

In [None]:
trace.get_choices()