# Colab Setup

**Step 1: Install & Auto-Restart**
Run the installation cell below. If it says "RESTARTING RUNTIME", let it finish, then start from **Step 2**.
This is necessary for the GPU drivers to recognize the 4-bit quantization package.

In [None]:
import os
os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
# Force install the exact versions required for QLoRA
# Note: Adafactor is now part of transformers/optimum or installed via source if unavailable on PyPI directly for some python versions
# We'll install it separately if needed, but 'transformers' usually handles it.
!pip install -q -U "bitsandbytes>=0.46.1" "accelerate>=1.0.0" peft transformers datasets tqdm faiss-cpu sentence-transformers

import bitsandbytes
import os
from packaging import version

required_version = "0.46.1"
current_version = bitsandbytes.__version__

if version.parse(current_version) < version.parse(required_version):
    print(f"\n[!] bitsandbytes version {current_version} is too old. RESTARTING RUNTIME...")
    os.kill(os.getpid(), 9)
else:
    print(f"\n[OK] bitsandbytes {current_version} is ready!")

In [None]:
# Mount Google Drive to save the model permanently
from google.colab import drive
drive.mount('/content/drive')

# Fine-Tune GPT-2 Large on FineWeb-Edu (QLoRA Version)

This notebook fine-tunes **GPT-2 Large (774M params)** using **QLoRA (4-bit Quantized LoRA)**.
**Why QLoRA?** Standard GPT-2 Large training hits OOM on T4 (15GB). QLoRA reduces the base model footprint to ~500MB, allowing full training with 1024 context length comfortably.

In [None]:
import torch
from datasets import load_dataset
from transformers import (
    GPT2LMHeadModel,
    GPT2TokenizerFast,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
from tqdm import tqdm

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

## 1. Load Pre-trained GPT-2 Large Model in 4-bit

In [None]:
MODEL_NAME = "gpt2-large"

print(f"Loading model {MODEL_NAME} in 4-bit...")
tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

# 4-bit Quantization Config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

model = GPT2LMHeadModel.from_pretrained(
    MODEL_NAME, 
    quantization_config=bnb_config,
    device_map={ "": 0 } 
)

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

print("Applying LoRA configuration...")
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, 
    inference_mode=False, 
    r=16, 
    lora_alpha=32, 
    lora_dropout=0.05,
    target_modules=["c_attn"] 
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

print(f"Max context length: {model.config.n_positions}")

## 2. Prepare Streaming Dataset

In [None]:
NUM_SAMPLES = 1_000_000
MAX_LENGTH = 1024 

print(f"Configuring stream for {NUM_SAMPLES:,} samples...")

raw_dataset = load_dataset(
    "HuggingFaceFW/fineweb-edu",
    split="train",
    streaming=True
)

def tokenize_stream(examples):
    tokenized = tokenizer(
        examples["text"],
        truncation=True,
        max_length=MAX_LENGTH,
        padding="max_length"
    )
    return {
        "input_ids": tokenized["input_ids"],
        "attention_mask": tokenized["attention_mask"]
    }

sample = next(iter(raw_dataset))
all_columns = list(sample.keys())

tokenized_dataset = raw_dataset.map(
    tokenize_stream, 
    batched=True, 
    remove_columns=all_columns, 
    batch_size=1000
)

shuffled_dataset = tokenized_dataset.shuffle(seed=42, buffer_size=10_000).take(NUM_SAMPLES)
print("Streaming dataset configured!")

## 3. Configure Training with Drive Checkpoints

In [None]:
output_dir = "/content/drive/MyDrive/fineweb_edu_gpt2_large/checkpoints"
os.makedirs(output_dir, exist_ok=True)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

BATCH_SIZE = 1
GRAD_ACCUM = 8
TOTAL_STEPS = NUM_SAMPLES // (BATCH_SIZE * GRAD_ACCUM)

training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=2e-4,
    max_steps=TOTAL_STEPS,
    fp16=True,             
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    logging_steps=100,
    save_steps=500,
    save_total_limit=3,
    report_to="none",
    remove_unused_columns=False
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=shuffled_dataset,
    data_collator=data_collator,
)

## 4. Train (Smart Auto-Resume)

In [None]:
from transformers.trainer_utils import get_last_checkpoint
last_checkpoint = get_last_checkpoint(output_dir)

if last_checkpoint is not None:
    print(f"Checkpoint detected: {last_checkpoint}. Resuming...")
    trainer.train(resume_from_checkpoint=last_checkpoint)
else:
    print("Starting fresh QLoRA training...")
    trainer.train()

In [None]:
final_model_dir = "/content/drive/MyDrive/fineweb_edu_gpt2_large/final_model"
print(f"Saving final LoRA adapters to: {final_model_dir}...")
trainer.save_model(final_model_dir)
tokenizer.save_pretrained(final_model_dir)
print("Training Complete!")

## 5. Build RAG Index

In [None]:
import faiss
from sentence_transformers import SentenceTransformer
import numpy as np

RAG_SAMPLES = 100_000
RAG_DIR = "/content/drive/MyDrive/fineweb_edu_gpt2_large/rag_index"
os.makedirs(RAG_DIR, exist_ok=True)

passages = []
rag_stream = raw_dataset.take(RAG_SAMPLES)

print("Extracting passages from stream...")
for row in tqdm(rag_stream, total=RAG_SAMPLES):
    text = row["text"].strip()
    for i in range(0, len(text), 500):
        chunk = text[i:i + 500].strip()
        if len(chunk) > 50: passages.append(chunk)

embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
embeddings = embedder.encode(passages, show_progress_bar=True, batch_size=256, convert_to_numpy=True)

index = faiss.IndexFlatIP(embeddings.shape[1])
faiss.normalize_L2(embeddings)
index.add(embeddings)

faiss.write_index(index, os.path.join(RAG_DIR, "faiss_index.bin"))
np.save(os.path.join(RAG_DIR, "passages.npy"), np.array(passages, dtype=object))
print("RAG Index Built!")