# How do Mixture of Experts Layers Work? Part 2

## Introduction

In [Part 1](https://blog.vikrampawar.com/how-mixture-of-experts-works-part-1.html), we introduced the idea of Mixtures of Expert layers and attempted a naive implementation, where we trained a simple neural network with an expert router and two experts. The model learned to route each datapoint to one of two regression models. We used a synthetic dataset tailored for our model. 

In this post, we'll build upon that basic design and explore how to scale MoE layers across multiple GPUs.

## Why Mixture of Experts?

The main reason why of Mixture of Experts is used because it can scale model parameters with sublinear compute scaling. In a dense model, every parameter is used for every input. Double the parameters, double the FLOPs. MoE breaks this relationship by only activating a subset of parameters for each input. In modern MoE models, each token is routed to K of N experts (typically K = 1 or 2). This means per-token FLOPs in the expert layers are reduced to K/N of what a dense network would require, without a large performance penalty.

As a result, you can train much larger models while keeping the compute budget under control. A 400B parameter MoE model with 8 experts and top-2 routing has roughly the same inference cost as a 100B dense model, but with access to 4x more parameters.

## The Parallelism Challenge

The parallelism challenge in MoE is about what happens when you try to distribute experts across multiple GPUs. In a naive implementation, you might run every token through every expert and then weight the outputs. This is actually dense computation masquerading as MoE — you don't get any compute savings.

To get actual sparse computation, you need to only send each token to its assigned K experts. But when experts live on different GPUs, this creates a coordination problem. Token A on GPU 0 might need Expert 3 on GPU 2. Token B on GPU 1 might need Expert 1 on GPU 0. Every GPU potentially needs to send tokens to every other GPU, creating a many-to-many communication pattern.

[DIAGRAM: The Parallelism Challenge]
*Suggestion: A before/after view of token-to-expert routing across GPUs. Left side shows GPUs with their local batch slices, right side shows tokens redistributed to their assigned experts. Arrows crossing between GPUs to show the many-to-many pattern.*

This introduces several challenges. First, all-to-all communication is expensive, especially across nodes. Second, if the router sends many tokens to one expert, that GPU becomes a bottleneck while others sit idle. Third, each GPU needs memory to buffer incoming and outgoing tokens during the shuffle. Finally, you need two shuffles — one to dispatch tokens to experts, and another to combine results back to their original positions.

## Expert Parallelism Explained

Expert parallelism solves the routing problem by distributing experts across GPUs and using all-to-all communication to shuffle tokens to their assigned experts. The sharding strategy has three components: expert weights are sharded on the expert dimension so that GPU i holds expert i, non-expert layers (attention, embeddings, normalization) are replicated across all GPUs, and data batches are sharded on the batch dimension so that GPU i processes batch slice i. This combines data parallelism with expert parallelism.

**MoE Sharding Strategy**

![moe-sharding-strategy](moe-sharding-strategy-v2.png)

## Token Routing - Dispatch and Combine

The token flow works in four phases. First, each GPU independently gathers its tokens into expert-specific buffers based on routing decisions. After this local gathering step, each GPU has organized its tokens by which expert they need. Second, during the all-to-all dispatch, the expert buffers are redistributed across GPUs so that each expert's tokens are collected on their respective GPU. GPU 0 receives all tokens destined for expert 0, GPU 1 receives all tokens for expert 1, and so on. Third, each GPU runs its local expert on the tokens it received. Fourth, during the all-to-all combine, the processed tokens are routed back to their original GPUs and reassembled in the correct sequence positions.

![token-dispatch-combine](token-dispatch-combine.png)

In JAX, the all-to-all communication emerges implicitly from sharding constraints. You specify how tensors should be partitioned (e.g., tokens sharded by batch, expert weights sharded by expert index), and XLA's compiler inserts the necessary collectives when an operation requires data that lives on another device. For MoE, this means the dispatch and combine shuffles happen automatically when the sharding layout changes between "tokens grouped by batch" and "tokens grouped by expert."

The dispatch and combine phases come with real costs. MoE trades compute savings (activating only K of N experts) for communication overhead — each token must be sent to its assigned expert and the result returned, roughly `2 × num_tokens × hidden_dim` bytes moved per MoE layer. All-to-all is also a synchronization barrier: every GPU must wait for the slowest one to finish sending and receiving, so load imbalance amplifies latency. Finally, each GPU needs buffer memory to stage outgoing tokens (grouped by destination) and incoming tokens (from all other GPUs), increasing peak memory usage beyond what the expert weights alone require.

## Capacity Factor and Token Dropping

The router doesn't distribute tokens evenly — one expert might receive 80% of tokens while another gets 5%. To handle this, MoE implementations define an expert capacity: `capacity = (batch_tokens / num_experts) × capacity_factor`, where capacity factor (typically 1.0 to 2.0) controls how much slack each expert has. When more tokens are routed to an expert than its capacity allows, the excess tokens are dropped — they skip the expert entirely and pass through via the residual connection.

The capacity factor controls a tradeoff: higher values mean fewer dropped tokens but more buffer memory and wasted compute on padding; lower values keep memory tight but risk dropping tokens. To minimize dropping without inflating capacity, MoE models use auxiliary load balancing losses that penalize uneven routing, pushing the router toward balanced assignments.

## Putting It Together: A Sparse MoE Layer

The implementation below demonstrates sparse token routing without sharding. The `Experts` class holds all expert weights in a single tensor of shape `(n_experts, n_embed, n_embed)`, indexing into the appropriate slice when called. The `MOE` class implements the full dispatch-compute-combine pattern: a router produces per-token logits, `jax.lax.top_k` selects the top-k experts for each token, and a masked softmax computes the expert weights. During dispatch, a `fori_loop` iterates through tokens and copies each to its assigned experts' input buffers. Each expert then processes only its gathered tokens. During combine, another `fori_loop` scatters the expert outputs back to their original positions, weighted by the router's softmax scores. The training example uses synthetic data where the target transformation depends on input values — the model learns to route tokens to the correct expert, as shown by the loss dropping from 2.69 to 0.0004.

In [1]:
import os

os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2'

from functools import partial
from dataclasses import dataclass
import random

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding

import flax.nnx as nnx
import optax

# Set up device mesh - all devices along a single "devices" axis
mesh = Mesh(jax.devices(), ["devices"])
num_devices = len(jax.devices())

# Sharding spec for expert-parallel tensors
expert_spec = PartitionSpec("devices",)


@dataclass(unsafe_hash=True)
class Config():
    name: str = "MoE"
    dtype: jnp.dtype = jnp.float32
    param_dtype: jnp.dtype = jnp.float32
    top_k = 2
    load_factor = 1.00
    n_experts = 2
    n_embed = 3
    n_mlp_hidden = 6
    mlp_bias = True
    dtype = jax.numpy.float32

config = Config()


class Experts(nnx.Module):
    def __init__(self, config, rngs):
        # Shard expert weights on the expert dimension (axis 0)
        init = nnx.with_partitioning(
            nnx.initializers.normal(stddev=0.02),
            sharding=expert_spec
        )
        self.w1 = nnx.Param(init(rngs.default(),
            (
                config.n_experts,
                config.n_embed,
                config.n_embed
            )
        ))

    def __call__(self, x):
        # x: (n_experts, tokens_per_expert, n_embed)
        # Apply sharding constraint before expert computation
        x = jax.lax.with_sharding_constraint(x, expert_spec)
        # Each expert processes its slice: einsum over expert dimension
        y = jnp.einsum('eti,eio->eto', x, self.w1.value)
        y = jax.lax.with_sharding_constraint(y, expert_spec)
        return y


class MOE(nnx.Module):
    def __init__(self, config: Config, rngs: nnx.Rngs):
        # Router is replicated (not sharded)
        self.router_gate = nnx.Linear(
            config.n_embed,
            config.n_experts,
            kernel_init=nnx.with_partitioning(
                nnx.initializers.normal(stddev=0.02),
                sharding=(None,)  # replicated
            ),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros, 
                sharding=(None,)
            ),
            use_bias=config.mlp_bias,
            dtype=config.dtype,
            rngs=rngs,
        )
        self.experts = Experts(config, rngs)        
        self.top_k = config.top_k
        self.n_experts = config.n_experts
        self.load_factor = config.load_factor
        self.add_noise = False
        self.rngs = rngs

    def __call__(self, x):
        B, T, C = x.shape
        x_flat = x.reshape(-1, C)
        
        # Router produces expert logits for each token
        logits = self.router_gate(x_flat)  # (B*T, n_experts)
        
        # Select top-k experts per token
        top_k_logits, expert_indices = jax.lax.top_k(logits, self.top_k)
        
        # Create sparse logits for softmax (only top-k positions have real values)
        zeros = jnp.full_like(logits, float('-inf'))
        sparse_logits = jnp.put_along_axis(
            zeros, expert_indices, top_k_logits, axis=-1, inplace=False
        )
        expert_weights = jax.nn.softmax(sparse_logits, axis=-1)

        # --- DISPATCH: gather tokens into expert buffers ---
        expert_capacity = int(self.load_factor * self.top_k * B * T // self.n_experts)
        expert_inputs = jnp.zeros((self.n_experts, expert_capacity, C))
        input_counters = jnp.zeros((self.n_experts,), dtype=jnp.int32)

        def update_expert_inputs(i, carry):
            expert_inputs, counters = carry
            for j in range(self.top_k):
                expert_idx = expert_indices[i, j]
                token_pos = counters[expert_idx]
                expert_inputs = expert_inputs.at[expert_idx, token_pos].set(x_flat[i])
                counters = counters.at[expert_idx].add(1)
            return expert_inputs, counters

        expert_inputs, input_counters = jax.lax.fori_loop(
            0, B * T, update_expert_inputs, (expert_inputs, input_counters)
        )

        # Apply sharding constraint to trigger all-to-all dispatch
        # This moves tokens from batch-sharded to expert-sharded layout
        expert_inputs = jax.lax.with_sharding_constraint(expert_inputs, expert_spec)

        # --- COMPUTE: each expert processes its tokens ---
        expert_outputs = self.experts(expert_inputs)

        # --- COMBINE: scatter results back to original positions ---
        output_counters = jnp.zeros((self.n_experts,), dtype=jnp.int32)
        y_pred = jnp.zeros_like(x_flat)

        def update_expert_outputs(i, carry):
            y_pred, output_counters = carry
            for j in range(self.top_k):
                expert_idx = expert_indices[i, j]
                token_pos = output_counters[expert_idx]
                y_pred = y_pred.at[i].add(
                    expert_outputs[expert_idx, token_pos] * expert_weights[i, expert_idx]
                )
                output_counters = output_counters.at[expert_idx].add(1)
            return y_pred, output_counters

        y_pred, output_counters = jax.lax.fori_loop(
            0, B * T, update_expert_outputs, (y_pred, output_counters)
        )

        # Apply sharding constraint to trigger all-to-all combine
        # This moves results from expert-sharded back to batch-sharded layout
        y_pred = jax.lax.with_sharding_constraint(y_pred, expert_spec)

        y_pred = y_pred.reshape(B, T, C)
        return y_pred


def loss_fn(model, x, y):
    y_pred = model(x)
    loss = jnp.mean((y - y_pred)**2)
    return loss, y_pred


@nnx.jit
def step(state, x, y):
    (loss, y_pred), grads = nnx.value_and_grad(
        loss_fn, has_aux=True)(state.model, x, y)
    state.update(grads)
    return loss, grads, y_pred


# Training setup
D, B, T, C = 1000, config.n_experts, 5, config.n_embed 
   
default = jax.random.key(69)
gate_noise = jax.random.key(42)
rngs = nnx.Rngs(default=default, gate_noise=gate_noise)

# Create model within mesh context for proper sharding
with mesh:
    model = MOE(config, rngs)
    model.train(add_noise=False)
    tx = optax.adam(1e-2)
    state = nnx.Optimizer(model, tx)

    # Create data sharding for batch dimension
    data_sharding = NamedSharding(mesh, PartitionSpec("devices",))

    # Synthetic data: target depends on input sign (expert routing signal)
    x = jax.random.normal(jax.random.key(1000), (D, B, T, C))
    expert_ids = (x[:, :, :, 0] > 0).astype(jnp.int32)[..., None]
    t = [
        jax.random.normal(jax.random.key(2000), (C, C)),
        jax.random.normal(jax.random.key(3000), (C, C)),
    ]

    def transform(xi, eid):
        return jnp.where(eid == 1, xi @ t[0], xi @ t[1])

    y = jax.vmap(lambda xi, ei: transform(xi, ei))(x, expert_ids)

    # Training loop
    indices = list(range(D))
    for e in range(100):
        for i in indices:
            # Shard data across devices
            x_batch = jax.device_put(x[i], data_sharding)
            y_batch = jax.device_put(y[i], data_sharding)
            loss, grads, y_pred = step(state, x_batch, y_batch)
            if i % 1000 == 0:
                print(e, i, loss)

ValueError: An auto mesh context or metadata is required if creating a variable with annotation sharding_names=(None,). For more guidance, see https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html.

## Practical Considerations

**Outline:**
- When to use expert parallelism vs. other strategies
- Tips for debugging distributed MoE

## References

1. Jacobs, R. A., Jordan, M. I., Nowlan, S. J., & Hinton, G. E. (1991). *Adaptive Mixtures of Local Experts*. Neural Computation.
2. Shazeer, N., Mirhoseini, A., Maziarz, K., Davis, A., Le, Q., Hinton, G., & Dean, J. (2017). *Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer*. ICLR.
3. Lepikhin, D., Lee, H., Xu, Y., Chen, D., Firat, O., Huang, Y., ... & Chen, Z. (2021). *GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding*. ICLR.
4. Fedus, W., Zoph, B., & Shazeer, N. (2022). *Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity*. JMLR.
5. Du, N., Huang, Y., Dai, A. M., Tong, S., Lepikhin, D., Xu, Y., ... & Le, Q. V. (2022). *GLaM: Efficient Scaling of Language Models with Mixture-of-Experts*. ICML.
6. Riquelme, C., Puigcerver, J., Mustafa, B., Neumann, M., Jenatton, R., Susano Pinto, A., ... & Houlsby, N. (2021). *Scaling Vision with Sparse Mixture of Experts*. NeurIPS.
7. Jiang, A. Q., Sablayrolles, A., Roux, A., Mensch, A., Savary, B., Bamford, C., ... & Sayed, W. E. (2024). *Mixtral of Experts*. arXiv.