In [1]:
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import genjax
from genjax import GenerativeFunction, ChoiceMap, Selection, trace

# Pretty printing.
console = genjax.pretty(show_locals=True, width=70)

# Reproducibility.
key = jax.random.PRNGKey(314159)

## What is GenJAX?

Here are a few high-level ways to think about GenJAX:

* A probabilistic programming system based on the concepts of [Gen](https://www.gen.dev/).

* A Bayesian model and inference compiler with support for device acceleration (courtesy of JAX).

* A base layer for experiments in model and inference DSL design.

There's already a well-supported implementation of [Gen in Julia](https://github.com/probcomp/Gen.jl). Why is a JAX port interesting?

There are a number of compelling reasons to explore Gen's probabilistic programming paradigm on top of JAX, here are a few:

* JAX's excellent accelerator support - our implementation natively supports several common accelerator idioms - like automatic struct-of-array representations, and the ability to automatically batch model/inference programs onto accelerators.

* JAX's excellent support for compositional higher-order AD allows us to explore higher-order AD in inference design, previously a difficult implementation challenge.

* JAX exposes compositional code transformations to library authors, and, as library authors, we can utilize code transformations to implement state-of-the-art optimizations for models and inference expressed in our system.

Let's truncate the list there for now.

For the JAX literati, one final (hopefully tantalizing) takeaway: <u>by construction, all GenJAX modeling + inference code is JAX traceable</u> - and thus, `vmap`able, `jit`able, etc.

## What is probabilistic programming?

Perhaps you may be coming to this notebook without any prior knowledge in probabilistic programming...

That's okay! Ideally, the ideas in this notebook should be self-contained^[You may miss _why generative functions (see below) are designed the way they are_ on a first read - but you'll still get the punchline if you follow the notebook to the end.].

### A Bayesian viewpoint

Here's one take on what probabilistic programming is all about: programming language design for expressing and solving Bayesian inference problems.

In the [Probabilistic Computing lab at MIT](http://probcomp.csail.mit.edu/), we also consider differentiable programming to be contained within the set of concerns of probabilistic programming.

Probabilistic programming is a broad field, and there are corners which may not be covered by this viewpoint. Mostly, people are interested in Bayes, and how to represent Bayes on computers in nice ways.

### What are we actually computing with?

As you'll see if you continue to read below, the objects which we program with expose a mixture of generative and differentiable interfaces - the interfaces are designed to support common (as well as quite advanced) classes of Bayesian inference algorithms. Gen provides automation for the tricky math which these algorithms sometimes require.

We separate the design of inference (whose implementation uses the interfaces), from the implementation of the interfaces on computational objects. This allows us to build languages of objects which satisfy the interfaces - and allows their interoperability and usage.

In Gen, we call objects which implement the interfaces **generative functions**.

## What is a generative function?

Generative functions are the key concept of Gen's probabilistic programming paradigm. Generative functions are computational objects defined by a set of associated data types and methods. These types and methods describe compositional interfaces useful for inference computations. 

Gen's formal description of generative functions consist of two objects:

* $P(\tau, r; x)$ - a normalized measure over tree-like data (*choice maps*) and untraced randomness^[More on this later. It's safe to say "I have no idea what that is" for now, and expect us to explain later.] $r$, parametrized by arguments $x$.

* $f(\tau; x)$ - a deterministic function from the above measure's sample space to a space of data types.

We can informally think of the sampling semantics of these objects as consisting of two steps:

1. First, sample a  choice map from $P$.
2. Then, compute the return value using $f$.

In many of the generative function interfaces, we won't just be interested in the final sampled return value. We'll also be interested in what happened along the way: we'll record the intermediate and final results of these steps in `Trace` objects - data structures which contain the recordings of values, along with probabilistic metadata like the score of random choices selected along the way.

Here's an example of a GenJAX generative function^[There's not just one generative function class - users can and are encouraged to design new types of generative functions which capture repeated modeling patterns. An excellent example of this modularity in Gen's design is [generative function combinators](/introduction/intro_to_combinators/intro_to_combinators).]. 

This generative function is part of a function-like language - pay close attention to the hierarchical compositionality of generative functions in this language under an abstraction (`genjax.trace`) similar to a function call. 

We'll discuss the addresses (`"sub"` and `"m0"`) a bit later.

In [2]:
@genjax.gen
def g(key, x):
    key, m0 = genjax.trace("m0", genjax.Bernoulli)(key, x)
    return (key, m0)


@genjax.gen
def h(key, x):
    key, m0 = genjax.trace("sub", g)(key, x)
    return (key, m0)


h

This generative function holds a Python `Callable` object. For this generative function language, the interface methods (see the list under **Generative function interface** below) which are useful for modeling and inference are given semantics via JAX's tracing and program transformation infrastructure.

Let's examine some of these operations now.

In [3]:
console.inspect(genjax.BuiltinGenerativeFunction, methods=True)

This is our first glimpse of the **generative function interface** (GFI), the secret sauce which Gen is based around.

::: {.callout-note}

## JAX interfaces

There's a few methods here which are not part of the GFI, but are worth mentioning because they deal with data interfaces to JAX:

* `flatten` - which allows us to treat generative functions as [Pytree](https://jax.readthedocs.io/en/latest/pytrees.html) implementors.
* `unflatten` - same as above.

These are used to register the implementor type as a `Pytree`, which is roughly a tree-like Python structure which JAX can zip/unzip at runtime API boundaries.

:::

Let's study the `simulate` method first: we'll explore its semantics, and see the types of data it produces.

In [4]:
key, tr = genjax.simulate(h)(key, (0.3,))
tr

If you're familiar with other "trace-based" probabilistic systems - this should look familiar. 

This object instance is a piece of data which has captured information about the execution of the function. Specifically, the subtraces of _other generative function calls_ which occur in `genjax.trace` statements.

It also captures the `score` - the log probability of the normalized measure which the model program represents, evaluated at the random choices which the generative call execution produced. 

If you were paying attention above, the score is $\log P(\tau, r; x)$.

### How is `simulate` implemented for this language?

For this generative function language, we implement `simulate` using a code transformation! Here's the transformed code.

In [5]:
jaxpr = jax.make_jaxpr(genjax.simulate(h))(key, (0.3,))
jaxpr

This is how we've implemented `simulate` for this particular generative function language.^[In general, Gen doesn't require that we follow the same "code transformation" implementation for other generative function languages. GenJAX, however, is a bit special - because we restrict the user to remain within the JAX traceable subset of Python - any generative function interface implementation must also be JAX traceable. **This is a JAX requirement, not a Gen one**.]

## Generative function interface

There are a few more generative function interface methods worth discussing.

In this notebook, instead of carefully walking through the math which these interface methods compute, we'll defer that discussion to another notebook. Below, we give an informal discussion of what each of the interface methods computes, and roughly describe what algorithm families are supported by their usage.

### The generative function interface in GenJAX

GenJAX's generative functions define an interface which support compositional usage of generative functions within other generative functions. The interface functions here closely mirror [the interfaces defined in Gen.jl](https://www.gen.dev/docs/stable/ref/gfi/#Generative-function-interface-1).

In the following, we use the following abbreviations:

* **IS** - importance sampling
* **SMC** - sequential Monte Carlo
* **MCMC** - Markov chain Monte Carlo
* **VI** - variational inference

| Interface | Type | Inference algorithm support |
| --- | --- | --- |
| `simulate` | Generative | IS, SMC |
| `importance` | Generative | IS, SMC, VI |
| `update` | Generative and incremental | MCMC, SMC |
| `assess` | Generative and differentiable | MCMC, IS, SMC |
| `unzip` | Differentiable | Differentiable and involutive MCMC and SMC, VI |

This interface supports several methods - I've roughly described them and split them into the two categories **Generative** and **Differentiable** below:

#### Generative

* `simulate` - sample from normalized trace measure, and return the score.
* `importance` - given constraints for some addresses, sample from unnormalized trace measure and return an importance weight.
* `update` - given an existing trace, and a set of constraints and argument change values, update the trace to be consistent with the set of constraints under execution with the new arguments, and return an incremental importance weight.
* `assess` - given a complete choice map and arguments, return the normalized log probability.


#### Differentiable

* `assess` - same as above.
* `unzip` - given a set of fixed constraints, return two callables. The first callable `score` accepts constraints which fill in the complement of the fixed constraints and arguments, and returns the normalized log probability of all the constraints. The second callable `retval` accepts constraints and arguments, and returns the return value for the generative function call consistent with the constraints and given arguments.

**unzip** produces two functions which can be compositionally used with `jax.grad` to evaluate gradients used by both differentiable and involutive MCMC and SMC.

## More about generative functions

Here are a few more bits of information which should help you gain context with these objects.

### Distributions are generative functions

In GenJAX, distributions are generative functions.

In [6]:
key, tr = genjax.simulate(genjax.Normal)(key, (0.0, 1.0))
tr

This should bring a sigh of relief! Ah, distributions are generative functions - *they can't be too exotic*.

Next to deterministic computations which trivially implement the interface, distributions are quite simple. They don't have internal compositional choice structure (like the function-like `BuiltinGenerativeFunction` language above).

Distributions themselves expose two interfaces:

* `logpdf` - exact density evaluation.
* `sample` - exact sampling.

We can use these two interfaces to implement all the generative function interfaces for distributions.

### Associated data types

* **Choice maps** are the tree-like recordings of random choices in a trace.
* **Selection** is an object which allows querying a trace/choice map - selecting certain choices.

In [7]:
@genjax.gen
def h(key, x):
    key, m1 = genjax.trace("m0", genjax.Bernoulli)(key, x)
    key, m2 = genjax.trace("m1", genjax.Bernoulli)(key, x)
    return (key, m1 + m2)


key, tr = genjax.simulate(h)(key, (0.3,))
tr

In [8]:
select = genjax.BuiltinSelection.new(["m1"])
selected, _ = select.filter(tr.get_choices())
selected

## Great ... now what can I do with them?

Now, we've informally seen the interfaces and datatypes associated with generative functions.

While studying the interfaces and the computational objects which satisfy them _in the abstract_ can be an entire PhD's worth of effort, we can do machine learning with them.

Let's consider a modeling problem where we wish to perform generalized regression with outliers between two variates, taking a family of polynomials as potential curves.

One such model for this data generating process is shown below.

In [9]:
# Two branches for a branching submodel.
@genjax.gen
def model_y(key, x, coefficients):
    key, y = trace("value", genjax.Normal)(key, 0.0, 1.0)
    return key, y


@genjax.gen
def outlier_model(key, x, coefficients):
    key, y = trace("value", genjax.Normal)(key, 0.0, 10.0)
    return key, y


# The branching submodel.
switch = genjax.SwitchCombinator([model_y, outlier_model])

# A mapped kernel function which calls the branching submodel.
@genjax.gen(genjax.MapCombinator, in_axes=(0, 0, None))
def kernel(key, x, coefficients):
    key, is_outlier = trace("outlier", genjax.Bernoulli)(key, 0.1)
    is_outlier = jnp.asarray(is_outlier, dtype=int)
    key, y = trace("y", switch)(key, is_outlier, x, coefficients)
    return key, y


@genjax.gen
def model(key, xs):
    key, coefficients = trace("alpha", genjax.MvNormal)(
        key, np.zeros(3), 2.0 * np.identity(3)
    )
    key, *sub_keys = jax.random.split(key, len(xs) + 1)
    sub_keys = jnp.array(sub_keys)
    _, ys = trace("ys", kernel)(sub_keys, xs, coefficients)
    return key, ys

There's a number of implementation patterns which you might pick up on by studying this model.

1. Generative functions explicitly pass a PRNG key in and out. This conforms to JAX's PRNG usage expectations.

2. To implement control flow, we use higher-order functions called combinators. These accept generative functions as input, and return generative functions as output.

3. Any JAX compatible code is allowed in the body of a generative function.

Courtesy of the interface, we get to design our `model` generative function in pieces. 

Now, let's examine a few sample traces from our model.

In [10]:
data = jnp.array([0.3, 1.0, 1.5, 2.0])
key, tr = jax.jit(model.simulate)(key, (data,))
tr

## Your first inference program