In [1]:
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import genjax

sns.set_theme(style="white")

# Pretty printing.
console = genjax.pretty(width=80)

# Reproducibility.
key = jax.random.PRNGKey(314159)

The generative function interface exposes functionality which allows usage of generative functions for differentiable programming. These interfaces are designed to work seamlessly with `jax.grad` - allowing (even higher-order) gradient computations which are useful for inference algorithms which require gradients and gradient estimators. In this notebook, we'll describe some of these interfaces - as well as their (current, but not forever) limitations. We'll walk through an implementation of MAP estimation, as well as the Metropolis-adjusted Langevin algorithm (MALA) using these interfaces.

## Gradient interfaces

Because JAX features best-in-class support for higher-order AD, GenJAX exposes interfaces that compose natively with JAX's interfaces for gradients. The primary interface method which provides `jax.grad`-compatible functions from generative functions is an interface called `unzip`. 

`unzip` allows a user to provide a key, and a fixed choice map - and it returns a new key and two closures:

* The first closure is a "score" closure which accepts a choice map as the first argument, and arguments which match the non-`PRNGKey` signature types of the generative function. The score closure returns the exact joint score of the generative function. It computes the exact joint score using an interface called `assess`.^[Caveat: `assess` is not required to return the _exact_ joint score, only an estimate. However, if `jax.grad` is used on estimates - the resulting thing is not a correct gradient estimator. See the important callout below!]
* The second closure is a "retval" closure which accepts a choice map as the first argument, and arguments which match the non-`PRNGKey` signature types of the generative function. The retval closure executes the generative function constrained using the union of the fixed choice map, and the user provided choice map, and returns the return value of the execution. Here, the return value is also provided by invoking the `assess` interface.

So really, `unzip` is syntactic sugar over another interface called `assess`.

### `assess` for exact density evaluation

`assess` is a generative function interface method which computes log joint density estimates from generative functions. `assess` requires that a user provide a choice map _which completely fills all choices encountered during execution_. Otherwise, it errors.^[And these errors are thrown at JAX trace time, so you'll get an exception before runtime.]

If a generative function also draws from untraced randomness - `assess` computes an estimate whose expectation over the distribution of untraced randomness gives the correct log joint density. 

::: {.callout-important}

## Correctness of gradient estimators

When used on generative functions which include untraced randomness, naively applying `jax.grad` to the closures returned by interfaces described in this notebook **do not compute** gradient estimators which are unbiased with respect to the true gradients.

Short: don't use these with untraced randomness. We're working on alternatives.

:::

## MAP estimation

## Exposing learnable modules with `TrainCombinator`

## Automatic differentiation variational inference

In this section, we'll show how we can use the gradient interfaces to implement [Automatic differentiation variational inference](https://arxiv.org/abs/1603.00788).