In [4]:
# %%
import jax
import jax.numpy as jnp
import numpy as np
import flax.linen as nn
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding


In [6]:

x = jnp.arange(64 * 64).reshape(64, 64)
jax.debug.visualize_array_sharding(x)

In [7]:
x.sharding

SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))

In [58]:

sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
sharding.reshape(4, 2)

PositionalSharding([[{TPU 0} {TPU 1}]
                    [{TPU 2} {TPU 3}]
                    [{TPU 6} {TPU 7}]
                    [{TPU 4} {TPU 5}]])

In [67]:

y = jax.device_put(x, sharding.reshape(4, 2).replicate(0))
jax.debug.visualize_array_sharding(y)

In [26]:
z = y + 1
z

Array([[   1,    2,    3, ...,   62,   63,   64],
       [  65,   66,   67, ...,  126,  127,  128],
       [ 129,  130,  131, ...,  190,  191,  192],
       ...,
       [3905, 3906, 3907, ..., 3966, 3967, 3968],
       [3969, 3970, 3971, ..., 4030, 4031, 4032],
       [4033, 4034, 4035, ..., 4094, 4095, 4096]], dtype=int32)

In [27]:
jax.debug.visualize_array_sharding(z)

In [28]:
@jax.jit
def f(x):
    return -x

a = f(z)
a

Array([[   -1,    -2,    -3, ...,   -62,   -63,   -64],
       [  -65,   -66,   -67, ...,  -126,  -127,  -128],
       [ -129,  -130,  -131, ...,  -190,  -191,  -192],
       ...,
       [-3905, -3906, -3907, ..., -3966, -3967, -3968],
       [-3969, -3970, -3971, ..., -4030, -4031, -4032],
       [-4033, -4034, -4035, ..., -4094, -4095, -4096]], dtype=int32)

In [29]:
jax.debug.visualize_array_sharding(a)

In [43]:
t = x.reshape(64, 64)
k = jax.device_put(t, sharding.reshape(4, 2))
k = k.reshape(64, 8, 8, 1)
len(k.addressable_shards)


8

In [46]:
@jax.pmap
def g(x):
    return x

p = g(t.reshape(8, 8, 64, 1))
p.sharding

PmapSharding(sharding_spec=ShardingSpec((Unstacked(8), NoSharding(), NoSharding(), NoSharding()), (ShardedAxis(axis=0),)), devices=[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0)
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1)
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0)
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1)
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0)
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)])

In [50]:
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils

P = PartitionSpec

mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), axis_names=('a', 'b'))
sharding = NamedSharding(mesh, P('a', 'b'))
sharding

NamedSharding(mesh={'a': 4, 'b': 2}, spec=PartitionSpec('a', 'b'))

In [70]:
y = jax.device_put(x, NamedSharding(mesh, P(None, 'b')))
jax.debug.visualize_array_sharding(y)