In [None]:
import torch

import torch.nn as nn
from torch.nn import functional as F

In [3]:
B = 32
T = 10
vocab_size = 255
C = 32
n_head = 4
hs = C // n_head
assert C % n_head == 0
hs

8

Transformer

<img src="transformer.jpg" width="300px">

$\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V$

In [4]:
# Embeddings
wte = nn.Embedding(vocab_size, C)
wpe = nn.Embedding(T, C)

# Transformer
transformer_dropout = nn.Dropout(0.2)
transformer_norm = nn.LayerNorm(C)
transformer_fc = nn.Linear(C, vocab_size)

# Block
block_norm_1 = nn.LayerNorm(C)
block_norm_2 = nn.LayerNorm(C)
block_fc = nn.Linear(C, 4 * C)
block_act = nn.ReLU()
block_proj = nn.Linear(4 * C, C)
block_dropout = nn.Dropout(0.2)

# Attention
c_attn = nn.Linear(C, 3 * C)
attn_dropout = nn.Dropout(0.2)
mask = torch.tril(torch.ones((1, 1, T, T)))
c_proj = nn.Linear(C, C)
resid_dropout = nn.Dropout(0.2)

In [None]:
# Raw input tokens
x = torch.randint(low=0, high=vocab_size, size=(B, T))
assert x.shape == (B, T)
# Positions of input tokens
p = torch.arange(0, T, dtype=torch.long).view(1, T)
assert p.shape == (1, T)
# Token embeddings
x = wte(x)
assert x.shape == (B, T, C)
# Positional embeddings
p = wpe(p)
assert p.shape == (1, T, C)
# Combine token and positional embeddings w/ broadcast
x += p
# transformer level dropout
x = transformer_dropout(x)
assert x.shape == (B, T, C)
# attention residual pathway
y = x.clone()
# Attention layer norm
x = block_norm_1(x)
# Attention projection
x = c_attn(x)
assert x.shape == (B, T, 3 * C)
# Split attention into key, query, and value tensors
k, q, v = x.split(C, dim=2)
assert k.shape == (B, T, C)
assert q.shape == (B, T, C)
assert v.shape == (B, T, C)
# Introduce a head size "batch"
k = k.view(B, T, n_head, C // n_head).transpose(1, 2)
q = q.view(B, T, n_head, C // n_head).transpose(1, 2)
v = v.view(B, T, n_head, C // n_head).transpose(1, 2)
assert k.shape == (B, n_head, T, hs)
assert q.shape == (B, n_head, T, hs)
assert v.shape == (B, n_head, T, hs)
# Compute attention
x = q @ k.transpose(-2, -1) * k.size(-1) ** -0.5
assert x.shape == (B, n_head, T, T)
# Apply the masking for casual self attention
x = torch.masked_fill(x, mask == 0, -float('inf'))
# Take the softmax to normalize each head
x = F.softmax(x, dim=-1)
assert torch.allclose(x.sum(-1), torch.ones(B, n_head, T))
# Randomly drop nodes communication
x = attn_dropout(x)
# Apply attention to values
x = x @ v
assert x.shape == (B, n_head, T, hs)
x = x.transpose(1, 2).contiguous().view(B, T, C)
assert x.shape == (B, T, C)
x = c_proj(x)
assert x.shape == (B, T, C)
x = resid_dropout(x)
x += y.clone()
# mlp residual pathway
y = x
# mlp
x = block_norm_2(x)
x = block_fc(x)
x = block_act(x)
x = block_proj(x)
x = block_dropout(x)
x += y
# finish transformer
x = transformer_norm(x)
x = transformer_fc(x)
assert x.shape == (B, T, vocab_size)