# CSE 234 Programming Assignment 3: Speculative Decoding

## Setup

In [1]:
import os
import torch
import time
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Tuple, Dict, Optional

## Speculative Decoding

In [2]:
class SpeculativeDecoder:
    def __init__(self, target_model_name: str, draft_model_name: str, device: str = "cuda"):
        self.device = device
        self.target_model, self.target_tokenizer = self.initialize_target_model(target_model_name)
        self.draft_model, self.draft_tokenizer = self.initialize_draft_model(draft_model_name)
        assert self.target_tokenizer.vocab == self.draft_tokenizer.vocab, "Tokenizers must be compatible"

        # Ensure PAD token is set
        if self.target_tokenizer.pad_token is None:
            self.target_tokenizer.pad_token = self.target_tokenizer.eos_token
        if self.draft_tokenizer.pad_token is None:
            self.draft_tokenizer.pad_token = self.draft_tokenizer.eos_token

    def initialize_target_model(self, model_name: str):
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(
            model_name, device_map="cuda:0"
        ).eval()
        model.config.use_cache = True
        return model, tokenizer

    def initialize_draft_model(self, model_name: str):
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(
            model_name, device_map="cuda:0"
        ).eval()
        model.config.use_cache = True
        return model, tokenizer

    def generate_draft_tokens(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, num_speculative_tokens) -> torch.Tensor:
        with torch.no_grad():
            output = self.draft_model.generate(
                input_ids,
                attention_mask=attention_mask,
                max_new_tokens=num_speculative_tokens,
                pad_token_id=self.draft_tokenizer.pad_token_id,
                do_sample=False
                # do_sample=True,
                # top_k=25,
                # temperature=0.25
            )
        return output[:, -num_speculative_tokens:]

    def verify_tokens_vectorized(self, input_ids: torch.Tensor, draft_tokens: torch.Tensor, attention_mask: torch.Tensor) -> Tuple[List[int], int]:
        combined_ids = torch.cat([input_ids, draft_tokens], dim=1)
        combined_mask = torch.cat([attention_mask, torch.ones_like(draft_tokens, dtype=torch.long)], dim=1)

        with torch.no_grad():
            outputs = self.target_model(input_ids=combined_ids, attention_mask=combined_mask)

        # Extract logits corresponding to the draft token positions.
        logits = outputs.logits[:, input_ids.shape[1] - 1:-1, :]
        predicted_tokens = logits.argmax(dim=-1)

        # Vectorized comparison between draft and predicted tokens.
        correct = (draft_tokens == predicted_tokens)
        mismatches = (~correct[0]).nonzero(as_tuple=True)[0]

        if mismatches.numel() > 0:
            first_mismatch = mismatches[0].item()
            accepted_tokens = draft_tokens[0, :first_mismatch].tolist()
            return accepted_tokens, first_mismatch
        else:
            accepted_tokens = draft_tokens[0].tolist()
            return accepted_tokens, draft_tokens.shape[1]

    def speculative_decode(self, prompt: str, max_tokens: int = 100, num_speculative_tokens: int = 15) -> str:
        inputs = self.target_tokenizer(prompt, return_tensors="pt", padding=True).to(self.device)
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        initial_length = input_ids.shape[1]

        total_draft_tokens_proposed = 0
        total_draft_tokens_accepted = 0
        start_time = time.time()

        while (input_ids.shape[1] - initial_length) < max_tokens:
            draft_tokens = self.generate_draft_tokens(input_ids, attention_mask, num_speculative_tokens)
            total_draft_tokens_proposed += draft_tokens.shape[1]
            accepted_tokens, first_rejected = self.verify_tokens_vectorized(input_ids, draft_tokens, attention_mask)
            total_draft_tokens_accepted += len(accepted_tokens)

            # Append accepted tokens
            if accepted_tokens:
                accepted_tensor = torch.tensor([accepted_tokens], dtype=torch.long).to(self.device)
                input_ids = torch.cat([input_ids, accepted_tensor], dim=1)
                attention_mask = torch.cat([attention_mask, torch.ones_like(accepted_tensor, dtype=torch.long)], dim=1)
                # Stop if the EOS token is encountered in accepted tokens.
                if accepted_tensor[0, -1] == self.target_tokenizer.eos_token_id:
                    break

            # If any draft token was rejected, fall back to a single token generation.
            if first_rejected < draft_tokens.shape[1]:
                with torch.no_grad():
                    output = self.target_model.generate(
                        input_ids, max_new_tokens=1,
                        attention_mask=attention_mask,
                        pad_token_id=self.target_tokenizer.pad_token_id,
                        do_sample=False
                        # do_sample=True,
                        # top_k=25,
                        # temperature=0.25
                    )
                input_ids = torch.cat([input_ids, output[:, -1:]], dim=1)
                attention_mask = torch.cat([attention_mask, torch.ones((1, 1), dtype=torch.long).to(self.device)], dim=1)
                # Stop if the EOS token is encountered in the fallback token.
                if output[0, -1] == self.target_tokenizer.eos_token_id:
                    break

        elapsed_time = time.time() - start_time
        generated_tokens = input_ids.shape[1] - initial_length
        acceptance_rate = total_draft_tokens_accepted / total_draft_tokens_proposed if total_draft_tokens_proposed > 0 else 0

        print(f"Generated {generated_tokens} tokens in {elapsed_time:.2f} seconds")
        print(f"Tokens per second: {generated_tokens / elapsed_time:.2f}")
        print(f"Draft token acceptance rate: {acceptance_rate:.2%}")

        return self.target_tokenizer.decode(input_ids[0], skip_special_tokens=True)


    def benchmark(self, prompt: str, max_tokens: int = 100,
                  num_runs: int = 3, compare_baseline: bool = True) -> Dict:
        """
        Benchmark the speculative decoder against baseline decoding.

        Args:
            prompt: Input text.
            max_tokens: Maximum number of tokens to generate.
            num_runs: Number of benchmark runs.
            compare_baseline: Whether to compare with baseline (non-speculative) decoding.

        Returns:
            Dictionary with benchmark results.
        """
        results = {
            "speculative": {"times": [], "tokens_per_second": []},
            "baseline": {"times": [], "tokens_per_second": []} if compare_baseline else None
        }

        # Benchmark speculative decoding.
        for _ in range(num_runs):
            start_time = time.time()
            output = self.speculative_decode(prompt, max_tokens=max_tokens)
            elapsed = time.time() - start_time
            prompt_len = len(self.target_tokenizer(prompt)["input_ids"])
            output_tokens = len(self.target_tokenizer.encode(output)) - prompt_len
            tps = output_tokens / elapsed
            results["speculative"]["times"].append(elapsed)
            results["speculative"]["tokens_per_second"].append(tps)

        # Benchmark baseline decoding.
        if compare_baseline:
            for _ in range(num_runs):
                inputs = self.target_tokenizer(prompt, return_tensors="pt", padding=True)
                input_ids = inputs["input_ids"].to(self.device)
                attention_mask = inputs["attention_mask"].to(self.device)
                start_time = time.time()
                with torch.no_grad():
                    output_ids = self.target_model.generate(
                        input_ids,
                        attention_mask=attention_mask,
                        max_length=input_ids.shape[1] + max_tokens,
                        do_sample=False,
                        pad_token_id=self.target_tokenizer.pad_token_id
                    )
                elapsed = time.time() - start_time
                output_tokens = output_ids.shape[1] - input_ids.shape[1]
                tps = output_tokens / elapsed
                results["baseline"]["times"].append(elapsed)
                results["baseline"]["tokens_per_second"].append(tps)

        for method in results.keys():
            if results[method] is not None:
                avg_time = sum(results[method]["times"]) / num_runs
                avg_tps = sum(results[method]["tokens_per_second"]) / num_runs
                results[method]["avg_time"] = avg_time
                results[method]["avg_tokens_per_second"] = avg_tps

        if compare_baseline:
            speedup = results["baseline"]["avg_time"] / results["speculative"]["avg_time"]
            results["speedup"] = speedup
            results["latency_reduction"] = (1 - results["speculative"]["avg_time"] / results["baseline"]["avg_time"]) * 100
            # print(f"Speculative decoding speedup: {speedup:.2f}x")
            # print(f"Latency reduction: {results['latency_reduction']:.2f}%")

        with torch.no_grad():
          torch.cuda.empty_cache()

        return results

    def bonus_benchmark(self, prompt: str, max_tokens: int = 100,
                        num_runs: int = 3, num_speculative_tokens: int = 15, compare_baseline: bool = True) -> Dict:
      """
      Benchmark the speculative decoder against baseline decoding.

      Args:
          prompt: Input text.
          max_tokens: Maximum number of tokens to generate.
          num_runs: Number of benchmark runs.
          compare_baseline: Whether to compare with baseline (non-speculative) decoding.

      Returns:
          Dictionary with benchmark results.
      """
      results = {
          "speculative": {"times": [], "tokens_per_second": []},
          "baseline": {"times": [], "tokens_per_second": []} if compare_baseline else None
      }

      # Benchmark speculative decoding.
      for _ in range(num_runs):
          start_time = time.time()
          output = self.speculative_decode(prompt, max_tokens=max_tokens, num_speculative_tokens=num_speculative_tokens)
          elapsed = time.time() - start_time
          prompt_len = len(self.target_tokenizer(prompt)["input_ids"])
          output_tokens = len(self.target_tokenizer.encode(output)) - prompt_len
          tps = output_tokens / elapsed
          results["speculative"]["times"].append(elapsed)
          results["speculative"]["tokens_per_second"].append(tps)

      # Benchmark baseline decoding.
      if compare_baseline:
          for _ in range(num_runs):
              inputs = self.target_tokenizer(prompt, return_tensors="pt", padding=True)
              input_ids = inputs["input_ids"].to(self.device)
              attention_mask = inputs["attention_mask"].to(self.device)
              start_time = time.time()
              with torch.no_grad():
                  output_ids = self.target_model.generate(
                      input_ids,
                      attention_mask=attention_mask,
                      max_length=input_ids.shape[1] + max_tokens,
                      do_sample=False,
                      pad_token_id=self.target_tokenizer.pad_token_id
                  )
              elapsed = time.time() - start_time
              output_tokens = output_ids.shape[1] - input_ids.shape[1]
              tps = output_tokens / elapsed
              results["baseline"]["times"].append(elapsed)
              results["baseline"]["tokens_per_second"].append(tps)

      for method in results.keys():
          if results[method] is not None:
              avg_time = sum(results[method]["times"]) / num_runs
              avg_tps = sum(results[method]["tokens_per_second"]) / num_runs
              results[method]["avg_time"] = avg_time
              results[method]["avg_tokens_per_second"] = avg_tps

      if compare_baseline:
          speedup = results["baseline"]["avg_time"] / results["speculative"]["avg_time"]
          results["speedup"] = speedup
          results["latency_reduction"] = (1 - results["speculative"]["avg_time"] / results["baseline"]["avg_time"]) * 100
          # print(f"Speculative decoding speedup: {speedup:.2f}x")
          # print(f"Latency reduction: {results['latency_reduction']:.2f}%")

      return results

## 3.2 - Test

In [4]:
target_model_name = "EleutherAI/pythia-1.4b-deduped"  # Larger target model
draft_model_name = "EleutherAI/pythia-160m-deduped"   # Smaller draft model


# Initialize speculative decoder
decoder = SpeculativeDecoder(
    target_model_name=target_model_name,
    draft_model_name=draft_model_name,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Test prompts
test_prompts = [
    "The future of Artificial Intelligence is",
    "Write a short story about a robot learning to feel emotions:",
    "Write the lyrics to the song 'Happy Birthday'."
]

# Run benchmark on test prompts
for i, prompt in enumerate(test_prompts):
    print(f"\nBenchmarking Prompt {i+1}:")
    print(f"Prompt: {prompt}")

    results = decoder.benchmark(
        prompt=prompt,
        max_tokens=100,
        num_runs=3,
        compare_baseline=True
    )

    print(f"Average speculative decoding time: {results['speculative']['avg_time']:.2f} seconds")
    print(f"Average speculative tokens per second: {results['speculative']['avg_tokens_per_second']:.2f}")

    if results["baseline"] is not None:
        print(f"Average baseline decoding time: {results['baseline']['avg_time']:.2f} seconds")
        print(f"Average baseline tokens per second: {results['baseline']['avg_tokens_per_second']:.2f}")
        print(f"Speedup: {results['speedup']:.2f}x")
        print(f"Latency reduction: {results['latency_reduction']:.2f}%")


Benchmarking Prompt 1:
Prompt: The future of Artificial Intelligence is
Generated 106 tokens in 2.04 seconds
Tokens per second: 52.06
Draft token acceptance rate: 87.50%
Generated 106 tokens in 1.92 seconds
Tokens per second: 55.24
Draft token acceptance rate: 87.50%
Generated 106 tokens in 1.97 seconds
Tokens per second: 53.74
Draft token acceptance rate: 87.50%
Average speculative decoding time: 1.98 seconds
Average speculative tokens per second: 53.65
Average baseline decoding time: 2.65 seconds
Average baseline tokens per second: 37.86
Speedup: 1.34x
Latency reduction: 25.28%

Benchmarking Prompt 2:
Prompt: Write a short story about a robot learning to feel emotions:
Generated 114 tokens in 2.07 seconds
Tokens per second: 54.98
Draft token acceptance rate: 94.17%
Generated 114 tokens in 2.06 seconds
Tokens per second: 55.41
Draft token acceptance rate: 94.17%
Generated 114 tokens in 2.14 seconds
Tokens per second: 53.37
Draft token acceptance rate: 94.17%
Average speculative decod

## 3.3 Experiments

In [3]:
target_model_name = "EleutherAI/pythia-1.4b-deduped"  # Larger target model
draft_model_name = "EleutherAI/pythia-160m-deduped"   # Smaller draft model


# Initialize speculative decoder
decoder = SpeculativeDecoder(
    target_model_name=target_model_name,
    draft_model_name=draft_model_name,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Test prompts
test_prompts = [
    "The future of Artificial Intelligence is",
    "Write a short story about a robot learning to feel emotions:",
    "Write the lyrics to the song 'Happy Birthday'."
]


max_tokens_list = [250, 500, 1000]
speculative_tokens = [10, 25, 100, 250]
# Run benchmark on test prompts
for max_token in max_tokens_list:
  for speculative_token in speculative_tokens:
    print(f"\nBenchmarking with {max_token} max tokens and {speculative_token} speculative tokens:")
    for i, prompt in enumerate(test_prompts):
      print(f"\nPrompt: {prompt}")

      results = decoder.bonus_benchmark(
          prompt=prompt,
          max_tokens=max_token,
          num_speculative_tokens=speculative_token,
          num_runs=3,
          compare_baseline=True
      )

      print(f"Average speculative decoding time: {results['speculative']['avg_time']:.2f} seconds")
      print(f"Average speculative tokens per second: {results['speculative']['avg_tokens_per_second']:.2f}")

      if results["baseline"] is not None:
          print(f"Average baseline decoding time: {results['baseline']['avg_time']:.2f} seconds")
          print(f"Average baseline tokens per second: {results['baseline']['avg_tokens_per_second']:.2f}")
          print(f"Speedup: {results['speedup']:.2f}x")
          print(f"Latency reduction: {results['latency_reduction']:.2f}%")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


pytorch_model.bin:  88%|########8 | 2.58G/2.93G [00:00<?, ?B/s]

The `GPTNeoXSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`attribute of the `GPTNeoXAttention` class! It will be removed in v4.48


model.safetensors:   0%|          | 0.00/2.93G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/569 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/375M [00:00<?, ?B/s]


Benchmarking with 250 max tokens and 10 speculative tokens:

Prompt: The future of Artificial Intelligence is
Generated 251 tokens in 6.75 seconds
Tokens per second: 37.17
Draft token acceptance rate: 96.15%
Generated 251 tokens in 5.90 seconds
Tokens per second: 42.52
Draft token acceptance rate: 96.15%
Generated 251 tokens in 5.66 seconds
Tokens per second: 44.32
Draft token acceptance rate: 96.15%
Average speculative decoding time: 6.11 seconds
Average speculative tokens per second: 41.31
Average baseline decoding time: 6.41 seconds
Average baseline tokens per second: 39.04
Speedup: 1.05x
Latency reduction: 4.61%

Prompt: Write a short story about a robot learning to feel emotions:
Generated 259 tokens in 5.98 seconds
Tokens per second: 43.30
Draft token acceptance rate: 99.23%
Generated 259 tokens in 6.69 seconds
Tokens per second: 38.70
Draft token acceptance rate: 99.23%
Generated 259 tokens in 6.07 seconds
Tokens per second: 42.65
Draft token acceptance rate: 99.23%
Average spe

## Bonus

In [13]:
# Settings to achieve 85%+ draft acceptance rate and about 1.3x speedup
target_model_name = "EleutherAI/pythia-1.4b-deduped"  # Larger target model
draft_model_name = "EleutherAI/pythia-160m-deduped"   # Smaller draft model


# Initialize speculative decoder
decoder = SpeculativeDecoder(
    target_model_name=target_model_name,
    draft_model_name=draft_model_name,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Test prompts
test_prompts = [
    "The future of Artificial Intelligence is",
    "Write a short story about a robot learning to feel emotions:",
    "Write the lyrics to the song 'Happy Birthday'."
]

# Run benchmark on test prompts
for i, prompt in enumerate(test_prompts):
    print(f"\nBenchmarking Prompt {i+1}:")
    print(f"Prompt: {prompt}")

    results = decoder.bonus_benchmark(
        prompt=prompt,
        max_tokens=100,
        num_speculative_tokens=16,
        num_runs=3,
        compare_baseline=True
    )

    print(f"Average speculative decoding time: {results['speculative']['avg_time']:.2f} seconds")
    print(f"Average speculative tokens per second: {results['speculative']['avg_tokens_per_second']:.2f}")

    if results["baseline"] is not None:
        print(f"Average baseline decoding time: {results['baseline']['avg_time']:.2f} seconds")
        print(f"Average baseline tokens per second: {results['baseline']['avg_tokens_per_second']:.2f}")
        print(f"Speedup: {results['speedup']:.2f}x")
        print(f"Latency reduction: {results['latency_reduction']:.2f}%")


Benchmarking Prompt 1:
Prompt: The future of Artificial Intelligence is
Generated 113 tokens in 2.17 seconds
Tokens per second: 52.09
Draft token acceptance rate: 87.50%
Generated 113 tokens in 2.06 seconds
Tokens per second: 54.77
Draft token acceptance rate: 87.50%
Generated 113 tokens in 2.07 seconds
Tokens per second: 54.50
Draft token acceptance rate: 87.50%
Average speculative decoding time: 2.10 seconds
Average speculative tokens per second: 53.76
Average baseline decoding time: 2.88 seconds
Average baseline tokens per second: 35.36
Speedup: 1.37x
Latency reduction: 26.85%

Benchmarking Prompt 2:
Prompt: Write a short story about a robot learning to feel emotions:
Generated 105 tokens in 2.68 seconds
Tokens per second: 39.20
Draft token acceptance rate: 92.86%
Generated 105 tokens in 2.62 seconds
Tokens per second: 40.07
Draft token acceptance rate: 92.86%
Generated 105 tokens in 2.35 seconds
Tokens per second: 44.60
Draft token acceptance rate: 92.86%
Average speculative decod

In [19]:
# Test speculative decoding on new models
target_model_name = "gpt2-large"
draft_model_name = "gpt2-medium"


# Initialize speculative decoder
decoder = SpeculativeDecoder(
    target_model_name=target_model_name,
    draft_model_name=draft_model_name,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Test prompts
test_prompts = [
    "The future of Artificial Intelligence is",
    "Write a short story about a robot learning to feel emotions:",
    "Write the lyrics to the song 'Happy Birthday'."
]

# Run benchmark on test prompts
for i, prompt in enumerate(test_prompts):
    print(f"\nBenchmarking Prompt {i+1}:")
    print(f"Prompt: {prompt}")

    results = decoder.bonus_benchmark(
        prompt=prompt,
        max_tokens=250,
        num_speculative_tokens=20,
        num_runs=3,
        compare_baseline=True
    )

    print(f"Average speculative decoding time: {results['speculative']['avg_time']:.2f} seconds")
    print(f"Average speculative tokens per second: {results['speculative']['avg_tokens_per_second']:.2f}")

    if results["baseline"] is not None:
        print(f"Average baseline decoding time: {results['baseline']['avg_time']:.2f} seconds")
        print(f"Average baseline tokens per second: {results['baseline']['avg_tokens_per_second']:.2f}")
        print(f"Speedup: {results['speedup']:.2f}x")
        print(f"Latency reduction: {results['latency_reduction']:.2f}%")


Benchmarking Prompt 1:
Prompt: The future of Artificial Intelligence is
Generated 262 tokens in 7.02 seconds
Tokens per second: 37.35
Draft token acceptance rate: 93.21%
Generated 262 tokens in 6.74 seconds
Tokens per second: 38.86
Draft token acceptance rate: 93.21%
Generated 262 tokens in 7.14 seconds
Tokens per second: 36.69
Draft token acceptance rate: 93.21%
Average speculative decoding time: 6.97 seconds
Average speculative tokens per second: 37.62
Average baseline decoding time: 7.30 seconds
Average baseline tokens per second: 34.30
Speedup: 1.05x
Latency reduction: 4.57%

Benchmarking Prompt 2:
Prompt: Write a short story about a robot learning to feel emotions:
Generated 267 tokens in 12.67 seconds
Tokens per second: 21.07
Draft token acceptance rate: 50.20%
Generated 267 tokens in 12.75 seconds
Tokens per second: 20.94
Draft token acceptance rate: 50.20%
Generated 267 tokens in 12.39 seconds
Tokens per second: 21.55
Draft token acceptance rate: 50.20%
Average speculative dec