In [1]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2'

## Basic Data Parallel Training Example

In [10]:
from functools import partial

import jax
import jax.numpy as jnp
from jax import value_and_grad, jit

def init_params():
    W = jnp.ones((2, 2))
    b = jnp.zeros((2,))
    return W, b

def forward(params, x):
    W, b = params
    return x @ W + b

def loss_fn(params, x, y):
    pred = forward(params, x)
    return jnp.mean((pred - y) ** 2)

@partial(jax.pmap, axis_name=("i"))
def train_step(params, x, y):
    # Compute per-device loss and grads
    loss, grads = value_and_grad(loss_fn)(params, x, y)

    # Use lax.psum to manually average across devices (if using with `pmap`)
    grads = jax.tree_map(lambda g: jax.lax.pmean(g, axis_name='i'), grads)

    # Update parameters
    new_params = jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)
    return new_params, loss

In [11]:
from jax.debug import visualize_array_sharding as viz

# Get available devices
devices = jax.local_devices()

B, C = 8, 2

key = jax.random.key(0)
x = jax.random.normal(jax.random.key(0), (B, C))
y = jax.random.normal(jax.random.key(1), (B, C))

# Replicate model params across devices
params = init_params()
replicated_params = jax.device_put_replicated(params, devices)
#viz(replicated_params[0][7])

# Perform one training step
new_params, loss = train_step(replicated_params, x, y)

## Basic All-to-All Example

In [13]:
from functools import partial
import jax 
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding

print(jax.devices())
mesh = Mesh(jax.devices(), ("devices",))
spec = PartitionSpec("devices")
sharding = NamedSharding(mesh, spec)

W = jnp.stack([
   jnp.zeros((3, 3)), 
   jnp.ones((3, 3)),
])

W = jax.device_put(W, sharding)

x = jax.random.normal(jax.random.key(0), shape=(2, 2, 5, 3))
print(x)
#jax.debug.visualize_array_sharding(x[0])
 
@partial(jax.pmap, axis_name="i")
def step(x):
   device_index = jax.lax.axis_index("i")
   y = jax.lax.all_to_all(x, "i", 0, 0)
   y = y @ W[device_index]
   z = jax.lax.all_to_all(y, "i", 0, 0)
   return z

y = step(x)
print(y.shape)
print(y)
#jax.debug.visualize_array_sharding(y[1])

[CpuDevice(id=0), CpuDevice(id=1)]
[[[[ 1.6226422   2.0252647  -0.43359444]
   [-0.07861735  0.1760909  -0.97208923]
   [-0.49529874  0.4943786   0.6643493 ]
   [-0.9501635   2.1795304  -1.9551506 ]
   [ 0.35857072  0.15779513  1.2770847 ]]

  [[ 1.5104648   0.970656    0.59960806]
   [ 0.0247007  -1.9164772  -1.8593491 ]
   [ 1.728144    0.04719035  0.814128  ]
   [ 0.13132767  0.28284705  1.2435943 ]
   [ 0.6902801  -0.80073744 -0.74099   ]]]


 [[[-1.5388287   0.30269185 -0.02071605]
   [ 0.11328721 -0.2206547   0.07052256]
   [ 0.8532958  -0.8217738  -0.01461421]
   [-0.15046217 -0.9001352  -0.7590727 ]
   [ 0.33309513  0.80924904  0.04269255]]

  [[-0.57767123 -0.41439894 -1.9412533 ]
   [ 1.3161184   0.7542728   0.16170931]
   [-0.03483307 -1.3306409   0.39362028]
   [ 0.48259583  0.80382955 -0.6337168 ]
   [ 1.038756   -0.74159133 -0.4299588 ]]]]
(2, 2, 5, 3)
[[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        