# Efficient Large Language Model Inference with SpecPC
This notebook demonstrates how to patch a pre-trained LLM with SpecPC. We'll also evaluate with a synthetic retrieval-style prompt.

In [None]:
import os
import random
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# (Optional) Ensure only the first GPU is used (if available)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Ensure deterministic shuffling
random.seed(42)

from draft_approx_llm import SpecPCConfig, patch_model

## Helper Function for Prompt Generation

In [None]:
def create_niah_example(n_repeat: int = 100, key: str = "123456789") -> str:
    """
    Generates a synthetic document with many distractors and a hidden key to test retrieval abilities.
    """
    example_sentences = [
        "The cat sat on the mat.",
        "The quick brown fox jumps over the lazy dog.",
        "A journey of a thousand miles begins with a single step.",
        "To be or not to be, that is the question.",
        "All that glitters is not gold."
    ]

    context = []
    for sentence in example_sentences:
        context.extend([sentence] * n_repeat)
    context.append(f"The key: {key}")
    random.Random(42).shuffle(context)  # deterministic shuffle
    
    prompt = "\n".join(context)
    prompt += "\n\nWhat is the key?"
    return prompt

## Model and Tokenizer Loading

In [None]:
# Model configurations: adjust as needed
MAIN_MODEL_NAME = "Qwen/Qwen2.5-14B-Instruct"
DRAFT_MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"

model_kwargs = {
    "torch_dtype": torch.bfloat16,
    "attn_implementation": "flash_attention_2",
    "device_map": "auto"
}

# Load models and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MAIN_MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MAIN_MODEL_NAME, **model_kwargs)
draft_model = AutoModelForCausalLM.from_pretrained(DRAFT_MODEL_NAME, **model_kwargs)

## SpecPC Configuration and Patching

In [None]:
# Adjust these hyperparameters to tune performance/quality tradeoff
sparse_config = SpecPCConfig(
    max_capacity_prompt=1024,
    window_size=64,
    pool_type="max",
    kernel_size=64,
    reduction_type="max",
    lookahead_tokens=1,
    neighbor_tokens=64,
    starting_layer_index=8,
    weighted_query=True
)

# Patch main model for SpecPC
patched_model = patch_model(model, draft_model, sparse_config)

## Prepare Prompt and Tokenized Input

In [None]:
sample_input = create_niah_example()
inputs = tokenizer(sample_input, return_tensors="pt").to(patched_model.device)
print(f"Input token count: {inputs.input_ids.shape[1]}")

## Generate Output

In [None]:
with torch.inference_mode():
    gen_outputs = patched_model.generate(
        **inputs,
        max_new_tokens=32,
        do_sample=False,
        temperature=None,
        top_p=None,
        top_k=None,
        return_dict_in_generate=True
    )

## Decode and Display Model Output

In [None]:
decoded_output = tokenizer.decode(gen_outputs.output_ids[0], skip_special_tokens=True).strip()
print(decoded_output)