In [1]:
import jax
import genjax

@genjax.gen
def random_walk(key, prev):
    key, x = genjax.trace("x", genjax.Normal)(key, (prev, 1.0))
    return (key, x)


unfold = genjax.UnfoldCombinator(random_walk, 1000)
init = 0.5
key = jax.random.PRNGKey(314159)
key, tr = jax.jit(genjax.simulate(unfold))(key, (1000, init,))
print(tr)

<class 'genjax.combinators.unfold.UnfoldTrace'>
UnfoldCombinator(
  kernel = BuiltinGenerativeFunction(source = <function random_walk>),
  max_length = 1000
)VectorChoiceMap(
  subtrace = <class 'genjax.builtin.datatypes.BuiltinTrace'>
  BuiltinGenerativeFunction(source = <function random_walk>)(
    x = <class 'genjax.distributions.distribution.DistributionTrace'>
    _Normal()(value = f32[1000])
    score: [-0.9247955  -1.8264294  -1.8033758  -1.2785668  -2.0266647  -1.6825668
 -1.2227434  -1.0590267  -3.1286416  -0.97036177 -1.1971071  -0.9330346
 -0.9295393  -0.9236511  -1.1370858  -1.0736619  -1.3489403  -1.216788
 -1.2314839  -0.964772   -1.505276   -1.0950221  -0.9406347  -0.94001603
 -1.8975222  -0.9548435  -0.9776876  -0.98727185 -1.0896006  -0.92114383
 -2.9103003  -1.3686922  -1.0431892  -0.9449425  -0.93848586 -1.9802654
 -2.7054317  -2.1110342  -1.2191311  -0.9594871  -0.921643   -1.1974747
 -0.9208673  -1.8259932  -1.3640103  -1.5634525  -2.0195951  -1.29763
 -1.1016977  

In [2]:
import jax
import jax.numpy as jnp
import genjax

@genjax.gen
def add_normal_noise(key, x):
    key, noise1 = genjax.trace("noise1", genjax.Normal)(
            key, (0.0, 1.0)
    )
    key, noise2 = genjax.trace("noise2", genjax.Normal)(
            key, (0.0, 1.0)
    )
    return (key, x + noise1 + noise2)

mapped = genjax.MapCombinator(add_normal_noise, in_axes=(0, 0))

arr = jnp.ones(100)
key = jax.random.PRNGKey(314159)
key, *subkeys = jax.random.split(key, 101)
subkeys = jnp.array(subkeys)
_, tr = jax.jit(genjax.simulate(mapped))(subkeys, (arr, ))
print(tr)

<class 'genjax.combinators.map.MapTrace'>
MapCombinator(
  kernel = BuiltinGenerativeFunction(source = <function add_normal_noise>),
  in_axes = (0, 0)
)VectorChoiceMap(
  subtrace = <class 'genjax.builtin.datatypes.BuiltinTrace'>
  BuiltinGenerativeFunction(source = <function add_normal_noise>)(
    noise1 = <class 'genjax.distributions.distribution.DistributionTrace'>
    _Normal()(value = f32[100])
    score: [-1.4996879  -2.2619872  -0.9215518  -1.479387   -1.1974081  -1.3033798
 -1.8249735  -1.0947304  -0.9251105  -1.0034034  -0.9407734  -0.9724346
 -0.92531097 -1.3261323  -1.1271875  -0.9945743  -1.6666276  -1.0104258
 -1.0538225  -1.7994913  -0.92095697 -2.1788874  -2.1146896  -1.7868929
 -1.0780694  -0.9550102  -3.0018265  -0.92260957 -1.9988542  -0.9713293
 -1.3433127  -1.1105025  -1.3307436  -1.2368753  -1.656857   -0.9746938
 -0.91894704 -0.95909494 -0.98636127 -1.5532479  -1.2469623  -1.52319
 -0.99405545 -1.3928486  -2.1056511  -1.8195016  -3.106034   -0.97066283
 -1.17867