# GPT-2 From Scratch â€” Training

Train GPT-2 Small (124M parameters) on WikiText-2 from scratch.

**Architecture**: 12 layers, 768 hidden dim, 12 attention heads, 512 context length  
**Dataset**: WikiText-2 (~2.4M tokens from Wikipedia)  
**Hardware**: Works on CUDA (T4/Colab), MPS (Apple Silicon), or CPU

## 1. Setup

In [None]:
!pip install -q torch tiktoken datasets numpy

In [None]:
!git clone https://github.com/manojkgorle/compute-guzzler-1.git gpt2-vc 2>/dev/null || echo "Already cloned"
%cd gpt2-vc

In [None]:
import torch
from config import GPT2Config, TrainConfig, get_device
from model import GPT2
from data import create_dataloaders
from train import train

device = get_device()
print(f"Device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

## 2. Configuration

In [None]:
config = GPT2Config()
train_config = TrainConfig(
    max_epochs=30,
    batch_size=8,
    learning_rate=3e-4,
    device=device,
)

print(f"Context length: {config.context_length}")
print(f"Batch size: {train_config.batch_size} "
      f"(effective: {train_config.batch_size * train_config.gradient_accumulation_steps})")
print(f"Epochs: {train_config.max_epochs}")
print(f"Learning rate: {train_config.learning_rate}")

## 3. Initialize Model

In [None]:
model = GPT2(config)

## 4. Prepare Data

In [None]:
train_loader, val_loader = create_dataloaders(config, train_config)

## 5. Train

On CUDA this automatically enables:
- `torch.compile` (kernel fusion)
- `float16` autocast (mixed precision via tensor cores)
- `GradScaler` (prevents float16 gradient underflow)

In [None]:
train(model, train_loader, val_loader, train_config, config)

## 6. Quick Generation Test

Generate text from the just-trained model to verify it learned something.

In [None]:
from generate import generate

prompts = [
    "The meaning of life is",
    "In a shocking finding, scientists discovered",
    "The history of the United States",
]

for prompt in prompts:
    print(f"Prompt: {prompt}")
    print("-" * 60)
    output = generate(
        model, prompt=prompt,
        max_new_tokens=100, temperature=0.8,
        top_k=50, top_p=0.95, device=device,
    )
    print(output)
    print("=" * 60)
    print()