# Discovering JAX

Google JAX is a machine learning framework for **transforming numerical functions**. The primary functions of JAX are:

1. grad: automatic differentiation
2. jit: compilation
3. vmap: auto-vectorization
4. pmap: SPMD programming

### Sources and interesting links

* wikipedia <https://en.wikipedia.org/wiki/Google_JAX>
* <https://jax.readthedocs.io/en/latest/notebooks/quickstart.html>
* <https://colinraffel.com/blog/you-don-t-know-jax.html>

## Import packages

In [5]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np

## JAX Numpy

`jax.numpy` and `numpy` objects and operations are often interchangeable thanks to duck-typing, a property of python that allows to operate on objects if they contain appropriate attributes and methods, rather than being of a specific type of object.

Working with JAX arrays is computationally more efficient when `jit` is intensively used. `numpy` works only on CPUs, while `jax.numpy` also works on GPUs and TPUs.

JAX arrays are immutable, contrary to numpy arrays. This is a condition for using `jit`. In practice, changing one value in a JAX array means creating a new array.

In [27]:
# NumPy:
x = np.arange(10)
x[0] = 10
print(x)

[10  1  2  3  4  5  6  7  8  9]


In [28]:
# JAX:
y = jnp.arange(10)
y = y.at[0].set(10)
print(y)

[10  1  2  3  4  5  6  7  8  9]


In [29]:
# numpy and JAX are interchangeable in many places:
print(jnp.exp(x))    # JAX sine function aplied to numpy array
print(np.exp(y))     # Numpy sine function applied to JAX array

[2.2026465e+04 2.7182817e+00 7.3890562e+00 2.0085537e+01 5.4598152e+01
 1.4841316e+02 4.0342880e+02 1.0966332e+03 2.9809580e+03 8.1030840e+03]
[2.20264658e+04 2.71828183e+00 7.38905610e+00 2.00855369e+01
 5.45981500e+01 1.48413159e+02 4.03428793e+02 1.09663316e+03
 2.98095799e+03 8.10308393e+03]


In [38]:
y.at[2].get()

DeviceArray(2, dtype=int32)

In [43]:
z=jnp.roll(y,2)
print(z)
y[2]+z[2]

[ 8  9 10  1  2  3  4  5  6  7]


DeviceArray(12, dtype=int32)

## Gradient calculation

An example of automatic differentiation.

By default, the gradient is taken with respect to the first argument; this can be controlled via the argnums argument to `jax.grad`. 

In [14]:
# define the logistic function
def logistic(x):  
    return jnp.exp(x) / (jnp.exp(x) + 1)

def sum_logistic(x):
    return jnp.sum(logistic(x))

# obtain the gradient functions
grad_logistic = grad(logistic)
grad_sum_logistic = grad(sum_logistic)

# evaluate the gradient of the logistic function at x = 1 
print( grad_logistic(1.0) )

# evaluate the gradient of the sum_logistic function at [0,1,2]
# Note that the array is created with jnp
x_small = jnp.arange(3.)
print( grad_sum_logistic(x_small) )

0.19661194
[0.25       0.19661194 0.10499358]


In [20]:
x = np.arange(4.)
grad_sum_logistic(x)

DeviceArray([0.25      , 0.19661194, 0.10499358, 0.04517668], dtype=float32)

## JIT: Just-In-Time compilation

JIT is an important ingredient for JAX numerical efficiency, but it takes a little bit of time and effort (and pain) to understand how to use it efficiently and without errors.

When a numpy function is called, it is first compiled, then applied to the arguments. This is done each time the function is called. JIT offers the possibility to compile a function once for all when it is first called. At future calls, it is not re-compiled. This saves time but imposes a certain number of constraints on the arguments (In particular, arrays must be immutable) and the type of operations performed in the function.

Check <https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html>.



In [35]:
# define the cube function
def cube(x):
    return x * x * x

# generate data
x = jnp.ones((20000, 10000))

# create the jit version of the cube function
jit_cube = jit(cube)

# apply the cube and jit_cube functions to the same data for spreed comoparion
for i in range(2):
    print('numpy: ',cube(x)[0,0])
    print('JAX: ',jit_cube(x)[0,0])

numpy:  1.0
JAX:  1.0
numpy:  1.0
JAX:  1.0
numpy:  1.0
JAX:  1.0
numpy:  1.0
JAX:  1.0


## VMAP: Vectorization

To be developed. Vectorization basically consists in organizing operations for easy parallelization.