# Google JAX

JAX (**J**ust **A**fter e**X**ecution) is basically Numpy for CPU, GPU, and TPU. It is a Python library developed by Google which replicates the Numpy API but is capable of offloading numerical computation to other devices. On top of that, JAX can just-in-time compile and optimize pure functions for extra speed. Users also can compute the gradient (*i.e.* derivative) of Python functions automatically. Working with JAX is almost as easy as working with Numpy. It's primary use is for machine learning workloads, but is not limited by that!

In the next section we will show the power and simplicity of JAX with a simple example and from there we will build increasingly more complex applications.

## Motivation

In this section we will show a motivating example using the well-known `sigmoid` function! But first, let us get the imports out of the way.

In [1]:
# Jax essential imports
import jax.numpy as jnp
from jax import jit, grad, vmap, pmap, make_jaxpr

# Other imports
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Use seaborn defaults for plotting
sns.set()

# Size of the array used for testing
SIZE = 100_000

We define the `sigmoid` function and its derivative `sigmoid_prime` using `numpy` as follows:

In [2]:
def sigmoid(z):
    """Sigmoid function implemented with numpy."""
    return 1 / (1 + np.exp(-z))

def sigmoid_prime(z):
    """First derivative of the sigmoid function implemented with numpy."""
    s = sigmoid(z)
    return s * (1 - s)

Let's generate an input array and feed it to both functions and measure the average time.

In [3]:
# Generates an array of SIZE elements with values uniformly ranging from -10 to 10.
z = np.linspace(-10.0, 10.0, SIZE)

t_s  = %timeit -o sigmoid(z)
t_sp = %timeit -o sigmoid_prime(z)

821 µs ± 7.26 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
966 µs ± 26.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


The time on my machine are:

- `sigmoid`: 779 us
- `sigmoid_prime`: 848 us

Now let us reproduce the same example using JAX, then we will compare the running time of both examples.

In [4]:
@jit
def sigmoid_jax(z):
    """Sigmoid function implemented with JAX and JIT compiled."""
    # Notice the use of `jnp` here instead of `np`.
    return 1 / (1 + jnp.exp(-z))

# Create a JIT'ed and vectorized gradient function
sigmoid_prime_jax = jit(vmap(grad(sigmoid_jax)))

Well, that is quite a bit of things going on here. Let's break it down.

On line 1, we use the `jax.jit` decorator in our `sigmoid_jax` function. That means that the first time we invoke this function, it will be compiled to our backend (CPU, in the case of this documen) and executed natively with improved performance!

Note on line 3 that we are using the JAX version of numpy via the namespace `jnp`. JAX replicates the numpy interface very closely, so, in practice, just adding the letter J to your numpy code should be sufficient!

One line 5 we compute the derivative of `sigmoid_jax` with respect to `z` by composing a couple of functions:

1. The `jax.grad` function is an operator that returns the gradient of a function. By default, it takes the gradient with respect to the first paramter (there is only one in this case, `z`) but we could change that if we wanted!

2. The gradient function is not vectorized, meaning that we could input only one "point" at a time to obtain its gradient. In this case, `jax.vmap` **vectorizes** the input function.

3. Finally, we use `jax.jit` on this new function to just-in-time compiler (and optimize) it.

Now we are going to time it.

In [5]:
# Same as before, but this time we use the JAX API.
z = jnp.linspace(-10.0, 10.0, SIZE)

t_sj  = %timeit -o sigmoid_jax(z).block_until_ready()
t_spj = %timeit -o sigmoid_prime_jax(z).block_until_ready()

44.2 µs ± 3.94 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
43.4 µs ± 1.01 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


On my machine the results read:

- `sigmoid_jax`: 87.5 us
- `sigmoid_prime_jax`: 99.2 us

Pay attention to the use of `block_until_ready()`. As stated before, JAX support several backends (i.e. devices) and some computation may run asynchronously from the host. For this reason we must explicitly tell Python to wait for the results to be ready.

Now let's calculate the speedup:

In [6]:
print(f"Speedup sigmoid       = {t_s.average  / t_sj.average :.2f}x")
print(f"Speedup sigmoid prime = {t_sp.average / t_spj.average:.2f}x")

Speedup sigmoid       = 18.58x
Speedup sigmoid prime = 22.25x


More than 8x speedup, that's a lot! And we barely did any work. Note that, in this case, both Numpy and JAX are running on the CPU, but because of the JIT and vectorization capabilities, we end up with much faster code using the latter.

In [8]:
make_jaxpr(sigmoid_jax)(0.0)

{ lambda ; a:f32[]. let
    b:f32[] = xla_call[
      call_jaxpr={ lambda ; c:f32[]. let
          d:f32[] = neg c
          e:f32[] = exp d
          f:f32[] = add e 1.0
          g:f32[] = div 1.0 f
        in (g,) }
      name=sigmoid_jax
    ] a
  in (b,) }