In [1]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

## Basic Data Parallel Training Example

In [10]:
from functools import partial

import jax
import jax.numpy as jnp
from jax import value_and_grad, jit

def init_params():
    W = jnp.ones((2, 2))
    b = jnp.zeros((2,))
    return W, b

def forward(params, x):
    W, b = params
    return x @ W + b

def loss_fn(params, x, y):
    pred = forward(params, x)
    return jnp.mean((pred - y) ** 2)

@partial(jax.pmap, axis_name=("i"))
def train_step(params, x, y):
    # Compute per-device loss and grads
    loss, grads = value_and_grad(loss_fn)(params, x, y)

    # Use lax.psum to manually average across devices (if using with `pmap`)
    grads = jax.tree_map(lambda g: jax.lax.pmean(g, axis_name='i'), grads)

    # Update parameters
    new_params = jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)
    return new_params, loss

In [11]:
from jax.debug import visualize_array_sharding as viz

# Get available devices
devices = jax.local_devices()

B, C = 8, 2

key = jax.random.key(0)
x = jax.random.normal(jax.random.key(0), (B, C))
y = jax.random.normal(jax.random.key(1), (B, C))

# Replicate model params across devices
params = init_params()
replicated_params = jax.device_put_replicated(params, devices)
#viz(replicated_params[0][7])

# Perform one training step
new_params, loss = train_step(replicated_params, x, y)

## Basic All-to-All Example