In [1]:
%config InlineBackend.figure_format = 'svg'

In [2]:
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import dataclasses
import genjax
from genjax import Diff
from typing import List, Tuple, Any

sns.set_theme(style="white")

# Pretty printing.
console = genjax.pretty(width=80)

# Reproducibility.
key = jax.random.PRNGKey(314159)

The `update` interface method for generative functions defines an update operation on traces produced by generative functions. 

`update` allows the user to provide new constraints, as well as new arguments, and returns an updated trace which is consistent with the new constraints, as well as an incremental importance weight which measures the difference between the new and old constraints under the model. `update` is used to implement many types of iterative MCMC inference families.

The specification of `update` only requires that a modeling language support the above behavior - nonetheless, modeling languages can implement `update` with custom optimizations to improve the cost of repeatedly calling `update` (e.g. an iterative MCMC inference procedure).

In this notebook, we'll be focused on these optimization opportunities within the implementation of `update` for the `BuiltinGenerativeFunction` language. We'll describe a system which supports incremental computing capabilities using change information (called `Diff` in the codebase) propagation.^[Think of a value of `Diff` type as representing a new value $v^\prime$ using a decomposition $v^\prime = v \oplus dv$ where $dv$ is the change to the value and $v$ is the original value.] 

While we'll be focused on the distribution and builtin languages, this system is also applicable to the combinator implementations of `update`. In another notebook, we'll see how the incremental computing system can be used to efficiently compute `update` for `UnfoldCombinator`.

## What is `update` used for?

Before we discuss how `update` can be optimized by a generative function implementor, it's worth constructing a simple example which shows how `update` is used, and to show why optimizing `update` is worthwhile.

One common usage of `update` is in MCMC algorithm kernels. MCMC is often repeatedly applied to generate a chain of samples: any optimization opportunities that we identify and take advantage of will provide runtime gains which are multiplied over the length of the chain.

Let's example this scenario using a pedagogical example - remember that the potential optimization pattern (based upon random variable dependency information) we'll describe extends to all generative functions.

### Pedagogical example

Consider the following generative function:

In [3]:
@genjax.gen
def model(x):
    a = genjax.trace("a", genjax.Normal)(x, 1.0)
    b = genjax.trace("b", genjax.Normal)(x, 1.0)
    c = genjax.trace("c", genjax.Normal)(a + b, 1.0)
    return c

The variable dependency graph is shown below.

```{mermaid}
flowchart LR
  x[Argument] --> a[a]
  x --> b[b]
  a --> c[c]
  b --> c
  c --> r[Return]
```

Now, when we simulate a trace from this model - we get choices for `"a"`, `"b"`, and `"c"`.

In [4]:
key, tr = model.simulate(key, (2.0,))
tr

Iterative inference techniques like Metropolis-Hastings (and other MCMC methods) start with an initial trace, propose an update to the trace using a proposal, and then compute a criterion for accepting or rejecting the update.

In Metropolis-Hastings, the criterion involves an _accept-reject ratio_ computation - which requires computing the probability of transitioning from the current trace to the new trace, as well as the probability of transitioning from the new trace back to the current trace, under a kernel defined by the algorithm.

The library implementation of Metropolis-Hastings is shown below - `MetropolisHastings.apply` shows the main content of the algorithm (it's safe to ignore other methods for now).

In [5]:
@dataclasses.dataclass
class MetropolisHastings(genjax.MCMCKernel):
    selection: genjax.Selection
    proposal: genjax.GenerativeFunction

    def flatten(self):
        return (), (self.selection, self.proposal)

    def apply(self, key, trace: genjax.Trace, proposal_args: Tuple):
        model = trace.get_gen_fn()
        model_args = trace.get_args()
        proposal_args_fwd = (trace.get_choices(), *proposal_args)
        key, proposal_tr = self.proposal.simulate(key, proposal_args_fwd)
        fwd_weight = proposal_tr.get_score()
        diffs = jtu.tree_map(Diff.no_change, model_args)
        key, (_, weight, new, discard) = model.update(
            key, trace, proposal_tr.get_choices(), diffs
        )
        proposal_args_bwd = (new, *proposal_args)
        key, (bwd_weight, _) = self.proposal.importance(key, discard, proposal_args_bwd)
        alpha = weight - fwd_weight + bwd_weight
        key, sub_key = jax.random.split(key)
        check = jnp.log(random.uniform(sub_key)) < alpha
        return (
            key,
            jax.lax.cond(
                check,
                lambda *args: (new, True),
                lambda *args: (trace, False),
            ),
        )

    def reversal(self):
        return self

This computation involves `update` - which _incrementally_ updates a trace to be consistent with new arguments and constraints, and computes an importance weight (the difference between the trace's new score and the old score).

::: {.callout-important}

In the invocation of `update`, there's an interesting not-yet-explained argument: `diffs` - a tuple of `Diff` values, which represent _changes_ to the original arguments of the call which produced the trace which we are attempting to update. We'll come back to these values in a moment.

:::

If we naively evaluate the required log probability by re-evaluating the entire model - we're performing extra computation. We can see this by considering a specific target address - let's consider `"a"`. If the update changes `"a"`, what other generative function calls do we need to visit to compute the correct update - both to the trace, and the importance weight? 

The graph below shows the answer.

```{mermaid}
flowchart LR
  x[Argument] --> a[<b>a</b>]
  x --> b[b]
  a --> c[<b>c</b>]
  b --> c
  c --> r[Return]
  style a fill:#f9f,stroke:#333,stroke-width:4px
  style c fill:#f9f,stroke:#333,stroke-width:4px
```

An update to `"a"` requires that we re-evaluate the log probability at `"c"` because the return value of the generative function call at `"a"` flows into the generative function call at `"c"` - but we do not need to re-visit `"b"` because none of the values which flow into `"b"` have changed. 

When computing the weight difference, unchanged sites thus contribute nothing.^[The important idea is that tracking what values have changed allows us to identify what parts of the computation graph are required - and what parts do not need to be re-visited or re-computed.]

## Change information

The specification of `update` doesn't require that an implementation track or use the change information - but generative function implementations can choose to optimize their `update` implementation. 

With that in mind, several of the languages which GenJAX exposes can be instructed to perform optimized `update` computations using `Diff` values.

A `Diff` value consists of a base value `v` and a value of `Change` type, which represents the change to the base value. The new argument value for `update` is given by $\text{v} \oplus dv$ where `dv :: Change`. 

The $\oplus$ operation must be appropriately defined for the change type lattice - we implement this operation for common change types in GenJAX, but users can define their own change types for `Pytree` data classes.

In [6]:
genjax.Diff.new(5.0, genjax.NoChange)

### Diffs for distributions

Let's explore the basics with distributions.

In [7]:
key, dist_tr = genjax.Normal.simulate(key, (0.0, 1.0))
dist_tr

In [8]:
# dist_tr.update is equivalent to model.update(key, tr, ...)
key, (ret_diff, w, tr, d) = dist_tr.update(
    key,
    genjax.EmptyChoiceMap(),
    (
        genjax.Diff.new(1.0, genjax.UnknownChange),
        genjax.Diff.new(1.0, genjax.NoChange),
    ),
)

The return values do not change.

In [9]:
(dist_tr.get_retval(), ret_diff.val)

The weight is non-zero because the arguments have changed, implying that we must re-evaluate the log probability.

In [10]:
w

What does the code look like when there is no new constraint and both the arguments do not change?

In [11]:
# dist_tr.update is equivalent to model.update(key, tr, ...)
jaxpr = jax.make_jaxpr(dist_tr.update)(
    key,
    genjax.EmptyChoiceMap(),
    (
        genjax.Diff.new(0.0, genjax.NoChange),
        genjax.Diff.new(1.0, genjax.NoChange),
    ),
)
jaxpr

As expected, no computation is required - so the flattened arguments are just forwarded to the return.

### Diff propagation in the `BuiltinGenerativeFunction` language

The builtin language exposes a new set of complexities. 

Now, programs represent generative functions. If we were designing a compiler, we'd need to explore an analysis which tracks the dataflow of values, and determines how changes to values propagate. With this analysis, we can identify what call sites are affected by changes, and implement our optimizations accordingly.

This is a rough description of how  you would implement change propagation and optimization in a traditional compiler. But with JAX, we can just write an interpreter to do the analysis and optimization for us. The interpreter will operate on values lifted to `Diff` types, with the set of change values defined by a lattice of change types.

[This interpreter is the propagation interpreter.](https://github.com/probcomp/genjax/blob/main/src/genjax/_src/core/propagate.py)^[We discuss this interpreter in more detail in the [Implementing the builtin modeling language](/advanced/impl_builtin_language/impl_builtin_language.ipynb) notebook.] We won't show the full complexity of this interpreter in the notebook - but we'll briefly walk through some of the components.

This interpreter operates on `Diff` values, and supports customization via propagation rules which can be provided to the interpreter.

#### Propagation rules

In [12]:
genjax.diff_propagation_rules

What is the `fallback_rule`? This rule is defined to operate on any primitive operation which is not explicitly key'd into the `rule_set`. In the `rule_set` for `Diff` propagation, we only provide overloads for special JAX operations - otherwise, the primitive uses the `fallback_rule` to attempt to propagate `Diff` values.

In [None]:
def fallback_diff_rule(prim: Any, incells: List[Diff], outcells: Any, **params):
    if all(map(lambda v: v.top(), incells)):
        in_vals = list(map(lambda v: v.get_val(), incells))
        out = prim.bind(*in_vals, **params)
        if all(map(lambda v: v.get_change() == NoChange, incells)):
            new_out = jtu.tree_map(lambda v: Diff.new(v, change=NoChange), out)
        else:
            new_out = jtu.tree_map(lambda v: Diff.new(v), out)
        if not prim.multiple_results:
            new_out = [new_out]
    else:
        new_out = outcells
    return incells, new_out, None

Here, the `incells` define edges propagating into the primitive operation `prim` in the `Jaxpr` representation of our program.

`v.top()` is a call to a method which returns information about the base value and the `Change` value in the `Diff` - the `Change` values are defined on a `Change` type lattice. For now, we'll just consider two values: `NoChange` and `UnknownChange`.

`v.top()` just checks to make sure that the base value is not `None` (e.g. propagation has succeeded in learning the type of this edge into the `prim` operation).

Propagation with this rule set ensures that once the type of the base value is known, so is the `Change` value.

If we know the base values of all `incells` edges, we can `bind` the primitive using the invals, and then decide how to propagate a new `Change` value out. If any of the incoming edge `Change` values are `UnknownChange` - we set all the outgoing edges to `UnknownChange` (that's the default call to `Diff.new(...)`).

#### What about `trace`?

You may be wondering - is `genjax.trace` also handled by the `fallback_rule`? The answer is no.

The semantics of the generative function interface methods sometimes require that we record state during the interpretation of `genjax.trace`. 

To support this pattern, we don't handle `genjax.trace` with a rule from the rule set - we use a stateful handler (similar to an effect handler, but staged out at trace time) to provide overloads for `genjax.trace` (and `genjax.cache`).

::: {.callout-important}

When we say "stateful" here -- you might be concerned about mutation, and JAX's distaste for it. These stateful handlers don't mutate JAX-owned values (tracers), so they're compatible with JAX's pure programming model. 

The handlers provide a state monad-like local context during JAX tracing. This context models read and writes to state - but because we stage out the transformation, they are partially evaluated away during code generation, and they do not occur in our runtime modeling and inference code.

Now, there are restrictions in how we write these stateful handlers - we discuss this further in [Implementing the builtin modeling language](/advanced/impl_builtin_language/impl_builtin_language.ipynb).

:::

#### Applying the interpreter

Below, we show how we invoke the interpreter in the implementation of `update` for the `BuiltinGenerativeFunction` language.

There's quite a bit of complexity here. If you want a further unpacking of this section, we recommend the [Implementing the builtin modeling language](/advanced/impl_builtin_language/impl_builtin_language.ipynb) notebook.

Just remember that this transformation is also staged out - so objects like the stateful `handler` are used to

In [None]:
# `update_transform` implements the semantics of `update` using a program transformation.
def update_transform(source_fn, **kwargs):
    @functools.wraps(source_fn)
    def _inner(key, previous_trace, constraints, argdiffs):
        vals = jtu.tree_map(strip_diff, argdiffs, is_leaf=check_is_diff)
        jaxpr, (_, _, out_tree) = stage(source_fn)(*vals, **kwargs)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals

        # Interpreter also accepts a stateful handler -
        # this encapsulates any extra state that we wish to
        # return out of the transformed function.
        handler = Update.new(key, previous_trace, constraints)

        # The interpreter used as a context.
        # We specify that the base type lattice for values
        # is `Diff`, and we also specify propagation rules.
        with PropagationInterpreter.new(
            Diff, diff_propagation_rules, handler
        ) as interpreter:
            flat_argdiffs, _ = jtu.tree_flatten(argdiffs, is_leaf=check_is_diff)

            # The interpreter iterates over the program `Jaxpr`,
            # applying the propagation rules, until a fixpoint
            # is reached.
            final_env, _ = interpreter(
                jaxpr,
                [Diff.new(v, change=NoChange) for v in consts],
                flat_argdiffs,
                [Diff.unknown(var.aval) for var in jaxpr.outvars],
            )

            # We get the final retval values out of the environment.
            flat_retval_diffs = safe_map(final_env.read, jaxpr.outvars)
            retval_diffs = jtu.tree_unflatten(
                out_tree,
                flat_retval_diffs,
            )

            # Now, we get the values which we must return from
            # the generative function interface - these come from
            # the stateful handler.
            retvals = tuple(map(strip_diff, flat_retval_diffs))
            retvals = jtu.tree_unflatten(out_tree, retvals)
            w = handler.weight
            constraints = handler.choice_state
            cache = handler.cache_state
            discard = handler.discard
            key = handler.key

        # Return everything that `update` in `BuiltinGenerativeFunction`
        # needs to implement its specification.
        return (
            key,
            (
                retval_diffs,
                w,
                (
                    source_fn,
                    vals,
                    retvals,
                    constraints,
                    previous_trace.get_score() + w,
                ),
                discard,
            ),
            cache,
        )

    return _inner

#### Combining the ingredients

Without digging into the implementation of this functionality, we can spot check that it's working by examining a few example `update` computations.

## `cache`: change aware memoization

The `BuiltinGenerativeFunction` language exposes a primitive called `cache` that interacts with the change tracking system to support memoization of deterministic computations (even deterministic computations which depend on random choices).