In [1]:
import genjax
import jax
import jax.tree_util as jtu
import seaborn as sns

sns.set_theme(style="white")

# Pretty printing.
console = genjax.console(width=70)

# Reproducibility.
key = jax.random.PRNGKey(314159)

One key property of the generative function interface is that it enables a separation between model and inference code - providing an abstraction layer that facilitates the development of modular model pieces, and then inference pieces that abstract over the implementation of the interface.

Now, implementing the interface on objects, and composing them in various ways (by e.g. specializing the implementation of the interface functions to support any intended composition) is a valid way to construct new generative functions. In fact, this is the pattern which generative function combinators follow - they accept generative functions as input, and produce new generative functions whose implementations are specialized to represent some specific pattern of computation.

Explicitly constructing generative functions using languages of objects, however, can often feel unwieldy. Part of the way that GenJAX (and [Gen.jl](https://github.com/probcomp/Gen.jl)) alleviates this restriction is by exposing languages _which construct generative functions from programs_. This drastically increases the expressivity available to the programmer.

In GenJAX, here's an example of the `BuiltinGenerativeFunction` language:

In [2]:
@genjax.Static
def model(x):
    y = genjax.trace("y", genjax.normal)(x, 1.0)
    z = genjax.trace("z", genjax.normal)(y + x, 1.0)
    return z

When we apply one of the interface functions to this object, we get the associated data types that we expect.

In [4]:
key, sub_key = jax.random.split(key)
tr = model.simulate(sub_key, (1.0,))
tr

StaticTrace(gen_fn=StaticGenerativeFunction(source=<function model at 0x169dbdbd0>), args=(1.0,), retval=Array(-0.0015921, dtype=float32), address_choices=Trie(inner={'y': DistributionTrace(gen_fn=TFPDistribution(make_distribution=<class 'tensorflow_probability.substrates.jax.distributions.normal.Normal'>), args=(1.0, 1.0), value=Array(-0.67220366, dtype=float32), score=Array(-2.317071, dtype=float32)), 'z': DistributionTrace(gen_fn=TFPDistribution(make_distribution=<class 'tensorflow_probability.substrates.jax.distributions.normal.Normal'>), args=(Array(0.32779634, dtype=float32), 1.0), value=Array(-0.0015921, dtype=float32), score=Array(-0.9731869, dtype=float32))}), cache=Trie(inner={}), score=Array(-3.290258, dtype=float32))

How exactly do we do this? In this notebook, you're going to find out. You'll also get a chance to explore some of the capabilities which JAX exposes to library designers. Ideally, you'll also get a sense of some of the limitations of JAX (and GenJAX) - which are restricted to support programs which are amenable to GPU/TPU acceleration.

## The magic of JAX

Let's examine the generative function object:

In [6]:
console.print(model)

All the decorator `genjax.gen` does is wrap the function into this object. It holds a reference to the function we defined above.

But clearly, we need to somehow get inside that function - because we're recording data onto the `BuiltinTrace` which come from intermediate results of the execution of the function.

That's where JAX comes in - JAX provides a way to trace pure, numerical Python programs - enabling us to construct program transformations which return new functions that compute different semantics from the original function.^[Program tracing is an approach which has its roots in automatic differentiation. If you're interesting in this technique, I cannot recommend [Autodidax: JAX core from scratch](https://jax.readthedocs.io/en/latest/autodidax.html) enough. It will introduce you to enough interesting PL ideas to keep you occupied for months, if not years.]

Let's utilize one of JAX's interpreters to construct an intermediate representation of the function which our generative function object holds reference to:

In [7]:
jaxpr = jax.make_jaxpr(model.source)(1.0)
jaxpr

let _normal = { lambda ; a:key<fry>[]. let
    b:f32[1] = pjit[
      name=_normal_real
      jaxpr={ lambda ; c:key<fry>[]. let
          d:f32[1] = pjit[
            name=_uniform
            jaxpr={ lambda ; e:key<fry>[] f:f32[] g:f32[]. let
                h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] f
                i:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] g
                j:u32[1] = random_bits[bit_width=32 shape=(1,)] e
                k:u32[1] = shift_right_logical j 9
                l:u32[1] = or k 1065353216
                m:f32[1] = bitcast_convert_type[new_dtype=float32] l
                n:f32[1] = sub m 1.0
                o:f32[1] = sub i h
                p:f32[1] = mul n o
                q:f32[1] = add p h
                r:f32[1] = max h q
              in (r,) }
          ] c -0.9999999403953552 1.0
          s:f32[1] = erf_inv d
          t:f32[1] = mul 1.4142135381698608 s
        in (t,) }
    ] a
  in (b,) } in
let _

So `jax.make_jaxpr` takes a function `f :: A -> B` and returns a function `f :: A -> Jaxpr`, where `Jaxpr` is the program representation above.

When we run this function using Python's interpreter, JAX lifts the input to something called a `Tracer`, JAX keeps an internal stack of interpreters which redirect infix operations on `Tracer` instances and modify their behavior. Additionally, JAX exposes new primitives (like all the `NumPy` primitives) which wrap a function called `bind`. `bind` takes in `Tracer` arguments, looks through them (and the interpreter stack), selects the interpreter which should handle the call - and then the interpreter is allowed to `process_primitive` - invoking the semantics which the interpreter defines for that primitive.

`jax.make_jaxpr` uses the above process to walk the program, and construct the above intermediate representation.

Now, the point of having this representation is that we can transform it further! We can lower it to other representations (including things like XLA - the linear algebra accelerator that JAX utilizes to go high performance). We could also write _another interpreter_ which walks this representation, invokes other primitives with `bind`, etc - deferring further transformation to the next interpreter in line.

This (admittedly rough description) above is the secret behind JAX's compositional transformations.

## New semantics via program transformations

Let's examine the representation once more.

In [8]:
jaxpr

let _normal = { lambda ; a:key<fry>[]. let
    b:f32[1] = pjit[
      name=_normal_real
      jaxpr={ lambda ; c:key<fry>[]. let
          d:f32[1] = pjit[
            name=_uniform
            jaxpr={ lambda ; e:key<fry>[] f:f32[] g:f32[]. let
                h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] f
                i:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] g
                j:u32[1] = random_bits[bit_width=32 shape=(1,)] e
                k:u32[1] = shift_right_logical j 9
                l:u32[1] = or k 1065353216
                m:f32[1] = bitcast_convert_type[new_dtype=float32] l
                n:f32[1] = sub m 1.0
                o:f32[1] = sub i h
                p:f32[1] = mul n o
                q:f32[1] = add p h
                r:f32[1] = max h q
              in (r,) }
          ] c -0.9999999403953552 1.0
          s:f32[1] = erf_inv d
          t:f32[1] = mul 1.4142135381698608 s
        in (t,) }
    ] a
  in (b,) } in
let _

You'll notice that there is an intrinsic called `trace` here - which looks suspiciously similar to `genjax.trace` above.

`trace` is a custom primitive that GenJAX defines - by defining a new primitive, we can place a stub in the intermediate representation, which we can further transform to implement the semantics we wish to express.

### A high level view

Now, we need to transform it! Here's where some serious design decisions enter into the picture.

One thing you might notice about the `Jaxpr` above is that the the arity of the function is fixed, and so is the arity of the return value. But when we call `simulate` on our `model` - we get out something which is not a `h :: f32[]` (it's actually a [`jax.Pytree`](https://jax.readthedocs.io/en/latest/pytrees.html) with a lot more data - so we'd expect a lot more return values in the `Jaxpr`^[JAX flattens/unflattens `Pytree` instances on each side of the IR boundary - the IR is strongly typed, but only natively supports a few base types, and a few composite array types.]. 

What gives?

Here's where JAX's support for compositional application of interpreters comes into play. 

Instead of attempting to modify the IR above to change the arity of everything (a process which the authors expect would be quite painful, and buggy) - we can write another interpreter which walks the IR and evaluates it, but that interpreter can keep track of the state that we want to put into the `BuiltinTrace` at the end of the interface invocation.

Then, we can _stage out that interpreter_ to support JIT compilation, etc. I'll describe the process below in pseudo-types:

We start with `f :: A -> B`, and we stage it to get a new function `f' :: Type[A] -> Jaxpr`, then we write an interpreter `I` with signature `I :: (Jaxpr, A) -> (B, State)`. The application of `I` itself can also be staged.

So this is really nice - we don't have to munge the IR manually, we just get to write an interpreter to do the transformation for us. That's the power that JAX provides for us!

### Interpreter design decisions

With the high-level view in mind, we'll examine two of the interface implementations. The first is `simulate` - likely the easiest implementation to understand^[For this notebook, we're going to ignore the inference math that we wish to support!]. The second is `update`.

Now, in GenJAX, the interpreter is written to be re-usable for each of the interface functions. Because we've chosen to re-use the interpreter (and parametrize the transformation semantics by configuring the interpreter in other ways -- besides the implementation), you're going to see some complexity right out the gate.

The reason why this complexity is there is because we wish to expose _incremental computing optimizations_ in `update`. To support this customization, the interpreter can best be described as a _propagation interpreter_ - similar to Julia's abstract interpretation machinery (if you're familiar). A propagation interpreter treats the `Jaxpr` as an undirected graph - and performs interpretation by iterating until a fixpoint condition is satisfied. 

The high level pattern from the previous section is still true! But if you've written interpreters for something like [Structure and Interpretation of Computer Programs](https://en.wikipedia.org/wiki/Structure_and_Interpretation_of_Computer_Programs) before - this interpreter might be a slight shock to the system.

Here's a boiled down form of the `simulate_transform`:

In [10]:
def simulate_transform(f, **kwargs):
    def _inner(key, args):
        # Step 1: stage out the function to a `Jaxpr`.
        closed_jaxpr, (flat_args, in_tree, out_tree) = stage(f)(key, *args, **kwargs)
        jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.literals

        # Step 2: create a `Simulate` instance, which we parametrize
        # the propagation interpreter with.
        #
        # `Bare` is an instance of something called a `Cell` - the
        # objects which the propagation interpreter reasons about.
        handler = Simulate()
        final_env, ret_state = propagate(
            Bare,
            bare_propagation_rules,
            jaxpr,
            [Bare.new(v) for v in consts],
            list(map(Bare.new, flat_args)),
            [Bare.unknown(var.aval) for var in jaxpr.outvars],
            handler=handler,
        )

        # Step 3: when the interpreter finishes, we read the values
        # out of its environment.
        flat_out = safe_map(final_env.read, jaxpr.outvars)
        flat_out = map(lambda v: v.get_val(), flat_out)
        key_and_returns = jtu.tree_unflatten(out_tree, flat_out)
        key, *retvals = key_and_returns
        retvals = tuple(retvals)

        # Here's the handler state - remember the signature from
        # above `I :: (Jaxpr, A) -> (B, State)`, these fields
        # below are the `State`.
        score = handler.score
        chm = handler.choice_state
        cache = handler.cache_state

        # This returns all the things which we want to put
        # into `BuiltinTrace`.
        return key, (f, args, retvals, chm, score), cache

    return _inner

And, just to show you that this is the key behind how we implement `simulate`, I've copied the `BuiltinGenerativeFunction` class method for `simulate` below:

In [None]:
def simulate(self, key, args, **kwargs):
    assert isinstance(args, Tuple)
    key, (f, args, r, chm, score), cache = simulate_transform(self.source, **kwargs)(
        key, args
    )
    return key, BuiltinTrace(self, args, r, chm, cache, score)

We'll discuss `propagate` in a moment - but a few high-level things.

Note that the `simulate` method can be staged out / used with JAX's interfaces:

In [12]:
jitted = jax.jit(model.simulate)
key, sub_key = jax.random.split(key)
tr = jitted(sub_key, (1.0,))
console.print(tr)

That's because `simulate_transform` and the interpreter implementation itself for `propagate` are all JAX traceable.

The only difference between the `BuiltinTrace` which we first generated at the top of the notebook and this one is that `jax.jit` will lift the `1.0` argument to a `Tracer` type, versus the non-jitted interpreter which just uses the Python `float` value.

And again, we can also stage out our `simulate` implementation and get a `Jaxpr` back:

In [13]:
jax.make_jaxpr(model.simulate)(key, (1.0,))

let _normal = { lambda ; a:key<fry>[]. let
    b:f32[1] = pjit[
      name=_normal_real
      jaxpr={ lambda ; c:key<fry>[]. let
          d:f32[1] = pjit[
            name=_uniform
            jaxpr={ lambda ; e:key<fry>[] f:f32[] g:f32[]. let
                h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] f
                i:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] g
                j:u32[1] = random_bits[bit_width=32 shape=(1,)] e
                k:u32[1] = shift_right_logical j 9
                l:u32[1] = or k 1065353216
                m:f32[1] = bitcast_convert_type[new_dtype=float32] l
                n:f32[1] = sub m 1.0
                o:f32[1] = sub i h
                p:f32[1] = mul n o
                q:f32[1] = add p h
                r:f32[1] = max h q
              in (r,) }
          ] c -0.9999999403953552 1.0
          s:f32[1] = erf_inv d
          t:f32[1] = mul 1.4142135381698608 s
        in (t,) }
    ] a
  in (b,) } in
let _

Giving us our pure, array math code. You can't help but admit that that's pretty elegant! 

## How does `propagate` work?

Now, in this section - we're going to talk about the nitty gritty of `propagate` itself. What exactly is this interpreter doing? Let's examine the context surrounding the call to `propagate`:

```python
def simulate_transform(f, **kwargs):
    def _inner(key, args):
        closed_jaxpr, (flat_args, in_tree, out_tree) = stage(f)(
            key, *args, **kwargs
        )
        jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.literals
        handler = Simulate()
        final_env, ret_state = propagate(
            # A lattice type
            Bare,
            
            # Lattice propagation rules
            bare_propagation_rules,
            
            # The Jaxpr which we wish to interpret
            jaxpr,
            
            # Trace-time constants
            [Bare.new(v) for v in consts],
            
            # Input cells
            list(map(Bare.new, flat_args)),
            
            # Output cells
            [Bare.unknown(var.aval) for var in jaxpr.outvars],
            
            # How we handle `trace`.
            handler=handler,
        )
        ...

    return _inner
```

First, we stage our model function into a `Jaxpr` - when we perform the staging process, everything (e.g. custom datatypes which are `Pytree` implementors) gets flattened out to array leaves.

After we stage, we collect all the data which we want to use to initialize our interpreter's environment with - but we encounter our first bit of complexity. 

What is `Bare`? And what is a `Cell`? Let's start with the latter question: a `Cell` is an abstract type which represents a _lattice value_.

To understand what a _lattice value_ is - it's worth gaining a high-level picture of what `propagate` attempts to do. `propagate` is an interpreter based on mixed concrete/abstract interpretation - it treats the `Jaxpr` as a graph - where the operations are nodes in the graph, and the SSA values (e.g. the named registers like `ci`, `cj`, etc) are edges.

The interpreter will iterate over the graph - attempting to update information about the edges by applying _propagation rules_ (hence the name, `propagate`) which we define (`bare_propagation_rules`, above).

A propagation rule accepts a list of input cells (the SSA edges which flow into the operation) and a list of output cells. It returns a new modified list of input cells, and a new modified list of output cells, as well as a state value (in this notebook, we won't discuss the state value - it's unneeded for the interfaces we will describe). 

The way the interpreter works is that it keeps a queue of nodes and an environment which maps SSA values to lattice values. We pop a node off the queue, grab the existing lattice values for input SSA values and output SSA values, attempt to update them using a propagation rule, and then store the update in the environment. In addition, after we attempt to update the cells - _we determine if the update has changed the information level of any of the cells_. If the information level has changed for any cell (as measured using the partial order on lattice values), we add any nodes which the SSA value associated with that cell flows into back onto the queue.

This process describes an iterative algorithm which attempts to compute an information fixpoint - defined by a state transition function (which operates on the state of all cells in the `Jaxpr` - the environment) which we get to customize using propagation rules.

I'm not going to inline any of the implementation of this interpreter into this notebook. I'll refer the reader to [the implementation of the interpreter](https://github.com/probcomp/genjax/blob/main/src/genjax/core/propagate.py).^[Note that the ideas behind this interpreter are quite widespread - but the original implementation (which the GenJAX authors modified) came from [Oryx](https://github.com/jax-ml/oryx), and that implementation initially came from Roy Frostig (as far as we can tell).]

### What happens in `simulate`?

Great - so how do we utilize this interpreter idea to implement the `simulate_transform` described above?