Over the past year or so, I've been using [JAX](https://jax.readthedocs.io) extensively for my research, and I've also been encouraging other astronomers to give it a try.
In particular, I've been using JAX as the computation engine for probabilistic inference tasks.

## An introduction to JAX

There's more to it, but one way that I like to think about JAX is as NumPy with just-in-time compilation and [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation).
The just-in-time compilation features of JAX can be used to speed up you NumPy computations by removing some Python overhead and by executing it on your GPU.
These optimizations can be useful for a wide range of applications, but JAX especially shines when using automatic differentiation for improving your inference algorithms.

While it's not true that naively using JAX will automatically speed up you code, in many cases it will.

In [None]:
import jax

jax.config.update("jax_enable_x64", True)

In [None]:
import jax.numpy as jnp

x = jnp.linspace(0, 10, 5)
jnp.sin(x)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# We'll choose the parameters of our synthetic data.
# The outlier probability will be 80%:
true_frac = 0.8

# The linear model has unit slope and zero intercept:
true_params = [1.0, 0.0]

# The outliers are drawn from a Gaussian with zero mean and unit variance:
true_outliers = [0.0, 1.0]

# For reproducibility, let's set the random number seed and generate the data:
np.random.seed(12)
x = np.sort(np.random.uniform(-2, 2, 15))
yerr = 0.2 * np.ones_like(x)
y = true_params[0] * x + true_params[1] + yerr * np.random.randn(len(x))

# Those points are all drawn from the correct model so let's replace some of
# them with outliers.
m_bkg = np.random.rand(len(x)) > true_frac
y[m_bkg] = true_outliers[0]
y[m_bkg] += np.sqrt(true_outliers[1] + yerr[m_bkg] ** 2) * np.random.randn(sum(m_bkg))

# Then save the *true* line.
x0 = np.linspace(-2.1, 2.1, 200)
y0 = np.dot(np.vander(x0, 2), true_params)

# Plot the data and the truth.
plt.errorbar(x, y, yerr=yerr, fmt=",k", ms=0, capsize=0, lw=1, zorder=999)
plt.scatter(x[m_bkg], y[m_bkg], marker="s", s=22, c="w", edgecolor="k", zorder=1000)
plt.scatter(x[~m_bkg], y[~m_bkg], marker="o", s=22, c="k", zorder=1000)
plt.plot(x0, y0, color="k", lw=1.5)
plt.xlabel("$x$")
plt.ylabel("$y$")
plt.ylim(-2.5, 2.5)
plt.xlim(-2.1, 2.1);

In [None]:
import jax
import jax.numpy as jnp

import numpyro

numpyro.set_host_device_count(2)
from numpyro import distributions as dist, infer


def linear_model(x, yerr, y=None):
    theta = numpyro.sample("theta", dist.Uniform(-0.5 * jnp.pi, 0.5 * jnp.pi))
    b_perp = numpyro.sample("b_perp", dist.Normal(0, 1))

    # Transformed parameters (and other things!) can be tracked during sampling using
    # "deterministics" as follows:
    m = numpyro.deterministic("m", jnp.tan(theta))
    b = numpyro.deterministic("b", b_perp / jnp.cos(theta))

    numpyro.sample("obs", dist.Normal(m * x + b, yerr), obs=y)


sampler = infer.MCMC(
    infer.NUTS(linear_model),
    num_warmup=2000,
    num_samples=2000,
    num_chains=2,
    progress_bar=True,
)
%time sampler.run(jax.random.PRNGKey(0), x, yerr, y=y)

In [None]:
import corner
import arviz as az

inf_data = az.from_numpyro(sampler)
corner.corner(inf_data, var_names=["m", "b"], truths=true_params)
az.summary(inf_data)

In [None]:
from numpyro_ext.distributions import MixtureGeneral


def linear_mixture_model(x, yerr, y=None):
    # Foreground model
    theta = numpyro.sample("theta", dist.Uniform(-0.5 * jnp.pi, 0.5 * jnp.pi))
    b_perp = numpyro.sample("b_perp", dist.Normal(0.0, 1.0))
    m = numpyro.deterministic("m", jnp.tan(theta))
    b = numpyro.deterministic("b", b_perp / jnp.cos(theta))
    fg_dist = dist.Normal(m * x + b, yerr)

    # Background model
    bg_mean = numpyro.sample("bg_mean", dist.Normal(0.0, 1.0))
    bg_sigma = numpyro.sample("bg_sigma", dist.HalfNormal(3.0))
    bg_dist = dist.Normal(bg_mean, jnp.sqrt(bg_sigma**2 + yerr**2))

    # Mixture
    Q = numpyro.sample("Q", dist.Uniform(0.0, 1.0))
    mix = dist.Categorical(probs=jnp.array([Q, 1.0 - Q]))
    numpyro.sample("obs", MixtureGeneral(mix, [fg_dist, bg_dist]), obs=y)


sampler = infer.MCMC(
    infer.NUTS(linear_mixture_model),
    num_warmup=2000,
    num_samples=2000,
    num_chains=2,
    progress_bar=True,
)
%time sampler.run(jax.random.PRNGKey(10), x, yerr, y=y)

In [None]:
inf_data = az.from_numpyro(sampler)
corner.corner(
    inf_data,
    var_names=["m", "b", "Q"],
    truths={
        "m": true_params[0],
        "b": true_params[1],
        "Q": true_frac,
    },
)
az.summary(inf_data)

In [None]:
import numpy as np
from astropy.io import fits
import matplotlib.pyplot as plt

with fits.open("data/m67.fits.gz") as f:
    data = f[1].data

mask = np.isfinite(data["parallax"])
mask &= np.isfinite(data["parallax_error"])
mask &= data["parallax"] > 0.5
mask &= data["parallax"] < 2.0
data = data[mask]

plt.hist(data["parallax"], 120, histtype="step")
plt.xlabel("parallax [mas]")
plt.ylabel("count");