# DeepSeek Mixture of Experts (MoE) - Simple Explanation

## The Main Goal

**Problem:** Large neural networks are expensive to run because every input goes through ALL parameters.

**Solution:** Use **Mixture of Experts (MoE)** - have many small "expert" networks, but only use a FEW of them for each input!

```
Traditional Dense Network:          MoE Network:
┌─────────────────────┐            ┌─────────────────────┐
│  Every input uses   │            │  Each input only    │
│  ALL parameters     │            │  uses SELECTED      │
│  (expensive!)       │            │  experts (cheap!)   │
└─────────────────────┘            └─────────────────────┘
```

---

## Architecture Overview

```
                         Input Token
                              │
                              ▼
              ┌───────────────┴───────────────┐
              │                               │
              ▼                               ▼
      ┌──────────────┐                ┌──────────────┐
      │    SHARED    │                │    ROUTER    │
      │   EXPERTS    │                │ (Gatekeeper) │
      │ (Always ON)  │                └──────┬───────┘
      └──────┬───────┘                       │
             │                      Picks Top-K Experts
             │                               │
             │                    ┌──────────┼──────────┐
             │                    ▼          ▼          ▼
             │               ┌────────┐  ┌────────┐  ┌────────┐
             │               │Expert 1│  │Expert 5│  │Expert 9│  <- Only these
             │               │        │  │        │  │        │     activated!
             │               └────┬───┘  └────┬───┘  └────┬───┘
             │                    │          │          │
             │                    └──────────┼──────────┘
             │                               │
             │                     Weighted Sum (gating)
             │                               │
             └───────────────┬───────────────┘
                             │
                             ▼
                       ┌───────────┐
                       │    ADD    │  <- Combine all outputs
                       │ (x+s+r)   │
                       └─────┬─────┘
                             │
                             ▼
                        Final Output
```

---

## Key Components Explained

### 1. Expert FFN (Feed-Forward Network)

Each expert is a **small 2-layer neural network**:

```
Input ──▶ [Linear Layer 1] ──▶ [GELU Activation] ──▶ [Linear Layer 2] ──▶ Output
         (expand)              (non-linearity)       (compress back)
```

```python
# Simple: Input → Hidden → Output
fc1: d_model → hidden    # Expand dimensions
GELU: activation         # Add non-linearity  
fc2: hidden → d_model    # Back to original size
```

---

### 2. Shared Experts (Always Active)

These experts process **EVERY token** - they're always "on duty":

```
Token 1 ──▶ ┌─────────────────┐
Token 2 ──▶ │  Shared Expert  │ ──▶ ALL tokens get processed
Token 3 ──▶ │    (Always ON)  │
   ...  ──▶ └─────────────────┘
```

**Why?** They capture general patterns that apply to all inputs.

---

### 3. Routed Experts (Selectively Active)

The **Router** decides which experts to use for each token:

```
          16 Routed Experts Available
    ┌───┬───┬───┬───┬───┬───┬───┬───┐
    │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │ 7 │ 8 │ ...
    └───┴───┴───┴───┴───┴───┴───┴───┘
          │       │   │           │
          ▼       ▼   ▼           ▼
        Token A picks: Expert 2, 4, 5, 8 (top_k = 4)
        
    ┌───┬───┬───┬───┬───┬───┬───┬───┐
    │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │ 7 │ 8 │ ...
    └───┴───┴───┴───┴───┴───┴───┴───┘
      │           │       │   │
      ▼           ▼       ▼   ▼
    Token B picks: Expert 1, 4, 6, 7 (different experts!)
```

---

### 4. The Router (Gatekeeper)

The router calculates **affinity scores** between each token and each expert:

```
                    How Router Works
                    ════════════════
                    
Token Embedding ──▶ Dot Product with Expert Centroids ──▶ Scores
     [1024]              [16 × 1024]                      [16]
                    
     "How similar is this token to each expert's specialty?"
```

```
Score Calculation:
┌────────────────────────────────────────────┐
│  logits = token · centroids + bias         │
│                                            │
│  Token: "The cat sat..."                   │
│                                            │
│  Expert 1 (grammar):     score = 0.8       │
│  Expert 2 (math):        score = 0.1       │
│  Expert 3 (code):        score = 0.2       │
│  Expert 4 (language):    score = 0.9       │
│  ...                                       │
│                                            │
│  → Select TOP-K highest scores!            │
└────────────────────────────────────────────┘
```

---

### 5. Gating Weights (Softmax)

After selecting top-k experts, we **weight their contributions**:

```
Selected Experts:    Expert 4    Expert 1    Expert 7    Expert 2
Top-k Logits:          0.9         0.8         0.6         0.5
                        │           │           │           │
                        ▼           ▼           ▼           ▼
                   ┌─────────── Softmax ───────────┐
                        │           │           │           │
                        ▼           ▼           ▼           ▼
Gate Weights:         0.35        0.30        0.20        0.15
                        │           │           │           │
                        │     (Higher score = More influence)
```

---

## Forward Pass Flow

```
Step 1: Input arrives
        ┌─────────────────────────┐
        │  x: [Batch, Seq, 1024]  │
        └───────────┬─────────────┘
                    │
Step 2: Shared experts process ALL tokens
        ┌───────────▼─────────────┐
        │  shared_out = Σ exp(x)  │
        └───────────┬─────────────┘
                    │
Step 3: Router calculates expert preferences
        ┌───────────▼─────────────┐
        │  logits = x · centroids │
        │  + bias                 │
        └───────────┬─────────────┘
                    │
Step 4: Select top-k experts per token
        ┌───────────▼─────────────┐
        │  top_k = 8 experts      │
        │  (out of 16 available)  │
        └───────────┬─────────────┘
                    │
Step 5: Route tokens to selected experts
        ┌───────────▼─────────────┐
        │  Each expert processes  │
        │  only ITS tokens        │
        └───────────┬─────────────┘
                    │
Step 6: Weight and combine outputs
        ┌───────────▼─────────────┐
        │  routed_out = Σ(w * out)│
        └───────────┬─────────────┘
                    │
Step 7: Final output = input + shared + routed
        ┌───────────▼─────────────┐
        │  return x + shared_out  │
        │         + routed_out    │
        └─────────────────────────┘
```

---

## Load Balancing (Bias Update)

**Problem:** Some experts might get ALL the tokens (overloaded), while others get NONE (underused).

**Solution:** Dynamically adjust the router bias!

```
                   Expert Load Distribution
    
    Before Balancing:                   After Balancing:
    
    Expert 1: ████████████ (120)        Expert 1: ████████ (80)
    Expert 2: ██           (20)         Expert 2: ███████  (70)
    Expert 3: █            (10)         Expert 3: ████████ (75)
    Expert 4: ████████████████ (160)    Expert 4: ███████  (75)
    
    Unbalanced!                         Balanced!
```

```
How Bias Update Works:
══════════════════════

1. Count how many tokens each expert received
2. Calculate average load
3. If expert is UNDER-loaded → INCREASE its bias (make it more attractive)
4. If expert is OVER-loaded → DECREASE its bias (make it less attractive)

Formula: bias += lr * tanh((avg - count) / avg)
         └────────────────────────────────────┘
                 Smooth, bounded update
```

---

## Example with Numbers

```python
# Configuration
d_model = 1024        # Token dimension
n_routed_exp = 16     # 16 routed experts available  
n_shared_exp = 2      # 2 shared experts (always active)
top_k = 8             # Select 8 experts per token

# Input
batch_size = 2
seq_len = 64
# Total tokens = 2 × 64 = 128 tokens
```

```
For each of 128 tokens:
├── 2 shared experts process it (ALWAYS)
└── Router selects 8 out of 16 routed experts
    └── These 8 experts process it with weighted contributions

Computation Savings:
├── Dense model: ALL parameters for ALL tokens
└── MoE model: Only 8/16 = 50% of routed experts per token!
```

---

## Visual Summary

```
┌─────────────────────────────────────────────────────────────────┐
│                       DeepSeek MoE Layer                        │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│   INPUT ──────────┬─────────────────────┐                       │
│                   │                     │                       │
│                   ▼                     ▼                       │
│            ┌──────────────┐      ┌─────────────┐                │
│            │   SHARED     │      │   ROUTER    │                │
│            │   EXPERTS    │      │  + GATING   │                │
│            │  (2 always)  │      └──────┬──────┘                │
│            └──────┬───────┘             │                       │
│                   │              ┌──────┴──────┐                │
│                   │              │  TOP-K = 8  │                │
│                   │              │  Selection  │                │
│                   │              └──────┬──────┘                │
│                   │                     │                       │
│                   │              ┌──────┴──────┐                │
│                   │              │  8 ROUTED   │                │
│                   │              │  EXPERTS    │                │
│                   │              │  (selected) │                │
│                   │              └──────┬──────┘                │
│                   │                     │                       │
│                   └─────────┬───────────┘                       │
│                             │                                   │
│                       ┌─────┴─────┐                             │
│                       │    ADD    │                             │
│                       │  (x+s+r)  │                             │
│                       └─────┬─────┘                             │
│                             │                                   │
│                          OUTPUT                                 │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘
```

---

## Key Takeaways

| Concept | What It Does |
|---------|--------------|
| **Shared Experts** | Always active, capture common patterns |
| **Routed Experts** | Specialized, only activated when needed |
| **Router** | Decides which experts to use per token |
| **Top-K Selection** | Limits computation by selecting few experts |
| **Gating** | Weights expert outputs by relevance |
| **Bias Update** | Keeps expert loads balanced |
| **Residual Connection** | Output = Input + Shared + Routed |

---

## Why This Matters

```
Traditional Transformer FFN:           vs        MoE FFN:
──────────────────────────                       ─────────
All 256 experts active                           Only 8 experts active
= 256× computation                               = 8× computation
= SLOW + EXPENSIVE                               = FAST + CHEAP!

Same model capacity, much less compute!
```

This is how DeepSeek-V3 achieves **large model capacity** with **efficient inference**!


In [2]:

import math 
from contextlib import nullcontext 
from typing import Optional 

import torch 
import torch.nn as nn 
import torch.nn.functional as F

In [3]:
def _gelu(x: torch.Tensor) -> torch.Tensor:
    # Slightly faster GELU (approx)
    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) *
                                       (x + 0.044715 * torch.pow(x, 3))))
    
class ExpertFFN(nn.Module):
    """
    A 2-layer MLP expert. Hidden dim is usually smaller than a dense FFN
    (e.g., 0.25 × d_model in DeepSeek-V3).
    """
    def __init__(self, d_model: int, hidden: int, dropout: float = 0.0):
        super().__init__()
        self.fc1 = nn.Linear(d_model, hidden, bias=False)
        self.fc2 = nn.Linear(hidden, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc2(self.dropout(_gelu(self.fc1(x))))



class DeepSeekMoE(nn.Module):
    """
    DeepSeek-V3 style Mixture-of-Experts (MoE) layer.

    This MoE layer incorporates both routed experts (selected by a router)
    and shared experts (applied to all inputs). It is designed based on
    the architecture described in the DeepSeek-V3 paper.

    Args:
        d_model (int): The dimension of the input and output features.
        n_routed_exp (int): The number of routed experts.
        n_shared_exp (int, optional): The number of shared experts. Defaults to 1.
        top_k (int, optional): The number of routed experts to select for each token.
                               Defaults to 8.
        routed_hidden (int, optional): The hidden dimension for routed experts.
                                      Defaults to 2048.
        shared_hidden (Optional[int], optional): The hidden dimension for shared experts.
                                                If None, uses routed_hidden. Defaults to None.
        bias_lr (float, optional): Learning rate for the router bias (updated online).
                                   Defaults to 0.01.
        fp16_router (bool, optional): Whether to use FP16 precision for router calculations.
                                     Defaults to False.
    """
    def __init__(
        self,
        d_model: int,
        n_routed_exp: int,
        n_shared_exp: int = 1,
        top_k: int = 8,
        routed_hidden: int = 2_048,
        shared_hidden: Optional[int] = None,
        bias_lr: float = 0.01,
        fp16_router: bool = False,
    ):
        super().__init__()
        # Assert that the number of selected experts (top_k) is less than or equal to the total number of routed experts.
        assert top_k <= n_routed_exp, "k must be ≤ number of routed experts"

        self.d_model = d_model
        self.n_routed = n_routed_exp
        self.n_shared = n_shared_exp
        self.top_k = top_k
        self.bias_lr = bias_lr
        self.fp16_router = fp16_router

        # Module list for the routed experts.
        self.routed = nn.ModuleList(
            [ExpertFFN(d_model, routed_hidden) for _ in range(n_routed_exp)]
        )
        # Determine the hidden dimension for shared experts. Use routed_hidden if shared_hidden is not provided.
        hidden_shared = shared_hidden or routed_hidden
        # Module list for the shared experts.
        self.shared = nn.ModuleList(
            [ExpertFFN(d_model, hidden_shared) for _ in range(n_shared_exp)]
        )

        # Register a parameter for the centroids used by the router.
        # Centroids represent the "preference" of each expert for different input features.
        self.register_parameter("centroids", nn.Parameter(torch.empty(n_routed_exp, d_model)))
        # Initialize centroids with a normal distribution.
        nn.init.normal_(self.centroids, std=d_model ** -0.5)

        # Register a buffer for the router bias. This bias is updated online
        # without using standard gradient descent, hence it's not a parameter.
        self.register_buffer("bias", torch.zeros(n_routed_exp))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the DeepSeekMoE layer.

        Args:
            x (torch.Tensor): The input tensor of shape [B, S, D], where B is
                              batch size, S is sequence length, and D is d_model.

        Returns:
            torch.Tensor: The output tensor of shape [B, S, D], which is the
                          sum of the input, shared expert outputs, and routed
                          expert outputs.
        """
        # Get dimensions of the input tensor.
        B, S, D = x.shape
        # Reshape the input to [N, D], where N = B * S (number of tokens).
        x_flat = x.reshape(-1, D)  # [N, D] with N=B*S

        # 1) Shared path: Process the input through all shared experts and sum their outputs.
        shared_out = torch.zeros_like(x)
        for exp in self.shared:
            shared_out += exp(x)
        # (Optional) Scale the shared expert output by the number of shared experts.
        # This can help in balancing the contribution of shared vs. routed experts.
        # shared_out = shared_out / max(1, self.n_shared)

        # 2) Router logits: Calculate the affinity of each token to each routed expert.
        # Use autocasting to FP16 if fp16_router is True and the device is CUDA.
        use_autocast = self.fp16_router and x.is_cuda
        device_type = "cuda" if x.is_cuda else x.device.type
        with torch.autocast(device_type=device_type, enabled=use_autocast):
            # Calculate logits by taking the dot product of the flattened input with the expert centroids.
            logits = F.linear(x_flat, self.centroids)  # [N, E]
            # Add the router bias to the logits. Ensure bias matches the logits' dtype.
            logits = logits + self.bias.to(logits.dtype)

        # Select the top_k experts with the highest logits for each token.
        topk_logits, topk_idx = torch.topk(logits, self.top_k, dim=-1)        # [N, k]
        # Apply softmax to the top_k logits to get gating weights.
        # Ensure the gate weights have the same dtype as the input for subsequent calculations.
        gate = F.softmax(topk_logits, dim=-1, dtype=x.dtype)                   # [N, k]

        # 3) Dispatch per expert: Route tokens to their selected experts and combine outputs.
        routed_out = torch.zeros_like(x_flat)                                   # [N, D]
        # Iterate through each routed expert.
        for i in range(self.n_routed):
            # Create a mask to identify which tokens selected the current expert (expert i).
            mask = (topk_idx == i)
            # Find the indices of the rows (tokens) and columns (which of the top-k) where expert i was selected.
            row_idx, which_k = mask.nonzero(as_tuple=True)                      # 1-D each
            # If no tokens selected this expert, skip.
            if row_idx.numel() == 0:
                continue
            # Select the input tokens that are routed to expert i.
            exp_in = x_flat.index_select(0, row_idx)                            # [Ti, D] where Ti is the number of tokens routed to expert i
            # Pass the selected tokens through the expert's FFN.
            out = self.routed[i](exp_in)                                        # [Ti, D]
            # Get the gating weights for the tokens routed to expert i.
            w = gate[row_idx, which_k].unsqueeze(-1)                            # [Ti, 1]
            # Scale the expert output by the gating weights and add it to the routed_out tensor
            # at the original token positions using index_add_.
            routed_out.index_add_(0, row_idx, out * w)

        # Reshape the routed output back to the original [B, S, D] shape.
        routed_out = routed_out.view(B, S, D)
        # The final output is the sum of the original input, shared expert outputs, and routed expert outputs.
        return x + shared_out + routed_out

    @torch.no_grad()
    def update_bias(self, x: torch.Tensor):
        """
        Updates the router bias based on expert load.

        This method is typically called once per optimizer step using the
        same batch of tokens that were passed through the forward method.
        It uses the current router logits (including the current bias) to
        estimate the load on each expert and adjusts the bias to encourage
        a more balanced distribution of tokens across experts.

        Args:
            x (torch.Tensor): The input tensor of shape [B, S, D], identical
                              to the input used in the corresponding forward pass.
        """
        # Calculate the total number of tokens.
        N = x.shape[0] * x.shape[1]
        # Calculate the router logits (affinity scores) for each token and expert, including the current bias.
        logits = F.linear(x.reshape(-1, self.d_model), self.centroids) + self.bias
        # Determine the top_k experts selected for each token based on the current logits.
        _, idx = torch.topk(logits, self.top_k, dim=-1)

        # Count how many times each expert was selected as one of the top_k.
        counts = torch.bincount(idx.flatten(), minlength=self.n_routed).float()
        # Calculate the average number of times an expert should ideally be selected.
        avg = counts.sum() / max(1, self.n_routed)

        # Calculate the "violation" for each expert. A positive violation means
        # the expert is under-loaded compared to the average, and its bias
        # should be increased to make it more likely to be selected in the future.
        # A negative violation means it's over-loaded, and its bias should be decreased.
        # Add a small epsilon (1e-6) to the denominator to avoid division by zero.
        violation = (avg - counts) / (avg + 1e-6)
        # Update the bias using a smooth, bounded update based on the violation.
        # torch.tanh() squashes the violation into the range [-1, 1], preventing
        # excessively large bias updates. The bias_lr controls the step size.
        self.bias.add_(self.bias_lr * torch.tanh(violation))
     

In [4]:

# Test the DeepSeekMoE class
d_model = 1024
n_routed_exp = 16
n_shared_exp = 2
top_k = 8

model = DeepSeekMoE(d_model, n_routed_exp, n_shared_exp, top_k)

# Create different random input data
batch_size_new = 2
seq_len_new = 64
random_input_new = torch.randn(batch_size_new, seq_len_new, d_model)

# Pass the new random input to the model's forward method
output_new = model(random_input_new)

print("New output shape:", output_new.shape)

New output shape: torch.Size([2, 64, 1024])


# DeepSeek MoE - Code Walkthrough

This document explains **every line of code** in simple terms.

---

## 1. Imports

```python
import math 
from contextlib import nullcontext 
from typing import Optional 

import torch 
import torch.nn as nn 
import torch.nn.functional as F
```

| Import | What It Does |
|--------|--------------|
| `math` | Python's math library (for `sqrt`, `pi`, etc.) |
| `nullcontext` | A "do nothing" context manager (placeholder) |
| `Optional` | Type hint meaning "this can be None" |
| `torch` | PyTorch - the deep learning framework |
| `torch.nn` | Neural network building blocks (layers, modules) |
| `torch.nn.functional as F` | Functional operations (softmax, linear, etc.) |

---

## 2. GELU Activation Function

```python
def _gelu(x: torch.Tensor) -> torch.Tensor:
    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) *
                                       (x + 0.044715 * torch.pow(x, 3))))
```

### What is GELU?
GELU = **G**aussian **E**rror **L**inear **U**nit

It's a smooth activation function (like ReLU but smoother):

```
Input x:     -2    -1     0     1     2
GELU(x):   -0.04  -0.16   0    0.84  1.95

         ReLU:              GELU:
         │    /             │    ╱
         │   /              │  ╱
    ─────┼──/────      ────╱┼─────
         │                ╱  │
         │              ╱    │
```

### Code Breakdown:

```python
def _gelu(x: torch.Tensor) -> torch.Tensor:
#   └─────┘               └───────────────┘
#   function name         returns a tensor
#          └───────────────┘
#          x is a tensor (type hint)
```

```python
return 0.5 * x * (1.0 + torch.tanh(...))
#      └───────────────────────────────┘
#      This is the approximate GELU formula
#      (faster than exact GELU)
```

```python
math.sqrt(2.0 / math.pi)  # ≈ 0.7979 (a constant)
x + 0.044715 * torch.pow(x, 3)  # x + 0.044715 * x³
#                 └───────────┘
#                 x raised to power 3
```

---

## 3. ExpertFFN Class (Feed-Forward Network)

```python
class ExpertFFN(nn.Module):
```
- `class ExpertFFN` → Define a new class called ExpertFFN
- `nn.Module` → Inherit from PyTorch's base neural network class

### Constructor (`__init__`)

```python
def __init__(self, d_model: int, hidden: int, dropout: float = 0.0):
#            └──┘  └───────────┘ └──────────┘ └─────────────────┘
#            self   input dim    hidden dim   dropout rate (default 0)
    
    super().__init__()
#   └────────────────┘
#   Call parent class constructor (required for nn.Module)
```

```python
    self.fc1 = nn.Linear(d_model, hidden, bias=False)
#   └──────┘   └───────────────────────────────────┘
#   save as    Create a linear layer: d_model → hidden
#   attribute  bias=False means no bias term (just weights)
```

```
Linear Layer Visualization:
                    
  Input [d_model]          Weights [d_model × hidden]         Output [hidden]
  ┌───┐                    ┌─────────────────────┐            ┌───┐
  │ x₁│ ──────────────────▶│  W₁₁  W₁₂  ...  W₁ₕ │──────────▶│ y₁│
  │ x₂│                    │  W₂₁  W₂₂  ...  W₂ₕ │           │ y₂│
  │...│                    │  ...  ...  ...  ... │           │...│
  │ xₙ│                    │  Wₙ₁  Wₙ₂  ...  Wₙₕ │           │ yₕ│
  └───┘                    └─────────────────────┘            └───┘
  
  Formula: y = x @ W  (matrix multiplication)
```

```python
    self.fc2 = nn.Linear(hidden, d_model, bias=False)
#   Second linear layer: hidden → d_model (back to original size)

    self.dropout = nn.Dropout(dropout)
#   Dropout layer: randomly zeros some neurons during training
#   (helps prevent overfitting)
```

### Forward Pass

```python
def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.fc2(self.dropout(_gelu(self.fc1(x))))
#          └─────────────────────────────────────────┘
#          Chain of operations (read inside-out):
```

```
Step-by-step:
┌───────────────────────────────────────────────────────────────┐
│  1. self.fc1(x)           →  x goes through first linear     │
│  2. _gelu(...)            →  apply GELU activation           │
│  3. self.dropout(...)     →  randomly drop some values       │
│  4. self.fc2(...)         →  second linear layer (output)    │
└───────────────────────────────────────────────────────────────┘

x [d_model] → fc1 → [hidden] → GELU → dropout → fc2 → [d_model]
```

---

## 4. DeepSeekMoE Class - Constructor

```python
class DeepSeekMoE(nn.Module):
    def __init__(
        self,
        d_model: int,           # Dimension of input/output (e.g., 1024)
        n_routed_exp: int,      # Number of routed experts (e.g., 16)
        n_shared_exp: int = 1,  # Number of shared experts (default 1)
        top_k: int = 8,         # How many experts to select per token
        routed_hidden: int = 2_048,  # Hidden dim for routed experts
        shared_hidden: Optional[int] = None,  # Hidden dim for shared (or None)
        bias_lr: float = 0.01,  # Learning rate for bias update
        fp16_router: bool = False,  # Use FP16 for router? (faster on GPU)
    ):
```

### Assertions and Attributes

```python
        super().__init__()
        
        assert top_k <= n_routed_exp, "k must be ≤ number of routed experts"
#       └────┘ └──────────────────┘  └─────────────────────────────────────┘
#       check  condition to check    error message if condition is False
```

```python
        # Store all parameters as instance attributes
        self.d_model = d_model        # Save for later use
        self.n_routed = n_routed_exp  # Number of routed experts
        self.n_shared = n_shared_exp  # Number of shared experts
        self.top_k = top_k            # K in "top-k" selection
        self.bias_lr = bias_lr        # Bias learning rate
        self.fp16_router = fp16_router  # FP16 flag
```

### Creating Expert Networks

```python
        self.routed = nn.ModuleList(
            [ExpertFFN(d_model, routed_hidden) for _ in range(n_routed_exp)]
        )
```

**Breaking it down:**

```python
nn.ModuleList([...])
#   └──────────────┘
#   A list of nn.Modules that PyTorch can track
#   (regular Python list won't work for gradient tracking!)
```

```python
[ExpertFFN(d_model, routed_hidden) for _ in range(n_routed_exp)]
#└────────────────────────────────────────────────────────────┘
#         List comprehension: create n_routed_exp experts
```

```
Visual:
┌─────────────────────────────────────────────────────────────┐
│  self.routed = [Expert₀, Expert₁, Expert₂, ... Expert₁₅]   │
│                    │        │        │           │          │
│                  FFN      FFN      FFN         FFN          │
│                (1024→   (1024→   (1024→      (1024→         │
│                 2048→    2048→    2048→       2048→         │
│                 1024)    1024)    1024)       1024)         │
└─────────────────────────────────────────────────────────────┘
```

```python
        hidden_shared = shared_hidden or routed_hidden
#       └───────────────────────────────────────────┘
#       If shared_hidden is None, use routed_hidden instead
#       (Python's "or" returns first truthy value)
        
        self.shared = nn.ModuleList(
            [ExpertFFN(d_model, hidden_shared) for _ in range(n_shared_exp)]
        )
#       Same pattern: create n_shared_exp shared experts
```

### Router Components

```python
        self.register_parameter("centroids", 
                                nn.Parameter(torch.empty(n_routed_exp, d_model)))
```

**What is this?**

```python
register_parameter("name", parameter)
#   └──────────────────────────────────┘
#   Register a learnable parameter with PyTorch
#   (will be updated during training via gradients)
```

```python
torch.empty(n_routed_exp, d_model)
#   └────────────────────────────┘
#   Create uninitialized tensor of shape [16, 1024]
#   (will be initialized next)
```

```
Centroids shape: [n_routed_exp, d_model] = [16, 1024]

Each row = one expert's "centroid" (preference vector)

         d_model = 1024 dimensions
         ←──────────────────────────→
    ┌────────────────────────────────┐  ↑
    │  Expert 0's centroid vector    │  │
    ├────────────────────────────────┤  │
    │  Expert 1's centroid vector    │  │ n_routed_exp
    ├────────────────────────────────┤  │ = 16 experts
    │           ...                  │  │
    ├────────────────────────────────┤  │
    │  Expert 15's centroid vector   │  ↓
    └────────────────────────────────┘
```

```python
        nn.init.normal_(self.centroids, std=d_model ** -0.5)
#       └───────────────────────────────────────────────────┘
#       Initialize with normal distribution
#       std = 1/√1024 ≈ 0.031 (small values)
```

```python
        self.register_buffer("bias", torch.zeros(n_routed_exp))
#       └──────────────────────────────────────────────────────┘
#       register_buffer = NOT a learnable parameter
#       (won't get gradients, but will be saved with model)
#       
#       torch.zeros(16) = [0, 0, 0, ... 0] (16 zeros)
```

---

## 5. Forward Pass - Step by Step

```python
def forward(self, x: torch.Tensor) -> torch.Tensor:
```

### Step 1: Get Dimensions

```python
        B, S, D = x.shape
#       └─────────────────┘
#       Unpack shape: Batch size, Sequence length, Dimension
#       e.g., x.shape = [2, 64, 1024]
#             B=2, S=64, D=1024
```

```python
        x_flat = x.reshape(-1, D)  # [N, D] with N=B*S
#               └────────────────┘
#       Flatten first two dimensions
#       [2, 64, 1024] → [128, 1024]
#       
#       -1 means "calculate this dimension automatically"
```

```
Before:                          After:
x: [B, S, D]                     x_flat: [N, D]
   [2, 64, 1024]                         [128, 1024]
   
   ┌─────────────┐               ┌─────────────┐
   │  Batch 0    │               │  Token 0    │
   │ ┌─────────┐ │               │  Token 1    │
   │ │Token 0-63│ │   flatten    │  Token 2    │
   │ └─────────┘ │   ───────▶    │    ...      │
   ├─────────────┤               │  Token 127  │
   │  Batch 1    │               └─────────────┘
   │ ┌─────────┐ │               
   │ │Token 0-63│ │               N = 2×64 = 128
   │ └─────────┘ │               
   └─────────────┘               
```

### Step 2: Shared Experts Path

```python
        shared_out = torch.zeros_like(x)
#       └──────────────────────────────┘
#       Create zero tensor with same shape as x: [B, S, D]
```

```python
        for exp in self.shared:
            shared_out += exp(x)
#       └────────────────────────┘
#       Loop through each shared expert
#       Add each expert's output to shared_out
```

```
Visual:
        x ──────┬──────▶ Shared Expert 0 ──▶ output₀
                │
                └──────▶ Shared Expert 1 ──▶ output₁
                
        shared_out = output₀ + output₁
```

### Step 3: Router Logits

```python
        use_autocast = self.fp16_router and x.is_cuda
#       └────────────────────────────────────────────┘
#       Use FP16 only if flag is True AND on GPU
        
        device_type = "cuda" if x.is_cuda else x.device.type
#       └───────────────────────────────────────────────────┘
#       Get device type string for autocast
```

```python
        with torch.autocast(device_type=device_type, enabled=use_autocast):
#       └──────────────────────────────────────────────────────────────────┘
#       Context manager for mixed precision (FP16)
#       Makes computation faster on modern GPUs
```

```python
            logits = F.linear(x_flat, self.centroids)  # [N, E]
#                   └──────────────────────────────────┘
#           Compute dot product between tokens and expert centroids
#           
#           x_flat:    [N, D] = [128, 1024]
#           centroids: [E, D] = [16, 1024]
#           logits:    [N, E] = [128, 16]
```

```
F.linear(input, weight) = input @ weight.T

For each token, compute similarity with each expert:

Token₀ • Centroid₀ = score₀₀   ─┐
Token₀ • Centroid₁ = score₀₁    │ ← logits for Token 0
Token₀ • Centroid₂ = score₀₂    │
...                             ─┘

Result: [128 tokens × 16 experts] = 128×16 scores
```

```python
            logits = logits + self.bias.to(logits.dtype)
#                   └──────────────────────────────────┘
#           Add bias to each expert's score
#           .to(logits.dtype) ensures matching data types
```

### Step 4: Top-K Selection

```python
        topk_logits, topk_idx = torch.topk(logits, self.top_k, dim=-1)
#       └─────────────────────────────────────────────────────────────┘
#       Select top_k highest values along last dimension
#       
#       topk_logits: [N, k] = [128, 8] ← the actual scores
#       topk_idx:    [N, k] = [128, 8] ← which experts were selected
```

```
Example for one token:
logits = [0.1, 0.9, 0.3, 0.7, 0.2, 0.8, 0.4, 0.6, ...]  (16 values)
                ↑        ↑        ↑
         torch.topk(logits, k=3) selects top 3:
         
topk_logits = [0.9, 0.8, 0.7]  ← highest scores
topk_idx    = [1,   5,   3]    ← which expert indices
```

```python
        gate = F.softmax(topk_logits, dim=-1, dtype=x.dtype)
#       └──────────────────────────────────────────────────┘
#       Convert scores to probabilities (sum to 1)
#       
#       Example: [0.9, 0.8, 0.7] → softmax → [0.38, 0.33, 0.29]
#                                              └────────────────┘
#                                                   sums to 1.0
```

### Step 5: Expert Dispatch

```python
        routed_out = torch.zeros_like(x_flat)  # [N, D]
#       └────────────────────────────────────┘
#       Initialize output buffer with zeros
```

```python
        for i in range(self.n_routed):  # Loop through 16 experts
```

```python
            mask = (topk_idx == i)
#           └────────────────────┘
#           Boolean mask: where was expert i selected?
#           Shape: [N, k] = [128, 8]
```

```
Example (expert i=3):
topk_idx = [[1, 5, 3, 7, ...],    ← Token 0 selected experts
            [3, 2, 8, 1, ...],    ← Token 1 selected experts
            [0, 3, 5, 9, ...],    ← Token 2 selected experts
            ...]

mask = [[False, False, True, False, ...],   ← expert 3 at position 2
        [True, False, False, False, ...],   ← expert 3 at position 0
        [False, True, False, False, ...],   ← expert 3 at position 1
        ...]
```

```python
            row_idx, which_k = mask.nonzero(as_tuple=True)
#           └─────────────────────────────────────────────┘
#           Find indices where mask is True
#           
#           row_idx: which tokens selected expert i
#           which_k: at which position in top-k
```

```python
            if row_idx.numel() == 0:
                continue
#           └─────────────────────────┘
#           Skip if no tokens selected this expert
#           .numel() = number of elements
```

```python
            exp_in = x_flat.index_select(0, row_idx)  # [Ti, D]
#           └──────────────────────────────────────┘
#           Select only the tokens that chose this expert
#           
#           index_select(dim, indices):
#             - dim=0 means select along first dimension (rows)
#             - row_idx are the row indices to select
```

```python
            out = self.routed[i](exp_in)  # [Ti, D]
#           └──────────────────────────┘
#           Pass selected tokens through expert i's FFN
```

```python
            w = gate[row_idx, which_k].unsqueeze(-1)  # [Ti, 1]
#           └──────────────────────────────────────┘
#           Get the gating weights for these tokens
#           
#           gate[row_idx, which_k]: select specific elements
#           .unsqueeze(-1): add dimension at end [Ti] → [Ti, 1]
#                          (needed for broadcasting)
```

```python
            routed_out.index_add_(0, row_idx, out * w)
#           └────────────────────────────────────────┘
#           Add weighted output back to original positions
#           
#           index_add_(dim, indices, source):
#             - In-place addition at specific indices
#             - routed_out[row_idx] += out * w
```

```
Visual of dispatch:
                                    Expert 0
Token 0 ─────────────────────────▶ ┌───────┐
Token 5 ─────────────────────────▶ │ FFN 0 │──▶ weighted output → routed_out[0,5,...]
Token 12 ────────────────────────▶ └───────┘

                                    Expert 1
Token 1 ─────────────────────────▶ ┌───────┐
Token 3 ─────────────────────────▶ │ FFN 1 │──▶ weighted output → routed_out[1,3,...]
Token 7 ─────────────────────────▶ └───────┘

... (repeat for all 16 experts)
```

### Step 6: Combine Outputs

```python
        routed_out = routed_out.view(B, S, D)
#       └───────────────────────────────────┘
#       Reshape back: [128, 1024] → [2, 64, 1024]
```

```python
        return x + shared_out + routed_out
#       └─────────────────────────────────┘
#       Residual connection: add everything together
#       
#       output = original_input + shared_experts + routed_experts
```

```
Final combination:
┌─────────────────────────────────────────────────┐
│                                                 │
│   x (input) ──────────────────────────┐         │
│        │                              │         │
│        ├──▶ Shared Experts ──────────┐│         │
│        │                             ││         │
│        └──▶ Router + Routed Experts ─┴┴──▶ ADD ─┼──▶ output
│                                                 │
└─────────────────────────────────────────────────┘
```

---

## 6. Update Bias Method

```python
    @torch.no_grad()
#   └──────────────┘
#   Decorator: disable gradient computation
#   (we don't want gradients here, just direct updates)
    
    def update_bias(self, x: torch.Tensor):
```

```python
        N = x.shape[0] * x.shape[1]
#       └─────────────────────────────┘
#       Total number of tokens = Batch × Sequence
```

```python
        logits = F.linear(x.reshape(-1, self.d_model), self.centroids) + self.bias
#       └──────────────────────────────────────────────────────────────────────────┘
#       Recalculate router logits (same as in forward)
```

```python
        _, idx = torch.topk(logits, self.top_k, dim=-1)
#       └─────────────────────────────────────────────┘
#       Get which experts were selected (we don't need the values)
#       _ means "discard this value"
```

```python
        counts = torch.bincount(idx.flatten(), minlength=self.n_routed).float()
#       └─────────────────────────────────────────────────────────────────────┘
#       Count how many times each expert was selected
#       
#       idx.flatten(): convert [128, 8] → [1024] (all selections)
#       bincount: count occurrences of each value 0-15
#       minlength: ensure output has 16 elements
#       .float(): convert to float for math operations
```

```
Example:
idx.flatten() = [1, 5, 3, 7, 3, 2, 8, 1, ...]  (1024 values)

bincount counts occurrences:
Expert 0: appeared 50 times
Expert 1: appeared 80 times
Expert 2: appeared 45 times
...
Expert 15: appeared 70 times

counts = [50, 80, 45, ..., 70]
```

```python
        avg = counts.sum() / max(1, self.n_routed)
#       └────────────────────────────────────────┘
#       Calculate average load per expert
#       
#       Example: sum=1024 (total selections), n_routed=16
#                avg = 1024 / 16 = 64 (ideal: each expert gets 64)
```

```python
        violation = (avg - counts) / (avg + 1e-6)
#       └───────────────────────────────────────┘
#       How far is each expert from the average?
#       
#       If count < avg: violation > 0 (under-loaded, needs boost)
#       If count > avg: violation < 0 (over-loaded, needs penalty)
#       
#       1e-6 prevents division by zero
```

```
Example:
avg = 64
counts = [50, 80, 45, 70, ...]

violation[0] = (64 - 50) / 64 = +0.22  (under-loaded, increase bias)
violation[1] = (64 - 80) / 64 = -0.25  (over-loaded, decrease bias)
```

```python
        self.bias.add_(self.bias_lr * torch.tanh(violation))
#       └──────────────────────────────────────────────────┘
#       Update bias in-place
#       
#       torch.tanh(): squash to [-1, 1] range (prevents extreme updates)
#       self.bias_lr: step size (0.01 = small steps)
#       .add_(): in-place addition (modifies self.bias directly)
```

```
tanh function:
                1 ┤         ╭──────────
                  │       ╱
                0 ┼─────╱─────────────
                  │   ╱
               -1 ┤──╯
                  └─────────────────────
                  -3   -1   0    1    3

Squashes any value to range [-1, 1]
Prevents explosive bias updates!
```

---

## 7. Testing the Model

```python
# Configuration
d_model = 1024         # Each token is a 1024-dim vector
n_routed_exp = 16      # 16 routed experts
n_shared_exp = 2       # 2 shared experts
top_k = 8              # Select 8 experts per token

# Create model
model = DeepSeekMoE(d_model, n_routed_exp, n_shared_exp, top_k)

# Create random input
batch_size_new = 2
seq_len_new = 64
random_input_new = torch.randn(batch_size_new, seq_len_new, d_model)
#                  └───────────────────────────────────────────────┘
#                  Random tensor from normal distribution
#                  Shape: [2, 64, 1024]

# Forward pass
output_new = model(random_input_new)

print("New output shape:", output_new.shape)
# Output: torch.Size([2, 64, 1024])  ← Same shape as input!
```

---

## Quick Reference: Key PyTorch Operations

| Operation | What It Does | Example |
|-----------|--------------|---------|
| `tensor.shape` | Get dimensions | `[2, 64, 1024]` |
| `tensor.reshape(-1, D)` | Flatten to 2D | `[128, 1024]` |
| `tensor.view(B, S, D)` | Reshape (must be contiguous) | `[2, 64, 1024]` |
| `F.linear(x, w)` | Matrix multiply: `x @ w.T` | `[N, D] @ [E, D].T → [N, E]` |
| `F.softmax(x, dim=-1)` | Normalize to probabilities | Sum to 1.0 |
| `torch.topk(x, k)` | Get k largest values | Values and indices |
| `tensor.index_select(0, idx)` | Select rows by index | Subset of tensor |
| `tensor.index_add_(0, idx, src)` | Add to specific rows | In-place accumulate |
| `torch.zeros_like(x)` | Zero tensor same shape as x | Initialization |
| `torch.bincount(x)` | Count occurrences | Histogram |
| `torch.tanh(x)` | Hyperbolic tangent | Squash to [-1, 1] |
| `tensor.add_(x)` | In-place addition | Modify in place |
| `@torch.no_grad()` | Disable gradients | For inference/manual updates |

---

## Summary

The code implements a **Mixture of Experts** layer where:

1. **Shared experts** always process all tokens
2. **Router** calculates which routed experts each token should use
3. **Top-k selection** picks the best experts for each token
4. **Gating** weights the contribution of each selected expert
5. **Dispatch** sends tokens to their selected experts
6. **Combine** adds everything together with a residual connection
7. **Bias update** keeps expert loads balanced over time

This achieves **large model capacity** with **efficient computation!**
