In [1]:
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.optim as optim
from jaxtyping import Float, Int
import numpy as np

# Attention Head in math

$$
	A(X) = \sigma\bigg(X W_Q W_K^T X^T + M\bigg) X W_V W_O^T
$$

- $W_Q, W_K, W_V, W_O$ can be made with `nn.Linear` and will all have dimension `d_model` $\times$ `d_head`
- $M$ is a lower triangular mask matrix, look up how to do this

In [2]:
def create_mask(n_context: int) -> Float[torch.Tensor, "n_context n_context"]:
	# this should return a `n_context, n_context` matrix,
	# with zeros below and on the diag, and `-float("inf")` below
	# output = ...
	mask = np.full((n_context, n_context), -np.inf, dtype=np.float64)
	mask[np.tril_indices_from(mask)] = 0
	return torch.Tensor(mask)

In [19]:
# NOTE: it's intentional that `n_context` is not in the `GPTConfig`
@dataclass
class GPTConfig:
	# default test values -- too small for a real language model, but big enough for testing
	d_vocab: int = 40_000
	d_model: int = 128
	d_mlp: int = 512
	n_heads: int = 4
	d_head: int = 32
	n_layers: int = 6
	act_fn: type[nn.Module] = nn.ReLU

	@property
	def n_params(self) -> int:
		"an estimate of the number of parameters"
		return (
			self.d_vocab * self.d_model # embeddings (and tied unembeddings)
			+ (
				self.d_model * self.d_mlp * 2 # mlp weights
				+ self.d_model + self.d_mlp # mlp bias
				+ self.n_heads * ( # number of heads
					4 * self.d_model * self.d_head # 4 because Q, K, O, V
				)
			) * self.n_layers, # for each layer
		)
	
print(GPTConfig().n_params)

(6303488,)


In [20]:
# note: the residual stream is `n_context` by `d_model`

# this is the row-wise (last dimension) softmax of x
# F.softmax(x, dim=-1)

class AttentionHead(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		self.cfg: GPTConfig = cfg
  
		self.W_Q = nn.Linear(cfg.d_model, cfg.d_head, bias=False)
		self.W_K = nn.Linear(cfg.d_model, cfg.d_head, bias=False)
		self.W_V = nn.Linear(cfg.d_model, cfg.d_head, bias=False)
		self.W_O = nn.Linear(cfg.d_head, cfg.d_model, bias=False)

	def forward(self, x: Float[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
		n_context = x.shape[0]
		p1 = self.W_Q(x)
		p2 = self.W_K(x).transpose(0,1)
		M = create_mask(n_context)
		p3 = self.W_V(x)
		p4 = self.W_O(p3)
		return F.softmax(p1 @ p2 + M) @ p4

In [21]:
# TESTING CODE
cfg = GPTConfig()
attn_head = AttentionHead(cfg)
seq_len: int = 10
x = torch.randn(seq_len, cfg.d_model)
print(f"{x.shape = }")
y = attn_head(x)
print(f"{y.shape = }")

x.shape = torch.Size([10, 128])
y.shape = torch.Size([10, 128])


  return F.softmax(p1 @ p2 + M) @ p4


In [22]:
class MultiHeadedAttention(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		self.cfg = cfg
		self.heads = nn.ModuleList([AttentionHead(cfg) for _ in range(self.cfg.n_heads)])


	def forward(self, x: Float[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
		for head in self.heads:
			x = x + head.forward(x)
		return x


class MLP(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		self.cfg = cfg
		self.W_MD = nn.Linear(cfg.d_mlp, cfg.d_model, bias=True)
		self.W_MU = nn.Linear(cfg.d_model, cfg.d_mlp, bias=True)
		self.ReLU = self.cfg.act_fn()
  
	def forward(self, x: Float[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
		p1 = self.W_MU(x).transpose(0,1)
		p2 = self.ReLU(p1.transpose(0,1))
		p3 = self.W_MD(p2)
		return p3


class TransformerBlock(nn.Module):

    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.cfg = cfg
        self.MHA = MultiHeadedAttention(cfg)
        self.MLP = MLP(cfg)

    def forward(self, x: Float[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
        x = x + self.MHA.forward(x)
        x = x + self.MLP.forward(x)
        return x


class Transformer(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		self.cfg = cfg
		self.transformerBlocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(self.cfg.n_layers)])
		self.embedding = nn.Embedding(cfg.d_vocab, cfg.d_model)

	def forward(self, x: Int[torch.Tensor, "n_context"]) -> Float[torch.Tensor, "n_context d_vocab"]:
		embedded = self.embedding(x)
		for transformerBlock in self.transformerBlocks:
			embedded = transformerBlock.forward(embedded)
		unembedded = (self.embedding.weight @ embedded.transpose(0,1)).transpose(0,1)
		return unembedded

In [23]:
# TESTING CODE
cfg = GPTConfig()
transformer = Transformer(cfg)
seq_len: int = 10
x = torch.randint(1, 9999, (10,))
print(f"{x.shape = }")
y = transformer(x)
print(f"{y.shape = }")

x.shape = torch.Size([10])
y.shape = torch.Size([10, 40000])


  return F.softmax(p1 @ p2 + M) @ p4


# Making a dataset

In [24]:
from pathlib import Path
import requests
import unicodedata

In [25]:
def get_gutenberg_book(
	id: int|None = 84,
	data_temp: Path|str = "../data/gutenberg_data",
	remove_gutenberg_meta: bool = True,
) -> str:
	
	data_temp = Path(data_temp)
	data_temp.mkdir(parents=True, exist_ok=True)
	
	url: str = f"https://www.gutenberg.org/cache/epub/{id}/pg{id}.txt"
	data_path: Path = Path(data_temp) / f"{id}.txt"
	data: str
	# read from cache if it exists
	if data_path.exists():
		with open(data_path, 'r', encoding='utf-8') as file:
			data = file.read()
	else:
		# download if it doesn't exist
		response = requests.get(url)
		response.raise_for_status()  # Ensure that the download was successful
		data = response.text

		# save to cache
		with open(data_path, 'w', encoding='utf-8') as file:
			file.write(data)

	# remove header/footer
	if remove_gutenberg_meta:
		data = '***'.join(data.split('***')[2:])
		data = '***'.join(data.split('***')[:-1])
	
	return data

def get_many_books(
		ids: list[int],
		data_temp: Path|str = "../data/gutenberg_data",
	) -> list[str]:
	
	data: list[str] = []
	for id in ids:
		print(f"Getting book {id}...")
		item: str = get_gutenberg_book(id, data_temp)
		print(f"\t{len(item)} characters read")
		data.append(item)
	
	return data

In [26]:
def process_text(
	text: str,
	allowed_punctuation: str = "-.,;:!?()\"" + "".join(str(x) for x in range(10)),
	punctuation_convert: dict[str,str] = {'—': '-'},
	numbers_allowed: bool = True,
) -> str:
	
	# replace some special characters which unicode won't normalize properly
	for char, replacement in punctuation_convert.items():
		text = text.replace(char, replacement)

	# if a line has ".jpg" in it, remove that line (this is specific to Don Quixote)
	text = '\n'.join(
		line 
		for line in text.split('\n')
		if '.jpg' not in line
	)

	# Normalize the string to decompose Unicode characters
	text = unicodedata.normalize('NFKD', text)

	# Encode to ASCII bytes, then decode back to string, ignoring errors
	text = text.encode('ascii', 'ignore').decode('ascii')

	# remove newlines and tabs
	text = text.replace('\n', ' ').replace('\t', ' ')


	# put spaces around allowed punctuation
	for char in allowed_punctuation:
		text = text.replace(char, f' {char} ')


	# remove leading and trailing spaces
	text = text.strip()

	# remove multiple spaces
	while '  ' in text:
		text = text.replace('  ', ' ')


	# remove all characters except (alphanumeric, allowed_punctuation, ' ')
	text = ''.join(
		(
			char 
			if (
				(char.isalnum() and numbers_allowed) or (char.isalpha())
				or char in allowed_punctuation 
				or char == ' '
			)
			else ' '
		)
		for char in text 
	)

	# convert to lowercase
	text = text.lower()

	text = text.strip()

	return text


In [27]:
def tokenize(
	text: str,
	process: bool = False,
) -> list[str]:
	if process:
		text = process_text(text)
	return [token for token in text.split(' ') if token]

def split_list(list_to_split, split_size):
    sublists = []
    for i in range(0, len(list_to_split), split_size):
        i2 = i + split_size if i + split_size < len(list_to_split) else len(list_to_split)
        sublists.append(list_to_split[i:i2])
        
    return sublists

In [28]:

DATA_RAW: list[str] = get_many_books([84, 15, 18, 82, 996, 2600])
DATA: str = " ".join(process_text(x, allowed_punctuation="", numbers_allowed=False) for x in DATA_RAW)

#print(DATA[:1000])

DATA_TOKENIZED: list[str] = tokenize(DATA)
TOKEN_SET: set[str] = set(DATA_TOKENIZED)
TOKEN_ALPHABETICAL: list[str] = sorted(list(TOKEN_SET))
TOKEN_TO_INDEX: dict[str, int] = {token: i for i, token in enumerate(TOKEN_ALPHABETICAL)}
#INDEX_TO_TOKEN: dict[int, str] = {i: token for i, token in enumerate(TOKEN_ALPHABETICAL)}
MODEL_DATA: list[int] = [TOKEN_TO_INDEX[token] for token in DATA_TOKENIZED]
MODEL_DATA_CHUNKS: list[list[int]] = split_list(MODEL_DATA, 200)

Getting book 84...
	426785 characters read
Getting book 15...
	1241025 characters read
Getting book 18...
	1192776 characters read
Getting book 82...
	1124986 characters read
Getting book 996...
	2342262 characters read
Getting book 2600...
	3273998 characters read


# Training Loop

In [None]:
#Training Loop

optimizer = optim.Adam(transformer.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

dataset = MODEL_DATA_CHUNKS

#for epoch in range(0,1):
#epoch_loss = 0
for sample in dataset:
    sample = torch.tensor(sample, dtype=torch.long)
    inputs = sample[:-1]
    targets = sample[1:]

    optimizer.zero_grad()
    outputs = transformer(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

  return F.softmax(p1 @ p2 + M) @ p4
