Parallel Evaluation in JAX
https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html


In [None]:
import jax
import jax.numpy as jnp
from typing import List, Set, Dict

In [None]:
jax.devices()

In [None]:
img: jnp.DeviceArray = jnp.arange(5)
kernel: jnp.DeviceArray = jnp.array([0.1, 0.9, 0.1])

def convolve(
    x: jnp.DeviceArray,
    kernel: jnp.DeviceArray,
    ) -> jnp.DeviceArray:
    """ Convolution operation hardcoded to 3x1 kernel. """
    out: List = []
    for i in range(1, len(x)-1):
        out.append(jnp.dot(x[i-1:i+2], kernel))
    return jnp.array(out)

convolve(img, kernel)

In [None]:
num_devices = jax.local_device_count()
img_p: jnp.DeviceArray = jnp.stack([img] * num_devices)
kernel_p: jnp.DeviceArray = jnp.stack([kernel] * num_devices)
print(f'img_p {img_p}')
print(f'kernel_p {kernel_p}')
print(f'img {img}')
print(f'kernel {kernel}')

In [None]:
'''
pmap is comparable to vmap because both transformations map a function over array axes,
but where vmap vectorizes functions by pushing the mapped axis down into primitive operations
pmap instead replicates the function and executes each replica on its own XLA device in parallel.
'''
jax.vmap(convolve)(img_p, kernel_p)
jax.pmap(convolve)(img_p, kernel_p)

In [None]:
import functools
from typing import Tuple, NamedTuple

LEARNING_RATE = 5e-3

class Params(NamedTuple):
    w: jnp.ndarray
    b: jnp.ndarray

def init(rng) -> Params:
    """ Initialize parameters for Linear Regression. """
    w_key, b_key = jax.random.split(rng)
    return Params(
        w=jax.random.normal(w_key, ()),
        b=jax.random.normal(b_key, ()),
    )

def loss_fn(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> jnp.ndarray:
    """ MSE Loss. """
    pred: jnp.ndarray = params.w * xs + params.b
    return jnp.mean((pred - ys) ** 2)

# "name" the axis on which we want to parallel map accross devices
@functools.partial(jax.pmap, axis_name='num_devices')
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:
    """ One gradient descent step. """

    # Performed on each device individually
    loss, grads = jax.value_and_grad(loss_fn)(params, xs, ys)

    # Performed accross devices
    mean_grads = jax.lax.pmean(grads, axis_name='num_devices')
    mean_loss = jax.lax.pmean(loss, axis_name='num_devices')

    # Performed on each device individually
    update_fn = lambda param, g: param - g * LEARNING_RATE
    new_params = jax.tree_multimap(update_fn, params, grads)
    
    return new_params, loss

In [None]:
# Create fake noisy linear data
rng = jax.random.PRNGKey(42)
true_params: Params = Params(w=2, b=-1)
xs: jnp.ndarray = jax.random.normal(rng, (128, 1))
noise: jnp.ndarray = 0.5 * jax.random.normal(rng, (128, 1))
ys: jnp.ndarray = xs * true_params.w + noise + true_params.b

# Initialize our Linear Regression params, replicated across devices
rng = jax.random.PRNGKey(1)
params: Params = init(rng)
num_devices = jax.local_device_count()
replicated_params = jax.tree_map(
    lambda x: jnp.array([x] * num_devices),
    params
)

In [None]:
def split(arr: jnp.ndarray) -> jnp.ndarray:
    """ Split first axis of arr evenly accross num_devices. """
    return arr.reshape(num_devices, arr.shape[0] // num_devices, *arr.shape[1:])

# Split the training data into num_device chunks
x_split = split(xs)
y_split = split(ys)
print(f'xs {xs.shape}')
print(f'x_split {x_split.shape}')
print(f'ys {ys.shape}')
print(f'y_split {y_split.shape}')

In [None]:
NUM_EPOCHS = 1000

print(f' Starting training loop')
for epoch in range(NUM_EPOCHS):
    replicated_params, loss = update(replicated_params, x_split, y_split)
    if epoch % 100 == 0:
        print(f'\t epoch {epoch:3d} loss {loss[0]:.3f}')

params = jax.device_get(jax.tree_map(lambda x: x[0], replicated_params))
print(f' true params y = {true_params.w:.3f} * x + {true_params.b:.3f} ')
print(f' pred params y = {params.w:.3f} * x + {params.b:.3f} ')