How can I do a mixture of models in Genjax?

In [None]:
# We use the `mixture_combinator`.
# Note that the trace is the join of the traces of the different components.

from genjax import flip, gen, inverse_gamma, mixture_combinator, normal
from jax import random


# We first define the three components of the mixture model as generative functions.
@gen
def mixture_component_1(p):
    x = normal(p, 1.0) @ "x"
    return x


@gen
def mixture_component_2(p):
    b = flip(p) @ "b"
    return b


@gen
def mixture_component_3(p):
    y = inverse_gamma(p, 0.1) @ "y"
    return y


@gen
def mixture_model(p):
    z = normal(p, 1.0) @ "z"
    # the switch combinators take as input the logits of the mixture components
    logits = (0.3, 0.5, 0.2)
    # and args for each component of the mixture
    arg_1 = (p,)
    arg_2 = (p,)
    arg_3 = (p,)
    a = (
        mixture_combinator(
            mixture_component_1, mixture_component_2, mixture_component_3
        )(logits, arg_1, arg_2, arg_3)
        @ "a"
    )
    return a + z


key = random.PRNGKey(23)
tr = mixture_model.simulate(key, (0.4,))
print("return value:", tr.get_retval())
print("value for z:", tr.get_sample()["z"])
# The combinator uses a fix address "mixture_component" for the components of the mixture model.
print("value for the mixture_component:", tr.get_sample()["a", "mixture_component"])