## Introduction to Interpreted GenJAX

This notebook will give a tour of the interpreted dialect of GenJAX, the probabilistic computing system developed by the [MIT Probabilistic Computing Laboratory](http://probcomp.csail.mit.edu/). This dialect is meant to offer access to the Gen model of automatic inference in a way that avoids the constraints imposed by the acceleration technology in JAX. JAX acceleration offers an immense benefit in the speed at which inference can be done, but requires more up-front design work on the modeling and inference technique.

In particular, you have to commit to fixed size for the vectors and tensors used in your model, as well as avoid native Python control flow in favor of a representation for branching computation that is compatible with the numerical linear algebra that the accelerated code must use.

For the present, we'll set those concerns aside and proceed with a simple inference task in the easy-going interpreted dialect.

In [None]:
import genjax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from genjax import interpreted

key = jax.random.PRNGKey(314159)
console = genjax.console(enforce_checkify=True, width=60)

A few notes about the prefatory material above. GenJAX uses JAX, and JAX provides its own flavor of numpy, which we will call `jnp` in this notebook. The main difference between jnp and numpy is that vectors and tensors are immutable once constructed: instead of changing them, you must make a copy with the updates you require. We could use regular numpy in the interpreted dialect, but it turns out that it's not difficult to work with jnp even in the interpreted case.

GenJAX allows the production of reproducible scientific work through the use of a "splittable" random number generator. You can set the initial seed of the generator as we have done above. Then, when we need random numbers within a function, we will split the generator and hand one fork to the function and keep the other fork at the topmost level. In this way, provided that the notebook cells are evaluated from top to bottom, all the random choices will be made in the same way each time. We encourage this technique for your own work to produce reproducible scientific communication.

Let's create our first generative function:

In [None]:
@interpreted
def g(x):
    b = genjax.flip(x) @ "b"
    return b

The `@interpreted` decoration is a bridge between an ordinary Python function and the [Generative Function](https://www.gen.dev/docs/stable/ref/gfi/#Generative-Functions) interface, which is at the heart of the Gen model of probabilistic programming. You can regard the `@` sign as something like the $\sim$ operator in statistics literature (although we use it backwards) to describe b as a random variable with a Bernoulli distribution (like tossing a coin with $x$ as the probability of heads). The GenJAX system uses the string name `"b"` to record the name of the random variable; it cannot see the name of the Python variable to which you assigned the value, but of course it's convenient to use the same name in both cases. (It's worth remembering that there's also a `genjax.bernoulli` function but its argument is log-odds.)

The decoration has equipped our function $g$ with the method `simulate`, which will draw a value from the distribution:

In [None]:
key, sub_key = jax.random.split(key)
tr = g.simulate(sub_key, (0.3,))
console.print(tr)

Running the function has produced a tree structure called a _trace_ which records the result of random choices made during the function's execution. As generative functions call other generative functions, the trace will become more elaborate. The value itself is in an array; to see it, we can call 

In [None]:
tr.get_retval()

If we'd like to draw more of a sample, we'll need more subkeys to get the randomness we need, which we can do like this:

In [None]:
n_samples = 20
key, *sub_keys = jax.random.split(key, n_samples + 1)
[g.simulate(k, (0.3,)).get_retval().item() for k in sub_keys]

### Compound Models
Our little function `h` didn't do much: it was just a wrapper for `genjax.flip`, which is already a generative function. Let's create some more involved functions that mix distributions in interesting ways. We're going to consider a simple dataset based on a quadratic function with some outliers. The task is to write an inference algorithm that can infer the coefficients of the hidden polynomial while automatically classifying outliers (thus denying them the opportunity to skew the distribution).

In [None]:
def polynomial(coefficients):
    """Given coefficients of a polynomial a_0, a_1, ..., return a function
    computing a_0 x^0 + a_1 x^1 + ..."""

    def f(x):
        powers_of_x = jnp.array([x**i for i in range(len(coefficients))])
        return jnp.sum(coefficients * powers_of_x)

    return f


@interpreted
def model_y(x, f):
    """Given x and f, model f(x) plus a small amount of gaussian noise."""
    y = genjax.normal(f(x), 0.3) @ "value"
    return y


@interpreted
def outlier_model(x, f):
    """Like model_y, except this time we allow a huge variance in the noise,
    to model an outlying value"""
    y = genjax.normal(f(x), 30.0) @ "value"
    return y

Now we have a generative functions for a polynomial model with inliers and outliers. The next step is to _generate_ candidate polynomials with inlying and outlying points by welding these small generative functions together into a more elaborate model which represents our prior belief about the structure of the data we might observe. This involves flipping an (unfair) coin to determine whether we have an inlier or outlier:

In [None]:
@interpreted
def kernel(xs, f):
    y = []
    for i, x in enumerate(xs):
        is_outlier = genjax.flip(0.1) @ ("outlier", i)

        if is_outlier:
            model = outlier_model
        else:
            model = model_y

        y.append(model(x, f) @ ("y", i))

    return jnp.array(y)

We pause here to note one universal feature of GenJAX: you must give every traced value a unique name. Since we are generating a vector of $y$ values, we label each of them with the tuple $(\mathbf{y}, i)$. 

Finally we draw the polynomial coefficients from a multivariate normal distribution (which is fancier than we need. We're using it with a diagonal covariance matrix, which amounts to individual independent selections with no cross-correlation, but it makes the code easy to write and allows us to batch up the selection of all the polynomial's coefficients into one random variable `alpha`.) It's at this point that we are committing to a polynomial of degree $\le 2$.

In [None]:
@interpreted
def model(xs):
    coefficients = genjax.mv_normal(jnp.zeros(3), 2.0 * jnp.identity(3)) @ "alpha"
    f = polynomial(coefficients)
    ys = kernel(xs, f) @ "ys"
    return ys

In [None]:
data = jnp.arange(0, 10, 0.5)
key, sub_key = jax.random.split(key)
tr = model.simulate(sub_key, (data,))
tr.get_retval()

In [None]:
tr.get_retval()

We hope that this is an example generated by a degree 2 polynomial possibly with some outliers... but is it? It's common in the Gen world to "visualize the prior." Let's do that.

In [None]:
def visualize_prior(key, w, h):
    f, axes = plt.subplots(w, h, figsize=(8, 8), sharex=True, sharey=True)
    for ax in axes.flatten():
        key, sub_key = jax.random.split(key)
        tr = model.simulate(sub_key, (data,))
        ax.scatter(data, tr.get_retval())


key, sub_key = jax.random.split(key)
visualize_prior(sub_key, 3, 3)

Not bad. With the preliminaries out of the way, let's turn to the inference task that motivated this notebook. The idea is to observe some data in the wild--a "ground truth"--and test our hypothesis that it is a polynomial-with-outliers by inferring the parameters of our model. To do that, behind our back, we will select some polynomial coefficients and then manually inject an obvious outlying value:

In [None]:
xs = jnp.array([0.3, 0.7, 1.1, 1.4, 2.3, 2.5, 3.0, 4.0, 5.0])
ys = jnp.array(2.0 * xs + 1.5 + xs**2)
ys = ys.at[2].set(50.0)

In [None]:
fig_data, ax_data = plt.subplots(figsize=(6, 6))
ax_data.scatter(xs, ys, color="blue")

### Observations
The next step is to _constrain_ the y values to the observed data while letting the _model parameters_ roam freely. To do this, we construct a [ChoiceMap](https://probcomp.github.io/genjax/genjax/library/core/datatypes.html#genjax.core.ChoiceMap). At this point, the significance of the names of the random variables we have chosen with the `@` operator becomes clear, as well as the influence of the structure of the functions we have written. If you look back at those functions, you will see that the $(\mathbf{y}, i)$ values are nested under the variable $\bf{ys}$, and that each of those is assigned a $\bf{value}$ in the `outlier` and `model` functions. 

In [None]:
observations = genjax.choice_map()
for i, y in enumerate(ys):
    observations["ys", "y", i, "value"] = y

GenJAX provides numerous sophisticated inference algorithms designed to operate seamlessly with generative functions. For this introduction we will content ourselves with a simple one: Sequential Importance resampling. We won't spend too much time understanding this technique--it is described greater mathematical detail in other notebooks here--but will simply say that we generate a batch of random numbers, score them against the ground truth, and then use the scores as the weight of a random categorical choice (a kind of survival of the fittest), and repeat until, hopefully, the model explains our data.

In [None]:
def sampling_importance_resampling(model, n_samples):
    def _inner(key, observations, model_args):
        """Generate a list of importance samples. Each such sample returns a tuple of
        (trace, log_weight or "score"). Treat the list as a weighted ensemble of
        choices, and draw one. This is the result of one SIR step."""
        resample_key, sub_key = jax.random.split(key)
        sub_keys = jax.random.split(sub_key, n_samples)
        tr_lw_pairs = [
            model.importance(sub_key, observations, model_args) for sub_key in sub_keys
        ]
        lws = [tr_lw[1] for tr_lw in tr_lw_pairs]
        index = genjax.categorical.sample(resample_key, lws)
        return tr_lw_pairs[index][0]

    return _inner

The next step is to "sample from the posterior." Here we will see if the SIR algorithm is converging toward an explanation of the data:

In [None]:
N = 20
model_args = (xs,)
key, *sub_keys = jax.random.split(key, N + 1)
samples = [
    sampling_importance_resampling(model, 2 * N)(sub_key, observations, model_args)
    for sub_key in sub_keys
]
coefficients = [s["alpha"] for s in samples]
coefficients

In [None]:
def plot_polynomial_values(ax_inf, xs, coefficients, **kwargs):
    f = polynomial(coefficients)
    ax_inf.plot(xs, [f(x) for x in xs], alpha=0.2)

In [None]:
fig_inf, ax_inf = plt.subplots(figsize=(6, 6))
ax_inf.scatter(xs, ys, color="blue")
for cs in coefficients:
    plot_polynomial_values(ax_inf, xs, cs)

Not too bad. That took a while, but it does appear that the sampling procedure has accomplished both of its goals: it has found plausible polynomial coefficients, and further, it has declined to allow the outlier to overly influence the result. To verify that, we can calculate what fraction of the samples classified $y_2$ as an outlier like this:

In [None]:
outlier_at_2 = [s["ys", "outlier", 2] for s in samples]

In [None]:
np.sum(outlier_at_2) / len(outlier_at_2)

We invite you to dive deeper into GenJAX, but learning more about how easily JAX can be used to accelerate the computation we just performed, as well as the many state of the art inference techniques the Gen system offers. 