Ok I have a generative function. What can I do with it?

In [None]:
import jax
from genjax import flip
from genjax import beta
from genjax import bernoulli
from genjax import static_gen_fn

# Define a generative function
@static_gen_fn
def beta_bernoulli_process(u):
    p = beta(0.0, u) @ "p"
    v = bernoulli(p) @ "v" # sweet
    return v

# We can:
# 1] Generate a traced sample
key = jax.random.PRNGKey(0)
trace = jax.jit(beta_bernoulli_process.simulate)(key, (0.5,))
# 1.1] Print the return value
print(trace.get_retval())
print()
# 1.2] Print the choice_map, i.e. the list of internal random choices made during the execution
print(trace.get_choices())
print()
print(trace.get_choices().get_submap("p"))
# 2] Compute log probabilities
# 2.1] Print the log probability of the trace
print(trace.get_score())
print()
# 2.2] Print the log probability of an observation under the model
print(TODO)

: 