#### Learing JAX Sharding

In [1]:
import os

# 当需要重新设置 device_count，需要重启 kernel!
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=9"
import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P

devices = jax.devices()
devices

[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7),
 CpuDevice(id=8)]

In [2]:
mesh_devices = np.array(devices).reshape(3, 3)
mesh = Mesh(mesh_devices, axis_names=("x", "y"))

In [3]:
sharding_spec = P("x", "y")
sharding = NamedSharding(mesh, sharding_spec)

In [4]:
def func(sharding_spec):
    sharding = NamedSharding(mesh, sharding_spec)
    arr = jnp.arange(27).reshape(3, 9)
    sharded_arr = jax.device_put(arr, sharding)
    jax.debug.visualize_array_sharding(sharded_arr)
    device_map = sharded_arr.sharding.devices_indices_map(sharded_arr.shape)
    for device, slice_index in device_map.items():
        print(f"{device}: {slice_index}")

In [7]:
sharding_lst = [P("x", "y"), P("x", None), P(None, "y"), P(None, None)]
for spec in sharding_lst:
    print("-" * 100)
    print(f"\nSharding Spec: {spec}")
    func(spec)

----------------------------------------------------------------------------------------------------

Sharding Spec: PartitionSpec('x', 'y')


TFRT_CPU_0: (slice(0, 1, None), slice(0, 3, None))
TFRT_CPU_1: (slice(0, 1, None), slice(3, 6, None))
TFRT_CPU_2: (slice(0, 1, None), slice(6, 9, None))
TFRT_CPU_3: (slice(1, 2, None), slice(0, 3, None))
TFRT_CPU_4: (slice(1, 2, None), slice(3, 6, None))
TFRT_CPU_5: (slice(1, 2, None), slice(6, 9, None))
TFRT_CPU_6: (slice(2, 3, None), slice(0, 3, None))
TFRT_CPU_7: (slice(2, 3, None), slice(3, 6, None))
TFRT_CPU_8: (slice(2, 3, None), slice(6, 9, None))
----------------------------------------------------------------------------------------------------

Sharding Spec: PartitionSpec('x', None)


TFRT_CPU_0: (slice(0, 1, None), slice(None, None, None))
TFRT_CPU_1: (slice(0, 1, None), slice(None, None, None))
TFRT_CPU_2: (slice(0, 1, None), slice(None, None, None))
TFRT_CPU_3: (slice(1, 2, None), slice(None, None, None))
TFRT_CPU_4: (slice(1, 2, None), slice(None, None, None))
TFRT_CPU_5: (slice(1, 2, None), slice(None, None, None))
TFRT_CPU_6: (slice(2, 3, None), slice(None, None, None))
TFRT_CPU_7: (slice(2, 3, None), slice(None, None, None))
TFRT_CPU_8: (slice(2, 3, None), slice(None, None, None))
----------------------------------------------------------------------------------------------------

Sharding Spec: PartitionSpec(None, 'y')


TFRT_CPU_0: (slice(None, None, None), slice(0, 3, None))
TFRT_CPU_1: (slice(None, None, None), slice(3, 6, None))
TFRT_CPU_2: (slice(None, None, None), slice(6, 9, None))
TFRT_CPU_3: (slice(None, None, None), slice(0, 3, None))
TFRT_CPU_4: (slice(None, None, None), slice(3, 6, None))
TFRT_CPU_5: (slice(None, None, None), slice(6, 9, None))
TFRT_CPU_6: (slice(None, None, None), slice(0, 3, None))
TFRT_CPU_7: (slice(None, None, None), slice(3, 6, None))
TFRT_CPU_8: (slice(None, None, None), slice(6, 9, None))
----------------------------------------------------------------------------------------------------

Sharding Spec: PartitionSpec(None, None)


TFRT_CPU_0: (slice(None, None, None), slice(None, None, None))
TFRT_CPU_1: (slice(None, None, None), slice(None, None, None))
TFRT_CPU_2: (slice(None, None, None), slice(None, None, None))
TFRT_CPU_3: (slice(None, None, None), slice(None, None, None))
TFRT_CPU_4: (slice(None, None, None), slice(None, None, None))
TFRT_CPU_5: (slice(None, None, None), slice(None, None, None))
TFRT_CPU_6: (slice(None, None, None), slice(None, None, None))
TFRT_CPU_7: (slice(None, None, None), slice(None, None, None))
TFRT_CPU_8: (slice(None, None, None), slice(None, None, None))
