# Implicit All-to-All Communication via Sharding Constraints

Explicit calls to jax.lax.all_to_all are not needed with the latest version of jax. The collective communications are inserted automatically based on the sharding constraints.

In [6]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'

import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding

mesh = Mesh(np.array(jax.devices()).reshape(2, 2), ("a", "b"))

x = jax.random.normal(jax.random.key(0), (4, 4))
print(x)

y_sharding = NamedSharding(mesh, PartitionSpec("a", None))
y = jax.device_put(x, y_sharding)
jax.debug.visualize_array_sharding(y)

z_sharding = NamedSharding(mesh, PartitionSpec(None, "b"))
z = jax.device_put(x, z_sharding)
jax.debug.visualize_array_sharding(z)

a = y @ z
jax.debug.visualize_array_sharding(a)
print(a)
jnp.allclose(a, x @ x)


[[ 1.6226422   2.0252647  -0.43359444 -0.07861735]
 [ 0.1760909  -0.97208923 -0.49529874  0.4943786 ]
 [ 0.6643493  -0.9501635   2.1795304  -1.9551506 ]
 [ 0.35857072  0.15779513  1.2770847   1.5104648 ]]


[[ 2.67335     1.7171221  -2.7521129   1.6026733 ]
 [-0.03722524  1.8502135  -0.04303275  1.2207012 ]
 [ 1.6575959  -0.11029654  2.4360166  -7.736466  ]
 [ 1.9996572  -0.40228555  4.4788065  -0.1655684 ]]


Array(True, dtype=bool)