# CSE 234 Programming Assignment 3: Speculative Decoding

## Setup

In [1]:
import os
import torch
import time
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Tuple, Dict, Optional

## Speculative Decoding

In [5]:
class SpeculativeDecoder:
    def __init__(self, target_model_name: str, draft_model_name: str, device: str = "cuda"):
        """
        Initialize the speculative decoder with target and draft models.

        Args:
            target_model_name: HuggingFace model ID for the larger target model.
            draft_model_name: HuggingFace model ID for the smaller draft model.
            device: Device to run models on ("cuda" or "cpu").
        """
        self.device = device
        self.target_model, self.target_tokenizer = self.initialize_target_model(target_model_name)
        self.draft_model, self.draft_tokenizer = self.initialize_draft_model(draft_model_name)

        # Ensure tokenizers are compatible
        assert self.target_tokenizer.vocab == self.draft_tokenizer.vocab, "Tokenizers must be compatible"

    def initialize_target_model(self, model_name: str):
        """Initialize the larger target model with caching enabled and proper pad token."""
        print(f"Loading target model: {model_name}")
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        # TODO: Implement target model initialization
        # 1. Set the pad token if it doesn't exist
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        # 2. Load the model with appropriate settings for inference
        model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
        model.eval()

        # 3. Enable any optimizations that might help with performance
        model = torch.compile(model)

        return model, tokenizer

    def initialize_draft_model(self, model_name: str):
        """
        Initialize a smaller, faster draft model with proper pad token.
        Uses lower precision and additional optimizations.
        """
        print(f"Loading draft model: {model_name}")
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        # TODO: Implement draft model initialization
        # 1. Set the pad token if it doesn't exist
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token


        # 2. Load the model with appropriate settings for inference
        model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
        model.generation_config.pad_token_id = tokenizer.pad_token_id
        model.eval()

        # 3. Enable any optimizations that might help with performance
        #model = torch.compile(model)
        model.half()

        return model, tokenizer

    def generate_draft_tokens(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
                             num_speculative_tokens: int = 10) -> torch.Tensor:
        """
        Generate speculative tokens in one forward call using the draft model.

        Args:
            input_ids: Input token IDs (tensor of shape [1, seq_len]).
            attention_mask: Corresponding attention mask.
            num_speculative_tokens: Number of tokens to speculate.

        Returns:
            Tensor of shape [1, num_speculative_tokens] containing the draft tokens.
        """
        # TODO: Implement draft token generation
        # 1. Use the draft model to generate tokens
        with torch.inference_mode(): #torch.no_grad():
            draft_tokens = self.draft_model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=num_speculative_tokens
            )

        # 2. Extract only the new tokens (not including the input)
        #print('INPUT TOKENS', input_ids)
        #print('DRAFT TOKENS', draft_tokens)
        new_tokens = draft_tokens[:, input_ids.shape[1]:]

        # 3. Return the newly generated tokens
        return new_tokens

    def verify_tokens_vectorized(self, input_ids: torch.Tensor, draft_tokens: torch.Tensor,
                               attention_mask: torch.Tensor) -> Tuple[List[int], int]:
        """
        Vectorized verification: verify all draft tokens in one forward pass using the target model.

        Args:
            input_ids: The current input token IDs (shape [1, L]).
            draft_tokens: Draft tokens from the draft model (shape [1, k]).
            attention_mask: The current attention mask for input_ids.

        Returns:
            accepted_tokens: List of accepted token IDs.
            accepted_position: Index of the first rejected token (if all accepted, equals draft_tokens.shape[1]).
        """
        # TODO: Implement efficient verification of draft tokens
        # 1. Run target model on input_ids concatenated with draft_tokens
        extended_input = torch.cat([input_ids, draft_tokens], dim=1)
        extended_mask = torch.cat([attention_mask, torch.ones_like(draft_tokens, device=self.device)], dim=1)

        # 2. Extract the logits for positions where draft tokens would be predicted
        with torch.inference_mode(): #torch.no_grad():
            logits = self.target_model(extended_input, attention_mask=extended_mask).logits

            # 3. Compare target model predictions with draft tokens
            start_idx = input_ids.shape[1]-1
            end_idx = start_idx + draft_tokens.shape[1]
            target_predictions = logits[:, start_idx:end_idx, :]
            matches = (target_predictions.argmax(dim=-1) == draft_tokens)


        #print('target predictions', target_predictions.argmax(dim=-1))
        #print('draft tokens', draft_tokens)
        #print('First mismatch', first_mismatch_idx)
        # 4. Determine how many consecutive tokens were accepted before first mismatch
        accepted_position = torch.where(~matches)[1].min().item() if not matches.all() else draft_tokens.shape[1]
        accepted_tokens = draft_tokens[:, :accepted_position].tolist() #[0]

        return accepted_tokens, accepted_position

    def speculative_decode(self, prompt: str, max_tokens: int = 100,
                          num_speculative_tokens: int = 15) -> str: #15 for pythia, 9 for gpt2
        """
        Main speculative decoding algorithm with vectorized verification.

        Args:
            prompt: Input text.
            max_tokens: Maximum number of tokens to generate (excluding prompt).
            num_speculative_tokens: Number of tokens to speculate per iteration.

        Returns:
            Generated text.
        """
        # Tokenize prompt
        inputs = self.target_tokenizer(prompt, return_tensors="pt", padding=True)
        input_ids = inputs["input_ids"].to(self.device)
        attention_mask = inputs["attention_mask"].to(self.device)
        prompt_length = input_ids.shape[1]

        # Initialize counters for performance tracking
        total_tokens_generated = prompt_length
        total_draft_tokens_proposed = 0
        total_draft_tokens_accepted = 0
        start_time = time.time()

        # TODO: Implement the core speculative decoding loop
        # 1. Generate draft tokens using the draft model
        while total_tokens_generated - prompt_length < max_tokens:
            draft_tokens = self.generate_draft_tokens(input_ids, attention_mask, num_speculative_tokens)
            total_draft_tokens_proposed += draft_tokens.shape[1]

            # 2. Verify draft tokens using the target model
            accepted_tokens, accepted_position = self.verify_tokens_vectorized(input_ids, draft_tokens, attention_mask)

            # 3. Accept verified tokens and append to the sequence
            total_draft_tokens_accepted += len(accepted_tokens[0])
            input_ids = torch.cat([input_ids, draft_tokens[:, :accepted_position]], dim=1)
            attention_mask = torch.cat([attention_mask, torch.ones_like(draft_tokens[:, :accepted_position], device=self.device)], dim=1)
            total_tokens_generated += accepted_position

            # 4. For rejected tokens or if all tokens are accepted, generate a new token with the target model
            if accepted_position < draft_tokens.shape[1]:
                with torch.inference_mode(): #torch.no_grad():
                    logits = self.target_model(input_ids, attention_mask=attention_mask).logits[:, -1, :]  # Get the logits for the last token
                    next_token = logits.argmax(dim=-1, keepdim=True)  # Select most likely token

                input_ids = torch.cat([input_ids, next_token], dim=1)
                attention_mask = torch.cat([attention_mask, torch.ones_like(next_token, device=self.device)], dim=1)
                total_tokens_generated += 1

            # 5. Stop when max_tokens is reached or an EOS token is generated
            if self.target_tokenizer.eos_token_id in input_ids[0]:
                break

        # Calculate performance metrics
        elapsed_time = time.time() - start_time
        acceptance_rate = total_draft_tokens_accepted / total_draft_tokens_proposed if total_draft_tokens_proposed > 0 else 0

        print(f"Generated {total_tokens_generated - prompt_length} tokens in {elapsed_time:.2f} seconds")
        print(f"Tokens per second: {(total_tokens_generated - prompt_length) / elapsed_time:.2f}")
        print(f"Draft token acceptance rate: {acceptance_rate:.2%}")

        return self.target_tokenizer.decode(input_ids[0], skip_special_tokens=True)

    def benchmark(self, prompt: str, max_tokens: int = 100,
                  num_runs: int = 3, compare_baseline: bool = True) -> Dict:
        """
        Benchmark the speculative decoder against baseline decoding.

        Args:
            prompt: Input text.
            max_tokens: Maximum number of tokens to generate.
            num_runs: Number of benchmark runs.
            compare_baseline: Whether to compare with baseline (non-speculative) decoding.

        Returns:
            Dictionary with benchmark results.
        """
        results = {
            "speculative": {"times": [], "tokens_per_second": []},
            "baseline": {"times": [], "tokens_per_second": []} if compare_baseline else None
        }

        # Benchmark speculative decoding.
        for _ in range(num_runs):
            start_time = time.time()
            output = self.speculative_decode(prompt, max_tokens=max_tokens)
            elapsed = time.time() - start_time
            prompt_len = len(self.target_tokenizer(prompt)["input_ids"])
            output_tokens = len(self.target_tokenizer.encode(output)) - prompt_len
            tps = output_tokens / elapsed
            results["speculative"]["times"].append(elapsed)
            results["speculative"]["tokens_per_second"].append(tps)

        # Benchmark baseline decoding.
        if compare_baseline:
            for _ in range(num_runs):
                inputs = self.target_tokenizer(prompt, return_tensors="pt", padding=True)
                input_ids = inputs["input_ids"].to(self.device)
                attention_mask = inputs["attention_mask"].to(self.device)
                start_time = time.time()
                with torch.no_grad():
                    output_ids = self.target_model.generate(
                        input_ids,
                        attention_mask=attention_mask,
                        max_length=input_ids.shape[1] + max_tokens,
                        do_sample=False,
                        pad_token_id=self.target_tokenizer.pad_token_id
                    )
                elapsed = time.time() - start_time
                output_tokens = output_ids.shape[1] - input_ids.shape[1]
                tps = output_tokens / elapsed
                results["baseline"]["times"].append(elapsed)
                results["baseline"]["tokens_per_second"].append(tps)

        for method in results.keys():
            if results[method] is not None:
                avg_time = sum(results[method]["times"]) / num_runs
                avg_tps = sum(results[method]["tokens_per_second"]) / num_runs
                results[method]["avg_time"] = avg_time
                results[method]["avg_tokens_per_second"] = avg_tps

        if compare_baseline:
            speedup = results["baseline"]["avg_time"] / results["speculative"]["avg_time"]
            results["speedup"] = speedup
            results["latency_reduction"] = (1 - results["speculative"]["avg_time"] / results["baseline"]["avg_time"]) * 100
            # print(f"Speculative decoding speedup: {speedup:.2f}x")
            # print(f"Latency reduction: {results['latency_reduction']:.2f}%")

        return results

## Test

In [6]:
target_model_name = "EleutherAI/pythia-1.4b-deduped"  # Larger target model
draft_model_name = "EleutherAI/pythia-160m-deduped"   # Smaller draft model


# Initialize speculative decoder
decoder = SpeculativeDecoder(
    target_model_name=target_model_name,
    draft_model_name=draft_model_name,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Test prompts
test_prompts = [
    "The future of Artificial Intelligence is",
    "Write a short story about a robot learning to feel emotions:",
    "Write the lyrics to the song 'Happy Birthday'."
]

# Run benchmark on test prompts
for i, prompt in enumerate(test_prompts):
    print(f"\nBenchmarking Prompt {i+1}:")
    print(f"Prompt: {prompt}")

    results = decoder.benchmark(
        prompt=prompt,
        max_tokens=100,
        num_runs=3,
        compare_baseline=True
    )

    print(f"Average speculative decoding time: {results['speculative']['avg_time']:.2f} seconds")
    print(f"Average speculative tokens per second: {results['speculative']['avg_tokens_per_second']:.2f}")

    if results["baseline"] is not None:
        print(f"Average baseline decoding time: {results['baseline']['avg_time']:.2f} seconds")
        print(f"Average baseline tokens per second: {results['baseline']['avg_tokens_per_second']:.2f}")
        print(f"Speedup: {results['speedup']:.2f}x")
        print(f"Latency reduction: {results['latency_reduction']:.2f}%")

Loading target model: EleutherAI/pythia-1.4b-deduped
Loading draft model: EleutherAI/pythia-160m-deduped

Benchmarking Prompt 1:
Prompt: The future of Artificial Intelligence is
Generated 106 tokens in 1.85 seconds
Tokens per second: 57.27
Draft token acceptance rate: 87.50%
Generated 106 tokens in 1.77 seconds
Tokens per second: 59.81
Draft token acceptance rate: 87.50%
Generated 106 tokens in 1.80 seconds
Tokens per second: 58.80
Draft token acceptance rate: 87.50%
Average speculative decoding time: 1.81 seconds
Average speculative tokens per second: 58.58
Average baseline decoding time: 2.71 seconds
Average baseline tokens per second: 36.99
Speedup: 1.50x
Latency reduction: 33.27%

Benchmarking Prompt 2:
Prompt: Write a short story about a robot learning to feel emotions:
Generated 114 tokens in 1.86 seconds
Tokens per second: 61.24
Draft token acceptance rate: 94.17%
Generated 114 tokens in 1.84 seconds
Tokens per second: 61.95
Draft token acceptance rate: 94.17%
Generated 114 toke

## Bonus

In [12]:
target_model_name = "openai-community/gpt2-large"  # Larger target model
draft_model_name = "openai-community/gpt2"   # Smaller draft model


# Initialize speculative decoder
decoder = SpeculativeDecoder(
    target_model_name=target_model_name,
    draft_model_name=draft_model_name,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Test prompts
test_prompts = [
    "The future of Artificial Intelligence is",
    "Write a short story about a robot learning to feel emotions:",
    "Write the lyrics to the song 'Happy Birthday'."
]

# Run benchmark on test prompts
for i, prompt in enumerate(test_prompts):
    print(f"\nBenchmarking Prompt {i+1}:")
    print(f"Prompt: {prompt}")

    results = decoder.benchmark(
        prompt=prompt,
        max_tokens=100,
        num_runs=3,
        compare_baseline=True
    )

    print(f"Average speculative decoding time: {results['speculative']['avg_time']:.2f} seconds")
    print(f"Average speculative tokens per second: {results['speculative']['avg_tokens_per_second']:.2f}")

    if results["baseline"] is not None:
        print(f"Average baseline decoding time: {results['baseline']['avg_time']:.2f} seconds")
        print(f"Average baseline tokens per second: {results['baseline']['avg_tokens_per_second']:.2f}")
        print(f"Speedup: {results['speedup']:.2f}x")
        print(f"Latency reduction: {results['latency_reduction']:.2f}%")

Loading target model: openai-community/gpt2-large
Loading draft model: openai-community/gpt2

Benchmarking Prompt 1:
Prompt: The future of Artificial Intelligence is
Generated 105 tokens in 1.62 seconds
Tokens per second: 64.69
Draft token acceptance rate: 96.30%
Generated 105 tokens in 1.72 seconds
Tokens per second: 61.03
Draft token acceptance rate: 96.30%
Generated 105 tokens in 1.77 seconds
Tokens per second: 59.48
Draft token acceptance rate: 96.30%
Average speculative decoding time: 1.71 seconds
Average speculative tokens per second: 61.58
Average baseline decoding time: 3.03 seconds
Average baseline tokens per second: 33.02
Speedup: 1.78x
Latency reduction: 43.70%

Benchmarking Prompt 2:
Prompt: Write a short story about a robot learning to feel emotions:
Generated 102 tokens in 3.13 seconds
Tokens per second: 32.61
Draft token acceptance rate: 56.17%
Generated 102 tokens in 2.42 seconds
Tokens per second: 42.22
Draft token acceptance rate: 56.17%
Generated 102 tokens in 2.34 s

## Report

My approach to implementing speculative decoding was to follow the skeleton code provided, and then focus on tuning the number of speculative tokens generated. As for setting up the speculative decoding implementation, it uses a two-stage approach to accelerate text generation. First, a smaller, faster draft model first predicts multiple tokens in a single forward pass. These draft tokens are then verified in parallel by a larger target model, which compares its own predicted probabilities to the draft outputs. If the draft tokens match the target model’s expectations, they are accepted, reducing the number of expensive target model forward passes. When a token is rejected, the target model directly generates the next token. This process should optimize efficiency with the smaller, quicker draft model, while still maintaining the quality of the target model, ultimately leading to faster decoding. As for my approach of tuning the number of speculative tokens generated, I focused on using as many speculative tokens as possible while still maintaining high draft token acceptance rate.


One of the optimizations I implemented was using torch.compile() on the target model and torch.half() on the draft model. Adding compiler optimizations should optimize the computational graph to process draft tokens more efficiently, while reducing overhead in final token acceptance. Using half-precision on the draft model should reduce memory consumption and transfer, which should also lead to higher speedups. **However it is very important to note that since I am using torch.compile(), there needs to be one warmup run of the test before I am able to achieve good results.** This is because the computational graph needs to be traced initially so it knows how to optimize the kernel execution, so there needs to be one warmup run before recording results, so that once CUDA kernels are already compiled and cached, inference runs much faster.


Another optimization I implemented was tuning the number of speculative tokens generated. I found that with the baseline pythia models, I could increase the number of speculative tokens to around 15-20 and achieve high acceptance. I found that when I increased the number of tokens to 20, I was getting higher speedup, but then the acceptance rate would drop slightly, so I ended up using 15 speculative tokens which worked well for these models. For the gpt-2 models (used in the bonus section), I found that it was quite easy to achieve significant speedup using anywhere from 10-20 speculative tokens generated, but the acceptance rate varied significantly. For example, when using 15-20 speculative tokens, the speedup could reach up to 2x for the first prompt and around 1.7x for the remaining two prompts, but the acceptance rate for the second two prompts would be ~45-55% (the first prompt always had >85% acceptance, regardless of the number of speculative tokens generated). However for these gpt-2 models, decreasing the speculative tokens to 9 led to the highest acceptance rates for all prompts. Still though, the gpt-2 models struggled with the second prompt, even after finding the optimal number of speculative tokens.


To reiterate the results, I found that there was a tradeoff between achieving higher speedup or achieving higher draft token acceptance. Ultimately, I found that it was easier to increase the acceptance rate as opposed to the speedup, so this was the metric I chose to optimize. I ended up achieving 87.50%, 94.17%, and 89.17% draft token acceptance rates for the three respective prompts, with 1.50x, 1.45x, and 1.41x respective speedups with num_speculative_tokens=15. Also again I will mention that in order to achieve these results, the test needs to be run one time as a warm up so that the graph can be optimized properly.   


One significant challenge I found was that in my initial implementation, I was only able to achieve 1/(# speculative tokens generated)% acceptance. For example, with the baseline of 15 speculative tokens generated, I was consistently only getting 1/15 = 6.67% draft token acceptance. No matter what optimizations I tried, I wasn’t able to change this acceptance rate. Then I started printing out what were the actual predicted tokens and what were the draft tokens, and I saw that the predicted tokens were shifted one from the draft tokens, which is why I was never able to get more than 1 accepted draft token. Thus, I decremented the predicted tokens by 1 (in the verify_tokens_vectorized() function: start_idx = input_ids.shape[1]-1), and this solved my problem. Now I was properly comparing predicted tokens with draft tokens, and my acceptance rate shot up to 80+%.
