In [1]:
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from jaxtyping import Float, Int

In [28]:
@dataclass
class GPTConfig:
	# default test values -- too small for a real language model, but big enough for testing
	d_vocab: int = 10_000
	d_model: int = 128
	d_mlp: int = 512
	n_heads: int = 4
	d_head: int = 32
	n_layers: int = 6
	act_fn: type[nn.Module] = nn.ReLU

	@property
	def n_params(self) -> int:
		"an estimate of the number of parameters"
		return (
			self.d_vocab * self.d_model # embeddings (and tied unembeddings)
			+ (
				self.d_model * self.d_mlp * 2 # mlp weights
				+ self.d_model + self.d_mlp # mlp bias
				+ self.n_heads * ( # number of heads
					4 * self.d_model * self.d_head # 4 because Q, K, O, V
				)
			) * self.n_layers, # for each layer
		)
	
print(GPTConfig().n_params)

# note: the residual stream is `n_context` by `d_model`

# this is the row-wise (last dimension) softmax of x
# F.softmax(x, dim=-1)

class AttentionHead(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		self.relu = nn.ReLU()
		self.d_vocab = cfg.d_vocab
		self.d_model = cfg.d_model
		self.d_head = cfg.d_head
		self.wq = nn.Linear(self.d_model, self.d_head)
		self.wk = nn.Linear(self.d_model, self.d_head)
		self.wv = nn.Linear(self.d_model, self.d_head)
		self.wo = nn.Linear(self.d_head, self.d_model)

	
	
	def forward(self, x: Int[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
		def masking_matrix(n_context):
			mask = torch.zeros((n_context, n_context))  # Start with all 0s
			mask[torch.triu(torch.ones((n_context, n_context)), diagonal=1) == 1] = -float('inf')  # Set above diagonal to -inf
			return mask
		
		M = masking_matrix(x.shape[0])
		# softmax_argument = x*self.wq*torch.transpose(self.wk)*torch.transpose(x) + M
		wk_out = torch.transpose(self.wk(x), 0, 1)
		# print("WK shape ", wk_out.shape)
		wq_out = self.wq(x)
		# print("WQ shape ", wq_out.shape)
		softmax_out = F.softmax((wq_out@wk_out + M), dim=-1)
		# print("Softmax shape ", softmax_out.shape)
		wv_out = self.wv(x)
		# print("WV shape ", wv_out.shape)
		wo_out = self.wo(wv_out)
		# wo_out = self.wo(wv_out)
		result = softmax_out@wo_out
		# print("Final A Shape ", result.shape)
		return result
		
		
class MultiHeadedAttention(nn.Module):

    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.n_heads = cfg.n_heads
        self.d_model = cfg.d_model
        self.d_head = cfg.d_head
        
        # Multi-head attention
        self.attention_heads = nn.ModuleList([AttentionHead(cfg) for _ in range(self.n_heads)])

        # Linear projection to project summed outputs back to d_model
        self.wo = nn.Linear(self.d_model, self.d_model)  # Fix the output size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        head_outputs = [head(x) for head in self.attention_heads]
        #print("Head output shape: ", head_outputs[0].shape)

        summed_heads = torch.sum(torch.stack(head_outputs), dim=0)  # Sum over heads -> (n_context, d_head)

        summed_heads += x  # Element-wise addition (ensures same shape)

        # Project back to d_model
        output = self.wo(summed_heads)  # (n_context, d_model)

        return output



class MLP(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()

		# config
		self.d_model = cfg.d_model
		self.d_mlp = cfg.d_mlp

		# affine transformations
		self.lin1 = nn.Linear(self.d_model, self.d_mlp)
		# with nonlinearities in between
		self.relu = nn.ReLU()
		self.lin2 = nn.Linear(self.d_mlp, self.d_model)

	def forward(self, x: Int[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
		# apply things in sequence
		out = self.lin1(x.flatten(start_dim=1))
		out = self.relu(out)
		out = self.lin2(out)
		return out


class TransformerBlock(nn.Module):
	def __init__(self, cfg: GPTConfig):
		super().__init__()
		# uses `MultiHeadedAttention` and `MLP`
		self.multiheadattn = MultiHeadedAttention(cfg)
		self.mlp = MLP(cfg)

	def forward(self, x: Float[torch.Tensor, "n_context d_vocab"]) -> Float[torch.Tensor, "n_context d_vocab"]:
		out = self.multiheadattn(x)
		out = self.mlp(out) + x
		return out
		


class Transformer(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		self.embedding = nn.Embedding(cfg.d_vocab, cfg.d_model)
		self.unembedding = nn.Linear(cfg.d_model, cfg.d_vocab)
		self.transformer_blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
		# uses `MultiHeadedAttention` and `MLP`
		# uses nn.Embedding for the embedding, transpose of it for the unembedding

	def forward(self, x: Int[torch.Tensor, "n_context"]) -> Float[torch.Tensor, "n_context d_vocab"]:
		out = self.embedding(x)
		#print(out.shape)
		for block in self.transformer_blocks:
			out = block(out)
		out = F.softmax(self.unembedding(out), dim=-1)
		return out

(2463488,)


In [27]:
gpt_config = GPTConfig()
gpt = Transformer(gpt_config)
x = torch.randint(0, gpt_config.d_vocab, (12,))
print(x.shape)
print(gpt(x).shape)

torch.Size([12])
torch.Size([12, 128])
Head output shape:  torch.Size([12, 128])
Head output shape:  torch.Size([12, 128])
Head output shape:  torch.Size([12, 128])
Head output shape:  torch.Size([12, 128])
Head output shape:  torch.Size([12, 128])
Head output shape:  torch.Size([12, 128])
torch.Size([12, 10000])


In [26]:
# Attention Testing
# gpt_config = GPTConfig()
# attn_head = AttentionHead(gpt_config)
# x = torch.randn(256, gpt_config.d_model)
# print(x)
# print(x.shape)
# attn_head.forward(x)
# multi_head = MultiHeadedAttention(gpt_config)
# multi_head.forward(x).shape

## Training the Transformer

In [25]:
import re
some_text = """
In reality, of course, we don't construct such chains explicitly, but instead we want them to learn from data.

To put something in a markov chain or neural network, we need to turn it into numbers. this is straightforward for images: each pixel is already a number! 

In computers, text is stored as a sequence of numbers. Our neural network, in principle, can learn to predict the next number in the sequence. However, each number usually represents a single letter, or even just part of a letter. what do you think happens when we throw something like this into a markov chain?
ChatGPT said:
When we throw text data into a Markov chain, especially when it's broken down into units like single letters or even parts of letters, the chain starts to model the statistical transitions between these units. Let's break this down:

A Markov chain, by its nature, works by predicting the next state based on the current state, without any memory of the past beyond that. So, if we consider each letter or part of a letter as a state in the chain, the Markov model will simply calculate the probabilities of transitioning from one letter or part of a letter to the next.

If we use individual letters as states, what happens is that the Markov chain doesn't get the full linguistic context. It may understand that certain letters follow other letters more frequently (e.g., "e" often follows "t" in English), but it doesn't understand words or meaning. The model, at best, can approximate patterns based on letter frequencies. This leads to the production of text that might look like it's composed of English letters but doesn't follow grammatical or semantic rules.

Let's look at a basic example: If we train a Markov chain on a corpus of text and use individual letters as the states, the chain might learn something like:

“t” → “h” → “e” → “r” → “e” → “s” → “h”
This might produce something that looks a bit like the word “there,” but it might also just generate a sequence of random characters that have a similar statistical distribution, with no real coherence or meaning.

Introducing a Larger Context
To avoid generating gibberish and bring more coherence to the generated text, we need to move beyond individual characters or even parts of characters. In more advanced models, the chain would work with larger units of information: words or even entire phrases, where the chain can look at a broader context. This is essentially what neural networks, and more specifically, Recurrent Neural Networks (RNNs) and Long Short-Term Memory networks (LSTMs), do. They capture dependencies over larger sequences of text, considering not just the current word or character but the entire sequence that preceded it.

Neural Networks and Word Embeddings
Neural networks, especially those built for natural language processing, handle text much differently than a simple Markov chain. In a neural network, text is typically converted into numerical representations that capture semantic meaning, such as word embeddings. Each word (or even sub-word units) is mapped to a high-dimensional vector that represents its meaning relative to other words.

When these word embeddings are fed into a neural network, the network doesn't just learn the sequence of letters or words in isolation. It learns relationships between words and their contexts. For instance, in a neural network trained on language, the words "king" and "queen" might have similar embeddings because they share semantic relationships, such as being royalty.

Unlike a Markov chain, which might learn that certain letters follow each other based purely on statistical frequency, neural networks can learn the underlying patterns in meaning, sentence structure, and grammar. This means that they can generate more coherent text, because they understand not just which letters or words are likely to follow each other, but why they do.

Transitioning from Markov Chains to Neural Networks
Now, if we take the principles of a Markov chain and combine them with the power of neural networks, we get more advanced text generation models like GPT (Generative Pretrained Transformers), which are capable of generating fluent and contextually rich text. A Markov chain by itself has limitations, because it doesn't understand long-range dependencies well, whereas a neural network, particularly the transformer architecture, can model very long sequences of text by paying attention to different parts of the input at once, rather than just relying on the immediately preceding state.

The Problem of Predicting the Next Number
With that said, the problem of turning text into numbers isn't trivial. In a Markov chain, a system might simply predict the next unit (letter, part of a letter, or word) based on the previous one. But for a neural network, the model must map the input into vectors (numbers) that it can process. In order to do this, we typically use tokenization.

For word-based models, each word in the vocabulary is assigned a unique number, and when we input text into the model, each word is converted into its corresponding number. The neural network, then, learns to predict the next word in the sequence by considering the numbers of previous words. This is done by training the model on large datasets, where it learns to predict the next token based on the preceding tokens in the text.

However, if the text is represented in such a way that each character or part of a character is turned into a separate number, the network is then tasked with learning how to sequence these numbers, and while it can potentially learn patterns and regularities, it's much harder for the model to understand larger contextual dependencies. This is because the model lacks the broader semantic understanding that comes with grouping these parts into larger units (like words or phrases).

A Simple Example
Imagine a basic example where we input "I love programming" into a Markov chain or a neural network model that uses characters as the basic units. The network might break it down into:

"I" → " " → "l" → "o" → "v" → "e" → " " → "p" → "r" → "o" → "g" → "r" → "a" → "m" → "m" → "i" → "n" → "g"
If we trained a Markov chain to predict the next character, it might learn that after "I", a space often follows. After "l", "o" is likely to come next, and so on. While the chain could generate valid sequences of letters, they wouldn't necessarily form recognizable words or sentences. For example, it might output something like:

"I love p r o g"

which, while grammatically incomplete, still has some structure. However, it can't understand the semantics of "programming" as a whole concept.

Now, if we use a neural network that works on whole words or sub-word units (like Byte Pair Encoding, or BPE, which splits words into frequent sub-word units), the model can better capture the semantics and coherence of the sentence. In this case, the network would learn that "I" and "love" are words that often go together, and it would also know the word "programming" as a single entity rather than trying to predict each character one by one.

Deep Learning for Text Generation
In deep learning models, such as transformers or GPT-like models, we don't just predict the next word based on the immediate previous word. Instead, these models use self-attention mechanisms to look at all the words in a sentence at once. This allows them to generate more coherent and contextually appropriate sequences. For example, when generating text, the model might consider the context of the entire sentence, rather than simply predicting the next word based on the previous one. This kind of context-aware generation leads to much more fluent, natural text.

So, when we throw text into a Markov chain, we're dealing with a system that doesn't understand the relationships between words or even characters in any deep sense. It just knows the frequency of transitions. But a neural network model, by virtue of its architecture and learning approach, is able to model more complex dependencies and structures in the text, making it far more effective at generating meaningful, contextually appropriate language.

Conclusion
In conclusion, when we throw text into a Markov chain, especially one that deals with individual letters or parts of letters, we're left with a statistical model that generates sequences based on probabilities, but lacks the capacity to understand or produce meaningful language. The Markov chain would likely create some sequences that resemble text in shape, but they wouldn't be coherent or meaningful in the same way that a neural network trained on language can generate.

Neural networks, particularly when using word embeddings or advanced models like transformers, are far more capable of capturing the complex relationships in language, allowing them to generate text that is not only syntactically correct but semantically rich as well. So, while a Markov chain can be a useful model for learning simple transitions, neural networks have the edge when it comes to understanding and generating real language.
"""

def create_word_index(text):
    words = re.findall(r'\b\w+\b', text.lower())
    sorted_words = sorted(set(words))
    word_to_index = {word: idx for idx, word in enumerate(sorted_words)}

    return word_to_index

def text_to_tensor(vocab_dict, text):
    # Remove punctuation and tokenize words
    words = re.findall(r'\b\w+\b', text.lower())

    # Convert words to their corresponding integer indices
    int_sequence = [vocab_dict[word] for word in words if word in vocab_dict]

    # Convert list to a PyTorch tensor
    return torch.tensor(int_sequence, dtype=torch.long)

print(text_to_tensor(create_word_index(some_text), some_text))

tensor([151, 275, 227,  ..., 125, 274, 173])


In [24]:
def train(
        model: nn.Module, 
        optimizer: optim.Optimizer, 
        data_list: torch.tensor 
    ) -> None:
    for index, (data) in enumerate(data_list):
        # reset the gradients
        optimizer.zero_grad()
        
        # forward pass
        # input: everything but the last element of the vectorized data
        input = data[:-1]
        # target: everything but the first element of the vectorized data
        target = data[1:]

        output = model(input)
        # calculate the loss
        loss = nn.CrossEntropyLoss()(output, target)
        # backward pass, update gradient
        loss.backward()
        optimizer.step()



In [23]:
def main_train(
    model: nn.Module,
    lr: float = 0.001
) -> None:
    # vectorize text
    vectorized_text = text_to_tensor(create_word_index(some_text), some_text)
    data_list = torch.split(vectorized_text, 5)

    # set up the optimizer, based on the parameters of the model
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # train the model
    train(
        model = model,
        optimizer = optimizer,
        data_list = data_list)

    return model

In [29]:
gptconfig = GPTConfig()

trained_model = main_train(
    model = Transformer(gptconfig),
    lr = 0.001
)

input_data = text_to_tensor(create_word_index("we don't construct such chains"), some_text)
output = trained_model(input_data)
print(output)

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], grad_fn=<SoftmaxBackward0>)
