In [None]:
import genjax
import jax
import jax.numpy as jnp
from genjax import ChoiceMapBuilder as C
from genjax import Mask, gen, normal
from genjax import SelectionBuilder as S

key = jax.random.PRNGKey(0)

In [None]:
class HiddenIndex:
    def __repr__(self):
        return "#"


class Addr:
    def __init__(self, addr, show_indices):
        if not show_indices:
            new_addr = []
            for a in addr:
                if isinstance(a, str):
                    new_addr.append(a)
                else:
                    new_addr.append(HiddenIndex())
            addr = new_addr
        self.addr = addr

    def __repr__(self):
        return f"<{self.addr}>"

    def __lt__(self, other):
        return self.addr < other.addr


def pytreefy(t, show_indices=False):
    def cm_kv_inner(t, addr_path=None, flag=None):
        if addr_path is None:
            addr_path = []
        else:
            addr_path = addr_path.copy()
        match type(t):
            case genjax._src.core.generative.choice_map.XorChm:
                ret1 = cm_kv_inner(t.c1, addr_path, flag)
                ret2 = cm_kv_inner(t.c2, addr_path, flag)
                # Check for empty intersection
                set1 = set(key.__repr__() for key in ret1.keys())
                set2 = set(key.__repr__() for key in ret2.keys())
                in_common = set1.intersection(set2)
                if not in_common:
                    ret1.update(ret2)
                    return ret1
                else:
                    raise ValueError("Common keys found in XorChm")
            case genjax._src.core.generative.choice_map.OrChm:
                ret1 = cm_kv_inner(t.c1, addr_path, flag)
                ret2 = cm_kv_inner(t.c2, addr_path, flag)
                ret1.update(ret2)
                return ret1
            case genjax._src.core.generative.choice_map.StaticChm:
                addr_path.append(t.addr)
                return cm_kv_inner(t.c, addr_path, flag)
            case genjax._src.core.generative.choice_map.IdxChm:
                addr_path.append(t.addr)
                return cm_kv_inner(t.c, addr_path, flag)
            case genjax._src.core.generative.choice_map.ValueChm:
                if isinstance(t.v, genjax._src.core.generative.choice_map.FilteredChm):
                    return cm_kv_inner(t.v, addr_path, flag)
                # TODO: a better version would replace the masked values with a special symbol indicating masked values
                if flag is None:
                    return {Addr(addr_path, show_indices): t.v}
                else:
                    return {Addr(addr_path, show_indices): (t.v.T * flag.T).T}
            case genjax._src.core.generative.choice_map.MaskChm:
                if flag is None:
                    flag = t.flag
                else:
                    # broadcasting with leading axis on the left
                    flag = (flag.T * t.flag.T).T
                return cm_kv_inner(t.c, addr_path, flag)
            case genjax._src.core.generative.choice_map.EmptyChm:
                return {}
            case genjax._src.core.generative.choice_map.FilteredChm:
                ret = cm_kv_inner(t.c, addr_path)
                # TODO: this should grap the list of addresses not just the top one. this creates a bug in test 6.1
                sel = t.selection.addr
                keys = ret.keys()
                kept_keys = [
                    k for k in keys if all(x == y for x, y in zip(k.addr, sel))
                ]
                return {key: ret[key] for key in kept_keys}
            case _:
                raise NotImplementedError(str(type(t)))

    return cm_kv_inner(t)


def test_for_masking_logic():
    # masks to test
    flag1 = jnp.array([True, False, True])
    flag2 = jnp.array([
        [
            True,
            False,
            False,
        ],
        [True, False, True],
        [True, True, True],
    ])
    flag = (flag2.T * flag1.T).T  # correct way to broadcast the mask
    # 2D matrix
    test_matrix_2d = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    test_matrix_2d[flag1, :]  # returns the first and third row
    test_matrix_2d[flag2]  # returns the non-zero elements
    test_matrix_2d[flag]  # returns the non-zero elements

    # 3D matrix
    test_matrix_3d = jnp.array([
        [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
        [[10, 11, 12], [13, 14, 15], [16, 17, 18]],
        [[19, 20, 21], [22, 23, 24], [25, 26, 27]],
    ])
    return (test_matrix_3d.T * flag.T).T, test_matrix_3d[
        flag
    ]  # both return the same non-zero elements, but the first one is a 3D matrix

Test 1: print a trace from a simple model with several traced variables.

In [None]:
@gen
def model():
    x = normal(0.0, 1.0) @ "x"
    y = normal(0.0, 1.0) @ "y"
    z = normal(0.0, 1.0) @ "z"
    return x, y, z


pytreefy(model.simulate(key, ()).get_choices())

Test 2: print a trace from a hierarchical model

In [None]:
@gen
def outer_model():
    x = model() @ "inner"
    w = normal(0.0, 1.0) @ "w"
    return x, w


pytreefy(outer_model.simulate(key, ()).get_choices())

Test 3: print a trace from models built with combinators that iterate computations

In [None]:
# repeat combinator
pytreefy(model.repeat(n=10).simulate(key, ()).get_choices())

In [None]:
# vmap combinator


@gen
def model(v):
    x = normal(v, 1.0) @ "x"
    return x


vs = (1.0 * jnp.arange(10),)
pytreefy(model.vmap(in_axes=(0,)).simulate(key, vs).get_choices())

In [None]:
# iterate combinator
pytreefy(model.iterate(n=10).simulate(key, (1.0,)).get_choices())

In [None]:
# iterate final combinator
pytreefy(model.iterate_final(n=10).simulate(key, (1.0,)).get_choices())

Test 4: print a masked trace

In [None]:
# TODO: this should be part of the standard library, hopefully soon.
def masked_scan_combinator(step, **scan_kwargs):
    mstep = step.mask().dimap(
        pre=lambda masked_state, masked_inval: (
            jnp.logical_and(masked_state.flag, masked_inval.flag),
            masked_state.value,
            masked_inval.value,
        ),
        post=lambda _, masked_retval: (
            Mask(masked_retval.flag, masked_retval.value[0]),
            Mask(masked_retval.flag, masked_retval.value[1]),
        ),
    )

    scanned = mstep.scan(**scan_kwargs)

    scanned_nice = scanned.dimap(
        pre=lambda initial_state, masked_input_values: (
            Mask(True, initial_state),
            Mask(masked_input_values.flag, masked_input_values.value),
        ),
        post=lambda _, retval: retval,
    )

    return scanned_nice


state_size = 3
variance = jnp.eye(state_size)
initial_state = jax.random.normal(jax.random.PRNGKey(0), (state_size,))

length = 10
stop_at_index = 5
mask = Mask(jnp.arange(length) < stop_at_index, None)


@genjax.gen
def hmm_step(x, _):
    new_x = genjax.mv_normal(x, variance) @ "new_x"
    return new_x, None


masked_hmm = masked_scan_combinator(hmm_step, n=length)

choices = masked_hmm.simulate(key, (initial_state, mask)).get_choices()

pytreefy(choices)

Test 5: print a nested masked trace

In [None]:
@genjax.gen
def outer_step(x, _):
    _ = masked_hmm(initial_state, mask) @ "hmm_x"
    return x, None


masked_masked_hmm = masked_scan_combinator(outer_step, n=length)
outer_stop_at_index = 6
outer_mask = Mask(jnp.arange(length) > outer_stop_at_index, None)


outer_trace_init = masked_hmm.simulate(key, (initial_state, mask))
flag = outer_trace_init.get_choices().get_submap(...).flag
unmasked_val = outer_trace_init.get_choices().get_submap(...).c("new_x").v
outer_init = unmasked_val[flag, :]
choices = masked_masked_hmm.simulate(key, (outer_init, mask)).get_choices()
pytreefy(choices)

Test 6: print a filtered choicemap (obtained after using explicit marginalization)

In [None]:
@gen
def model(v):
    x = normal(v, 1.0) @ "x"
    y = normal(x, 1.0) @ "y"
    return x, y


pytreefy(model.marginal(selection=S["x", "y"]).simulate(key, (1.0,)).get_choices())

In [None]:
a = model.marginal(selection=S["y"]).simulate(key, (1.0,)).get_choices()
pytreefy(a)

Test 7: print an empty choice map. 

In [None]:
chm = C.n()
pytreefy(chm)

Test 8: print an or choice map

In [None]:
chm = C["x"].set(3.0).at["y"].set(2.0)
pytreefy(chm)

Test 9: print an or choice map with repeated entries

In [None]:
chm = C["x"].set(3.0).at["x"].set(2.0)
# because the keys are classes, the two keys are considered different
pytreefy(chm)

Test 10: print a xor choice map

In [None]:
chm = C["x"].set(3.0) ^ C["y"].set(2.0)
pytreefy(chm)

Test 11: print a Xor choicemap with repeat entries.

In [None]:
chm = C["x"].set(3.0) ^ C["x"].set(2.0)
try:
    pytreefy(chm)
except ValueError as e:
    print(e)

Other usages include being able to easily see what the shapes and dtype are (or many other manipulations).

In [None]:
print(jax.tree.map(jnp.shape, pytreefy(choices)))
print(jax.tree.map(jnp.dtype, pytreefy(choices)))