Over the past year or so, I've been using [JAX](https://jax.readthedocs.io) extensively for my research, and I've also been encouraging other astronomers to give it a try.
In particular, I've been using JAX as the computation engine for probabilistic inference tasks.
There's more to it, but one way that I like to think about JAX is as NumPy with just-in-time compilation and [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation).
The just-in-time compilation features of JAX can be used to speed up you NumPy computations by removing some Python overhead and by executing it on your GPU.
Then, automatic differentiation can be used to efficiently compute the derivatives of your code with respect to its input parameters.
These derivatives can substantially improve the performance of numerical inference methods (like maximum likelihood or Markov chain Monte Carlo) and for other tasks such as Fisher information analysis.

This post isn't meant to be a comprehensive introduction to JAX (take a look at [the excellent JAX docs](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) for more of that) or to automatic differentiation ([I've written some words](https://docs.exoplanet.codes/en/latest/tutorials/autodiff/) about that, and [so have many others](https://www.google.com/search?q=automatic+differentiation)), but rather an introduction to the JAX ecosystem for probabilistic inference, with some examples that will be familiar to astronomers.
From my perspective, one benefit of the JAX ecosystem compared to other similar tools available in Python (e.g. [PyMC](https://www.pymc.io), [Stan](https://mc-stan.org), etc.) is that it's generally more modular.
In practice, this means that you can (relatively) easily combine different JAX libraries to develop your preferred workflow.
For example, you can build a probabilistic model using [NumPyro](https://num.pyro.ai) that uses [tinygp](https://tinygp.readthedocs.io) for Gaussian Processes, and then run a Markov chain Monte Carlo (MCMC) analysis using [BlackJAX](https://github.com/blackjax-devs/blackjax).

In this post, however, I'll focus primarily on providing an introduction to [NumPyro](https://num.pyro.ai), which is a probabilistic programming library that provides an interface for defining probabilistic models and running inference algorithms.
At this point, NumPyro is probably the most mature JAX-based probabilistic programming library, and [its documentation page](https://num.pyro.ai) has a lot of examples, but I've found that these docs are not that user-friendly for my collaborators, so I wanted to provide a different perspective.
In the following sections, I'll present two examples:

1. The first example is a fairly simple linear regression problem that introduces some basic NumPyro concepts. In the second half of this example, we will re-implement the model from [my "Mixture Models" post](https://dfm.io/posts/mixture-models/) to account for outliers in the simulated dataset, while also introducing some more advanced elements.

2. The second example is an astronomy-specific problem that is designed to really highlight the power of these methods. In this example, we will measure the distance to the [M67 open cluster](https://en.wikipedia.org/wiki/Messier_67) using a huge hierarchical model for the observed Gaia parallaxes of stars in the direction of M67. This example includes running an MCMC sampler with thousands of parameters, which would be intractable with the tools commonly used by astronomers, but only takes a few minutes to run using NumPyro.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# We'll choose the parameters of our synthetic data.
# The outlier probability will be 80%:
true_frac = 0.8

# The linear model has unit slope and zero intercept:
true_params = [1.0, 0.0]

# The outliers are drawn from a Gaussian with zero mean and unit variance:
true_outliers = [0.0, 1.0]

# For reproducibility, let's set the random number seed and generate the data:
np.random.seed(12)
x = np.sort(np.random.uniform(-2, 2, 15))
yerr = 0.2 * np.ones_like(x)
y = true_params[0] * x + true_params[1] + yerr * np.random.randn(len(x))

# Those points are all drawn from the correct model so let's replace some of
# them with outliers.
m_bkg = np.random.rand(len(x)) > true_frac
y[m_bkg] = true_outliers[0]
y[m_bkg] += np.sqrt(true_outliers[1] + yerr[m_bkg] ** 2) * np.random.randn(sum(m_bkg))

# Then save the *true* line.
x0 = np.linspace(-2.1, 2.1, 200)
y0 = np.dot(np.vander(x0, 2), true_params)

# Plot the data and the truth.
plt.errorbar(x, y, yerr=yerr, fmt=",k", ms=0, capsize=0, lw=1, zorder=999)
plt.scatter(x[m_bkg], y[m_bkg], marker="s", s=22, c="w", edgecolor="k", zorder=1000)
plt.scatter(x[~m_bkg], y[~m_bkg], marker="o", s=22, c="k", zorder=1000)
plt.plot(x0, y0, color="k", lw=1.5)
plt.xlabel("$x$")
plt.ylabel("$y$")
plt.ylim(-2.5, 2.5)
plt.xlim(-2.1, 2.1);

In [None]:
import jax
import jax.numpy as jnp

import numpyro
from numpyro import distributions as dist, infer

numpyro.set_host_device_count(2)


def linear_model(x, yerr, y=None):
    # These are the parameters that we're fitting and we're required to define explicit
    # priors using distributions from the numpyro.distributions module.
    theta = numpyro.sample("theta", dist.Uniform(-0.5 * jnp.pi, 0.5 * jnp.pi))
    b_perp = numpyro.sample("b_perp", dist.Normal(0, 1))

    # Transformed parameters (and other things!) can be tracked during sampling using
    # "deterministics" as follows:
    m = numpyro.deterministic("m", jnp.tan(theta))
    b = numpyro.deterministic("b", b_perp / jnp.cos(theta))

    # Then we specify the sampling distribution for the data, or the likelihood function.
    # Here we're using a numpyro.plate to indicate that the data are independent. This
    # isn't actually necessary here and we could have equivalently omitted the plate since
    # the Normal distribution can already handle vector-valued inputs. But, it's good to
    # get into the habit of using plates because some inference algorithms or distributions
    # can take advantage of knowing this structure.
    with numpyro.plate("data", len(x)):
        numpyro.sample("y", dist.Normal(m * x + b, yerr), obs=y)


sampler = infer.MCMC(
    infer.NUTS(linear_model),
    num_warmup=2000,
    num_samples=2000,
    num_chains=2,
    progress_bar=True,
)
%time sampler.run(jax.random.PRNGKey(0), x, yerr, y=y)

In [None]:
import corner
import arviz as az

inf_data = az.from_numpyro(sampler)
corner.corner(inf_data, var_names=["m", "b"], truths=true_params)
az.summary(inf_data)

In [None]:
from numpyro_ext.distributions import MixtureGeneral


def linear_mixture_model(x, yerr, y=None):
    # Foreground model
    theta = numpyro.sample("theta", dist.Uniform(-0.5 * jnp.pi, 0.5 * jnp.pi))
    b_perp = numpyro.sample("b_perp", dist.Normal(0.0, 1.0))
    m = numpyro.deterministic("m", jnp.tan(theta))
    b = numpyro.deterministic("b", b_perp / jnp.cos(theta))
    fg_dist = dist.Normal(m * x + b, yerr)

    # Background model
    bg_mean = numpyro.sample("bg_mean", dist.Normal(0.0, 1.0))
    bg_sigma = numpyro.sample("bg_sigma", dist.HalfNormal(3.0))
    bg_dist = dist.Normal(bg_mean, jnp.sqrt(bg_sigma**2 + yerr**2))

    # Mixture
    Q = numpyro.sample("Q", dist.Uniform(0.0, 1.0))
    mix = dist.Categorical(probs=jnp.array([Q, 1.0 - Q]))
    numpyro.sample("obs", MixtureGeneral(mix, [fg_dist, bg_dist]), obs=y)


sampler = infer.MCMC(
    infer.NUTS(linear_mixture_model),
    num_warmup=2000,
    num_samples=2000,
    num_chains=2,
    progress_bar=True,
)
%time sampler.run(jax.random.PRNGKey(10), x, yerr, y=y)

In [None]:
inf_data = az.from_numpyro(sampler)
corner.corner(
    inf_data,
    var_names=["m", "b", "Q"],
    truths={
        "m": true_params[0],
        "b": true_params[1],
        "Q": true_frac,
    },
)
az.summary(inf_data)

In [None]:
from astropy.io import fits

with fits.open("data/m67.fits.gz") as f:
    data = f[1].data

mask = np.isfinite(data["parallax"])
mask &= np.isfinite(data["parallax_error"])
mask &= data["parallax"] > 0.2
mask &= data["parallax"] < 3.0
data = data[mask]

plt.hist(data["parallax"], 120, histtype="step")
plt.xlabel("parallax [mas]")
plt.ylabel("count");

**A brief aside:**

In [None]:
from scipy import stats

L = 370.0
x = np.linspace(0, 5000, 500)
r = 0.5 * L * stats.chi2(df=6).rvs(500_000)
plt.hist(r, 100, range=(0, 5000), density=True, histtype="step", label="samples")
plt.plot(x, 0.5 * x**2 * np.exp(-x / L) / L**3, "--", label="pdf")
plt.xlabel("distance [pc]")
plt.yticks([])
plt.legend();

In [None]:
def gaia_single_model(plx_err, plx=None, plx_zp=0.0, L=370.0):
    normed = numpyro.sample("normed", dist.Chi2(6))
    distance = numpyro.deterministic("distance", 0.5 * L * normed)
    numpyro.sample("plx", dist.Normal(1000.0 / distance + plx_zp, plx_err), obs=plx)


plx = data[1200]["parallax"]
plx_err = data[1200]["parallax_error"]

sampler = infer.MCMC(
    infer.NUTS(gaia_single_model),
    num_warmup=2000,
    num_samples=2000,
    num_chains=2,
)
%time sampler.run(jax.random.PRNGKey(0), plx_err, plx=plx)

samples_single = sampler.get_samples()
plt.hist(samples_single["distance"], 50, density=True, histtype="step")
plt.xlabel("distance to target number 1200 [pc]")
plt.ylabel("posterior density")
plt.yticks([]);

In [None]:
def gaia_cluster_model(plx_err, plx=None, plx_zp=0.0):
    log_L = numpyro.sample("log_L", dist.Normal(np.log(370.0), 2.0))
    log_dist_clust = numpyro.sample("log_dist_clust", dist.Normal(np.log(920.0), 2.0))
    log_sigma = numpyro.sample("log_sigma", dist.Normal(0.0, 2.0))
    frac_clust = numpyro.sample("frac_clust", dist.Uniform(0.0, 1.0))

    L = numpyro.deterministic("L", jnp.exp(log_L))
    numpyro.deterministic("dist_clust", jnp.exp(log_dist_clust))
    numpyro.deterministic("approx_size_clust", jnp.exp(log_sigma + log_dist_clust))

    with numpyro.plate("stars", len(plx_err)):

        # The background distance distribution is the same as the one we used above,
        # but this time we fit for the length scale L. Another difference is that we're
        # "transforming" the distribution using an affine transformation, instead of
        # sampling in the "normalized distance" and then multiplying by 0.5*L like we
        # did above. These two approaches are equivalent, but we need to specify a
        # single distribution here for use in the MixtureGeneral distribution below.
        dist_bg = dist.TransformedDistribution(
            dist.Chi2(6),
            dist.transforms.AffineTransform(
                0.0, 0.5 * L, domain=dist.constraints.positive
            ),
        )

        # The foreground distribution is a Gaussian in log distance. Like with the
        # background distribution, we use a transformation to convert from a Gaussian
        # in log distance to a distribution in distance.
        dist_fg = dist.TransformedDistribution(
            dist.Normal(log_dist_clust, jnp.exp(log_sigma)),
            dist.transforms.ExpTransform(),
        )

        # Now we "mix" the foreground and background distributions using the "cluster
        # membership fraction" parameter to specify the mixing weights.
        mixture = MixtureGeneral(
            dist.Categorical(probs=jnp.stack((frac_clust, 1 - frac_clust), axis=-1)),
            [dist_fg, dist_bg],
        )
        distance = numpyro.sample("distance", mixture)

        # Finally, we convert the distance to parallax and add the zero-point offset.
        plx_true = numpyro.deterministic("plx_true", 1000.0 / distance + plx_zp)
        numpyro.sample("plx_obs", dist.Normal(plx_true, plx_err), obs=plx)


plx = np.ascontiguousarray(data["parallax"], dtype=np.float32)
plx_err = np.ascontiguousarray(data["parallax_error"], dtype=np.float32)
sampler = infer.MCMC(
    infer.NUTS(gaia_cluster_model),
    num_warmup=2000,
    num_samples=4000,
    num_chains=2,
    progress_bar=True,
)
%time sampler.run(jax.random.PRNGKey(42), plx_err, plx=plx)

In [None]:
samples = sampler.get_samples()
plt.hist(
    samples_single["distance"],
    50,
    density=True,
    histtype="step",
    label="single target model",
)
plt.hist(
    samples["distance"][:, 1200],
    50,
    density=True,
    histtype="step",
    label="cluster model",
)
plt.xlabel("distance to target number 1200 [pc]")
plt.ylabel("posterior density")
plt.legend()
plt.yticks([]);

In [None]:
inf_data = az.from_numpyro(sampler)
corner.corner(
    inf_data,
    var_names=["L", "dist_clust", "approx_size_clust", "frac_clust"],
    labels=[
        "background length scale [pc]",
        "cluster distance [pc]",
        "cluster intrinsic size [pc]",
        "cluster membership fraction",
    ],
)
az.summary(inf_data, var_names=["L", "dist_clust", "approx_size_clust", "frac_clust"])

In [None]:
pred = infer.Predictive(gaia_cluster_model, samples)(jax.random.PRNGKey(10), plx_err)

_, bins, _ = plt.hist(plx, 50, histtype="step", lw=2, label="observed")
label = "posterior predictive"
for n in np.random.default_rng(0).integers(len(pred["plx_obs"]), size=100):
    plt.hist(
        pred["plx_obs"][n],
        bins,
        histtype="step",
        color="k",
        alpha=0.1,
        lw=0.5,
        label=label,
    )
    label = None
plt.legend()
plt.xlabel("mean parallax [mas]")
plt.ylabel("count")
plt.xlim(bins[0], bins[-1]);

In [None]:
prior_pred = infer.Predictive(gaia_cluster_model, num_samples=50)(
    jax.random.PRNGKey(11), plx_err
)
label = "prior samples"
for n in range(len(prior_pred["plx_obs"])):
    plt.hist(
        prior_pred["plx_obs"][n],
        100,
        range=(0.0, 3.0),
        histtype="step",
        color="k",
        lw=0.5,
        alpha=0.5,
        label=label,
    )
    label = None
plt.legend()
plt.xlabel("mean parallax [mas]")
plt.ylabel("count")
plt.xlim(0, 3);