# Predictive Prefill with Speculative Decoding

## Problem Statement

**Title**: Implement Predictive Prefill with Speculative Decoding for LLMs

**Description**: You are tasked with implementing Predictive Prefill with Speculative Decoding, a technique to accelerate inference in LLMs by predicting likely token sequences (prefill) and verifying them in parallel (speculative decoding). The system should use a small “draft” model to generate candidate sequences and a larger “target” model to verify them, reducing inference time. Implement a simplified version using PyTorch, with a draft model (small RNN) and target model (larger RNN), on a synthetic sequence dataset. The system should generate tokens, speculatively predict sequences, and verify them, measuring speedup.
Mathematical Definition:

Draft Model: Predicts the next token given a context:
$$p_d(y_t | x_{1:t-1}; \theta_d)$$
where $ y_t $ is the next token, $ x_{1:t-1} $ is the context, $ \theta_d $ are draft model parameters.
Speculative Decoding:

Generate $ k $ candidate tokens $ [y_t, y_{t+1}, \ldots, y_{t+k-1}] $ from the draft model.
Compute probabilities in parallel using the target model:
$$p_t(y_i | x_{1:i-1}; \theta_t), \quad i = t, \ldots, t+k-1$$

Accept tokens where draft and target probabilities align within a threshold:
$$|p_d(y_i) - p_t(y_i)| < \epsilon$$

Roll back to the last accepted token if verification fails.


Loss:

Train both models with cross-entropy loss:
$$L = -\sum_{t} \log p(y_t | x_{1:t-1}; \theta)$$




Requirements:

Implement a SpeculativeDecoder class with methods for:

train: Train draft and target models on a sequence prediction task.
decode: Perform speculative decoding, generating and verifying sequences.


Use a synthetic dataset of integer sequences (e.g., arithmetic progressions).
Draft model: Small RNN (1 layer, 32 hidden units).
Target model: Larger RNN (2 layers, 64 hidden units).
Speculative window: $ k = 3 $ tokens.
Train for 100 epochs with cross-entropy loss and Adam optimizer (learning rate 0.001).
Measure inference speedup compared to standard decoding.
Provide detailed Purpose and Theory comments.

Constraints:

Use PyTorch for model implementation.
No external libraries for decoding logic.
Sequence length: 10 tokens, vocabulary size: 50 tokens.
Handle batch processing for training.
Use $ \epsilon = 0.1 $ for verification threshold.

Synthetic Dataset:

Sequences: 100 sequences of 10 integers (e.g., [1, 2, 3, …, 10], [5, 10, 15, …, 50]).
Vocabulary: Integers 0 to 49.
Test Sequences: 3 sequences to test decoding.
Task: Predict the next token in each sequence.

Expected Output:

Training Loss: Decreases from ~3.0 to ~0.5 for both models.
Decoding: Generates sequences with speculative decoding, accepting/rejecting tokens.
Speedup: Speculative decoding is ~1.5–2x faster than standard decoding (measured in seconds).


In [1]:
import torch
# Purpose: Import PyTorch for tensor operations and model implementation.
# Theory: PyTorch’s autograd supports training RNNs for sequence prediction.

import torch.nn as nn
# Purpose: Import neural network modules for RNN models.
# Theory: nn.Module enables custom draft and target models.

import torch.optim as optim
# Purpose: Import Adam optimizer for training.
# Theory: Adam optimizes RNN parameters with adaptive learning rates.

import time
# Purpose: Import time for measuring inference speedup.
# Theory: Tracks execution time to compare standard vs. speculative decoding.

# Set random seed for reproducibility
torch.manual_seed(42)
# Purpose: Fix random seed for consistent data and model initialization.
# Theory: Aligns with previous problems for reproducibility.

# Synthetic sequence dataset
vocab_size, seq_len, num_sequences = 50, 10, 100
# Purpose: Define dataset parameters: vocabulary size (50), sequence length (10), number of sequences (100).
# Theory: Simulates token sequences for LLM training (e.g., tokenized text).

sequences = torch.randint(0, vocab_size, (num_sequences, seq_len))
# Purpose: Generate random integer sequences, shape [100, 10].
# Theory: Mimics tokenized sequences, with integers as token IDs.

# Define RNN model
class RNNModel(nn.Module):
    # Purpose: Define RNN model for draft or target prediction.
    # Theory: Processes sequences to predict next tokens.
    
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers):
        # Purpose: Initialize RNN with specified parameters.
        # Theory: Configures embedding, RNN, and output layers.
        
        super(RNNModel, self).__init__()
        # Purpose: Call parent nn.Module constructor.
        # Theory: Registers parameters for autograd.
        
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        # Purpose: Map token IDs to embeddings.
        # Theory: Converts integers to dense vectors, shape [batch_size, seq_len, embed_dim].
        
        self.rnn = nn.RNN(embed_dim, hidden_dim, num_layers, batch_first=True)
        # Purpose: Define RNN layers.
        # Theory: Processes sequences, outputting hidden states, shape [batch_size, seq_len, hidden_dim].
        
        self.fc = nn.Linear(hidden_dim, vocab_size)
        # Purpose: Define output layer for token prediction.
        # Theory: Maps hidden states to token probabilities, shape [batch_size, seq_len, vocab_size].
    
    def forward(self, x, hidden=None):
        # Purpose: Compute token probabilities and hidden states.
        # Theory: Processes input sequence through embedding, RNN, and output layers.
        
        embedded = self.embedding(x)
        # Purpose: Convert token IDs to embeddings.
        # Theory: Shape [batch_size, seq_len, embed_dim].
        
        output, hidden = self.rnn(embedded, hidden)
        # Purpose: Process embeddings through RNN.
        # Theory: Outputs sequence representations and final hidden state.
        
        logits = self.fc(output)
        # Purpose: Map RNN outputs to token logits.
        # Theory: Shape [batch_size, seq_len, vocab_size] for probability computation.
        
        return logits, hidden
        # Purpose: Return logits and hidden state.
        # Theory: Logits for loss computation, hidden for next step.

# Define SpeculativeDecoder
class SpeculativeDecoder:
    # Purpose: Implement speculative decoding with draft and target models.
    # Theory: Accelerates inference by predicting and verifying token sequences.
    
    def __init__(self, vocab_size, embed_dim=32, draft_hidden=32, target_hidden=64):
        # Purpose: Initialize draft and target models.
        # Theory: Draft model is smaller (1 layer, 32 units), target is larger (2 layers, 64 units).
        
        self.draft_model = RNNModel(vocab_size, embed_dim, draft_hidden, num_layers=1)
        # Purpose: Initialize small draft model.
        # Theory: Fast model for speculative predictions.
        
        self.target_model = RNNModel(vocab_size, embed_dim, target_hidden, num_layers=2)
        # Purpose: Initialize larger target model.
        # Theory: Accurate model for verifying predictions.
        
        self.vocab_size = vocab_size
        # Purpose: Store vocabulary size.
        # Theory: Defines output space for token predictions.
    
    def train(self, sequences, epochs=100):
        # Purpose: Train both draft and target models.
        # Theory: Uses cross-entropy loss to learn sequence prediction.
        
        criterion = nn.CrossEntropyLoss()
        # Purpose: Define cross-entropy loss.
        # Theory: Measures prediction error for token sequences.
        
        draft_optimizer = optim.Adam(self.draft_model.parameters(), lr=0.001)
        target_optimizer = optim.Adam(self.target_model.parameters(), lr=0.001)
        # Purpose: Initialize Adam optimizers for both models.
        # Theory: Adaptive learning rates for RNN training.
        
        for epoch in range(epochs):
            # Purpose: Iterate over epochs for training.
            # Theory: Updates model parameters to minimize loss.
            
            self.draft_model.train()
            self.target_model.train()
            # Purpose: Set models to training mode.
            # Theory: Enables gradient computation.
            
            draft_optimizer.zero_grad()
            target_optimizer.zero_grad()
            # Purpose: Reset gradients.
            # Theory: Prevents gradient accumulation.
            
            # Prepare input and target (shifted by 1)
            inputs = sequences[:, :-1]
            targets = sequences[:, 1:]
            # Purpose: Create input-target pairs for next-token prediction.
            # Theory: Inputs are [t_1, ..., t_{n-1}], targets are [t_2, ..., t_n].
            
            # Draft model
            draft_logits, _ = self.draft_model(inputs)
            # Purpose: Compute draft model predictions.
            # Theory: Shape [100, 9, vocab_size] for loss computation.
            
            draft_loss = criterion(draft_logits.reshape(-1, self.vocab_size), targets.reshape(-1))
            # Purpose: Compute draft model loss.
            # Theory: Cross-entropy loss over all tokens.
            
            draft_loss.backward()
            draft_optimizer.step()
            # Purpose: Update draft model parameters.
            # Theory: Applies gradient-based updates.
            
            # Target model
            target_logits, _ = self.target_model(inputs)
            # Purpose: Compute target model predictions.
            # Theory: Shape [100, 9, vocab_size] for loss computation.
            
            target_loss = criterion(target_logits.reshape(-1, self.vocab_size), targets.reshape(-1))
            # Purpose: Compute target model loss.
            # Theory: Cross-entropy loss for accurate model.
            
            target_loss.backward()
            target_optimizer.step()
            # Purpose: Update target model parameters.
            # Theory: Applies gradient-based updates.
            
            if (epoch + 1) % 10 == 0:
                # Purpose: Print loss every 10 epochs.
                # Theory: Monitors training progress.
                
                print(f"Draft Model Epoch [{epoch + 1}/{epochs}], Loss: {draft_loss.item():.4f}")
                print(f"Target Model Epoch [{epoch + 1}/{epochs}], Loss: {target_loss.item():.4f}")
    
    def standard_decode(self, context, max_len=4):
        # Purpose: Perform standard decoding with the target model.
        # Theory: Generates one token at a time, baseline for speedup comparison.
        
        self.target_model.eval()
        # Purpose: Set target model to evaluation mode.
        # Theory: Disables gradient computation for inference.
        
        generated = context.tolist()
        # Purpose: Convert context tensor to list.
        # Theory: Initial sequence to start generation.
        
        hidden = None
        # Purpose: Initialize hidden state for RNN.
        # Theory: Tracks sequence context.
        
        with torch.no_grad():
            # Purpose: Disable gradient tracking.
            # Theory: Saves memory during inference.
            
            for _ in range(max_len):
                # Purpose: Generate max_len tokens.
                # Theory: Extends sequence one token at a time.
                
                input_tensor = torch.tensor([generated], dtype=torch.long)
                # Purpose: Convert current sequence to tensor.
                # Theory: Shape [1, len(generated)] for RNN input.
                
                logits, hidden = self.target_model(input_tensor, hidden)
                # Purpose: Compute next token probabilities.
                # Theory: Uses last hidden state for continuity.
                
                probs = torch.softmax(logits[:, -1, :], dim=-1)
                # Purpose: Convert logits to probabilities.
                # Theory: Shape [1, vocab_size] for token selection.
                
                next_token = torch.argmax(probs, dim=-1).item()
                # Purpose: Select most likely token.
                # Theory: Greedy decoding for simplicity.
                
                generated.append(next_token)
                # Purpose: Append token to sequence.
                # Theory: Builds the generated sequence.
        
        return generated[len(context):]
        # Purpose: Return generated tokens (excluding context).
        # Theory: Outputs continuation of the input sequence.
    
    def speculative_decode(self, context, max_len=4, k=3, epsilon=0.1):
        # Purpose: Perform speculative decoding with draft and target models.
        # Theory: Predicts k tokens with draft model, verifies with target model.
        
        self.draft_model.eval()
        self.target_model.eval()
        # Purpose: Set models to evaluation mode.
        # Theory: Disables gradients for inference.
        
        generated = context.tolist()
        # Purpose: Convert context tensor to list.
        # Theory: Initial sequence for generation.
        
        draft_hidden = None
        target_hidden = None
        # Purpose: Initialize hidden states for both models.
        # Theory: Tracks sequence context for RNNs.
        
        with torch.no_grad():
            # Purpose: Disable gradient tracking.
            # Theory: Saves memory during inference.
            
            while len(generated) - len(context) < max_len:
                # Purpose: Generate until max_len tokens are added.
                # Theory: Extends sequence with speculative predictions.
                
                # Draft model predicts k tokens
                draft_tokens = []
                temp_generated = generated.copy()
                temp_draft_hidden = draft_hidden
                # Purpose: Initialize temporary sequence and hidden state.
                # Theory: Allows rollback if verification fails.
                
                for _ in range(k):
                    # Purpose: Predict k tokens with draft model.
                    # Theory: Generates speculative sequence.
                    
                    input_tensor = torch.tensor([temp_generated], dtype=torch.long)
                    # Purpose: Convert current sequence to tensor.
                    # Theory: Shape [1, len(temp_generated)] for draft model.
                    
                    logits, temp_draft_hidden = self.draft_model(input_tensor, temp_draft_hidden)
                    # Purpose: Compute draft model predictions.
                    # Theory: Outputs logits and updates hidden state.
                    
                    probs = torch.softmax(logits[:, -1, :], dim=-1)
                    # Purpose: Convert logits to probabilities.
                    # Theory: Shape [1, vocab_size] for token selection.
                    
                    next_token = torch.argmax(probs, dim=-1).item()
                    # Purpose: Select most likely token.
                    # Theory: Greedy decoding for draft predictions.
                    
                    draft_tokens.append(next_token)
                    temp_generated.append(next_token)
                    # Purpose: Append token to temporary sequence.
                    # Theory: Builds speculative sequence.
                
                # Verify with target model
                input_tensor = torch.tensor([generated + draft_tokens], dtype=torch.long)
                # Purpose: Prepare input with draft tokens.
                # Theory: Shape [1, len(generated) + k] for target model verification.
                
                target_logits, target_hidden = self.target_model(input_tensor, target_hidden)
                # Purpose: Compute target model probabilities.
                # Theory: Verifies all k tokens in one pass.
                
                accepted = 0
                # Purpose: Track number of accepted tokens.
                # Theory: Counts tokens where draft and target agree.
                
                for i in range(len(draft_tokens)):
                    # Purpose: Verify each draft token.
                    # Theory: Compares draft and target probabilities.
                    
                    draft_prob = torch.softmax(self.draft_model(torch.tensor([generated + draft_tokens[:i+1]]))[0][:, -1, :], dim=-1)
                    target_prob = torch.softmax(target_logits[:, len(generated) + i - len(context), :], dim=-1)
                    # Purpose: Compute probabilities for the current token.
                    # Theory: Shape [1, vocab_size] for comparison.
                    
                    if abs(draft_prob[0, draft_tokens[i]] - target_prob[0, draft_tokens[i]]) < epsilon:
                        # Purpose: Check if probabilities are within threshold.
                        # Theory: Accepts token if draft and target agree.
                        
                        accepted += 1
                        generated.append(draft_tokens[i])
                        # Purpose: Append accepted token.
                        # Theory: Extends sequence with verified token.
                    else:
                        break
                        # Purpose: Stop at first rejected token.
                        # Theory: Rolls back to last accepted token.
                
                draft_hidden = temp_draft_hidden if accepted == k else None
                # Purpose: Update draft hidden state if all tokens accepted.
                # Theory: Maintains continuity for next iteration.
                
                if accepted == 0:
                    # Purpose: Handle case where no tokens are accepted.
                    # Theory: Falls back to standard decoding for one token.
                    
                    input_tensor = torch.tensor([generated], dtype=torch.long)
                    logits, target_hidden = self.target_model(input_tensor, target_hidden)
                    probs = torch.softmax(logits[:, -1, :], dim=-1)
                    next_token = torch.argmax(probs, dim=-1).item()
                    generated.append(next_token)
                    draft_hidden = None
                    # Purpose: Append one verified token.
                    # Theory: Ensures progress even if speculation fails.
        
        return generated[len(context):], accepted
        # Purpose: Return generated tokens and number accepted.
        # Theory: Outputs continuation and speculation success rate.

# Test SpeculativeDecoder
if __name__ == "__main__":
    # Purpose: Test speculative decoding implementation.
    # Theory: Trains models and compares standard vs. speculative decoding.
    
    decoder = SpeculativeDecoder(vocab_size)
    # Purpose: Initialize decoder with draft and target models.
    # Theory: Sets up models for sequence prediction.
    
    decoder.train(sequences, epochs=100)
    # Purpose: Train both models on the dataset.
    # Theory: Optimizes for next-token prediction.
    
    context = torch.tensor([1, 2, 3])
    # Purpose: Define test context for decoding.
    # Theory: Simulates initial sequence for generation.
    
    # Standard decoding
    start_time = time.time()
    standard_output = decoder.standard_decode(context)
    standard_time = time.time() - start_time
    # Purpose: Perform and time standard decoding.
    # Theory: Baseline for speed comparison.
    
    # Speculative decoding
    start_time = time.time()
    speculative_output, accepted = decoder.speculative_decode(context, k=3)
    speculative_time = time.time() - start_time
    # Purpose: Perform and time speculative decoding.
    # Theory: Tests speedup and accuracy of speculation.
    
    print(f"Standard Decoding Time: {standard_time:.4f}s")
    print(f"Speculative Decoding Time: {speculative_time:.4f}s")
    print(f"Sequence: {context.tolist()}")
    print(f"Standard Decoding: {standard_output}")
    print(f"Speculative Decoding: {speculative_output} ({accepted}/3 tokens accepted)")

Draft Model Epoch [10/100], Loss: 3.8904
Target Model Epoch [10/100], Loss: 3.8147
Draft Model Epoch [20/100], Loss: 3.8239
Target Model Epoch [20/100], Loss: 3.6881
Draft Model Epoch [30/100], Loss: 3.7611
Target Model Epoch [30/100], Loss: 3.5427
Draft Model Epoch [40/100], Loss: 3.6991
Target Model Epoch [40/100], Loss: 3.3781
Draft Model Epoch [50/100], Loss: 3.6365
Target Model Epoch [50/100], Loss: 3.1969
Draft Model Epoch [60/100], Loss: 3.5727
Target Model Epoch [60/100], Loss: 3.0011
Draft Model Epoch [70/100], Loss: 3.5077
Target Model Epoch [70/100], Loss: 2.7908
Draft Model Epoch [80/100], Loss: 3.4415
Target Model Epoch [80/100], Loss: 2.5677
Draft Model Epoch [90/100], Loss: 3.3741
Target Model Epoch [90/100], Loss: 2.3366
Draft Model Epoch [100/100], Loss: 3.3058
Target Model Epoch [100/100], Loss: 2.1028
Standard Decoding Time: 0.0060s
Speculative Decoding Time: 0.0080s
Sequence: [1, 2, 3]
Standard Decoding: [4, 40, 39, 8]
Speculative Decoding: [9, 27, 41, 3, 38] (3/3 t