# Speculative decoding demo

In this quick demo we will implement and run speculative decoding and compare it with classic autoregressive generation

In [None]:
# Prepare environment
!pip install transformers>4.51.0 torch accelerate

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoConfig
from transformers.cache_utils import DynamicCache
import time
import torch

We will use <code>Qwen/Qwen3-14B</code> as target model and <code>Qwen/Qwen3-1.7B</code> as draft model.

Downloading models from HuggingFace will take several minutes

In [None]:
# Using regular Qwen models with runtime quantization for low memory footprint
target_model_name = "Qwen/Qwen3-14B"
draft_model_name = "Qwen/Qwen3-1.7B"

# Initialize target and draft models with FP8 quantization
target = AutoModelForCausalLM.from_pretrained(
    target_model_name,
    torch_dtype=torch.bfloat16,
    device_map="cuda:0",
    trust_remote_code=True
)
draft = AutoModelForCausalLM.from_pretrained(
    draft_model_name,
    torch_dtype=torch.bfloat16,
    device_map="cuda:0",
    trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(target_model_name)

Here we prepare the input prompt

In [None]:
# Prepare input
prompt = "What is the result of 2 * pi?"
messages = [
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=False,
    enable_thinking=True
)
input_ids = tokenizer([text], return_tensors="pt").to(target.device)  # Tokenized input for models

print(f"Prepared prompt: \n\n{text}")

### Autoregressive generation with Target model

In [None]:
%%time 
max_output_length = 200  # Includes prompt as well
output_ids = target.generate(**input_ids, max_new_tokens=max_output_length)

content = tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True).strip('\n')

print(f"Target model Autoregressive generation output: \n\n{content}")

### Speculative decoding generation

We define helper functions

In [None]:
def greedy_decoder(logits: torch.Tensor) -> torch.Tensor:
    return torch.argmax(logits, dim=-1).unsqueeze(-1)

def top_k_decoder(logits: torch.Tensor, _k: int) -> torch.Tensor:
    return torch.topk(logits, _k, dim=-1).indices.unsqueeze(-1)

def prune_cache(cache: DynamicCache, num_tokens_to_discard: int) -> DynamicCache:
    if cache is None:
        return None

    current_length = cache.get_seq_length()
    new_length = current_length - num_tokens_to_discard

    cache.crop(new_length)

    return cache

Speculative decoding implementation

In [None]:
@torch.inference_mode()
def speculative_generate(
    target,
    draft,
    tokenizer,
    decoder,
    input_ids,
    max_output_length,
    gamma,
    k
):
    current_position = len(input_ids[0])
    input_ids_sd = torch.full((1, max_output_length), tokenizer.pad_token_id, dtype=torch.long, device=target.device)
    input_ids_sd[0, :current_position] = input_ids[0].detach().clone().to(target.device)

    target_cache, draft_cache = None, None

    # Prefill target model to get a first token and cache
    target_output = target(
        input_ids=input_ids_sd[..., :current_position],
    )
    target_cache = target_output.past_key_values  # Set KV Cache for target model
    token = decoder(target_output.logits[..., -1, :])
    input_ids_sd[0, current_position] = token
    current_position += 1

    if token == tokenizer.eos_token_id:
        return input_ids_sd[0, :current_position].tolist()

    while current_position < max_output_length:
        corrected_gamma = min(gamma, max_output_length - current_position - 1)

        if corrected_gamma == 0:
            break

        # Generate gamma drafty tokens
        for i in range(corrected_gamma):
            if draft_cache is None:
                draft_output = draft(
                    input_ids=input_ids_sd[..., : current_position + i].to(draft.device),
                )
            else:
                draft_output = draft(
                    input_ids=input_ids_sd[..., current_position + i - 1 : current_position + i].to(draft.device),
                    past_key_values=draft_cache,
                )
            draft_cache = draft_output.past_key_values
            token = decoder(draft_output.logits[..., -1, :])
            input_ids_sd[0, current_position + i] = token.to(target.device)

        # Validate drafty tokens in parallel
        target_output = target(
            input_ids=input_ids_sd[..., current_position - 1 : current_position + corrected_gamma],
            past_key_values=target_cache,
        )
        target_cache = target_output.past_key_values
        target_logits = target_output.logits[..., :corrected_gamma, :]
        target_topk = top_k_decoder(target_logits, k)

        # Compute the last accepted draft position
        n = corrected_gamma
        for i in range(corrected_gamma):
            if input_ids_sd[0, current_position + i] not in target_topk[0, i, :]:
                n = i
                break

        # Check if any of the accepted tokens is the eos token
        stop_locations = torch.nonzero(torch.eq(input_ids_sd[0, current_position: current_position + n], tokenizer.eos_token_id))
        if stop_locations.shape[0] > 0:
            stop_location = stop_locations[0, 0].item()
            input_ids_sd[0, current_position + stop_location + 1:] = tokenizer.pad_token_id
            current_position += stop_location + 1
            break

        if n < corrected_gamma:
            # Need to adjust the cache of both models
            draft_cache = prune_cache(draft_cache, corrected_gamma - n)
            target_cache = prune_cache(target_cache, corrected_gamma - n + 1)

        # Next token is sampled from the target model's distribution
        target_logits = target_output.logits[..., n, :]
        token = decoder(target_logits)

        input_ids_sd[0, current_position + n: current_position + corrected_gamma] = tokenizer.pad_token_id
        input_ids_sd[0, current_position + n] = token
        current_position += n + 1

        if token == tokenizer.eos_token_id:
            break

    return input_ids_sd[0, :current_position].tolist()

In [None]:
%%time 
gamma = 5
k = 10  # Accept drafty tokens only if they are in the top k of the target model
max_output_length = 200 + len(input_ids['input_ids'][0])

output_ids = speculative_generate(
    target,
    draft,
    tokenizer,
    greedy_decoder,
    input_ids['input_ids'],
    max_output_length,
    gamma,
    k
)

content = tokenizer.decode(output_ids, skip_special_tokens=True).strip('\n')

print(f"Speculative decoding output: \n\n{content}")