## Some resources:
https://jax.readthedocs.io/en/latest/notebooks/quickstart.html
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Random-Numbers


# What is JAX?

"JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research." - JAX docs


It has anupdated version of autograd can differentiate native python and numpy code, it can also differentiate through a lot of python features (loops, ifs, recursion, etc.) and keep taking nested derivatives. Supports reverse and forward mode differentiation, can compose both in arbitrary order.

JAX uses XLA to compile and run numpy code on accelerators (GPUs, TPUs). happens undere the hood with just-in-time compilation (JIT). Allows for great performance in python.

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [3]:
# Multipyling Matrices
# NOTE: Big difference between numpy and JAX on how you generate random numbers
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [5]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU --> not true on mac
# Block_until_ready since we are using asynchronous execution by default with JAX

97.6 ms ± 1.06 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [6]:
# Jax numpy functions work on normal numpy arrays but is slower normally since it has to 
# transfer data to GPU everytime
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

104 ms ± 3.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
# Ensure that NDArray is backed by device memory using device_put()
from jax import device_put
x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

81.6 ms ± 883 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


three main program transformations that are useful when writing numerical code:

jit(), for speeding up your code
    
grad(), for taking derivatives
    
vmap(), for automatic vectorization or batching.

In [9]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

1.24 ms ± 18.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

309 µs ± 5.23 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [11]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))


[0.25       0.19661197 0.10499357]


In [12]:
def first_finite_differences(f, x):
  eps = 1e-3
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])


print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1964569  0.10502338]


In [13]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.035325598


In [14]:
# computes full hessian
from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

In [15]:
# vmap: the vectorizing map
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

In [16]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
430 µs ± 546 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [17]:
@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
8.03 µs ± 18.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [18]:
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
12.1 µs ± 15.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
