What is a generative function and how to use it? 

In [9]:
import genjax

# This is a simple example of a beta-bernoulli process
# We can import standard distributions from genjax
import jax
from genjax import beta
from genjax import bernoulli

# Import the static_gen_fn decorator to create generative functions
from genjax import static_gen_fn

#TODO: we should rename @static_gen_fn to @gen 
@static_gen_fn
def beta_bernoulli_process(u):
    p = beta(0.0, u) @ "p"
    v = bernoulli(p) @ "v" # sweet
    return v

# We can now call the generative function with a specified random key
key = jax.random.PRNGKey(314159)
# Running the function will return a trace, which records the arguments, random choices made, and the return value
tr = beta_bernoulli_process.simulate(key, (1.0,))
# We can print the trace to see what happened
print(tr.args)
print()
print(tr.get_choices())
print()
print(tr.get_retval())

(1.0,)
HierarchicalChoiceMap(
  trie=Trie(
    inner={'p': ChoiceValue(value=f32[]), 'v': ChoiceValue(value=i32[])}
  )
)
0


In [23]:
# Genjax functions can be accelerated with JIT
@static_gen_fn
@jax.jit
def fast_beta_bernoulli_process(u):
    p = beta(0.0, u) @ "p"
    v = bernoulli(p) @ "v" # sweet
    return v

# But the proper way is to @jit the final function we aim to run
jitted = jax.jit(beta_bernoulli_process.simulate)

# We can compare the speed of the three functions
key = jax.random.PRNGKey(314159)
%timeit beta_bernoulli_process.simulate(key, (1.0,))
%timeit fast_beta_bernoulli_process.simulate(key, (1.0,))
%timeit jitted(key, (1.0,))

395 ms ± 2.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.01 ms ± 135 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
58 µs ± 12 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
