In [2]:
import math
from dataclasses import dataclass, asdict
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
@dataclass
class GPTConfig:
    block_size: int = 16      # small for debugging
    vocab_size: int = 100     # small vocab for debugging
    n_layer: int = 2
    n_head: int = 2
    n_embd: int = 32
    dropout: float = 0.0      # turn off dropout for determinism


### Attention building blocks (Head + MultiHeadAttention)


In [4]:
class Head(nn.Module):
    """One head of causal self-attention with debug storage.

    Input : (B, T, C = n_embd)
    Output: (B, T, head_size)
    """
    def __init__(self, head_size: int, config: GPTConfig):
        super().__init__()
        self.head_size = head_size

        self.key   = nn.Linear(config.n_embd, head_size, bias=False)
        self.query = nn.Linear(config.n_embd, head_size, bias=False)
        self.value = nn.Linear(config.n_embd, head_size, bias=False)

        # Causal mask for this head: (block_size, block_size)
        mask = torch.tril(torch.ones(config.block_size, config.block_size))
        self.register_buffer("tril", mask)

        # for debugging info
        self.debug = {}

    def forward(self, x):
        # x: (B, T, C)
        B, T, C = x.shape

        k = self.key(x)      # (B, T, head_size)
        q = self.query(x)    # (B, T, head_size)
        v = self.value(x)    # (B, T, head_size)

        # Attention logits: (B, T, T)
        logits = q @ k.transpose(-2, -1) * (self.head_size ** -0.5)

        # Causal masking: only attend to t' <= t
        mask = self.tril[:T, :T] == 0
        logits_masked = logits.masked_fill(mask, float('-inf'))

        # Softmax over last dimension
        att = F.softmax(logits_masked, dim=-1)  # (B, T, T)

        # Weighted sum of values
        out = att @ v                           # (B, T, head_size)

        # Store debug tensors (on CPU) for later inspection / saving
        self.debug = {
            "x_in": x.detach().cpu(),               # (B, T, C)
            "q": q.detach().cpu(),                  # (B, T, hs)
            "k": k.detach().cpu(),                  # (B, T, hs)
            "v": v.detach().cpu(),                  # (B, T, hs)
            "logits_raw": logits.detach().cpu(),    # before mask
            "logits_masked": logits_masked.detach().cpu(),
            "att": att.detach().cpu(),              # softmax weights
            "out": out.detach().cpu(),              # head output
        }

        return out

In [5]:
class MultiHeadAttention(nn.Module):
    """Multiple heads of self-attention in parallel."""

    def __init__(self, config: GPTConfig):
        super().__init__()
        assert config.n_embd % config.n_head == 0, "n_embd must be divisible by n_head"
        self.n_head = config.n_head
        head_size = config.n_embd // config.n_head

        # Independent heads, each with its own K/Q/V and mask
        self.heads = nn.ModuleList([Head(head_size, config) for _ in range(config.n_head)])
        # Final projection back to model dimension
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        # Concatenate head outputs on the channel dimension
        out = torch.cat([h(x) for h in self.heads], dim=-1)  # (B, T, n_embd)
        out = self.proj(out)
        out = self.dropout(out)
        return out

### MLP

In [6]:
class MLP(nn.Module):
    """Feed-forward network used inside each transformer block."""
    def __init__(self, config: GPTConfig):
        super().__init__()
        hidden_dim = 4 * config.n_embd  # GPT-2 uses 4x expansion
        self.c_fc = nn.Linear(config.n_embd, hidden_dim)
        self.c_proj = nn.Linear(hidden_dim, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = F.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

### Transformer Decoder block

In [7]:
class Block(nn.Module):
    """Transformer Decoder block: pre-LN + self-attention + MLP with residuals."""
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = MultiHeadAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        # Pre-LN + residual connection around attention
        x = x + self.attn(self.ln_1(x))
        # Pre-LN + residual connection around MLP
        x = x + self.mlp(self.ln_2(x))
        return x

### GPT model


In [8]:
class GPT(nn.Module):
    """GPT-2 style language model."""
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config

        # Transformer body
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),   # token embeddings
            wpe = nn.Embedding(config.block_size, config.n_embd),   # positional embeddings
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),                     # final layer norm
        ))
        # Language modeling head
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # Weight tying: share token embedding weights with lm_head
        self.transformer.wte.weight = self.lm_head.weight

        # GPT-2 style parameter initialization
        self.apply(self._init_weights)
        # Special scaled init for residual projections (as in GPT-2)
        for name, p in self.named_parameters():
            if name.endswith("c_proj.weight"):
                nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        """
        idx: LongTensor of shape (B, T)
        targets: Optional LongTensor of shape (B, T)
        Returns:
            logits: (B, T, vocab_size)
            loss: scalar or None
        """
        B, T = idx.size()
        if T > self.config.block_size:
            raise ValueError(
                f"Cannot forward sequence of length {T}, "
                f"block size is only {self.config.block_size}"
            )

        # Token and positional embeddings
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)  # (T,)
        tok_emb = self.transformer.wte(idx)        # (B, T, n_embd)
        pos_emb = self.transformer.wpe(pos)        # (T, n_embd)
        x = tok_emb + pos_emb                      # (B, T, n_embd)

        # Transformer blocks
        for block in self.transformer.h:
            x = block(x)

        # Final layernorm
        x = self.transformer.ln_f(x)

        # Language modeling head
        logits = self.lm_head(x)                   # (B, T, vocab_size)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1)
            )

        return logits, loss

In [9]:
torch.manual_seed(0)

# Tiny config for debugging
config = GPTConfig(
    block_size=8,
    vocab_size=50,
    n_layer=2,
    n_head=2,
    n_embd=16,
    dropout=0.0,
)

model = GPT(config)
model.eval()  # turn off dropout etc.

# Fixed tiny input
# Shape: (B=1, T=5)
idx = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long)

with torch.no_grad():
    logits, loss = model(idx, targets=idx)

print("=== Forward pass ===")
print("Input idx:", idx)
print("Logits shape:", logits.shape)
print("Loss:", loss.item())

# Access debug info for block 0, head 0
block0 = model.transformer.h[0]
head00 = block0.attn.heads[0]
dbg = head00.debug

print("\n=== Debug: Block 0, Head 0 ===")
print("x_in shape:", dbg["x_in"].shape)
print("q shape:", dbg["q"].shape)
print("k shape:", dbg["k"].shape)
print("v shape:", dbg["v"].shape)
print("logits_raw shape:", dbg["logits_raw"].shape)
print("logits_masked shape:", dbg["logits_masked"].shape)
print("att (softmax) shape:", dbg["att"].shape)
print("out shape:", dbg["out"].shape)

# Print small slices so it doesn't spam
print("\nq[0, :, :]:\n", dbg["q"][0])
print("\natt[0, :, :]:\n", dbg["att"][0])
print("\nout[0, :, :]:\n", dbg["out"][0])

# Save everything you might want to mirror in C++
torch.save({
    "config": asdict(config),
    "idx": idx.cpu(),
    "state_dict": model.state_dict(),
    "logits": logits.cpu(),
    "loss": loss.cpu(),
    "block0_head0_debug": dbg,
}, "gpt_debug_forward.pt")

print("\nSaved debug tensors to gpt_debug_forward.pt")

=== Forward pass ===
Input idx: tensor([[1, 2, 3, 4, 5]])
Logits shape: torch.Size([1, 5, 50])
Loss: 3.7472751140594482

=== Debug: Block 0, Head 0 ===
x_in shape: torch.Size([1, 5, 16])
q shape: torch.Size([1, 5, 8])
k shape: torch.Size([1, 5, 8])
v shape: torch.Size([1, 5, 8])
logits_raw shape: torch.Size([1, 5, 5])
logits_masked shape: torch.Size([1, 5, 5])
att (softmax) shape: torch.Size([1, 5, 5])
out shape: torch.Size([1, 5, 8])

q[0, :, :]:
 tensor([[ 0.0433, -0.1484, -0.0681,  0.1077,  0.1292, -0.0265, -0.0527,  0.0224],
        [ 0.0530, -0.2151, -0.0018,  0.0051, -0.0335,  0.0242,  0.0804, -0.0540],
        [ 0.0262,  0.0207,  0.0550,  0.0878,  0.0659, -0.0131, -0.0234, -0.0209],
        [ 0.0111, -0.0153, -0.0934, -0.0422,  0.0706, -0.0545, -0.0838, -0.0405],
        [-0.0250, -0.1596,  0.0500, -0.0421, -0.0925, -0.0238,  0.0721, -0.2070]])

att[0, :, :]:
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5020, 0.4980, 0.0000, 0.0000, 0.0000],
        [0.3339, 0.3