# JAX Quickstart

**JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance general numerical computing and machine learning research.**

It is an extensible system for composable function transformations using Python and NumPy. JAX is fast, easy to use, and uses a functional programming model, which aligns well with mathematics.

With an updated version of **[Autograd](https://github.com/hips/autograd)**, JAX
can automatically differentiate native Python and NumPy code. It can
differentiate through a large subset of Python’s features, including loops, if statements, recursions, and closures. 

JAX can even **take derivatives of derivatives of
derivatives**. It supports **reverse-mode and forward-mode differentiation** and the two can be composed arbitrarily to any order.

What’s new is that JAX uses
**[XLA](https://www.tensorflow.org/xla)** to compile and run your NumPy code on accelerators, like GPUs and TPUs. Compilation happens under the hood by default, with library calls getting just-in-time compiled and executed. But JAX even lets you **just-in-time (JIT) compile** your own Python functions into XLA-optimized kernels using a one-function API. Since all computations are compiled with JIT, there are almost no handwritten kernels. And JIT gives you an almost "eager-mode" performance.

Compilation and automatic differentiation (autodiff) can be composed arbitrarily, so you
can express sophisticated algorithms and get maximal performance without having to leave Python.

JAX is much more than just an accelerator-backed NumPy. It comes with a few key _program transformations_ (transforms) that are useful when writing numerical code. 

In this quickstart, you'll be focusing on the following transforms:

 - [jit()](https://jax.readthedocs.io/en/latest/jax.html#jax.jit): for speeding up your code
 - [grad()](https://jax.readthedocs.io/en/latest/jax.html#jax.grad): for taking derivatives
 - [vmap()](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap): for automatic vectorization or batching
 - [pmap()](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap): for executing functions on multiple accelerators in parallel (on multiple GPU or even TPU cores)

The general API is Pythonic and user-friendly—you pass a function into a transform and get a function out. You'll see it in the examples below.

Since JAX is still being developed, do check out the [Common Gotchas in JAX](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) notebook.

## Import the libraries

- Import `grad`, `jit`, `vmap`, and `pmap` as the transforms covered in this quickstart, `jax.random`—for random initialization, and `jax.numpy`—as the accelerator-backed NumPy:

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

## Multiplying Matrices

First, let's look at how to do N-dimensional array (NDArray) multiplication in JAX, which should be as easy as in NumPy.

Note that one big difference between NumPy and JAX is how you generate random numbers with JAX's [pseudo-random number generator (PRNG)](https://github.com/google/jax/blob/master/design_notes/prng.md). 

JAX PRNG = threefry counter PRNG + a functional array-oriented splitting model. 

Let's look at some examples:

- Generate a random array with [jax.random()](https://jax.readthedocs.io/en/latest/jax.random.html):

In [None]:
# Create a PRNG key (seed set to 0)
key = random.PRNGKey(0)

# Create a Normally-distributed 10-element NDArray with the PRNG key
x = random.normal(key, (10,))

# Print the new array
print(x)

- Multiply two matrices and measure the compute time:

In [None]:
# Set a matrix size to 3000
size = 3000

# Instantiate a Normal matrix as JAX NumPy dtype `float32` with the PRNG key
x = random.normal(key, (size, size), dtype=jnp.float32)

# Multiply the matrix by its transpose, measure the time it takes to compute
%timeit jnp.dot(x, x.T).block_until_ready()

The NumPy calculation runs on the accelerator! Notice how much time it took to carry out the matrix multiply with JAX NumPy.

Note that when measuring the true cost of the operation, you add `block_until_ready()` to wait until the computation is complete. Read more about how [JAX uses asynchronous execution by default](https://jax.readthedocs.io/en/latest/async_dispatch.html). 

- The next example shows how you can use JAX NumPy functions to work on regular NumPy arrays:

In [None]:
# Import the original CPU-backed NumPy
import numpy as onp

# Create a Normal matrix of dtype `float32` with ordinary NumPy 
x = onp.random.normal(size=(size, size)).astype(onp.float32)

# Dot multiply again with JAX NumPy, measure the compute time
%timeit jnp.dot(x, x.T).block_until_ready()

This is slower than before because the computation follows the data placement without staying on the CPU, GPU or TPU. The kernels are dispatched to the accelerator one operation at a time.

With JAX, however, you can explicitly place data, such as your NDArray, on the device. You can read more on data and computation placement [here](https://jax.readthedocs.io/en/latest/faq.html#faq-data-placement).

- Ensure that an NDArray is backed by device memory by using [jax.device_put()](https://jax.readthedocs.io/en/latest/jax.html?highlight=device_put#jax.device_put):

In [None]:
from jax import device_put

# As before, create a Normal matrix with ordinary NumPy and dtype `float32`
x = onp.random.normal(size=(size, size)).astype(onp.float32)

# Transfer NDArray to device
x = device_put(x)

# As before, dot multiply with JAX NumPy, measure the compute time 
%timeit jnp.dot(x, x.T).block_until_ready()

Notice the time improvement here.

The output of `device_put` still acts like an NDArray, but it only copies values back to the CPU when they're needed for printing, plotting, saving to disk, branching, etc. 

Note that the behavior of `device_put` is equivalent to the function `jit(lambda x: x)`, but it's _faster_.

If you have a GPU (or TPU!), these calls run on the accelerator and have the potential to be much faster than on CPU.

- Run the dot multiply again using ordinary NumPy instead of JAX NumPy:

In [None]:
# As before, create a Normal matrix with ordinary NumPy and dtype `float32`
x = onp.random.normal(size=(size, size)).astype(onp.float32)

# Create a dot product with ordinary NumPy instead of JAX
# and measure the time it takes to compute
%timeit onp.dot(x, x.T)

This was just an introduction on how to multiply matrices with JAX.

Let's go over each of JAX's key program transforms—`jit`, `grad`, `vmap`, and `pmap`. You'll also be composing these program transformations in interesting ways!

## Using `jit` to speed up functions

JAX runs transparently on the GPU/TPU (or CPU, if you don't have one).

However, in the above examples, JAX is dispatching kernels to the GPU one operation at a time. Overall, JAX NumPy can be sometimes slower than ordinary NumPy—especially on the CPU. So, if you have a sequence of operations, to speed your calculation you can use the @[jit](https://jax.readthedocs.io/en/latest/jax.html#jax.jit) decorator to compile multiple operations together using [XLA](https://www.tensorflow.org/xla). 

Let's try that.

- In the example below, you'll define a Scaled Exponential Linear Unit (SELU) activation function (which you may be familiar with from the world of deep learning). Then, you'll apply SELU on a large array and check how fast the computation is with JAX NumPy:

In [None]:
# Define a SELU activation function
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

# Instantiate a Normally-distributed NDArray with 1 million elements
x = random.normal(key, (1000000,))

# Call SELU and measure the computation once the computation is complete
%timeit selu(x).block_until_ready()

- You can speed up the calculation with `jit`, which will JIT-compile the first time `selu` is called (and it will be cached thereafter):

In [None]:
# Apply `jit()` on the SELU function
selu_jit = jit(selu)

# Call jitted SELU, measure the compute time again
%timeit selu_jit(x).block_until_ready()

Notice how much faster it is when you JIT-compile.

## Taking derivatives with `grad`

In addition to evaluating numerical functions, you can also transform them. One such transformation is [autodiff](https://en.wikipedia.org/wiki/Automatic_differentiation). 

Just like in [Autograd](https://github.com/HIPS/autograd), you can compute gradients with the [grad()](https://jax.readthedocs.io/en/latest/jax.html#jax.grad) function in JAX. 

Taking derivatives is as easy as calling `grad`.

- Let's explore autodiff with a logistic regression example below where you first use `jit` on a function and then calculate the gradient:

In [None]:
# Define a logistic regression function with JAX NumPy
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

# Create an array with spaced elements
x_small = jnp.arange(3.)

# JIT the logistic regression function, take its derivative
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

- Let's verify that your result is correct with finite differences:

In [None]:
def first_finite_differences(f, x):
  eps = 1e-3
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])

print(first_finite_differences(sum_logistic, x_small))

 Note how close the results are. 

`grad` and `jit` compose and you can mix them arbitrarily. In the above example, you jitted `sum_logistic` and then took its derivative. You can go even further and take as many gradients of gradients as you want.

- Invoke `grad` and `jit` multiple times:

In [None]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

For more advanced autodiff, you can use JAX's primitive [vjp()](https://jax.readthedocs.io/en/latest/jax.html?highlight=jax.vjp#jax.vjp) (vector-Jacobian products) for reverse-mode vector-Jacobian products and [jvp()](https://jax.readthedocs.io/en/latest/jax.html?highlight=jax.jvp#jax.jvp) (Jacobian-vector products) for forward-mode Jacobian-vector products. These transforms return functions to push-forward or pull-back single vectors. (Reverse-mode differentiation is often referred to as backpropagation or backprop in machine learning.)

They can also be composed arbitrarily with one another, and with other JAX transformations, such as `jit`. 

- Below is a simple example of a function that efficiently computes full Hessian matrices with [jax.hessian()](https://jax.readthedocs.io/en/latest/_modules/jax/api.html#hessian). You can mix `jit` with full Jacobian matrices using JAX's [jacfwd()](https://jax.readthedocs.io/en/latest/jax.html#jax.jacfwd) and [jacrev()](https://jax.readthedocs.io/en/latest/jax.html#jax.jacrev).

In [None]:
from jax import jacfwd, jacrev

# Define some small random arrays with the PRNG key
size = 5
x = random.normal(key, (size, ), dtype=jnp.float32)
y = random.normal(key, (size, size), dtype=jnp.float32)

# Define a simple dot product multiply with JAX NumPy
def fun(x):
    return jnp.dot(jnp.dot(x, y), x)

# This is a Hessian function
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

# Define the Hessian computation
hess = hessian(fun)

# Compute and show the output
print(hess(x))

Hope you get the taste of what you can do with JAX's autodiff system. To learn more, check out [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html).

## Auto-vectorization with `vmap`

JAX has another transformation in its API that you might find useful: [vmap()](https://jax.readthedocs.io/en/latest/jax.html#jax.vmap). The vectorizing map promotes matrix-on-vector multiplies into matrix-on-matrix multiplies. It achieves that by adding a batch dimension to every primitive operation in the function. Using `vmap` can save you from having to carry around batch dimensions in your code. Compared to manual vectorization, this transform is more practical and faster particularly when building more complex neural networks.

`vmap` has the familiar semantics of mapping a function along array axes. It works a lot like a regular Python `map`. But, instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance. 

When composed with `jit`, it can be just as fast as adding the batch dimensions by hand.

You're going to work with a simple example of executing matrix-on-matrix multiplications by hand, which is easy to do by hand here. The same technique can apply to more complicated functions.

- Set up a simple dot product of matrices:

In [None]:
# Define a matrix - a random 150x100 NDArray with the PRNG key, as before
mat = random.normal(key, (150, 100))

# Define batched inputs - a random 10x100 NDArray
batched_x = random.normal(key, (10, 100))

# Define a dot product of a matrix on vector
def apply_matrix(v):
  return jnp.dot(mat, v)

- Given a function such as `apply_matrix`, you can loop over a batch dimension in Python (but, usually, the performance of doing so is poor):

In [None]:
# Define a naively batched dot product
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

# Measure the time to perform the computation, use `block_until_ready()`, as before
print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

- You probably know how to batch this operation manually. In this case, `jnp.dot` handles extra batch dimensions transparently as follows:

In [None]:
# Apply the @jit decorator to compile normal dot product ops with XLA
@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

# Check how much time it takes to compute
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

- However, suppose you had a more complicated function without batching support. You can use `vmap` to add batching support automatically:

In [None]:
# Apply the @jit decorator to compile auto-vectorized ops with XLA
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

# Check how much time it takes to compute
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Of course, `vmap` can be arbitrarily composed with `jit`, `grad`, forward- and reverse-mode autodiff for fast Jacobian (`jax.jacfwd` and `jax.jacrev`) and Hessian (`jax.hessian`) matrix calculations, and any other JAX transformation.

## Parallelization with `pmap`

Suppose you want to run your computations on multiple XLA accelerators—GPUs or TPUs. JAX has an API for that called [pmap()](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap) that lets you write single-program multiple-data (SPMD) programs. Google Cloud TPU pods are multi-host platforms and you can utilize them with this transform.

Similar to `jit`, when you apply `pmap` to a function, you compile with XLA. Then, you execute it in parallel on multiple devices.

Like `vmap`, this transform maps a function over array axes. But, unlike `vmap`, `pmap` replicates the function and executes each replica on its own XLA device in parallel. 

The transform also allows you to use all-reduce sum and other parallel SPMD collective operations.

### Enable Cloud TPUs in Google Colab

To take advantage of JAX's `pmap`, you'll be using Cloud TPUs in Google Colab:

- Change your Google Colab runtime by clicking **Edit** > **Notebook settings**
- Select **Hardware acceleration: TPU**
- If it says "Unable to connect to the runtime", click on the **Reconnect** button in the top right corner and wait until it has a green check mark (✓)

Then, run the cell below to further configure the TPU support:

In [None]:
# Run this cell inside Google Colab with TPUs enabled
import requests
import os

# Import JAX libraries again, just in case (runtime switching resets the state)
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, pmap, random

if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

# This is required to use TPU Driver as JAX's backend
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

Your output should have the Colab TPU IP address (gRPC) and port number.

- To list available devices for the backend, use [jax.devices()](https://jax.readthedocs.io/en/latest/jax.html#jax.devices):

In [None]:
jax.devices

(The default backend is 'gpu' or 'tpu', if available. Otherwise it's 'cpu'.)

Let's go over some `pmap` examples.

- For a number of XLA devices available (e.g. 8), you can use `pmap` to map along a leading array axis:

In [None]:
y = pmap(lambda x: x ** 2)(jnp.arange(8))

print(y)

- Run a dot product multiply on several devices:

In [None]:
# Create a PRNG key and split it into 8 new ones
keys = random.split(random.PRNGKey(0), 8)

# Complile a new 5000x5000 matrix with `pmap`
size = 5000
matrices = pmap(lambda key: random.normal(key, (size, size)))(keys)

# Run the dot product of two matrices in parallel on XLA accelerators
result = pmap(jnp.dot)(matrices, matrices)

# Return the mean of the result and show the output
print(pmap(jnp.mean)(result))

- Measure the compute time:

In [None]:
timeit -n 5 -r 5 pmap(jnp.dot)(matrices, matrices).block_until_ready()

- Compose autodiff (`grad`) with `pmap`:

In [None]:
# Define an arbitrary array and reshape it
x = jnp.arange(8.).reshape((4, 2))

# Apply the @pmap decorator to run calculations across devices
@pmap
def f(x):
  y = jnp.sin(x)
  @pmap
  def g(z):
    return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
  # Call `grad`!
  return grad(lambda w: jnp.sum(g(w)))(x)

f(x)

The code here is simple. For a neural network example where you can do some data-parallel neural network training, check out [the SPMD MNIST example](https://github.com/google/jax/blob/master/examples/spmd_mnist_classifier_fromscratch.py) or the much more capable [Trax library](https://github.com/google/trax/).

This is just a taste of what JAX can do. We're really excited to see what you do with it!