In [1]:
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import genjax

sns.set_theme(style="white")

# Pretty printing.
console = genjax.pretty(width=80)

# Reproducibility.
key = jax.random.PRNGKey(314159)

The `update` interface method for generative functions defines an update operation on traces produced by generative functions. `update` allows the user to provide new constraints, as well as new arguments, and returns an updated trace which is consistent with the new constraints, as well as an incremental importance weight. `update` is used to implement many types of MCMC inference.

Depending on the implementation, `update` provides several optimization opportunities. In this notebook, we'll be focused on these opportunities within the implementation of `update` for the `BuiltinGenerativeFunction` language. We'll describe a system which supports incremental computing capabilities using `Diff` propagation.^[Think of a value of `Diff` type as representing a change to a value.] While we'll be focused on the builtin language, this system is also applicable to the combinator implementations of `update`. We'll also describe how we can use this system to support an efficient incremental update operation for `Unfold` combinator.

## What is `update` used for?

Before we discuss how `update` can be optimized, it's worth constructing a simple example to show why optimizing `update` is worthwhile, especially when repeatedly apply MCMC kernels during an inference process.