In [None]:
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.optim as optim
from jaxtyping import Float, Int
import numpy as np

# Attention Head in math

$$
	A(X) = \sigma\bigg(X W_Q W_K^T X^T + M\bigg) X W_V W_O^T
$$

- $W_Q, W_K, W_V, W_O$ can be made with `nn.Linear` and will all have dimension `d_model` $\times$ `d_head`
- $M$ is a lower triangular mask matrix, look up how to do this

In [38]:
def create_mask(n_context: int) -> Float[torch.Tensor, "n_context n_context"]:
	# this should return a `n_context, n_context` matrix,
	# with zeros below and on the diag, and `-float("inf")` below
	# output = ...
	mask = np.full((n_context, n_context), -np.inf, dtype=np.float64)
	mask[np.tril_indices_from(mask)] = 0
	return torch.Tensor(mask)

In [39]:
# NOTE: it's intentional that `n_context` is not in the `GPTConfig`
@dataclass
class GPTConfig:
	# default test values -- too small for a real language model, but big enough for testing
	d_vocab: int = 10_000
	d_model: int = 128
	d_mlp: int = 512
	n_heads: int = 4
	d_head: int = 32
	n_layers: int = 6
	act_fn: type[nn.Module] = nn.ReLU

	@property
	def n_params(self) -> int:
		"an estimate of the number of parameters"
		return (
			self.d_vocab * self.d_model # embeddings (and tied unembeddings)
			+ (
				self.d_model * self.d_mlp * 2 # mlp weights
				+ self.d_model + self.d_mlp # mlp bias
				+ self.n_heads * ( # number of heads
					4 * self.d_model * self.d_head # 4 because Q, K, O, V
				)
			) * self.n_layers, # for each layer
		)
	
print(GPTConfig().n_params)

(2463488,)


In [None]:
# note: the residual stream is `n_context` by `d_model`

# this is the row-wise (last dimension) softmax of x
# F.softmax(x, dim=-1)

class AttentionHead(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		self.cfg: GPTConfig = cfg
  
		self.W_Q = nn.Linear(cfg.d_model, cfg.d_head, bias=False)
		self.W_K = nn.Linear(cfg.d_model, cfg.d_head, bias=False)
		self.W_V = nn.Linear(cfg.d_model, cfg.d_head, bias=False)
		self.W_O = nn.Linear(cfg.d_head, cfg.d_model, bias=False)

	def forward(self, x: Float[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
		n_context = x.shape[0]
		p1 = self.W_Q(x)
		p2 = self.W_K(x).transpose(0,1)
		M = create_mask(n_context)
		p3 = self.W_V(x)
		p4 = self.W_O(p3)
		return F.softmax(p1 @ p2 + M) @ p4

x.shape = torch.Size([10, 128])
y.shape = torch.Size([10, 128])


  return F.softmax(p1 @ p2 + M) @ p4


In [None]:
# TESTING CODE
cfg = GPTConfig()
attn_head = AttentionHead(cfg)
seq_len: int = 10
x = torch.randn(seq_len, cfg.d_model)
print(f"{x.shape = }")
y = attn_head(x)
print(f"{y.shape = }")

In [25]:
class MultiHeadedAttention(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		self.cfg = cfg
		self.heads = nn.ModuleList([AttentionHead(cfg) for _ in range(self.cfg.n_heads)])


	def forward(self, x: Float[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
		for head in self.heads:
			x = x + head.forward(x)
		return x


class MLP(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		self.cfg = cfg
		self.W_MD = nn.Linear(cfg.d_mlp, cfg.d_model, bias=True)
		self.W_MU = nn.Linear(cfg.d_model, cfg.d_mlp, bias=True)
		self.ReLU = self.cfg.act_fn()
  
	def forward(self, x: Float[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
		p1 = self.W_MU(x).transpose(0,1)
		p2 = self.ReLU(p1.transpose(0,1))
		p3 = self.W_MD(p2)
		return p3


class TransformerBlock(nn.Module):

    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.cfg = cfg
        self.MHA = MultiHeadedAttention(cfg)
        self.MLP = MLP(cfg)

    def forward(self, x: Float[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
        x = x + self.MHA.forward(x)
        x = x + self.MLP.forward(x)
        return x


class Transformer(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		self.cfg = cfg
		self.transformerBlocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(self.cfg.n_layers)])
		self.embedding = nn.Embedding(cfg.d_vocab, cfg.d_model)

	def forward(self, x: Int[torch.Tensor, "n_context"]) -> Float[torch.Tensor, "n_context d_vocab"]:
		embedded = self.embedding(x)
		for transformerBlock in self.transformerBlocks:
			embedded = transformerBlock.forward(embedded)
		unembedded = (self.embedding.weight @ embedded.transpose(0,1)).transpose(0,1)
		return unembedded

In [26]:
# TESTING CODE
cfg = GPTConfig()
transformer = Transformer(cfg)
seq_len: int = 10
x = torch.randint(1, 9999, (10,))
print(f"{x.shape = }")
y = transformer(x)
print(f"{y.shape = }")

x.shape = torch.Size([10])
y.shape = torch.Size([10, 10000])


  return F.softmax(p1 @ p2 + M) @ p4


In [None]:
#Training Loop

optimizer = optim.Adam(transformer.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

#for epoch in range(0,1):
#epoch_loss = 0
for sample in dataset:
    inputs = sample[:-1]
    targets = sample[1:]

    optimizer.zero_grad()
    outputs = transformer(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()