# Introduction to JAX

Written by Ben Moseley


## What is JAX?

<img src="what-is-jax.png" width=80%>

In [None]:
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2'
import jax
import matplotlib.pyplot as plt

# Arrays with JAX NumPy

In [None]:
import jax.numpy as jnp

x = jnp.array([[0., 2., 4.]])
print(x, x.shape)
print(x @ x.T)
print(x * x.T)

# Autodifferentiation with JAX

In [None]:
def fn(x):
    return jnp.tanh(x)

x = jnp.linspace(-5,5,500)

plt.plot(x, fn(x), label="f(x)")
plt.legend()
plt.xlabel("x")
plt.show()

In [None]:
# gradient
# TODO: define dfdx and d2dfdx2 functions

plt.plot(x, fn(x), label="f(x)")
plt.plot(x, [dfdx_fn(x_) for x_ in x], label="df/dx")
plt.plot(x, [d2fdx2_fn(x_) for x_ in x], label="d$^2$f/dx$^2$")
plt.legend()
plt.xlabel("x")
plt.show()

In [None]:
print(jax.make_jaxpr(fn)(x))# JAX transforms programs using a simple intermediate language call jaxpr

In [None]:
# Jacobian
jacobian_fn = jax.jacfwd(fn)
j = jacobian_fn(x)
print(j)
print(j.shape)

In [None]:
# vector-Jacobian product
f, vjp_fn = jax.vjp(fn, x)
dfdx, = vjp_fn(jnp.ones_like(x))

plt.plot(x, f, label="f(x)")
plt.plot(x, dfdx, label="df/dx")
plt.legend()
plt.xlabel("x")
plt.show()

# Jacobian-vector product
f, dfdx = jax.jvp(fn, (x,), (jnp.ones_like(x),))

plt.plot(x, f, label="f(x)")
plt.plot(x, dfdx, label="df/dx")
plt.legend()
plt.xlabel("x")
plt.show()

# Vectorisation with JAX

In [None]:
def forward_fn(w, b, x):
    x = w @ x + b
    x = jnp.tanh(x)
    return x

key = jax.random.key(seed=0)
key1, key2, key3 = jax.random.split(key, 3)
x = jax.random.normal(key1, (3,))
w = jax.random.normal(key2, (10,3))
b = jax.random.normal(key3, (10,))
y = forward_fn(w, b, x)
print(x.shape)
print(y.shape)

In [None]:
# TODO: vectorise forward_fn

x_batch = jax.random.normal(key, (1000,3))
y_batch = forward_batch_fn(w, b, x_batch)
print(x_batch.shape)
print(y_batch.shape)

# Just-in-time compilation with JAX

In [None]:
def fn(x):
    return x + x*x + x*x*x

jit_fn = jax.jit(fn)

x = jax.random.normal(key, (1000,1000))
%timeit fn(x).block_until_ready()
%timeit jit_fn(x).block_until_ready()

# Putting it all together: linear regression

In [None]:
x_batch = jnp.linspace(0, 1, 100).reshape((100,1))
y_label_batch = 5*x_batch + 1 + jax.random.normal(key, (100,1))

plt.scatter(x_batch, y_label_batch, label="training data")
plt.legend()
plt.xlabel("x"); plt.ylabel("y")
plt.show()

In [None]:
def init():
    "Returns initial model parameters"
    w = jnp.array(0.).reshape((1,1))
    b = jnp.array(0.).reshape((1,))
    theta = (w,b)
    return theta

def forward(theta, x):
    "Returns model prediction, for a single example input"
    w, b = theta
    x = w @ x + b
    return x

forward_batch = jax.vmap(forward, in_axes=(None, 0))# batched version of forward

def loss(theta, x_batch, y_label_batch):
    "Computes mean squared error between model prediction and training data"
    y_batch = forward_batch(theta, x_batch)
    return jnp.mean((y_batch-y_label_batch)**2)

grad = jax.value_and_grad(loss, argnums=0)# gradient of loss wrt model parameters

def step(lrate, theta, x_batch, y_label_batch):
    "Performs one gradient descent step on model parameters, given training data"
    # TODO: write step function

jit_step = jax.jit(step)# makes step go brr


# initialise model parameters
theta = init()

# run gradient descent
for i in range(1000):
    theta, lossval = jit_step(0.1, theta, x_batch, y_label_batch)

plt.scatter(x_batch, y_label_batch, label="training data")
plt.plot(x_batch, forward_batch(theta, x_batch), color="tab:orange", lw=3, label="model prediction")
plt.legend()
plt.xlabel("x"); plt.ylabel("y")
plt.show()
print(theta)

# Extra: multi-device parallelisation with JAX

In [None]:
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding

print(jax.devices())

x = jax.random.normal(key, (8192, 8192))

sharding = PositionalSharding(jax.devices()).reshape(2,1)
x = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(x)# shards array across first dimension

y = x**2
jax.debug.visualize_array_sharding(y)# "computation follows sharding" paradigm

y = jnp.mean(x**2, axis=0, keepdims=True)# compiler also inserts communication as necessary!
jax.debug.visualize_array_sharding(y)# result is replicated across devices