# How do Mixture of Experts Layers Work? Part 2

## Introduction


In [Part 1](https://blog.vikrampawar.com/how-mixture-of-experts-works-part-1.html), we introduced the idea of Mixtures of Expert layers and attempted a naive implementation, where we trained a simple neural network with an expert router and two experts. The model learned to route each datapoint to one of two regression models. We used a synthetic dataset tailored for our model. 

In this post, we'll build upon that basic design. 

## Why Mixture of Experts?

A big reason for the popularity of Mixture of Experts layers in modern language models is sparse computation. In modern models, each token is routed to K of N experts, usually with K = 1 or 2. This effectively means that per-token FLOPs in the expert layers is reduced to K/N times that of dense networks without a large performance penalty. This becomes especially useful on multi-GPU setups where each expert is assigned per a. This is called expert parallelism. Routing tokens to each expert ( and hence GPU ) can however incur significant overhead due to all-to-all communication between GPUs. This can be mitigated to a large extent using design practices.

With the help of MoE layers, it becomes possible to train much larger models while keeping the FLOPs under control.

## The Parallelism Challenge

**Outline:**
- Why naive implementations don't scale
- The token-to-expert routing problem across devices

## Expert Parallelism Explained

**Outline:**
- Sharding experts across GPUs
- How tokens flow: local → global → expert → back
- Diagram/visualization of the data flow

## Token Routing - Dispatch and Combine

**Outline:**
- Dispatch: sending tokens to their assigned experts
- Combine: gathering results back to original positions
- How JAX's sharding annotations trigger the right communication
- Understanding the cost implications

## Capacity Factor and Token Dropping

**Outline:**
- Expert capacity limits
- What happens when experts overflow
- Balancing computation vs. accuracy

## Putting It Together: A Parallel MoE Layer

**Outline:**
- Full implementation with proper sharding
- Code walkthrough with JAX/Flax

In [None]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx
from typing import Any

class Router(nnx.Module):
    def __init__(self, dim: int, num_experts: int, *, rngs: nnx.Rngs):
        self.w1 = nnx.Linear(dim, num_experts, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        return self.w1(x)

class Expert(nnx.Module):
    def __init__(self, dim: int, *, rngs: nnx.Rngs):
        self.linear = nnx.Linear(dim, dim, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        return self.linear(x)

class SimpleMoE(nnx.Module):
    def __init__(self, dim: int, *, rngs: nnx.Rngs):
        num_experts = 2
        self.router = Router(dim, num_experts=num_experts, rngs=rngs)
        self.experts = nnx.List([
            Expert(dim, rngs=rngs)
            for _ in range(num_experts)
        ])
        self.top_k = 2

    def __call__(self, x: jax.Array) -> jax.Array:
        gate_logits = self.router(x)       
        top_k_logits, expert_indices = jax.lax.top_k(gate_logits, self.top_k)
        zeros = jnp.full_like(gate_logits, float('-inf'))
        sparse_logits = jnp.put_along_axis(
            zeros, expert_indices, top_k_logits, axis=-1, inplace=False
        )
        expert_weights = jax.nn.softmax(sparse_logits, axis=-1)

        mean_gates = jnp.mean(gate_logits, axis=0)
        lb_loss = gate_logits.shape[1] * jnp.sum(mean_gates ** 2)

        outputs = [ e(x) for e in self.experts ]

        result = jnp.zeros_like(x)

        for i, o in enumerate(outputs):
            result += (o * expert_weights[:, :, i:i+1])
           
        return result, lb_loss, expert_weights

In [None]:
import optax 

D, B, T, C = 10000, 2, 5, 3

model = SimpleMoE(dim=C, rngs=nnx.Rngs(0))
tx = optax.adam(1e-3)
state = nnx.Optimizer(model, tx, wrt=nnx.Param)

x = jax.random.normal(jax.random.key(1000), (D * B * T, C))

expert_ids = (x[:, 0] > 0).astype(jnp.int32)
t = [
    jax.random.normal(jax.random.key(2000), (C, C)),
    jax.random.normal(jax.random.key(3000), (C, C)),
]
def transform(xi, eid):
    return jnp.where(eid == 1, xi @ t[0], xi @ t[1])

y = jax.vmap(lambda xi, ei: transform(xi, ei))(x, expert_ids)

def loss_fn(model, x, y):
    y_pred, lb_loss, gates = model(x)
    loss = jnp.mean((y - y_pred)**2) # + lb_loss
    return loss, gates

@nnx.jit
def step(model, state, x, y):
    (loss, gates), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model, x, y)
    state.update(model, grads)
    return loss, gates, grads

x = x.reshape(D, B, T, C)
y = y.reshape(D, B, T, C)

for e in range(10):
    for i in range(D):
        loss, gates, grads = step(model, state, x[i], y[i])
        if i % 1000 == 0:
            print(i, loss)

## Practical Considerations

**Outline:**
- When to use expert parallelism vs. other strategies
- Tips for debugging distributed MoE

## References

1. Jacobs, R. A., Jordan, M. I., Nowlan, S. J., & Hinton, G. E. (1991). *Adaptive Mixtures of Local Experts*. Neural Computation.
2. Shazeer, N., Mirhoseini, A., Maziarz, K., Davis, A., Le, Q., Hinton, G., & Dean, J. (2017). *Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer*. ICLR.
3. Lepikhin, D., Lee, H., Xu, Y., Chen, D., Firat, O., Huang, Y., ... & Chen, Z. (2021). *GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding*. ICLR.
4. Fedus, W., Zoph, B., & Shazeer, N. (2022). *Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity*. JMLR.
5. Du, N., Huang, Y., Dai, A. M., Tong, S., Lepikhin, D., Xu, Y., ... & Le, Q. V. (2022). *GLaM: Efficient Scaling of Language Models with Mixture-of-Experts*. ICML.
6. Riquelme, C., Puigcerver, J., Mustafa, B., Neumann, M., Jenatton, R., Susano Pinto, A., ... & Houlsby, N. (2021). *Scaling Vision with Sparse Mixture of Experts*. NeurIPS.
7. Jiang, A. Q., Sablayrolles, A., Roux, A., Mensch, A., Savary, B., Bamford, C., ... & Sayed, W. E. (2024). *Mixtral of Experts*. arXiv.