# Distributed Data Parallel (DDP) Demo
This notebook demonstrates the concept of separating model definition and sharding decisions using Parallax DDP.

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax import nnx
import jax.sharding


from etils import ecolab
import treescope
treescope.basic_interactive_setup(autovisualize_arrays=True)

with ecolab.adhoc('paralax', reload='parallax'):
  from parallax import DataParallelTraining
  from parallax import dereplicate

#%load_ext google3.learning.brain.tensorboard.notebook.extension
print(dir(DataParallelTraining))

['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_initialize_sharded_model_optimizer', '_replicate_state', '_setup_sharding', 'get_sharded_components', 'get_sharded_data']


In [2]:
jax.devices()

INFO:2025-04-04 15:20:43,936:jax._src.xla_bridge:865: Unable to initialize backend 'pathways': Could not initialize backend 'pathways'
INFO:2025-04-04 15:20:43,937:jax._src.xla_bridge:865: Unable to initialize backend 'proxy': INVALID_ARGUMENT: IFRT proxy server address must be '<transport-type>://<backend-address>' (e.g., 'grpc://localhost'), but got 
INFO:2025-04-04 15:20:43,943:jax._src.xla_bridge:865: Unable to initialize backend 'sliceme': Could not initialize backend 'sliceme'


# User provide a model and optimizer
No need to worry about sharding; just focus on the model design.

In [3]:
class MLP(nnx.Module):
  def __init__(self, din, dmid, dout, *, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
    self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    return self.linear2(nnx.relu(self.linear1(x)))

model = MLP(1, 64, 1, rngs=nnx.Rngs(0))
optimizer = nnx.ModelAndOptimizer(model, optax.adamw(1e-2))

# Get some data for training

In [4]:
def dataset(steps, batch_size):
  for _ in range(steps):
    x = np.random.uniform(-2, 2, size=(batch_size, 1))
    y = 0.8 * x**2 + 0.1 + np.random.normal(0, 0.1, size=x.shape)
    yield x, y

# Define train step

In [5]:
@nnx.jit
def train_step(model: MLP, optimizer: nnx.ModelAndOptimizer, x, y):
  def loss_fn(model: MLP):
    y_pred = model(x)
    return jnp.mean((y - y_pred) ** 2)

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)
  return loss

# Call Parallax to train the model with DDP

Before sharding, the model stays on the default device.

In [6]:
jax.debug.visualize_array_sharding(model.linear1.kernel.value)

┌──────────────────────────────────────────────────────────────────────────────┐
│                                                                              │
│                                                                              │
│                                                                              │
│                                                                              │
│                                    TPU 0                                     │
│                                                                              │
│                                                                              │
│                                                                              │
│                                                                              │
└──────────────────────────────────────────────────────────────────────────────┘


In [7]:
# Call parallax for sharding
ddp = DataParallelTraining(optimizer)
sharded_model, sharded_optimizer = ddp.get_sharded_components()

print("Sharded model")
jax.debug.visualize_array_sharding(sharded_model.linear1.kernel.value)
print("The original model remains unchanged after sharding.")
jax.debug.visualize_array_sharding(model.linear1.kernel.value)

# Run train steps
for step, (x, y) in enumerate(dataset(1000, 16)):
  x, y = ddp.get_sharded_data(x, y)
  loss = train_step(sharded_model, sharded_optimizer, x, y)

  if step == 0:
    print('data is sharded')
    jax.debug.visualize_array_sharding(x)

  if step % 100 == 0:
    print(f'step={step}, loss={loss}')


INFO:2025-04-04 15:21:01,360:jax._src.mesh_utils:82: Reordering mesh to physical ring order on single-tray TPU v2/v3.


Sharded model
┌──────────────────────────────────────────────────────────────────────────────┐
│                                                                              │
│                                                                              │
│                                                                              │
│                                                                              │
│                             TPU 0,1,2,3,4,5,6,7                              │
│                                                                              │
│                                                                              │
│                                                                              │
│                                                                              │
└──────────────────────────────────────────────────────────────────────────────┘
The original model remains unchanged after sharding.
┌─────────────────────────────────────────

# Dereplicate the model to keep a single copy on the CPU

In [8]:
dereped_model = dereplicate(ddp)
print(f"Dereplicated model stays at: {dereped_model.linear1.kernel.value.device}")


print("The original model remains unchanged after sharding.")
jax.debug.visualize_array_sharding(model.linear1.kernel.value)

print("Sharded model")
jax.debug.visualize_array_sharding(sharded_model.linear1.kernel.value)


Dereplicated model stays at: cpu
The original model remains unchanged after sharding.
┌──────────────────────────────────────────────────────────────────────────────┐
│                                                                              │
│                                                                              │
│                                                                              │
│                                                                              │
│                                    TPU 0                                     │
│                                                                              │
│                                                                              │
│                                                                              │
│                                                                              │
└──────────────────────────────────────────────────────────────────────────────┘
Sharded model
┌────────