# 06 Distributed Computing

Original Documentation: https://docs.jax.dev/en/latest/sharded-computation.html


In [2]:
import jax
import jax.numpy as jnp

SPMD is a parallelism technique where the same computation (e.g., a forward pass of a neural network) can be run on different input data (e.g., batch inputs) in parallel on different devices.

JAX has three SPMD strategies:

- Automatic sharding in JIT contexts.
- Explicit sharding which is similar to automatic sharding, but makes the sharding of an array part of the JAX type system as something queryable available at trace-time.
  - The compiler still decides how to apply SPMD, but is constrained by user-supplied shardings.
- Fully manual sharding with per-device code and explicit communication collectives


In [3]:
# Tell JAX to create 8 logical CPUs, corresponding to 8 OS threads.
jax.config.update("jax_num_cpu_devices", 8)
print(jax.devices())

[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]


## Data sharding

`jax.Array` represents an array in physical memory spanning one or more devices. The type is designed with distributed data and computation in mind.

Every `jax.Array` has a `jax.sharding.Sharding` object, which describes which shard of the global data is required by each device.

In simple cases, arrays are sharded on a single device:


In [4]:
a = jnp.arange(32.0).reshape(4, 8)
print(a.devices())
print(a.sharding)

{CpuDevice(id=0)}
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)


We can use `jax.debug.visualize_array_sharding()` to visualize how `a` is stored in memory of a single device:


In [6]:
jax.debug.visualize_array_sharding(a)

To create an array with non-trivial sharding, we can define an array with a `jax.sharding` specification and pass this to `jax.device_put()`:


In [7]:
# Make a (2, 4) grid of devices and label X-axis and Y-axis.
mesh = jax.make_mesh((2, 4), ("x", "y"))

# Create a specification to shard an array across the mesh.
# P("x", "y") means split the first dimension of an array across
# X-axis and second dimension across Y-axis.
sharding = jax.sharding.NamedSharding(mesh, jax.P("x", "y"))

# Shard the array across the mesh. It will be split into 2 chunks
# across the X-axis and 4 chunks across the Y-axis.
a = jnp.arange(32.0).reshape(4, 8)
a_sharded = jax.device_put(a, sharding)

jax.debug.visualize_array_sharding(a_sharded)

## Automatic parallelism in JIT

Once we have shared data, the easiest way to do parallel computation is to simply pass the data to a JIT compiled function.

The XLA compiler includes heuristics for optimizing computations across multiple devices. We just need to specify how our input and output data is sharded.


In [8]:
@jax.jit
def f_elementwise(x):
    """Apply 2 * sin(x) + 1 to an array in element-wise fashion"""
    return 2 * jnp.sin(x) + 1


result = f_elementwise(a_sharded)
print("Shardings match:", a_sharded.sharding == result.sharding)

Shardings match: True


Here is another example:


In [9]:
@jax.jit
def f_contract(x):
    return jnp.sum(x, axis=0)


result = f_contract(a_sharded)
jax.debug.visualize_array_sharding(result)

The first two elements are replicated on devices 0 and 4, the second two elements on devices 1 and 5, and so on.


## Explicit sharding

The key idea is that the JAX-level type of a value includes a description how the data is sharded.

The JAX-level type is effectively the information we have access to in a JIT function.

We can query the JAX-level type of any JAX value, Numpy array, or Python scalar with `jax.typeof()`:


In [10]:
a = jnp.arange(8)

# Use jax.typeof() to query the JAX-level type
print(f"JAX-level type:", jax.typeof(a))


@jax.jit
def f(a):
    # jax.typeof() also works in JIT contexts
    print(f"JAX-level type:", jax.typeof(a))
    return 2 * a


f(a)

JAX-level type: int32[8]
JAX-level type: int32[8]


Array([ 0,  2,  4,  6,  8, 10, 12, 14], dtype=int32)

To see sharding in the type:


In [None]:
from jax.sharding import AxisType

mesh = jax.make_mesh((2, 4), ("X", "Y"), axis_types=(AxisType.Explicit, AxisType.Explicit))

a = jnp.arange(8).reshape(4, 2)

# P("X", None) shards axis 0 across the X mesh (2 shards) and replicates each shard
# across the Y mesh (4 replicas per shard). Y replicas hold identical data, so any
# computation that doesn't differentiate across Y will be redundant across those devices.
a_sharded = jax.device_put(a, jax.NamedSharding(mesh, jax.P("X", None)))

print(f"a type: {jax.typeof(a)}")
print(f"a_sharded type: {jax.typeof(a_sharded)}")

a type: int32[4,2]
a_sharded type: int32[4@X,2]


`int32[4@X,2]` means a 4 by 2 array of int32s, where the first dimension is sharded across the X-axis of the mesh. The array is replicated across the other mesh axes.

In this case, each device will receive a 2 by 2 array. The computation will be replicated across the Y-axis of the mesh.

This is an example of adding two sharded arrays:


In [None]:
# Explicitly shard along X-axis
a0 = jax.device_put(jnp.arange(4).reshape(4, 1), jax.NamedSharding(mesh, jax.P("X", None)))

# Explicitly shard along Y-axis
a1 = jax.device_put(jnp.arange(8).reshape(1, 8), jax.NamedSharding(mesh, jax.P(None, "Y")))


@jax.jit
def add(a, b):
    """Add two arrays. Will use Numpy style broadcasting."""
    ans = a + b
    print(f"a sharding: {jax.typeof(a)}")
    print(f"b sharding: {jax.typeof(b)}")
    print(f"ans sharding: {jax.typeof(ans)}")
    return ans


# Result will be sharded across both X and Y axes.
add(a0, a1)

a sharding: int32[4@X,1]
b sharding: int32[1,8@Y]
ans sharding: int32[4@X,8@Y]


Array([[ 0,  1,  2,  3,  4,  5,  6,  7],
       [ 1,  2,  3,  4,  5,  6,  7,  8],
       [ 2,  3,  4,  5,  6,  7,  8,  9],
       [ 3,  4,  5,  6,  7,  8,  9, 10]], dtype=int32)

Shardings propagate deterministically at trace time and we can query them at trace time.

## Manual parallelism

In automatic parallelism, we write a function as if we are operation on the full dataset and JIT will split the computation across multiple devices.

In manual parallelism with `jax.shard_map()`, we write a function that will handle a single shard of data and `shard_map` will construct the full function:


In [None]:
mesh = jax.make_mesh((8,), ("X",))


def f_elementwise(x):
    """Apply 2 * sin(x) + 1 to an array in element-wise fashion"""
    return 2 * jnp.sin(x) + 1


a = jnp.arange(32)

# jax.shard_map() will take a per-shard function and create the full function.
f_elementwise_sharded = jax.shard_map(
    f_elementwise,
    mesh=mesh,
    in_specs=jax.P("X"),  # Specify what axis to shard the input array along
    out_specs=jax.P("X"),  # Specify what axis to shard the returned output array along
)

af_sharded = f_elementwise_sharded(a)
print(af_sharded)

[ 1.          2.682942    2.818595    1.28224    -0.513605   -0.9178486
  0.44116902  2.3139732   2.9787164   1.824237   -0.08804226 -0.99998045
 -0.07314587  1.840334    2.9812148   2.3005757   0.42419338 -0.92279494
 -0.50197446  1.2997544   2.8258905   2.6733112   0.98229736 -0.69244087
 -0.81115675  0.7352965   2.525117    2.912752    1.5418116  -0.32726777
 -0.97606325  0.19192469]


For reference, `jax.shard_map()` can work inside a JIT function if needed.

The function only sees a single portion of the data:


In [16]:
x = jnp.arange(32)
print(f"Global shape: {x.shape=}")


def f(x):
    print(f"Device-local shape: {x.shape=}")
    return x * 2


y = jax.shard_map(f, mesh=mesh, in_specs=jax.P("X"), out_specs=jax.P("X"))(x)

Global shape: x.shape=(32,)
Device-local shape: x.shape=(4,)


Because the functions only see the device-local portion of the data, it means that aggregation-like functions require some extra thought.

For example, consider summation:


In [17]:
x = jnp.arange(32)


def f(x):
    # keepdims=True will return a 1D array instead of a scalar
    return jnp.sum(x, keepdims=True)


result = jax.shard_map(f, mesh=mesh, in_specs=jax.P("X"), out_specs=jax.P("X"))(x)
print(result)

[  6  22  38  54  70  86 102 118]


Since `f` operates individually on each shard of the data, we never have any logic to sum across shards. We can explicitly handle this with collective operations like `jax.lax.psum()`:


In [18]:
x = jnp.arange(32)


def f(x):
    shard_sum = x.sum()

    # Performs an all-reduce style sum over all shards along X-axis
    return jax.lax.psum(shard_sum, "X")


# Note that out_specs=jax.P(); since the result is a scalar, we do not shard
# the result across the X-axis.
result = jax.shard_map(f, mesh=mesh, in_specs=jax.P("X"), out_specs=jax.P())(x)
print(result)

496


## Comparing the three approaches

Consider a forward pass on a NN layer:


In [19]:
key = jax.random.key(1701)


@jax.jit
def layer(x, W, b):
    """Applies a dense nn layer with sigmoid activation."""
    return jax.nn.sigmoid(x @ W + b)


key, x_key, W_key, b_key = jax.random.split(key, 4)

x = jax.random.normal(x_key, (32,))
W = jax.random.normal(W_key, (32, 4))
b = jax.random.normal(b_key, (4,))

We can automatically perform this in distributed manner using `jax.jit()` and passing appropriately sharded data:


In [21]:
mesh = jax.make_mesh((8,), ("X",))
x_sharded = jax.device_put(x, jax.NamedSharding(mesh, jax.P("X")))
W_replicated = jax.device_put(W, jax.NamedSharding(mesh, jax.P()))  # P() = replicate

# Every device will see 4 input elements. The weights are replicated across all devices.
# After each device completes its forward pass, the outputs will be aggregated together
# automatically, since JAX will insert a jax.lax.psum() call.
print(layer(x_sharded, W_replicated, b))

[0.01422223 0.9701333  0.00124826 0.08572862]


The weights are replicated across each device, and the input is sharded across the devices. The final result is joined together in all-reduce fashion.

We can also use explicit sharding:


In [22]:
explicit_mesh = jax.make_mesh((8,), ("X",), axis_types=(AxisType.Explicit,))
x_sharded = jax.device_put(x, jax.NamedSharding(explicit_mesh, jax.P("X")))
W_replicated = jax.device_put(W, jax.NamedSharding(explicit_mesh, jax.P()))


@jax.jit
def layer_auto(x, W, b):
    """Applies a dense nn layer with sigmoid activation."""
    print("x sharding:", jax.typeof(x))
    print("W sharding:", jax.typeof(W))
    print("b sharding:", jax.typeof(b))
    out = layer(x, W, b)
    print("out sharding:", jax.typeof(out))
    return out


# Every device will see 4 input elements. The weights are replicated across all devices.
# After each device completes its forward pass, the outputs will be aggregated together
# automatically, since JAX will insert a jax.lax.psum() call.
print(layer_auto(x_sharded, W_replicated, b))

x sharding: float32[32@X]
W sharding: float32[32,4]
b sharding: float32[4]
out sharding: float32[4]
[0.01422223 0.9701333  0.00124826 0.08572862]


The key difference here is that the sharding computation is still automatically handled by the compiler, but the sharding is now part of the JAX-level type system.

Finally, we can do the same thing with manual parallelism (i.e., `jax.shard_map()`):


In [23]:
from functools import partial


@jax.jit
@partial(
    jax.shard_map,
    mesh=mesh,
    in_specs=(
        jax.P("X"),  # Shard `x` along X-axis
        jax.P("X", None),  # Shard the rows of `W` along X-axis, columns are replicated
        jax.P(None),  # `b` is replicated across all devices
    ),
    out_specs=jax.P(None),  # Output is replicated across all devices
)
def layer_sharded(x, W, b):
    # Each device computes its local partial (x_i @ W_i) in R^M.
    # psum(..., "X") then does an elementwise all-reduce (sum) across devices,
    # yielding the full x @ W on every device. Adding b and sigmoid are local.
    return jax.nn.sigmoid(jax.lax.psum(x @ W, "X") + b)


print(layer_sharded(x, W, b))

# ❯ uv run main.py
# [0.01422223 0.9701333  0.00124826 0.08572862]

[0.01422223 0.9701333  0.00124826 0.08572862]


## Distributed Training

We can shard model weights across device and incrementally update them during training.


In [None]:
key = jax.random.key(1701)
mesh = jax.make_mesh((8,), ("X",))

B = 128 * jax.device_count()
D = 32
LEARNING_RATE = 0.01

# Generate dataset
X = jax.random.normal(key, (B, D))
w_true = jax.random.normal(key, (D, 1))  # True weights for label generation
y = (X @ w_true > 0).astype(jnp.float32)  # Labels: 1 if positive, 0 else

# Shard data and parameters across X-axis
data_sharding = jax.NamedSharding(mesh, jax.P("X", None))
params_sharding = jax.NamedSharding(mesh, jax.P())

X_sharded = jax.device_put(X, data_sharding)
y_sharded = jax.device_put(y, data_sharding)

params = {"w": jnp.zeros((D, 1)), "b": jnp.zeros((1,))}
params_sharded = jax.tree.map(lambda a: jax.device_put(a, params_sharding), params)


@partial(
    jax.jit,
    in_shardings=(params_sharding, data_sharding, data_sharding),
    out_shardings=params_sharding,
)
def train_step(p, x, y):
    """
    Perform one step of training with data-parallel sharding.

    Args:
    - p: Model parameters replicated across all devices.
    - x: Input features of the batch sharded along the batch dimension (X-axis).
         In this example, each device recieves 128 examples per step (128 * 8 / 8).
         Since there are 32 features, each device effectively has a 128x8 matrix.
    - y: Target labels for the batch.
    """

    def loss_fn(p):
        logits = x @ p["w"] + p["b"]
        preds = jax.nn.sigmoid(logits)
        return jnp.mean((preds - y) ** 2)

    # Update weights & bias with loss gradient direction
    grad = jax.grad(loss_fn)(p)
    return jax.tree.map(lambda a, g: a - LEARNING_RATE * g, p, grad)


# Train for 100 steps
for _ in range(100):
    params = train_step(params_sharded, X_sharded, y_sharded)

The model’s weights are replicated on every device, so each device holds a full copy and updates it in sync.

The input batch is sharded along the X-axis, so each device processes a different slice of examples, computes local gradients, and JAX automatically all-reduces those gradients so every device applies the same weight update.
