# Fine-Tune GPT-2 Medium on FineWeb-Edu

This notebook fine-tunes the **GPT-2 Medium (355M params)** pre-trained model on high-quality educational content from FineWeb-Edu.

**Hardware**: RTX 4060 (8GB VRAM) — uses gradient checkpointing + Adafactor + fp16 to fit.

In [None]:
import torch
from datasets import load_dataset, Dataset
from transformers import (
    GPT2LMHeadModel,
    GPT2TokenizerFast,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling
)
import os
from tqdm import tqdm

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print(f"CUDA available: {torch.cuda.is_available()}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1024**3:.1f} GB")

## 1. Load Pre-trained GPT-2 Medium Model & Tokenizer

Instead of training from scratch, we start with the pre-trained GPT-2 Medium (355M params) and fine-tune it.
This gives us a strong language foundation — the model already understands English, grammar, and general knowledge.

In [None]:
MODEL_NAME = "gpt2-medium"

print(f"Loading pre-trained model: {MODEL_NAME}")
tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

model = GPT2LMHeadModel.from_pretrained(MODEL_NAME)

# Enable gradient checkpointing to fit in 8GB VRAM
model.gradient_checkpointing_enable()

num_params = sum(p.numel() for p in model.parameters())
print(f"Model loaded: {num_params / 1e6:.1f}M parameters")
print(f"Vocab size: {tokenizer.vocab_size}")
print(f"Max context length: {model.config.n_positions}")

## 2. Load & Materialize FineWeb-Edu Dataset

We use 500K samples from the FineWeb-Edu dataset — high-quality educational web content curated by HuggingFace.

In [None]:
NUM_SAMPLES = 500_000

print(f"Loading {NUM_SAMPLES:,} samples from FineWeb-Edu (streaming)...")
dataset = load_dataset(
    "HuggingFaceFW/fineweb-edu",
    split="train",
    streaming=True
)

# Materialize the streaming dataset
subset_iter = dataset.take(NUM_SAMPLES)
data_list = [row for row in tqdm(subset_iter, total=NUM_SAMPLES, desc="Materializing dataset")]
print(f"Total samples: {len(data_list):,}")
print(f"Sample keys: {data_list[0].keys()}")

## 3. Tokenize the Dataset

We tokenize with a context length of 1024 tokens (GPT-2 Medium's full capacity).

In [None]:
MAX_LENGTH = 1024

def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=MAX_LENGTH
    )

# Convert to HuggingFace Dataset and tokenize in batches
print("Converting to HuggingFace Dataset...")
hf_dataset = Dataset.from_list(data_list)

print("Tokenizing...")
tokenized_dataset = hf_dataset.map(
    tokenize_function,
    batched=True,
    batch_size=1000,
    remove_columns=hf_dataset.column_names,
    desc="Tokenizing"
)

print(f"Tokenized dataset: {len(tokenized_dataset):,} samples")
print(f"Token sequence length: {MAX_LENGTH}")

## 4. Configure Training

Optimized for RTX 4060 8GB VRAM:
- **Batch size 1** + gradient accumulation 16 = effective batch of 16
- **Gradient checkpointing** trades compute for memory (~40% VRAM savings)
- **Adafactor optimizer** uses less memory than AdamW
- **FP16** mixed precision for speed + memory savings

In [None]:
output_dir = "out/models/gpt2_medium_finetuned"
os.makedirs(output_dir, exist_ok=True)

# Data collator handles creating labels (shifted input_ids) automatically
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False  # causal LM, not masked LM
)

training_args = TrainingArguments(
    output_dir=output_dir,
    
    # Batch size — optimized for 8GB VRAM
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,  # effective batch size = 16
    
    # Learning rate — lower for fine-tuning (not training from scratch)
    learning_rate=2e-5,
    warmup_steps=500,
    weight_decay=0.01,
    
    # Training duration
    num_train_epochs=1,
    
    # Memory optimization
    fp16=True,
    gradient_checkpointing=True,
    optim="adafactor",
    
    # Logging & saving
    logging_steps=100,
    save_steps=2000,
    save_total_limit=3,
    report_to="none",
    
    # Dataloader
    dataloader_num_workers=2,
    dataloader_pin_memory=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)

print(f"Training config:")
print(f"  Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"  Total steps: {len(tokenized_dataset) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps)}")
print(f"  Optimizer: {training_args.optim}")
print(f"  FP16: {training_args.fp16}")
print(f"  Gradient checkpointing: {training_args.gradient_checkpointing}")

## 5. Train

In [None]:
print("Starting fine-tuning...")
print("This will take several hours on RTX 4060. Progress is logged every 100 steps.")
trainer.train()

## 6. Save Model & Tokenizer

In [None]:
print("Saving fine-tuned model and tokenizer...")
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"Model saved to: {output_dir}")
print("Fine-tuning complete!")

## 7. Build RAG Knowledge Base

Index a portion of the dataset into a FAISS vector store so the chatbot can retrieve relevant passages in real-time.
We use `sentence-transformers/all-MiniLM-L6-v2` for fast, lightweight embeddings.

In [None]:
import numpy as np

try:
    import faiss
    from sentence_transformers import SentenceTransformer
    HAS_RAG_DEPS = True
except ImportError:
    HAS_RAG_DEPS = False
    print("RAG dependencies not installed. Run:")
    print("  pip install faiss-cpu sentence-transformers")
    print("Then re-run this cell.")

In [None]:
if HAS_RAG_DEPS:
    RAG_SAMPLES = 100_000  # number of passages to index
    MAX_PASSAGE_LEN = 500  # characters per passage
    RAG_DIR = "out/rag_index"
    os.makedirs(RAG_DIR, exist_ok=True)

    print(f"Building RAG index from {RAG_SAMPLES:,} samples...")

    # Extract and chunk passages
    passages = []
    for row in tqdm(data_list[:RAG_SAMPLES], desc="Extracting passages"):
        text = row["text"].strip()
        # Split long documents into chunks
        for i in range(0, len(text), MAX_PASSAGE_LEN):
            chunk = text[i:i + MAX_PASSAGE_LEN].strip()
            if len(chunk) > 50:  # skip tiny fragments
                passages.append(chunk)

    print(f"Total passages: {len(passages):,}")

    # Embed passages
    print("Loading sentence-transformer model...")
    embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

    print("Encoding passages (this may take a few minutes)...")
    embeddings = embedder.encode(
        passages,
        show_progress_bar=True,
        batch_size=256,
        convert_to_numpy=True
    )

    # Build FAISS index
    print("Building FAISS index...")
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatIP(dimension)  # Inner Product (cosine similarity on normalized vectors)
    faiss.normalize_L2(embeddings)  # normalize for cosine similarity
    index.add(embeddings)

    # Save index and passages
    faiss.write_index(index, os.path.join(RAG_DIR, "faiss_index.bin"))
    np.save(os.path.join(RAG_DIR, "passages.npy"), np.array(passages, dtype=object))

    print(f"RAG index saved to {RAG_DIR}/")
    print(f"  Index: {index.ntotal:,} vectors, {dimension}D")
    print(f"  Passages: {len(passages):,}")
    print("Done!")
else:
    print("Skipping RAG index build (dependencies not installed).")
    print("You can build it later by running: python build_rag_index.py")