# Speculative Decoding: Fast LLM Inference

This notebook explores speculative decoding, a technique for accelerating autoregressive LLM inference.

We'll cover:
1. **Problem**: Why is LLM generation slow?
2. **Solution**: Speculative decoding algorithm
3. **Implementation**: Draft model + target model
4. **Performance analysis**: Speedup calculations
5. **Variants**: Self-speculative, multi-token prediction

## The Problem: Memory-Bound Generation

LLM generation is slow because:
- Each token requires full model forward pass
- Memory bandwidth limited (loading weights)
- GPU underutilized (low arithmetic intensity)
- Can't parallelize across sequence (autoregressive)

**Key insight**: Generate multiple tokens in parallel, accept if correct!

## 1. Import Libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import time
from typing import List, Tuple, Optional

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Standard Autoregressive Decoding (Baseline)

Traditional approach: Generate one token at a time
- Run model forward pass
- Sample next token
- Append to sequence
- Repeat

**Complexity**: O(n) forward passes for n tokens

In [None]:
class SimpleLanguageModel(nn.Module):
    """
    Simple language model for demonstration.
    """
    def __init__(self, vocab_size, d_model, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead=8, dim_feedforward=d_model*4, batch_first=True),
            num_layers=num_layers
        )
        self.output = nn.Linear(d_model, vocab_size)
        self.d_model = d_model
    
    def forward(self, x):
        """
        Args:
            x: Token indices (batch_size, seq_len)
        Returns:
            Logits (batch_size, seq_len, vocab_size)
        """
        x = self.embedding(x) * np.sqrt(self.d_model)
        x = self.transformer(x)
        return self.output(x)

def standard_decoding(model, prompt_tokens, max_new_tokens, temperature=1.0):
    """
    Standard autoregressive decoding (one token at a time).
    
    Args:
        model: Language model
        prompt_tokens: Initial tokens (batch_size, prompt_len)
        max_new_tokens: Number of tokens to generate
        temperature: Sampling temperature
    
    Returns:
        Generated tokens, number of forward passes
    """
    model.eval()
    generated = prompt_tokens.clone()
    num_forward_passes = 0
    
    with torch.no_grad():
        for _ in range(max_new_tokens):
            # Forward pass through entire sequence
            logits = model(generated)
            num_forward_passes += 1
            
            # Sample next token
            next_token_logits = logits[:, -1, :] / temperature
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Append to sequence
            generated = torch.cat([generated, next_token], dim=1)
    
    return generated, num_forward_passes

# Example usage
vocab_size = 1000
d_model = 256
model = SimpleLanguageModel(vocab_size, d_model, num_layers=4)

prompt = torch.randint(0, vocab_size, (1, 5))  # Batch=1, prompt_len=5
generated, num_passes = standard_decoding(model, prompt, max_new_tokens=10)

print(f"Standard Decoding:")
print(f"  Prompt length: {prompt.shape[1]}")
print(f"  Generated length: {generated.shape[1]}")
print(f"  New tokens: {generated.shape[1] - prompt.shape[1]}")
print(f"  Forward passes: {num_passes}")
print(f"  Tokens per pass: {(generated.shape[1] - prompt.shape[1]) / num_passes:.2f}")

## 3. Speculative Decoding Algorithm

**Key Idea:**
1. Use small **draft model** to quickly generate K candidate tokens
2. Verify all K tokens in parallel with **target model** (one forward pass)
3. Accept tokens where draft and target agree
4. Reject first mismatch and resample from target

**Benefits:**
- Generate multiple tokens per target model forward pass
- Mathematically equivalent to standard decoding (same distribution)
- Speedup: 2-3× typical, up to 5× possible

**Algorithm:**
```
1. Generate K draft tokens: x₁, x₂, ..., xₖ (fast, draft model)
2. Get target model probabilities: p_target(x₁), p_target(x₂|x₁), ...
3. Get draft model probabilities: p_draft(x₁), p_draft(x₂|x₁), ...
4. For each position i:
   - Accept xᵢ with probability min(1, p_target(xᵢ) / p_draft(xᵢ))
   - If rejected, resample from adjusted distribution and stop
5. If all accepted, sample one bonus token from target
```

In [None]:
def speculative_decoding(target_model, draft_model, prompt_tokens, 
                         max_new_tokens, k=4, temperature=1.0):
    """
    Speculative decoding with draft model.
    
    Args:
        target_model: Large, high-quality model
        draft_model: Small, fast model
        prompt_tokens: Initial tokens
        max_new_tokens: Number of tokens to generate
        k: Number of speculative tokens to generate
        temperature: Sampling temperature
    
    Returns:
        Generated tokens, stats dict
    """
    target_model.eval()
    draft_model.eval()
    
    generated = prompt_tokens.clone()
    num_target_forward = 0
    num_draft_forward = 0
    num_accepted = []
    
    with torch.no_grad():
        while (generated.shape[1] - prompt_tokens.shape[1]) < max_new_tokens:
            # Step 1: Draft model generates K speculative tokens
            draft_tokens = generated.clone()
            draft_probs_list = []
            
            for _ in range(k):
                logits = draft_model(draft_tokens)
                num_draft_forward += 1
                
                next_token_logits = logits[:, -1, :] / temperature
                probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
                
                draft_probs_list.append(probs)
                draft_tokens = torch.cat([draft_tokens, next_token], dim=1)
            
            # Step 2: Target model verifies all K tokens in parallel
            target_logits = target_model(draft_tokens)
            num_target_forward += 1
            
            # Step 3: Verification - check which tokens to accept
            accepted = 0
            for i in range(k):
                draft_token = draft_tokens[:, generated.shape[1] + i]
                
                # Get target probability for this token
                target_logits_at_pos = target_logits[:, generated.shape[1] + i - 1, :]
                target_probs = F.softmax(target_logits_at_pos / temperature, dim=-1)
                
                # Get draft probability
                draft_probs = draft_probs_list[i]
                
                # Acceptance criterion: min(1, p_target / p_draft)
                target_prob = target_probs[0, draft_token.item()]
                draft_prob = draft_probs[0, draft_token.item()]
                acceptance_prob = min(1.0, (target_prob / (draft_prob + 1e-10)).item())
                
                # Accept or reject
                if torch.rand(1).item() < acceptance_prob:
                    generated = torch.cat([generated, draft_token.unsqueeze(0)], dim=1)
                    accepted += 1
                else:
                    # Rejection: resample from adjusted distribution
                    adjusted_probs = torch.clamp(target_probs - draft_probs, min=0)
                    adjusted_probs = adjusted_probs / (adjusted_probs.sum() + 1e-10)
                    resampled_token = torch.multinomial(adjusted_probs, num_samples=1)
                    generated = torch.cat([generated, resampled_token], dim=1)
                    break
            
            # If all tokens accepted, sample bonus token from target
            if accepted == k:
                bonus_logits = target_logits[:, -1, :] / temperature
                bonus_probs = F.softmax(bonus_logits, dim=-1)
                bonus_token = torch.multinomial(bonus_probs, num_samples=1)
                generated = torch.cat([generated, bonus_token], dim=1)
                accepted += 1
            
            num_accepted.append(accepted)
            
            # Check if we've generated enough
            if (generated.shape[1] - prompt_tokens.shape[1]) >= max_new_tokens:
                break
    
    stats = {
        'num_target_forward': num_target_forward,
        'num_draft_forward': num_draft_forward,
        'num_accepted_per_round': num_accepted,
        'avg_accepted': np.mean(num_accepted) if num_accepted else 0,
        'tokens_per_target_forward': (generated.shape[1] - prompt_tokens.shape[1]) / num_target_forward
    }
    
    return generated[:, :prompt_tokens.shape[1] + max_new_tokens], stats

# Create draft model (smaller, faster)
draft_model = SimpleLanguageModel(vocab_size, d_model=128, num_layers=2)

# Run speculative decoding
prompt = torch.randint(0, vocab_size, (1, 5))
generated_spec, stats = speculative_decoding(
    target_model=model,
    draft_model=draft_model,
    prompt_tokens=prompt,
    max_new_tokens=10,
    k=4
)

print(f"\nSpeculative Decoding:")
print(f"  Prompt length: {prompt.shape[1]}")
print(f"  Generated length: {generated_spec.shape[1]}")
print(f"  New tokens: {generated_spec.shape[1] - prompt.shape[1]}")
print(f"  Target forward passes: {stats['num_target_forward']}")
print(f"  Draft forward passes: {stats['num_draft_forward']}")
print(f"  Avg tokens accepted per round: {stats['avg_accepted']:.2f}")
print(f"  Tokens per target forward: {stats['tokens_per_target_forward']:.2f}")
print(f"  Speedup vs standard: {10 / stats['num_target_forward']:.2f}×")

## 4. Acceptance Rate Analysis

The speedup depends on the **acceptance rate**: how often draft model matches target model.

**Factors affecting acceptance:**
- Draft model quality (larger draft → higher acceptance)
- Task difficulty (easier tasks → higher acceptance)
- Temperature (higher temperature → easier to match)
- K value (more speculation → lower per-token acceptance)

In [None]:
def analyze_acceptance_rate(target_model, draft_models, prompt, k_values=[2, 4, 8]):
    """
    Analyze how acceptance rate varies with different configurations.
    """
    results = []
    
    # Test different K values
    print("=== Acceptance Rate Analysis ===")
    print()
    
    for k in k_values:
        for model_name, draft_model in draft_models.items():
            _, stats = speculative_decoding(
                target_model=target_model,
                draft_model=draft_model,
                prompt_tokens=prompt,
                max_new_tokens=20,
                k=k
            )
            
            speedup = 20 / stats['num_target_forward']
            results.append({
                'k': k,
                'model': model_name,
                'avg_accepted': stats['avg_accepted'],
                'speedup': speedup
            })
            
            print(f"K={k}, Model={model_name}:")
            print(f"  Avg accepted: {stats['avg_accepted']:.2f} / {k}")
            print(f"  Acceptance rate: {stats['avg_accepted']/k*100:.1f}%")
            print(f"  Speedup: {speedup:.2f}×")
            print()
    
    return results

# Create draft models of different sizes
draft_models = {
    'Tiny (d=64, L=1)': SimpleLanguageModel(vocab_size, d_model=64, num_layers=1),
    'Small (d=128, L=2)': SimpleLanguageModel(vocab_size, d_model=128, num_layers=2),
    'Medium (d=192, L=3)': SimpleLanguageModel(vocab_size, d_model=192, num_layers=3)
}

prompt = torch.randint(0, vocab_size, (1, 10))
results = analyze_acceptance_rate(model, draft_models, prompt, k_values=[2, 4, 6])

## 5. Performance Visualization

Visualize the trade-offs between K, acceptance rate, and speedup.

In [None]:
# Visualize acceptance rate and speedup
import pandas as pd

df = pd.DataFrame(results)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Average accepted tokens
ax = axes[0]
for model_name in df['model'].unique():
    model_data = df[df['model'] == model_name]
    ax.plot(model_data['k'], model_data['avg_accepted'], marker='o', label=model_name, linewidth=2)

# Add diagonal line showing k value
k_vals = df['k'].unique()
ax.plot(k_vals, k_vals, 'k--', alpha=0.3, label='Perfect acceptance (K tokens)')

ax.set_xlabel('K (number of speculative tokens)', fontsize=11)
ax.set_ylabel('Average tokens accepted per round', fontsize=11)
ax.set_title('Acceptance Rate vs K', fontsize=12, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 2: Speedup
ax = axes[1]
for model_name in df['model'].unique():
    model_data = df[df['model'] == model_name]
    ax.plot(model_data['k'], model_data['speedup'], marker='o', label=model_name, linewidth=2)

ax.axhline(y=1.0, color='r', linestyle='--', alpha=0.5, label='No speedup')
ax.set_xlabel('K (number of speculative tokens)', fontsize=11)
ax.set_ylabel('Speedup factor', fontsize=11)
ax.set_title('Speedup vs K', fontsize=12, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nKey Observations:")
print("1. Larger draft models → higher acceptance rate")
print("2. Higher K → more speculation but lower per-token acceptance")
print("3. Optimal K depends on draft model quality")
print("4. Typical speedup: 2-3× for well-matched draft model")

## 6. Speedup Calculation

Theoretical speedup analysis:

**Let:**
- α = acceptance rate (probability draft token accepted)
- K = number of speculative tokens
- T_draft = time for draft model forward pass
- T_target = time for target model forward pass

**Expected tokens per round:**
$$E[\text{tokens}] = \sum_{i=0}^{K-1} (i+1) \cdot \alpha^i (1-\alpha) + (K+1) \cdot \alpha^K$$

**Speedup (ignoring draft cost):**
$$\text{Speedup} \approx E[\text{tokens}] = \frac{1 - \alpha^{K+1}}{1 - \alpha}$$

**With draft cost:**
$$\text{Speedup} = \frac{E[\text{tokens}]}{1 + K \cdot (T_{\text{draft}} / T_{\text{target}})}$$

In [None]:
def theoretical_speedup(alpha, k, draft_cost_ratio=0.1):
    """
    Calculate theoretical speedup for speculative decoding.
    
    Args:
        alpha: Acceptance rate (0 to 1)
        k: Number of speculative tokens
        draft_cost_ratio: T_draft / T_target
    
    Returns:
        Expected speedup
    """
    # Expected tokens accepted per round
    if alpha == 1.0:
        expected_tokens = k + 1
    else:
        expected_tokens = (1 - alpha**(k+1)) / (1 - alpha)
    
    # Time cost per round (normalized to T_target = 1)
    time_per_round = 1 + k * draft_cost_ratio
    
    # Speedup = tokens per unit time
    speedup = expected_tokens / time_per_round
    
    return speedup, expected_tokens

# Visualize speedup surface
alphas = np.linspace(0.1, 0.95, 50)
ks = [2, 4, 6, 8, 10]

plt.figure(figsize=(12, 5))

# Plot 1: Speedup vs acceptance rate
plt.subplot(1, 2, 1)
for k in ks:
    speedups = [theoretical_speedup(alpha, k)[0] for alpha in alphas]
    plt.plot(alphas, speedups, label=f'K={k}', linewidth=2)

plt.xlabel('Acceptance Rate (α)', fontsize=11)
plt.ylabel('Speedup Factor', fontsize=11)
plt.title('Theoretical Speedup vs Acceptance Rate', fontsize=12, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)
plt.axhline(y=1.0, color='r', linestyle='--', alpha=0.3)

# Plot 2: Expected tokens vs acceptance rate
plt.subplot(1, 2, 2)
for k in ks:
    expected = [theoretical_speedup(alpha, k)[1] for alpha in alphas]
    plt.plot(alphas, expected, label=f'K={k}', linewidth=2)

plt.xlabel('Acceptance Rate (α)', fontsize=11)
plt.ylabel('Expected Tokens Per Round', fontsize=11)
plt.title('Expected Tokens Accepted Per Round', fontsize=12, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print some example speedups
print("\n=== Example Speedup Calculations ===")
print()
for alpha in [0.5, 0.7, 0.9]:
    print(f"Acceptance rate: {alpha*100:.0f}%")
    for k in [4, 8]:
        speedup, expected = theoretical_speedup(alpha, k, draft_cost_ratio=0.1)
        print(f"  K={k}: {speedup:.2f}× speedup ({expected:.2f} tokens/round)")
    print()

## 7. Variants of Speculative Decoding

### 7.1 Self-Speculative Decoding
- Use **same model** for draft and target
- Draft uses fewer layers (early exit)
- No need for separate draft model

### 7.2 Multi-Token Prediction
- Train model to predict multiple future tokens
- Use predictions as draft candidates
- No separate draft model needed

### 7.3 Staged Speculative Decoding
- Multiple draft models of increasing size
- Filter candidates through stages
- Better quality-speed trade-off

In [None]:
class SelfSpeculativeModel(nn.Module):
    """
    Model with early exit for self-speculative decoding.
    """
    def __init__(self, vocab_size, d_model, num_layers=6, draft_layers=3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model
        self.num_layers = num_layers
        self.draft_layers = draft_layers
        
        # Transformer layers
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model, nhead=8, dim_feedforward=d_model*4, batch_first=True)
            for _ in range(num_layers)
        ])
        
        # Output heads
        self.draft_output = nn.Linear(d_model, vocab_size)
        self.final_output = nn.Linear(d_model, vocab_size)
    
    def forward(self, x, early_exit=False):
        """
        Args:
            x: Token indices
            early_exit: If True, exit after draft_layers
        """
        x = self.embedding(x) * np.sqrt(self.d_model)
        
        # Run through layers
        for i, layer in enumerate(self.layers):
            x = layer(x)
            
            # Early exit for draft
            if early_exit and i == self.draft_layers - 1:
                return self.draft_output(x)
        
        return self.final_output(x)

# Example
self_spec_model = SelfSpeculativeModel(vocab_size, d_model=256, num_layers=6, draft_layers=3)

x = torch.randint(0, vocab_size, (1, 10))
draft_output = self_spec_model(x, early_exit=True)
full_output = self_spec_model(x, early_exit=False)

print("Self-Speculative Model:")
print(f"  Total layers: {self_spec_model.num_layers}")
print(f"  Draft layers: {self_spec_model.draft_layers}")
print(f"  Draft output shape: {draft_output.shape}")
print(f"  Full output shape: {full_output.shape}")
print(f"\nBenefit: ~{self_spec_model.draft_layers / self_spec_model.num_layers * 100:.0f}% computation for draft")

## 8. Production Considerations

### When to Use Speculative Decoding

**Best for:**
- Long-form generation (stories, articles, code)
- High latency requirements
- Memory-bound inference
- When draft model available

**Not ideal for:**
- Very short generations (< 20 tokens)
- Compute-bound scenarios
- When draft model too slow or unavailable

### Practical Implementation Tips

In [None]:
print("=== Speculative Decoding: Production Guide ===")
print()

print("1. Draft Model Selection:")
print("   - Size: 5-20% of target model parameters")
print("   - Examples:")
print("     • Target: LLaMA 70B → Draft: LLaMA 7B (10%)")
print("     • Target: GPT-3.5 175B → Draft: GPT-3.5 6.7B (4%)")
print("   - Same architecture helps (shared tokenizer, similar outputs)")
print()

print("2. Optimal K Selection:")
print("   - Typical range: K = 4 to 8")
print("   - Higher K if:")
print("     • Draft model very good")
print("     • Draft model very fast")
print("     • Easy task (high acceptance)")
print("   - Lower K if:")
print("     • Draft model slower")
print("     • Low acceptance rate")
print("     • Need low latency")
print()

print("3. Memory Considerations:")
print("   - Need both models in memory")
print("   - Can share KV cache structure")
print("   - Consider model quantization (draft in INT8)")
print()

print("4. Performance Tuning:")
print("   - Measure acceptance rate in production")
print("   - Adjust K based on observed acceptance")
print("   - A/B test different draft models")
print("   - Monitor latency vs. throughput")
print()

print("5. Expected Real-World Speedups:")
print("   - Code generation: 2.5-3.5×")
print("   - Story writing: 2.0-3.0×")
print("   - Chat responses: 1.5-2.5×")
print("   - Technical Q&A: 1.8-2.8×")
print()

print("6. Implementation Checklist:")
print("   ✓ Verify output distribution matches standard decoding")
print("   ✓ Handle edge cases (end of sequence, special tokens)")
print("   ✓ Implement efficient batch verification")
print("   ✓ Monitor GPU utilization")
print("   ✓ Add fallback to standard decoding if needed")
print("   ✓ Log acceptance rates and speedups")

## 9. Comparison with Other Speedup Techniques

In [None]:
import pandas as pd

comparison_data = {
    'Technique': [
        'Speculative Decoding',
        'Continuous Batching',
        'FlashAttention',
        'Quantization (INT8)',
        'KV Cache Optimization',
    ],
    'Speedup': [
        '2-3×',
        '5-10× (throughput)',
        '1.5-2×',
        '1.5-2×',
        '2-3×'
    ],
    'Memory Impact': [
        '+Draft model',
        'Higher utilization',
        '-20-30%',
        '-50-75%',
        'Enables longer context'
    ],
    'Quality Impact': [
        'None (same dist)',
        'None',
        'None',
        'Minimal',
        'None'
    ],
    'Best Use Case': [
        'Long generation',
        'High throughput',
        'All inference',
        'Deployment',
        'Long context'
    ],
    'Complexity': [
        'Medium',
        'High',
        'Medium',
        'Low',
        'Medium'
    ]
}

df = pd.DataFrame(comparison_data)
print("\n=== LLM Inference Optimization Techniques ===")
print()
print(df.to_string(index=False))
print()
print("Note: These techniques are complementary and can be combined!")
print("Example: Speculative + Continuous Batching + FlashAttention = 10-20× total speedup")

## 10. Summary

### Key Concepts

1. **The Problem**
   - LLM generation is memory-bound
   - One token at a time is slow
   - GPU underutilized

2. **The Solution**
   - Generate K draft tokens (fast)
   - Verify all in parallel (one target pass)
   - Accept correct predictions
   - Mathematically equivalent to standard decoding

3. **Key Parameters**
   - **K**: Number of speculative tokens (typically 4-8)
   - **Draft model**: 5-20% size of target model
   - **Acceptance rate**: Determines speedup (aim for >60%)

4. **Expected Speedup**
   - **Typical**: 2-3× faster
   - **Best case**: 4-5× faster (high acceptance)
   - **Worst case**: ~1× (if draft model poor)

5. **Production Considerations**
   - Need both models in memory
   - Monitor acceptance rates
   - Tune K based on task
   - Combine with other optimizations

6. **When to Use**
   - ✅ Long-form generation
   - ✅ Memory-bound inference
   - ✅ When draft model available
   - ❌ Very short outputs
   - ❌ Extreme latency requirements

### Real-World Applications
- **Code generation**: GitHub Copilot, CodeLlama
- **Chat systems**: Faster response generation
- **Content creation**: Articles, stories, documentation
- **Translation**: Long document translation

### Further Reading
- Original paper: "Fast Inference from Transformers via Speculative Decoding" (Chen et al., 2023)
- Self-speculative: "Speculative Decoding with Big Little Decoder" (Zhou et al., 2023)
- Multi-token prediction: "Better & Faster Large Language Models via Multi-token Prediction" (Gloeckle et al., 2024)