# LMU Code Coffee: [JAX](https://github.com/google/jax)
Brett Morris

### Install

Installation of jax is sometimes [nontrivial](https://github.com/google/jax#installation), but if you're feeling lucky, you can try the following on your laptop (CPU): 
```bash
pip install --upgrade "jax[cpu]"
```
jax can run on GPUs and TPUs but requires specific builds for each system. See the link above for details.

### Why jax?

It exploits [_autodiff_](https://github.com/hips/autograd) with [_accelerated linear algebra_](https://www.tensorflow.org/xla) with a _numpy-like API_ but a [_just-in-time compiled backend_](https://github.com/google/jax#compilation-with-jit) to calculate blazing fast, differentiable models. Let's break that down: 

* The automatic differentiation allows you to compute gradients of your mathematical models without explicitly deriving gradients for each function. These gradients can be used for gradient-based inference techniques like Hamiltonian Monte Carlo.
* The accelerated linear algebra (XLA) package is an optimizing compiler designed for machine learning. You write Python code and it gets just-in-time compiled before you execute it.


### Getting started with jax
It looks like ordinary Python code, but you generally use jax's numpy API rather than ordinary numpy to handle array calculations.  

For example, here's how we can create an array of linearly spaced values:

In [None]:
# We need to import numpyro first, though we use it last
import numpyro
from numpyro.infer import MCMC, NUTS
from numpyro import distributions as dist

# Set the number of cores on your machine for parallelism:
cpu_cores = 4
numpyro.set_host_device_count(cpu_cores)

# From jax, we'll import the numpy module as `jnp`
from jax import numpy as jnp

# Create a linearly spaced `DeviceArray` object, which behaves like a np.ndarray:
dev_arr = jnp.linspace(-5, 5, 1_000)

dev_arr[:10]

Note that the above code created a `DeviceArray` object, which is not an ordinary numpy array. This is limited to data type `float32` by default. `DeviceArray` objects have the usual built-in methods: 

In [None]:
dev_arr.mean(), dev_arr.std()

### Creating a synthetic dataset 
Now let's create some synthetic data which we'll fit using jax: 

In [None]:
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

x = np.array(dev_arr.copy())

# Set the parameters of the double-Gaussian 
# profile in our synthetic data
amp0 = 5
amp1 = 10
x0 = 0.5
x1 = -0.2
s0 = 1
s1 = 0.3
yerr = 0.4

y = (
    amp0 * np.exp(-0.5 * (x - x0)**2 / s0**2) + 
    amp1 * np.exp(-0.5 * (x - x1)**2 / s1**2) + 
    np.random.normal(scale=yerr, size=(len(x)))
)

plt.plot(x, y)
plt.gca().set(xlabel='x', ylabel='y');

### Fitting with numpy/scipy

We could fit the observations $(x, y)$ with numpy and scipy like this: 

In [None]:
# From scipy, we'll import the Powell minimizer
from scipy.optimize import fmin_powell

def model_numpy(p, x):
    """Numpy implementation of a double-Gaussian profile"""
    a0, x0, s0, a1, x1, s1 = p
    return (
        a0 * np.exp(-0.5 * (x - x0)**2 / s0**2) + 
        a1 * np.exp(-0.5 * (x - x1)**2 / s1**2)
    )

def chi2_numpy(p, x, y, yerr):
    """chi^2 function to minimize, using numpy/scipy"""
    return np.sum((model_numpy(p, x) - y)**2 / yerr**2)

init_guess = np.array([5, 0, 2, 10, 0, 0.8])

bestp_numpy = fmin_powell(chi2_numpy, init_guess, disp=0, args=(x, y, yerr))

The `fmin_powell` function does optimization _without_ computing gradients. The best fit solutions are: 

In [None]:
bestp_numpy   # best fit parameter solutions

Which look like this: 

In [None]:
plt.plot(x, y)
plt.plot(x, model_numpy(init_guess, x), 'b')
plt.plot(x, model_numpy(bestp_numpy, x), 'r', ls='--', lw=3)
plt.gca().set(xlabel='x', ylabel='y');

The initial guess is shown above in blue, and the best-fit model using numpy/scipy is shown in red. You can see that the best-fit model doesn't fit the observations particularly well. Now let's implement the same thing in jax. 

### Fitting with jax

Let's specify the model that we will fit to the data using the numpy module within jax. We'll also "decorate" it with the `jit` decorator, which will compile the function for us at runtime. 

In [None]:
# Get the just-in-time decorator
from jax import jit

@jit
def model_jax(p, x):     
    """
    Jax implementation of the `model_numpy` function.
    
    The use of `jnp` in place of `np` is the only difference 
    from the numpy version.
    """
    a0, x0, s0, a1, x1, s1 = p
    return (
        a0 * jnp.exp(-0.5 * (x - x0)**2 / s0**2) + 
        a1 * jnp.exp(-0.5 * (x - x1)**2 / s1**2)
    )
@jit
def chi2_jax(p, x, y, yerr):
    """chi^2 function written for minimization with jax"""
    return jnp.sum((model_jax(p, x) - y)**2 / yerr**2)

Now we import the minimize module from the `scipy.optimize` API within jax: 

In [None]:
# Jax has its own scipy module which uses autodiffed gradients
from jax.scipy.optimize import minimize

bestp_jax = minimize(chi2_jax, init_guess, args=(x, y, yerr), method='bfgs')

# print the best-fit parameters
bestp_jax.x

In the above cell, we have used _gradient-based_ optimization with the [BFGS method](https://en.wikipedia.org/wiki/Broyden%E2%80%93Fletcher%E2%80%93Goldfarb%E2%80%93Shanno_algorithm). Note that we didn't have to specify the gradient of our model with respect to each free parameter, that was done for us!

Let's plot the best-fit model: 

In [None]:
plt.plot(x, y, 'k,')
plt.plot(x, model_numpy(init_guess, x), 'b', label='init guess')
plt.plot(x, model_numpy(bestp_numpy, x), 'm', ls=':', lw=1.5, label='numpy')
plt.plot(x, model_jax(bestp_jax.x, x), 'r', ls='--', lw=3, label='jax')
plt.legend();

In the figure above, the blue curve is the initial guess, the magenta dotted curve is the best fit with Powell's method via numpy/scipy, and the red dashed curve is the best-fit with jax. Finally, a good fit!

### Speed comparison

Now let's check if there's any speed difference between the two implementations:

In [None]:
print('Numpy only:')
time_numpy = %timeit -n 100 -o model_numpy(init_guess, x)
print('jax:')
time_jax = %timeit -n 100 -o model_jax(init_guess, x)

print(f'\n\njax model evaluation is {time_numpy.average / time_jax.average :.1f}x faster\n\n')

So not only is the jax model evaluation is faster, but the best-fit solution is closer to the true answer. Great work jax!

### Posterior inference with jax/numpyro

Now let's infer posterior distributions for the parameters using more complex inference methods, using _numpyro_. We will define a _model_ which specifies _distributions_ that represent each parameter:

In [None]:
def numpyro_model():
    """
    Define a model to sample with the No U-Turn Sampler (NUTS) via numpyro.
    
    The two Gaussians are defined by an amplitude, mean, and standard deviation.
    
    To find unique solutions for the two Gaussians, we put non-overlapping bounded 
    priors on the two amplitudes, but vary the means and stddev's with identical 
    uniform priors. 
    """
    # Define priors for non-overlapping Gaussian amplitudes:
    a0 = numpyro.sample('amp0', dist.Uniform(low=0, high=8))
    a1 = numpyro.sample('amp1', dist.Uniform(low=8, high=30))
    
    # Uniform priors for means
    x0, x1 = numpyro.sample(
        'center', dist.Uniform(low=-1, high=1), 
        sample_shape=(2,)
    )
    
    # Uniform priors for the stddev's
    s0, s1 = numpyro.sample(
        'sigma', dist.Uniform(low=0, high=3), 
        sample_shape=(2,)
    )
    
    # Normally distributed likelihood
    numpyro.sample(
        "obs", dist.Normal(
            loc=model_jax([a0, x0, s0, a1, x1, s1], x), 
            scale=yerr
        ), obs=y
    )

The above cell defines the model. Now the cell below defines how to sample the model, and runs the sampler.

In [None]:
# Random numbers in jax are generated like this:
from jax.random import PRNGKey, split

rng_seed = 42
rng_keys = split(
    PRNGKey(rng_seed), 
    cpu_cores
)

# Define a sampler, using here the No U-Turn Sampler (NUTS)
# with a dense mass matrix:
sampler = NUTS(
    numpyro_model, 
    dense_mass=True
)

# Monte Carlo sampling for a number of steps and parallel chains: 
mcmc = MCMC(
    sampler, 
    num_warmup=1_000, 
    num_samples=5_000, 
    num_chains=cpu_cores
)

# Run the MCMC
mcmc.run(rng_keys)

Wow, that was fast! Now let's visualize the posteriors using `arviz` and `corner`:

In [None]:
# these packages will aid in visualization:
import arviz
from corner import corner

# arviz converts a numpyro MCMC object to an `InferenceData` object based on xarray:
result = arviz.from_numpyro(mcmc)

# these are the inputs to the synthetic double-gaussian profile (blue lines)
truths = [amp0, amp1, x0, x1, s0, s1]

# make a corner plot
corner(
    result, 
    quiet=True, 
    truths=truths
);

Note how all posterior distributions contain the "true" value, from which we generated the dataset. We've accurately inferred the six parameters, in no time at all!