In [None]:
import jax.numpy as jnp
import jax.random as random
from genjax import ChoiceMapBuilder as C
from genjax import bernoulli, gen, mix, normal, or_else, pretty, repeat, scan, vmap
from jax import jit

pretty()
key = random.PRNGKey(0)
# TODO: more basic examples, choicemap creation, and high level explanations

Example of Choice map creation.

In [None]:
# Create a choice map with a single choice 0.5 at the address "p"
chm = C["p"].set(0.5)
# Create a choice map at several addresses
chm = C["p"].set(0.5) ^ C["v"].set(1)
# Alternatively
chm2 = C["p"].set(0.5).at["v"].set(1)
# One can iteratively add choices to a choice map
chm2 = chm2 ^ C["p2"].set(0.6)
for i in range(10):
    chm2 = chm2 ^ C["p" + str(i)].set(i)

Accessing the right elements in the trace can become non-trivial when one creates hierarchical generative functions. 
Here are minimal examples and solutions for selection.

In [None]:
# For `or_else` combinator
@gen
def model(p):
    branch_1 = gen(lambda p: bernoulli(p) @ "v1")
    branch_2 = gen(lambda p: bernoulli(-p) @ "v2")
    v = or_else(branch_1, branch_2)(p > 0, (p,), (p,)) @ "s"
    return v


trace = jit(model.simulate)(key, (0.5,))
trace.get_sample()[("s", "v2")]

In [None]:
# For `vmap` combinator
sample_image = vmap(in_axes=(0,))(
    vmap(in_axes=(0,))(gen(lambda pixel: normal(pixel, 1.0) @ "new_pixel"))
)

image = jnp.zeros([2, 3], dtype=jnp.float32)
trace = sample_image.simulate(key, (image,))
trace.get_sample()[..., ..., "new_pixel"]

In [None]:
# For `scan_combinator`
@scan(n=10)
@gen
def hmm(x, c):
    z = normal(x, 1.0) @ "z"
    y = normal(z, 1.0) @ "y"
    return y, None


trace = hmm.simulate(key, (0.0, None))
trace.get_sample()[..., "z"], trace.get_sample()[3, "y"]

In [None]:
# For `repeat_combinator`
@repeat(n=10)
@gen
def model(y):
    x = normal(y, 0.01) @ "x"
    y = normal(x, 0.01) @ "y"
    return y


trace = model.simulate(key, (0.3,))
trace.get_sample()[..., "x"]

In [None]:
# For `mixture_combinator`
@gen
def mixture_model(p):
    z = normal(p, 1.0) @ "z"
    logits = (0.3, 0.5, 0.2)
    arg_1 = (p,)
    arg_2 = (p,)
    arg_3 = (p,)
    a = (
        mix(
            gen(lambda p: normal(p, 1.0) @ "x1"),
            gen(lambda p: normal(p, 2.0) @ "x2"),
            gen(lambda p: normal(p, 3.0) @ "x3"),
        )(logits, arg_1, arg_2, arg_3)
        @ "a"
    )
    return a + z


trace = mixture_model.simulate(key, (0.4,))
# The combinator uses a fixed address "mixture_component" for the components of the mixture model.
trace.get_sample()["a", "mixture_component"]