# Data Sharding Across Devices


In [3]:
from pathlib import Path
import sys

from typing import Literal

import jax

platform : Literal["darwin", "colab", "cuda"] = "darwin"

try:
    import google.colab
    platform = "colab"
except ImportError:
    devices = jax.devices()
    if any(d.platform == "gpu" for d in devices):
        platform = "cuda"

print(f"Running on {platform}")

if platform == "colab":
    !git clone https://github.com/novastar53/high_performance_jax
    !cd high_performance_jax && git pull
    !git clone https://github.com/novastar53/deepkit
    !cd deepkit && git pull
    hpj_dir = str(Path().absolute() / "high_performance_jax" / "src" )
    dt_dir = str(Path().absolute() / "deepkit" / "src" )
    sys.path.append(hpj_dir)
    print(hpj_dir)
    sys.path.append(dt_dir)
    print(dt_dir)

Running on colab
Cloning into 'high_performance_jax'...
remote: Enumerating objects: 131, done.[K
remote: Counting objects: 100% (131/131), done.[K
remote: Compressing objects: 100% (80/80), done.[K
remote: Total 131 (delta 57), reused 87 (delta 39), pack-reused 0 (from 0)[K
Receiving objects: 100% (131/131), 820.49 KiB | 6.51 MiB/s, done.
Resolving deltas: 100% (57/57), done.
Already up to date.
Cloning into 'deepkit'...
remote: Enumerating objects: 244, done.[K
remote: Counting objects: 100% (36/36), done.[K
remote: Compressing objects: 100% (36/36), done.[K
remote: Total 244 (delta 0), reused 0 (delta 0), pack-reused 208 (from 1)[K
Receiving objects: 100% (244/244), 232.42 KiB | 3.06 MiB/s, done.
Resolving deltas: 100% (128/128), done.
Already up to date.
/content/high_performance_jax/src
/content/deepkit/src


In [4]:
from deepkit.utils import timeit
from high_performance_jax.single_chip_performance import *
import matplotlib.pyplot as plt

dtype = jnp.bfloat16
devices = jax.devices()
print("Devices:")
for i,d in enumerate(devices):
  print(f"{i+1}. {d.device_kind}")    # e.g. “TPU v3”

Devices:
1. TPU v2
2. TPU v2
3. TPU v2
4. TPU v2
5. TPU v2
6. TPU v2
7. TPU v2
8. TPU v2


In [13]:
A = jnp.ones((1024, 1024))
jax.debug.visualize_array_sharding(A)

In [None]:
import numpy as np
device_array = np.array(jax.devices()).reshape((4,2))
print(device_array)
mesh = jax.sharding.Mesh(device_array, ["myaxis1", "myaxis2"])
p = jax.sharding.PartitionSpec("myaxis2", "myaxis1")
sharding = jax.sharding.NamedSharding(mesh, p)
sharded_A = jax.device_put(A, sharding)
jax.debug.visualize_array_sharding(sharded_A)

In [None]:
sharded_A.addressable_shards[0].data.shape

In [None]:
import numpy as np
import jax

dim = 8
A = jnp.ones((dim, dim, dim))

# Step 1: Create device mesh (no names needed)
device_array = np.array(jax.devices()).reshape((2, 2, 2))
# Step 3: Create positional sharding object
sharding = jax.sharding.PositionalSharding(device_array)

# Step 4: Place array on devices
sharded_A = jax.device_put(A, sharding)

# Step 5: Visualize sharding"
#jax.debug.visualize_array_sharding(sharded_A)

In [None]:
sharded_A.addressable_shards[0].data.shape

In [5]:
A = jnp.ones((1024, 1024))
B = jnp.ones((1024, 1024))

device_array = np.array(jax.devices()).reshape((2, 4))

mesh = jax.sharding.Mesh(device_array, ["myaxis1", "myaxis2"])

p_A = jax.sharding.PartitionSpec("myaxis1", "myaxis2")
p_B = jax.sharding.PartitionSpec("myaxis2", "myaxis1")

sharding_A = jax.sharding.NamedSharding(mesh, p_A)
sharding_B = jax.sharding.NamedSharding(mesh, p_B)

sharded_A = jax.device_put(A, sharding_A)
jax.debug.visualize_array_sharding(sharded_A)

sharded_B = jax.device_put(B, sharding_B)
jax.debug.visualize_array_sharding(sharded_B)

C = sharded_A + sharded_B
jax.debug.visualize_array_sharding(C)


NameError: name 'np' is not defined

In [3]:
A = jnp.ones((2**16, 2**15))
jax.debug.visualize_array_sharding(A)

XlaRuntimeError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 8.00G. That was not possible. There are 7.48G free.; (0x0x0_HBM0)

In [6]:
from functools import partial
import jax

SIZE = 16384*2

A = jax.numpy.ones((SIZE, SIZE), dtype=jax.numpy.bfloat16)
mesh = jax.sharding.Mesh(jax.devices(), ("ouraxis"))
sharded_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("ouraxis"))
unsharded_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None))
A_sharded = jax.device_put(A, sharded_sharding)

@partial(jax.jit, out_shardings = unsharded_sharding)
def unshard_array(input):
  return input

avg_time = timeit(unshard_array, A_sharded, task='unshard_array')
print(avg_time)
#A_unsharded = unshard_array(A)
achieved_bandwidth_GB_s = A.size * 2 / 10**9 / (avg_time / 10**3)
print(achieved_bandwidth_GB_s)


#A_unsharded = jax.device_put(A, unsharded_sharding)

/tmp/t_unshard_array_37X0Y27JN4
66.3527
32.364676162386765


In [10]:
jax.debug.visualize_array_sharding(A_sharded)

In [18]:
jax.debug.visualize_array_sharding(A_unsharded)