# Experiment with Model Sharding Using Parallax — No Code Changes Needed

This Colab demonstrates how to apply different sharding strategies to a pre-defined Flax Linen model using Parallax — without modifying the model definition or training loop.

We use a simple decoder-only Transformer model to showcase Parallax's flexibility and ease of integration.


In [None]:
from typing import Optional

from etils import ecolab
from flax import linen as nn
from flax.training import train_state
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import optax
# import treescope
# treescope.basic_interactive_setup(autovisualize_arrays=True)

with ecolab.adhoc('parallax_piper', reload='parallax'):
  from parallax import manual_sharding_linen

# Define a Decoder-Only Transformer Model

Note that the model definition includes no sharding annotations — the model is written in standard Flax Linen, without any changes for parallelism or partitioning.

In [None]:
class PositionalEmbedding(nn.Module):
  seq_len: int
  embed_dim: int

  @nn.compact
  def __call__(self, x):
    pos_emb = self.param(
        "pos_embedding",
        nn.initializers.normal(stddev=0.02),
        (self.seq_len, self.embed_dim),
    )
    return x + pos_emb[None, :, :]


class MLP(nn.Module):
  hidden_dim: int
  out_dim: int
  dropout_rate: float = 0.1

  @nn.compact
  def __call__(self, x, deterministic: bool):
    x = nn.Dense(self.hidden_dim)(x)
    x = nn.gelu(x)
    x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
    x = nn.Dense(self.out_dim)(x)
    return x


class SelfAttentionBlock(nn.Module):
  embed_dim: int
  num_heads: int
  dropout_rate: float = 0.1

  @nn.compact
  def __call__(
      self, x, deterministic: bool, mask: Optional[jnp.ndarray] = None
  ):
    # LayerNorm + Self-Attention + Residual
    residual = x
    x = nn.LayerNorm()(x)
    x = nn.SelfAttention(
        num_heads=self.num_heads,
        dropout_rate=self.dropout_rate,
        deterministic=deterministic,
        use_bias=True,
        broadcast_dropout=False,
    )(x, mask=mask)
    x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
    x = x + residual

    # LayerNorm + MLP + Residual
    residual = x
    x = nn.LayerNorm()(x)
    x = MLP(
        hidden_dim=4 * self.embed_dim,
        out_dim=self.embed_dim,
        dropout_rate=self.dropout_rate,
    )(x, deterministic)
    x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
    x = x + residual

    return x


class DecoderOnlyTransformer(nn.Module):
  vocab_size: int
  seq_len: int
  embed_dim: int
  num_layers: int
  num_heads: int
  dropout_rate: float = 0.1

  @nn.compact
  def __call__(self, input_ids, deterministic: bool = True):
    # Token and Positional Embedding
    x = nn.Embed(self.vocab_size, self.embed_dim)(input_ids)
    x = PositionalEmbedding(seq_len=self.seq_len, embed_dim=self.embed_dim)(x)
    x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)

    # Causal mask (decoder-only)
    mask = nn.combine_masks(
        nn.make_attention_mask(input_ids > 0, input_ids > 0, dtype=jnp.bool_),
        nn.make_causal_mask(input_ids),
    )

    # Transformer decoder blocks
    for _ in range(self.num_layers):
      x = SelfAttentionBlock(
          embed_dim=self.embed_dim,
          num_heads=self.num_heads,
          dropout_rate=self.dropout_rate,
      )(x, deterministic=deterministic, mask=mask)

    # Final projection to vocab
    logits = nn.Dense(self.vocab_size)(x)
    return logits



# Define Sharding Rules in a Separate Configuration JSON

Sharding strategies for inputs, parameters, and outputs are specified in a standalone JSON configuration — completely decoupled from the model code.


In [None]:
sharding_config = {
    "mesh_axes": ("data", "model"),
    "in": P("data", None),  # Batch x Seq
    "out": P("data", None),
    "parameters": {
        "params": {
            # Input token embedding
            "Embed_0": {
                "embedding": P("data", "model"),  # (vocab_size, embed_dim)
            },
            # Positional embedding (e.g., learned or sinusoidal)
            "PositionalEmbedding_0": {
                "pos_embedding": P(None, "model"),  # (seq_len, embed_dim)
            },
            # === Decoder Block 0 ===
            "SelfAttentionBlock_0": {
                "SelfAttention_0": {
                    "key": {
                        "kernel": P(
                            None, "model", None
                        ),  # 3D: (input_dim, num_heads, head_dim)
                        "bias": P("model", None),  # 2D: (num_heads, head_dim)
                    },
                    "query": {
                        "kernel": P(None, "model", None),
                        "bias": P("model", None),
                    },
                    "value": {
                        "kernel": P(None, "model", None),
                        "bias": P("model", None),
                    },
                    "out": {
                        "kernel": P(
                            "model", None, None
                        ),  # Output projection: (num_heads, head_dim, output_dim)
                        "bias": P("model"),
                    },
                },
                "LayerNorm_0": {
                    "scale": P("model"),
                    "bias": P("model"),
                },
                "LayerNorm_1": {
                    "scale": P("model"),
                    "bias": P("model"),
                },
                "MLP_0": {
                    "Dense_0": {
                        "kernel": P("model", None),
                        "bias": P("model"),
                    },
                    "Dense_1": {
                        "kernel": P(None, "model"),
                        "bias": P("model"),
                    },
                },
            },
            # === Decoder Block 1 ===
            "SelfAttentionBlock_1": {
                "SelfAttention_0": {
                    "key": {
                        "kernel": P(
                            None, "model", None
                        ),  # 3D: (input_dim, num_heads, head_dim)
                        "bias": P("model", None),  # 2D: (num_heads, head_dim)
                    },
                    "query": {
                        "kernel": P(None, "model", None),
                        "bias": P("model", None),
                    },
                    "value": {
                        "kernel": P(None, "model", None),
                        "bias": P("model", None),
                    },
                    "out": {
                        "kernel": P(
                            "model", None, None
                        ),  # Output projection: (num_heads, head_dim, output_dim)
                        "bias": P("model"),
                    },
                },
                "LayerNorm_0": {
                    "scale": P("model"),
                    "bias": P("model"),
                },
                "LayerNorm_1": {
                    "scale": P("model"),
                    "bias": P("model"),
                },
                "MLP_0": {
                    "Dense_0": {
                        "kernel": P("model", None),
                        "bias": P("model"),
                    },
                    "Dense_1": {
                        "kernel": P(None, "model"),
                        "bias": P("model"),
                    },
                },
            },
            # === Output head ===
            "Dense_0": {
                "kernel": P("model", None),
                "bias": P(None),
            },
        }
    },
}

# Training Utilities and Setup

In [None]:
# Fake data generator for causal language modeling
def generate_fake_lm_data(batch_size, seq_len, vocab_size, seed=0):
  key = jax.random.PRNGKey(seed)
  input_ids = jax.random.randint(key, (batch_size, seq_len), 0, vocab_size)
  # Targets: input shifted left, last token ignored (set to 0)
  labels = jnp.roll(input_ids, shift=-1, axis=1)
  labels = labels.at[:, -1].set(0)
  return input_ids, labels

In [None]:
def create_train_state(apply_fn, params, learning_rate):
  tx = optax.adam(learning_rate)
  return train_state.TrainState.create(apply_fn=apply_fn, params=params, tx=tx)

### Define a Single Training Step

In [None]:
# Single training step
@jax.jit
def train_step(state, batch_x, batch_y):
  def loss_fn(params):
    logits = state.apply_fn(params, batch_x)
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits, batch_y
    ).mean()
    return loss, logits

  (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(
      state.params
  )
  state = state.apply_gradients(grads=grads)
  accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == batch_y)
  return state, {"loss": loss, "accuracy": accuracy}

### Define the Training Loop (Unchanged Across Sharding Strategies)

The training loop remains exactly the same, regardless of the chosen sharding strategy.


In [None]:
def train(sharding_config):
  batch_size = 64
  seq_len = 128
  vocab_size = 1024
  learning_rate = 1e-3

  # Generate fake LM data (inputs and targets)
  X, y = generate_fake_lm_data(batch_size, seq_len, vocab_size)

  rng = jax.random.PRNGKey(0)
  model = DecoderOnlyTransformer(
      vocab_size=vocab_size,
      seq_len=seq_len,
      embed_dim=512,
      num_layers=2,
      num_heads=8,
  )

  dummy_input = jnp.ones((2, seq_len), dtype=jnp.int32)
  init_vars = model.init(rng, dummy_input)

  # Shard model and get pjit-wrapped apply_fn
  pjit_runner, sharded_params, mesh = manual_sharding_linen.shard_linen_model(
      model, init_vars, sharding_config
  )

  # Create state with sharded params and pjit apply function
  state = create_train_state(pjit_runner, sharded_params, learning_rate)

  print("Parameter sharding: Embed_0, embedding")
  jax.debug.visualize_array_sharding(
      state.params["params"]["Embed_0"]["embedding"]
  )

  print("Parameter sharding: SelfAttentionBlock_1, MLP_0, Dense_1, kernel")
  jax.debug.visualize_array_sharding(
      state.params["params"]["SelfAttentionBlock_1"]["MLP_0"]["Dense_1"][
          "kernel"
      ]
  )
  print("Parameter sharding: SelfAttentionBlock_1, MLP_0, Dense_1, bias")
  jax.debug.visualize_array_sharding(
      state.params["params"]["SelfAttentionBlock_1"]["MLP_0"]["Dense_1"]["bias"]
  )

  print("Parameter sharding: Dense_0, MLP_0, kernel")
  jax.debug.visualize_array_sharding(
      state.params["params"]["Dense_0"]["kernel"]
  )
  print("Parameter sharding: Dense_0, MLP_0, bias")
  jax.debug.visualize_array_sharding(state.params["params"]["Dense_0"]["bias"])

  # Training loop inside mesh context
  with mesh:
    for epoch in range(10):
      state, metrics = train_step(state, X, y)
      print(
          f"Epoch {epoch}: Loss = {metrics['loss']:.4f}, Accuracy ="
          f" {metrics['accuracy'] * 100:.2f}%"
      )

# Train the Decoder with the Provided Sharding Rules

In [None]:
train(sharding_config)

Sharded param ('params', 'Embed_0', 'embedding') with PartitionSpec('data', 'model')
Sharded param ('params', 'PositionalEmbedding_0', 'pos_embedding') with PartitionSpec(None, 'model')
Sharded param ('params', 'SelfAttentionBlock_0', 'LayerNorm_0', 'scale') with PartitionSpec('model',)
Sharded param ('params', 'SelfAttentionBlock_0', 'LayerNorm_0', 'bias') with PartitionSpec('model',)
Sharded param ('params', 'SelfAttentionBlock_0', 'SelfAttention_0', 'query', 'kernel') with PartitionSpec(None, 'model', None)
Sharded param ('params', 'SelfAttentionBlock_0', 'SelfAttention_0', 'query', 'bias') with PartitionSpec('model', None)
Sharded param ('params', 'SelfAttentionBlock_0', 'SelfAttention_0', 'key', 'kernel') with PartitionSpec(None, 'model', None)
Sharded param ('params', 'SelfAttentionBlock_0', 'SelfAttention_0', 'key', 'bias') with PartitionSpec('model', None)
Sharded param ('params', 'SelfAttentionBlock_0', 'SelfAttention_0', 'value', 'kernel') with PartitionSpec(None, 'model', No

# Now Let's change the Sharding Rules

In [None]:
sharding_config_2 = {
    "mesh_axes": ("data", "model"),
    "in": P("data", None),  # Batch x Seq
    "out": P("data", None),
    "parameters": {
        "params": {
            # Input token embedding
            "Embed_0": {
                "embedding": P(
                    None, "model"
                ),  # (vocab_size, embed_dim). # <-- Changed
            },
            # Positional embedding (e.g., learned or sinusoidal)
            "PositionalEmbedding_0": {
                "pos_embedding": P(None, "model"),  # (seq_len, embed_dim)
            },
            # === Decoder Block 0 ===
            "SelfAttentionBlock_0": {
                "SelfAttention_0": {
                    "key": {
                        "kernel": P(
                            None, "model", None
                        ),  # 3D: (input_dim, num_heads, head_dim)
                        "bias": P("model", None),  # 2D: (num_heads, head_dim)
                    },
                    "query": {
                        "kernel": P(None, "model", None),
                        "bias": P("model", None),
                    },
                    "value": {
                        "kernel": P(None, "model", None),
                        "bias": P("model", None),
                    },
                    "out": {
                        "kernel": P(
                            "model", None, None
                        ),  # Output projection: (num_heads, head_dim, output_dim)
                        "bias": P("model"),
                    },
                },
                "LayerNorm_0": {
                    "scale": P("model"),
                    "bias": P("model"),
                },
                "LayerNorm_1": {
                    "scale": P("model"),
                    "bias": P("model"),
                },
                "MLP_0": {
                    "Dense_0": {
                        "kernel": P("model", None),
                        "bias": P("model"),
                    },
                    "Dense_1": {
                        "kernel": P(None, "model"),
                        "bias": P("model"),
                    },
                },
            },
            # === Decoder Block 1 ===
            "SelfAttentionBlock_1": {
                "SelfAttention_0": {
                    "key": {
                        "kernel": P(
                            None, "model", None
                        ),  # 3D: (input_dim, num_heads, head_dim)
                        "bias": P("model", None),  # 2D: (num_heads, head_dim)
                    },
                    "query": {
                        "kernel": P(None, "model", None),
                        "bias": P("model", None),
                    },
                    "value": {
                        "kernel": P(None, "model", None),
                        "bias": P("model", None),
                    },
                    "out": {
                        "kernel": P(
                            "model", None, None
                        ),  # Output projection: (num_heads, head_dim, output_dim)
                        "bias": P("model"),
                    },
                },
                "LayerNorm_0": {
                    "scale": P("model"),
                    "bias": P("model"),
                },
                "LayerNorm_1": {
                    "scale": P("model"),
                    "bias": P("model"),
                },
                "MLP_0": {
                    "Dense_0": {
                        "kernel": P("model", None),
                        "bias": P("model"),
                    },
                    "Dense_1": {
                        "kernel": P("data", None),  # <--- changed
                        "bias": P("data"),  # <--- changed
                    },
                },
            },
            # === Output head ===
            "Dense_0": {
                "kernel": P("data", "model"),  # <--- changed
                "bias": P("data"),  # <--- changed
            },
        }
    },
}

# Train with New Sharding Rules—No Changes to Model or Training Code



In [None]:
train(sharding_config_2)

Sharded param ('params', 'Embed_0', 'embedding') with PartitionSpec(None, 'model')
Sharded param ('params', 'PositionalEmbedding_0', 'pos_embedding') with PartitionSpec(None, 'model')
Sharded param ('params', 'SelfAttentionBlock_0', 'LayerNorm_0', 'scale') with PartitionSpec('model',)
Sharded param ('params', 'SelfAttentionBlock_0', 'LayerNorm_0', 'bias') with PartitionSpec('model',)
Sharded param ('params', 'SelfAttentionBlock_0', 'SelfAttention_0', 'query', 'kernel') with PartitionSpec(None, 'model', None)
Sharded param ('params', 'SelfAttentionBlock_0', 'SelfAttention_0', 'query', 'bias') with PartitionSpec('model', None)
Sharded param ('params', 'SelfAttentionBlock_0', 'SelfAttention_0', 'key', 'kernel') with PartitionSpec(None, 'model', None)
Sharded param ('params', 'SelfAttentionBlock_0', 'SelfAttention_0', 'key', 'bias') with PartitionSpec('model', None)
Sharded param ('params', 'SelfAttentionBlock_0', 'SelfAttention_0', 'value', 'kernel') with PartitionSpec(None, 'model', None