In [None]:
!pip install https://storage.googleapis.com/jax-wheels/cuda92/jax-0.0-py3-none-any.whl

# JAX Quickstart
dougalm@, phawkins@, mattjj@, frostig@, alexbw@

### TODO: LOGO

#### [JAX](http://go/jax) is NumPy on the CPU, GPU and TPU, with great automatic differentiation for high-performance machine learning research.

With its updated version of [Autograd](https://github.com/hips/autograd), JAX
can automatically differentiate native Python and NumPy code. It can
differentiate through a large subset of Python’s features, including loops, ifs,
recursion, and closures, and it can even take derivatives of derivatives of
derivatives. It supports reverse-mode as well as forward-mode differentiation, and the two can be composed arbitrarily
to any order.

What’s new is that JAX uses
[XLA](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/overview.md)
to compile and run your NumPy code on accelerators, like GPUs and TPUs.
Compilation happens under the hood by default, with library calls getting
just-in-time compiled and executed. But JAX even lets you just-in-time compile
your own Python functions into XLA-optimized kernels using a one-function API.
Compilation and automatic differentiation can be composed arbitrarily, so you
can express sophisticated algorithms and get maximal performance without having
to leave Python.


## The basics of JAX

In [None]:
from __future__ import print_function, division
import numpy as onp
from tqdm import tqdm
import jax.numpy as np
from jax import grad, jit, vmap
from jax import device_put
from jax import random

### Multiplying Matrices

We'll be generating random data in the following examples. One big difference between NumPy and JAX is how you ask for random numbers. We needed to make this change to support some of the great features we talk about below.

In [None]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

Let's dive right in and multiply two big matrices.

In [None]:
size = 3000
x = random.normal(key, (size, size), dtype=np.float32)
%timeit np.dot(x, x.T)  # runs on the GPU

JAX NumPy functions work on raw NumPy arrays as well.

In [None]:
import numpy as onp  # original CPU-backed NumPy
x = onp.random.normal(size=(size, size)).astype(onp.float32)
%timeit np.dot(x, x.T)

That's slower beacuse it has to transfer data to the GPU every time. You can ensure that an NDArray is backed by device memory using `device_put`.

In [None]:
x = onp.random.normal(size=(size, size)).astype(onp.float32)
x = device_put(x)
%timeit np.dot(x, x.T)

The output of `device_put` still acts like an NDArray. By the way, the implementation of `device_put` is just `device_put = jit(lambda x: x)`.

All of these calls above are faster than original NumPy on the CPU.

In [None]:
x = onp.random.normal(size=(size, size)).astype(onp.float32)
%timeit onp.dot(x, x.T)

JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numeric code. For now, there's three main ones:

 - `jit`, for speeding up your code
 - `grad`, for taking derivatives
 - `vmap`, for automatic vectorization or batching.

Let's go over these, one-by-one. We'll also end up composing these in interesting ways.

### Using `jit` to speed up functions

JAX runs transparently on the GPU (or CPU, if you don't have one, and TPU coming soon!). However, in the above example, JAX is dispatching kernels to the GPU one operation at a time. If we have a sequence of operations, JAX will incur overhead. Fortunately, JAX has a `@jit` decorator which will fuse multiple operations together. Let's try that.

In [None]:
def selu_raw(x, alpha=1.67, lmbda=1.05):
  return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

x = np.zeros(1000000)
%timeit selu_raw(x)

We can speed it up with @jit, which will jit-compile the first time `selu` is called and will be cached thereafter.

In [None]:
selu = jit(selu)
%timeit selu(x)

### Taking derivatives with `grad`

We don't just want to compute with NumPy arrays, we also want to tranform numeric programs, like by taking their derivative. In JAX, just like in Autograd, there is a one-function API for taking derivatives: the `grad` function.

In [None]:
def sum_logistic(x):
  return np.sum(1.0 / (1.0 + np.exp(-x)))

x_small = np.ones((3,))
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

Let's verify with finite differences that our result is correct.

In [None]:
def first_finite_differences(f, x):
  eps = 1e-3
  return np.array([(f(x + eps * basis_vect) - f(x - eps * basis_vect)) / (2 * eps)
                   for basis_vect in onp.eye(len(x))])


print(first_finite_differences(sum_logistic, x_small))

Taking derivatives is as easy as calling `grad`. `grad` and `jit` compose and can be mixed arbitrarily. In the above example we jitted `sum_logistic` and then took its derivative. We can go further:

In [None]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

For more advanced autodiff, you can use `jax.vjp` for reverse-mode vector-Jacobian products and `jax.jvp` for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. We used them with `vmap` (which we'll describe in a moment) to write `jax.jacfwd` and `jax.jacrev` for computing full Jacobian matrices. Here's one way to compose those to make a function that efficiently computes full Hessian matrices:

In [None]:
from jax import jacfwd, jacrev
def hessian(fun):
  return jacfwd(jacrev(fun))

### Auto-vectorization with `vmap`

JAX has one more transformation in its API that you might find useful: `vmap`, the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance. When composed with `jit`, it can be just as fast as adding the batch dimensions by hand.

We're going to work with a simple example, and promote matrix-vector products into matrix-matrix products using vmap. Although this is trivial to do by hand in this specific case, the same technique can apply to more complicated functions.

In [None]:
def apply_matrix(v):
  return np.dot(mat, v)

In [None]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

In [None]:
def naively_batched_apply_matrix(v_batched):
  return np.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x)

In [None]:
@jit
def batched_apply_matrix(v_batched):
  return np.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x)

In [None]:
@jit
def vmap_batched_apply_matrix(batched_x):
  return vmap(apply_matrix, batched_x)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x)

Of course, `vmap` can be arbitrarily composed with `jit`, `grad`, and any other JAX transformation. Now, let's put it all together.



We've now used the whole of the JAX API: `grad` for derivatives, `jit` for speedups and `vmap` for auto-vectorization.
We used NumPy to specify all of our computation, and borrowed the great data loaders from PyTorch, and ran the whole thing on the GPU.