# The generative function interface [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ChiSym/genjax/blob/main/docs/cookbook/active/generative_function_interface.ipynb)

In [None]:
import sys

if "google.colab" in sys.modules:
    %pip install --quiet "genjax[genstudio]"

In [None]:
import jax
from jax import jit

from genjax import ChoiceMapBuilder as C
from genjax import (
    bernoulli,
    beta,
    gen,
    pretty,
)
from genjax._src.generative_functions.static import MissingAddress
from genjax.incremental import Diff, NoChange, UnknownChange

key = jax.random.key(0)
pretty()


# Define a generative function
@gen
def beta_bernoulli_process(u):
    p = beta(1.0, u) @ "p"
    v = bernoulli(p) @ "v"
    return 2 * v

1) Generate a traced sample and constructs choicemaps

There's an entire cookbook entry on this in `choicemap_creation_selection`.

In [None]:
key, subkey = jax.random.split(key)
trace = jax.jit(beta_bernoulli_process.simulate)(subkey, (0.5,))

2) Compute log probabilities

2.1 Print the log probability of the trace

In [None]:
trace.get_score()

2.2 Print the log probability of an observation encoded as a ChoiceMap under the model

It returns both the log probability and the return value

In [None]:
chm = C["p"].set(0.5) ^ C["v"].set(1)
args = (0.5,)
beta_bernoulli_process.assess(chm, args)

Note that the ChoiceMap should be complete, i.e. all random choices should be observed

In [None]:
chm_2 = C["v"].set(1)
try:
    beta_bernoulli_process.assess(chm_2, args)
except MissingAddress as e:
    print(e)

3) Generate a sample conditioned on the observations

We can also use a partial ChoiceMap as a constraint/observation and generate a full trace with these constraints.

In [None]:
key, subkey = jax.random.split(key)
partial_chm = C["v"].set(1)  # Creates a ChoiceMap of observations
args = (0.5,)
trace, weight = beta_bernoulli_process.importance(
    subkey, partial_chm, args
)  # Runs importance sampling

This returns a pair containing the new trace and the log probability of produced trace under the model

In [None]:
trace.get_choices()

In [None]:
weight

4) Update a trace.

We can also update a trace. This is for instance useful as a performance optimization in Metropolis-Hastings algorithms where often most of the trace doesn't change between time steps.

We first define a model for which changing the argument will force a change in the trace.

In [None]:
@gen
def beta_bernoulli_process(u):
    p = beta(1.0, u) @ "p"
    v = bernoulli(p) @ "v"
    return 2 * v

We then create an trace to be updated and constraints.

In [None]:
key, subkey = jax.random.split(key)
jitted = jit(beta_bernoulli_process.simulate)
old_trace = jitted(subkey, (1.0,))
constraint = C["v"].set(1)

Now the update uses a form of incremental computation.
It works by tracking the differences between the old new values for arguments.
Just like for differentiation, it can be achieved by providing for each argument a tuple containing the new value and its change compared to the old value.

If there's no change for an argument, the change is set to NoChange.

In [None]:
arg_diff = (Diff(1.0, NoChange),)

If there is any change, the change is set to UnknownChange.

In [None]:
arg_diff = (Diff(5.0, UnknownChange),)

We finally use the update method on the trace by passing it a key, and the update to be performed.

In [None]:
key, subkey = jax.random.split(key)
new_trace, weight_diff, ret_diff, discard_choice = jit(old_trace.update)(
    subkey, constraint, arg_diff
)

We can compare the old and new values for the samples and notice that they have not changed.

In [None]:
old_trace.get_choices() == new_trace.get_choices()

We can also see that the weight has changed. In fact we can check that the following relation holds `new_weight` = `old_weight` + `weight_diff`.

In [None]:
weight_diff, old_trace.get_score() + weight_diff == new_trace.get_score()

   5. A few more convenient methods

5.1 `propose`

It uses the same inputs as `simulate` but returns the sample, the score and the return value

In [None]:
key, subkey = jax.random.split(key)
sample, score, retval = jit(beta_bernoulli_process.propose)(subkey, (0.5,))
sample, score, retval

5.2 `get_gen_fn`

It returns the generative function that produced the trace.

In [None]:
trace.get_gen_fn()

5.3 `get_args`

It returns the arguments passed to the generative function used to produce the trace

In [None]:
trace.get_args()

5.4 `get_subtrace`

It takes a `StaticAddress` as argument and returns the sub-trace of a trace rooted at these addresses

In [None]:
subtrace = trace.get_subtrace("p")
subtrace, subtrace.get_choices()