In [None]:
import jax
from genjax import bernoulli, beta, gen, pretty

pretty()

The following is a simple  of a beta-bernoulli process. We use the `@gen` decorator to create generative functions.

In [None]:
@gen
def beta_bernoulli_process(u):
    p = beta(0.0, u) @ "p"
    v = bernoulli(p) @ "v"
    return v

We can now call the generative function with a specified random key

In [None]:
key = jax.random.PRNGKey(0)

Running the function will return a trace, which records the arguments, random choices made, and the return value

In [None]:
key, subkey = jax.random.split(key)
tr = beta_bernoulli_process.simulate(subkey, (1.0,))

We can print the trace to see what happened

In [None]:
print(tr.args)
print()
print(tr.get_sample())
print()
print(tr.get_retval())

GenJAX functions can be accelerated with `jit` compilation. 

The non-optimal way is within the `@gen` decorator.

In [None]:
@gen
@jax.jit
def fast_beta_bernoulli_process(u):
    p = beta(0.0, u) @ "p"
    v = bernoulli(p) @ "v"  # sweet
    return v

And the better way is to `jit` the final function we aim to run

In [None]:
jitted = jax.jit(beta_bernoulli_process.simulate)

We can then compare the speed of the three functions. 
To fairly compare we need to run the functions once to compile them.

In [None]:
key, subkey = jax.random.split(key)
fast_beta_bernoulli_process.simulate(subkey, (1.0,))
key, subkey = jax.random.split(key)
jitted(subkey, (1.0,))

In [None]:
key, subkey = jax.random.split(key)
%timeit beta_bernoulli_process.simulate(subkey, (1.0,))
key, subkey = jax.random.split(key)
%timeit fast_beta_bernoulli_process.simulate(subkey, (1.0,))
key, subkey = jax.random.split(key)
%timeit jitted(subkey, (1.0,))