# Data Sharding Across Devices


In [None]:
from pathlib import Path
import sys

from typing import Literal

import jax

import os

print(os.environ["LIBTPU_INIT_ARGS"])

platform : Literal["darwin", "colab", "cuda", "tpu"] = "darwin"

try:
    import google.colab
    platform = "colab"
except ImportError:
    devices = jax.devices()
    if any(d.platform == "gpu" for d in devices):
        platform = "cuda"
    elif any(d.platform == "tpu" for d in devices):
        platform = "tpu"

print(f"Running on {platform}")

if platform == "colab":
    !git clone https://github.com/novastar53/high_performance_jax
    !cd high_performance_jax && git pull
    !git clone https://github.com/novastar53/deepkit
    !cd deepkit && git pull
    hpj_dir = str(Path().absolute() / "high_performance_jax" / "src" )
    dt_dir = str(Path().absolute() / "deepkit" / "src" )
    sys.path.append(hpj_dir)
    print(hpj_dir)
    sys.path.append(dt_dir)
    print(dt_dir)

--xla_enable_async_all_gather=true TPU_MEGACORE=MEGACORE_DENSE --xla_tpu_use_enhanced_launch_barrier=true


https://symbolize.stripped_domain/r/?trace=7fd2f308f251,7fd3afa4251f,7fd2efc9f01e,7fd2efc9e4b7,7fd2efc9d6ec,7fd2efc9d39d,7fd2ec0a6ff1,7fd2ec099ce8,7fd2ec098529,7fd2ec0947aa,7fd2e9e2c759,7fd2e916990d,7fd2fc09aed3,7fd2fc09abab,7fd2fbf47d45,7fd30444a4e0,7fd3b01e4e37&map= 
*** SIGFPE (@0x7fd2f308f251), see go/stacktraces#s15 received by PID 103042 (TID 103042) on cpu 18; stack trace: ***
PC: @     0x7fd2f308f251  (unknown)  (unknown)
    @     0x7fd2f5d1f565       1904  (unknown)
    @     0x7fd3afa42520  2052567888  (unknown)
    @     0x7fd2efc9f01f        416  (unknown)
    @     0x7fd2efc9e4b8        544  (unknown)
    @     0x7fd2efc9d6ed         80  (unknown)
    @     0x7fd2efc9d39e         96  (unknown)
    @     0x7fd2ec0a6ff2       1248  (unknown)
    @     0x7fd2ec099ce9        144  (unknown)
    @     0x7fd2ec09852a        592  (unknown)
    @     0x7fd2ec0947ab        272  (unknown)
    @     0x7fd2e9e2c75a        320  (unknown)
    @     0x7fd2e916990e        544  (unknown)
 

In [None]:
import os

print(os.environ["LIBTPU_INIT_ARGS"])
os.environ["LIBTPU_INIT_ARGS"] = "--xla_enable_async_all_gather=true TPU_MEGACORE=MEGACORE_DENSE"
print(os.environ["LIBTPU_INIT_ARGS"])

from deepkit.utils import timeit
from high_performance_jax.single_chip_performance import *
import matplotlib.pyplot as plt

dtype = jnp.bfloat16
devices = jax.devices()
print("Devices:")
for i,d in enumerate(devices):
  print(f"{i+1}. {d.device_kind}")    # e.g. “TPU v3”

In [13]:
A = jnp.ones((1024, 1024))
jax.debug.visualize_array_sharding(A)

In [None]:
import numpy as np
device_array = np.array(jax.devices()).reshape((4,2))
print(device_array)
mesh = jax.sharding.Mesh(device_array, ["myaxis1", "myaxis2"])
p = jax.sharding.PartitionSpec("myaxis2", "myaxis1")
sharding = jax.sharding.NamedSharding(mesh, p)
sharded_A = jax.device_put(A, sharding)
jax.debug.visualize_array_sharding(sharded_A)

In [None]:
sharded_A.addressable_shards[0].data.shape

In [None]:
import numpy as np
import jax

dim = 8
A = jnp.ones((dim, dim, dim))

# Step 1: Create device mesh (no names needed)
device_array = np.array(jax.devices()).reshape((2, 2, 2))
# Step 3: Create positional sharding object
sharding = jax.sharding.PositionalSharding(device_array)

# Step 4: Place array on devices
sharded_A = jax.device_put(A, sharding)

# Step 5: Visualize sharding"
#jax.debug.visualize_array_sharding(sharded_A)

In [None]:
sharded_A.addressable_shards[0].data.shape

In [3]:
A = jnp.ones((1024, 1024))
B = jnp.ones((1024, 1024))

device_array = np.array(jax.devices()).reshape((2, 4))

mesh = jax.sharding.Mesh(device_array, ["myaxis1", "myaxis2"])

p_A = jax.sharding.PartitionSpec("myaxis1", "myaxis2")
p_B = jax.sharding.PartitionSpec("myaxis2", "myaxis1")

sharding_A = jax.sharding.NamedSharding(mesh, p_A)
sharding_B = jax.sharding.NamedSharding(mesh, p_B)

sharded_A = jax.device_put(A, sharding_A)
jax.debug.visualize_array_sharding(sharded_A)

sharded_B = jax.device_put(B, sharding_B)
jax.debug.visualize_array_sharding(sharded_B)

C = sharded_A + sharded_B
jax.debug.visualize_array_sharding(C)

NameError: name 'np' is not defined

In [3]:
A = jnp.ones((2**16, 2**15))
jax.debug.visualize_array_sharding(A)

XlaRuntimeError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 8.00G. That was not possible. There are 7.48G free.; (0x0x0_HBM0)

In [3]:
from functools import partial
import jax
import numpy as np

ROWS = 2**16
COLS = 2**15

A = np.ones((ROWS, COLS), dtype=np.float16)
mesh = jax.sharding.Mesh(jax.devices(), ("ouraxis"))
sharded_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("ouraxis"))
unsharded_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None))

@jax.jit
def alloc_array(input):
    return jnp.array(input, dtype=jnp.bfloat16)
avg_time = timeit(alloc_array, A, task="alloc_array")

#@partial(jax.jit, out_shardings = sharded_sharding)
#def shard_array(input):
    #return input
#avg_time = timeit(shard_array, A, task='shard_array')
    
#A_sharded = jax.device_put(A, sharded_sharding)


#@partial(jax.jit, out_shardings = unsharded_sharding)
#def unshard_array(input):
#  return input

#avg_time = timeit(unshard_array, A_sharded, task='unshard_array')
print(avg_time)
#A_unsharded = unshard_array(A)

achieved_bandwidth_GB_s = A.size * 2 / 10**9 / (avg_time / 10**3)
print(achieved_bandwidth_GB_s)


/tmp/t_alloc_array_8W4EMT3AF9
RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 4.00G. That was not possible. There are 3.48G free.; (0x0x0_HBM0)


RuntimeError: No profile started

In [5]:
import os

os.environ["LIBTPU_INIT_ARGS"] = "--xla_enable_async_all_gather=true TPU_MEGACORE=MEGACORE_DENSE"

SIZE = 16384
BATCH_PER_CHIP = 4096
LAYERS = 4

ACTIVATIONS = jnp.ones((BATCH_PER_CHIP*jax.device_count(), SIZE), dtype=jnp.bfloat16)
Ws = [jnp.ones((SIZE, SIZE), dtype=jnp.bfloat16) for i in range(LAYERS)]

mesh = jax.sharding.Mesh(jax.devices(), ('ouraxis'))
activation_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('ouraxis', None))
weight_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('ouraxis'))

ACTIVATIONS = jax.device_put(ACTIVATIONS, activation_sharding)
Ws = [ jax.device_put(W, weight_sharding) for W in Ws ]

@jax.jit
def matmul(_act, _weights):
    for _weight in _weights: 
        _act =  _act @ _weight
    return _act

average_time_ms = timeit(matmul, ACTIVATIONS, Ws, task="unshard_array")
print(f"{average_time_ms=}")
#achieved_bandwidth_GB_s = A.size * 2 / 10**9 / (average_time_ms / 10**3)
#print(f"{achieved_bandwidth_GB_s=}")

/tmp/t_unshard_array_IEJ5DA53SU


2025-05-25 07:42:56.180417: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748158976.217553  100579 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748158976.228821  100579 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1748158976.253570  100579 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1748158976.253589  100579 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1748158976.253594  100579 computation_placer.cc:177] computation placer alr

average_time_ms=437.4933
