## Shardy:MPMD intro for JAX users

Shardy:MPMD is a new MPMD partitioning system, built in MLIR and integrated on top of JAX.

This colab demonstrates how to use MPMD pipelining for JAX users who use `jax.jit`. See our RFC for more details.

**Note**: This colab is purely read-only, and cannot be executed until we
fully open source all the components

## Overview
This colab starts by
1. Defining a simplified Transformer (without the encode and decode stages) in SPMD with jax.jit and some sharding, and then
2. Demonstrates how to pipeline it using MPMD using different schedules



### Set up
We connect to the Pathways server, inspect the devices, and load each slice into its own mesh. Mesh names are "m0", "m1", ...

This colab assumes we have 8 devices.

In [None]:
# @title Imports and connect to Pathways server { form-width: "80px" }
pathways_server_xid = 171321046  # @param {type: "number"}

from pprint import pprint

import jax
import jax.numpy as jnp
import numpy as np

import mpmd # Shardy MPMD python lib
import pathways_launch

jax.config.update('jax_use_shardy_partitioner', True)

# Mock API to connect to Pathways on Cloud TPUs.
pathways_launch.connect(pathways_server_xid)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
# @title Check devices:
print(f"Total num devices: {len(jax.devices())}")
mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(4,2), ("stage", "data"))
print("Base mesh: ", mesh)

Total num devices: 8
Base mesh:  Mesh('stage': 4, 'data': 2, axis_types=(Auto, Auto))


In [None]:
# @title Set up basic topology and assignment
topology = {}
for i in range(mesh.devices.shape[0]):
  topology[f"m{i}"] = jax.sharding.Mesh(
      mesh.devices[i].reshape(1, 2), ("stage", "data")
  )


print("MPMD topology: ")
pprint(topology)

MPMD topology: 
{'m0': Mesh(device_ids=array([[0, 1]]), axis_names=('stage', 'data'), axis_types=(Auto, Auto)),
 'm1': Mesh(device_ids=array([[2, 3]]), axis_names=('stage', 'data'), axis_types=(Auto, Auto)),
 'm2': Mesh(device_ids=array([[4, 5]]), axis_names=('stage', 'data'), axis_types=(Auto, Auto)),
 'm3': Mesh(device_ids=array([[6, 7]]), axis_names=('stage', 'data'), axis_types=(Auto, Auto))}


### Define a basic Transformer and util functions

In [None]:
import flax.linen as nn

# Define model parameters
BATCH_SIZE = 2
SEQ_LEN = 8 * 1024
D_MODEL = 1024
MLP_DIM = 4 * D_MODEL
NUM_LAYERS = 24


class Block(nn.Module):

  @nn.remat
  @nn.jit
  @nn.compact
  def __call__(self, x):
    attn_output = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=16)(
        x
    )
    x = x + attn_output
    x = nn.LayerNorm()(x)

    # Feed-forward network
    mlp_output = nn.Dense(features=MLP_DIM)(x)
    mlp_output = nn.gelu(mlp_output)
    mlp_output = nn.Dense(features=x.shape[-1])(mlp_output)
    x = x + mlp_output
    x = nn.LayerNorm()(x)

    return x


class Transformer(nn.Module):

  @nn.compact
  def __call__(self, x):
    for i in range(NUM_LAYERS):
      x = Block(name=f"block_{i}")(x)
    return x


# Initialize the model's parameters
dummy_input = jnp.ones((BATCH_SIZE, SEQ_LEN, D_MODEL))
transformer = Transformer()
key = jax.random.PRNGKey(0)
params = transformer.init(key, dummy_input)["params"]

print("Model initialized successfully!")

Model initialized successfully!


In [None]:
from flax.training import train_state
import optax
import time

optimizer = optax.adamw(learning_rate=0.001)
state = train_state.TrainState.create(
    apply_fn=transformer.apply, params=params, tx=optimizer
)


# Define the training step.
def train_step(state, xs, targets):
  """Trains the model for one step."""

  def loss_fn(params, x, targets):
    predictions = state.apply_fn({"params": params}, x)
    return jnp.mean((predictions - targets) ** 2)

  loss_acc, grads_acc = None, None
  for x in xs:
    loss, grads = jax.value_and_grad(loss_fn)(state.params, x, targets)

    loss_acc = loss if loss_acc is None else loss_acc + loss
    grads_acc = (
        grads
        if grads_acc is None
        else jax.tree.map(lambda x, y: x + y, grads_acc, grads)
    )

  state = state.apply_gradients(grads=grads_acc)
  return state, loss_acc


def train_with_progress(train_step, inputs, num_steps=3):
  updated_state, x, targets = inputs
  training_loss = None
  # Warmup
  jax.block_until_ready(train_step(updated_state, x, targets))

  start_time = time.perf_counter()
  for i in range(num_steps):
    updated_state, training_loss = train_step(updated_state, x, targets)
    if i % 2 == 1:
      print(f"Training loss after step {i+1}: {training_loss}")

  jax.block_until_ready(updated_state)
  end_time = time.perf_counter()
  print(f"Final training loss: {training_loss}")
  print(f"Training took: {end_time - start_time:.2f} seconds")

print("Model util functions initialized.")

Model util functions initialized.


### Run the Transformer

In [None]:
# Set up inputs.
# We set the microbatches to num_pipeline stages as that's what we'll
# use for the pipelining.
NUM_PIPELINE_STAGE = len(topology)
NUM_MB = NUM_PIPELINE_STAGE
print("Num microbatches: ", NUM_MB)

xs = tuple([jnp.ones_like(dummy_input)] * NUM_MB)
inputs = (state, xs, dummy_input)

def get_param_sharding(x):
  if len(getattr(x, "shape", [])) > 0:
    return jax.sharding.NamedSharding(
        mesh,
        jax.sharding.PartitionSpec(("stage", "data")),
    )
  else:
    return jax.sharding.NamedSharding(
        mesh,
        jax.sharding.PartitionSpec(),
    )

# Data parallel + ZeRO 3 sharding on stage + data.
in_shardings = (
    jax.tree.map(get_param_sharding, state),
    jax.sharding.NamedSharding(
        mesh,
        jax.sharding.PartitionSpec("data"),
    ),
    jax.sharding.NamedSharding(
        mesh,
        jax.sharding.PartitionSpec("data"),
    ),
)

Num microbatches:  4


In [None]:
# Simple SPMD training with micro-batching.
jitted_train_step = jax.jit(train_step, in_shardings=in_shardings)
compiled = jitted_train_step.lower(*inputs).compile()
sharded_inputs = jax.device_put(inputs, in_shardings)

train_with_progress(compiled, sharded_inputs)

Training loss after step 2: 7.991689205169678
Final training loss: 7.980075836181641
Training took: 10.06 seconds


Profile:

![spmd_profile](https://raw.githubusercontent.com/openxla/shardy/main/rfcs/images/2025-06-18-mpmd-rfc/spmd_profile.png)

### Pipeline the transformer

In [None]:
# To use MPMD, annotate the transformer and use mpmd.call instead of a for loop.
class AnnotatedTransformer(nn.Module):

  @nn.compact
  def __call__(self, x):
    for i in range(NUM_LAYERS):
      x = mpmd.flax.named_computation(Block, name=f"block_{i}")()(x)
    return x


def mpmd_train_step(state, xs, targets):
  """Trains the model for one step with mpmd microbatching."""

  def loss_fn(params, x):
    predictions = state.apply_fn({"params": params}, x)
    return jnp.mean((predictions - targets) ** 2)

  carry = jnp.zeros(()), jax.tree.map(jnp.zeros_like, state.params)

  # Accumulation is inside the mpmd.call, to ensure that the accumulation
  # is done as we go along. E.g. instead of at the end, which would be bad
  # for memory.
  def microbatch_step(carry, params ,x):
    val_and_grad = jax.value_and_grad(loss_fn)(params, x)
    carry = jax.tree.map(lambda x, y: x + y, carry, val_and_grad)
    return carry

  for i, x in enumerate(xs):
    # Note the mpmd.call here, with call counter, wrapping the accumulation
    # function.
    carry = mpmd.call(microbatch_step, call_counter=i)(carry, state.params, x)

  loss_acc, grads_acc = carry
  state = state.apply_gradients(grads=grads_acc)
  return state, loss_acc


annotated_transformer = AnnotatedTransformer()
annotated_params = annotated_transformer.init(key, dummy_input)["params"]
annotated_state = train_state.TrainState.create(
    apply_fn=annotated_transformer.apply, params=annotated_params, tx=optimizer
)
annotated_placeholder_inputs = (annotated_state, xs, dummy_input)

basic_assignment = {}
for i in range(NUM_LAYERS):
  layers_per_mesh = NUM_LAYERS // len(topology)
  mesh_idx = min(i // layers_per_mesh, len(topology) - 1)
  basic_assignment[f"block_{i}"] = f"m{mesh_idx}"


print("Name to mesh assignment:")
pprint(basic_assignment)

mpmd_config = mpmd.make_config(
    topology=topology,
    name_to_mesh_assignment=basic_assignment,
    partitioning_options=mpmd.make_partitioning_options({
        "mpmd_pipeline_schedule": "ONE_FWD_ONE_BWD",
    }),
)

Name to mesh assignment:
{'block_0': 'm0',
 'block_1': 'm0',
 'block_10': 'm1',
 'block_11': 'm1',
 'block_12': 'm2',
 'block_13': 'm2',
 'block_14': 'm2',
 'block_15': 'm2',
 'block_16': 'm2',
 'block_17': 'm2',
 'block_18': 'm3',
 'block_19': 'm3',
 'block_2': 'm0',
 'block_20': 'm3',
 'block_21': 'm3',
 'block_22': 'm3',
 'block_23': 'm3',
 'block_3': 'm0',
 'block_4': 'm0',
 'block_5': 'm0',
 'block_6': 'm1',
 'block_7': 'm1',
 'block_8': 'm1',
 'block_9': 'm1'}


In [None]:
# Similar to before, except the stage axis is now used for pipelining,
# so we don't shard on it.
def get_sharding_for_pipeline_state(x):
  if len(getattr(x, "shape", [])) > 0:
    return jax.sharding.NamedSharding(
        mpmd_config.sharding_mesh,
        jax.sharding.PartitionSpec("data"),
    )
  else:
    return jax.sharding.NamedSharding(
        mpmd_config.sharding_mesh,
        jax.sharding.PartitionSpec(),
    )


# Data parallel.
in_shardings = (
    jax.tree.map(get_sharding_for_pipeline_state, annotated_state),
    jax.sharding.NamedSharding(
        mpmd_config.sharding_mesh,
        jax.sharding.PartitionSpec("data"),
    ),
    jax.sharding.NamedSharding(
        mpmd_config.sharding_mesh,
        jax.sharding.PartitionSpec("data"),
    ),
)

#### Running the pipelined transformer

Now, we execute the transformer.

Note that we've not had to annotate other parts
of our program, e.g. we've not had to annotate the optimizer, nor the loss.
We've also not had to do anything with the gradient computations. These are
handled by the compiler and merged into an appropriate program.

Furthermore, we've not introduced any cross-mesh transfers explicitly. These
are automatically created on the name-to-name boundaries, e.g. when going from
"layer{i}" to "layer{i+1}", if they are assigned to different meshes, we create
the cross-mesh transfer.

We execute the transformer with various schedules, with the schedule applied at
jit-time. This can also be manually orchestrated with `mpmd.jit`, but we've
found the flexibility to be beneficial.

Note in the profile below, that some of the blocks have been compiled to
multiple programs. E.g. the backward computation of blocks 0..5 have programs
p7, p10 and p14. This is because of how we've merged in the unannotated ops.
The first backward computation p7 will have the gradient accumulators initialized, and the last one will have the param updates, which is why they
are different.

In [None]:
mpmd_jitted_train = mpmd.jit(
    mpmd_train_step,
    mpmd_config=mpmd_config,
    in_shardings=in_shardings,
    # Partitioning API is a work-in-progress. For now we've hardcoded the
    # schedule, but in the future we'll expose fine-grained control as in the
    # RFC.
    partitioning_options=mpmd.make_partitioning_options({
        "mpmd_pipeline_schedule": "ONE_FWD_ONE_BWD",
    }),
).lower(*annotated_placeholder_inputs)
mpmd_compiled = mpmd_jitted_train.compile()

# With MPMD, we need to be more careful with state, and make sure it's on the
# right devices.
pipelined_inputs = jax.device_put(
    annotated_placeholder_inputs,
    mpmd_jitted_train.function_named_shardings.input_specs,
)

print("Running program with schedule: ONE_FWD_ONE_BWD")
train_with_progress(mpmd_compiled, pipelined_inputs)

Running program with schedule: ONE_FWD_ONE_BWD
Training loss after step 2: 7.992943286895752
Final training loss: 7.976408004760742
Training took: 11.30 seconds


Profile:

![mpmd_1f1b_profile](https://raw.githubusercontent.com/openxla/shardy/main/rfcs/images/2025-06-18-mpmd-rfc/mpmd_1f1b_profile.png)

In [None]:
def run_xprof_with_schedule(schedule, assignment, stage_assignment=None):
  print(f"Running program with schedule: {schedule}")

  options = {"mpmd_pipeline_schedule": schedule}
  mpmd_jitted_train = mpmd.jit(
      mpmd_train_step,
      mpmd_config=mpmd.make_config(
          topology=topology,
          name_to_mesh_assignment=assignment,
          name_to_stage_assignment=stage_assignment,
          partitioning_options=mpmd.make_partitioning_options(options),
      ),
      in_shardings=in_shardings,
  ).lower(*annotated_placeholder_inputs)
  mpmd_compiled = mpmd_jitted_train.compile()

  pipelined_inputs = jax.device_put(
      annotated_placeholder_inputs,
      mpmd_jitted_train.function_named_shardings.input_specs,
  )

  train_with_progress(mpmd_compiled, pipelined_inputs)


In [None]:
run_xprof_with_schedule("GPIPE", basic_assignment)

Running program with schedule: GPIPE
Training loss after step 2: 7.992943286895752
Final training loss: 7.976408004760742
Training took: 11.35 seconds


Profile:

![mpmd_gpipe_profile](https://raw.githubusercontent.com/openxla/shardy/main/rfcs/images/2025-06-18-mpmd-rfc/mpmd_gpipe_profile.png)

In [None]:
circular_assignment = {}
stage_assignment = {}
for i in range(NUM_LAYERS):
  circular_assignment[f"block_{i}"] = f"m{i % len(topology)}"
  stage_assignment[f"block_{i}"] = i // 2

run_xprof_with_schedule("CIRCULAR", circular_assignment, stage_assignment)

Running program with schedule: CIRCULAR
Training loss after step 2: 7.992943286895752
Final training loss: 7.976408004760742
Training took: 7.46 seconds


Profile:

![mpmd_circular_profile](https://raw.githubusercontent.com/openxla/shardy/main/rfcs/images/2025-06-18-mpmd-rfc/mpmd_circular_profile.png)

In [None]:
# Print the main func body of the original MPMD program (1F1B)
mlir_module = mpmd_jitted_train.as_text("mpmd")
truncated_mlir_module = mlir_module.split("func.func")[1]
print("func.func" + truncated_mlir_module)

func.func public @main(%arg0: !mpmd.mesh_tensor<"m0", tensor<i32>, sharding=<@mesh, []>>, %arg1: !mpmd.mesh_tensor<"m0", tensor<4096xf32>, sharding=<@mesh, [{"data"}]>>, %arg2: !mpmd.mesh_tensor<"m0", tensor<1024x4096xf32>, sharding=<@mesh, [{"data"}, {}]>>, %arg3: !mpmd.mesh_tensor<"m0", tensor<1024xf32>, sharding=<@mesh, [{"data"}]>>, %arg4: !mpmd.mesh_tensor<"m0", tensor<4096x1024xf32>, sharding=<@mesh, [{"data"}, {}]>>, %arg5: !mpmd.mesh_tensor<"m0", tensor<1024xf32>, sharding=<@mesh, [{"data"}]>>, %arg6: !mpmd.mesh_tensor<"m0", tensor<1024xf32>, sharding=<@mesh, [{"data"}]>>, %arg7: !mpmd.mesh_tensor<"m0", tensor<1024xf32>, sharding=<@mesh, [{"data"}]>>, %arg8: !mpmd.mesh_tensor<"m0", tensor<1024xf32>, sharding=<@mesh, [{"data"}]>>, %arg9: !mpmd.mesh_tensor<"m0", tensor<8x2xf32>, sharding=<@mesh, [{"data"}, {}]>>, %arg10: !mpmd.mesh_tensor<"m0", tensor<1024x8x2xf32>, sharding=<@mesh, [{"data"}, {}, {}]>>, %arg11: !mpmd.mesh_tensor<"m0", tensor<1024xf32>, sharding=<@mesh, [{"data"}