In [1]:
import os 
import math 
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline 

import matplotlib
from IPython.display import set_matplotlib_formats 
from matplotlib.colors import to_rgba
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.set_context('talk')
sns.set_style('whitegrid')  

import time 
from tqdm.auto import tqdm

JAX Framework API for dealing with data array is ``jax.numpy``

In [2]:
import jax 
import jax.numpy as jnp 
print("jax version: ", jax.__version__)

jax version:  0.4.30


In [3]:
#Creating an array of zeros with shape [2,5] 

a = jnp.zeros((2,5), dtype=jnp.float32)
print(a)

[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]


In [4]:
#creating an array with values of 0 to 5 using jnp.arange
b = jnp.arange(6)
print(b)

[0 1 2 3 4 5]


In [5]:
b.__class__

jaxlib.xla_extension.ArrayImpl

JAX can execute the same code on different backends - CPU, GPU, TPU. ``ArrayImpl`` represents an array which is on one of the backends

In [6]:
b.devices()

{CpuDevice(id=0)}

In [7]:
#to change the device of the array, we can get 

b_cpu = jax.device_get(b)
print(b_cpu.__class__)

<class 'numpy.ndarray'>


In [8]:
#To explicitly push a numpy array to the GPU, 
# we can use ``jax.device_put``

b_gpu = jax.device_put(b_cpu)
print(f"Device Put:{b_gpu.__class__} on {b_gpu.devices()}")

Device Put:<class 'jaxlib.xla_extension.ArrayImpl'> on {CpuDevice(id=0)}


JAX can handle any device class when you try to perform a numpy operation

In [9]:
b_cpu + b_gpu

Array([ 0,  2,  4,  6,  8, 10], dtype=int32)

For instance, if we call ``out = jnp.matmul(b, b)``, JAX first returns a placeholder array for out which may not be filled with the values as soon as the function calls finishes. This way, Python will not block the execution of follow-up statements, but instead only does it whenever we strictly need the value of ``out``, for instance for printing or putting it on CPU. PyTorch uses a very similar principle to allow asynchronous computation.

#### Immutable Tensors 

``In-place`` operations are not allowed like ``b[0]=1``. JAX requires programs to be "pure" functions.

So, we can use ``b.at[0].set(1)`` and its returns a new array 

In [10]:
b_new = b.at[0].set(1)
print("original array: ", b)
print("new array: ", b_new)

original array:  [0 1 2 3 4 5]
new array:  [1 1 2 3 4 5]


#### Psuedo Random Numbers in JAX 

``np.random.normal()`` will 5 different numbers at 5 separate call, as every execution changes the state/seed of the pseudo random number generator (PRNG)

JAX solves it by explicitly passing and iterating PRNG state

In [11]:
rng = jax.random.PRNGKey(42)

In [12]:
# A non-desirable way of generating psuedo random numbers

jax_random_number_1 = jax.random.normal(rng)
jax_random_number_2 = jax.random.normal(rng) 
print("JAX - Random Number 1: ", jax_random_number_1)
print("JAX - Random Number 2: ", jax_random_number_2)

#Typical random numbers in Numpy 
np.random.seed(42)
np_random_number_1 = np.random.normal()
np_random_number_2 = np.random.normal()
print("Numpy - Random Number 1: ", np_random_number_1)
print("Numpy - Random Number 2: ", np_random_number_2)

JAX - Random Number 1:  -0.18471177
JAX - Random Number 2:  -0.18471177
Numpy - Random Number 1:  0.4967141530112327
Numpy - Random Number 2:  -0.13826430117118466


To get different random number every time we sample, we can _split_ the PRNG state to get usable subkeys every time we need a pseudo-random number

In [13]:
rng, subkey1, subkey2 = jax.random.split(rng, num=3) 
jax_random_number_1 = jax.random.normal(subkey1)
jax_random_number_2 = jax.random.normal(subkey2)
print("JAX - Random Number 1: ", jax_random_number_1)
print("JAX - Random Number 2: ", jax_random_number_2)

JAX - Random Number 1:  0.107961535
JAX - Random Number 2:  -1.2226542


FUNCTIONs with JAXPR

We can check which operations are performed on which array, and what shapes the arrays are. 

consider function: 

$$y = \frac{1}{\lvert x \rvert} \sum_{i} [(x_{i}+2)^2 + 3]$$

In [14]:
def simple_graph(x):
    x = x+2
    x = x ** 2
    x = x + 3
    y = x.mean()
    return y 

inp = jnp.arange(3, dtype=jnp.float32)
print("Input", inp)
print("Output", simple_graph(inp))

Input [0. 1. 2.]
Output 12.666667


In [15]:
#To view JAXPR representation, we can use `jax.make_jaxpr`. 
#Since, tracing depends on the shape of the input,
#we need to pass an input.

jax.make_jaxpr(simple_graph)(inp)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[3][39m = add a 2.0
    c[35m:f32[3][39m = integer_pow[y=2] b
    d[35m:f32[3][39m = add c 3.0
    e[35m:f32[][39m = reduce_sum[axes=(0,)] d
    f[35m:f32[][39m = div e 3.0
  [34m[22m[1min [39m[22m[22m(f,) }

In [16]:
#Another example of a simple graph 

global_list = []

#invalid function with side-effect 
def norm(x):
    global_list.append(x)
    x = x ** 2
    n = x.sum()
    n = jnp.sqrt(n)
    return n 

jax.make_jaxpr(norm)(inp)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[3][39m = integer_pow[y=2] a
    c[35m:f32[][39m = reduce_sum[axes=(0,)] b
    d[35m:f32[][39m = sqrt c
  [34m[22m[1min [39m[22m[22m(d,) }

#### Automatic Differentiation

Instead of backpropagating gradients through tensors, JAX takes as input a function, and outputs another function which directly calculates the gradients for it. It is done by ``jax.grad``

In [17]:
grad_function = jax.grad(simple_graph)
gradients = grad_function(inp)
print("Gradients: ", gradients)

Gradients:  [1.3333334 2.        2.6666667]


In [18]:
#get jaxpr representation of the gradient function
jax.make_jaxpr(grad_function)(inp)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[3][39m = add a 2.0
    c[35m:f32[3][39m = integer_pow[y=2] b
    d[35m:f32[3][39m = integer_pow[y=1] b
    e[35m:f32[3][39m = mul 2.0 d
    f[35m:f32[3][39m = add c 3.0
    g[35m:f32[][39m = reduce_sum[axes=(0,)] f
    _[35m:f32[][39m = div g 3.0
    h[35m:f32[][39m = div 1.0 3.0
    i[35m:f32[3][39m = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] h
    j[35m:f32[3][39m = mul i e
  [34m[22m[1min [39m[22m[22m(j,) }

Often, we dont only want the gradients, but also the actual output of the function, for  instance for logging the loss. It is done by ``jax.value_and_grad``

In [19]:
val_grad_function = jax.value_and_grad(simple_graph)
val_grad_function(inp)

(Array(12.666667, dtype=float32),
 Array([1.3333334, 2.       , 2.6666667], dtype=float32))

#### Speeding Up Computation with JUST-IN-TIME

JAX takes full advantage of the available accelerator hardware, by compiling functions just-in-time with XLA (Accelerated Linear Algebra), using their JAXPR representation.

It is done by ``jax.jit`` which can be either applied directly on function or used as the decorator before a function

In [20]:
jitted_function = jax.jit(simple_graph)

In [21]:
#create a new random subkey for generatting new random values 

rng, normal_rng = jax.random.split(rng)
large_input = jax.random.normal(normal_rng, (1000,))
#run the jitted function once to start compilation
_ = jitted_function(large_input)

In [22]:
%%timeit
simple_graph(large_input).block_until_ready()

65.2 µs ± 1.2 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [23]:
%%timeit
jitted_function(large_input).block_until_ready()

4.35 µs ± 40.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


Also , we can apply multiple transformations on the same function in JAX such as applying ``jax.jit`` on a gradient function

In [24]:
jitted_grad_function = jax.jit(grad_function)
_ = jitted_grad_function(large_input)

In [25]:
%%timeit 
grad_function(large_input).block_until_ready()

1.95 ms ± 54.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [26]:
%%timeit 
jitted_grad_function(large_input).block_until_ready()

3.87 µs ± 89.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
