In [1]:
from transformers import (
    AutoTokenizer,
    AutoModelForImageTextToText,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)

import torch
import time
from transformers.utils import logging

# Set Transformers logging to error-only (no warnings or info)
logging.set_verbosity_error()

In [2]:
draft_model_id = "google/gemma-3-1b-pt"
target_model_id = "google/gemma-3-27b-pt"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype="float16",
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

In [3]:
# Keep draft model on single GPU
draft_model = AutoModelForCausalLM.from_pretrained(
    draft_model_id,
    device_map="cuda:0",
    torch_dtype="auto",
    # quantization_config=bnb_config,
)
tokenizer = AutoTokenizer.from_pretrained(draft_model_id)

# Split target model across GPUs
target_model = AutoModelForImageTextToText.from_pretrained(
    target_model_id,
    device_map="cpu",
    torch_dtype="auto",
    # quantization_config=bnb_config,
)

Loading checkpoint shards:   0%|          | 0/12 [00:00<?, ?it/s]

In [4]:
# Params
k = 4
prompt = "Once upon a time"
max_tokens = 16
use_cache = False

# Set pad token if needed
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Determine how many tokens are in original prompt
generated = tokenizer(prompt, return_tensors="pt")
N = generated["input_ids"].shape[1]

# Pad prompt to align shape with compiled model
generated = tokenizer(
    prompt,
    padding="max_length",
    max_length=max_tokens,
    return_tensors="pt",
)["input_ids"]

cached_prompt = None
successful_spec_decodes = 0
draft_model_process_time = []
target_model_process_time = []


draft_model.eval()
target_model.eval()
with torch.inference_mode():
    # Run until model is reached max number of tokens
    print("\n# Speculative Decoding Initiating #")
    while N < max_tokens:
        print(f"\nSequence length: {N}")
        print(f"Sequence text: '{tokenizer.decode(generated[0][-N+1:])}'")
        generated = generated.to(draft_model.device)

        # Step 1: Generate draft output sequence of length k
        spec_sequences = [{"input_ids": generated[0]}]
        spec_tokens = []
        start = time.perf_counter()
        for i in range(k):
            # Use generated sequence stored in spec_tokens
            draft_logits = draft_model(spec_sequences[i]["input_ids"][None, :]).logits
            # Greedy speculative decoding using argmax (for now..)
            spec_token = draft_logits[:, -1, :].argmax(dim=-1)
            spec_tokens.append(spec_token)
            # Concatenate most recent speculative token to previous sequence
            spec_sequence = torch.cat([spec_sequences[i]["input_ids"][1:], spec_token])
            spec_sequences.append({"input_ids": spec_sequence})
        spec_tokens = torch.cat(spec_tokens)
        end = time.perf_counter()

        # Track draft model speed
        draft_model_process_time.append(end - start)
        print(f"Draft Model: {k / (end - start):.3f} tok/s")

        # Step 2: Pad speculatize token sequences for use with target model
        batched_input = tokenizer.pad(spec_sequences, return_tensors="pt")

        # Step 3: Process batched input
        start = time.perf_counter()
        if cached_prompt is None and use_cache:
            # Cache prompt in target model
            target_output = target_model(
                **batched_input.to(target_model.device), use_cache=use_cache
            )
            target_logits = target_output.logits
            cached_prompt = target_output.past_key_values
        else:
            # Decode using cached prompt
            target_logits = target_model(
                **batched_input.to(target_model.device),
                past_key_values=cached_prompt,
                use_cache=use_cache,
            ).logits
        target_logits = target_model(**batched_input.to(target_model.device)).logits
        target_tokens = target_logits[:, -1, :].argmax(dim=-1)
        end = time.perf_counter()

        # Track target model speed
        target_model_process_time.append(end - start)
        print(f"Target Model: {1 / (end - start):.3f} tok/s")

        # Step 4: Evaluate speculated tokens
        matches = spec_tokens.to(target_model.device) == target_tokens[:-1]
        accepted_matches = matches.cumprod(dim=0).sum().item()

        if accepted_matches > 0:
            accepted_tokens = spec_tokens[:accepted_matches]
        else:
            # If no accepted tokens, use first token from target model
            accepted_tokens = target_tokens[:1].to(draft_model.device)

        # Step 5: Update generated sequence with accepted tokens
        generated = torch.cat([generated, accepted_tokens[None, :]], dim=-1)
        generated = generated[:, -max_tokens:]
        N += len(accepted_tokens)
        successful_spec_decodes += accepted_matches

        print(f"Accepted speculative tokens: {accepted_matches}", end=", ")
        if accepted_matches > 0:
            print(f"Speculative Text: '{tokenizer.decode(accepted_tokens)}'")
        print(f"SpecDecode: {len(accepted_tokens) / (end - start):.3f} tok/s")

# Print final sequence
print(f"\nSequence length: {N}")
print(f"Sequence text: {tokenizer.decode(generated[0][-N:])}")
print("\n# Speculative Decoding Complete #")
draft_model_perf = (k * len(draft_model_process_time)) / sum(draft_model_process_time)
print(f"\tDraft Model Perf: {draft_model_perf:.3f} tok/s")
target_model_perf = len(target_model_process_time) / sum(target_model_process_time)
print(f"\tTarget Model Perf: {target_model_perf:.3f} tok/s")
spec_decode_perf = N / sum(target_model_process_time)
print(f"\tSpec Decode Perf: {spec_decode_perf:.3f} tok/s")
print(f"\tSpec Decode Speedup: {spec_decode_perf / target_model_perf:.1f}")
acceptance_ratio = successful_spec_decodes / (k * max_tokens + 1e-9)
print(f"\tSpec Decode Acceptance Ratio: {acceptance_ratio:.3f}")


# Speculative Decoding Initiating #

Sequence length: 5
Sequence text: 'Once upon a time'
Draft Model: 14.011 tok/s
Target Model: 0.033 tok/s
Accepted speculative tokens: 4, Speculative Text: ', there was a'
SpecDecode: 0.132 tok/s

Sequence length: 9
Sequence text: 'Once upon a time, there was a'
Draft Model: 36.711 tok/s
Target Model: 0.033 tok/s
Accepted speculative tokens: 2, Speculative Text: ' little girl'
SpecDecode: 0.066 tok/s

Sequence length: 11
Sequence text: 'Once upon a time, there was a little girl'
Draft Model: 36.086 tok/s
Target Model: 0.033 tok/s
Accepted speculative tokens: 0, SpecDecode: 0.033 tok/s

Sequence length: 12
Sequence text: 'Once upon a time, there was a little girl who'
Draft Model: 36.847 tok/s
Target Model: 0.033 tok/s
Accepted speculative tokens: 0, SpecDecode: 0.033 tok/s

Sequence length: 13
Sequence text: 'Once upon a time, there was a little girl who loved'
Draft Model: 36.339 tok/s
Target Model: 0.033 tok/s
Accepted speculative tokens: 4, Specu