Ok I have a generative function. What can I do with it?

In [51]:
import jax
from genjax import beta
from genjax import bernoulli
from genjax import gen

# Define a generative function
@gen
def beta_bernoulli_process(u):
    p = beta(1.0, u) @ "p"
    v = bernoulli(p) @ "v"
    return 2*v

Given our generative function, the first thing we can do is to generate a traced sample.

In [97]:
# 1] Generate a traced sample
key = jax.random.PRNGKey(0)
trace = jax.jit(beta_bernoulli_process.simulate)(key, (0.5,))

# 1.1] Print the return value
print(trace.get_retval())
print()

# 1.2] Print the choice_map, i.e. the list of internal random choices made during the execution
print(trace.get_sample())
print(trace.get_sample()["p"])
print(trace.get_sample()["v"])
print()

2

ChoiceMap(((Empty ⊕ Static(p => Value(0.9725157618522644))) ⊕ Static(v => Value(1))))
0.97251576
1



Then, we can create a choice_map of observations and perform diverse operations on it.

In [92]:
# Create a choice_map of observations
from genjax import ChoiceMap
chm = ChoiceMap.n.at["p"].set(0.5).at["v"].set(1)
chm

ChoiceMap((Static(v => Value(1)) + (Static(p => Value(0.5)) + Empty)))

For instance, we can compute log probabilities

In [98]:

# 3] Compute log probabilities

# 3.1] Print the log probability of the trace
print(trace.get_score())
print()

# 3.2] Print the log probability of an observation encoded as a ChoiceMap under the model
chm = ChoiceMap.n.at["p"].set(0.5).at["v"].set(1)
args = (0.5,)
print(beta_bernoulli_process.assess(chm, args))

# Note that the ChoiceMap should be complete, i.e. all random choices should be observed
chm_2 = ChoiceMap.n.at["v"].set(1)
try: 
    beta_bernoulli_process.assess(chm_2, args)
except ValueError as _:
    print("The ChoiceMap is incomplete")

0.7831962

(Array(-0.8206506, dtype=float32), Array(2, dtype=int32))
The ChoiceMap is incomplete


We can also use a partial ChoiceMap as a constraint/observation and generate a full trace with these constraints.

In [99]:
# 4] Generate a sample conditioned on the observations
key = jax.random.PRNGKey(42)
partial_chm = ChoiceMap.n.at["v"].set(1)
args = (0.5,)
tr = beta_bernoulli_process.importance(key, partial_chm, args)

# This returns a pair containing the new trace and the log probability of produced trace under the model
print(tr[0].get_sample())
print(tr[1])

ChoiceMap(((Empty ⊕ Static(p => Value(0.4484125077724457))) ⊕ Static(v => Value(1))))
-0.49386734


We can also update a trace. This is for instance useful for performance optimizations in MH algorithms where often most of the trace doesn't change between time steps.

In [115]:
# 5] Update a trace.
from genjax.incremental import Diff, NoChange, UnknownChange

key = jax.random.PRNGKey(42)
old_trace = trace
constraint = ChoiceMap.n.at["v"].set(1)
arg_diff = (Diff(0.0, NoChange),) #TODO: explain this

new_trace, weight_diff, _, _ = beta_bernoulli_process.update(
    key, 
    old_trace, 
    constraint, 
    arg_diff
    )

