What is a generative function and how to use it? 

In [None]:
# This is a simple example of a beta-bernoulli process
# We can import standard distributions from genjax
import jax
from genjax import beta
from genjax import bernoulli

# Import the @gen decorator to create generative functions
from genjax import gen


@gen
def beta_bernoulli_process(u):
    p = beta(0.0, u) @ "p"
    v = bernoulli(p) @ "v"  # sweet
    return v


# We can now call the generative function with a specified random key
key = jax.random.PRNGKey(314159)
# Running the function will return a trace, which records the arguments, random choices made, and the return value
tr = beta_bernoulli_process.simulate(key, (1.0,))
# We can print the trace to see what happened
print(tr.args)
print()
print(tr.get_sample())
print()
print(tr.get_retval())

In [None]:
# Genjax functions can be accelerated with JIT
@gen
@jax.jit
def fast_beta_bernoulli_process(u):
    p = beta(0.0, u) @ "p"
    v = bernoulli(p) @ "v"  # sweet
    return v


# But the proper way is to @jit the final function we aim to run
jitted = jax.jit(beta_bernoulli_process.simulate)

# We can compare the speed of the three functions
key = jax.random.PRNGKey(314159)
%timeit beta_bernoulli_process.simulate(key, (1.0,))
%timeit fast_beta_bernoulli_process.simulate(key, (1.0,))
%timeit jitted(key, (1.0,))