In [3]:
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from jaxtyping import Float, Int

In [27]:
@dataclass
class GPTConfig:
	# default test values -- too small for a real language model, but big enough for testing
	d_vocab: int = 10_000
	d_model: int = 128
	d_mlp: int = 512
	n_heads: int = 4
	d_head: int = 32
	n_layers: int = 6
	act_fn: type[nn.Module] = nn.ReLU

	@property
	def n_params(self) -> int:
		"an estimate of the number of parameters"
		return (
			self.d_vocab * self.d_model # embeddings (and tied unembeddings)
			+ (
				self.d_model * self.d_mlp * 2 # mlp weights
				+ self.d_model + self.d_mlp # mlp bias
				+ self.n_heads * ( # number of heads
					4 * self.d_model * self.d_head # 4 because Q, K, O, V
				)
			) * self.n_layers, # for each layer
		)
	
print(GPTConfig().n_params)

# note: the residual stream is `n_context` by `d_model`

# this is the row-wise (last dimension) softmax of x
# F.softmax(x, dim=-1)

class AttentionHead(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		self.relu = nn.ReLU()
		self.d_vocab = cfg.d_vocab
		self.d_model = cfg.d_model
		self.d_head = cfg.d_head
		self.wq = nn.Linear(self.d_model, self.d_head)
		self.wk = nn.Linear(self.d_model, self.d_head)
		self.wv = nn.Linear(self.d_model, self.d_head)
		self.wo = nn.Linear(self.d_head, self.d_model)

	
	
	def forward(self, x: Int[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
		def masking_matrix(n_context):
			mask = torch.zeros((n_context, n_context))  # Start with all 0s
			mask[torch.triu(torch.ones((n_context, n_context)), diagonal=1) == 1] = -float('inf')  # Set above diagonal to -inf
			return mask
		
		M = masking_matrix(x.shape[0])
		# softmax_argument = x*self.wq*torch.transpose(self.wk)*torch.transpose(x) + M
		wk_out = torch.transpose(self.wk(x), 0, 1)
		# print("WK shape ", wk_out.shape)
		wq_out = self.wq(x)
		# print("WQ shape ", wq_out.shape)
		softmax_out = F.softmax((wq_out@wk_out + M), dim=-1)
		# print("Softmax shape ", softmax_out.shape)
		wv_out = self.wv(x)
		# print("WV shape ", wv_out.shape)
		wo_out = self.wo(wv_out)
		# wo_out = self.wo(wv_out)
		result = softmax_out@wo_out
		# print("Final A Shape ", result.shape)
		return result
		
		
class MultiHeadedAttention(nn.Module):

    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.n_heads = cfg.n_heads
        self.d_model = cfg.d_model
        self.d_head = cfg.d_head
        
        # Multi-head attention
        self.attention_heads = nn.ModuleList([AttentionHead(cfg) for _ in range(self.n_heads)])

        # Linear projection to project summed outputs back to d_model
        self.wo = nn.Linear(self.d_model, self.d_model)  # Fix the output size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        head_outputs = [head(x) for head in self.attention_heads]
        print("Head output shape: ", head_outputs[0].shape)

        summed_heads = torch.sum(torch.stack(head_outputs), dim=0)  # Sum over heads -> (n_context, d_head)

        summed_heads += x  # Element-wise addition (ensures same shape)

        # Project back to d_model
        output = self.wo(summed_heads)  # (n_context, d_model)

        return output



class MLP(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()

		# config
		self.d_model = cfg.d_model
		self.d_mlp = cfg.d_mlp

		# affine transformations
		self.lin1 = nn.Linear(self.d_model, self.d_mlp)
		# with nonlinearities in between
		self.relu = nn.ReLU()
		self.lin2 = nn.Linear(self.d_mlp, self.d_model)

	def forward(self, x: Int[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
		# apply things in sequence
		out = self.lin1(x.flatten(start_dim=1))
		out = self.relu(out)
		out = self.lin2(out)
		return out


class TransformerBlock(nn.Module):
	def __init__(self, cfg: GPTConfig):
		super().__init__()
		# uses `MultiHeadedAttention` and `MLP`
		self.multiheadattn = MultiHeadedAttention(cfg)
		self.mlp = MLP(cfg)

	def forward(self, x: Float[torch.Tensor, "n_context d_vocab"]) -> Float[torch.Tensor, "n_context d_vocab"]:
		out = self.multiheadattn(x)
		out = self.mlp(out) + x
		return out
		


class Transformer(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		self.embedding = nn.Embedding(cfg.d_vocab, cfg.d_model)
		self.unembedding = nn.Linear(cfg.d_model, cfg.d_vocab)
		self.transformer_blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
		# uses `MultiHeadedAttention` and `MLP`
		# uses nn.Embedding for the embedding, transpose of it for the unembedding

	def forward(self, x: Int[torch.Tensor, "n_context"]) -> Float[torch.Tensor, "n_context d_vocab"]:
		out = self.embedding(x)
		print(out.shape)
		for block in self.transformer_blocks:
			out = block(out)
		out = F.softmax(self.unembedding(out), dim=-1)
		return out

(2463488,)


In [28]:
gpt_config = GPTConfig()
gpt = Transformer(gpt_config)
x = torch.randint(0, gpt_config.d_vocab, (12,))
print(x.shape)
print(gpt(x).shape)

torch.Size([12])
torch.Size([12, 128])
Head output shape:  torch.Size([12, 128])
Head output shape:  torch.Size([12, 128])
Head output shape:  torch.Size([12, 128])
Head output shape:  torch.Size([12, 128])
Head output shape:  torch.Size([12, 128])
Head output shape:  torch.Size([12, 128])
torch.Size([12, 10000])


In [33]:
# Attention Testing
# gpt_config = GPTConfig()
# attn_head = AttentionHead(gpt_config)
# x = torch.randn(256, gpt_config.d_model)
# print(x)
# print(x.shape)
# attn_head.forward(x)
# multi_head = MultiHeadedAttention(gpt_config)
# multi_head.forward(x).shape

## Training the Transformer

In [25]:
import re
some_text = """
In reality, of course, we don't construct such chains explicitly, but instead we want them to learn from data.

To put something in a markov chain or neural network, we need to turn it into numbers. this is straightforward for images: each pixel is already a number! 

In computers, text is stored as a sequence of numbers. Our neural network, in principle, can learn to predict the next number in the sequence. However, each number usually represents a single letter, or even just part of a letter. what do you think happens when we throw something like this into a markov chain?
"""

def create_word_index(text):
    words = re.findall(r'\b\w+\b', text.lower())
    sorted_words = sorted(set(words))
    word_to_index = {word: idx for idx, word in enumerate(sorted_words)}

    return word_to_index

def text_to_tensor(vocab_dict, text):
    # Remove punctuation and tokenize words
    words = re.findall(r'\b\w+\b', text.lower())

    # Convert words to their corresponding integer indices
    int_sequence = [vocab_dict[word] for word in words if word in vocab_dict]

    # Convert list to a PyTorch tensor
    return torch.tensor(int_sequence, dtype=torch.long)

print(text_to_tensor(create_word_index(some_text), some_text))

tensor([21, 45, 37,  9, 64, 12, 53,  8, 52,  6, 15,  3, 22, 64, 63, 56, 60, 27,
        17, 10, 60, 44, 49, 21,  0, 30,  5, 38, 33, 32, 64, 31, 60, 61, 25, 23,
        36, 58, 24, 51, 16, 20, 13, 41, 24,  1,  0, 35, 21,  7, 54, 24, 50,  2,
         0, 47, 37, 36, 39, 33, 32, 21, 43,  4, 27, 60, 42, 55, 34, 35, 21, 55,
        47, 19, 13, 35, 62, 46,  0, 48, 28, 38, 14, 26, 40, 37,  0, 28, 65, 11,
        67, 57, 18, 66, 64, 59, 49, 29, 58, 23,  0, 30,  5])


In [21]:
def train(
        model: nn.Module, 
        optimizer: optim.Optimizer, 
        data_list: torch.tensor 
    ) -> None:
    for index, (data) in enumerate(data_list):
        # reset the gradients
        optimizer.zero_grad()
        
        # forward pass
        # input: everything but the last element of the vectorized data
        input = data[:-1]
        # target: everything but the first element of the vectorized data
        target = data[1:]

        output = model(input)
        # calculate the loss
        loss = nn.CrossEntropyLoss()(output, target)
        # backward pass, update gradient
        loss.backward()
        optimizer.step()



In [38]:
def main_train(
    model: nn.Module,
    lr: float = 0.001
) -> None:
    # vectorize text (I am assuming we should do this with various different text blocks, rather than repeatedly vectorizing some_text)
    vectorized_text_1 = text_to_tensor(create_word_index(some_text), some_text)
    vectorized_text_2 = text_to_tensor(create_word_index(some_text), some_text)
    vectorized_text_3 = text_to_tensor(create_word_index(some_text), some_text)
    data_list = [vectorized_text_1, vectorized_text_2, vectorized_text_3]

    # set up the optimizer, based on the parameters of the model
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # train the model
    train(
        model = model,
        optimizer = optimizer,
        data_list = data_list)

    return model

In [39]:
gptconfig = GPTConfig()

trained_model = main_train(
    model = Transformer(gptconfig),
    lr = 0.001
)

tensor([21, 45, 37,  9, 64, 12, 53,  8, 52,  6, 15,  3, 22, 64, 63, 56, 60, 27,
        17, 10, 60, 44, 49, 21,  0, 30,  5, 38, 33, 32, 64, 31, 60, 61, 25, 23,
        36, 58, 24, 51, 16, 20, 13, 41, 24,  1,  0, 35, 21,  7, 54, 24, 50,  2,
         0, 47, 37, 36, 39, 33, 32, 21, 43,  4, 27, 60, 42, 55, 34, 35, 21, 55,
        47, 19, 13, 35, 62, 46,  0, 48, 28, 38, 14, 26, 40, 37,  0, 28, 65, 11,
        67, 57, 18, 66, 64, 59, 49, 29, 58, 23,  0, 30,  5])
tensor([21, 45, 37,  9, 64, 12, 53,  8, 52,  6, 15,  3, 22, 64, 63, 56, 60, 27,
        17, 10, 60, 44, 49, 21,  0, 30,  5, 38, 33, 32, 64, 31, 60, 61, 25, 23,
        36, 58, 24, 51, 16, 20, 13, 41, 24,  1,  0, 35, 21,  7, 54, 24, 50,  2,
         0, 47, 37, 36, 39, 33, 32, 21, 43,  4, 27, 60, 42, 55, 34, 35, 21, 55,
        47, 19, 13, 35, 62, 46,  0, 48, 28, 38, 14, 26, 40, 37,  0, 28, 65, 11,
        67, 57, 18, 66, 64, 59, 49, 29, 58, 23,  0, 30])
torch.Size([102, 128])
Head output shape:  torch.Size([102, 128])
tensor([[-0.4664