In [None]:
import genjax
import jax
import jax.numpy as jnp
import jax.random as random
from genjax import ChoiceMapBuilder as C
from genjax import (
    bernoulli,
    beta,
    gen,
    mix,
    normal,
    or_else,
    pretty,
    repeat,
    scan,
    vmap,
)

pretty()
key = random.PRNGKey(0)

Choice maps are dictionary-like data structures that accumulate the random choices produced by generative functions which are `traced` by the system, i.e. that are indicated by `@ "p"`  in generative functions. 

They also serve as a set of constraints/observations when one tries to do inference: given the constraints, inference provides plausible value to complete a choice map to a full trace  of a generative model (one value per traced random sample).

In [None]:
@gen
def beta_bernoulli_process(u):
    p = beta(1.0, u) @ "p"
    v = bernoulli(p) @ "v"
    return 2 * v

Simulating from a model produces a traces which contains a choice map.

In [None]:
key = jax.random.PRNGKey(0)
trace = jax.jit(beta_bernoulli_process.simulate)(key, (0.5,))

From that trace, we can recover the choicemap with either of the two equivalent methods:

In [None]:
trace.get_sample(), trace.get_choices()

We can also print specific subparts of the choice map.

In [None]:
trace.get_sample()["p"]

Then, we can create a choice map of observations and perform diverse operations on it.
We can set the value of an address in the choice map.

In [None]:
chm = C["p"].set(0.5) ^ C["v"].set(1)  # ^ acts as a union of two choice maps
chm

A different way to achieve the same result.

In [None]:
chm = C["p"].set(0.5).at["v"].set(1)

Note that one is a `Xor` choice map while the other one is an `Or` choice map. One nuance is that the former assumes that the two sub choice maps are disjoint.

This also works for hierarchical addresses

In [None]:
chm = C["p", "v"].set(1)
chm

We can also directly set a value in the choice_map

In [None]:
chm = C.v(5.0)
chm

We can also create an empty choice_map

In [None]:
chm = C.n()
chm

Other examples of Choice map creation include iteratively adding choices to a choice map.

In [None]:
chm = C.n()
for i in range(10):
    chm = chm ^ C["p" + str(i)].set(i)
# A more JAX-friendly way to do this
chm = jax.vmap(lambda idx: C[idx].set(idx.astype(float)))(jnp.arange(10))
chm

For a nested vmap combinator, the creation of a choice map can be a bit more tricky.

In [None]:
sample_image = genjax.vmap(in_axes=(0,))(
    genjax.vmap(in_axes=(0,))(gen(lambda pixel: normal(pixel, 1.0) @ "new_pixel"))
)

image = jnp.zeros([4, 4], dtype=jnp.float32)
trace = sample_image.simulate(key, (image,))
trace.get_sample()

Creating a few values for the choice map is simple.

In [None]:
chm = C[1, 2, "new_pixel"].set(1.0) ^ C[0, 2, "new_pixel"].set(1.0)

tr, w = jax.jit(sample_image.importance)(key, chm, (image,))
w

But because of the nested `vmap`, the address hierarchy can sometimes lead to unintuitive results. 

In [None]:
chm = chm ^ C[1, 3, "new_pixel"].set(1.0)  # seemingly adding a new constraint
tr, w = jax.jit(sample_image.importance)(key, chm, (image,))
w  # Yet we obtain the same weight as before

In [None]:
N = 4
F = 4


chm3 = C[
    jnp.arange(F), jnp.repeat(jnp.arange(N)[jnp.newaxis], F, axis=0), "new_pixel"
].set(1.0)

Accessing the right elements in the trace can become non-trivial when one creates hierarchical generative functions. 
Here are minimal examples and solutions for selection.

In [None]:
# For `or_else` combinator
@gen
def model(p):
    branch_1 = gen(lambda p: bernoulli(p) @ "v1")
    branch_2 = gen(lambda p: bernoulli(-p) @ "v2")
    v = or_else(branch_1, branch_2)(p > 0, (p,), (p,)) @ "s"
    return v


trace = jax.jit(model.simulate)(key, (0.5,))
trace.get_sample()["s", "v1"]

In [None]:
# For `vmap` combinator
sample_image = vmap(in_axes=(0,))(
    vmap(in_axes=(0,))(gen(lambda pixel: normal(pixel, 1.0) @ "new_pixel"))
)

image = jnp.zeros([2, 3], dtype=jnp.float32)
trace = sample_image.simulate(key, (image,))
trace.get_choices()[..., ..., "new_pixel"]

In [None]:
# For `scan_combinator`
@scan(n=10)
@gen
def hmm(x, c):
    z = normal(x, 1.0) @ "z"
    y = normal(z, 1.0) @ "y"
    return y, None


trace = hmm.simulate(key, (0.0, None))
trace.get_choices()[..., "z"], trace.get_choices()[3, "y"]

In [None]:
# For `repeat_combinator`
@repeat(n=10)
@gen
def model(y):
    x = normal(y, 0.01) @ "x"
    y = normal(x, 0.01) @ "y"
    return y


trace = model.simulate(key, (0.3,))
trace.get_choices()[..., "x"]

In [None]:
# For `mixture_combinator`
@gen
def mixture_model(p):
    z = normal(p, 1.0) @ "z"
    logits = (0.3, 0.5, 0.2)
    arg_1 = (p,)
    arg_2 = (p,)
    arg_3 = (p,)
    a = (
        mix(
            gen(lambda p: normal(p, 1.0) @ "x1"),
            gen(lambda p: normal(p, 2.0) @ "x2"),
            gen(lambda p: normal(p, 3.0) @ "x3"),
        )(logits, arg_1, arg_2, arg_3)
        @ "a"
    )
    return a + z


trace = mixture_model.simulate(key, (0.4,))
# The combinator uses a fixed address "mixture_component" for the components of the mixture model.
trace.get_sample()["a", "mixture_component"]

Similarly, if traces were created as a batch using `jax.vmap`, in general it will not create a valid batched trace.

In [None]:
@genjax.gen
def random_walk_step(prev, _):
    x = genjax.normal(prev, 1.0) @ "x"
    return x, None


random_walk = random_walk_step.scan(n=1000)

init = 0.5
keys = jax.random.split(key, 10)


trs = jax.vmap(random_walk.simulate, (0, None))(keys, (init, None))
try:
    trs.get_choices()
except Exception as e:
    print(e)

However, with a little extra step we can recover information in individual traces.

In [None]:
jax.vmap(lambda tr: tr.get_choices())(trs)

Note that this limitation is dependent on the model, and the simpler thing may work anyway for some classes models.

In [None]:
jitted = jax.jit(jax.vmap(model.simulate, in_axes=(0, None)))
keys = random.split(key, 10)
traces = jitted(keys, (0.5,))


traces.get_choices()