In [1]:
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import genjax
from genjax import GenerativeFunction, ChoiceMap, Selection, trace

sns.set_theme(style="white")

# Pretty printing.
console = genjax.pretty(width=70)

# Reproducibility.
key = jax.random.PRNGKey(314159)

The generative function interface exposes functionality which allows usage of generative functions for differentiable programming. These interfaces are designed to work seamlessly with `jax.grad` - allowing (even higher-order) gradient computations which are useful for inference algorithms which require gradients and gradient estimators. In this notebook, we'll describe some of these interfaces - as well as their (current, but not forever) limitations. We'll walk through an implementation of MAP estimation, as well as the Metropolis-adjusted Langevin algorithm (MALA) using these interfaces.

## `assess` for exact density evaluation

`assess` is a generative function interface method which computes log joint density estimates from generative functions. `assess` requires that a user provide a choice map _which completely fills all choices encountered during execution_. Otherwise, it errors.^[And these errors are thrown at JAX trace time, so you'll get an exception before runtime.]

If a generative function also draws from untraced randomness - `assess` computes an estimate whose expectation over the distribution of untraced randomness gives the correct log joint density. 

::: {.callout-important}

The gradient interfaces described in this notebook **do not compute** correct or useful estimates when used on generative programs which include untraced randomness.

:::