# Building GPT2 from scratch

Source: 
* GPT2 min: https://github.com/The-Pocket/PocketFlow-Tutorial-Video-Generator/blob/main/docs/llm/transformer.md
* KV caching: https://github.com/The-Pocket/PocketFlow-Tutorial-Video-Generator/blob/main/docs/llm/kv_cache.md

## Model

In [9]:
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math
import numpy as np
import torchvision
import torch
import lightning as L
from lightning.pytorch.loggers import CSVLogger
from dataclasses import dataclass
import importlib
import shared_utilities
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
import os

# Fix 2: Suppress tokenizer warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# This command forces Python to re-read the .py file from the disk
importlib.reload(shared_utilities)

from shared_utilities import HFTextDataModule, LightningModel

In [2]:
import torch.nn as nn
import torch.nn.functional as F

@dataclass
class GPTConfig:
    vocab_size: int
    block_size: int
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.1

class CausalSelfAttention(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.n_head, self.n_embd = config.n_head, config.n_embd
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.resid_drop = nn.Dropout(config.dropout)
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
    
    def forward(self, x, past_kv=None):
        B, T, C = x.size() # Note: T is the new sequence length, usually 1 during generation
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        head_dim = C // self.n_head
        q = q.view(B, T, self.n_head, head_dim).transpose(1, 2) # (B, nh, T, hs)
        k = k.view(B, T, self.n_head, head_dim).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, head_dim).transpose(1, 2) # (B, nh, T, hs)

        if past_kv is not None:
            past_k, past_v = past_kv
            k = torch.cat((past_k, k), dim=-2) # Concatenate along the sequence length dimension
            v = torch.cat((past_v, v), dim=-2)

        present_kv = (k, v)
        T_total = k.size(-2)
        # Perform the attention calculation
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(head_dim))
        att = att.masked_fill(self.bias[:, :, T_total-T:T_total, :T_total] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.resid_drop(self.c_proj(y)), present_kv

class MLP(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.proj = nn.Linear(config.n_embd * 4, config.n_embd)
        self.drop = nn.Dropout(config.dropout)
    
    def forward(self, x):
        x = self.fc(x)
        x = F.gelu(x)
        x = self.proj(x)
        x = self.drop(x)
        return x

class Block(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)
    
    def forward(self, x):
        attn_output, _ = self.attn(self.ln_1(x), past_kv=None)  # Ignore KV cache
        x = x + attn_output
        x = x + self.mlp(self.ln_2(x))
        return x

class GPT2(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.block_size, config.n_embd)
        self.drop = nn.Dropout(config.dropout)
        self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.lm_head.weight = self.wte.weight
    
    def forward(self, idx, targets=None):
        B, T = idx.size()
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device).unsqueeze(0)
        x = self.wte(idx) + self.wpe(pos)
        x = self.drop(x)
        for block in self.h:
            x = block(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) if targets is not None else None
        return logits, loss
    
    @torch.no_grad()
    def generate(self, idx, max_new_tokens=50, temperature=1.0, top_k=None):
        past_kv = None
        
        for _ in range(max_new_tokens):
            # Only pass new token after first iteration
            idx_cond = idx if past_kv is None else idx[:, -1:]
            idx_cond = idx_cond[:, -self.config.block_size:]
            
            # Forward with cache
            logits, past_kv = self.forward_with_cache(idx_cond, past_kv)
            logits = logits[:, -1, :] / max(temperature, 1e-8)
            
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('inf')
            
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_token), dim=1)
        
        return idx

    def forward_with_cache(self, idx, past_kv=None):
        B, T = idx.size()
        
        # Adjust position embeddings based on cache
        pos_offset = 0 if past_kv is None else past_kv[0][0].shape[-2]
        pos = torch.arange(pos_offset, pos_offset + T, device=idx.device)
        
        x = self.wte(idx) + self.wpe(pos)
        x = self.drop(x)
        
        new_past_kv = []
        for i, block in enumerate(self.h):
            layer_past = past_kv[i] if past_kv else None
            x = x + block.ln_1(x)
            x, present = block.attn(x, layer_past)
            new_past_kv.append(present)
            x = x + block.mlp(block.ln_2(x))
        
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        return logits, new_past_kv

In [3]:
# Initialize data module
dm = HFTextDataModule(
    dataset_name="wikitext",
    dataset_config="wikitext-2-raw-v1",
    batch_size=24,
    block_size=256,
    num_workers=8,
    persistent_workers=True,
    prefetch_factor=4
)
dm.prepare_data()
# Setup tokenization and datasets
dm.setup()

Tokenizing dataset...

✅ Dataset loaded:
   Vocab size: 50,257
   Train tokens: 2,391,884
   Val tokens: 247,289
   Test tokens: 283,287
   Train samples: 2,391,628
   Val samples: 247,033


In [4]:
print("=" * 80)
print("INSPECTING FIRST 5 SAMPLES FROM WIKITEXT-2")
print("=" * 80)

# Basic statistics
print(f"\nDataset Info:")
print(f"  Vocab size: {len(dm.tokenizer):,}")
print(f"  Train samples: {len(dm.train):,}")
print(f"  Block size: {dm.block_size}")

# Inspect first 5 samples
print("\n" + "=" * 80)
for i in range(5):
    inputs, targets = dm.train[i]
    
    print(f"\nSample {i+1}:")
    print(f"  Input shape:  {inputs.shape}")
    print(f"  Target shape: {targets.shape}")
    
    # ✅ Decode tokens to text (HuggingFace way)
    input_text = dm.tokenizer.decode(inputs[:50])
    target_text = dm.tokenizer.decode(targets[:50])
    
    print(f"\n  Input text (first 50 tokens):")
    print(f"    {input_text}")
    
    print(f"\n  Target text (first 50 tokens, shifted by 1):")
    print(f"    {target_text}")
    
    print("-" * 80)

INSPECTING FIRST 5 SAMPLES FROM WIKITEXT-2

Dataset Info:
  Vocab size: 50,257
  Train samples: 2,391,628
  Block size: 256


Sample 1:
  Input shape:  torch.Size([256])
  Target shape: torch.Size([256])

  Input text (first 50 tokens):
     = Valkyria Chronicles III = 
 Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) ,

  Target text (first 50 tokens, shifted by 1):
     Valkyria Chronicles III = 
 Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly
--------------------------------------------------------------------------------

Sample 2:
  Input shape:  torch.Size([256])
  Target shape: torch.Size([256])

  Input text (first 50 tokens):
     Valkyria Chronicles III = 
 Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly

  Target text (first 50 tokens, shifted by 1):
    alkyria Chronicles

In [5]:
%%capture --no-display

L.pytorch.seed_everything(123)
# Fix 1: Enable Tensor Cores (2x speedup)
torch.set_float32_matmul_precision('high')

# pytorch_model = GPT2(GPTConfig(vocab_size=len(dm.tokenizer), block_size=256))
# Fix 3: Compile the model (1.5-2x speedup)
config = GPTConfig(vocab_size=len(dm.tokenizer), block_size=256)
pytorch_model = GPT2(config)
# pytorch_model = torch.compile(pytorch_model, mode='reduce-overhead')  # Add this

lightning_model = LightningModel(
    model=pytorch_model,
    learning_rate=3e-4,
    warmup_steps=500
)

# Callbacks
checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints",
    filename="gpt2-{epoch:02d}-{val_loss:.2f}",
    monitor="val_loss",
    mode="min",
    save_top_k=3,
)

# Disabling it because can't be used withou Logger
#lr_monitor = LearningRateMonitor(logging_interval="step")

trainer = L.Trainer(
    max_epochs=1,
    accelerator="gpu",
    devices=1,
    precision="bf16-mixed",        # Add this - 40% memory savings
    accumulate_grad_batches=4,     # Add this - effective batch size 64
    callbacks=[checkpoint_callback],
    logger=False, # skipping due to problems with CSVLogger(save_dir="logs/", name="gpt2_training"),
    gradient_clip_val=1.0,
    log_every_n_steps=50,
    val_check_interval=0.25,  # Validate 4 times per epoch
    enable_progress_bar=True,
    enable_model_summary=True
)

Seed set to 123
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [6]:
trainer.fit(model=lightning_model, datamodule=dm)

Tokenizing dataset...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/borja-dosuna/miniconda3/envs/vit-env/lib/python3.13/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py:242: Precision bf16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name  | Type | Params | Mode  | FLOPs
-----------------------------------------------
0 | model | GPT2 | 123 M  | train | 0    
-----------------------------------------------
123 M     Trainable params
0         Non-trainable params
123 M     Total params
495.400   Total estimated model params size (MB)
139       Modules in train mode
0         Modules in eval mode
0         Total Flops



✅ Dataset loaded:
   Vocab size: 50,257
   Train tokens: 2,391,884
   Val tokens: 247,289
   Test tokens: 283,287
   Train samples: 2,391,628
   Val samples: 247,033


Sanity Checking: |                                                                                            …

Creating val dataloader...
Val dataloader created with 10294 batches
Creating train dataloader...
Train dataloader created with 99652 batches


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=1` reached.


In [10]:
trainer.test(model=lightning_model, datamodule=dm)

MisconfigurationException: No `test_step()` method defined to run `Trainer.test`.