# Data Sharding Across Devices


In [2]:
from pathlib import Path
import sys

from typing import Literal

import jax

import os

#print(os.environ["LIBTPU_INIT_ARGS"])

platform : Literal["darwin", "colab", "cuda", "tpu"] = "darwin"

try:
    import google.colab
    platform = "colab"
except ImportError:
    devices = jax.devices()
    if any(d.platform == "gpu" for d in devices):
        platform = "cuda"
    elif any(d.platform == "tpu" for d in devices):
        platform = "tpu"

print(f"Running on {platform}")

if platform == "colab":
    !git clone https://github.com/novastar53/high_performance_jax
    !cd high_performance_jax && git pull
    !git clone https://github.com/novastar53/deepkit
    !cd deepkit && git pull
    hpj_dir = str(Path().absolute() / "high_performance_jax" / "src" )
    dt_dir = str(Path().absolute() / "deepkit" / "src" )
    sys.path.append(hpj_dir)
    print(hpj_dir)
    sys.path.append(dt_dir)
    print(dt_dir)

Running on darwin


In [3]:
import os

#print(os.environ["LIBTPU_INIT_ARGS"])
#os.environ["LIBTPU_INIT_ARGS"] = "--xla_enable_async_all_gather=true TPU_MEGACORE=MEGACORE_DENSE"
#print(os.environ["LIBTPU_INIT_ARGS"])

from deepkit.utils import timeit
from high_performance_jax.single_chip_performance import *
import matplotlib.pyplot as plt

dtype = jnp.bfloat16
devices = jax.devices()
print("Devices:")
for i,d in enumerate(devices):
  print(f"{i+1}. {d.device_kind}")    # e.g. “TPU v3”

Devices:
1. cpu


In [14]:
A = jnp.ones((1024, 1024))
jax.debug.visualize_array_sharding(A)

In [15]:
import numpy as np
device_array = np.array(jax.devices()).reshape((4,2))
print(device_array)
mesh = jax.sharding.Mesh(device_array, ["myaxis1", "myaxis2"])
p = jax.sharding.PartitionSpec("myaxis2", "myaxis1")
sharding = jax.sharding.NamedSharding(mesh, p)
sharded_A = jax.device_put(A, sharding)
jax.debug.visualize_array_sharding(sharded_A)

[[CudaDevice(id=0) CudaDevice(id=1)]
 [CudaDevice(id=2) CudaDevice(id=3)]
 [CudaDevice(id=4) CudaDevice(id=5)]
 [CudaDevice(id=6) CudaDevice(id=7)]]


In [None]:
sharded_A.addressable_shards[0].data.shape

In [None]:
import numpy as np
import jax

dim = 8
A = jnp.ones((dim, dim, dim))

# Step 1: Create device mesh (no names needed)
device_array = np.array(jax.devices()).reshape((2, 2, 2))
# Step 3: Create positional sharding object
sharding = jax.sharding.PositionalSharding(device_array)

# Step 4: Place array on devices
sharded_A = jax.device_put(A, sharding)

# Step 5: Visualize sharding"
#jax.debug.visualize_array_sharding(sharded_A)

In [None]:
sharded_A.addressable_shards[0].data.shape

In [None]:
A = jnp.ones((1024, 1024))
B = jnp.ones((1024, 1024))

device_array = np.array(jax.devices()).reshape((2, 4))

mesh = jax.sharding.Mesh(device_array, ["myaxis1", "myaxis2"])

p_A = jax.sharding.PartitionSpec("myaxis1", "myaxis2")
p_B = jax.sharding.PartitionSpec("myaxis2", "myaxis1")

sharding_A = jax.sharding.NamedSharding(mesh, p_A)
sharding_B = jax.sharding.NamedSharding(mesh, p_B)

sharded_A = jax.device_put(A, sharding_A)
jax.debug.visualize_array_sharding(sharded_A)

sharded_B = jax.device_put(B, sharding_B)
jax.debug.visualize_array_sharding(sharded_B)

C = sharded_A + sharded_B
jax.debug.visualize_array_sharding(C)

In [None]:
A = jnp.ones((2**16, 2**15))
jax.debug.visualize_array_sharding(A)

In [11]:
from functools import partial
import jax
import numpy as np

ROWS = 2**15
COLS = 2**15

#A = np.ones((ROWS, COLS), dtype=np.float16)
A = jnp.ones((ROWS, COLS), dtype=jnp.bfloat16)
mesh = jax.sharding.Mesh(jax.devices(), ("ouraxis"))
sharded_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("ouraxis"))
unsharded_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None))

#@jax.jit
#def alloc_array(input):
#    return jnp.array(input, dtype=jnp.bfloat16)
#avg_time = timeit(alloc_array, A, task="alloc_array")

#@partial(jax.jit, out_shardings = sharded_sharding)
#def shard_array(input):
#    return input
#avg_time = timeit(shard_array, A, task='shard_array')
    
#A_sharded = jax.device_put(A, sharded_sharding)


@partial(jax.jit, out_shardings = unsharded_sharding)
def unshard_array(input):
  return input

avg_time = timeit(unshard_array, A_sharded, task='unshard_array')
print(avg_time)

achieved_bandwidth_GB_s = A.size * 2 / 10**9 / (avg_time / 10**3)
print(achieved_bandwidth_GB_s)


/tmp/t_unshard_array_IFIG6IZ4RA


2025-05-25 15:51:05.127682: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3021] Can't reduce memory use below 6.96GiB (7471197388 bytes) by rematerialization; only reduced to 8.50GiB (9126805520 bytes), down from 8.50GiB (9126805520 bytes) originally


32.6342
65.804697158196


In [None]:
import os

#os.environ["LIBTPU_INIT_ARGS"] = "--xla_enable_async_all_gather=true TPU_MEGACORE=MEGACORE_DENSE"

SIZE = 16384
BATCH_PER_CHIP = 4096
LAYERS = 4

ACTIVATIONS = jnp.ones((BATCH_PER_CHIP*jax.device_count(), SIZE), dtype=jnp.bfloat16)
Ws = [jnp.ones((SIZE, SIZE), dtype=jnp.bfloat16) for i in range(LAYERS)]

mesh = jax.sharding.Mesh(jax.devices(), ('ouraxis'))
activation_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('ouraxis', None))
weight_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('ouraxis'))

ACTIVATIONS = jax.device_put(ACTIVATIONS, activation_sharding)
Ws = [ jax.device_put(W, weight_sharding) for W in Ws ]

@jax.jit
def matmul(_act, _weights):
    for _weight in _weights: 
        _act =  _act @ _weight
    return _act

average_time_ms = timeit(matmul, ACTIVATIONS, Ws, task="unshard_array")
print(f"{average_time_ms=}")
#achieved_bandwidth_GB_s = A.size * 2 / 10**9 / (average_time_ms / 10**3)
#print(f"{achieved_bandwidth_GB_s=}")

In [None]:
import jax
import jax.nn as nn
import jax.numpy as jnp

BATCH = 1
HEADS = 4 
SEQUENCE = 2048
HEAD_DIM = 128

Q = jax.random.normal( jax.random.key(0), (BATCH, SEQUENCE, HEADS, HEAD_DIM))
K = jax.random.normal( jax.random.key(0), (BATCH, SEQUENCE, HEADS, HEAD_DIM))
V = jax.random.normal( jax.random.key(0), (BATCH, SEQUENCE, HEADS, HEAD_DIM))

def attention(_Q, _K, _V):
    _weights_unnormalized = jnp.einsum("BSHD,BTHD->BHST", _Q, _K)
    _weights = nn.softmax(_weights_unnormalized)
    output = jax.numpy.einsum("BHST,BTHD->BSHD", _weights, _V)
    return output

attention(Q, K, V)

attn_value = nn.dot_product_attention(Q, K, V)