In [43]:
"""Run this cell twice"""

import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
os.environ["CUDA_VISIBLE_DEVICES"] = ""

import jax
from jax.sharding import NamedSharding, PartitionSpec
import jax.numpy as jnp
import numpy as np

print(jax.local_devices())

[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]


In [4]:
"""Single device"""

arr = jnp.arange(32.0).reshape(4, 8)
print(arr.devices())
print(arr.sharding)
jax.debug.visualize_array_sharding(arr)

{CpuDevice(id=0)}
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)


In [44]:
"""Create mesh and shard data to multiple devices"""

mesh = jax.make_mesh((2, 4), ("x", "y"))
sharding = NamedSharding(mesh, PartitionSpec("x", "y"))
print(sharding)

arr_sharded = jax.device_put(arr, sharding)
print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)

NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'), memory_kind=unpinned_host)
[[ 0.  1.  2.  3.  4.  5.  6.  7.]
 [ 8.  9. 10. 11. 12. 13. 14. 15.]
 [16. 17. 18. 19. 20. 21. 22. 23.]
 [24. 25. 26. 27. 28. 29. 30. 31.]]


In [33]:
"""JIT compiles functions to be performed in parallel"""


@jax.jit
def f_elementwise(x):
    return 2 * jnp.sin(x) + 1


result = f_elementwise(arr_sharded)
print(arr_sharded.sharding)
print(result.sharding)
jax.debug.visualize_array_sharding(result)

NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'), memory_kind=unpinned_host)
NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'), memory_kind=unpinned_host)


In [14]:
"""Reducing across devices - replicated across 0, 4 etc."""


@jax.jit
def f_contract(x):
    return jnp.sum(x, axis=0)


result = f_contract(arr_sharded)
print(result)
print(result.sharding)
jax.debug.visualize_array_sharding(result)

[48. 52. 56. 60. 64. 68. 72. 76.]
NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('y',), memory_kind=unpinned_host)


In [45]:
"""Sharding with constraints"""


@jax.jit
def f_contract_2(x):
    out = x.sum(axis=0)
    sharding = NamedSharding(mesh, PartitionSpec("x"))
    return jax.lax.with_sharding_constraint(out, sharding)


result = f_contract_2(arr_sharded)
print(result)
print(result.sharding)
jax.debug.visualize_array_sharding(result)

[48. 52. 56. 60. 64. 68. 72. 76.]
NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x',), memory_kind=unpinned_host)


In [46]:
"""Sharding with a layer"""


@jax.jit
def layer(x, weights, biases):
    return jax.nn.sigmoid(x @ weights + biases)


# No sharding
rng = np.random.default_rng(0)
x = rng.normal(size=(32,))
weights = rng.normal(size=(32, 4))
bias = rng.normal(size=(4,))

result = layer(x, weights, bias)
print(result)
jax.debug.visualize_array_sharding(result)

# Sharding
mesh = jax.make_mesh((8,), ("x",))
sharding = NamedSharding(mesh, PartitionSpec("x"))
x_sharded = jax.device_put(x, sharding)
weights_sharded = jax.device_put(weights, sharding)

result = layer(x_sharded, weights_sharded, bias)
print(result)
jax.debug.visualize_array_sharding(result)


# Sharding with constraint
@jax.jit
def layer_auto(x, weights, bias):
    x = jax.lax.with_sharding_constraint(x, sharding)
    weights = jax.lax.with_sharding_constraint(weights, sharding)
    return layer(x, weights, bias)


result = layer_auto(x_sharded, weights_sharded, bias)
print(result)
jax.debug.visualize_array_sharding(result)

[0.02138916 0.8931118  0.5989196  0.9774251 ]


[0.02138916 0.8931118  0.5989196  0.9774251 ]


[0.02138916 0.8931118  0.5989196  0.9774251 ]


In [47]:
"""Shard across different axes"""

mesh = jax.make_mesh((4, 2), ("a", "b"))
x = jax.random.normal(jax.random.key(0), (8192, 8192))
x_sharded = jax.device_put(x, NamedSharding(mesh, PartitionSpec("a", "b")))
print("a, b")
jax.debug.visualize_array_sharding(x_sharded)
x_sharded = jax.device_put(x, NamedSharding(mesh, PartitionSpec("b", "a")))
print("b, a")
jax.debug.visualize_array_sharding(x_sharded)
x_sharded = jax.device_put(x, NamedSharding(mesh, PartitionSpec("a", None)))
print("a, None")
jax.debug.visualize_array_sharding(x_sharded)
x_sharded = jax.device_put(x, NamedSharding(mesh, PartitionSpec(None, "b")))
print("None, b")
jax.debug.visualize_array_sharding(x_sharded)
x_sharded = jax.device_put(x, NamedSharding(mesh, PartitionSpec(None, "a")))
print("None, a")
jax.debug.visualize_array_sharding(x_sharded)
x_sharded = jax.device_put(x, NamedSharding(mesh, PartitionSpec(("a", "b"), None)))
print("(a, b), None")
jax.debug.visualize_array_sharding(x_sharded)

a, b


b, a


a, None


None, b


None, a


(a, b), None


In [49]:
"""Sharded matmul"""

y = jax.device_put(x, NamedSharding(mesh, PartitionSpec("a", None)))
z = jax.device_put(x, NamedSharding(mesh, PartitionSpec(None, "b")))
print("LHS sharding")
jax.debug.visualize_array_sharding(y)
print("RHS sharding")
jax.debug.visualize_array_sharding(z)

w = jnp.dot(y, z)
print("Result sharding")
jax.debug.visualize_array_sharding(w)

LHS sharding


RHS sharding


Result sharding


In [103]:
"""Neural network example"""


def predict(params, inputs):
    for W, b in params:
        outputs = jnp.dot(inputs, W) + b
        inputs = jax.nn.relu(outputs)
    return outputs


def loss(params, batch):
    inputs, targets = batch
    predictions = predict(params, inputs)
    return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1))


loss_jit = jax.jit(loss)
gradfun = jax.jit(jax.grad(loss))


def init_layer(key, n_in, n_out):
    k1, k2 = jax.random.split(key)
    W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
    b = jax.random.normal(k2, (n_out,))
    return W, b


def init_model(key, layer_sizes, batch_size):
    key, *keys = jax.random.split(key, len(layer_sizes))
    params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

    key, *keys = jax.random.split(key, 3)
    inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
    targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))

    return params, (inputs, targets)


layer_sizes = [4, 8, 8, 8, 2]
batch_size = 64

In [115]:
"""8-way data parallelism"""

params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)

mesh = jax.make_mesh((8,), ("batch",))
sharding = NamedSharding(mesh, PartitionSpec("batch"))
replicated_sharding = NamedSharding(mesh, PartitionSpec())

batch = jax.device_put(batch, sharding)
params = jax.device_put(params, replicated_sharding)
print(loss_jit(params, batch))

step_size = 1e-2

for _ in range(10):
    grads = gradfun(params, batch)
    params = [
        (W - step_size * dW, b - step_size * db)
        for (W, b), (dW, db) in zip(params, grads)
    ]

    print(loss_jit(params, batch))

3.5906165
2.4103546
2.338912
2.3126652
2.291673
2.2735362
2.2575002
2.243159
2.230052
2.2181578
2.207295


In [125]:
"""4-way data parallelism and 2-way model parallelism"""

params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)

mesh = jax.make_mesh((4, 2), ("batch", "model"))
batch = jax.device_put(batch, NamedSharding(mesh, PartitionSpec("batch", None)))
jax.debug.visualize_array_sharding(batch[0])

replicated_sharding = NamedSharding(mesh, PartitionSpec())
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params

W1 = jax.device_put(W1, replicated_sharding)
b1 = jax.device_put(b1, replicated_sharding)

W2 = jax.device_put(W2, NamedSharding(mesh, PartitionSpec(None, "model")))
b2 = jax.device_put(b2, NamedSharding(mesh, PartitionSpec("model")))

W3 = jax.device_put(W3, NamedSharding(mesh, PartitionSpec("model", None)))
b3 = jax.device_put(b3, replicated_sharding)

W4 = jax.device_put(W4, replicated_sharding)
b4 = jax.device_put(b4, replicated_sharding)

params = [(W1, b1), (W2, b2), (W3, b3), (W4, b4)]

print("W2")
jax.debug.visualize_array_sharding(W2)
print("b2")
jax.debug.visualize_array_sharding(b2)
print("W3")
jax.debug.visualize_array_sharding(W3)
print(loss_jit(params, batch))

step_size = 1e-2

for _ in range(10):
    grads = gradfun(params, batch)
    params = [
        (W - step_size * dW, b - step_size * db)
        for (W, b), (dW, db) in zip(params, grads)
    ]

    print(loss_jit(params, batch))

W2


b2


W3


3.5906165
2.4103546
2.338912
2.312665
2.291673
2.2735362
2.2575
2.2431588
2.230052
2.2181578
2.2072952
