# GPT-2 From Scratch â€” Inference

Load a trained GPT-2 checkpoint and generate text with different sampling strategies.

## 1. Setup

In [None]:
# Install dependencies (uncomment on Colab)
# !pip install torch tiktoken datasets numpy

In [None]:
import torch
from config import GPT2Config, get_device
from model import GPT2
from generate import generate

device = get_device()
print(f"Device: {device}")

## 2. Load Checkpoint

In [None]:
checkpoint_path = "checkpoints/best.pt"

checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
config = GPT2Config(**checkpoint["model_config"])
model = GPT2(config)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(device)
model.eval()

print(f"Loaded from step {checkpoint.get('step', '?')}")
print(f"Validation loss: {checkpoint.get('val_loss', '?'):.4f}")

## 3. Generate Text

Sampling parameters:
- **temperature**: <1 = focused, >1 = creative (default 0.8)
- **top_k**: keep top k tokens (default 50, 0 = disabled)
- **top_p**: nucleus sampling threshold (default 0.95, 1.0 = disabled)

In [None]:
prompt = "The meaning of life is"

output = generate(
    model, prompt=prompt,
    max_new_tokens=200,
    temperature=0.8,
    top_k=50,
    top_p=0.95,
    device=device,
)

print(output)

## 4. Compare Sampling Strategies

See how different temperature and top-k settings affect output quality.

In [None]:
prompt = "In a shocking finding, scientists discovered"

settings = [
    {"temperature": 0.3, "top_k": 10,  "top_p": 1.0,  "label": "Conservative (T=0.3, k=10)"},
    {"temperature": 0.8, "top_k": 50,  "top_p": 0.95, "label": "Balanced (T=0.8, k=50, p=0.95)"},
    {"temperature": 1.2, "top_k": 100, "top_p": 0.95, "label": "Creative (T=1.2, k=100, p=0.95)"},
]

for s in settings:
    print(f"--- {s['label']} ---")
    output = generate(
        model, prompt=prompt,
        max_new_tokens=100,
        temperature=s["temperature"],
        top_k=s["top_k"],
        top_p=s["top_p"],
        device=device,
    )
    print(output)
    print()

## 5. Interactive Generation

Enter your own prompts.

In [None]:
prompt = "Once upon a time"  # <-- edit this

output = generate(
    model, prompt=prompt,
    max_new_tokens=200,
    temperature=0.8,
    top_k=50,
    top_p=0.95,
    device=device,
)

print(output)