# Bayesian inference in [JAX](https://github.com/google/jax)
Brett Morris

## Install

<div class="alert alert-block alert-warning">
    <strong>WARNING</strong>: It's possible that installing these dependencies in your usual working environment will break your other packages. Don't do that!
</div>


Please use `conda` or [`venv`](https://docs.python.org/3/library/venv.html) to create an isolated python environment for this tutorial.

For conda:

```bash
conda create -n jax-demo python=3.12
conda activate jax-demo
```

For venv:

```bash
python -m venv /path/to/new/virtual/environment
source <environment_name>/bin/activate
```


#### Installing jax

Try the following on your laptop (CPU): 
```bash
python -m pip install --upgrade "jax[cpu]"
```
jax can run on GPUs and TPUs but requires specific builds for each architecture. Check out the [jax installation docs](https://github.com/google/jax#installation) for details.

#### Other dependencies

Other installations needed for this tutorial can be installed with: 

```bash
python -m pip install numpy scipy matplotlib numpyro arviz corner ipywidgets
```

## Why jax?

jax leverages [just-in-time code compilation](https://docs.jax.dev/en/latest/jit-compilation.html), with [automatic differentiation](https://docs.jax.dev/en/latest/automatic-differentiation.html), and [_accelerated linear algebra_](https://github.com/openxla/xla) with a _numpy-like API_ to calculate blazing fast, differentiable models. Let's break that down: 

* Automatic differentiation allows you to compute gradients of your mathematical models without explicitly deriving gradients for each function. These gradients can be used in gradient-based inference techniques like gradient descent optimization, or Hamiltonian Monte Carlo.
* Accelerated linear algebra package is an optimizing compiler designed for machine learning. You write Python code and it gets just-in-time compiled for your computer architecture (CPU or GPU) at runtime.


## Getting started with jax

First, let's do a bunch of imports:

In [None]:
# The bare necessities:
import numpy as np
import matplotlib.pyplot as plt

# in this example, numpyro must be imported before jax 
import numpyro
from numpyro.infer import MCMC, NUTS
from numpyro import distributions as dist

# Set the number of cores to use on your machine for 
# parallel computing on your CPU:
cpu_cores = 4
numpyro.set_host_device_count(cpu_cores)

# From jax, we'll import the numpy module as `jnp`:
from jax import numpy as jnp

# `grad` computes gradients!
from jax import grad

# this module produces random numbers in jax:
from jax import random

# arviz is a statistical toolkit for analyzing 
# bayesian inference models and their posteriors
import arviz

# corner makes corner plots
from corner import corner

`jax` has been designed to mimic the methods and arguments you're already accustomed to in `numpy`. For many functions, you can simply use the `jax.numpy` module rather than ordinary `numpy` to handle array calculations in `jax`. 

In some instances, converting your numpy model implementations to jax implementations can be as simple as replacing `np` with `jnp` in your code. 🪄

For example, here's how we can create an array of linearly spaced values:

In [None]:
# Create a linearly spaced `Array` object, which behaves like a np.ndarray:
array = jnp.linspace(-5, 5, 20_000)

array[:10]

Note that the above code created a `Array` object, which is not an ordinary numpy array. This is limited to data type `float32` by default. `DeviceArray` objects have the usual built-in methods: 

In [None]:
array.mean(), array.std()

### Creating a synthetic dataset 
Now let's create some synthetic data which we'll fit using jax: 

In [None]:
np.random.seed(0)

x = array

# Set the parameters of the double-Gaussian 
# profile in our synthetic data
amp0 = 2
amp1 = 1.5
x0 = -0.5
x1 = 0.6
s0 = 0.6
s1 = 0.9
yerr = 0.3

y_first_term = (
    amp0 * np.exp(-0.5 * (x - x0)**2 / s0**2)
)

y_second_term = (
    amp1 * np.exp(-0.5 * (x - x1)**2 / s1**2)
)

y_noise = np.random.normal(scale=yerr, size=(len(x)))

y = y_first_term + y_second_term + y_noise

fig, ax = plt.subplots()
plt.plot(x, y, ',', color='silver')
plt.plot(x, y_first_term, label='first term')
plt.plot(x, y_second_term, label='second term')
plt.plot(x, y_first_term + y_second_term, label='first + second')
plt.legend()
ax.set(
    xlabel='x', 
    ylabel='y'
);

### Fitting with numpy/scipy

We could fit the observations $(x, y)$ with numpy and scipy like this: 

In [None]:
def model_numpy(p, x):
    """Numpy implementation of a double-Gaussian profile"""
    a0, x0, s0, a1, x1, s1 = p
    return (
        # gaussian 0
        a0 * np.exp(-0.5 * (x - x0)**2 / s0**2) + 

        # gaussian 1
        a1 * np.exp(-0.5 * (x - x1)**2 / s1**2)
    )

def chi2_numpy(p, x, y, yerr):
    """chi^2 function to minimize with scipy"""
    return np.sum((model_numpy(p, x) - y)**2 / yerr**2)

init_guess = np.float32([1.5, -0.2, 0.7, 1.5, 0.2, 0.7])

from scipy.optimize import minimize as scipy_minimize
bestp_numpy = scipy_minimize(chi2_numpy, init_guess, args=(x, y, yerr), method='BFGS').x

The scipy `minimize` function does optimization _without_ computing gradients. The best fit solutions are: 

In [None]:
bestp_numpy   # best fit parameter solutions

Which look like this: 

In [None]:
plt.plot(x, y, ',', color='silver')
plt.plot(x, model_numpy(init_guess, x), 'purple', ls=':', label='init')
plt.plot(x, model_numpy(bestp_numpy, x), 'dodgerblue', lw=2, label='numpy')
plt.legend()
plt.gca().set(xlabel='x', ylabel='y');

The initial guess is shown above in blue, and the best-fit model using numpy/scipy is shown in red. You can see that the best-fit model doesn't fit the observations particularly well. Now let's implement the same thing in jax. 

### Fitting with jax

Let's specify the model that we will fit to the data using the numpy module within jax. We'll also "decorate" it with the `jit` decorator, which will compile the function for us at runtime. 

In [None]:
# The just-in-time decorator:
from jax import jit

@jit
def model_jax(p, x):     
    """
    Jax implementation of the `model_numpy` function.
    
    The use of `jnp` in place of `np` is the only difference 
    from the numpy version.
    """
    a0, x0, s0, a1, x1, s1 = p
    return (
        # first gaussian
        a0 * jnp.exp(-0.5 * (x - x0)**2 / s0**2) + 
        
        # second gaussian
        a1 * jnp.exp(-0.5 * (x - x1)**2 / s1**2)
    )
@jit
def chi2_jax(p, x, y, yerr):
    """chi^2 function written for minimization with jax"""
    return jnp.sum((model_jax(p, x) - y)**2 / yerr**2)

Now we import the minimize module from the `scipy.optimize` API within jax: 

In [None]:
# jax has its own scipy module which uses autodiffed gradients
from jax.scipy.optimize import minimize as jax_minimize

bestp_jax = jax_minimize(chi2_jax, init_guess, args=(x, y, yerr), method='bfgs').x

# print the best-fit parameters
bestp_jax

In the above cell, we have used _gradient-based_ optimization with the [BFGS method](https://en.wikipedia.org/wiki/Broyden%E2%80%93Fletcher%E2%80%93Goldfarb%E2%80%93Shanno_algorithm). Note that we didn't have to specify the gradient of our model with respect to each free parameter, that was done for us!

Let's plot the best-fit model: 

In [None]:
plt.plot(x, y, ',', color='silver')
plt.plot(x, model_numpy(init_guess, x), 'purple', ls=':', label='init guess')
plt.plot(x, model_numpy(bestp_numpy, x), 'dodgerblue', lw=2, label='numpy')
plt.plot(x, model_jax(bestp_jax, x), 'r', lw=2, label='jax')
plt.legend();
plt.gca().set(xlabel='x', ylabel='y');

In the figure above, the blue curve is the initial guess, the magenta dotted curve is the best fit with Powell's method via numpy/scipy, and the red dashed curve is the best-fit with jax.

### automatic differentiation


In [None]:
from jax import vmap

def model_jax_wrapper(a0, x0, s0, a1, x1, s1, x):
    return model_jax([a0, x0, s0, a1, x1, s1], x)

# `in_axes` tells `vmap` over which axes to 
# map input parameters. We only want to vmap 
# over the `x` variable, the others get None:
in_axes = tuple(len(bestp_jax) * [None]) + (0,)

fig, ax = plt.subplots()
for n, name in enumerate('a0 x0 s0 a1 a2 a3'.split()):
    # vmap vectorizes the operation over `x`:
    d_param_d_x = vmap(
        
        # take the gradient with respect to each model parameter:
        grad(model_jax_wrapper, argnums=n), 
        in_axes=in_axes

    # evaluate (d f / d theta_i) at the best-fit values for 
    # each parameter theta_i:
    )(*bestp_jax, x)
    ls = '-' if name.endswith('0') else ':'
    plt.plot(x, d_param_d_x, label=f"d f / d {name}", ls=ls)
ax.set(
    xlabel='x',
    ylabel='d/dx'
)
plt.legend()
plt.show()

### Speed comparison

Now let's check if there's any speed difference between the two implementations:

In [None]:
print('Numpy only:')
time_numpy = %timeit -n 100 -o model_numpy(init_guess, x)
print('\n\njax:')
time_jax = %timeit -n 100 -o model_jax(init_guess, x)

print(f'\n\njax model evaluation is {time_numpy.average / time_jax.average :.1f}x faster\n\n')

So not only is the jax model evaluation is faster, but the best-fit solution is closer to the true answer. Great work jax!

### Posterior inference with jax/numpyro

Now let's infer posterior distributions for the parameters using more complex inference methods, using _numpyro_. We will define a _model_ which specifies _distributions_ that represent each parameter:

In [None]:
def numpyro_model():
    """
    Define a model to sample with the No U-Turn Sampler (NUTS) via numpyro.
    
    The two Gaussians are defined by an amplitude, mean, and standard deviation.
    
    To find unique solutions for the two Gaussians, we put non-overlapping bounded 
    priors on the two amplitudes, but vary the means and stddev's with identical 
    uniform priors. 
    """
    # Uniform priors for amplitudes
    a0, a1 = numpyro.sample(
        'amp', dist.Uniform(low=1, high=3),
        sample_shape=(2,)
    )

    # Uniform priors for the means
    x0, x1 = numpyro.sample('x', dist.Uniform(low=-4, high=4), sample_shape=(2,))

    # Uniform priors for the stddev's
    s0, s1 = numpyro.sample(
        'sigma', dist.Uniform(low=0.5, high=1.5), 
        sample_shape=(2,)
    )

    # save the model computed at each step
    model = numpyro.deterministic('model', model_jax([a0, x0, s0, a1, x1, s1], x))

    # Normally distributed likelihood
    numpyro.sample(
        "obs", dist.Normal(
            loc=model, 
            scale=yerr
        ), obs=y
    )

Now we use the No U-Turn Sampler for gradient-based inference:

In [None]:
# Random numbers in jax are generated like this:
rng_seed = 0
rng_keys = random.split(
    random.key(rng_seed), 
    int(cpu_cores)
)

# Define a sampler, using here the No U-Turn Sampler (NUTS)
# with a dense mass matrix:
sampler = NUTS(
    numpyro_model, 
)

# Monte Carlo sampling for a number of steps and parallel chains: 
mcmc = MCMC(
    sampler, 
    num_warmup=1_000,
    num_samples=1_000, 
    num_chains=cpu_cores,
    progress_bar=False,
)

# Run the MCMC
mcmc.run(rng_keys)

In [None]:
# arviz converts a numpyro MCMC object to an `InferenceData` object based on xarray:
result = arviz.from_numpyro(mcmc)

# these are the inputs to the synthetic double-gaussian profile (blue lines)
truths = {'amp': [amp0, amp1], 'sigma': [s0, s1], 'x': [x0, x1]}#, 'delta_x': None, 'q1': None, 'q2': None}

# make a corner plot
corner(
    result, 
    var_names=['amp', 'sigma', 'x'],
    quiet=True, 
    truths=truths
);

In [None]:
model_draws = result.posterior['model'].to_numpy()  # shape: (chains, steps, y data points)

n_chains, n_samples, n_data_points = model_draws.shape

plt.plot(x, y, ',', color='gray')
for i in range(n_chains):
    for j in np.random.randint(0, n_samples, size=10):
        plt.plot(x, model_draws[i, j], color=f'C{i}')

Let's check the results, especially the convergence metric called the [Gelman-Rubin statistic](https://ui.adsabs.harvard.edu/abs/1992StaSc...7..457G/abstract), $\hat{r}$.

In [None]:
arviz.summary(result, var_names=['~model'])

As chains approach convergence, $\hat{r}$ approaches unity from above. Good convergence is roughly $\hat{r} \lesssim 1.01$.

The bimodal distribution in the posteriors appears like poor convergence. We can see in the plots above that the bimodal distribution stems from the first and second gaussian terms switching indices. $x_0$ swaps with $x_1$, etc. 

#### Reparameterization

Below, we avoid this bimodality in the posteriors by constraining $x_0 < x_1$, so no swapping can happen. We enforce this constraint by reparameterizing the sampling parameters. 

We follow the traingular sampling approach in [Turk 1990](https://doi.org/10.1016/B978-0-08-050753-8.50015-2) and [Kipping (2013)](https://ui.adsabs.harvard.edu/abs/2013MNRAS.435.2152K/abstract). Rather than sampling uniformly in $x_0, x_1$, we sample uniformly from two new parameters $q_0  \sim [0, 1]$ and $q_1 \sim [0, 1]$. The transformations below give us constrained samples with $x_0 < x_1$:
\begin{eqnarray}
x_0 &=& \Delta ~ \sqrt{q_0} q_1 + x_{\rm min},\\
\delta x &=& \Delta \left(1 - \sqrt{q_0}\right),~~ {\rm and}\\
x_1 &=& x_0 + \delta x
\end{eqnarray}
where the sampling bounds in $x$-space are $(x_{\rm min}, x_{\rm max})$, and $\Delta = x_{\rm max} -  x_{\rm min}$ is the full range.

Here's what that looks like in an example:

In [None]:
q1 = np.random.uniform(0, 1, size=1000)
q2 = np.random.uniform(0, 1, size=1000)

x_min = -4
x_max = 4
Delta = x_max - x_min
demo_x0 = Delta * (np.sqrt(q1) * q2) + x_min
demo_delta_x = Delta * (1 - np.sqrt(q1))

plt.scatter(demo_x0, demo_x0 + demo_delta_x)
plt.gca().set(xlabel='x0', ylabel='x1')

Now let's use this reparameterization in our numpyro model:

In [None]:
def numpyro_model_reparam():
    """
    Define a model to sample with the No U-Turn Sampler (NUTS) via numpyro.
    
    The two Gaussians are defined by an amplitude, mean, and standard deviation.
    
    To find unique solutions for the two Gaussians, we put non-overlapping bounded 
    priors on the two amplitudes, but vary the means and stddev's with identical 
    uniform priors. 
    """
    # Uniform priors for amplitudes
    a0, a1 = numpyro.sample(
        'amp', dist.Uniform(low=1, high=3), 
        sample_shape=(2,)
    )
    
    # Non-overlapping uniform priors for Gaussian means
    q1 = numpyro.sample('q1', dist.Uniform(low=0, high=1))
    q2 = numpyro.sample('q2', dist.Uniform(low=0, high=1))
    x_min = -4
    x_max = 4
    
    # reparameterize in the style of Turk (1990), or 
    # Kipping (2013) Equations 13 & 14
    Delta = x_max - x_min
    x0 = numpyro.deterministic(
        'x0',
        Delta * (jnp.sqrt(q1) * q2) + x_min
    )
    delta_x = numpyro.deterministic(
        'delta_x',
        # reparameterize in the style of Turk (1990)
        Delta * (1 - jnp.sqrt(q1))
    )
    x1 = numpyro.deterministic('x1', x0 + delta_x)
    
    # Uniform priors for the stddev's
    s0, s1 = numpyro.sample(
        'sigma', dist.Uniform(low=0.5, high=1.5), 
        sample_shape=(2,)
    )

    model = numpyro.deterministic('model', model_jax([a0, x0, s0, a1, x1, s1], x))
    
    # Normally distributed likelihood
    numpyro.sample(
        "obs", dist.Normal(
            loc=model, 
            scale=yerr
        ), obs=y
    )

The above cell defines the model. Now the cell below defines how to sample the model, and runs the sampler.

In [None]:
# Random numbers in jax are generated like this:
rng_seed = 0
rng_keys = random.split(
    random.PRNGKey(rng_seed), 
    cpu_cores
)

# Define a sampler, using here the No U-Turn Sampler (NUTS)
# with a dense mass matrix:
sampler = NUTS(
    numpyro_model_reparam, 
    dense_mass=True
)

# Monte Carlo sampling for a number of steps and parallel chains: 
mcmc = MCMC(
    sampler, 
    num_warmup=1_000, 
    num_samples=5_000, 
    num_chains=cpu_cores,
    progress_bar=False,
)

# Run the MCMC
mcmc.run(rng_keys)

Wow, that was fast! Now let's visualize the posteriors using `arviz` and `corner`:

In [None]:
# arviz converts a numpyro MCMC object to an `InferenceData` object based on xarray:
result = arviz.from_numpyro(mcmc)

# these are the inputs to the synthetic double-gaussian profile (blue lines)
truths = {'amp': [amp0, amp1], 'sigma': [s0, s1], 'x0': x0, 'x1': x1}

# make a corner plot
corner(
    result, 
    var_names=['amp', 'sigma', 'x0', 'x1'],#, 'beta'],
    quiet=True, 
    truths=truths
);

Note how all posterior distributions contain the "true" value and the chains have converged. 

Let's see how the samples look in the data space:

In [None]:
model_draws = result.posterior['model'].to_numpy()  # shape: (chains, steps, y data points)

n_chains, n_samples, n_data_points = model_draws.shape

plt.plot(x, y, ',', color='gray')
for i in range(n_chains):
    for j in np.random.randint(0, n_samples, size=10):
        plt.plot(x, model_draws[i, j], color=f'C{i}')

Let's check convergence:

In [None]:
arviz.summary(result, var_names=['~model'])

We've accurately and robustly inferred the six parameters, in no time at all!