# Multi-Token Prediction (MTP) - DeepSeek's Approach

## What is Multi-Token Prediction?

**Traditional LLMs:** Predict **ONE token at a time** (slow!)

**Multi-Token Prediction:** Predict **MULTIPLE tokens at once** (fast!)

```
Traditional (Next-Token Prediction):
┌─────────────────────────────────────────────────────────────────┐
│                                                                 │
│   "The cat sat on the" ──▶ Model ──▶ "mat"                      │
│   "The cat sat on the mat" ──▶ Model ──▶ "and"                  │
│   "The cat sat on the mat and" ──▶ Model ──▶ "slept"            │
│                                                                 │
│   3 tokens = 3 forward passes (SLOW!)                           │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Multi-Token Prediction (MTP):
┌─────────────────────────────────────────────────────────────────┐
│                                                                 │
│   "The cat sat on the" ──▶ Model ──▶ "mat", "and", "slept"      │
│                                                                 │
│   3 tokens = 1 forward pass (FAST!)                             │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘
```

---

## Why Does This Matter?

```
The Bottleneck Problem:
═══════════════════════

In autoregressive generation, each token depends on ALL previous tokens.

Token 1 ──▶ Token 2 ──▶ Token 3 ──▶ Token 4 ──▶ ...
   │           │           │           │
   ▼           ▼           ▼           ▼
Forward     Forward     Forward     Forward
Pass 1      Pass 2      Pass 3      Pass 4

Problem: Sequential dependency = Can't parallelize during inference!
```

**MTP Solution:** Train the model to predict multiple future tokens simultaneously.

---

## How DeepSeek Implements MTP

### Architecture Overview

```
                              Input Tokens
                                   │
                                   ▼
                    ┌──────────────────────────┐
                    │                          │
                    │     Main Transformer     │
                    │      (Shared Trunk)      │
                    │                          │
                    └────────────┬─────────────┘
                                 │
                      Hidden States [h₁, h₂, h₃, ...]
                                 │
            ┌────────────────────┼────────────────────┐
            │                    │                    │
            ▼                    ▼                    ▼
    ┌──────────────┐     ┌──────────────┐     ┌──────────────┐
    │   MTP Head   │     │   MTP Head   │     │   MTP Head   │
    │   (k = 1)    │     │   (k = 2)    │     │   (k = 3)    │
    │  Next Token  │     │  +2 Token    │     │  +3 Token    │
    └──────┬───────┘     └──────┬───────┘     └──────┬───────┘
           │                    │                    │
           ▼                    ▼                    ▼
        Token₁               Token₂               Token₃
      (position t+1)       (position t+2)       (position t+3)
```

### Key Components

```
1. Shared Transformer Trunk:
   ┌─────────────────────────────────────────┐
   │  Same backbone processes input once     │
   │  Produces rich hidden representations   │
   │  Most computation happens here          │
   └─────────────────────────────────────────┘

2. Multiple Prediction Heads:
   ┌─────────────────────────────────────────┐
   │  Lightweight heads (small MLPs)         │
   │  Each head predicts a different future  │
   │  position: t+1, t+2, t+3, ...           │
   └─────────────────────────────────────────┘
```

---

## DeepSeek's MTP Module Design

```
For each prediction depth k:

    Hidden State (from transformer)
           │
           ▼
    ┌─────────────┐
    │  Embedding  │  ◄── Previous prediction's embedding
    │   Lookup    │      (for k > 1, chain predictions)
    └──────┬──────┘
           │
           ▼
    ┌─────────────┐
    │   Concat    │  ◄── Combine hidden state + embedding
    └──────┬──────┘
           │
           ▼
    ┌─────────────┐
    │  MTP Block  │  ◄── Small transformer layer
    │  (1 layer)  │      (self-attention + FFN)
    └──────┬──────┘
           │
           ▼
    ┌─────────────┐
    │   Output    │  ◄── Project to vocabulary
    │   Head      │
    └──────┬──────┘
           │
           ▼
      Prediction k
```

### The Chaining Mechanism

```
How predictions flow:

Step 1: Main Model produces hidden states
        h = Transformer(input_tokens)

Step 2: Head 1 predicts token at t+1
        pred₁ = Head₁(h)

Step 3: Head 2 uses pred₁'s embedding + h to predict t+2
        pred₂ = Head₂(concat(h, embed(pred₁)))

Step 4: Head 3 uses pred₂'s embedding + h to predict t+3
        pred₃ = Head₃(concat(h, embed(pred₂)))

... and so on for k prediction depths
```

---

## Training vs Inference

### During Training

```
┌─────────────────────────────────────────────────────────────────┐
│                         TRAINING                                │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  Input: "The cat sat on the mat and slept"                      │
│                                                                 │
│  For position "the":                                            │
│  ┌─────────────────────────────────────────┐                    │
│  │  Head 1 target: "mat"      (t+1)        │                    │
│  │  Head 2 target: "and"      (t+2)        │                    │
│  │  Head 3 target: "slept"    (t+3)        │                    │
│  └─────────────────────────────────────────┘                    │
│                                                                 │
│  Loss = Loss₁ + Loss₂ + Loss₃                                   │
│         (weighted sum of all prediction losses)                 │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘
```

### During Inference (Speculative Decoding)

```
┌─────────────────────────────────────────────────────────────────┐
│                        INFERENCE                                │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  Step 1: Generate k draft tokens using MTP heads                │
│          draft = [token₁, token₂, token₃, ...]                  │
│                                                                 │
│  Step 2: Verify ALL drafts in ONE forward pass                  │
│          verified = MainModel.verify(draft)                     │
│                                                                 │
│  Step 3: Accept correct predictions, reject wrong ones          │
│          If token₁ ✓, token₂ ✓, token₃ ✗                        │
│          Accept: token₁, token₂                                 │
│          Regenerate from token₃                                 │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘
```

---

## Why MTP Makes Inference Faster

### The Speedup Intuition

```
Traditional Autoregressive:
═══════════════════════════

Generate 12 tokens = 12 sequential forward passes

Pass 1 ──▶ Pass 2 ──▶ Pass 3 ──▶ ... ──▶ Pass 12
  │          │          │                   │
  ▼          ▼          ▼                   ▼
 T1         T2         T3         ...      T12

Time = 12 × (forward pass time)


With MTP (k=4 prediction depth):
════════════════════════════════

Generate 12 tokens ≈ 3-4 forward passes (with verification)

Pass 1 ──────────────▶ Pass 2 ──────────────▶ Pass 3
  │                      │                      │
  ▼                      ▼                      ▼
T1,T2,T3,T4           T5,T6,T7,T8           T9,T10,T11,T12
(draft+verify)        (draft+verify)        (draft+verify)

Time ≈ 3-4 × (forward pass time)

Speedup: ~3-4x faster!
```

### Acceptance Rate

```
The key metric: How often are draft tokens correct?

High Acceptance Rate (good):
┌─────────────────────────────────────────┐
│  Draft: [mat, and, slept, peacefully]   │
│  Verify: [✓,   ✓,   ✓,     ✓]           │
│  Accept ALL 4 tokens!                   │
└─────────────────────────────────────────┘

Low Acceptance Rate (less speedup):
┌─────────────────────────────────────────┐
│  Draft: [mat, or,  jumped, quickly]     │
│  Verify: [✓,   ✗,   -,      -]          │
│  Accept only 1 token, regenerate rest   │
└─────────────────────────────────────────┘

DeepSeek reports: ~85-90% acceptance rate for greedy decoding
```

---

## Benefits of MTP

### 1. Faster Inference

```
┌─────────────────────────────────────────────────────────────────┐
│  Tokens per Second Comparison:                                  │
│                                                                 │
│  Standard:    ████████████████ 50 tok/s                         │
│  With MTP:    ████████████████████████████████████ 150 tok/s    │
│                                                                 │
│  ~2-3x speedup in practice!                                     │
└─────────────────────────────────────────────────────────────────┘
```

### 2. Better Representations (Training Benefit)

```
Why training with MTP helps:
════════════════════════════

Predicting multiple tokens forces the model to:

  ┌─────────────────────────────────────────┐
  │  1. Plan ahead (not just next token)    │
  │  2. Learn longer-range dependencies     │
  │  3. Build richer hidden representations │
  │  4. Better understand context           │
  └─────────────────────────────────────────┘

Result: Even single-token prediction improves!
```

### 3. Memory Efficiency

```
┌─────────────────────────────────────────────────────────────────┐
│  KV Cache Savings:                                              │
│                                                                 │
│  Traditional: Update cache 1 token at a time                    │
│  MTP: Update cache for multiple tokens at once                  │
│                                                                 │
│  Fewer memory operations = Less memory bandwidth used           │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘
```

---

## DeepSeek's Specific Implementation Details

### MTP Head Structure

```python
# Simplified structure of each MTP head:

class MTPHead:
    def __init__(self, d_model, vocab_size):
        self.embed = Embedding(vocab_size, d_model)
        self.proj = Linear(2 * d_model, d_model)  # Concat input
        self.transformer_block = TransformerBlock(d_model)
        self.output = Linear(d_model, vocab_size)
    
    def forward(self, hidden_state, prev_token_embed):
        # Combine hidden state with previous prediction
        x = concat(hidden_state, prev_token_embed)
        x = self.proj(x)
        x = self.transformer_block(x)
        logits = self.output(x)
        return logits
```

### Training Objective

```
Total Loss = λ₀·L₀ + λ₁·L₁ + λ₂·L₂ + ... + λₖ·Lₖ

Where:
  L₀ = Main next-token prediction loss
  L₁ = MTP head 1 loss (t+1)
  L₂ = MTP head 2 loss (t+2)
  ...
  λᵢ = Weight for each head (often decreasing)

DeepSeek uses: λ = 1.0 for all heads (equal weighting)
```

---



```
┌─────────────────────────────────────────────────────────────────┐
│              Multi-Token Prediction Pipeline                    │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  TRAINING:                                                      │
│  ─────────                                                      │
│                                                                 │
│  Input ──▶ Transformer ──┬──▶ Head₁ ──▶ Loss (t+1)              │
│                          ├──▶ Head₂ ──▶ Loss (t+2)              │
│                          ├──▶ Head₃ ──▶ Loss (t+3)              │
│                          └──▶ Head₄ ──▶ Loss (t+4)              │
│                                                                 │
│  All heads trained jointly, shared trunk learns better!         │
│                                                                 │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  INFERENCE (Speculative Decoding):                              │
│  ─────────────────────────────────                              │
│                                                                 │
│  ┌─────────┐     ┌──────────────┐     ┌────────────┐            │
│  │  Draft  │ ──▶ │    Verify    │ ──▶ │   Accept   │            │
│  │ k tokens│     │ (1 fwd pass) │     │ or Reject  │            │
│  └─────────┘     └──────────────┘     └────────────┘            │
│                                                                 │
│  Accepted tokens: Skip generation, move forward fast!           │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘
```

---

## Key Takeaways

| Aspect | Traditional | With MTP |
|--------|-------------|----------|
| **Tokens per forward pass** | 1 | k (multiple) |
| **Inference speed** | Baseline | 2-3x faster |
| **Training signal** | Next token only | Multiple future tokens |
| **Representation quality** | Good | Better (plans ahead) |
| **Memory efficiency** | Standard | Improved (batch KV updates) |
| **Complexity** | Simple | Slightly more complex |

---

## Why This Works: The Intuition

```
Think of it like writing:
═════════════════════════

Slow writer (traditional):
  "The" ──▶ think ──▶ "cat" ──▶ think ──▶ "sat" ──▶ think ──▶ ...

Fast writer (MTP):
  "The" ──▶ think ──▶ "cat sat on" ──▶ verify ──▶ continue...

The fast writer:
  1. Has a plan in mind (multiple tokens)
  2. Writes in chunks
  3. Only pauses to verify occasionally

Same idea for LLMs:
  - MTP heads draft multiple tokens quickly
  - Main model verifies in one pass
  - Accept correct predictions, retry wrong ones
```

---

## Summary

Multi-Token Prediction (MTP) as implemented by DeepSeek:

1. **Trains multiple prediction heads** to forecast future tokens
2. **Uses speculative decoding** during inference
3. **Achieves 2-3x speedup** with high acceptance rates
4. **Improves model quality** by learning longer-range dependencies
5. **Reduces memory operations** by batching KV cache updates

This is a key technique that makes DeepSeek-V3 both **fast** and **capable**!

## Code Implementation 

In [19]:
import torch
import torch.nn as nn
import torch.functional as F
 

In [20]:
class RMSNorm(nn.Module):
    def __init__(self, d_model, eps:float=1e-8):
        super().__init__()
        self.eps=eps
    def forward(self,x):
        rms=torch.sqrt(x.pow(2).mean(dim=-1,keepdim=True+self.eps))
        return x/rms


In [None]:
class SimpleMTP(nn.Module):
    def __init__(self,d_model:int,vocab_size:int,num_heads:int=3,n_head:int=2):
        super().__init__()
        self.d_model=d_model
        self.vocab_size=vocab_size
        self.num_heads=num_heads
        self.n_head=n_head

        #shared modules
        self.RMSNorm=RMSNorm(d_model=self.d_model)
        self.embed=nn.Embedding(vocab_size,d_model)
        self.unembed=nn.Linear(d_model,vocab_size,bias=False)
        # shared weights between embed and unembed 
        self.unembed.weight=self.embed.weight
        # one projection + transformer per head 
        self.projections=nn.ModuleList([nn.Linear(2*d_model,d_model) for _ in range(num_heads)])

        self.tranformer=nn.ModuleList([nn.TransformerEncoderLayer(d_model=d_model,nhead=n_head ) for _ in range(num_heads)])


    ''' token_ids: (batch, seq_len) integer IDs of your input tokens init_hidden: optional (batch, seq_len, d_model) base hidden states;
    # if None, uses token embeddings as initial '''

    '''Returns:
        logits_out: Tensor of shape (batch, T-D, D, vocab_size), where T=sea_len and D=num_heads'''

    def forward(self, token_ids: torch.LongTensor, init_hidden: torch.Tensor = None):
        B, T = token_ids.shape
        device = token_ids.device

        # token embeddings: (B, T, d_model)
        embeds = self.embed(token_ids)
        
        # base hidden states
        if init_hidden is None:
            h0_seq = embeds  # use embeddings as base hidden
        else:
            h0_seq = init_hidden  # user-provided base states
        
        outputs = []  # will hold (B, num_heads, vocab_size) for each position i
        
        # slide over positions where i + num_heads < T
        max_i = T - self.num_heads - 1
        
        for i in range(0, max_i + 1):
            # previous hidden for depth 0 at pos i
            h_prev = h0_seq[:, i, :]  # (B, d_model)
            
            # collect logits for all k at this position i
            logits_k = []
            
            for k in range(self.num_heads):
                # future token embed at pos i + (k+1)
                future_pos = i + (k + 1)
                tok_embed = embeds[:, future_pos, :]  # (B, d_model)
                
                # concatenate hidden state with future token embedding
                x = torch.cat([h_prev, tok_embed], dim=-1)  # (B, 2*d_model)
                
                # project down to d_model
                x = self.projections[k](x)  # (B, d_model)
                
                # apply RMSNorm
                x = self.RMSNorm(x)
                
                # pass through transformer block (adding batch dimension for transformer)
                x = x.unsqueeze(1)  # (B, 1, d_model)
                x = self.tranformer[k](x)  # (B, 1, d_model)
                x = x.squeeze(1)  # (B, d_model)
                
                # project to vocabulary size
                logits = self.unembed(x)  # (B, vocab_size)
                logits_k.append(logits)
                
                # update h_prev for next prediction head (chain predictions)
                h_prev = x
            
            # stack all predictions for this position: (B, num_heads, vocab_size)
            logits_k = torch.stack(logits_k, dim=1)
            outputs.append(logits_k)
        
        # stack all positions: (B, T-num_heads-1, num_heads, vocab_size)
        outputs = torch.stack(outputs, dim=1)
        
        return outputs

        
