# Module 9.6: Speculative Decoding Deep Dive

**Goal**: Implement and optimize speculative decoding

**Time**: 90 minutes

**Concepts Covered**:
- Draft + target model setup
- Parallel verification
- Acceptance/rejection logic
- Speedup measurement
- Optimal k selection

## Setup

In [None]:
!pip install torch transformers accelerate matplotlib seaborn numpy -q

In [None]:
import torch
import torch.nn.functional as F

class SpeculativeDecoder:
    """Speculative decoding for faster generation"""
    def __init__(self, draft_model, target_model):
        self.draft_model = draft_model  # Small, fast model
        self.target_model = target_model  # Large, accurate model
    
    def generate_draft(self, input_ids, k=5):
        """Generate k tokens with draft model"""
        draft_tokens = []
        current_ids = input_ids
        
        for _ in range(k):
            logits = self.draft_model(current_ids).logits[:, -1, :]
            next_token = torch.argmax(logits, dim=-1)
            draft_tokens.append(next_token)
            current_ids = torch.cat([current_ids, next_token.unsqueeze(1)], dim=1)
        
        return torch.stack(draft_tokens, dim=1)
    
    def verify_draft(self, input_ids, draft_tokens):
        """Verify draft tokens with target model"""
        # Run target model on input + draft tokens
        full_sequence = torch.cat([input_ids, draft_tokens], dim=1)
        target_logits = self.target_model(full_sequence).logits
        
        # Acceptance probability for each draft token
        accepted_tokens = []
        current_ids = input_ids
        
        for i, draft_token in enumerate(draft_tokens.unbind(1)):
            # Get target model probability for draft token
            target_probs = F.softmax(target_logits[:, input_ids.size(1) + i - 1, :], dim=-1)
            draft_prob = target_probs.gather(1, draft_token.unsqueeze(1))
            
            # Get draft model probability
            draft_logits = self.draft_model(current_ids).logits[:, -1, :]
            draft_model_prob = F.softmax(draft_logits, dim=-1).gather(1, draft_token.unsqueeze(1))
            
            # Acceptance probability
            accept_prob = torch.min(torch.ones_like(draft_prob), target_probs / (draft_model_prob + 1e-10))
            
            # Accept or reject
            if torch.rand(1) < accept_prob:
                accepted_tokens.append(draft_token)
                current_ids = torch.cat([current_ids, draft_token.unsqueeze(1)], dim=1)
            else:
                # Sample from adjusted distribution
                adjusted_probs = F.normalize(torch.clamp(target_probs - draft_model_prob, min=0), p=1, dim=1)
                new_token = torch.multinomial(adjusted_probs, 1)
                accepted_tokens.append(new_token.squeeze(1))
                break
        
        return torch.stack(accepted_tokens, dim=1) if accepted_tokens else None

print("Speculative Decoding:")
print("- Draft model: Fast, generates k tokens")
print("- Target model: Accurate, verifies draft")
print("- Speedup: 2-3x for compatible models")
print("- Optimal k: Usually 3-5 tokens")

## Key Takeaways

✅ **Module Complete**

## Next Steps

Continue to the next module in the course.