In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

In [2]:
draft_model_id = "Qwen/Qwen3-0.6B"
target_model_id = "Qwen/Qwen3-14B"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype="float16",
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

In [3]:
# Keep draft model on single GPU
draft_model = AutoModelForCausalLM.from_pretrained(
    draft_model_id,
    device_map="cuda:0",
    torch_dtype="auto",
    quantization_config=bnb_config,
)
tokenizer = AutoTokenizer.from_pretrained(draft_model_id)

# Split target model across GPUs
target_model = AutoModelForCausalLM.from_pretrained(
    target_model_id,
    device_map="auto",
    torch_dtype="auto",
    quantization_config=bnb_config,
)

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [4]:
def prepare_prompt(prompt, tokenizer, max_prompt_len):
    messages = [
        {"role": "user", "content": f"{prompt}"},
    ]
    return tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        truncation=True,
        max_length=max_prompt_len,
        enable_thinking=False,
        return_tensors="pt",
    )

In [5]:
def generate(model, tokenizer, prompt, max_prompt_len, max_new_tokens):
    with torch.inference_mode():
        inputs = prepare_prompt(prompt, tokenizer, max_prompt_len)
        outputs = model.generate(
            **inputs.to(model.device),
            max_new_tokens=max_new_tokens,
        )
    return tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1] :])

In [6]:
max_prompt_len = 512
max_new_tokens = 64
prompt = "Concisely, what is speculative decoding in relation to LLM performance?"

In [7]:
# Draft Model
result = generate(draft_model, tokenizer, prompt, max_prompt_len, max_new_tokens)
print(result)

Speculative decoding is a technique used in language models to enhance their performance by making predictions about future text or data based on the current context or previous information. This approach allows models to generate more accurate and contextually relevant responses, improving their ability to understand and respond to complex or ambiguous situations. It is often applied in tasks


In [8]:
# Target Model
result = generate(target_model, tokenizer, prompt, max_prompt_len, max_new_tokens)
print(result)

Speculative decoding is a technique used to accelerate large language model (LLM) inference by generating text in parallel using a faster, smaller model (the "speculator") while the main model processes the input sequentially. This improves latency without sacrificing accuracy, as the main model can later verify and correct the speculator's output
