In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

## Introduction

JAX has a `jit` function that allows us to just-in-time compile functions
written using JAX's NumPy and SciPy-wrapped functions.
JIT stands for "just-in-time" compilation,
which stands in contrast to AOT (ahead-of-time).
Using `jit` should give you speed-ups compared to not using it.

In this notebook, we are going to explore the gains that we expect to get
by using JAX's just-in-time compilation function `jit`.
Because JIT compilation is usually simply applied _on top of_ existing functions,
we'll explore its primarily by examples rather than by exercises.

## JIT example from the JAX docs

Coming up with an example where JIT compilation could be useful is quite a challenge,
so let's start off with an examplee from the JAX docs.

The function in question is the SELU function,
which is an activation function applied elementwise
to the outputs of a neural network layer.

In [None]:
import jax.numpy as np


def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

Timing the function _without_ JIT compilation:

In [None]:
from time import time

from jax import random

key = random.PRNGKey(44)
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

Now, let's try JIT-compiling the function.

In [None]:
from jax import jit

selu_jit = jit(selu)

%timeit selu_jit(x).block_until_ready()

As we can see, the JIT-compiled function is about 3X faster than the non-JIT compiled function.

More importantly, any function that you write using JAX-wrapped NumPy,
JAX-wrapped SciPy,
and its own provided `lax` submodule,
can be JIT-compiled to gain speed-ups.

## Re-examining the Gaussian random walk

Let's revisit the Gaussian random walk that we implemented
as a case study in what happens when we use JAX's idioms to write our code.

### Pure Python version of the Gaussian random walk

In [None]:
import numpy as onp


def gaussian_random_walk_python(num_realizations, num_timesteps):
    rws = []
    for i in range(num_realizations):
        rw = []
        prev_draw = 0
        for t in range(num_timesteps):
            prev_draw = onp.random.normal(loc=prev_draw)
            rw.append(prev_draw)
        rws.append(rw)
    return rws

In [None]:
from time import time

N_REALIZATIONS = 1_000
N_TIMESTEPS = 10_000
start = time()
trajectories_python = gaussian_random_walk_python(N_REALIZATIONS, N_TIMESTEPS)
end = time()
print(f"{end - start:.2f} seconds")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

for trajectory in trajectories_python[:20]:
    plt.plot(trajectory)
sns.despine()

### JAX implementation without JIT

Now, let's take a look at the JAX-based implementation.

In [None]:
from jax import lax, random

key = random.PRNGKey(44)
keys = random.split(key, N_TIMESTEPS)


def new_draw(prev_val, key):
    """lax.scannable function for drawing a new draw from the GRW."""
    new = prev_val + random.normal(key)
    return new, prev_val


def grw_draw(key, num_steps):
    """One GRW draw over a bunch of steps."""
    keys = random.split(key, num_steps)
    final, draws = lax.scan(new_draw, 0.0, keys)
    return final, draws

In [None]:
from functools import partial

from jax import vmap


def gaussian_random_walk_jax(num_realizations, num_timesteps):
    """Multiple GRW draws."""
    keys = random.split(key, num_realizations)
    grw_k_steps = partial(grw_draw, num_steps=num_timesteps)
    final, trajectories = vmap(grw_k_steps)(keys)
    return final, trajectories

In [None]:
from jax import random

key = random.PRNGKey(42)
start = time()
final_jax, trajectories_jax = gaussian_random_walk_jax(
    N_REALIZATIONS, N_TIMESTEPS
)
trajectories_jax.block_until_ready()
end = time()
print(f"{end - start:.2f} seconds")

In [None]:
%timeit gaussian_random_walk_jax(N_REALIZATIONS, N_TIMESTEPS)[1].block_until_ready()

In [None]:
for trajectory in trajectories_jax[:20]:
    plt.plot(trajectory)
sns.despine()

### JAX implementation _with_ JIT compilation

Now we're going to JIT-compile our Gaussian Random Walk function and see how long it takes for the program to run.

In [None]:
from jax import jit


def gaussian_random_walk_jit(num_realizations, num_timesteps):
    keys = random.split(key, num_realizations)
    grw_k_steps = jit(partial(grw_draw, num_steps=num_timesteps))
    final, trajectories = vmap(grw_k_steps)(keys)
    return final, trajectories

In [None]:
start = time()
final_jit, trajectories_jit = gaussian_random_walk_jit(
    N_REALIZATIONS, N_TIMESTEPS
)
trajectories_jit.block_until_ready()
end = time()
print(f"{end - start:.2f} seconds")

In [None]:
%timeit gaussian_random_walk_jit(N_REALIZATIONS, N_TIMESTEPS)[1].block_until_ready()

In [None]:
for trajectory in trajectories_jit[:20]:
    plt.plot(trajectory)
sns.despine()

It may appear that JIT-compilation doesn't appear to do much,
but we can assure you that there's a great explanation for this phenomena.

Within the Gaussian random walk, we used `lax.scan`,
which itself gives us a fairly compiled operation already.
The docs spell it out in jargon:

> Also unlike that Python version, scan is a JAX primitive and is lowered to a single XLA While HLO. 
> That makes it useful for reducing compilation times for jit-compiled functions, 
> since native Python loop constructs in an @jit function are unrolled, 
> leading to large XLA computations.

If we were to use a for-loop instead of `lax.scan`,
then we would be missing out on te performance gain.
So when we add in JIT-compilation _on top of_ using `lax.scan`,
the added gain is not as much as if we didn't use `lax.scan`.

In both cases, the runtime is essentially constant

JIT-compilation gave us about a 1-2X speedup over non-JIT compiled code,
and was approximately at least 20X faster than the pure Python version.
That shouldn't surprise you one bit :).

## A few pointers on syntax

Firstly, if we subscribe to the Zen of Python's notion that "flat is better than nested",
then by following the idioms listed here -- closures/partials, `vmap` and `lax.scan`,
we'll likely only ever go one closure deep into our programs.
Notice how we basically never wrote any for-loops in our array code;
they were handled elegantly by the looping constructs `vmap` and `lax.scan`. 

Secondly, using `jit`, we get further optimizations on our code for free.
A pre-requisite of `jit` is that the _every_ function call made in the program function being `jit`-ed
is required to be written in a "pure functional" style,
i.e. there are no side effects, no mutation of global state.
Put plainly, everything that you use _inside_ the function should be passed in
(with the exception of imports, of course).
If you write a program using the idioms used here
(closures to wrap state, `vmap`/`lax.scan` in lieu of loops,
explicit random number generation using PRNGKeys),
then you will be able to JIT compile the program with ease.