In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import time

In [2]:
draft_model_id = "Qwen/Qwen3-0.6B"
target_model_id = "Qwen/Qwen3-14B"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype="float16",
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

In [3]:
# Keep draft model on single GPU
draft_model = AutoModelForCausalLM.from_pretrained(
    draft_model_id,
    device_map="cuda:0",
    torch_dtype="auto",
    quantization_config=bnb_config,
)
tokenizer = AutoTokenizer.from_pretrained(draft_model_id)

# Split target model across GPUs
target_model = AutoModelForCausalLM.from_pretrained(
    target_model_id,
    device_map="auto",
    torch_dtype="auto",
    quantization_config=bnb_config,
)

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [4]:
def prepare_prompt(prompt, tokenizer, max_prompt_len):
    messages = [
        {"role": "user", "content": f"{prompt}"},
    ]
    return tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        truncation=True,
        max_length=max_prompt_len,
        enable_thinking=False,
        return_tensors="pt",
    )

In [13]:
def generate(model, tokenizer, prompt, max_prompt_len, max_new_tokens):
    with torch.inference_mode():
        inputs = prepare_prompt(prompt, tokenizer, max_prompt_len)
        outputs = model.generate(
            **inputs.to(model.device),
            max_new_tokens=max_new_tokens,
        )
    # return tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1] :])
    return outputs[0][inputs["input_ids"].shape[-1] :]

In [14]:
max_prompt_len = 512
max_new_tokens = 64
prompt = "Concisely, what is speculative decoding in relation to LLM performance?"

In [None]:
# Draft Model
start = time.perf_counter()
output = generate(draft_model, tokenizer, prompt, max_prompt_len, max_new_tokens)
end = time.perf_counter()
print(f"{len(output) / (end - start):.1f} tok/sec")
print(tokenizer.decode(output))

28.666312 tok/sec
Speculative decoding refers to a technique used in language models (LLMs) to enhance their performance by making assumptions about the next word or context based on prior information. This approach allows the model to make predictions more accurately, especially when the context is incomplete or uncertain, thereby improving the fluency and coherence of the generated text


In [None]:
# Target Model
start = time.perf_counter()
output = generate(target_model, tokenizer, prompt, max_prompt_len, max_new_tokens)
end = time.perf_counter()
print(f"{len(output) / (end - start):.1f} tok/sec")
print(tokenizer.decode(output))

15.615868 tok/sec
Speculative decoding is a technique used to accelerate the inference of large language models (LLMs) by using a smaller, faster model to generate text speculatively in parallel with the main model, improving throughput without sacrificing accuracy.<|im_end|>
