In [6]:
import jax
import numpy as np
import jax.numpy as jnp

arr = jnp.arange(32.0).reshape(4, 8)
print("Array devices:", arr.devices())
print("Array sharding:", arr.sharding)
jax.debug.visualize_array_sharding(arr)

Array devices: {CpuDevice(id=0)}
Array sharding: SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)


In [7]:
from jax.sharding import PartitionSpec as P

mesh = jax.make_mesh((2, 4), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)
arr_sharded = jax.device_put(arr, sharding)

print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)

NamedSharding(mesh=Mesh('x': 2, 'y': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('x', 'y'), memory_kind=device)
[[ 0.  1.  2.  3.  4.  5.  6.  7.]
 [ 8.  9. 10. 11. 12. 13. 14. 15.]
 [16. 17. 18. 19. 20. 21. 22. 23.]
 [24. 25. 26. 27. 28. 29. 30. 31.]]


In [8]:
// 1. Automatic sharding via jax.jit()
@jax.jit
def f_contract(x):
  return x.sum(axis=0)

result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)

[48. 52. 56. 60. 64. 68. 72. 76.]


In [None]:
// 2. Explicit Sharding 

In [None]:
// 3. Fully manual sharding using jax.shard_map(): shard_map