# Colab Setup

In [None]:
import os
os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = '1'
# Prevent tokenizers parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
# Install dependencies (Colab environments usually need these)
!pip install datasets transformers tqdm faiss-cpu sentence-transformers adafactor

In [None]:
# Mount Google Drive to save the model permanently
from google.colab import drive
drive.mount('/content/drive')

# Fine-Tune GPT-2 Large on FineWeb-Edu (Streaming Model)

This notebook fine-tunes **GPT-2 Large (774M params)** using **Streaming Datasets**.
This method uses minimal RAM because it never loads the full dataset into memory.

**Hardware**: Optimized for **T4 GPU (15GB VRAM)** on Google Colab.

In [None]:
import torch
from datasets import load_dataset, IterableDataset
from transformers import (
    GPT2LMHeadModel,
    GPT2TokenizerFast,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling
)
from tqdm import tqdm

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

## 1. Load Pre-trained GPT-2 Large Model & Tokenizer

In [None]:
MODEL_NAME = "gpt2-large"

print(f"Loading pre-trained model: {MODEL_NAME}...")
tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

model = GPT2LMHeadModel.from_pretrained(MODEL_NAME)
model.gradient_checkpointing_enable()

num_params = sum(p.numel() for p in model.parameters())
print(f"Model loaded: {num_params / 1e6:.1f}M parameters")
print(f"Max context length: {model.config.n_positions}")

## 2. Prepare Streaming Dataset

Instead of downloading everything to RAM, we stream the dataset and tokenize on-the-fly.
This fixes the "Resource Exhausted" (RAM crash) error.

In [None]:
NUM_SAMPLES = 1_000_000
MAX_LENGTH = 1024

print(f"Configuring stream for {NUM_SAMPLES:,} samples...")

# 1. Load streaming dataset
raw_dataset = load_dataset(
    "HuggingFaceFW/fineweb-edu",
    split="train",
    streaming=True
)

# 2. Define tokenization function for streaming
def tokenize_stream(examples):
    tokenized = tokenizer(
        examples["text"],
        truncation=True,
        max_length=MAX_LENGTH,
        padding="max_length"
    )
    # Crucial: Return ONLY the tokenized fields
    return {
        "input_ids": tokenized["input_ids"],
        "attention_mask": tokenized["attention_mask"]
    }

# 3. Fetch all raw columns to remove
sample = next(iter(raw_dataset))
all_columns = list(sample.keys())
print(f"Removing raw columns: {all_columns}")

# 4. Apply mapping and explicitly remove columns
tokenized_dataset = raw_dataset.map(
    tokenize_stream, 
    batched=True, 
    remove_columns=all_columns,  # Must remove ALL original columns
    batch_size=1000
)

# 5. Extra safety: select only tensor-compatible columns
# (IterableDataset doesn't support .select_columns(), but the map above should handle it)
shuffled_dataset = tokenized_dataset.shuffle(seed=42, buffer_size=10_000).take(NUM_SAMPLES)

print("Streaming dataset configured!")

## 3. Configure Training (Iterable Support)

We use `max_steps` usage instead of `num_train_epochs` because streaming datasets don't have a fixed length.

In [None]:
output_dir = "out/models/gpt2_large_finetuned"
os.makedirs(output_dir, exist_ok=True)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Calculate steps: total_samples / (batch_size * grad_accum)
BATCH_SIZE = 1
GRAD_ACCUM = 8
TOTAL_STEPS = NUM_SAMPLES // (BATCH_SIZE * GRAD_ACCUM)

training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=2e-5,
    max_steps=TOTAL_STEPS,  # Must use max_steps for streaming
    fp16=True,
    gradient_checkpointing=True,
    optim="adafactor",
    logging_steps=100,
    save_steps=2000,
    save_total_limit=2,
    report_to="none",
    remove_unused_columns=False  # Important for streaming
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=shuffled_dataset,
    data_collator=data_collator,
)

print(f"Training config:")
print(f"  Total Steps: {TOTAL_STEPS:,}")
print(f"  Effective Batch Size: {BATCH_SIZE * GRAD_ACCUM}")

## 4. Train

In [None]:
print("Starting streaming fine-tuning...")
trainer.train()

In [None]:
print("Saving model and tokenizer...")
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"Model saved to: {output_dir}")

## 5. Build RAG Index (Streaming)
We also stream the data for indexing to save RAM.

In [None]:
import faiss
from sentence_transformers import SentenceTransformer
import numpy as np

RAG_SAMPLES = 100_000
RAG_DIR = "out/rag_index"
os.makedirs(RAG_DIR, exist_ok=True)

passages = []
# Streaming iteration for RAG
rag_stream = raw_dataset.take(RAG_SAMPLES)

print("Extracting passages from stream...")
for row in tqdm(rag_stream, total=RAG_SAMPLES):
    text = row["text"].strip()
    for i in range(0, len(text), 500):
        chunk = text[i:i + 500].strip()
        if len(chunk) > 50: passages.append(chunk)

print(f"Total passages: {len(passages):,}")

embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
embeddings = embedder.encode(passages, show_progress_bar=True, batch_size=256, convert_to_numpy=True)

index = faiss.IndexFlatIP(embeddings.shape[1])
faiss.normalize_L2(embeddings)
index.add(embeddings)

faiss.write_index(index, os.path.join(RAG_DIR, "faiss_index.bin"))
np.save(os.path.join(RAG_DIR, "passages.npy"), np.array(passages, dtype=object))
print("RAG Index Built!")

## 6. Save to Google Drive

In [None]:
import shutil
drive_path = "/content/drive/MyDrive/fineweb_edu_gpt2_large"
os.makedirs(drive_path, exist_ok=True)

print("Copying model to Google Drive...")
shutil.copytree(output_dir, os.path.join(drive_path, "model"), dirs_exist_ok=True)
shutil.copytree(RAG_DIR, os.path.join(drive_path, "rag_index"), dirs_exist_ok=True)
print(f"All files saved to: {drive_path}")