In [1]:
import jax
import genjax

@genjax.gen
def branch_1(key):
    key, x = genjax.trace("x1", genjax.Normal)(key, (0.0, 1.0))
    return (key, )

@genjax.gen
def branch_2(key):
    key, x = genjax.trace("x2", genjax.Bernoulli)(key, (0.3, ))
    return (key, )

switch = genjax.SwitchCombinator([branch_1, branch_2])

key = jax.random.PRNGKey(314159)
jitted = jax.jit(genjax.simulate(switch))
key, _ = jitted(key, (0, ))
key, tr = jitted(key, (1, ))
print(tr)

SwitchTrace
  gen_fn: SwitchCombinator(
    branches = {
      0:
      BuiltinGenerativeFunction(source = <function branch_1>),
      1:
      BuiltinGenerativeFunction(source = <function branch_2>)
    }
  )
  args: ()
  return: ()
  score: f32[]
  choices: IndexedChoiceMap
    index = i32[]
    [
      BuiltinTrace
        gen_fn: BuiltinGenerativeFunction(source = <function branch_1>)
        args: ()
        return: ()
        score: f32[]
        choices: BuiltinChoiceMap
          'x1':
          DistributionTrace
            gen_fn: _Normal()
            args: (f32[], f32[])
            return: (f32[],)
            score: f32[]
            choices: ValueChoiceMap(value: f32[]),
      BuiltinTrace
        gen_fn: BuiltinGenerativeFunction(source = <function branch_2>)
        args: ()
        return: ()
        score: f32[]
        choices: BuiltinChoiceMap
          'x2':
          DistributionTrace
            gen_fn: _Bernoulli()
            args: (f32[],)
            return:

In [2]:
import jax
import genjax

@genjax.gen
def random_walk(key, prev):
    key, x = genjax.trace("x", genjax.Normal)(key, (prev, 1.0))
    return (key, x)


unfold = genjax.UnfoldCombinator(random_walk, 1000)
init = 0.5
key = jax.random.PRNGKey(314159)
key, tr = jax.jit(genjax.simulate(unfold))(key, (1000, init,))
print(tr)

UnfoldTrace
  gen_fn: UnfoldCombinator(
    kernel = BuiltinGenerativeFunction(source = <function random_walk>),
    max_length = 1000
  )
  args: (i32[], f32[])
  return: (f32[],)
  score: f32[]
  choices: VectorChoiceMap
    indices: i32[1000]
    BuiltinTrace
      gen_fn: BuiltinGenerativeFunction(source = <function random_walk>)
      args: (f32[1000],)
      return: (f32[1000],)
      score: f32[1000]
      choices: BuiltinChoiceMap
        'x':
        DistributionTrace
          gen_fn: _Normal()
          args: (f32[1000], f32[1000])
          return: (f32[1000],)
          score: f32[1000]
          choices: ValueChoiceMap(value: f32[1000])


In [3]:
import jax
import jax.numpy as jnp
import genjax

@genjax.gen
def add_normal_noise(key, x):
    key, noise1 = genjax.trace("noise1", genjax.Normal)(
            key, (0.0, 1.0)
    )
    key, noise2 = genjax.trace("noise2", genjax.Normal)(
            key, (0.0, 1.0)
    )
    return (key, x + noise1 + noise2)

mapped = genjax.MapCombinator(add_normal_noise, in_axes=(0, 0))

arr = jnp.ones(100)
key = jax.random.PRNGKey(314159)
key, *subkeys = jax.random.split(key, 101)
subkeys = jnp.array(subkeys)
_, tr = jax.jit(genjax.simulate(mapped))(subkeys, (arr, ))
print(tr)

MapTrace
  gen_fn: MapCombinator(
    kernel = BuiltinGenerativeFunction(source = <function add_normal_noise>),
    in_axes = (0, 0)
  )
  args: (f32[100],)
  return: (f32[100],)
  score: f32[]
  choices: VectorChoiceMap
    indices: i32[100]
    BuiltinTrace
      gen_fn: BuiltinGenerativeFunction(source = <function add_normal_noise>)
      args: (f32[100],)
      return: (f32[100],)
      score: f32[100]
      choices: BuiltinChoiceMap
        'noise1':
        DistributionTrace
          gen_fn: _Normal()
          args: (f32[100], f32[100])
          return: (f32[100],)
          score: f32[100]
          choices: ValueChoiceMap(value: f32[100]),
        'noise2':
        DistributionTrace
          gen_fn: _Normal()
          args: (f32[100], f32[100])
          return: (f32[100],)
          score: f32[100]
          choices: ValueChoiceMap(value: f32[100])
