### Part 0: Jax/Numpy API

JAX on Princeton Research Computing systems

https://github.com/PrincetonUniversity/intro_ml_libs/tree/master/jax

`pip install -U "jax[cuda12]"`

In [1]:
import jax.numpy as jnp
import numpy as onp
import matplotlib.pyplot as plt
import jax

try:
    import rich
except ModuleNotFoundError:
    print("rich not found, install it with pip install rich")
    !pip install rich

In [None]:
x = jnp.array([1.0, 2.0, 3.0])

print(f"{x=}")
print(f"{type(x)=}")

Most numpy functions are available with the `jax.numpy` namespace.

In [None]:
jnp.square(x)

Use `onp` to convert Jax arrays to Numpy arrays.

In [None]:
onp.square(x)

In [None]:
type(onp.square(x))

Unlike NumPy, JAX arrays are immutable,

In [None]:
x_onp = onp.arange(0.0, 10.0)
x_onp[:5] = -1.0
x_onp

In [None]:
x_jnp = jnp.arange(0.0, 10.0)
x_jnp[:5] = -1.0
x_jnp

meaning any modification requires creating a new array rather than altering the original.

In [None]:
x_jnp = jnp.arange(0.0, 10.0)
x_jnp = x_jnp.at[:5].set(-1.0)
x_jnp

### Part 1: Jax as a tool for computing gradients

In [9]:
from jax import grad

Consider the function

$f(x, y, z) = \sin(x) + e^y + \sqrt{z}.$

In [10]:
def f(X):
    x, y, z = X
    return jnp.sin(x) + jnp.exp(y) + jnp.sqrt(z)

It has partial derivatives

$\frac{\partial f}{\partial x} = \cos(x)$

$\frac{\partial f}{\partial y} = e^y$

$\frac{\partial f}{\partial z} = \frac{1}{2\sqrt{z}},$

which can be computed exactly with `jax.grad`:

In [None]:
dfdX = grad(f)

dfdX(jnp.array([0.0, 0.0, 0.0]))

Now let's add a parameter $h$:

$f(x, y, z; h) = \sin(x-h) + e^{y-h} + \sqrt{z-h}.$

This gives a new partial derivative

$\frac{\partial f}{\partial h} = -\cos(x-h) - e^{y-h} - \frac{1}{2\sqrt{z-h}}.$

We can use automatic differentiation to compute $\frac{\partial f}{\partial h}$ by specifying `argnums`:

In [None]:
def f(X, h):
    x, y, z = X
    return jnp.sin(x - h) + jnp.exp(y - h) + jnp.sqrt(z - h)


grad(f, argnums=1)(jnp.array([0.0, 0.0, 0.25]), 0.0)

Now, let's talk Jacobians and Hessians.

In [13]:
from jax import jacobian, hessian

Let's consider a vector-valued function

$$
\vec{f}(u, v) =
\left[\begin{array}{c} 
f_1(u,v)\\
f_2(u,v)
\end{array}\right]=
\left[\begin{array}{c} 
e^u \cos(v)\\
e^u \sin(v)
\end{array}\right].
$$ 

In [14]:
def fvec(X):
    u, v = X
    return jnp.array([jnp.exp(u) * jnp.cos(v), jnp.exp(u) * jnp.sin(v)])

The Jacobian of $\vec{f}$ is written

$$
\mathbf{J}_{\vec{f}}=
\left[\begin{array}{cc} 
\partial f_1 / \partial u & \partial f_1 / \partial v \\
\partial f_2 / \partial u & \partial f_2 / \partial v
\end{array}\right]=
\left[\begin{array}{cc} 
e^u \cos(v) & -e^u \sin(v)\\
e^u \sin(v) & e^u \cos(v)
\end{array}\right]
$$ 

and can be exactly computed with `jax.jacobian`:

In [None]:
fjac = jacobian(fvec)
fjac(jnp.array([0.0, 0.0]))

The function `jax.jacobian` is an alias for `jax.jacrev`, which computes gradients using reverse-mode automatic differentiation. In contrast, `jax.jacfwd` performs forward-mode differentiation. Reverse-mode (`jax.jacrev`) is more efficient for wide matrices, while forward-mode (`jax.jacfwd`) is better suited for tall matrices.

The Hessian of $\vec{f}$,
$$
\mathbf{H}_{\vec{f}}=
\left[
\left[\begin{array}{cc} 
\partial^2 f_1 / \partial u^2 & \partial^2 f_1 / \partial u \partial v\\
\partial^2 f_1 / \partial v \partial u & \partial^2 f_1 / \partial v^2
\end{array}\right],
\left[\begin{array}{cc} 
\partial^2 f_2 / \partial u^2 & \partial^2 f_2 / \partial u \partial v\\
\partial^2 f_2 / \partial v \partial u & \partial^2 f_2 / \partial v^2
\end{array}\right]
\right]=
\left[
\left[\begin{array}{cc} 
e^u \cos(v) & -e^u \sin(v)\\
-e^u \sin(v) & -e^u \cos(v)
\end{array}\right],
\left[\begin{array}{cc} 
e^u \sin(v) & e^u \cos(v)\\
e^u \cos(v) & -e^u \sin(v)
\end{array}\right]
\right],
$$

can be computed exactly with `jax.hessian`:

In [None]:
fhes = hessian(fvec)
fhes(jnp.array([0.0, 0.0]))

### Part 2: Banjamin

JAX has a much more evolved approach to **random number generation** than NumPy, design to allow parallel random number generators. It is also needed to statisfy the JAX pure function approach.

To use any stochastic function in JAX, you need to specify a key, which is a seed that the function can then use internally. So let's create one:

In [None]:
rng_key = jax.random.PRNGKey(42) # Because 42 is the answer
print(rng_key)

In [None]:
key1, key2, rng_key = jax.random.split(rng_key,   3)
print(key1, key2)

In [None]:
# Create some random matrices
A = jax.random.normal(key1, [500,1000])
B = jax.random.normal(key2, [1000, 500])
A, B

JAX enables operations to execute on CPU/GPU/TPU using the same code thanks to XLA. XLA (Accelerated Linear Algebra) is an open-source compiler for machine learning.

When you execute JAX code without JIT, you run through the code at the Python level, until you encounter the low level XLA interface, which is hidden behind the numpy API. At that point, the XLA bit of code is executed, and the result is returned to Python which continues to run through the next commands.

Let us look at the original numpy computation

In [None]:
A_np = onp.array(A)
B_np = onp.array(B)

def func_onp(A, B):
  C = onp.dot(A, B)
  C = onp.where(C>0, C, 0)
  return C


%timeit func_onp(A_np, B_np)

In [45]:
def func(A, B):
  C = jnp.dot(A, B)
  C = jnp.where(C>0, C, 0)
  return C

In [None]:
%timeit func(A, B).block_until_ready()

This can be pretty slow because the execution is still driven by Python.

The idea of JAX is that you can `jit` a big function, to turn it into a single,compiled XLA graph, that runs without needing Python

In [None]:
jitted_func = jax.jit(func) # returns another function
%time jitted_func(A, B).block_until_ready(); # First execution won't be fast

In [None]:
%timeit jitted_func(A,B).block_until_ready() #  Next calls are fast

JAX is based on a **purely functional** approach, with no side effects. A pure function is a function where the output only depends on the inputs of the function.

Let's see what we mean by that by creating a simple jitted function:

In [None]:
@jax.jit
def my_func(x):
  print("Baguette")
  return 2*jnp.abs(x)


y = my_func(0) # The first time I execute it I get:

In [24]:
y = my_func(1) #  Second time: I see no print!

In [None]:
y = my_func(2.0) # Third time: The print is back???

You can compute `grad`, `jit` and `vmap` **with respect to objects**!

In [None]:
def myfunc(cosmo):

  return 2* cosmo["sigma8"]**2 + 1

cosmo = {"sigma8": 0.8, "omega_m":0.3}

jax.jit(jax.grad(myfunc))(cosmo)

### Part 3: Matt

A very useful functionality of JAX is its automated vectorization with `jax.vmap` 
Here lets take a function that won't trivially work by just feeding in a 2D array of inputs, a weighted mean function.

In [None]:
a = jnp.array([1.0, 4.0, 0.5])
b = jnp.arange(5, 10, dtype=jnp.float32)


def weighted_mean(a, b):
    output = []
    for idx in range(1, b.shape[0] - 1):
        output.append(jnp.mean(a + b[idx - 1 : idx + 2]))
    return jnp.array(output)


print(f"a shape: {a.shape}")
print(f"b shape: {b.shape}")
output = weighted_mean(a, b)
print(f"output: {output.shape}")
print(f"output: {output}")

We see this works as expected with both inputs `a` and `b` as arrays. But what if we wanted to compute this for a large set of 1D arrays? 

This is where we can use `jax.vmap` to vectorize this calculation without us making any alterations to the function itself! To do this we simply make a stack of $n \times d$ arrays for each input.

In [None]:
# Let's include the batch dim to the inputs
n = 10 # number of elements in the stack
stacked_a = jnp.stack([a] * n)
stacked_b = jnp.stack([b] * n)

# lets show that the input arrays can be different 
for i in range(n):
    stacked_b = stacked_b.at[i].set(b + i)

print(f"stacked_a shape: {stacked_a.shape}")
print(f"stacked_b shape: {stacked_b.shape}")

If we try to use our original function this clearly will error

In [None]:
try:
    output = weighted_mean(stacked_a, stacked_a)
except:
    print("Error: the function doesn't take this input")

But, we can use `jax` to perform a vectorization behind the scenes and allow us to call out original function of the stacked dataset

In [None]:
stacked_output = jax.vmap(weighted_mean)(stacked_a, stacked_b)
print(f"stacked output shape: {stacked_output.shape}")
print(f"stacked output:")
print(stacked_output)

In [None]:
def plus_one(n, stacked_b, b):
    for i in range(n):
        stacked_b = stacked_b.at[i].set(b + i)
    return stacked_b

jit_plus_one = jax.jit(plus_one, static_argnums=1)

In [None]:
# lets compare the runtimes
import time

n = [1000, 2000, 4000, 8000, 16000, 32000, 64000]
time_loop = []
time_vmap = []
for iters in n:
    start = time.time()
    for i in range(iters):
        weighted_mean(a, b)
    time_loop.append(time.time() - start)

    batch_size = iters
    batched_a = jnp.stack([a] * batch_size)
    batched_b = jnp.stack([b] * batch_size)
    # use our jitted plus one function for speed 
    batched_b = plus_one(batch_size, batched_b, b)
    start = time.time()
    jax.vmap(weighted_mean)(batched_a, batched_b)
    time_vmap.append(time.time() - start)

In [None]:
fig = plt.figure(figsize=(12, 8))
plt.plot(n, time_loop, label="for loop", alpha=0.5, c="firebrick")
plt.scatter(n, time_loop, c="firebrick")
plt.plot(n, time_vmap, label="vmap", alpha=0.5, c="cornflowerblue")
plt.scatter(n, time_vmap, c="cornflowerblue")
plt.yscale("log")
plt.xscale("log")
plt.xlabel("number of function calls", fontsize=20)
plt.ylabel("time (s)", fontsize=20)
plt.legend(fontsize=18)
plt.show();

As said before, we can vmap over object/dictionaries

In [None]:
cosmo = {'sigma8': jnp.arange(5.), 'omega_m': jnp.arange(5.)}
x = 1.
def myfunc(dct, x):
  return 2*dct['sigma8']**2 + dct['omega_m'] + x
out = jax.vmap(myfunc, in_axes=({'sigma8': 0, 'omega_m': None}, None))(cosmo, x)
print(out)

This is great, it means that we are able to turn any function into a vectorized version without edditing the underlying function itself. This works for scalar inputs, but also inputs of vectors and matrices, just add a new blank dimension to the start of the object, and stack them up!

What if we want more speed? Well `jax` can also distribute tasks automatically across devices. Let us see what devices we have available

In [None]:
jax.devices()

We can see what device our object is attached to currently

In [None]:
arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()

We can use a cool tool to visualise the partitioning

In [None]:
jax.debug.visualize_array_sharding(arr)

As expected, all the data is on one device for now

But here we can use `jax.sharding` to allocate different devices to the data

In [None]:
from jax.sharding import PartitionSpec as P

n = jax.device_count()
print(f"Sharding overs {n} devices")

mesh = jax.make_mesh((n, 1), ("x", "y"))
sharding = jax.sharding.NamedSharding(mesh, P("x", "y"))
print(sharding)

Now lets see where our data is held

In [None]:
arr_sharded = jax.device_put(arr, sharding)

print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)

Great!

Now you can use JIT to let the XLA compilers in JAX perform the optimal load management and run away

In [None]:
@jax.jit
def f_contract(x):
    return x.sum(axis=0)


result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)