4.1

In [2]:
import torch
import torch.nn as nn

In [3]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,
    "context_length": 1024,
    "emb_dim": 768,
    "n_heads": 12,
    "n_layers": 12,
    "drop_rate": 0.1,
    "qkv_bias": False
}

mha module

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_dim, n_heads):
        super().__init__()
        self.W_Q = nn.Linear(emb_dim, emb_dim)
        self.W_K = nn.Linear(emb_dim, emb_dim)
        self.W_V = nn.Linear(emb_dim, emb_dim)
        self.out_proj = nn.Linear(emb_dim, emb_dim)
    def forward(self, x):
        return self.out_proj(x)

ff module

In [5]:
class FeedForward(nn.Module):
    def __init__(self, emb_dim, expansion=4):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(emb_dim, expansion * emb_dim),
            nn.GELU(),
            nn.Linear(expansion * emb_dim, emb_dim)
        )
    def forward(self, x):
        return self.layers(x)

In [6]:
mha = MultiHeadAttention(768, 12)
ffn = FeedForward(768)

In [7]:
mha_params = sum(p.numel() for p in mha.parameters())
ffn_params = sum(p.numel() for p in ffn.parameters())

In [8]:
diff = ffn_params - mha_params
print(f"MHA parameters: {mha_params:,}")
print(f"ff parameters:         {ffn_params:,}")
print(f"ff has more params than MHA by {diff:,}")

MHA parameters: 2,362,368
ff parameters:         4,722,432
ff has more params than MHA by 2,360,064


4.2

transformer module

In [9]:
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attn = MultiHeadAttention(config.emb_dim, config.n_heads)
        self.ffn  = FeedForward(config.emb_dim)
        self.ln1 = nn.LayerNorm(config.emb_dim)
        self.ln2 = nn.LayerNorm(config.emb_dim)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

gpt module 

In [None]:
class GPTModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.tok_emb = nn.Embedding(config.vocab_size, config.emb_dim)
        self.pos_emb = nn.Embedding(config.context_length, config.emb_dim)
        self.blocks = nn.Sequential(*[TransformerBlock(config) for _ in range(config.n_layers)])
        self.ln_f = nn.LayerNorm(config.emb_dim)
        self.head = nn.Linear(config.emb_dim, config.vocab_size, bias=False)

        self.head.weight = self.tok_emb.weight

    def forward(self, x):
        b, t = x.shape
        tok = self.tok_emb(x)
        pos = self.pos_emb(torch.arange(t, device=x.device))
        x = tok + pos
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)
        return logits


gpt configuration

In [11]:
class GPTConfig:
    def __init__(self, vocab_size=50257, context_length=1024,
                 emb_dim=768, n_heads=12, n_layers=12):
        self.vocab_size = vocab_size
        self.context_length = context_length
        self.emb_dim = emb_dim
        self.n_heads = n_heads
        self.n_layers = n_layers

In [12]:
configs = {
    "small":  GPTConfig(emb_dim=768,  n_heads=12, n_layers=12),
    "medium": GPTConfig(emb_dim=1024, n_heads=16, n_layers=24),
    "large":  GPTConfig(emb_dim=1280, n_heads=20, n_layers=36),
    "xl":     GPTConfig(emb_dim=1600, n_heads=25, n_layers=48)
}

In [13]:
for name, cfg in configs.items():
    model = GPTModel(cfg)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"GPT-2 {name:<6}: {total_params/1e6:>8.1f}M parameters")

GPT-2 small :    124.4M parameters
GPT-2 medium:    354.8M parameters
GPT-2 large :    774.0M parameters
GPT-2 xl    :   1557.6M parameters
