Ok I have a generative function. What can I do with it?

In [1]:
import jax
from genjax import beta
from genjax import bernoulli
from genjax import gen

# Define a generative function
@gen
def beta_bernoulli_process(u):
    p = beta(1.0, u) @ "p"
    v = bernoulli(p) @ "v"
    return 2*v

Given our generative function, the first thing we can do is to generate a traced sample.

In [2]:
# 1] Generate a traced sample
key = jax.random.PRNGKey(0)
trace = jax.jit(beta_bernoulli_process.simulate)(key, (0.5,))

# 1.1] Print the return value
print(trace.get_retval())
print()

# 1.2] Print the choice_map, i.e. the list of internal random choices made during the execution
print(trace.get_sample())
# Alternative way to print the choice_map
print(trace.get_choices())
# Print specific subparts of the choice_map
print(trace.get_sample()["p"])
print(trace.get_sample()["v"])
print()

2

XorChm(c1=StaticChm(addr='p', c=ValueChm(v=<jax.Array(0.97251576, dtype=float32)>)), c2=StaticChm(addr='v', c=ValueChm(v=<jax.Array(1, dtype=int32)>)))
XorChm(c1=StaticChm(addr='p', c=ValueChm(v=<jax.Array(0.97251576, dtype=float32)>)), c2=StaticChm(addr='v', c=ValueChm(v=<jax.Array(1, dtype=int32)>)))
0.97251576
1



Then, we can create a choice_map of observations and perform diverse operations on it.

In [3]:
# Create a choice_map of observations
from genjax import ChoiceMapBuilder as C

# We can set the value of an address in the choice_map
chm = C["p"].set(0.5) ^ C["v"].set(1)
print(chm)

# A different way to achieve the same result
chm = C["p"].set(0.5).at["v"].set(1)
print(chm)

# This also works for hierarchical addresses
chm = C["p", "v"].set(1) 
print(chm)

# We can also directly set a value in the choice_map
chm = C.v(5.0)
print(chm)

XorChm(c1=StaticChm(addr='p', c=ValueChm(v=0.5)), c2=StaticChm(addr='v', c=ValueChm(v=1)))
OrChm(c1=StaticChm(addr='v', c=ValueChm(v=1)), c2=StaticChm(addr='p', c=ValueChm(v=0.5)))
StaticChm(addr='p', c=StaticChm(addr='v', c=ValueChm(v=1)))
ValueChm(v=5.0)


For instance, we can compute log probabilities

In [4]:
# 2] Compute log probabilities

# 2.1] Print the log probability of the trace
print(trace.get_score())
print()

# 2.2] Print the log probability of an observation encoded as a ChoiceMap under the model
# It returns both the log probability and the return value
chm = C["p"].set(0.5) ^ C["v"].set(1)
args = (0.5,)
print(beta_bernoulli_process.assess(chm, args))

# Note that the ChoiceMap should be complete, i.e. all random choices should be observed
chm_2 = C["v"].set(1)
try: 
    beta_bernoulli_process.assess(chm_2, args)
except ValueError as e:
    print(e)

0.7831962

(Array(-0.8206506, dtype=float32), Array(2, dtype=int32))
Attempt to convert a value (None) with an unsupported type (<class 'NoneType'>) to a Tensor.


We can also use a partial ChoiceMap as a constraint/observation and generate a full trace with these constraints.

In [5]:
# 3] Generate a sample conditioned on the observations
key = jax.random.PRNGKey(42)
partial_chm =  C["v"].set(1) # Creates a ChoiceMap of observations
args = (0.5,)
trace, weight = beta_bernoulli_process.importance(key, partial_chm, args) # Runs importance sampling

# This returns a pair containing the new trace and the log probability of produced trace under the model
print(trace.get_sample())
print(weight)

XorChm(c1=StaticChm(addr='p', c=ValueChm(v=<jax.Array(0.4484125, dtype=float32)>)), c2=StaticChm(addr='v', c=ValueChm(v=1)))
-0.49386734


We can also update a trace. This is for instance useful for performance optimizations in MH algorithms where often most of the trace doesn't change between time steps.

In [6]:
# 4] Update a trace.
from genjax.incremental import Diff, NoChange, UnknownChange
from genjax import bernoulli, gen, GenericProblem
from jax import jit

# Define a model for which changing the argument will force a change in the trace.
@gen
def beta_bernoulli_process(u):
    p = beta(1.0, u) @ "p"
    v = bernoulli(p) @ "v"
    return 2*v

key = jax.random.PRNGKey(42)
jitted = jit(beta_bernoulli_process.simulate)
old_trace = jitted(key, (0.5,))
constraint =  C["v"].set(1)
# Update uses incremental computation.
# It works by tracking the differences between the old new values for arguments.
# Just like for differentiation, it can be achieved by providing for each argument a tuple containing the new value and its change compared to the old value.

# If there's no change for an argument, the change is set to NoChange.
arg_diff = (Diff(1.0, NoChange),) 
# If there's a change, the change is set to the difference between the new and old value.
arg_diff = (Diff(3.0, 2.0),)
# If there's an unknown change, the change is set to UnknownChange.
arg_diff = (Diff(1.0, UnknownChange),)
# we bundle together the arguments and the constraint as a GenericProblem, a simple instance of an UpdateProblem.
update = GenericProblem(arg_diff, constraint)

jitted_update = jit(beta_bernoulli_process.update)

new_trace, weight_diff, ret_diff, discard_choice = jitted_update(
    key, 
    old_trace, 
    update
    )

# print the old trace with a message
print(old_trace.get_sample())
print("Old value for p:", old_trace.get_sample()["p"])
print("New value for p:", new_trace.get_sample()["p"]) 
print()
print(weight_diff)
print()
print(ret_diff)
print()
print(discard_choice)

%timeit jitted(key, (0.5,))
%timeit jitted_update(key, old_trace, update)

XorChm(c1=StaticChm(addr='p', c=ValueChm(v=<jax.Array(0.4484125, dtype=float32)>)), c2=StaticChm(addr='v', c=ValueChm(v=<jax.Array(1, dtype=int32)>)))
Old value for p: 0.4484125
New value for p: 0.4484125

0.3956688

Diff(primal=<jax.Array(2, dtype=int32)>, tangent=_UnknownChange())

XorChm(c1=StaticChm(addr='p', c=ValueChm(v=EmptyProblem())), c2=StaticChm(addr='v', c=ValueChm(v=<jax.Array(1, dtype=int32)>)))
42.8 µs ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
78.4 µs ± 1.49 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [27]:
# 5] A few more convenient methods
# 5.1] propose
# It uses the same inputs as simulate but returns the sample, the score and the return value
sample, score, retval = jit(beta_bernoulli_process.propose)(key, (0.5,))
print(sample)
print()
print(score)
print() 
print(retval)
print()

# 5.2] get_gen_fn
# It returns the generative function that produced the trace
gen_fn = trace.get_gen_fn()
print(gen_fn)
print()

# 5.2] get_args
# It returns the arguments passed to the generative function used to produce the trace
args = trace.get_args()
print(args)
print()

# 5.3] get_subtrace
# It takes a `StaticAddress` as argument and returns the sub-trace of a trace rooted at these addresses
subtrace = trace.get_subtrace(("p",))
print(subtrace.get_sample())

XorChm(c1=StaticChm(addr='p', c=ValueChm(v=<jax.Array(0.97251576, dtype=float32)>)), c2=StaticChm(addr='v', c=ValueChm(v=<jax.Array(1, dtype=int32)>)))

0.7831962

2

StaticGenerativeFunction(source=Closure(dyn_args=(), fn=<function beta_bernoulli_process at 0x2ca085ee0>))

(Array(0.5, dtype=float32, weak_type=True),)

ValueChm(v=<jax.Array(0.97251576, dtype=float32)>)
