The purpose of this notebook to give the listener/reader an accelerated introduction to several concepts native to Gen and GenJAX (an implementation of Gen on top of JAX). As for pre-requisites, it assumes familiarity with trace-based probabilistic programming systems, and Monte Carlo inference - especially importance sampling and MCMC methods.

In [1]:
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import genjax
from genjax import GenerativeFunction, ChoiceMap, Selection, trace

# Pretty printing.
console = genjax.pretty()

# Reproducibility.
key = jax.random.PRNGKey(314159)

## What is GenJAX?

GenJAX is:

* A probabilistic programming system based on the concepts of [Gen](https://www.gen.dev/).

* A model and inference compiler with support for device acceleration (courtesy of JAX).

* A base layer for experiments in model and inference DSL design.

By virtue of a few key design decisions, and JAX's excellent foundation - it natively supports several common accelerator idioms - like automatic struct-of-array representations, and the ability to automatically batch model/inference programs onto accelerators. It does this - while supporting the convenience of Gen's interfaces - allowing modular construction of generative programs from smaller pieces.

<u>By construction, all GenJAX modeling + inference code is JAX jittable</u> - and thus, `vmap`able, etc.

## What is a generative function?

*Generative functions* are the key concept of Gen's probabilistic programming paradigm. Generative functions are computational objects defined by a set of associated data types and methods. These types and methods describe compositional interfaces useful for inference computations. 

Gen's formal description of generative functions consist of two objects:

* $P(\tau, r; x)$ - a measure over dictionary-like data (*choice maps*) and untraced randomness $r$, parametrized by arguments $x$.

* $f(\tau; x)$ - a deterministic function from the above measure's target space to a space of data types.

We often think about sampling choice maps from $P$, computing the return value from the generative function call using $f$ - we record both in `Trace` objects - data structures which contain the recordings of these values, along with probabilistic metadata like the score of random choices selected along the way.

Here's an example of a GenJAX generative function. This generative function is part of a function-like language - pay close attention to the hierarchical compositionality of generative functions in this language under an abstraction (`genjax.trace`) similar to a function call.

In [2]:
@genjax.gen
def g(key, x):
    key, m1 = genjax.trace("m0", genjax.Bernoulli)(key, x)
    return (key, m1)


@genjax.gen
def h(key, x):
    key, m1 = genjax.trace("m0", g)(key, x)
    return (key, m1)


h

This is a `Callable` object - operations (see the list under **Generative function interface** below) which are useful for modeling and inference are given semantics via program transformations.

Let's examine these operations now.

In [3]:
console.inspect(genjax.BuiltinGenerativeFunction, methods=True)

This is our first glimpse of the **generative function interface** (GFI), the secret sauce which Gen is based around.

There's a few methods here which are not part of the GFI:

* `flatten` - which allows us to treat generative functions as [Pytree](https://jax.readthedocs.io/en/latest/pytrees.html) implementors.
* `unflatten` - same as above.

Let's study the `simulate` method first: we'll explore its semantics, and see the types of data it produces.

In [4]:
key, tr = genjax.simulate(h)(key, (0.3,))
tr

If you're familiar with other "trace-based" probabilistic systems - this should look familiar. 

This object instance is a piece of data which has captured information about the execution of the function. Specifically, the subtraces of _other generative function calls_ in `genjax.trace`.

It also captures the `score` - the log probability of the normalized measure which the model program represents, evaluated at the random choices which the generative call execution produced. If you were paying attention above, the score is $\log P(\tau, r; x)$.

### How is `simulate` implemented for this language?

For this generative function language, we implement `simulate` using a code transformation! Here's the transformed code.

In [5]:
jaxpr = jax.make_jaxpr(genjax.simulate(h))(key, (0.3,))
jaxpr

## Generative function interface

GenJAX's generative functions define an interface which support compositional usage of generative functions within other generative functions. The interface functions here closely mirror [the interfaces defined in Gen](https://www.gen.dev/docs/stable/ref/gfi/#Generative-function-interface-1), deviating only when interfaces are redundant (or implemented in Gen.jl for performance optimized code paths - which may not be relevant to our implementation).

| Interface | Type | Inference algorithm support |
| --- | --- | --- |
| `simulate` | Generative | Importance sampling, SMC |
| `importance` | Generative | Importance sampling, SMC |
| `update` | Generative and incremental | MCMC, SMC |
| `assess` | Generative and differentiable | MCMC, importance sampling, SMC |
| `unzip` | Differentiable | Differentiable and involutive MCMC and SMC |

This interface supports several methods - I've roughly described them and split them into the two categories **Generative** and **Differentiable** below:

#### Generative

* `simulate` - sample from normalized trace measure, and return the score.
* `importance` - given constraints for some addresses, sample from unnormalized trace measure and return an importance weight.
* `update` - given an existing trace, and a set of constraints and argument change values, update the trace to be consistent with the set of constraints under execution with the new arguments, and return an incremental importance weight.
* `assess` - given a complete choice map and arguments, return the normalized log probability.


#### Differentiable

* `assess` - same as above.
* `unzip` - given a set of fixed constraints, return two callables. The first callable `score` accepts constraints which fill in the complement of the fixed constraints and arguments, and returns the normalized log probability of all the constraints. The second callable `retval` accepts constraints and arguments, and returns the return value for the generative function call consistent with the constraints and given arguments.

**unzip** produces two functions which can be compositionally used with `jax.grad` to evaluate gradients used by both differentiable and involutive MCMC and SMC.

## More about generative functions

### Distributions are generative functions

In GenJAX, distributions are generative functions.

In [6]:
key, tr = genjax.simulate(genjax.Normal)(key, (0.0, 1.0))
tr

### Associated data types

* **Choice maps** are the dictionary-like recordings of random choices in a trace.
* **Selection** is an object which allows querying a trace/choice map - selecting certain choices.

In [7]:
@genjax.gen
def h(key, x):
    key, m1 = genjax.trace("m0", genjax.Bernoulli)(key, x)
    key, m2 = genjax.trace("m1", genjax.Bernoulli)(key, x)
    return (key, m1 + m2)


key, tr = genjax.simulate(h)(key, (0.3,))
tr

In [8]:
select = genjax.BuiltinSelection.new(["m1"])
selected, _ = select.filter(tr.get_choices())
selected

## Great ... now what can I do with them?

While studying the interfaces and the computational objects which satisfied them _in the abstract_ can be a pleasing hobby, we can do machine learning with them.

Let's consider a modeling problem where we wish to perform generalized regression with outliers between two variates, taking a family of polynomials as potential curves.

One such model for this data generating process is shown below.

In [9]:
@genjax.gen
def model_y(key, x, degree):
    key, coefficients = trace(
        "alpha",
        genjax.MapCombinator(genjax.Normal, in_axes=(None, None, None)),
    )(key, (0.0, 2.0))
    key, y = trace("value", genjax.Normal)(key)


@genjax.gen
def outlier_model(key, x, degree):
    y = trace("value", genjax.Normal)(key, (0.0, 10.0))
    return key, y


switch = genjax.SwitchCombinator([model_y, outlier_model])


@genjax.gen(genjax.MapCombinator, in_axes=(0, 0))
def kernel(key, x):
    key, is_outlier = trace("outlier", genjax.Bernoulli)(key, (0.1,))
    key, polynomial_degree = trace("degree", genjax.Geometric)(key, (0.1,))
    key, y = trace("y", switch)(key, (is_outlier, x, polynomial_degree))
    return key, y


@genjax.gen
def model(key, xs):
    pass

There's a number of implementation patterns which you might pick up on by studying this model.

1. Generative functions explicitly pass a PRNG key in and out. This conforms to JAX's PRNG usage expectations.

2. To implement control flow, we use higher-order functions called combinators. These accept generative functions as input, and return generative functions as output.

3. Any JAX compatible code is allowed in the body of a generative function.

Courtesy of the interface, we get to design our `model` generative function in pieces. 

Now, let's examine a few sample traces from our model.

In [None]:
def trace_visualizer(tr):
    pass

## Your first inference program

::: {.column-margin}
We know from *the first fundamental theorem of calculus* that for $x$ in $[a, b]$:


$$\frac{d}{dx}\left( \int_{a}^{x} f(u)\,du\right)=f(x).$$

:::