# How do Mixture of Experts Layers Work? Part 2


In [Part 1](https://blog.vikrampawar.com/how-mixture-of-experts-works-part-1.html), we introduced the idea of Mixtures of Expert layers and attempted a naive implementation, where we trained a simple neural network with an expert router and two experts. The model learned to route each datapoint to one of two regression models. We used a synthetic dataset tailored for our model. 

In this post, we'll build upon that basic design. 

## Why Mixture of Experts?

A big reason for the popularity of Mixture of Experts layers in modern language models is sparse computation. In modern models, each token is routed to K of N experts, usually with K = 1 or 2. This effectively means that per-token FLOPs in the expert layers is reduced to K/N times that of dense networks without a large performance penalty. This becomes especially useful on multi-GPU setups where each expert is assigned per a. This is called expert parallelism. Routing tokens to each expert ( and hence GPU ) can however incur significant overhead due to all-to-all communication between GPUs. This can be mitigated to a large extent using design practices.

In [1]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx
from typing import Any

class Router(nnx.Module):
    def __init__(self, dim: int, num_experts: int, *, rngs: nnx.Rngs):
        self.w1 = nnx.Linear(dim, num_experts, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        return self.w1(x)

class Expert(nnx.Module):
    def __init__(self, dim: int, *, rngs: nnx.Rngs):
        self.linear = nnx.Linear(dim, dim, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        return self.linear(x)

class SimpleMoE(nnx.Module):
    def __init__(self, dim: int, *, rngs: nnx.Rngs):
        num_experts = 2
        self.router = Router(dim, num_experts=num_experts, rngs=rngs)
        self.experts = nnx.List([
            Expert(dim, rngs=rngs)
            for _ in range(num_experts)
        ])
        self.top_k = 2

    def __call__(self, x: jax.Array) -> jax.Array:
        gate_logits = self.router(x)       
        top_k_logits, expert_indices = jax.lax.top_k(gate_logits, self.top_k)
        zeros = jnp.full_like(gate_logits, float('-inf'))
        sparse_logits = jnp.put_along_axis(
            zeros, expert_indices, top_k_logits, axis=-1, inplace=False
        )
        expert_weights = jax.nn.softmax(sparse_logits, axis=-1)

        mean_gates = jnp.mean(gate_logits, axis=0)
        lb_loss = gate_logits.shape[1] * jnp.sum(mean_gates ** 2)

        outputs = [ e(x) for e in self.experts ]

        result = jnp.zeros_like(x)

        for i, o in enumerate(outputs):
            result += (o * expert_weights[:, :, i:i+1])
           
        return result, lb_loss, expert_weights

In [2]:
import optax 

D, B, T, C = 10000, 2, 5, 3

model = SimpleMoE(dim=C, rngs=nnx.Rngs(0))
tx = optax.adam(1e-3)
state = nnx.Optimizer(model, tx, wrt=nnx.Param)

x = jax.random.normal(jax.random.key(1000), (D * B * T, C))

expert_ids = (x[:, 0] > 0).astype(jnp.int32)
t = [
    jax.random.normal(jax.random.key(2000), (C, C)),
    jax.random.normal(jax.random.key(3000), (C, C)),
]
def transform(xi, eid):
    return jnp.where(eid == 1, xi @ t[0], xi @ t[1])

y = jax.vmap(lambda xi, ei: transform(xi, ei))(x, expert_ids)

def loss_fn(model, x, y):
    y_pred, lb_loss, gates = model(x)
    loss = jnp.mean((y - y_pred)**2) # + lb_loss
    return loss, gates

@nnx.jit
def step(model, state, x, y):
    (loss, gates), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model, x, y)
    state.update(model, grads)
    return loss, gates, grads

x = x.reshape(D, B, T, C)
y = y.reshape(D, B, T, C)

for e in range(10):
    for i in range(D):
        loss, gates, grads = step(model, state, x[i], y[i])
        if i % 1000 == 0:
            print(i, loss)

0 2.8386207
1000 1.4691024
2000 0.6264829
3000 0.22666237
4000 0.32224867
5000 0.27109587
6000 0.036007397
7000 0.042003997
8000 0.359906
9000 0.019899525
0 0.07967947
1000 0.048338573
2000 0.07231833
3000 0.0011955692
4000 0.061102558
5000 0.027696146
6000 0.0011745306
7000 0.0014423532
8000 0.25469956
9000 0.0031980982
0 0.037122283
1000 0.021085124
2000 0.020036932
3000 0.0002973358
4000 0.028141744
5000 0.0031959899
6000 0.0005252546
7000 0.0006971828
8000 0.23215021
9000 0.0010879862
0 0.019682594
1000 0.009787401
2000 0.00533594
3000 0.00018117022
4000 0.013458978
5000 0.00086001644
6000 0.00029388236
7000 0.0004293031
8000 0.21831232
9000 0.00060997024
0 0.012715654
1000 0.004591771
2000 0.0015240598
3000 0.00012839105
4000 0.0068611316
5000 0.0005576007
6000 0.00019687101
7000 0.00030134188
8000 0.20550406
9000 0.00043756963
0 0.009739019
1000 0.002218916
2000 0.0005265196
3000 0.00010175501
4000 0.0037555827
5000 0.000450459
6000 0.00014797169
7000 0.00023094774
8000 0.1936279