# Speculative decoding demo

In this quick demo we will implement and run speculative decoding and compare it with classic autoregressive generation

In [1]:
# Prepare environment
!pip install transformers>4.51.0 torch optimum-quanto

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoConfig
from transformers.cache_utils import DynamicCache
import time
import torch

We will use <code>Qwen/Qwen3-4B</code> as target model and <code>Qwen/Qwen3-1.7B</code> as draft model.

To make sure we can run inference on the 16GB T4 GPU Colab provides, we quantize the model parameters at runtime at FP8.

Note that we cannot use the pre-quantized versions of these models available on HuggingFace as they require compute capability > 8.9 (NVIDIA Hopper architecture).

Downloading models from HuggingFace will take around 3-4 minutes

In [3]:
# Using regular Qwen models with runtime quantization for low memory footprint
target_model_name = "Qwen/Qwen3-4B"
draft_model_name = "Qwen/Qwen3-1.7B"

target_quantize = QuantoConfig(weights="float8")
draft_quantize = QuantoConfig(weights="float8")

# Initialize target and draft models with FP8 quantization
target = AutoModelForCausalLM.from_pretrained(
    target_model_name,
    quantization_config=target_quantize,
    torch_dtype="auto",
    device_map="auto",
    trust_remote_code=True
)
draft = AutoModelForCausalLM.from_pretrained(
    draft_model_name,
    quantization_config=draft_quantize,
    torch_dtype="auto",
    device_map="auto",
    trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(target_model_name)

# Move models to device
device = "cuda" if torch.cuda.is_available() else "cpu"
target.to(device)
draft.to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/3.99G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/99.6M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/622M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): QLinear(in_features=2048, out_features=2048, bias=False)
          (k_proj): QLinear(in_features=2048, out_features=1024, bias=False)
          (v_proj): QLinear(in_features=2048, out_features=1024, bias=False)
          (o_proj): QLinear(in_features=2048, out_features=2048, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): QLinear(in_features=2048, out_features=6144, bias=False)
          (up_proj): QLinear(in_features=2048, out_features=6144, bias=False)
          (down_proj): QLinear(in_features=6144, out_features=2048, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen3RMSNorm((2048,), eps=1e-06)
        (post_attentio

Here we prepare the input prompt

In [7]:
# Prepare input
prompt = "What is the result of 2 * pi?"
messages = [
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=False,
    enable_thinking=False
)
input_ids = tokenizer([text], return_tensors="pt").to(target.device)  # Tokenized input for models

print(f"Prepared prompt: \n\n{text}")

Prepared prompt: 

<|im_start|>user
What is the result of 2 * pi?<|im_end|>
<|im_start|>assistant
<think>

</think>




### Autoregressive generation with Target model

In [8]:
max_output_length = 256  # Includes prompt as well
output_ids = target.generate(**input_ids, max_new_tokens=max_output_length)

content = tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True).strip('\n')

print(f"Target model Autoregressive generation output: \n\n{content}")

Target model Autoregressive generation output: 

user
What is the result of 2 * pi?
assistant
<think>

</think>

The result of $ 2 \times \pi $ is approximately:

$$
2\pi \approx 6.283185307
$$

So, $ 2\pi $ is about **6.283** when rounded to three decimal places.


### Speculative decoding generation

We define helper functions

In [9]:
def greedy_decoder(logits: torch.Tensor) -> torch.Tensor:
    return torch.argmax(logits, dim=-1).unsqueeze(-1)

def top_k_decoder(logits: torch.Tensor, _k: int) -> torch.Tensor:
    return torch.topk(logits, _k, dim=-1).indices.unsqueeze(-1)

def prune_cache(cache: DynamicCache, num_tokens_to_discard: int) -> DynamicCache:
    if cache is None:
        return None

    current_length = cache.get_seq_length()
    new_length = current_length - num_tokens_to_discard

    cache.crop(new_length)

    return cache

Speculative decoding implementation

In [10]:
@torch.no_grad()
def speculative_generate(
    target,
    draft,
    tokenizer,
    decoder,
    input_ids,
    max_output_length,
    gamma,
    k
):
    current_position = len(input_ids[0])
    input_ids_sd = torch.full((1, max_output_length), tokenizer.pad_token_id, dtype=torch.long, device=target.device)
    input_ids_sd[0, :current_position] = input_ids[0].detach().clone().to(target.device)

    target_cache, draft_cache = None, None

    # Prefill target model to get a first token and cache
    target_output = target(
        input_ids=input_ids_sd[..., :current_position],
        past_key_values=target_cache,
        use_cache=False
    )
    target_cache = target_output.past_key_values  # Set KV Cache for target model
    token = decoder(target_output.logits[..., -1, :])
    input_ids_sd[0, current_position] = token
    current_position += 1

    if token == tokenizer.eos_token_id:
        return input_ids_sd[0, :current_position].tolist()

    while current_position < max_output_length:
        corrected_gamma = min(gamma, max_output_length - current_position - 1)

        # Generate gamma drafty tokens
        for k in range(corrected_gamma):
            draft_output = draft(
                input_ids=input_ids_sd[..., :current_position + k],
                past_key_values=draft_cache,
                use_cache=False
            )
            draft_cache = draft_output.past_key_values
            token = decoder(draft_output.logits[..., -1, :])
            input_ids_sd[0, current_position + k] = token

        # Validate drafty tokens in parallel
        target_output = target(
            input_ids=input_ids_sd[..., :current_position + corrected_gamma],
            past_key_values=target_cache,
            use_cache=False
        )
        target_cache = target_output.past_key_values
        target_logits = target_output.logits[..., current_position - 1: current_position + corrected_gamma - 1, :]
        target_topk = top_k_decoder(target_logits, k)

        # Compute the last accepted draft position
        n = corrected_gamma
        for i in range(corrected_gamma):
            if input_ids_sd[0, current_position + i] not in target_topk[0, i, :]:
                n = i
                break

        # Check if any of the accepted tokens is the eos token
        stop_locations = torch.nonzero(torch.eq(input_ids_sd[0, current_position: current_position + n], tokenizer.eos_token_id))
        if stop_locations.shape[0] > 0:
            stop_location = stop_locations[0, 0].item()
            input_ids_sd[0, current_position + stop_location + 1:] = tokenizer.pad_token_id
            current_position += stop_location + 1
            break

        if n != corrected_gamma:
            # Need to adjust the cache of both models
            draft_cache = prune_cache(draft_cache, corrected_gamma - n)
            target_cache = prune_cache(target_cache, corrected_gamma - n + 1)

        # Next token is sampled from the target model's distribution
        target_logits = target_output.logits[..., current_position + n - 1, :]
        token = decoder(target_logits)

        input_ids_sd[0, current_position + n: current_position + corrected_gamma] = tokenizer.pad_token_id
        input_ids_sd[0, current_position + n] = token
        current_position += n + 1

        if token == tokenizer.eos_token_id:
            break

    return input_ids_sd[0, :current_position].tolist()

In [12]:
gamma = 5
k = 10  # Accept drafty tokens only if they are in the top k of the target model
max_output_length = 256

output_ids = speculative_generate(
    target,
    draft,
    tokenizer,
    greedy_decoder,
    input_ids['input_ids'],
    max_output_length,
    gamma,
    k
)

content = tokenizer.decode(output_ids, skip_special_tokens=True).strip('\n')

print(f"Speculative decoding output: \n\n{content}")

Speculative decoding output: 

user
What is the result of 2 * pi?
assistant
<think>

</think>

The result of $ 2 \times \pi $ is approximately:

$$
2 \pi \approx 6.2832
$$

This is a well-known value in mathematics, often used in calculations involving circles, waves, and other circular or periodic phenomena.
