In [None]:
!pip install transformers torch accelerate

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch


# Choose your models - draft should be much smaller than target
# Using Mistral-7B-Instruct-v0.2 as a larger, publicly accessible target model
target_model_name = "google/gemma-7b-it"
# Using TinyLlama-1.1B-Chat-v1.0 as a smaller, publicly accessible draft model
draft_model_name = "google/gemma-2b-it"

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load tokenizer (must be the same for both models)
tokenizer = AutoTokenizer.from_pretrained(target_model_name)

# Load target model (the large, high-quality model)
print("Loading target model...")
target_model = AutoModelForCausalLM.from_pretrained(
    target_model_name,
    dtype=torch.float16,  # Use fp16 for faster inference
    device_map="auto"
)

# Load draft model (the small, fast model)
print("Loading draft model...")
draft_model = AutoModelForCausalLM.from_pretrained(
    draft_model_name,
    dtype=torch.float16,
    device_map="auto"
)

print("Models loaded successfully!")

To access gated models like Gemma, you need to log in to Hugging Face.

**Instructions:**

1.  **Get a Hugging Face API Token:** Go to [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) and create a new access token (ensure it has at least 'read' permissions).
2.  **Log in:**
    *   **Option 1 (Recommended in Colab):** Run the following code in a new cell and paste your token when prompted:

        ```python
        from huggingface_hub import login
        login()
        ```

    *   **Option 2 (Environment Variable):** Set the `HF_TOKEN` environment variable before running any code that accesses Hugging Face. For example:

        ```python
        import os
        os.environ["HF_TOKEN"] = "hf_YOUR_TOKEN_HERE"
        ```

In [None]:
# Create a prompt
prompt = "Quantum entanglement is a phenomenon where"

# Tokenize the input
inputs = tokenizer(prompt, return_tensors="pt").to(device)

print(f"Input prompt: {prompt}")
print(f"Input token count: {inputs['input_ids'].shape[1]}")

In [None]:
import time

# Standard generation (no speculation)
print("\n--- Standard Generation (Baseline) ---")
start_time = time.time()

baseline_output = target_model.generate(
    **inputs,
    max_new_tokens=50,
    # do_sample=False,  # Greedy decoding
    pad_token_id=tokenizer.eos_token_id
)

baseline_time = time.time() - start_time
baseline_text = tokenizer.decode(baseline_output[0], skip_special_tokens=True)

print(f"Generated text:\n{baseline_text}\n")
print(f"Time taken: {baseline_time:.2f} seconds")
print(f"Tokens per second: {50/baseline_time:.2f}")

In [None]:
import time
import warnings # Import warnings module


# Speculative decoding - just add assistant_model parameter!
print("\n--- Speculative Decoding ---")
start_time = time.time()

with warnings.catch_warnings():
    warnings.simplefilter("ignore") # Ignore all warnings within this block
    speculative_output = target_model.generate(
        **inputs,
        max_new_tokens=50,
        do_sample=True,  # set to False for greedy decoding
        pad_token_id=tokenizer.eos_token_id,
        assistant_model=draft_model,   # This enables speculative decoding!
        num_assistant_tokens = 10
    )

speculative_time = time.time() - start_time
speculative_text = tokenizer.decode(speculative_output[0], skip_special_tokens=True)

print(f"Generated text:\n{speculative_text}\n")
print(f"Time taken: {speculative_time:.2f} seconds")
print(f"Tokens per second: {50/speculative_time:.2f}")

# Calculate speedup
speedup = baseline_time / speculative_time
print(f"\n Speedup: {speedup:.2f}x faster!")