In [None]:
import jax
jax.config.update('jax_array', True)  # required for jax<=0.4.0

### `jax.numpy` on TPU (or GPU, or CPU)

In [None]:
import jax.numpy as jnp
from jax import random

In [None]:
x = random.normal(random.PRNGKey(0), (8192, 8192))
x

In [None]:
print(x.shape)
print(x.dtype)

In [None]:
y = jnp.dot(x, jnp.cos(x.T))
z = y[[0, 2, 1, 0], ..., ::-1, None]
print(z[:3, :3])

In [None]:
%timeit -n 5 -r 5 jnp.dot(x, x).block_until_ready()

In [None]:
import numpy as np
x_cpu = np.array(x)

In [None]:
%timeit -n 1 -r 2 np.dot(x_cpu[:2048, :2048], x_cpu[:2048, :2048])

### Automatic differentiation

In [None]:
from jax import grad

In [None]:
def f(x):
    if x > 0:
        return 2 * x ** 3
    else:
        return 3 * x

In [None]:
x = -3.14

print(grad(f)(x))
print(grad(f)(-x))

In [None]:
print(grad(grad(grad(f)))(-x))

In [None]:
def predict(params, inputs):
    for W, b in params:
        outputs = jnp.dot(inputs, W) + b
        inputs = jnp.maximum(outputs, 0)
    return outputs

def loss(params, batch):
    inputs, targets = batch
    predictions = predict(params, inputs)
    return jnp.sum((predictions - targets)**2)

In [None]:
def init_layer(key, n_in, n_out):
    k1, k2 = random.split(key)
    W = random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
    b = random.normal(k2, (n_out,))
    return W, b

def init_model(key, layer_sizes, batch_size):
    key, *keys = random.split(key, len(layer_sizes))
    params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

    key, *keys = random.split(key, 3)
    inputs = random.normal(keys[0], (batch_size, layer_sizes[0]))
    targets = random.normal(keys[1], (batch_size, layer_sizes[-1]))

    return params, (inputs, targets)

layer_sizes = [784, 2048, 2048, 2048, 10]
batch_size = 128

params, batch = init_model(random.PRNGKey(0), layer_sizes, batch_size)

In [None]:
print(loss(params, batch))

In [None]:
step_size = 1e-5

for _ in range(30):
    grads = grad(loss)(params, batch)
    params = [(W - step_size * dW, b - step_size * db)
              for (W, b), (dW, db) in zip(params, grads)]

In [None]:
print(loss(params, batch))

Lots more autodiff...
* forward- and reverse-mode, totally composable
* fast Jacobians and Hessians
* complex number support (holomorphic and non-holomorphic)
* exponentially-faster very-high-order autodiff
* precise control over stored intermediate values

### End-to-end optimized compilation with `jit`

In [None]:
from jax import jit

In [None]:
loss_jit = jit(loss)

In [None]:
print(loss_jit(params, batch))

In [None]:
%timeit -n 5 -r 5 loss(params, batch).block_until_ready()

In [None]:
%timeit -n 5 -r 5 loss_jit(params, batch).block_until_ready()

In [None]:
gradfun = jit(grad(loss))

for _ in range(30):
    grads = gradfun(params, batch)
    params = [(W - step_size * dW, b - step_size * db)
              for (W, b), (dW, db) in zip(params, grads)]
    
print(loss_jit(params, batch))

Limitations with jit:
* value-dependent Python control flow disallowed, use e.g. `lax.cond`, `lax.scan` instead
* must be functionally pure, **like all JAX code**

### Batching with `vmap`

In [None]:
from jax import vmap

In [None]:
def l1_distance(x, y):
    assert x.ndim == y.ndim == 1
    return jnp.sum(jnp.abs(x - y))

In [None]:
xs = random.normal(random.PRNGKey(0), (20, 3))
ys = random.normal(random.PRNGKey(1), (20, 3))

In [None]:
dists = jnp.stack([l1_distance(x, y) for x, y in zip(xs, ys)])
print(dists)

In [None]:
dists = vmap(l1_distance)(xs, ys)
print(dists)

In [None]:
from jax import make_jaxpr
make_jaxpr(l1_distance)(xs[0], ys[0])

In [None]:
make_jaxpr(vmap(l1_distance))(xs, ys)

In [None]:
def pairwise_distances(xs, ys):
    return vmap(vmap(l1_distance, (0, None)), (None, 0))(xs, ys)

In [None]:
make_jaxpr(pairwise_distances)(xs, ys)

In [None]:
perexample_grads = jit(vmap(grad(loss), in_axes=(None, 0)))

In [None]:
(dW, db), *_ = perexample_grads(params, batch)
dW.shape

Use `vmap` to plumb batch dimensions through anything: vectorize your code, library code, autodiff-generated code...

### Explicit SPMD parallelism with `pmap`

In [None]:
from jax import pmap

In [None]:
jax.devices()

In [None]:
keys = random.split(random.PRNGKey(0), 8)
mats = pmap(lambda key: random.normal(key, (8192, 8192)))(keys)
mats.shape

In [None]:
result = pmap(jnp.dot)(mats, mats)
print(result.shape)

In [None]:
timeit -n 5 -r 5 pmap(jnp.dot)(mats, mats).block_until_ready()

In [None]:
from functools import partial
from jax import lax

@partial(pmap, axis_name='i')
def allreduce_sum(x):
    return lax.psum(x, 'i')

allreduce_sum(jnp.ones(8))

### **NEW**: Implicit parallelism with `jit`!

In [None]:
import jax

x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))

In [None]:
jax.debug.visualize_array_sharding(x)

Sharding an array across multiple devices:

In [None]:
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
devices = mesh_utils.create_device_mesh((8,))
sharding = PositionalSharding(devices)

In [None]:
x = jax.device_put(x, sharding.reshape(8, 1))
jax.debug.visualize_array_sharding(x)

A sharding is an array of sets of devices:

In [None]:
sharding

Shardings can be reshaped, just like arrays:

In [None]:
sharding.shape

In [None]:
sharding.reshape(8, 1)

In [None]:
sharding.reshape(4, 2)

An array `x` can be sharded with a sharding if the sharding is _congruent_ with `x.shape`, meaning the sharding has the same length as `x.shape` and each element evenly divides the corresponding element of `x.shape`.

For example:

In [None]:
sharding = sharding.reshape(4, 2)
print(sharding)

In [None]:
y = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(y)

Different `sharding`s result in different distributed layouts:

In [None]:
sharding = sharding.reshape(1, 8)
print(sharding)

In [None]:
y = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(y)

Sometimes we might want to _replicate_ some slices:

We can express replication by calling the sharding reducer method `replicate`:

In [None]:
sharding = sharding.reshape(4, 2)
print(sharding.replicate(axis=0, keepdims=True))

In [None]:
y = jax.device_put(x, sharding.replicate(axis=0, keepdims=True))
jax.debug.visualize_array_sharding(y)

The `replicate` method acts similar to the familiar NumPy array reduction methods like `.sum()` and `.prod()`.

In [None]:
print(sharding.replicate(0).shape)
print(sharding.replicate(1).shape)

In [None]:
y = jax.device_put(x, sharding.replicate(1))
jax.debug.visualize_array_sharding(y)

## Computation follows sharding

JAX uses a computation-follows-data layout policy, which extends to shardings:

In [None]:
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))

x = jax.device_put(x, sharding.reshape(4, 2))
print('Input sharding:')
jax.debug.visualize_array_sharding(x)

In [None]:
y = jnp.sin(x)

In [None]:
print('Output sharding:')
jax.debug.visualize_array_sharding(y)

For an elementwise operation like `jnp.sin` the compiler avoids communication and chooses the output sharding to be the same as the input.

A richer example:

In [None]:
y = jax.device_put(x, sharding.reshape(4, 2).replicate(1))
z = jax.device_put(x, sharding.reshape(4, 2).replicate(0))
print('LHS sharding:')
jax.debug.visualize_array_sharding(y)
print('RHS sharding:')
jax.debug.visualize_array_sharding(z)

In [None]:
w = jnp.dot(y, z)

In [None]:
print('Output sharding:')
jax.debug.visualize_array_sharding(w)

The compiler chose an output sharding that maximally parallelizes the computation and avoids communication.

How can we be sure it's actually running in parallel? We can do a simple timing experiment:

In [None]:
x_single = jax.device_put(x, jax.devices()[0])
jax.debug.visualize_array_sharding(x_single)

In [None]:
np.allclose(jnp.dot(x_single, x_single),
            jnp.dot(y, z))

In [None]:
%timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready()

In [None]:
%timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready()

## Examples: neural networks

We can use `jax.device_put` and `jax.jit`'s computation-follows-sharding features to parallelize computation in neural networks. Here are some simple examples, based on this basic neural network:

In [None]:
def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.maximum(outputs, 0)
  return outputs

def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))

In [None]:
loss_jit = jax.jit(loss)
gradfun = jax.jit(jax.grad(loss))

In [None]:
def init_layer(key, n_in, n_out):
    k1, k2 = jax.random.split(key)
    W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
    b = jax.random.normal(k2, (n_out,))
    return W, b

def init_model(key, layer_sizes, batch_size):
    key, *keys = jax.random.split(key, len(layer_sizes))
    params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

    key, *keys = jax.random.split(key, 3)
    inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
    targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))

    return params, (inputs, targets)

layer_sizes = [784, 8192, 8192, 8192, 10]
batch_size = 8192

params, batch = init_model(jax.random.PRNGKey(0), layer_sizes, batch_size)

### 8-way batch data parallelism

In [None]:
sharding = PositionalSharding(jax.devices()).reshape(8, 1)

In [None]:
batch = jax.device_put(batch, sharding)
params = jax.device_put(params, sharding.replicate())

In [None]:
jax.debug.visualize_array_sharding(batch[0])
jax.debug.visualize_array_sharding(params[0][0])

In [None]:
loss_jit(params, batch)

In [None]:
step_size = 1e-5

for _ in range(30):
  grads = gradfun(params, batch)
  params = [(W - step_size * dW, b - step_size * db)
            for (W, b), (dW, db) in zip(params, grads)]

print(loss_jit(params, batch))

In [None]:
jax.debug.visualize_array_sharding(params[0][0])

In [None]:
%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready()

In [None]:
batch_single = jax.device_put(batch, jax.devices()[0])
params_single = jax.device_put(params, jax.devices()[0])

In [None]:
%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready()

### 4-way batch (data) parallelism and 2-way model (weight) parallelism

In [None]:
sharding = sharding.reshape(4, 2)

In [None]:
batch = jax.device_put(batch, sharding.replicate(1))
jax.debug.visualize_array_sharding(batch[0])
jax.debug.visualize_array_sharding(batch[1])

In [None]:
params, batch = init_model(jax.random.PRNGKey(0), layer_sizes, batch_size)

(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params

W1 = jax.device_put(W1, sharding.replicate())
b1 = jax.device_put(b1, sharding.replicate())

W2 = jax.device_put(W2, sharding.replicate(0))
b2 = jax.device_put(b2, sharding.replicate(0))

W3 = jax.device_put(W3, sharding.replicate(0).T)
b3 = jax.device_put(b3, sharding.replicate())

W4 = jax.device_put(W4, sharding.replicate())
b4 = jax.device_put(b4, sharding.replicate())

params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)

In [None]:
jax.debug.visualize_array_sharding(W2)

In [None]:
jax.debug.visualize_array_sharding(W3)

In [None]:
print(loss_jit(params, batch))

In [None]:
step_size = 1e-5

for _ in range(30):
    grads = gradfun(params, batch)
    params = [(W - step_size * dW, b - step_size * db)
              for (W, b), (dW, db) in zip(params, grads)]

In [None]:
print(loss_jit(params, batch))

In [None]:
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
jax.debug.visualize_array_sharding(W2)
jax.debug.visualize_array_sharding(W3)

In [None]:
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()

We didn't change our model code at all! Write your code for one device, run it on _N_...

Compose with `grad`, `vmap`, `jit`, ...