# The Sparse Mixture-of-Experts Layer



The [Sparse Mixture-of-Experts](https://arxiv.org/abs/1701.06538) (MoE) layer computes a **sparse** combination of expert modules for a given input.

In this notebook, we provide a walkthorugh for implementing a sparse MoE layer and then show how you can include it in a character-level language model.

This notebook is inspired by the makeMoE blog post, available at: https://huggingface.co/blog/AviSoori1x/makemoe-from-scratch.

*The following implementations are mearly illustrative and didactic, often at the cost of performance.*

## Part I: The Sparse MoE layer

> MoE layers combine the output of $n$ expert networks $\{E_1, \dots, E_n\}$ using a gating function $G$. The following equation defines an MoE layer. It maps an input vector $x$ to a weighted sum of the experts' outputs, where the gating function assigns the weights.
>
> $$
> y = \sum_{i=1}^{n} G(x)_i E_i(x)
> $$
>
> Sparse MoE layers rely on **sparse gating functions**, which only select a subset of experts. This can be achieved with a sparse gating vector *i.e.*, by assigning some experts 0 weight. Importantly, the *computation for experts assigned zero weight can be skipped*.

At the core of the sparse MoE layer is its sparse gating (or routing) function, which assigns a weight to each expert.

> The sparse gating function $G : \mathbb{R}^N → \mathbb{R}^E$ maps an input token to a sparse vector with the weight scores of each expert.

A common approach is to learn a linear transformation $W_g^{N \times E}$ which maps the input token representations to affinity scores for each expert and then select the highest k:

$$
G(x) = Softmax(KeepTopK(W_g x, k)), \text{where} \\
KeepTopK(z, k)_i =
\begin{cases}
  z_i & \text{if $z_i$ in the top $k$ elements of $z$} \\
  -\inf & \text{otherwise}
\end{cases}
$$

In [None]:
import torch
import torch.nn as nn

def topk_scores(gate: nn.Linear, k: int, x: torch.Tensor):
    raw_scores = gate(x)
    scores = keep_topk(raw_scores, k)
    scores = torch.softmax(scores, dim=-1, dtype=torch.float32)
    return scores

def keep_topk(scores: torch.Tensor, k: int) -> torch.Tensor:
    _, topk_indices = scores.topk(k, dim=-1)
    assignment = torch.zeros_like(scores, dtype=torch.bool)
    assignment.scatter_(-1, topk_indices, True)
    return torch.where(assignment, scores, float("-inf"))

Another component of sparse MoE layers are the experts.

> An expert $E : \mathbb{R}^N → \mathbb{R}^N$ learns a (usually) non-linear transformation of the input x.

In this notebook, we will use a two-layer MLP.

In [None]:
def make_expert(hidden_size):
    return nn.Sequential(
        nn.Linear(hidden_size, 4 * hidden_size),
        nn.GELU(),
        nn.Linear(4 * hidden_size, hidden_size),
    )

One important aspect of constructing the experts is ensuring that they only compute the output representations for selected inputs.

In [None]:
def run_expert(expert: nn.Module, expert_scores: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    # Select the inputs to compute and totally skip the expert if it was not selected.
    selected_x = x[expert_scores > 0]
    if len(selected_x) == 0:
        return None

    weights = expert_scores[expert_scores > 0].unsqueeze(-1)

    return weights * expert(selected_x)

Putting everything together, we can run a sparse MoE layer by first obtaining the scores and then computing the expert outputs for the selected tokens.

In [None]:
class MoELayer(nn.Module):
    def __init__(self, *, hidden_size: int, num_experts: int, k: int):
        super().__init__()
        self.gate = nn.Linear(hidden_size, num_experts)
        self.experts = nn.ModuleList(
            [make_expert(hidden_size) for _ in range(num_experts)]
        )
        self.k = k

    def forward(self, x: torch.Tensor):
        # Flatten all dimensions (e.g batch, seq len)
        x_flat = x.view(-1, x.size(-1))

        scores = topk_scores(self.gate, self.k, x_flat)

        outputs = torch.zeros_like(x_flat)
        for i, expert in enumerate(self.experts):
            expert_scores = scores[:, i]

            expert_outputs = run_expert(expert, expert_scores, x_flat)
            if expert_outputs is None:
                continue

            outputs[expert_scores > 0] += expert_outputs

        # Recover original dimensions
        outputs = outputs.view_as(x)

        return outputs

## Part II: The training data

With the sparse MoE layer implemented, we will move towards developing a character level model using that layer. The model will be trained on text from Shakespeare, available at [https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt
](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt
).

We will start by downloading and inspecting the data.

In [None]:
# Downloading the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt

# read it in to inspect it
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print("Length in characters: ", len(text))

# let's look at the first 100 characters
print("First 100 characters", "-" * 50)
print(text[:100])

--2024-10-02 10:29:14--  https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.3’


2024-10-02 10:29:15 (156 MB/s) - ‘input.txt.3’ saved [1115394/1115394]

Length in characters:  1115394
First 100 characters --------------------------------------------------
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


From here, we can create the model vocabulary, which maps the characters into token ids.

In [None]:
class Vocabulary:

  def __init__(self, text: str):
    chars = set(text)
    self.ctoi = {c: i for i, c in enumerate(chars)}
    self.itoc = {v: k for k, v in self.ctoi.items()}

  def encode(self, text: str) -> list[int]:
    return [self.ctoi[t] for t in text]

  def decode(self, ids: list[int]) -> str:
    return "".join(self.itoc[i] for i in ids)

  def __len__(self):
    return len(self.ctoi)

  def __str__(self):
    return f"Vocabulary[size={len(self)}]{self.ctoi}"

We can train the vocabulary on the whole training set and some light testing to check that everything is working as expected.

In [None]:
vocab = Vocabulary(text)
print(vocab)

# Test if encoding/decoding work
first_chars = text[:1000]
encoded = vocab.encode(first_chars)
decoded = vocab.decode(encoded)
assert first_chars == decoded

Vocabulary[size=65]{'N': 0, 'j': 1, '\n': 2, 'U': 3, '$': 4, 'c': 5, '-': 6, 'y': 7, 'k': 8, 't': 9, 'v': 10, 'K': 11, 'S': 12, 'w': 13, 'z': 14, 'f': 15, 'p': 16, 'h': 17, '?': 18, 'M': 19, 'I': 20, 'Q': 21, 'X': 22, 'b': 23, '!': 24, 'F': 25, 'O': 26, 'r': 27, ' ': 28, 'u': 29, 'W': 30, '.': 31, ':': 32, 'Z': 33, ',': 34, 'd': 35, 'o': 36, 'H': 37, 'V': 38, 'B': 39, 'T': 40, 'n': 41, 'a': 42, 'L': 43, '3': 44, 'J': 45, 'x': 46, 'D': 47, 'g': 48, 'Y': 49, ';': 50, 's': 51, 'l': 52, 'm': 53, 'R': 54, 'G': 55, 'P': 56, '&': 57, 'i': 58, 'E': 59, 'e': 60, "'": 61, 'C': 62, 'q': 63, 'A': 64}


Afterwards, we can encode the data as a tensor and create a train and validation split.

In [None]:
val_ratio = 0.05

# Encode all text
encoded_text = torch.tensor(vocab.encode(text))
print("Total tokens:", encoded_text.size(0))
# Split into train and validation
val_size = int(val_ratio*len(encoded_text))

encoded_train = encoded_text[:-val_size]
encoded_val = encoded_text[-val_size:]
print("Train tokens:", encoded_train.size(0))
print("Val tokens: ", encoded_val.size(0))

Total tokens: 1115394
Train tokens: 1059625
Val tokens:  55769


From here, we can create the dataset class to iterate the data. Each record will be a chunk of adjacent tokens and the model will learn to predict the next token from the previous ones.

In [None]:
import torch.utils.data

# Predict on block_size contiguous tokens
seq_len = 64


class ShakespearDataset(torch.utils.data.Dataset):

  def __init__(self, data: torch.Tensor, seq_len: int):
    self.data = data
    self.seq_len = seq_len


  def __getitem__(self, i: int) -> torch.Tensor:
    # Account for extra one in LM offset
    data = self.data[i:i + self.seq_len + 1]
    x = data[:-1]
    y = data[1:]
    return x, y

  def __len__(self):
    return len(self.data) - self.seq_len - 1

train = ShakespearDataset(encoded_train, seq_len)
val = ShakespearDataset(encoded_val, seq_len)
print("Train size:", len(train), ", val size:", len(val))

Train size: 1059560 , val size: 55704


## Part III: The Mixture-of-Experts Model

Finally, we can create the transformer model with the sparse MoE layer replacing the usual feed forward network.

We start by defining the model configuration.

*This network implementation is mearly illustrative.*

In [None]:
import dataclasses

@dataclasses.dataclass
class Config:
  seq_len: int = seq_len
  n_layers: int = 6
  n_heads: int = 6
  hidden_size: int = 192
  head_size: int = 32
  vocab_size: int = len(vocab)

  k: int = 1
  normalize_scores: int = True
  n_experts: int = 8

Afterwards, we define the self-attention mechanism.

In [None]:
import einops
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):

  def __init__(self, config: Config):
    super().__init__()
    self.config = config
    # We have n_head projections, each projecting to head_dim for the queries, keys and values.
    self.queries_proj = nn.Linear(config.hidden_size, config.n_heads * config.head_size)
    self.keys_proj = nn.Linear(config.hidden_size, config.n_heads * config.head_size)
    self.values_proj = nn.Linear(config.hidden_size, config.n_heads * config.head_size)
    self.out_proj = nn.Linear(config.n_heads * config.head_size, config.hidden_size)

  def forward(self, x: torch.Tensor, is_causal: bool = True, use_sdpa: bool = True):
    _, T, _ = x.size()
    n_heads, head_dim = self.config.n_heads, self.config.head_size

    # Compute projections and split by queries, keys and values.
    q, k, v = self.queries_proj(x), self.keys_proj(x), self.values_proj(x)

    if use_sdpa:
      q = einops.rearrange(q, "b t (head proj) -> b head t proj", head=n_heads, proj=head_dim)
      k = einops.rearrange(k, "b t (head proj) -> b head t proj", head=n_heads, proj=head_dim)
      v = einops.rearrange(v, "b t (head proj) -> b head t proj", head=n_heads, proj=head_dim)

      out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=is_causal)
    else:
      q = einops.rearrange(q, "b t (head proj) -> b t head proj", head=n_heads, proj=head_dim)
      k = einops.rearrange(k, "b t (head proj) -> b t head proj", head=n_heads, proj=head_dim)
      v = einops.rearrange(v, "b t (head proj) -> b t head proj", head=n_heads, proj=head_dim)

      # Compute unnormalized attention scores
      a = einops.einsum(q, k, "b query head proj, b key head proj -> b head query key") # (b, n_heads, T, T)
      a = a / head_dim ** 0.5 # (b, n_heads, T, T)
      # Mask future tokens
      if is_causal:
        mask = torch.tril(torch.ones((T, T), dtype=torch.bool, device=x.device))
        a = torch.masked_fill(a, mask == 0, float("-inf")) # (b, n_heads, T, T)
      # Compute attention weights
      a = torch.softmax(a, dim=-1) # (b, n_heads, T, T)

      # Compute attention values
      out = einops.einsum(a, v, "b head query key, batch key head proj -> b head query proj")

    # Reorder to have all head values for token together
    out = einops.rearrange(out, "b head time proj -> b time (head proj)")

    # Compute the final projection
    out = self.out_proj(out)
    return out

config = Config()
attn = MultiHeadAttention(config)
input = torch.rand(4, 8, config.hidden_size)
out = attn(input, is_causal=True)
out.size()

torch.Size([4, 8, 192])

In [None]:
class Block(nn.Module):

  def __init__(self, config: Config):
    super().__init__()
    self.attn = MultiHeadAttention(config)
    self.attn_norm = nn.LayerNorm(config.hidden_size)
    self.mlp = MoELayer(
        hidden_size=config.hidden_size,
        num_experts=config.n_experts,
        k=config.k,
    )
    self.mlp_norm = nn.LayerNorm(config.hidden_size)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    h = self.attn_norm(x)
    h = self.attn(h, use_sdpa=True)
    x = x + h
    h = self.mlp_norm(x)
    h = self.mlp(h)
    x = x + h
    return x


class Transformer(nn.Module):

  def __init__(self, config: Config):
    super().__init__()
    self.tok_emb = nn.Embedding(config.vocab_size, config.hidden_size)
    self.pos_emb = nn.Embedding(config.seq_len, config.hidden_size)
    self.blocks = nn.ModuleList(Block(config) for _ in range(config.n_layers))
    self.final_norm = nn.LayerNorm(config.hidden_size)
    self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

  def forward(self, tokens: torch.Tensor) -> torch.Tensor:
    tok_emb = self.tok_emb(tokens)
    pos_emb = self._get_pos_embeddings(tokens)

    hidden_states = tok_emb + pos_emb

    for block in self.blocks:
      hidden_states = block(hidden_states)

    hidden_states = self.final_norm(hidden_states)
    logits = self.lm_head(hidden_states)
    return logits

  def _get_pos_embeddings(self, tokens: torch.Tensor) -> torch.Tensor:
    batch_size, seq_len = tokens.size()
    pos_idx = torch.arange(seq_len, device=tokens.device)
    pos_emb = self.pos_emb(pos_idx)
    return pos_emb


In [None]:
def train_step(
    model: Transformer,
    optimizer: torch.optim.Optimizer,
    scaler,
    device_type: str,
    x: torch.Tensor,
    y: torch.Tensor,
):
  with torch.autocast(device_type=device_type, dtype=torch.float16):
    logits = model(x)
    logits = einops.rearrange(logits, "b s e -> (b s) e")
    y = einops.rearrange(y, "b s -> (b s)")
    loss = F.cross_entropy(logits, y)

  scaler.scale(loss).backward()

  scaler.step(optimizer)

  scaler.update()

  optimizer.zero_grad(set_to_none=True)
  return loss.item()

In [None]:
import functools
import torch
import torch.cuda

import torch.optim
import time

from typing import Literal
torch.manual_seed(42)


device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
  print("WARNING: Using cpu device, training will be slow")

train_dataloader = torch.utils.data.DataLoader(
    train, batch_size=64, shuffle=True, pin_memory=True, num_workers=2, prefetch_factor=2,
)

config = Config()
model = Transformer(config)

model_params = sum(p.numel() for p in model.parameters())
print("Total parameters:", (model_params / 1e6), "M")

optimizer = torch.optim.AdamW(
    model.parameters(), lr=1e-3, weight_decay=0.1, betas=(0.9, 0.95),
)
scaler = torch.amp.GradScaler()
model.to(device)

num_epochs = 1
step_time = 0
step_count = 0
token_count = 0
start_time = time.time()

num_steps = 10000
dataiter = iter(train_dataloader)
for step in range(1, num_steps + 1):
  x, y = next(dataiter)
  start = time.time()
  x, y = x.to(device), y.to(device)
  token_count += x.numel()
  loss = train_step(model, optimizer, scaler, device, x, y)
  step_time += time.time() - start
  step_count += 1
  if step % 10 == 0:
    completed = step / num_steps * 100
    step_avg = step_time / step_count
    ellapsed = time.time() - start_time
    tok_per_sec = token_count / ellapsed
    fmt_ellapsed = time.strftime("%H:%M:%S", time.gmtime(ellapsed))
    print(
        f"Step {step}/{num_steps} ({completed:.0f}%): "
        f"TrainTime={fmt_ellapsed}, "
        f"Loss={loss:.4f}, "
        f"TokensPerSec={tok_per_sec:.0f}, "
        f"AvgStepTime={step_avg:.2}s"
    )

Total parameters: 15.142704 M
Step 10/10000 (0%): TrainTime=00:00:01, Loss=2.8247, TokensPerSec=24996, AvgStepTime=0.15s
Step 20/10000 (0%): TrainTime=00:00:02, Loss=2.6095, TokensPerSec=28185, AvgStepTime=0.14s
Step 30/10000 (0%): TrainTime=00:00:04, Loss=2.5605, TokensPerSec=29962, AvgStepTime=0.13s
Step 40/10000 (0%): TrainTime=00:00:05, Loss=2.4793, TokensPerSec=30736, AvgStepTime=0.13s
Step 50/10000 (0%): TrainTime=00:00:06, Loss=2.4451, TokensPerSec=29551, AvgStepTime=0.13s
Step 60/10000 (1%): TrainTime=00:00:08, Loss=2.4691, TokensPerSec=30381, AvgStepTime=0.13s
Step 70/10000 (1%): TrainTime=00:00:09, Loss=2.4879, TokensPerSec=30438, AvgStepTime=0.13s
Step 80/10000 (1%): TrainTime=00:00:11, Loss=2.4157, TokensPerSec=29593, AvgStepTime=0.14s
Step 90/10000 (1%): TrainTime=00:00:12, Loss=2.3974, TokensPerSec=29990, AvgStepTime=0.13s
Step 100/10000 (1%): TrainTime=00:00:13, Loss=2.3939, TokensPerSec=30506, AvgStepTime=0.13s
Step 110/10000 (1%): TrainTime=00:00:14, Loss=2.3415, Token

In [None]:
@torch.no_grad
def generate(
  model: nn.Module,
  vocab: Vocabulary,
  input: str,
  new_tokens: int,
):
  idx = vocab.encode(input)
  idx = torch.tensor(idx, dtype=torch.long, device=device)
  idx = einops.rearrange(idx, "... -> 1 ...")

  for t in range(new_tokens):
    ctx = idx[:, -seq_len:]

    logits = model(ctx)

    next_logits = logits[:, -1, :]

    next_probs = torch.softmax(next_logits, dim=-1)

    next_idx = torch.multinomial(next_probs, num_samples=1)
    # append sampled index to the running sequence and continue
    idx = torch.cat((idx, next_idx), dim=1)

  idx = idx.tolist()[0]
  output = vocab.decode(idx)
  return output

output = generate(
    model,
    vocab,
    "H",
    1000,
)
print(output)

HARD III:
Here letters thirvingman:
Directitude and victory purpose!

First Senatol:
Men's eyes will or powerful unfold floveth.

DUKE VINCENTIO:
Pray, no, no, not to me well.

PETRUCHIO:
Faith, marquess, and come to your grace powerful
We have unducationed by the dgown, carce with and give
not into a village hap into your souls,
In warnals dagger, and got love, kill'd contried
Than the brunts. For what I call attend,
Insug hacks are true kneeping, to make the queen,
We wish she sers of the Tare of onestable,
Nor that is the match of pale beloud,
Of the dukedom, where he is my body's destisity.

LADY ANNE:
I will tell me, if I remember it were all
on the hundred, and disease it is,
Nor done immoder, up with Romeo.

First Senator:
You have done, no denerate commends.
Do you have with her? here chesicourse he has discovered thy watery,
With that runfidence into harms, and does he
Doth log from your suit on your great at the
shed but as is taunting, our before thou
Then denying the educat