Ok I have a generative function. What can I do with it?

In [3]:
import jax
from genjax import beta
from genjax import bernoulli
from genjax import gen

# Define a generative function
@gen
def beta_bernoulli_process(u):
    p = beta(1.0, u) @ "p"
    v = bernoulli(p) @ "v" # sweet
    return 2*v

# We can:
# 1] Generate a traced sample
key = jax.random.PRNGKey(0)
trace = jax.jit(beta_bernoulli_process.simulate)(key, (0.5,))
# 1.1] Print the return value
print(trace.get_retval())
print()
# 1.2] Print the choice_map, i.e. the list of internal random choices made during the execution
print(trace.get_sample())
print(trace.get_sample()["p"])
print(trace.get_sample()["v"])
print()
# 2] Create a choice_map of observations
chm = {"p": 0.5, "v": 1}
# 3] Compute log probabilities
# 3.1] Print the log probability of the trace
print(trace.get_score())
print()
# 3.2] Print the log probability of an observation under the model
print(trace.score(chm))

2

ChoiceMap(((Empty ⊕ Static(p => Value(0.9725157618522644))) ⊕ Static(v => Value(1))))
0.97251576
1

0.7831962



TypeError: 'jaxlib.xla_extension.ArrayImpl' object is not callable