In [None]:
%config InlineBackend.figure_format = 'svg'

In [None]:
import genjax
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from genjax import ChoiceMap as C
from genjax import gen

sns.set_theme(style="white")

# Reproducibility.
key = jax.random.PRNGKey(314159)

## What is GenJAX?

Here are a few high-level ways to think about GenJAX:

* A probabilistic programming^[New to probabilistic programming? [Don't fret, read on!](#what-is-probabilistic-programming)] system based on the concepts of [Gen](https://www.gen.dev/).

* A Bayesian modelling and inference compiler with support for device acceleration (courtesy of JAX).

* A base layer for experiments in model and inference DSL design.

There's already a well-supported implementation of [Gen in Julia](https://github.com/probcomp/Gen.jl). Why is a JAX port interesting?

There are a number of compelling technical and social reasons to explore Gen's probabilistic programming paradigm on top of JAX, here are a few:

* JAX's excellent accelerator support - our implementation natively supports several common accelerator idioms - like automatic struct-of-array representations, and the ability to automatically batch model/inference programs onto accelerators.

* JAX's excellent support for compositional AD removes implementation and maintenance complexity for Gen's gradient interfaces - previously a difficult challenge in other implementations. In addition, JAX's support for convenient, higher-order AD opens up new opportunities to explore during inference design with gradient interfaces.

* JAX exposes compositional code transformations to library authors, and, as library authors, we can utilize code transformations to implement state-of-the-art optimizations for models and inference expressed in our system.

* A lot of domain experts and modelers are working in Python! Some of them even use JAX (hopefully more each year). Presenting an interface to Gen which is familiar, and takes advantage of JAX's native ecosystem is a compelling social reason.

Let's truncate the list there for now.

For the JAX literati, one final (hopefully tantalizing) takeaway: <u>by construction, all GenJAX modeling + inference code is JAX traceable</u> - and thus, `vmap`able, `jit`able, etc.

## What is probabilistic programming?

Perhaps you may be coming to this notebook without any prior knowledge in probabilistic programming...

That's okay! Ideally, the ideas in this notebook should be self-contained^[You may miss _why generative functions (see below) are designed the way they are_ on a first read - but you'll still get the punchline if you follow the notebook to the end.].

### A Bayesian viewpoint

Here's one practical take on what probabilistic programming is all about: programming language design for expressing and solving Bayesian inference problems^[In the [Probabilistic Computing lab at MIT](http://probcomp.csail.mit.edu/), we also consider differentiable programming to be contained within the set of concerns of probabilistic programming. We won't cover differentiable programming interfaces in this notebook.]. 

Probabilistic programming is a broad field, and there are corners which may not be covered by this practical take. We'll just assume that people are interested in Bayes, and how to represent Bayes on computers in nice ways. For our purposes in this notebook, we'll stick as much as we can to the basics.

### What are we actually computing with?

The objects which we program with expose a mixture of generative and differentiable interfaces - the interfaces are designed to support common (as well as quite advanced) classes of Bayesian inference algorithms. Gen provides automation for the tricky math which these algorithms require.

We separate the design of inference (whose implementation uses the interfaces), from the implementation of the interfaces on computational objects. This allows us to build languages of objects which satisfy the interfaces - and allows their compositional usage and interoperability.

In Gen, the objects which implement the interfaces are called **generative functions**.

## What is a generative function?

Generative functions are the key concept of Gen's probabilistic programming paradigm. Generative functions are computational objects defined by a set of associated data types and methods. These types and methods describe compositional interfaces that are useful for Bayesian inference computations. 

Gen's formal description of generative functions consist of two objects:

* $P(\tau, r; x)$ - a normalized measure over tree-like data (*choice maps*) and untraced randomness^[More on this later. It's safe to say "I have no idea what that is" for now, and expect us to explain later or in another notebook.] $r$, parametrized by arguments $x$.

* $f(\tau; x)$ - a deterministic function from the above measure's sample space to a space of data types.

We can informally think of the sampling semantics of these objects as consisting of two steps:

1. First, sample a  choice map from $P$.
2. Then, compute the return value using $f$.

In many of the generative function interfaces, we won't just be interested in the final sampled return value. We'll also be interested in what happened along the way: we'll record the intermediate and final results of these steps in `Trace` objects - data structures which contain the recordings of values, along with probabilistic metadata like the score of random choices selected along the way.

Below, we provide an example of an (admittedly, not too interesting) GenJAX generative function^[There's not just one generative function class - users can and are encouraged to design new types of generative functions which capture repeated modeling patterns. An excellent example of this modularity in Gen's design is [generative function combinators](/introduction/intro_to_combinators/intro_to_combinators).]. 

This generative function is part of a function-like language (the `BuiltinGenerativeFunction` language) - pay close attention to the hierarchical compositionality of generative functions in this language under an abstraction (`genjax.trace`) similar to a function call. We'll discuss the addresses (`"sub"` and `"m0"`) a bit later.

In [None]:
@gen
def g(x):
    m0 = genjax.bernoulli(x) @ "m0"  # unsweetened
    return m0


@gen
def h(x):
    m0 = g(x) @ "sub"  # sweetened
    return m0


h

This generative function holds a Python `Callable` object. For this generative function language, the interface methods (see the list under **Generative function interface** below) which are useful for modeling and inference are given semantics via JAX's tracing and program transformation infrastructure.

Let's examine some of these operations now.

This is our first glimpse of the **generative function interface** (GFI), the secret sauce which Gen is based around.

::: {.callout-note}

## JAX interfaces

There's a few methods which you might see which are not explicitly part of Gen's GFI, but are worth mentioning because they deal with data interfaces to JAX:

* `flatten` - which allows us to treat generative functions as [Pytree](https://jax.readthedocs.io/en/latest/pytrees.html) implementors.
* `unflatten` - same as above.

These are used to register the implementor type as a `Pytree`, which is roughly a tree-like Python structure which JAX can zip/unzip at runtime API boundaries.

:::

Let's study the `simulate` method first: we'll explore its semantics, and see the types of data it produces.

In [None]:
key, sub_key = jax.random.split(key)
tr = h.simulate(sub_key, (0.3,))
tr

If you're familiar with other "trace-based" probabilistic systems - this should look familiar. 

This object instance is a piece of data which has captured information about the execution of the function. Specifically, the subtraces of _other generative function calls_ which occur in `genjax.trace` statements.

It also captures the `score` - the log probability of the normalized measure which the model program represents, evaluated at the random choices which the generative call execution produced. 

If you were paying attention above, the score is $\log P(\tau, r; x)$.

### How is `simulate` implemented for this language?

For this generative function language, we implement `simulate` using a code transformation! Here's the transformed code.

In [None]:
jaxpr = jax.make_jaxpr(h.simulate)(key, (0.3,))
jaxpr

That's a lot of code! This code is pure, numerical, and ready for acceleration. By utilizing JAX and a staging transformation, we've stripped out all Python overhead.

This is how we've implemented `simulate` for this particular generative function language.^[In general, Gen doesn't require that we follow the same "JAX code transformation" implementation for other generative function languages. The relationship with JAX in GenJAX, however, is a bit special - often, we assume that the user is working within the JAX traceable subset of Python - hence, generative function interface implementations which are JAX traceable benefit from composition with other JAX compatible modeling structures.]

## Generative function interface

There are a few more generative function interface methods worth discussing.

In this notebook, instead of carefully walking through the math which these interface methods compute, we'll defer that discussion to another notebook. Below, we give an informal discussion of what each of the interface methods computes, and roughly describe what algorithm families are supported by their usage.

### The generative function interface in GenJAX

GenJAX's generative functions define an interface which support compositional usage of generative functions within other generative functions. The interface functions here closely mirror [the interfaces defined in Gen.jl](https://www.gen.dev/docs/stable/ref/gfi/#Generative-function-interface-1).

In the following, we use the following abbreviations:

* **IS** - importance sampling
* **SMC** - sequential Monte Carlo
* **MCMC** - Markov chain Monte Carlo
* **VI** - variational inference

| Interface | Type | Inference algorithm support |
| --- | --- | --- |
| `simulate` | Generative | IS, SMC |
| `importance` | Generative | IS, SMC, VI |
| `update` | Generative and incremental | MCMC, SMC |
| `assess` | Generative and differentiable | MCMC, IS, SMC |

This interface supports several methods - I've roughly described them and split them into the two categories **Generative** and **Differentiable** below:

#### Generative

* `simulate` - sample from normalized trace measure, and return the score.
* `importance` - given constraints for some addresses, sample from unnormalized trace measure and return an importance weight.
* `update` - given an existing trace, and a set of constraints and argument change values, update the trace to be consistent with the set of constraints under execution with the new arguments, and return an incremental importance weight.
* `assess` - given a complete choice map and arguments, return the normalized log probability.

## More about generative functions

Here are a few more bits of information which should help you gain context with these objects.

### Distributions are generative functions

In GenJAX, distributions are generative functions.

In [None]:
key, sub_key = jax.random.split(key)
tr = genjax.normal.simulate(sub_key, (0.0, 1.0))
tr.pprint()

This should bring a sigh of relief! Ah, distributions are generative functions - *the concepts can't be too exotic*.

Distributions implement the interface in a conceptually simple way. They don't have internal compositional choice structure (like the function-like `BuiltinGenerativeFunction` language above).

Distributions themselves expose two interfaces:

* `logpdf` - exact density evaluation.
* `sample` - exact sampling.

We can use these two interfaces to implement all the generative function interfaces for distributions.

### Associated data types

* **Choice maps** are the tree-like representations of the values sampled at random choices inside of generative functions.
* **Selections** are objects which allows querying a trace/choice map - filtering certain choices, projecting joint log score computations onto address contributions, even manipulating choice maps. For those familiar with functional programming - they present a lens-like interface on choice maps.

In [None]:
@gen
def h(x):
    m1 = genjax.bernoulli(x) @ "m0"
    m2 = genjax.bernoulli(x) @ "m1"
    return m1 + m2


key, sub_key = jax.random.split(key)
tr = h.simulate(sub_key, (0.3,))
tr.pprint()

In [None]:
selection = genjax.Selection.at["m1"]
selected = tr.get_choices().filter(selection)
selected

## What can I do with them?

Now, we've informally seen the interfaces and datatypes associated with generative functions.

Studying the interfaces (and improvements thereof), as well as the computational objects which satisfy them can be an entire PhD's worth of effort. 

In the remainder of this notebook, let's see how we can do machine learning with them.

Let's consider a modeling problem where we wish to perform generalized regression with outliers between two variates, taking a family of polynomials as potential curves.

One such model for this data generating process is shown below.

In [None]:
# Two branches for a branching submodel.
@gen
def model_y(x, coefficients):
    basis_value = jnp.array([1.0, x, x**2])
    polynomial_value = jnp.sum(basis_value * coefficients)
    y = genjax.normal(polynomial_value, 0.3) @ "value"
    return y


@gen
def outlier_model(x, coefficients):
    basis_value = jnp.array([1.0, x, x**2])
    polynomial_value = jnp.sum(basis_value * coefficients)
    y = genjax.normal(polynomial_value, 30.0) @ "value"
    return y


# The branching submodel.
switch = genjax.switch_combinator(model_y, outlier_model)

# A mapped kernel function which calls the branching submodel.


@genjax.vmap_combinator(in_axes=(0, None))
@gen
def kernel(x, coefficients):
    is_outlier = genjax.flip(0.1) @ "outlier"
    is_outlier = jnp.array(is_outlier, dtype=int)
    y = switch(is_outlier, (x, coefficients), (x, coefficients)) @ "y"
    return y


@gen
def model(xs):
    coefficients = (
        genjax.mv_normal(np.zeros(3, dtype=float), 2.0 * np.identity(3)) @ "alpha"
    )
    ys = kernel(xs, coefficients) @ "ys"
    return ys

There's a few implementation patterns which you might pick up on by studying this model.

1. To implement control flow, we use higher-order functions called combinators. These accept generative functions as input, and return generative functions as output.

2. Any JAX compatible code is allowed in the body of a generative function.

Courtesy of the interface, we get to design our `model` generative function in pieces.

Now, let's examine the sampled observation address `("ys", "y")` from a sample trace from our model.

In [None]:
data = jnp.arange(0, 10, 0.5)
key, sub_key = jax.random.split(key)
tr = jax.jit(model.simulate)(sub_key, (data,))
tr.get_sample()

Here, I'm just showing a concise, pretty printed representation of the choice map -- but it doesn't tell us much about the values we truly care about here - the sampled values.

From this model, we can get these in two ways. 

The first way: we can just look at the trace return value.

In [None]:
tr.get_retval()

The second way: we can get them out of the choice map of the trace directly.

In [None]:
chm = tr.get_choices()
values = [chm["ys", i, "y", "value"] for i in range(len(data))]
values

Now, let's construct a small visualization function to show us the samples.

In [None]:
def get_ys(chm):
    return jax.vmap(lambda v: chm["ys", v, "y", "value"])(jnp.arange(0, len(data)))


def viz(ax, x, y, **kwargs):
    chm = tr.get_choices()
    get_ys(chm)
    sns.scatterplot(x=x, y=y, ax=ax, **kwargs)


f, axes = plt.subplots(3, 3, figsize=(8, 8), sharex=True, sharey=True)
jitted = jax.jit(model.simulate)
for ax in axes.flatten():
    key, sub_key = jax.random.split(key)
    tr = jitted(key, (data,))
    x = data
    y = tr.get_retval()
    viz(ax, x, y, marker=".")

plt.show()

These are the `("ys", "y", "value")` samples for 9 traces from our model, against the points from the data we passed.

We just walked through one of the main elements of probabilistic programming: setting up a program, which represents a joint distribution over random variates, some of which we'll identify with data we expect to see in the world.

We can adjust the noise settings of our model to produce wider priors over possible sets of points - and we may want to do this if our data is noisy!

For now, let's keep the settings as is, and explore inference in GenJAX.

## Your first inference program

Now, let's say we have some data.

In [None]:
x = np.array([0.3, 0.7, 1.1, 1.4, 2.3, 2.5, 3.0, 4.0, 5.0])
y = 2.0 * x + 1.5 + x**2
y[2] = 50.0

In [None]:
fig_data, ax_data = plt.subplots(figsize=(6, 6))
viz(ax_data, x, y, color="blue")

In Bayesian inference, if we wish to consider the conditional distribution $P(\tau, r; x | \text{data})$ induced from a model $P(\tau, \text{data}, r; x)$ - Bayes' rule gives us a way to compute it.

$$
P(\tau, r; x | \text{data}) = \frac{P(\tau, \text{data}, r; x)}{\int P(\tau, \text{data}, r; x) \ d\tau}
$$

The problem is that we often cannot compute the denominator (_the evidence integral_) easily. Instead, we turn to _approximate_ Bayesian inference.

Depending on how we wish to use the LHS conditional (which is called _the posterior_ in Bayesian inference) - we have different options available to us.

If we wish to approximately sample from the posterior, to get an empirical sense of its shape and properties, we will often utilize techniques which provide exact samplers for another distribution which gets asymptotically close to the target posterior if we increase certain hyperparameters.

One such algorithm is _importance sampling_, and that's what we'll write today.

Here's importance sampling (with a single sample) without a custom proposal in GenJAX.

In [None]:
obs = jax.vmap(lambda idx, v: C.n.at["ys", idx, "y", "value"].set(v))(
    jnp.arange(len(y)), y
)
key, sub_key = jax.random.split(key)
(tr, w) = model.importance(sub_key, obs, (x,))

We're introduced to another interface method! 

`importance` accepts a PRNG key, a choice map representing observations (sometimes called constraints), and model arguments. It returns a new evolved PRNG key, and a tuple contained a log _importance weight_ and a trace.

The trace is consistent with the arguments and constraints passed into the invocation.

In [None]:
[observations["ys", i, "y", "value"] for i in range(0, len(y))]

Let's examine the weight now, and compare it to the score.

In [None]:
(w, tr.get_score())

Notice that these two quantities are different. 

Remember: the score is the normalized log density of the choice map measure evaluated at the complete set of trace constraints. We'll refer to complete traces by $\tau$.

The log importance weight `w` is slightly different.

### Importance sampling, informally

Let's discuss how importance sampling works first^[In this notebook, I'll defer discussing formal proofs concerning the asymptotic consistence of posterior estimators derived from importance sampling.].

This will provide us with an understanding as to why `w` is different from the `score`. 

More importantly, we'll understand how we can use `importance` to solve the inference task of approximately sampling from the posterior over coefficients (and ultimately, over curves) from our generative function.

::: {.callout-note}

Importance sampling is typically presented by focusing on posterior expectations $E_{x \sim P(x | y)}[f(x)]$.

In our case, we want to sample $x \sim P(x | y)$. To do this, we'll actually be considering a different procedure called *sampling importance resampling* or SIR for short.

Importantly, importance sampling is a subroutine in SIR. 

We'll discuss why importance sampling works here, and provide references to why SIR works to solve our problem.

:::

Let's start by considering two distributions which we can sample from, and evaluate densities.

Below, I'm plotting the densities of two distributions - a 1D Gaussian mixture and a 1D Gaussian^[GenJAX allows usage of TensorFlow Distributions as generative functions. Here, we're just using the `logpdf` interface from distributions which expose exact `logpdf` evaluation - but `genjax` exports a wrapper which implements the complete generative function interface.].

In [None]:
mix = genjax.mixture(genjax.categorical, [genjax.normal, genjax.normal])
mix_args = ([0.5, 0.5], [(-3.0, 0.8), (1.0, 0.3)])
d = genjax.normal
d_args = (0.0, 1.0)


fig, ax = plt.subplots(figsize=(8, 8))
evaluation_points = np.arange(-5, 5, 0.01)


def plot_logpdf(ax, logpdf_fn, evaluation_points, **kwargs):
    logpdfs = jax.vmap(logpdf_fn)(evaluation_points)
    ax.scatter(evaluation_points, jnp.exp(logpdfs), marker=".", **kwargs)


def d_logpdf(v):
    return d.logpdf(v, *d_args)


def mix_logpdf(v):
    return mix.logpdf(v, *mix_args)


plot_logpdf(ax, d_logpdf, evaluation_points, color="red", label="1D Gaussian PDF")
plot_logpdf(
    ax,
    mix_logpdf,
    evaluation_points,
    color="blue",
    label="1D Gaussian mixture PDF",
)
ax.legend()

To gain context on importance sampling, imagine that the distribution which produces the blue curve is difficult to sample from - but it exposes a `logpdf` interface which we can use to evaluate the density at any point on the support of the distribution.

Now, suppose you hand me the distribution which made the red curve - and it is easy to sample from, and it also exposes a `logpdf` interface. 

One thing we could do is sample from the red curve and then "correct" for the fact that we're sampling from the wrong distribution.

This is the key intuition behind importance sampling.

Now, I'm going to write a procedure and ask you to just go with it ... for a moment.

In [None]:
def importance_sample(hard, easy):
    def _inner(key, hard_args, easy_args):
        sample = easy.sample(key, *easy_args)
        easy_logpdf = easy.logpdf(sample, *easy_args)
        hard_logpdf = hard.logpdf(sample, *hard_args)
        importance_weight = hard_logpdf - easy_logpdf
        return (importance_weight, sample)

    return _inner

In [None]:
hard = genjax.mixture(genjax.categorical, [genjax.normal, genjax.normal])
easy = genjax.normal
jitted = jax.jit(importance_sample(hard, easy))
key, sub_key = jax.random.split(key)
(importance_weight, sample) = jitted(sub_key, mix_args, d_args)

In [None]:
(importance_weight, sample)

Now, we can easily run this procedure many times in parallel.

In [None]:
jitted = jax.jit(jax.vmap(importance_sample(hard, easy), in_axes=(0, None, None)))

In [None]:
key, *sub_keys = jax.random.split(key, 100 + 1)
sub_keys = jnp.array(sub_keys)
(importance_weight, sample) = jitted(sub_keys, mix_args, d_args)

In [None]:
importance_weight

We're just sampling from `easy`, then scoring the samples with `importance_weight` according to the log ratio `easy_logpdf(sample) - hard_logpdf(sample)`.

Here's the trick - from our collection of samples and weights, let's normalize the weights into a distribution and sample a single sample to return using it.

In [None]:
def sampling_importance_resampling(hard, easy, n_samples):
    def _inner(key, hard_args, easy_args):
        fn = importance_sample(hard, easy)
        resample_key, _sub_key = jax.random.split(key)
        sub_keys = jax.random.split(key, n_samples)
        vmapped = jax.vmap(fn, in_axes=(0, None, None))
        (ws, samples) = vmapped(sub_keys, hard_args, easy_args)
        logits = ws
        index = genjax.categorical.sample(resample_key, logits)
        final_sample = samples[index]
        return final_sample

    return _inner

In [None]:
hard = genjax.mixture(genjax.categorical, [genjax.normal, genjax.normal])
easy = genjax.normal
jitted = jax.jit(sampling_importance_resampling(hard, easy, 100))
key, sub_key = jax.random.split(key)
sample = jitted(sub_key, mix_args, d_args)

In [None]:
sample

Let's run this procedure a bunch of times and plot the points on the x-axis of our plot above.

In [None]:
def plot_on_x(ax, x, **kwargs):
    ax.scatter(x, np.zeros_like(x), **kwargs)


key, *sub_keys = jax.random.split(key, 1000 + 1)
sub_keys = jnp.array(sub_keys)
fn = sampling_importance_resampling(hard, easy, 1000)
jitted = jax.jit(jax.vmap(fn, in_axes=(0, None, None)))
samples = jitted(sub_keys, mix_args, d_args)
plot_on_x(ax, samples, color="gold", marker=".", alpha=0.05)
fig

Notice what happens with the SIR samples (in gold)?

They accumulate around the places you'd expect to see if you were sampling from the `hard` distribution!

That's what importance sampling and sampling importance sampling give us - we provide a "hard" distribution with a `logpdf` interface, and another "easy" distribution with `sample` and `logpdf` interface^[There are more constraints. The second distribution must be _absolutely continuous_ in measure with respect to the first. Let's defer this discussion to a formal treatment of importance sampling.], and SIR returns an exact sampler for a distribution which approximates the hard distribution.

### Back to our generative function

Now that we've seen the ingredients and implementation of importance sampling and sampling importance resampling - let's return to our original problem.

In [None]:
fig_data

If you studied the previous section careful - one question might jump out at you: what is the "easy" distribution for `model.importance`?

#### Builtin proposals

Generative functions defined using the `BuiltinGenerativeFunction` language come with builtin proposals - it's a distribution (which we'll refer to as $Q$) induced from the prior, with sampling and `score` defined ancestrally.

Give observation constraints $u$, the importance weight which `model.importance` computes is^[This definition again considers "untraced randomness" $r$. If you wish to ignore this in the math, just remove the $Q(r; x, \tau)$ term. Even in the presence of untraced randomness, the weights which Gen computes are asymptotically consistent in expectation over $Q(r; x, \tau)$]:

$$
\begin{align}
\log w &= \log P(\tau, r; x) - \log Q(\tau; u, x)Q(r; x, \tau) \\
\end{align}
$$

For the `BuiltinGenerativeFunction` language, we implement $Q$ by invoking the generative function - when we arrive at a constrained address, we recursively called `submodel.importance` - accumulate the log weight, as well as the log score.

Now, if an address has no constraints - we get `0.0` for the weight (think about why this is by looking at the above equation and asking what happens when $Q$ has to generate a full $\tau$). However, we still get a score.

#### Sequential importance resampling in GenJAX

Here's SIR using builtin proposals (just a single call to `model.importance`) in GenJAX^[To implement a variant with custom proposals, all we need to do is first `proposal.simulate`, merge the proposal choice map with the constraints, then `model.importance` followed by a final weight adjustment `w = w - proposal_tr.get_score()` - easy peasy.]:

In [None]:
def sampling_importance_resampling(model, n_samples):
    def _inner(key, observations, model_args):
        resample_key, sub_key = jax.random.split(key)
        sub_keys = jax.random.split(sub_key, n_samples)
        vmapped = jax.vmap(model.importance, in_axes=(0, None, None))
        (trs, lws) = vmapped(sub_keys, observations, model_args)
        index = genjax.categorical.sample(resample_key, lws)
        final_tr = jtu.tree_map(lambda v: v[index], trs)
        return final_tr

    return _inner

One difference between our first implementation (on just distributions) above and this one is that `Trace` instances are structured objects (but all of them are `Pytree` implementors) - meaning we need to index into the leaves when we wish to return a single sampled trace.

In [None]:
model_args = (x,)
jitted = jax.jit(
    jax.vmap(sampling_importance_resampling(model, 100), in_axes=(0, None, None))
)
key, *sub_keys = jax.random.split(key, 100 + 1)
sub_keys = jnp.array(sub_keys)
samples = jitted(sub_keys, observations, model_args)
coefficients = samples["alpha"]

So now we have an approximate sampler for the posterior and we can use it to look at properties of the posterior - like what sort of curves are likely given our data and our model prior.

... and by the way, to get a representative set of samples from the posterior for this model, on an Apple M2 device - only takes about 0.05 seconds^[Just remember: we're running this notebook on CPU - but the resulting specialized inference code can easily be moved to accelerators, courtesy of the fact that all our code is JAX traceable.].

In [None]:
%%timeit
samples = jitted(sub_keys, observations, model_args)

In [None]:
def polynomial_at_x(x, coefficients):
    basis_values = jnp.array([1.0, x, x**2])
    polynomial_value = jnp.sum(coefficients * basis_values)
    return polynomial_value


jitted = jax.jit(jax.vmap(polynomial_at_x, in_axes=(None, 0)))

In [None]:
def plot_polynomial_values(ax, x, coefficients, **kwargs):
    v = jitted(x, coefficients)
    ax.scatter(np.repeat(x, len(v)), v, **kwargs)


coefficients = samples["alpha"]
evaluation_points = np.arange(0, 5, 0.01)
for data in evaluation_points:
    plot_polynomial_values(
        ax_data, data, coefficients, marker=".", color="gold", alpha=0.01
    )
fig_data

Intuitively, this makes a lot of sense. Our prior over polynomials considers a wide range of curves - but, if our approximate sampling process is trusted, we're correctly seeing what we should expect to happen if we observed this data - polynomials with the coefficients shown above tend to be sampled more under the posterior.

We can also ask for an estimate of the posterior probability that any particle point was an outlier. 

For example, below is the set of samples projected onto the `("ys", "outlier")` address for the point which we manually set to be quite far from the curve.

In [None]:
chs = jax.vmap(lambda idx: samples.get_choices()["ys", idx, "outlier"])(jnp.arange(100))
outlier_at_2 = chs[:, 2]

In [None]:
np.sum(outlier_at_2) / len(outlier_at_2)

That seems to make sense! We pulled that point quite far away from ground truth curve - so we'd expect that point 2 is considered an outlier under the true posterior.

## Summary

We've covered a lot of ground in this notebook. Please reflect, re-read, and post issues!

* We discussed the Gen probabilistic programming framework, and discussed GenJAX - an implementation of Gen on top of JAX. 
* We discussed _generative functions_ - the main computational object of Gen.
* We discussed how to create generative functions using _generative function languages_, and several of GenJAX's builtin capabilities for constructing generative functions.
* We discussed how to use generative functions to represent joint probability distributions, which can be used to construct models of phenomena.
* We created a generative function to model a data-generating process based on sampling and evaluating random polynomials at input data - to represent a typical regression task.
* We discussed how to formulate questions about induced conditional distributions under a probabilistic model as a Bayesian inference problem.
* We discussed importance sampling and sampling importance resampling, two central techniques in approximate Bayesian inference.
* We created a sampling importance resampling routine and applied it to produce approximate posterior samples from the posterior in the our polynomial generating model.
* We investigated the approximate posterior samples, and visually inspected that they match the inferences that we might draw - both for the polynomials we expected to produce the data, as well as what data points might be outliers.

This is just the beginning! There's a lot more to learn, but plenty to bite off with this initial notebook.