How can I do a mixture of models in Genjax?

In [1]:
# We use the `mixture_combinator`. 
# Note that the trace is the join of the traces of the different components.

from genjax import mixture_combinator, gen, normal, flip, inverse_gamma
from jax import random

# We first define the three components of the mixture model as generative functions.
@gen
def mixture_component_1(p):
    x = normal(p, 1.0) @ "x"
    return x

@gen
def mixture_component_2(p):
    b = flip(p) @ "b"
    return b

@gen
def mixture_component_3(p):
    y = inverse_gamma(p, 0.1) @ "y"
    return y

@gen
def mixture_model(p):
    z = normal(p, 1.0) @ "z"
    # the switch combinators take as input the logits of the mixture components
    logits = (0.3, 0.5, 0.2)
    # and args for each component of the mixture
    arg_1 = (p,)
    arg_2 = (p,)
    arg_3 = (p,)
    a = mixture_combinator(
        mixture_component_1, 
        mixture_component_2,
        mixture_component_3)(logits, arg_1, arg_2, arg_3) @ "a"
    return a+z

key = random.PRNGKey(23)
tr = mixture_model.simulate(key, (0.4,))
print("return value:", tr.get_retval())
print("value for z:", tr.get_sample()["z"])
# The combinator uses a fixed address "mixture_component" for the components of the mixture model. 
print("value for the mixture_component:",tr.get_sample()["a","mixture_component"])

TypeError: ExactDensityFromCallables.sample() got an unexpected keyword argument 'logits'

In [13]:
from genjax import gen, normal, bernoulli
import jax

@gen
def model1():
    return normal(0.0, 1.0) @ "x"

@gen
def model2():
    z = bernoulli(0.5) @ "z"
    return (z, z)

@gen
def model3(x):
    y = bernoulli(x) @ "y"
    return (y, y+1)

# produces a sum type
tr = jax.jit(model1.switch([model2, model3]).simulate)(key, (1, (), (), (0.5,)))
print(tr.get_retval())

def collapsing_a_sum_type(key, idx):
    tr = jax.jit(model1.switch([model2, model3]).simulate)(key, (idx, (), (), (0.5,)))
    sum = tr.get_retval()
    v = jax.lax.switch(
        sum.idx,
        [
            lambda: sum.values[0] + 3.0,
            lambda: 1.0 + sum.values[1][0] + sum.values[1][1],
        ],
    )
    return v

# Collapsing a sum type
x = jax.jit(collapsing_a_sum_type)(key, 1)
print(x)
x = jax.jit(collapsing_a_sum_type)(key, 2)
print(x)
x = jax.jit(collapsing_a_sum_type)(key, 3)
print(x)

from genjax import Sum

# From sum type to mask type
def uncertain_idx(idx):
    s = Sum(idx, [1, 2, 3])
    return s[2]

mask = jax.jit(uncertain_idx)(1)
print(mask)

Sum(...)
1.0
1.0
1.0
Mask(...)
