# Introduction to Gen and GenJAX

The purpose of this notebook to give the listener/reader an accelerated introduction to several concepts native to Gen and GenJAX (an implementation of Gen on top of JAX). It mostly assumes as a pre-requisite some familiarity with trace-based probabilistic programming systems, and Monte Carlo inference - especially importance sampling and MCMC methods.

In [1]:
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import genjax
from genjax import GenerativeFunction, ChoiceMap, Selection

# Pretty printing.
console = genjax.go_pretty()

# Reproducibility.
key = jax.random.PRNGKey(314159)

## What is GenJAX?

### Short pitch

GenJAX is:
* A probabilistic programming system based on the concepts of [Gen](https://www.gen.dev/)
* A model + inference compiler with support for device acceleration (courtesy of JAX)
* A base layer for experiments in model + inference DSL design.

By virtue of a few key design decisions, and JAX's excellent foundation - it natively supports several common accelerator idioms - like automatic `struct of array` representations, and the ability to automatically batch model/inference programs onto accelerators. It does this - while supporting the convenience of Gen's interfaces - allowing modular construction of complicated generative programs from smaller pieces.

<u>By construction, all GenJAX modeling + inference code is JAX jittable</u> - and thus, `vmap`able, etc.

## What's a generative function?

*Generative functions* are the key concept of Gen's probabilistic programming paradigm. Generative functions are computational objects defined by a set of associated data types and methods. These types and methods describe compositional interfaces useful for inference computations. 

Gen's formal description of generative functions consist of two objects:
* $P(\tau, r; x)$ - a measure over dictionary-like data (*choice maps*) and untraced randomness $r$, parametrized by arguments $x$.
* $f(\tau; x)$ - a deterministic function from the above measure's target space to a space of data types.

We often think about sampling choice maps from $P$, computing the return value from the generative function call using $f$ - we record both in `Trace` objects - data structures which contain the recordings of these values, along with probabilistic metadata like the score of random choices selected along the way.

Here's an example of a GenJAX generative function using a function-like language - notice the composition under a function call abstraction (`genjax.trace`).

In [2]:
@genjax.gen
def g(key, x):
    key, m1 = genjax.trace("m0", genjax.Bernoulli)(key, (x,))
    return (key, m1)


@genjax.gen
def h(key, x):
    key, m1 = genjax.trace("m0", g)(key, (x,))
    return (key, m1)


h

This is a `Callable` object -- operations (see the list under **Generative function interface** below) which are useful for modeling and inference are given semantics via program transformations.

In [3]:
key, tr = genjax.simulate(h)(key, (0.3,))
tr

If you're familiar with other "trace-based" probabilistic systems - this should look familiar. It's a piece of data which has captured information about the execution of the function - specifically, the choices of traced random calls, include the `score` - a log probability from the normalized measure which the model program represents.

`simulate` is a code transformation! Here's the transformed code.

In [4]:
jaxpr = jax.make_jaxpr(genjax.simulate(h))(key, (0.3,))
jaxpr

## Generative function interface

GenJAX's generative functions define an interface which support compositional usage of generative functions within other generative functions. The interface functions here closely mirror [the interfaces defined in Gen](https://www.gen.dev/docs/stable/ref/gfi/#Generative-function-interface-1), deviating only when interfaces are redundant (or implemented in Gen.jl for performance optimized code paths - which may not be relevant to our implementation).

| Interface | Type | Inference algorithm support |
| --- | --- | --- |
| `simulate` | Generative | Importance sampling, SMC |
| `importance` | Generative | Importance sampling, SMC |
| `update` | Generative/incremental | MCMC, SMC |
| `choice_vjp` | Differentiable | Differentiable Monte Carlo |
| `retval_vjp` | Differentiable | Involutive MCMC and SMC |

This interface supports several methods -- I've roughly described them and split them into the two categories **Generative** and **Differentiable** below:

#### Generative

* `simulate` - sample from normalized trace measure, and return the score.
* `importance` - given constraints for some addresses, sample from unnormalized trace measure and return an importance weight.
* `update` - given an existing trace, and a set of constraints and argument , update the trace to be consistent with the set of constraints under execution with the new arguments, and return an incremental importance weight.

#### Differentiable

Below, the term `pullback` is used to denote mappings from cotangents of return values of a function to cotangents of primals (argument inputs).

* `choice_vjp` - compute (using reverse mode) the pullback of log joint with respect to selected random choices (selected via a `genjax.Selection` - see below) *and* the arguments to the generative function call. This is similar to JAX's `vjp` - it returns the same signature, but the pullback computes the gradient defined in the previous sentence. 

**Short**: get the gradients wrt log pdf of random choices, and the gradients wrt log pdf of arguments.

* `retval_vjp` - compute (using reverse mode) the pullback of the return value function with respect to selected random choices (selected via a `genjax.Selection` - see below). This is similar to JAX's `vjp` - it returns the same signature, but the pullback computes the gradient defined in the previous sentence.

**Short**: get the gradients wrt return value function of selected choices, and the gradients wrt return value function of arguments.

**Note**: forward mode versions can easily be defined for both of these interfaces - there's no fundamental restriction on which mode you use.

## More about generative functions

### Distributions are generative functions

In GenJAX, distributions are generative functions.

In [5]:
key, tr = genjax.simulate(genjax.Normal)(key, (0.0, 1.0))
tr

### Associated data types

* **Choice maps** are the dictionary-like recordings of random choices in a trace.
* **Selection** is an object which allows querying a trace/choice map - selecting certain choices.

In [6]:
@genjax.gen
def h(key, x):
    key, m1 = genjax.trace("m0", genjax.Bernoulli)(key, (x,))
    key, m2 = genjax.trace("m1", genjax.Bernoulli)(key, (x,))
    return (key, m1 + m2)


key, tr = genjax.simulate(h)(key, (0.3,))
tr

In [7]:
select = genjax.Selection(["m1"])
selected, _ = select.filter(tr)
selected