# GPT-2 From Scratch — Training on OpenWebText

**Setup:** Upload your repo as a Kaggle dataset (or use git clone). Make sure `data/vocab/vocab.json` and `data/vocab/merges.txt` are included.

**Two-phase approach:**
1. **Phase 1 (CPU session):** Download OWT + tokenize → save train.bin/val.bin as Kaggle dataset
2. **Phase 2 (GPU session):** Train GPT-2 on the tokenized data

The custom BPE tokenizer is pure Python, so tokenization is slow. Run it once on CPU, save the result, then use GPU sessions for training only.

In [None]:
import os, sys

# If repo is uploaded as Kaggle dataset, adjust this path
REPO_DIR = '/kaggle/input/mygpt2'  # <-- change if needed
WORK_DIR = '/kaggle/working'

# Copy repo to working dir so we can write files
if os.path.exists(REPO_DIR) and not os.path.exists(f'{WORK_DIR}/model'):
    !cp -r {REPO_DIR}/* {WORK_DIR}/

os.chdir(WORK_DIR)
sys.path.insert(0, WORK_DIR)
print('Working dir:', os.getcwd())
print('Files:', os.listdir('.'))

In [None]:
!pip install -q datasets tqdm

## Phase 1: Download & Tokenize OpenWebText (CPU session)
Skip this phase if you already have `train.bin` and `val.bin` uploaded as a dataset.

In [None]:
# Download OpenWebText
from datasets import load_dataset

ds = load_dataset('openwebtext', trust_remote_code=True)
print(f"Total documents: {len(ds['train']):,}")

In [None]:
import numpy as np
from tqdm import tqdm
from tokenizer.bpe import BPETokenizer

tokenizer = BPETokenizer()
tokenizer.load_vocab_merges('data/vocab/vocab.json', 'data/vocab/merges.txt')

# Tokenize in chunks and write incrementally to avoid OOM
os.makedirs('data/tokenized', exist_ok=True)

split_idx = int(len(ds['train']) * 0.9)

def tokenize_split(dataset_slice, output_path):
    """Tokenize a slice of the dataset and write to a binary file."""
    all_tokens = []
    for i, example in enumerate(tqdm(dataset_slice, desc=f'Tokenizing {output_path}')):
        tokens = tokenizer.encode(example['text'])
        all_tokens.extend(tokens)
        all_tokens.append(50256)  # document separator
        
        # Flush periodically to manage memory
        if len(all_tokens) > 50_000_000:
            arr = np.array(all_tokens, dtype=np.uint16)
            with open(output_path, 'ab') as f:
                arr.tofile(f)
            all_tokens = []
    
    # Write remaining
    if all_tokens:
        arr = np.array(all_tokens, dtype=np.uint16)
        with open(output_path, 'ab') as f:
            arr.tofile(f)

# Clear any existing partial files
for f in ['data/tokenized/train.bin', 'data/tokenized/val.bin']:
    if os.path.exists(f):
        os.remove(f)

print(f'Train docs: {split_idx:,}, Val docs: {len(ds["train"]) - split_idx:,}')
tokenize_split(ds['train'].select(range(split_idx)), 'data/tokenized/train.bin')
tokenize_split(ds['train'].select(range(split_idx, len(ds['train']))), 'data/tokenized/val.bin')

train_size = os.path.getsize('data/tokenized/train.bin') // 2  # uint16 = 2 bytes
val_size = os.path.getsize('data/tokenized/val.bin') // 2
print(f'Train tokens: {train_size:,}')
print(f'Val tokens: {val_size:,}')

**After tokenization:** Save `data/tokenized/train.bin` and `val.bin` as a new Kaggle dataset so you don't have to re-tokenize for GPU sessions.

---

## Phase 2: Train GPT-2 (GPU session)
If tokenized data is a separate dataset, update `DATA_DIR` below.

In [None]:
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'GPU memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')

In [None]:
import torch
from torch.utils.data import DataLoader

from model.ModelConfig import ModelConfig
from model.TrainingConfig import TrainingConfig
from model.gpt2 import GPT2
from training.dataset import TextDataset
from training.loss import compute_loss
from training.optimizer import configure_optimizer
from training.scheduler import CosineAnnealingScheduler
from training.trainer import Trainer

# --- Config ---
DATA_DIR = 'data/tokenized'            # or '/kaggle/input/your-tokenized-dataset'
CHECKPOINT_DIR = '/kaggle/working/checkpoints'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
USE_AMP = True                          # fp16 mixed precision
USE_COMPILE = True                      # torch.compile

EPOCHS = 3
BATCH_SIZE = 8                          # adjust for GPU memory (T4=16GB, P100=16GB)
BLOCK_SIZE = 1024
LR = 6e-4
MIN_LR = 6e-5
WARMUP_STEPS = 2000
WEIGHT_DECAY = 0.1

print(f'Device: {DEVICE}, AMP: {USE_AMP}, Compile: {USE_COMPILE}')

In [None]:
# Datasets & loaders
train_dataset = TextDataset(f'{DATA_DIR}/train.bin', block_size=BLOCK_SIZE)
val_dataset = TextDataset(f'{DATA_DIR}/val.bin', block_size=BLOCK_SIZE)
print(f'Train samples: {len(train_dataset):,}')
print(f'Val samples: {len(val_dataset):,}')

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=2, pin_memory=True)

In [None]:
# Model
model = GPT2(ModelConfig(), TrainingConfig())
param_count = sum(p.numel() for p in model.parameters())
print(f'Model parameters: {param_count:,}')

if USE_COMPILE and hasattr(torch, 'compile'):
    print('Compiling model...')
    model = torch.compile(model)

model = model.to(DEVICE)

In [None]:
# Optimizer & scheduler
optimizer = configure_optimizer(model, lr=LR, weight_decay=WEIGHT_DECAY)
total_steps = len(train_loader) * EPOCHS
scheduler = CosineAnnealingScheduler(
    optimizer, warmup_steps=WARMUP_STEPS, total_steps=total_steps,
    max_lr=LR, min_lr=MIN_LR
)
print(f'Total training steps: {total_steps:,}')

In [None]:
# Train
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    validation_loader=val_loader,
    scheduler=scheduler,
    train_config=TrainingConfig(),
    optimizer=optimizer,
    loss=compute_loss,
    device=DEVICE,
    use_amp=USE_AMP,
)

# Uncomment to resume from checkpoint:
# trainer.load_checkpoint('/kaggle/working/checkpoints/best_model.pt')

best_val_loss = trainer.train(num_epochs=EPOCHS, checkpoint_dir=CHECKPOINT_DIR)
print(f'\nBest validation loss: {best_val_loss:.4f}')