What is a generative function and how to use it? 

In [None]:
import jax
from genjax import bernoulli, beta, gen

The following is a simple  of a beta-bernoulli process. We use the @gen decorator to create generative functions.

In [None]:
@gen
def beta_bernoulli_process(u):
    p = beta(0.0, u) @ "p"
    v = bernoulli(p) @ "v"  # sweet
    return v

We can now call the generative function with a specified random key

In [None]:
key = jax.random.PRNGKey(314159)

Running the function will return a trace, which records the arguments, random choices made, and the return value

In [None]:
tr = beta_bernoulli_process.simulate(key, (1.0,))

We can print the trace to see what happened

In [None]:
print(tr.args)
print()
print(tr.get_sample())
print()
print(tr.get_retval())

GenJAX functions could be accelerated with `jit` compilation within the @gen decorator.

In [None]:
@gen
@jax.jit
def fast_beta_bernoulli_process(u):
    p = beta(0.0, u) @ "p"
    v = bernoulli(p) @ "v"  # sweet
    return v

But the proper way is to `jit` the final function we aim to run

In [None]:
#
jitted = jax.jit(beta_bernoulli_process.simulate)

# We can compare the speed of the three functions
key = jax.random.PRNGKey(314159)
%timeit beta_bernoulli_process.simulate(key, (1.0,))
%timeit fast_beta_bernoulli_process.simulate(key, (1.0,))
%timeit jitted(key, (1.0,))