# Speculative decoding demo

In this quick demo we will implement and run speculative decoding and compare it with classic autoregressive generation

In [None]:
# Prepare environment
!pip install transformers>4.51.0 torch optimum-quanto

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoConfig
from transformers.cache_utils import DynamicCache
import time
import torch

We will use <code>Qwen/Qwen3-4B</code> as target model and <code>Qwen/Qwen3-1.7B</code> as draft model.

To make sure we can run inference on the 16GB T4 GPU Colab provides, we quantize the model parameters at runtime at FP8.

Note that we cannot use the pre-quantized versions of these models available on HuggingFace as they require compute capability > 8.9 (NVIDIA Hopper architecture).

Downloading models from HuggingFace will take around 3-4 minutes

In [None]:
# Using regular Qwen models with runtime quantization for low memory footprint
target_model_name = "Qwen/Qwen3-4B"
draft_model_name = "Qwen/Qwen3-1.7B"

target_quantize = QuantoConfig(weights="float8")
draft_quantize = QuantoConfig(weights="float8")

# Initialize target and draft models with FP8 quantization
target = AutoModelForCausalLM.from_pretrained(
    target_model_name,
    quantization_config=target_quantize,
    torch_dtype="auto",
    device_map="auto",
    trust_remote_code=True
)
draft = AutoModelForCausalLM.from_pretrained(
    draft_model_name,
    quantization_config=draft_quantize,
    torch_dtype="auto",
    device_map="auto",
    trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(target_model_name)

# Move models to device
device = "cuda" if torch.cuda.is_available() else "cpu"
target.to(device)
draft.to(device)

Here we prepare the input prompt

In [None]:
# Prepare input
prompt = "What is the result of 2 * pi?"
messages = [
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=False,
    enable_thinking=False
)
input_ids = tokenizer([text], return_tensors="pt").to(target.device)  # Tokenized input for models

print(f"Prepared prompt: \n\n{text}")

### Autoregressive generation with Target model

In [None]:
max_output_length = 256  # Includes prompt as well
output_ids = target.generate(**input_ids, max_new_tokens=max_output_length)

content = tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True).strip('\n')

print(f"Target model Autoregressive generation output: \n\n{content}")

### Speculative decoding generation

We define helper functions

In [None]:
def greedy_decoder(logits: torch.Tensor) -> torch.Tensor:
    return torch.argmax(logits, dim=-1).unsqueeze(-1)

def top_k_decoder(logits: torch.Tensor, _k: int) -> torch.Tensor:
    return torch.topk(logits, _k, dim=-1).indices.unsqueeze(-1)

def prune_cache(cache: DynamicCache, num_tokens_to_discard: int) -> DynamicCache:
    if cache is None:
        return None

    current_length = cache.get_seq_length()
    new_length = current_length - num_tokens_to_discard

    cache.crop(new_length)

    return cache

Speculative decoding implementation

In [None]:
@torch.no_grad()
def speculative_generate(
    target,
    draft,
    tokenizer,
    decoder,
    input_ids,
    max_output_length,
    gamma,
    k
):
    current_position = len(input_ids[0])
    input_ids_sd = torch.full((1, max_output_length), tokenizer.pad_token_id, dtype=torch.long, device=target.device)
    input_ids_sd[0, :current_position] = input_ids[0].detach().clone().to(target.device)

    target_cache, draft_cache = None, None

    # Prefill target model to get a first token and cache
    target_output = target(
        input_ids=input_ids_sd[..., :current_position],
        past_key_values=target_cache,
        use_cache=False
    )
    target_cache = target_output.past_key_values  # Set KV Cache for target model
    token = decoder(target_output.logits[..., -1, :])
    input_ids_sd[0, current_position] = token
    current_position += 1

    if token == tokenizer.eos_token_id:
        return input_ids_sd[0, :current_position].tolist()

    while current_position < max_output_length:
        corrected_gamma = min(gamma, max_output_length - current_position - 1)

        # Generate gamma drafty tokens
        for k in range(corrected_gamma):
            draft_output = draft(
                input_ids=input_ids_sd[..., :current_position + k],
                past_key_values=draft_cache,
                use_cache=False
            )
            draft_cache = draft_output.past_key_values
            token = decoder(draft_output.logits[..., -1, :])
            input_ids_sd[0, current_position + k] = token

        # Validate drafty tokens in parallel
        target_output = target(
            input_ids=input_ids_sd[..., :current_position + corrected_gamma],
            past_key_values=target_cache,
            use_cache=False
        )
        target_cache = target_output.past_key_values
        target_logits = target_output.logits[..., current_position - 1: current_position + corrected_gamma - 1, :]
        target_topk = top_k_decoder(target_logits, k)

        # Compute the last accepted draft position
        n = corrected_gamma
        for i in range(corrected_gamma):
            if input_ids_sd[0, current_position + i] not in target_topk[0, i, :]:
                n = i
                break

        # Check if any of the accepted tokens is the eos token
        stop_locations = torch.nonzero(torch.eq(input_ids_sd[0, current_position: current_position + n], tokenizer.eos_token_id))
        if stop_locations.shape[0] > 0:
            stop_location = stop_locations[0, 0].item()
            input_ids_sd[0, current_position + stop_location + 1:] = tokenizer.pad_token_id
            current_position += stop_location + 1
            break

        if n != corrected_gamma:
            # Need to adjust the cache of both models
            draft_cache = prune_cache(draft_cache, corrected_gamma - n)
            target_cache = prune_cache(target_cache, corrected_gamma - n + 1)

        # Next token is sampled from the target model's distribution
        target_logits = target_output.logits[..., current_position + n - 1, :]
        token = decoder(target_logits)

        input_ids_sd[0, current_position + n: current_position + corrected_gamma] = tokenizer.pad_token_id
        input_ids_sd[0, current_position + n] = token
        current_position += n + 1

        if token == tokenizer.eos_token_id:
            break

    return input_ids_sd[0, :current_position].tolist()

In [None]:
gamma = 5
k = 10  # Accept drafty tokens only if they are in the top k of the target model
max_output_length = 256

output_ids = speculative_generate(
    target,
    draft,
    tokenizer,
    greedy_decoder,
    input_ids['input_ids'],
    max_output_length,
    gamma,
    k
)

content = tokenizer.decode(output_ids, skip_special_tokens=True).strip('\n')

print(f"Speculative decoding output: \n\n{content}")