In [None]:
# standard libs
import os
import math
import time

import numpy as np

# imports for plotting
import matplotlib.pyplot as plt

%matplotlib inline
from IPython.display import set_matplotlib_formats

set_matplotlib_formats("svg", "pdf")
from matplotlib.colors import to_rgba
import seaborn as sns

sns.set_theme()

# progress bar
from tqdm.auto import tqdm

  set_matplotlib_formats('svg', 'pdf')


# Jax as NumPy on accelerators

In [None]:
import jax
import jax.numpy as jnp

print("Using jax", jax.__version__)

Using jax 0.4.37


## Device Arrays

In [6]:
a = jnp.zeros((2, 5), dtype=jnp.float32)
print(a)

[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]


In [7]:
b = jnp.arange(6)
print(b)

[0 1 2 3 4 5]


In [8]:
b.__class__

jaxlib.xla_extension.ArrayImpl

In [10]:
b.device

CudaDevice(id=0)

In [11]:
b_cpu = jax.device_get(b)
print(b_cpu.__class__)

<class 'numpy.ndarray'>


In [None]:
b_gpu = jax.device_put(b_cpu)
print(f"Device put: {b_gpu.__class__} on {b_gpu.device}")

Device put: <class 'jaxlib.xla_extension.ArrayImpl'> on cuda:0


In [15]:
b_cpu + b_gpu, (b_cpu + b_gpu).device

(Array([ 0,  2,  4,  6,  8, 10], dtype=int32), CudaDevice(id=0))

In [17]:
jax.devices()

[CudaDevice(id=0)]

## [JAX - Asynchronous Dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)

In [18]:
out = jnp.matmul(b, b)

## Immutable tensors: Pure Function Language

However, we said that JAX is very efficient. Isn’t creating a new array in this case the opposite? While it is indeed less efficient, it can made much more efficient with JAX’s just-in-time compilation. The compiler can recognize unnecessary array duplications, and replace them with in-place operations again. More on the just-in-time compilation later!

In [None]:
b_new = b.at[0].set(1)  # returns a new array
print("Original Array:", b)
print("Changed Array:", b_new)

Original Array: [0 1 2 3 4 5]
Changed Array: [1 1 2 3 4 5]


## Pseudo Random Numbers in JAX: For Pure Function

In libraries like NumPy and PyTorch, the random number generator are controlled by a seed, which we set initially to obtain the same samples every time we run the code (this is why the numbers are not truly random, hence “pseudo”-random). However, if we call np.random.normal() 5 times consecutively, we will get 5 different numbers since every execution changes the state/seed of the pseudo random number generation (PRNG). In JAX, if we would try to generate a random number with this approach, a function creating pseudo-random number would have an effect outside of it. To prevent this, JAX takes a different approach by explicitly passing and iterating the PRNG state. First, let’s create a PRNG for the seed 42:

In [22]:
rng = jax.random.PRNGKey(42)

In [None]:
# A non-desirable way of generating pseudo-random numbers...
jax_random_number_1 = jax.random.normal(rng)
jax_random_number_2 = jax.random.normal(rng)
print("JAX - Random number 1:", jax_random_number_1)
print("JAX - Random number 2:", jax_random_number_2)

# Typical random numbers in NumPy
np.random.seed(42)
np_random_number_1 = np.random.normal()
np_random_number_2 = np.random.normal()
print("NumPy - Random number 1:", np_random_number_1)
print("NumPy - Random number 2:", np_random_number_2)

JAX - Random number 1: -0.18471177
JAX - Random number 2: -0.18471177
NumPy - Random number 1: 0.4967141530112327
NumPy - Random number 2: -0.13826430117118466


In general, you want to split the PRNG key every time before generating a pseudo-number, to prevent accidentally obtaining the exact same numbers (for instance, sampling the exact same dropout mask every time you run the network makes dropout itself quite useless…).

In [None]:
rng, subkey1, subkey2 = jax.random.split(rng, num=3)  # We create 3 new keys
jax_random_number_1 = jax.random.normal(subkey1)
jax_random_number_2 = jax.random.normal(subkey2)
print("JAX new - Random number 1:", jax_random_number_1)
print("JAX new - Random number 2:", jax_random_number_2)

JAX new - Random number 1: 0.107961535
JAX new - Random number 2: -1.2226542


# Function Transformations with Jaxpr

The most important difference, and in some sense the root of all the rest, is that JAX is designed to be functional, as in functional programming. The reason behind this is that the kinds of program transformations that JAX enables are much more feasible in functional-style programs. […] The important feature of functional programming to grok when working with JAX is very simple: don’t write code with side-effects.

Essentially, we want to write our main code of JAX in functions that do not affect anything else besides its outputs. For instance, we do not want to change input arrays in-place, or access global variables. While this might seem limiting at first, you get used to this quite quickly and most JAX functions that need to fulfill these constraints can be written this way without problems. Note that not all possible functions in training a neural network need to fulfill the constraints. For instance, loading or saving of models, the logging, or the data generation can be done in naive functions. Only the network execution, which we want to do very efficiently on our accelerator (GPU or TPU), should strictly follow these constraints.

``jaxpr``, conceptually, you can think of any operation that JAX does on a function, as first trace-specializing the Python function to be transformed into a small and well-behaved intermediate form. This means that we check which operations are performed on which array, and what shapes the arrays are. Based on this representation, JAX then interprets the function with transformation-specific interpretation rules, which includes automatic differentiation or compiling a function in XLA to efficiently use the accelerator.

In [None]:
def simple_graph(x):
    x = x + 2
    x = x**2
    x = x + 3
    y = x.mean()
    return y


inp = jnp.arange(3, dtype=jnp.float32)
print("Input:", inp)
print("Output:", simple_graph(inp))

Input: [0. 1. 2.]
Output: 12.666667


In [26]:
jax.make_jaxpr(simple_graph)(inp)

{ lambda ; a:f32[3]. let
    b:f32[3] = add a 2.0
    c:f32[3] = integer_pow[y=2] b
    d:f32[3] = add c 3.0
    e:f32[] = reduce_sum[axes=(0,)] d
    f:f32[] = div e 3.0
  in (f,) }

A jaxpr representation follows the structure:
```
jaxpr ::= { lambda Var* ; Var+.
            let Eqn*
            in  [Expr+] }
```

where `Var*` are constants and `Var+` are input arguments. In the cell above, this is `a:f32[3]`, i.e. an array of shape 3 with type `jnp.float32` (`inp`). The list of equations, `Eqn*`, define the intermediate results of the function. You can see that each operation in `simple_graph` is translated to a corresponding equation, like `x = x + 2` is translated to `b:f32[3] = add a 2.0`. Furthermore, you see the specialization of the operations on the input shape, like `x.mean()` being replacing in `e` and `f` with summing and dividing by 3. Finally, `Expr+` in the jaxpr representation are the outputs of the functions. In the example, this is `f`, i.e. the final result of the function. Based on these atomic operations, JAX offers all kind of function transformations, of which we will discuss the most important ones later in this section. Hence, you can consider the jaxpr representation is an intermediate compilation stage of JAX. What happens if we actually try to look at the jaxpr representation of a function with side-effect? Let’s consider the following function, which, as an illustrative example, appends the input to a global list:

In [None]:
global_list = []


# invalid function with side-effect
def norm(x):
    global_list.append(x)  # unavailable
    x = x**2
    n = x.sum()
    n = jnp.sqrt(n)
    return n

jax.make_jaxpr(norm)(inp)

{ lambda ; a:f32[3]. let
    b:f32[3] = integer_pow[y=2] a
    c:f32[] = reduce_sum[axes=(0,)] b
    d:f32[] = sqrt c
  in (d,) }

## Automatic Differentiation

In frameworks like PyTorch with a dynamic computation graph, we would compute the gradients based on the loss tensor itself, e.g. by calling loss.backward(). However, JAX directly works with functions. Instead of backpropagating gradients through tensors, JAX takes as input a function, and outputs another function which directly calculates the gradients for it. While this might seem quite different to what you are used to from other frameworks, it is quite intuitive: your gradient of parameters is really a function of parameters and data.

In [28]:
grad_function = jax.grad(simple_graph)
gradients = grad_function(inp)
print('Gradient', gradients)

Gradient [1.3333334 2.        2.6666667]


In [29]:
jax.make_jaxpr(grad_function)(inp)

{ lambda ; a:f32[3]. let
    b:f32[3] = add a 2.0
    c:f32[3] = integer_pow[y=2] b
    d:f32[3] = integer_pow[y=1] b
    e:f32[3] = mul 2.0 d
    f:f32[3] = add c 3.0
    g:f32[] = reduce_sum[axes=(0,)] f
    _:f32[] = div g 3.0
    h:f32[] = div 1.0 3.0
    i:f32[3] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(3,)
      sharding=None
    ] h
    j:f32[3] = mul i e
  in (j,) }

In [30]:
# get both grad and value
val_grad_function = jax.value_and_grad(simple_graph)
val_grad_function(inp)

(Array(12.666667, dtype=float32),
 Array([1.3333334, 2.       , 2.6666667], dtype=float32))

More topics like multi-input grad and PyTree will be mentioned later

## Jit: Just-In-Time