# JAX Distributed Arrays

This is modified version of JAX official documentation on [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)

In [None]:
! apt update; apt install -y graphviz

In [None]:
! pip install jaxtyping
! pip install graphviz

In [None]:
#@title Imports
from typing import Optional

import os
import functools

import jax
import jax.numpy as jnp
import graphviz
import numpy as np
import tabulate

from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.sharding import PositionalSharding

# This makes it more readable to inspect sharings.
np.set_printoptions(precision=3)
print(jax.__version__)


def show_shards_on_devices(array: jax.Array, mesh: jax.sharding.Mesh, show_devices: bool=False):
    """Visualization helper."""
    slice_from_id = {}
    idx_map = array.sharding.devices_indices_map(array.shape)
    for device, (x_slice, y_slice) in idx_map.items():
        slice_from_id[device.id] = array[x_slice, y_slice]

    # Build a table to visualize which slice goes on which device.
    nrows, ncols = mesh.device_ids.shape
    data = []
    for row in range(nrows):
        row_data = []
        for col in range(ncols):
            device_id = mesh.device_ids[row, col]
            data_slice = slice_from_id[device_id]
            if show_devices:
                cell_data = f"{jax.local_devices()[device_id]}\n\n{data_slice}"
            else:
                cell_data = data_slice
            row_data.append(cell_data)
        data.append(row_data)
    table = tabulate.tabulate(data, [], tablefmt="fancy_grid")
    print(table)

In [None]:
#@title Notebook setting
USE_MOCK_DEVICES = False #@param {type:"boolean"}
import chex

if USE_MOCK_DEVICES:
    print('Using 8 mock devices.')
    # Forces XLA to use `n` CPU threads as host devices.
    # This will make the code work.
    chex.set_n_cpu_devices(8)

if len(jax.local_devices()) < 8:
    raise Exception("Notebook requires 8 devices to run")

N_DEVICES = len(jax.local_devices())

# JAX Distributed Arrays

In [None]:
#@title jax.Array { run: "auto", form-width: "1000px" }
mesh_shape = "(4, 2)"                          #@param ["(4, 2)", "(2, 4)", "(1, 8)", "(8, 1)"]
axis_names = "('data', 'model')"               #@param ["('data', 'model')"]
partition_spec = "PartitionSpec()" #@param ["PartitionSpec()", "PartitionSpec('data', None)", "PartitionSpec(None, 'model')", "PartitionSpec('model', None)", "PartitionSpec('data', 'model')", "PartitionSpec('model', 'data')", "PartitionSpec(None, 'data')", "PartitionSpec(('data', 'model'), None)", "PartitionSpec(('model', 'data'), None)"]
show_device_ids = False #@param {type:"boolean"}

# Kids don't do this at home.
mesh_shape = eval(mesh_shape)
axis_names = eval(axis_names)
partition_spec =  eval(f"jax.sharding.{partition_spec}")

# Input data
input_data = jnp.arange(8 * 2).reshape(8, 2)

devices = mesh_utils.create_device_mesh(mesh_shape)
mesh = jax.sharding.Mesh(devices, axis_names=axis_names)

sharding = jax.sharding.NamedSharding(mesh,  partition_spec)
M = jax.device_put(input_data, sharding)

print("Input Array:\n%s" % M)
print("shape: %s" % repr(M.shape))
print()
jax.debug.visualize_array_sharding(M)

show_shards_on_devices(M, mesh, show_device_ids)

In [None]:
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))

x = jnp.arange(2*8).reshape(4, 4)


y = jax.device_put(x, sharding.reshape(1, 8).replicate(1))
z = jax.device_put(x, sharding.reshape(8, 1).replicate(0))
print('lhs sharding:')
jax.debug.visualize_array_sharding(y)
show_shards_on_devices(y, mesh, True)

print('rhs sharding:')
jax.debug.visualize_array_sharding(z)
show_shards_on_devices(z, mesh, True)

w = jnp.dot(y, z)
print('out sharding:')
jax.debug.visualize_array_sharding(w)
show_shards_on_devices(w, mesh, True)

## Batch data paralleism

In [None]:
from typing import Tuple

import jax
import jax.numpy as jnp
from jaxtyping import Array, Float


def mlp(x: Float[Array,  "B   H_1"],
        w1: Float[Array, "H_1 H_2"],
        w2: Float[Array, "H_2 H_1"]) ->Float[Array,  "B H_1"]:
    """A simple two layer MLP.

    Z = max(X * W_1, 0) * W_2

    """
    y = jnp.dot(x, w1)    # [B, H_1] @ [H_1, H_2] -> [B, H_2]
    u = jnp.maximum(y, 0) # [B, H_2]
    z = jnp.dot(u, w2)    # [B, H_2] @ [H_2, H_1] -> [B, H_1]
    return z


def cross_entropy(logits: Float[Array,  "B C"],
                  targets: Float[Array, "B"]) -> Float[Array, "B"]:
    """Calculates cross-entropy."""
    return -jnp.sum(targets * jax.nn.log_softmax(logits, axis=-1), axis=-1)


def loss_fn(x, w1, w2, targets):
      return jnp.sum(cross_entropy(mlp(x, w1, w2), targets))


def init_model(
    key: jax.random.PRNGKey,
    H_1: int,
    H_2: int) -> Tuple[Float[Array, "H_1 H_2"], Float[Array, "H_2 H_1"]]:
    """Initialises the network."""
    k1, k2 = jax.random.split(key)
    W_1 = jax.random.normal(k1, (H_1, H_2)) / jnp.sqrt(H_1)
    W_2 = jax.random.normal(k1, (H_2, H_1)) / jnp.sqrt(H_2)
    return W_1, W_2


def init_data(
    key: jax.random.PRNGKey,
    B: int, H_1: int, C: int,
    ) -> Tuple[Float[Array, "B H_1"], Float[Array, "B C"]]:
    "Initialises data"
    k1, k2 = jax.random.split(key)
    return (
        jax.random.normal(k1, (B, H_1)),
        jax.random.randint(k2, (B, 1), minval=0, maxval=C))

In [None]:
# Set the shapes, to be small to be able to visualize
# what's happening.
B, H_1, H_2, C = N_DEVICES * 2, 4, 4, 8

key = jax.random.PRNGKey(0)
w1, w2 = init_model(key, H_1, H_2)
x, y = init_data(key, B, H_1, C)

mesh_shape = (8, 1)

# First specify sharding with positional sharding on 8 devices
sharding = jax.sharding.PositionalSharding(jax.devices()).reshape(*mesh_shape)

# Technically one doesn't need mesh for PositionalSharding, but we'll be
# using it to visualize what goes to each device.
mesh = jax.sharding.Mesh(
    mesh_utils.create_device_mesh(mesh_shape), axis_names=('data', 'model'))


# Shard batch on the first dimension, each mini-batch should be of size B/8.
(x, y) = jax.device_put((x, y), sharding)

# Put parameters on devices, replicating them. Each
# device has the same copy of the model.
w1, w2 = jax.device_put((w1, w2), sharding.replicate())

In [None]:
# This is our batch of data.
x

In [None]:
# This is how it's sharded across devices, on its row axis.
show_shards_on_devices(x, mesh, show_devices=True)

In [None]:
loss_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn, argnums=(1, 2)))

In [None]:
lr = 0.01

for i in range(100):
    loss, (dw1, dw2) = loss_and_grad_fn(x, w1, w2, y)
    w1 = w1 - lr*dw1
    w2 = w2 - lr*dw2
    if i % 10 == 0:
        print(f"Step: {i} {loss:.2f}")

In [None]:
# Increase the sizes, to make the task more computionaly intensive.
B, H_1, H_2, C = 8192, 1024*4, 1024*4, 32_000

key = jax.random.PRNGKey(0)
w1, w2 = init_model(key, H_1, H_2)
x, y = init_data(key, B, H_1, C)

# Put on device 0
device_0 =jax.devices()[0]
x_d0, y_d0 = jax.device_put((x, y), device_0)
w1_d0, w2_d0 = jax.device_put((w1, w2), device_0)

# Shard data, replicate params.
(x, y) = jax.device_put((x, y), sharding)
w1, w2 = jax.device_put((w1, w2), sharding.replicate())

In [None]:
loss_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn, argnums=(1, 2)))

In [None]:
%timeit -n 5 -r 10 loss_and_grad_fn(x, w1, w2, y)[0].block_until_ready()

In [None]:
%timeit -n 5 -r 10 loss_and_grad_fn(x_d0, w1_d0, w2_d0, y_d0)[0].block_until_ready()

## Batch and model parallelism.

In [None]:
# Let's now create 4-batch, 2 model positional sharding.
import numpy as np

mesh_shape = (4, 2)
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices,  ('data', 'model'))

# Create input data
B, H_1, H_2 = 8, 2, 4

X = np.arange(B * H_1).reshape(B, H_1)
w1 = 3 * np.arange(H_1 * H_2).reshape(H_1, H_2)
w2 = 7 * np.arange(H_2 * H_1).reshape(H_2, H_1)

# Shard input data on first batch dimension, replicate on model dimension,
# as mini-batch for data will be multpied by each shard of weight w1.
X  = jax.device_put(
    X,
    jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('data', None)))

print("Placement of input_data, split on batch dimension")
print("""
 [[ 0,  1],
  [ 2,  3],
  --------
  [ 4,  5],
  [ 6,  7],
  --------
  [ 8,  9],
  [10, 11],
  --------
  [12, 13],
  [14, 15]]
""")
show_shards_on_devices(X , mesh)

# Shard w_1 on columns, replicate on data dimension.
w_1 = jax.device_put(
    w1,
    jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None, 'model')))

print("Placement of w_1, split on columns")
print("""
[[ 0  3  | 6  9]
 [12 15  | 18 21]]
""")
show_shards_on_devices(w_1, mesh)

# Shard w_2 on rows, replicate on data dimension.
w_2 = jax.device_put(
    w2,
    jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('model', None)))

print("Placement of w_2, split on rows")
print("""
[[ 0  7]
 [14 21]
 -------
 [28 35]
 [42 49]]
""")
show_shards_on_devices(w_2, mesh)

# Calculate sharded product.
Y = jnp.maximum(X @ w_1, 0)
Z = Y @ w_2

show_shards_on_devices(Z, mesh)

In [None]:
# Increase the sizes, to make the task more computionaly intensive.
B, H_1, H_2, C = 8192, 512, 1024, 32_000

key = jax.random.PRNGKey(0)
w1, w2 = init_model(key, H_1, H_2)
x, t = init_data(key, B, H_1, C)

# Shard input data on first batch dimension, replicate on model dimension,
# as mini-batch for data will be multpied by each shard of weight w1.
(x, t) = jax.device_put(
    (x, t),
    jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('data', None)))

# Shard w_1 on columns, replicate on data dimension.
w_1 = jax.device_put(
    w1,
    jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None, 'model')))

# Shard w_2 on rows, replicate on data dimension.
w_2 = jax.device_put(
    w2,
    jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('model', None)))

In [None]:
list(map(lambda x: x.shape, (x, t, w_1, w_2)))

In [None]:
%timeit -n 5 -r 10 loss_and_grad_fn(x, w_1, w_2, t)[0].block_until_ready()

In [None]:
B, H_1, H_2 = 8, 4, 8
w1, w2 = init_model(key, H_1, H_2)
X, _ = init_data(key, B, H_1, C)

sharding_2d = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('data', 'model'))

X = jax.device_put(X, sharding_2d)
w_1 = jax.device_put(w1,sharding_2d)
w_2 = jax.device_put(w2,sharding_2d)
z = mlp(X, w_1, w_2)

print("X", X.shape)
show_shards_on_devices(X , mesh)
print("w_1", w_1.shape)
show_shards_on_devices(w_1, mesh)
print("w_2", w_2.shape)
show_shards_on_devices(w_2, mesh)
print("Z", z.shape)
show_shards_on_devices(z, mesh)

In [None]:
# The curious can always export and inspect the generated XLA computation.
compiled = jax.jit(mlp).lower(X, w_1, w_2).compile()
hlo_module = compiled.runtime_executable().hlo_modules()[0]
dot_graph = jax.interpreters.xla.xe.hlo_module_to_dot_graph(hlo_module)

In [None]:
graph = graphviz.Source(dot_graph, format='png')
graph

In [None]:
jax.debug.visualize_array_sharding(z)

# Additional Materials

## TPU Layout.

You might have noticed while calling `jax.debug.visualize_array_sharding` that TPU device numbers are in a particular order: [TPU 0, TPU 1, TPU 2, TPU 3, TPU 6, TPU 7, TPU 4, TPU 5](https://github.com/google/jax/blob/404e3061b6368daed3efa3ee7b99128327254ac2/jax/experimental/mesh_utils.py#L61).

```
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘
```


This is because 2x2 slice that you will be most probably using is physically layed out in a ring, and aforementioned layout enables most efficient communication.  

```
TPU 2, TPU 3 -----→ TPU 6, TPU 7
    ↑                   |
    |                   |
    |                   |
    |                   ↓
TPU 0, TPU 1 ←----- TPU 4, TPU 5
```

As per device numbers, those can be recovered from TPU coords.
