# GPT from scratch

Follows [this](https://www.youtube.com/watch?v=kCc8FmEb1nY) video tutorial.

In [2]:
# !curl -O https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [3]:
with open('input.txt', 'r') as f:
    text = f.read()

In [4]:
print(f"Length of text: {len(text)}")
print(text[:100])

Length of text: 1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


We need to create our vocabulary.

In [5]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
vocab = {c: i for i, c in enumerate(chars)}
print(len(vocab)) # 65 unique chars.

65


We want to now be able to tokenize the text somehow. Here, we're building a character-level model, so we're just translating individual characters into integers.

In [6]:
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for i, c in enumerate(chars)}

# take a string, return a list of integers
encode = lambda x: [stoi[c] for c in x]
# take a list of integers, return a string
decode = lambda x: ''.join([itos[i] for i in x])

In [7]:
print(encode("Hello this is a test string"))
print(decode(encode("Hello this is a test string")))

[20, 43, 50, 50, 53, 1, 58, 46, 47, 57, 1, 47, 57, 1, 39, 1, 58, 43, 57, 58, 1, 57, 58, 56, 47, 52, 45]
Hello this is a test string


We can now tokenize the entire text dataset

In [8]:
import torch

In [9]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:10])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])


In [10]:
# let's also split out data into train and validation datasets
n = int(0.9*len(data))
train_data, val_data = data[:n], data[n:]

We can't pass in the entire dataset in one iteration of the transformer, so we do need to set a "block size" (AKA context length).

In [11]:
block_size = 8

Let's take a look at one block

In [12]:
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

### Training using chunks of the context window
A transformer makes simultaneous predictions for each of these positions. This makes training much more efficient since we treat each position as a separate prediction task instead of just predicting the next character after the last position, which greatly increases the number of samples used for training.

This is why we take `[:block_size+1]` since if we have 8 chars, we need to look 9 chars so that we can always predict the char after a given char.

We can actually display what the prediction task is: given a block of text, we predict what token is at each position. The input passed into each prediction task is whatever tokens were in the block before this target.

In [13]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"When input is {context}, the target is: {target}")

When input is tensor([18]), the target is: 47
When input is tensor([18, 47]), the target is: 56
When input is tensor([18, 47, 56]), the target is: 57
When input is tensor([18, 47, 56, 57]), the target is: 58
When input is tensor([18, 47, 56, 57, 58]), the target is: 1
When input is tensor([18, 47, 56, 57, 58,  1]), the target is: 15
When input is tensor([18, 47, 56, 57, 58,  1, 15]), the target is: 47
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]), the target is: 58


So, in a chunk of 9 characters, we essentially have 8 training examples. This not only gives us more training samples, but it also makes the Transformer used to seeing context sizes from as little as 1 all the way to the context length.

### Batching

We also need to care about the batching dimension as well, since we'll be feeding the input as batches of tensors. This lets us process more inputs in parallel, especially since GPUs are great at working with batches of data.

In [14]:
torch.manual_seed(1337)

<torch._C.Generator at 0x10f29b030>

In [15]:
# number of sequences in a batch, processed in parallel.
batch_size = 4

# maximum context length for prediction
block_size = 8

In [16]:
def get_batch(split):
    """Gets a batch of data for training or validation."""
    data = (
        train_data if split == "train" else val_data
    )
    # we take random ints as our start indices. We make
    # sure to take random ints that allow us to take a full
    # sequence of `block_size` length
    ix = torch.randint(
        len(data) - block_size, # which ints to include in pool
        (batch_size,) # number of results = batch size.
    )

    # for each start index, we take a random block of chars.
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [17]:
xb, yb = get_batch(split="train")

Let's now look at our batch.

In [18]:
for b in range(batch_size):
    print(f"Batch {b}")
    print('-' * 10)
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f"When input is {context}, the target is: {target}")
    print('-' * 10)

Batch 0
----------
When input is tensor([24]), the target is: 43
When input is tensor([24, 43]), the target is: 58
When input is tensor([24, 43, 58]), the target is: 5
When input is tensor([24, 43, 58,  5]), the target is: 57
When input is tensor([24, 43, 58,  5, 57]), the target is: 1
When input is tensor([24, 43, 58,  5, 57,  1]), the target is: 46
When input is tensor([24, 43, 58,  5, 57,  1, 46]), the target is: 43
When input is tensor([24, 43, 58,  5, 57,  1, 46, 43]), the target is: 39
----------
Batch 1
----------
When input is tensor([44]), the target is: 53
When input is tensor([44, 53]), the target is: 56
When input is tensor([44, 53, 56]), the target is: 1
When input is tensor([44, 53, 56,  1]), the target is: 58
When input is tensor([44, 53, 56,  1, 58]), the target is: 46
When input is tensor([44, 53, 56,  1, 58, 46]), the target is: 39
When input is tensor([44, 53, 56,  1, 58, 46, 39]), the target is: 58
When input is tensor([44, 53, 56,  1, 58, 46, 39, 58]), the target i

Each of these batch samples becomes a row in the data. What we end up getting is a `(batch_size, block_size)`-shaped tensor for our training data.For batch_size=4 and block_size=8, this leads to a 4x8 tensor.

Because we predict the next token for each position in the block, we get 8 training samples out of a block of size=8. This means that in this batch, we have 32 training samples.

In [19]:
len(y)

8

### Starting neural network training with the simplest model: bigrams
Now we can start using this input for neural network training. We can start with the simplest model: the bigram.

#### What is a bigram model?
A bigram model is a model that predicts the next word given the previous word in a sentence. It assumes that the probability of the next word is *only affected by the previous word*.

For our case, we'll use a bigram character model, where the probability of a given character is based only on the probability of the previous character.

In [20]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

<torch._C.Generator at 0x10f29b030>

In [21]:
class BigramLanguageModel(nn.Module):
    """Bigram model."""
    def __init__(self, vocab_size):
        super().__init__()
        # read off the logits for the next token
        # from a lookup table.
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, idx):
        """Forward pass."""
        # get the logits for the next token.
        # for a given batch with shape (B, T), we will get an output
        # of shape (B, T, C), where B = batch size, T = sequence length
        # and C = vocab size.
        logits = self.token_embedding_table(idx)
        return logits

    def predict(self, x):
        """Predict the next token."""
        # get the logits
        logits = self(x)
        # get argmax
        return logits.argmax(dim=-1)


Let's see how this would look so far.

In [22]:
model = BigramLanguageModel(vocab_size=vocab_size)
print(f"{model.token_embedding_table.weight.shape=}")

model.token_embedding_table.weight.shape=torch.Size([65, 65])


Let's make a naive prediction. Given a particular index in our vocabulary, let's see what the predicted output is:

In [23]:
test_idx = 5
logits = model(torch.tensor([test_idx]))
prediction = model.predict(torch.tensor([test_idx]))
print(f"Input index: {test_idx}, Prediction index: {prediction.item()}")
# tensor of shape (1, vocab_size), with each element being the
# logit for each token in the vocabulary at the given index.
print(f"Output logits shape: {logits.shape}")
# we take the max index of the logits to get the
# predicted token.
print(f"Input: {itos[test_idx]}, Prediction: {itos[prediction.item()]}")

Input index: 5, Prediction index: 39
Output logits shape: torch.Size([1, 65])
Input: ', Prediction: a


Let's try it on a batch of texts. Let's get the accuracy based on a random set of results. When we run `forward`, we get a result of shape `(batch_size, block_size, vocab_size)`. We get a `[vocab_size]`-shaped tensor for all positions in our `[batch_size, block_size]`-shaped batch.

In [24]:
batch_logits = model(xb)
preds = model.predict(xb)

In [25]:
# batch logits shape = (batch_size, block_size, vocab_size)
print(f"{batch_logits.shape=}")

# preds shape = (batch_size, block_size) because we
# take the argmax on the -1 dim to get the predicted token.
print(f"{preds.shape=}")
print(f"{yb.shape=}")
print(f"Accuracy: {(preds == yb).float().mean()}")

batch_logits.shape=torch.Size([4, 8, 65])
preds.shape=torch.Size([4, 8])
yb.shape=torch.Size([4, 8])
Accuracy: 0.0


Our model only has an accuracy of 3%. Again, this makes sense given that our model is randomly initialized.

Of course this doesn't have any actual basis, since this is based on a randomly initialized lookup table.

Let's add a loss function to our model so we can have a way to evaluate "how good" the predictions are.

In [26]:
class BigramLanguageModel(nn.Module):
    """Bigram model."""
    def __init__(self, vocab_size):
        super().__init__()
        # read off the logits for the next token
        # from a lookup table.
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, idx, targets=None):
        """Forward pass."""
        # get the logits for the next token.
        # for a given batch with shape (B, T), we will get an output
        # of shape (B, T, C), where B = batch size, T = sequence length
        # and C = vocab size.
        logits = self.token_embedding_table(idx)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            # flatten the logits
            logits = logits.view(B*T, C)
            # flatten the targets
            targets = targets.view(B*T)
            # compute the loss
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def predict(self, x):
        """Predict the next token."""
        # get the logits
        logits = self(x)
        # get argmax
        return logits.argmax(dim=-1)


Now let's get the predictions and loss. We get the logits for all `[batch_size, block_size]` characters as well as the total loss.

We're expecting something like `-ln(1/65) ~ 4.17`, which is what the loss would be if we were accurate `1/65` of the time (which is expected for a random model with vocab size of 65).

In [27]:
model = BigramLanguageModel(vocab_size=vocab_size)
logits, loss = model(xb, yb)
print(f"{logits.shape=}")
print(f"{loss=}")

logits.shape=torch.Size([32, 65])
loss=tensor(4.6630, grad_fn=<NllLossBackward0>)


Now that we can evaluate the loss, we'd like to also be able to do generation from the model.

In [28]:
class BigramLanguageModel(nn.Module):
    """Bigram model."""
    def __init__(self, vocab_size):
        super().__init__()
        # read off the logits for the next token
        # from a lookup table.
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, idx, targets=None):
        """Forward pass."""
        # get the logits for the next token.
        # for a given batch with shape (B, T), we will get an output
        # of shape (B, T, C), where B = batch size, T = sequence length
        # and C = vocab size.
        logits = self.token_embedding_table(idx)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            # flatten the logits
            logits = logits.view(B*T, C)
            # flatten the targets
            targets = targets.view(B*T)
            # compute the loss
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def predict(self, x):
        """Predict the next token."""
        # get the logits
        logits = self(x)
        # get argmax
        return logits.argmax(dim=-1)


    def generate(self, idx, max_new_tokens):
        """Generate new tokens.
        
        idx is (B,T) shaped tensor with array of indices
        that are in the current context.

        At each step, we predict the next token and add it to the context.

        Algorithm:
            - We get the logits for the last token in the context for each sequence
            in the batch
            - We apply softmax to get probabilities for the next token, for each
            sequence in the batch.
            - We sample from the distribution to get the next token for each sequence
            in the batch.
            - We add the new token to the context and repeat.
        """
        for _ in range(max_new_tokens):
            # get predictions
            logits, _ = self(idx)
            # focus on last time stamp
            # becomes (B, C) since we take the logits of the
            # last position in each sequence of the batch.
            logits = logits[:, -1, :]
            # apply softmax to get probabilities.
            # (B, C), but now values are probabilities.
            # Each row is a probability distribution whose values add
            # up to 1.
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            # (B, 1). We get the next token for each sample in the batch.
            next_token = torch.multinomial(probs, num_samples=1)
            # we add the new token to the context and repeat.
            # idx originally (B, T), next_token (B, 1), so when we concatenate
            # we get (B, T+1)
            idx = torch.cat([idx, next_token], dim=1) # (B, T+1)

        # shape of idx is (B, T+max_new_tokens) since we've added
        # max_new_tokens to the context.
        return idx


Let's see how well this works.

In [29]:
model = BigramLanguageModel(vocab_size=vocab_size)

In [30]:
max_new_tokens = 100
generated_tokens = model.generate(xb, max_new_tokens)
# batch_size of 4
# 8 initial tokens
# => (4, 8)
print(f"{xb.shape=}")
# batch_size of 4
# for each sequence, 8 initial tokens + 100 generated tokens.
# => (4, 108)
print(f"{generated_tokens.shape=}")

xb.shape=torch.Size([4, 8])
generated_tokens.shape=torch.Size([4, 108])


Let's decode both of these to see the actual letters. First, let's look at our batch.

In [31]:
decoded_batch = []
for seq in xb:
    output_str = ""
    for idx in seq:
        output_str += itos[idx.item()]
    decoded_batch.append(output_str)

for i, seq in enumerate(decoded_batch):
    print(f"String {i}: {seq}")

String 0: Let's he
String 1: for that
String 2: nt that 
String 3: MEO:
I p


Now let's look at our generated text.

In [32]:
decoded_generated_text = []
for seq in generated_tokens:
    output_str = ""
    for idx in seq:
        output_str += itos[idx.item()]
    decoded_generated_text.append(output_str)

for i, (seq, generated_seq) in enumerate(
    zip(decoded_batch, decoded_generated_text)
):
    print(f"String {i}: {seq} -> {generated_seq}\n")
    print('-' * 10)

String 0: Let's he -> Let's heO;$XXId TuoJck,rsLYMOY:HA!;UQZMAOBt&oNk,.?Cn.$pXWKCaL3uc.Y:lsL3IKcSjMWixZfhqVQZYjJpKF!lcE-HHsejq$ntj

----------
String 1: for that -> for thatwckIGObSYrDZOEgqxYDjeRjEISpDSsu
dFV:PDBOkCBhS,.zk'.!:TfrWx$
dXHStdYKt?aMmQuaOl;UobB;MWi?V-HFYx, !wtd

----------
String 2: nt that  -> nt that WyZvFrk,,,i:thNdlBJZIVNSqPe,,KaHtBIE?Gc!
y:TtMwKbL:FBt!H$CfVDSjY&$hzGcbSoc,HbKAcp'tBR'wLGck.jC'H,llM

----------
String 3: MEO:
I p -> MEO:
I piEgWKckxrZIckeIhlk,.JOHAa?kYauc.ErucA R,BAy?bECGmX?QJMLFFmNdXGBRfRUH;U'cA.YxJuZE$leb&I;Z3VqTK&-BqWBz

----------


As we can see, the strings that are generated are a bunch of random-looking characters. This is because we randomly generated the weights of our lookup table.

Let's set up the logic to train our model then so that it becomes less random.

We'll set up what the training logic will look like. The current results from training won't be very impressive, but we'll use the same logic and swap in more powerful training logic later.

In [33]:
lr = 1e-3
batch_size = 32
n_epochs = 100
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

In [34]:
for epoch in range(n_epochs):
    # sample a batch of data
    xb, yb = get_batch(split="train")

    # evaluate loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

Epoch 0, Loss: 4.714339733123779
Epoch 10, Loss: 4.676767349243164
Epoch 20, Loss: 4.701565265655518
Epoch 30, Loss: 4.765804290771484
Epoch 40, Loss: 4.667387008666992
Epoch 50, Loss: 4.631993293762207
Epoch 60, Loss: 4.670619487762451
Epoch 70, Loss: 4.655516624450684
Epoch 80, Loss: 4.576014995574951
Epoch 90, Loss: 4.6416544914245605


This is our general logic. The loss doesn't go down by much here. Let's try to train for more iterations.

Our training loop has the effect of updating the initial weights in our lookup table such that they'll more accurately reflect what we expect.

Let's train this for many more iterations though since the update step is pretty simple, in order to see if we can get a more useful result.

In [35]:
for epoch in range(10000):
    # sample a batch of data
    xb, yb = get_batch(split="train")

    # evaluate loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if epoch % 1000 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

Epoch 0, Loss: 4.506431579589844
Epoch 1000, Loss: 3.6920104026794434
Epoch 2000, Loss: 3.0407776832580566
Epoch 3000, Loss: 2.6942100524902344
Epoch 4000, Loss: 2.6889488697052
Epoch 5000, Loss: 2.5391829013824463
Epoch 6000, Loss: 2.5220439434051514
Epoch 7000, Loss: 2.5135819911956787
Epoch 8000, Loss: 2.623948335647583
Epoch 9000, Loss: 2.479816198348999


Let's take a look at our results.

It's clear that even though the results are still pretty random, it's much better than the original results. It's beginning to at least take the structure and lengths that you expect in a Shakespeare play. Words are still pretty random but it looks a lot more like an encrypted message than a random string of characters. Heck, it even appears to get close to resembling some text.

This is a promising approach, it's impressive that we can begin to get somewhat plausible text just from a simple bigram model combined with backpropagation. We don't even include any other context besides what character came earlier! During backpropagation, the gradients are computed for the selected rows in the embedding lookup table, and then the optimizer updates these rows based on the gradients. Across epochs, this has the effect of updating the rows of the embedding lookup table so that they more appropriately represent the probability of next characters given a certain reference index character.

In [36]:
print(decode(model.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))


He bldese bad f anthtourmand rseno H:
d wel.
Se, theananonge!
I, t the ksoth rmorad'Thathany tw t theareace 'K:
Bur f mowetooutwis K:
QUS:
Yo one glle:
TOKITESe imo'd y al intr ancV:
ERCHeeanthoulallian, n, t, w;
Q:

AMENCAndoound I oyofe.
Anl ouke givapat be meat,-kers,
CHARI shears m gmavencl hs
S:
ABo e t f I'dYBEETIne, s theom gr tenouthonomoreand by maw cextht,VI,
Sig I pENGI r frtite s t brse milddor iv, wind,
Hano, t,

's ethace hekanoure saico heneef m:
Whed nce INThe llas'd Bl oungollto


We can make this much more powerful though if we can use more of the context than just what character came before the one that we're trying to predict.

### The mathematical trick in self-attention

Let's consider the following example:

In [37]:
torch.manual_seed(1337)
B,T,C = 4, 8, 2
x = torch.randn(B, T, C)
print(f"{x.shape=}")

x.shape=torch.Size([4, 8, 2])


Currently, the 8 tokens in each sequence are not talking to each other. We want the tokens to talk to each other, but in a very specific way.

For example, if we are at the 5th token, we want to use the 1st-4th tokens as context, but not any future tokens.

What is an easy way to encapsulate the "context"? One naive way is to just take the average of the tokens before the $t^{th}$ token and use that as the context for the $t^{th}$ token.

If I'm the 5th token, for example, I want to take the information from my time step but also steps 1-4 as well, and then take a weighted average of the those, in order to create a feature vector that "summarizes" the 5th token.

This obviously loses a lot of information, especially positional information. This falls for a lot of the same flaws as, say, one-hot encoding or bag of words does, namely that you can't capture position. For example, "the cat is jumping over the dog" and "the dog is jumping over the cat" are treated identically because they have the same words. We take the "average representation" of all the words up to and including the $t^{th}$ token.

In [38]:
xbow = torch.zeros((B,T,C)) # bag of words
# for each sequence in the batch
for b in range(B):
    # for each token in the sequence
    for t in range(T):
        # get the previous tokens.
        x_prev = x[b,:t+1] # (t, C)
        # get the mean of the previous tokens
        x_mean = x_prev.mean(dim=0)
        # set the mean to the current token.
        xbow[b,t] = x_mean

# we want x[b,t] = mean_{i<=t} x[b,i]

Let's take a look at how this looks:

In [39]:
print(f"{xbow.shape=}")

xbow.shape=torch.Size([4, 8, 2])


Let's see how this looks for, say, the 1st sequence. We see that for the first sequence, we get a result with shape `[8,2]`.

In [40]:
print(f"{x[0,:,:].shape=}")

x[0,:,:].shape=torch.Size([8, 2])


Let's figure out how we got the value at, say, `[0, 4,:]`, the 5th element in the tensor.

In [41]:
print(f"{x[0,4,:]=}") # original value
print(f"{xbow[0,4,:]=}") # averaged value

x[0,4,:]=tensor([0.3612, 1.1679])
xbow[0,4,:]=tensor([0.3525, 0.0545])


Let's see what the elements are from the 1st to the 5th elements and then average them.

In [42]:
print(f"{x[0,4,:]=}")

x[0,4,:]=tensor([0.3612, 1.1679])


In [43]:
print(f"{x[0,0:5,:]=}")
print(f"{x[0,:5].mean(dim=0)=}") # matches our averaged value.

x[0,0:5,:]=tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679]])
x[0,:5].mean(dim=0)=tensor([0.3525, 0.0545])


We took the logits of the 1st to 5th values, averaged them out, and that was our representation for the 5th value.

#### The trick: migrating this to matrix multiplication
We can actually do this quite efficiently if we use matrix multiplication.

Let's see how this can work.

In [44]:
torch.manual_seed(42)
a = torch.ones((3, 3))
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print(f"{a=}")
print(f"{a.shape=}")
print(f"{b=}")
print(f"{b.shape=}")
print(f"{c=}")
print(f"{c.shape=}")

a=tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
a.shape=torch.Size([3, 3])
b=tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
b.shape=torch.Size([3, 2])
c=tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])
c.shape=torch.Size([3, 2])


Multiplying by a matrix of ones naturally adds all of the elements per column:

$$
\left[
\begin{matrix}
1&1&1\\
\end{matrix}
\right]
*
\left[
\begin{matrix}
2\\
6\\
6\\
\end{matrix}
\right]
= 14
\rightarrow
\left[
\begin{matrix}
1&1&1\\
1&1&1\\
1&1&1\\
\end{matrix}
\right]
*
\left[
\begin{matrix}
2\\
6\\
6\\
\end{matrix}
\right]
=
\left[
\begin{matrix}
14\\
14\\
14\\
\end{matrix}
\right]
$$

$$
\left[
\begin{matrix}
1&1&1\\
\end{matrix}
\right]
*
\left[
\begin{matrix}
7\\
4\\
5\\
\end{matrix}
\right]
= 16
\rightarrow
\left[
\begin{matrix}
1&1&1\\
1&1&1\\
1&1&1\\
\end{matrix}
\right]
*
\left[
\begin{matrix}
7\\
4\\
5\\
\end{matrix}
\right]
=
\left[
\begin{matrix}
16\\
16\\
16\\
\end{matrix}
\right]
$$

We can then simply average these, if we wanted, by taking any of the values per column and dividing by the mean:
$$
\left[
\begin{matrix}
\frac{14}{3}&\frac{16}{3}\\
\end{matrix}
\right]
$$

We've now established that we can easily get the sums of a tensor via matrix multiplication.

Let's now look at another piece for the calculation: `torch.tril`.

In [45]:
# lower triangular matrix
torch.tril(torch.ones(3,3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

Now let's replicate our logic from above.

In [46]:
a = torch.tril(torch.ones(3,3))

In [47]:
c = a @ b
print(f"{a=}")
print(f"{b=}")
print(f"{c=}")
print(f"{c.shape=}")

a=tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
b=tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c=tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])
c.shape=torch.Size([3, 2])


Let's break down what we're looking at by looking at the output matrix.

The first output row is $\left[2,7\right]$. We get this through the matrix multiplication of:
$$
\left[
1,0,0
\right]
*
\left[
\begin{matrix}
2&7\\
6&4\\
6&5\\
\end{matrix}
\right]
\rightarrow
\left[(1*2 + 0*6 + 0*6),(1*7 + 0*4 + 0*5)\right]
\rightarrow
\left[2, 7\right]
$$

We take the sum of the first element and zero out the rest.

The second output row is $\left[8, 11\right]$. We get this through the matrix multiplication of:
$$
\left[
1,1,0
\right]
*
\left[
\begin{matrix}
2&7\\
6&4\\
6&5\\
\end{matrix}
\right]
\rightarrow
\left[(1*2 + 1*6 + 0*6),(1*7 + 1*4 + 0*5)\right]
\rightarrow
\left[8, 11\right]
$$

We calculate the second row by suming the first and second elements and zeroing out the rest.

The same logic applies for the third row.

By multiplying by the triangular matrix of ones, we can get the cumulative sum, at index $t$, of all elements up to and including $t$, which is exactly what we had done with a for loop before. This lets us perform the operation of "get the sum of all the previous terms up to the term at position $t$" in a mathematical way.

Now, to calculate the means, we can just change $a$ so that we normalize across all the rows:

In [48]:
a = torch.tril(torch.ones(3,3))
print(f"a before normalization: {a}")
a = a / a.sum(dim=-1, keepdim=True)
print(f"a after normalization: {a}")


a before normalization: tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
a after normalization: tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])


This matrix will now let us calculate "add up all the previous terms up to and including the term at position $t$ **and** find the mean" in a mathematical way.

In [49]:
c = a @ b
print(f"{a=}")
print(f"{b=}")
print(f"{c=}")

a=tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b=tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c=tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


We've now replaced the previous for-loop with a matrix operation.

Let's see what else we can vectorize to make things even more efficient.

We can generalize this to an arbitrary context length/block size $T$, which is what we've been using.

If we do the same matrix multiplication for our arbitrary sample batch $x$, we see what we get the same results as our previous for-loop, generalized to a matrix multiplication. By multiplying with a normalized lower triangular matrix, we get the same weighted aggregation that we previously got with a for loop.

In [50]:
wei = torch.tril(torch.ones(T, T)) # our weight matrix
wei = wei / wei.sum(1, keepdim=True) # normalize the weights
# originally, wei = (T, T)
# PyTorch will make it (1, T, T)
# it will then broadcast it to (B, T, T)
# then, for each batch element, it will multiply (T,T) @ (T,C) -> (T,C)
# so the output will be (B, T, C)
xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)
torch.allclose(xbow, xbow2)

True

In [51]:
print(wei.shape)
print(x.shape)
print(xbow2.shape)

torch.Size([8, 8])
torch.Size([4, 8, 2])
torch.Size([4, 8, 2])


Let's now rewrite it in one more way, which is to use the softmax. We use the same logic as before, but instead of having zeros, we use -inf and take the softmax of the weight matrix.

By taking the softmax, we make sure that the rows all add up to 1, therefore creating a probability distribution. This will be useful for us since in practice, the weights won't add up to 1. We want to be able to interpret the weights as probabilities for what we're doing later, and taking the softmax helps us with that.

This is also why we use -inf instead of 0, since the softmax of -inf is 0 ($e^{-\infty} = 0$) while the softmax of 0 is 1 ($e^{0} = 1$), and we don't want 0 values to have any weight in the probability distribution.

In [52]:
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
print(f"wei before: {wei}")
wei = wei.masked_fill(tril == 0, float("-inf"))
print(f"wei after fill: {wei}")
wei = F.softmax(wei, dim=-1)
# all the zeros will have equal weighting, while -inf will have 0 weighting.
print(f"wei after softmax: {wei}")


wei before: tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
wei after fill: tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
wei after softmax: tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
 

In [53]:
print(wei)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


We recreated the same weight matrix as earlier, so let's do the same computation.

In [54]:
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

True

When we do attention, our weights will give a relative weighting of "how much do we want to attend to that specific context", and setting future tokens to -inf means "don't attend to this token". These weighing will be trained by the data so that certain rows can "attend" to other rows.

We can use this lower triangular matrix + softmax method to do weighted aggregations of past elements so that we can have relative weightings of how important certain past tokens are to the understanding of the current token.

### Setting up self-attention

Now that we've set up that math, we can now start to construct self-attention.

First, let's start by updating our Bigram model definition.

In [55]:
class BigramLanguageModel(nn.Module):
    """Bigram model."""
    def __init__(self, vocab_size=65, block_size=8, n_embed=32):
        super().__init__()
        # we add a n_embed parameter to the model. This will tell us
        # the embedding dimension of the tokens.
        # we'll have 32-d embeddings for each token.
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)

        # we also want to represent the position of whatever token that we're 
        # looking at. We'll use a positional embedding for this.
        self.positional_embedding_table = nn.Embedding(block_size, n_embed)

        # language model head. Converts from token embedding to logits.
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        """Forward pass."""
        B, T = idx.shape
        # now, instead of logits, we'll get token embeddings.
        token_embeddings = self.token_embedding_table(idx) # (B, T, n_embed)
        # we now also get the positional embeddings. We fetch the positional embeddings
        # for each token in the sequence, up to length "T" (the sequence length).
        # this helps us capture the position of each token in the sequence.
        positions = torch.arange(T, device=idx.device)
        position_embeddings = self.positional_embedding_table(positions) # (T, n_embed)

        # we now represent X as a combination of token embeddings and positional embeddings.
        # So, X contains not only the token identities, but the position at which they occur.
        x = token_embeddings + position_embeddings # (B, T, n_embed)

        # we pass the embeddings through the language model head
        # to get the logits.
        logits = self.lm_head(x) # (B, T, vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def predict(self, x):
        """Predict the next token."""
        # get the logits
        logits = self(x)
        # get argmax
        return logits.argmax(dim=-1)


    def generate(self, idx, max_new_tokens):
        """Generate new tokens.
        
        idx is (B,T) shaped tensor with array of indices
        that are in the current context.

        At each step, we predict the next token and add it to the context.

        Algorithm:
            - We get the logits for the last token in the context for each sequence
            in the batch
            - We apply softmax to get probabilities for the next token, for each
            sequence in the batch.
            - We sample from the distribution to get the next token for each sequence
            in the batch.
            - We add the new token to the context and repeat.
        """
        for _ in range(max_new_tokens):
            # get predictions
            logits, _ = self(idx)
            # focus on last time stamp
            # becomes (B, C) since we take the logits of the
            # last position in each sequence of the batch.
            logits = logits[:, -1, :]
            # apply softmax to get probabilities.
            # (B, C), but now values are probabilities.
            # Each row is a probability distribution whose values add
            # up to 1.
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            # (B, 1). We get the next token for each sample in the batch.
            next_token = torch.multinomial(probs, num_samples=1)
            # we add the new token to the context and repeat.
            # idx originally (B, T), next_token (B, 1), so when we concatenate
            # we get (B, T+1)
            idx = torch.cat([idx, next_token], dim=1) # (B, T+1)

        # shape of idx is (B, T+max_new_tokens) since we've added
        # max_new_tokens to the context.
        return idx


### Implementing self-attention

Now we can get to what is the crux of self-attention.

In [56]:
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time/context length/block size, embedding dimension / number of channels
x = torch.randn(B, T, C)

We implement our same summation as before. Each of the 4 sampels in  the batch has shape `[8, 32]`. We multiply this by our `[8,8]` lower triangular softmax weight matrix. It takes a weighted running average of the entries up to entry $t$.

For example, for the 5th entry in the list, we treat its representation as the weighted sum of all 5 `[1, 32]` embedding tensors up to position 5:
$$
    \Sigma_{i=0}^{i=4}w_i\text{embedding}_i = w_0e_0 + w_1e_1 + w_2e_2 + w_3e_3 + w_4e_4\\
    w_i \in \R^1 \\
    \text{embedding}_i \in \R^{32}
$$

For each position, each embedding will be the token embedding for that token plus the positional embedding for that embedding. We then take a weighted sum of those.
$$\text{Embedding} = \text{Token embedding} + \text{Positional embedding}$$

In [58]:
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float("-inf"))
wei = F.softmax(wei, dim=-1)
print(f"{wei.shape=}")
print(f"{x.shape=}")
print(f"{wei=}")
# print(f"{x=}")
out = wei @ x
print(f"{out.shape=}")
# print(f"{out=}")

wei.shape=torch.Size([8, 8])
x.shape=torch.Size([4, 8, 32])
wei=tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
out.shape=torch.Size([4, 8, 32])


We initialize the "affinities" or the weights of the weight matrix to a lower triangular matrix of zeros, since once we do the softmax of this, this will translate to an equally-weighted weight matrix of all tokens up to and including the tokens at position $t$.

In practice, we don't actually want these to be all equally weighted. We want different weights, corresponding to how "important" a particular token is to the context of the token at position $t$. We want higher weights to corresopnd to higher affinities, or our model learning that that particular token is more important than other tokens to understanding the meaning of the token in question.

Put differently, different tokens will find different other tokens interesting, and we want the weights to reflect that.

Self-attention solves the problem of getting information from the past, but selectively choosing which information is deemed relevant and doing so in a data-dependent way.

### How self-attention works
Our goal is to get a weight matrix that will tell us how important each token is to understanding the token at position $t$. That is our end objective. We know that we'll have succeeded if we get an input sequence and for any arbitrary position $t$, we can take a weighted sum of tokens up to and including $t$ that will help us "understand token $t$.

One key innovation of self-attention is that it **decomposes the weight matrix into two matrices, the key and the query matrices**.

Every single token at each position emits two vectors, a **key** and a **query** tensor. Roughly speaking,
- **Query** vectors tell "what am I looking for?"
- **Key** vectors tell "this is what information I contain"

#### How queries and keys interact
We get affinities between tokens in a sequence by taking to dot product of the query and key vectors. For a given token, we take its query vector and dot product against all the other key vectors of the sequence. The higher the dot product between token $a$'s query vector and token $b$'s key vector, the more that b "answers" a's query, meaning that b is more important to understanding the meaning of a, so a puts more weight to token b and "attends" to it more.


### Attention heads
We perform "self-attention" in an "Attention Head", so let's see it in action.

Let's initialize key and query matrices and do the computation. Each key and query matrix will take in a token embedding and output a tensor of size "head_size".

$$\text{Embedding} \rightarrow \text{Key}, \text{Query}$$

For each token, we take its token embedding, pass it into the key and query matrices, and get a key and query representation:
$$\text{Token embedding} \in \R^{32} \rightarrow key(\text{embedding}), query(\text{embedding}) \rightarrow \text{Key} \in \R^{16}, \text{Query} \in \R^{16}$$

In [61]:
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
k = key(x)
q = query(x)
print(f"{x.shape=}") # (B, T, 32)
# for each 32-d embedding in the [4,8] batch sequences we get a 16-d key and query.
print(f"{k.shape=}") # (B, T, 16)
print(f"{q.shape=}") # (B, T, 16)

x.shape=torch.Size([4, 8, 32])
k.shape=torch.Size([4, 8, 16])
q.shape=torch.Size([4, 8, 16])


All 32 tokens in the $[4,8]$ matrix each get a key and query vector. These haven't interacted with each other yet, so now we need to do that interaction.

In [62]:
# we need to transpose the last two dimensions of k, since what we want is (T, 16) @ (16, T) -> (T, T) (and since it does it in parallel for each batch element, we get (B, T, T))
# We now have the weight matrix, represented as the matrix product of the key and query.
wei = q @ k.transpose(-2, -1)

In [65]:
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float("-inf"))
print(f"{wei.shape=}")
print(f"B=1 of weight matrix (product of query and key matrices) before softmax: \n{wei[0]}")
wei = F.softmax(wei, dim=-1)
print(f"B=1 of weight matrix after softmax: \n{wei[0]}")
out = wei @ x
print(f"{out.shape=}")

wei.shape=torch.Size([4, 8, 8])
B=1 of weight matrix (product of query and key matrices) before softmax: 
tensor([[1.0000,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.5186, 0.4814,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.3447, 0.3377, 0.3176,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4020, 0.2005, 0.2034, 0.1940,   -inf,   -inf,   -inf,   -inf],
        [0.1829, 0.2702, 0.1690, 0.2027, 0.1753,   -inf,   -inf,   -inf],
        [0.2205, 0.1504, 0.1565, 0.1490, 0.1502, 0.1734,   -inf,   -inf],
        [0.1615, 0.1488, 0.1367, 0.1314, 0.1589, 0.1285, 0.1342,   -inf],
        [0.1277, 0.1173, 0.1294, 0.1517, 0.1211, 0.1160, 0.1197, 0.1171]],
       grad_fn=<SelectBackward0>)
B=1 of weight matrix after softmax: 
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5093, 0.4907, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3371, 0.3348, 0.3281, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        

We now can represent the weighs as the matrix product of the key and query matrices, each of which is a fully-connected neural network whose parameters can be learned! Now the weight elements are no longer just a constant, where we take equally weighted amounts of each element, but rather a learned weighing of each element.

Let's say, for example, that we're the 5th token. The 5th token knows what content it has and its position in the sequence. Based on that info (which gets represented as the embedding of that 5th token, the sum of the token and the positional embeddings), creates a query ("hey, I'm looking for this kind of stuff").
- e.g., One of the weights in the query tensor might represent the query "I'm a consonant in position 8, I'm looking for vowels up to position 4"
- All nodes/previous tokens emit keys to answer the query, and when the node can answer the query being asked ("yes I am a vowel up to position 4"), the dot product between its key and the query is high.
- Let's say that we get a good answer on position = 3.
- **The resulting weight matrix, $T * T$ represents queries (as rows) and keys (as columns).**
- Therefore, we would get a dot product, $q_4 * k_2$, at position $\text{wei}[4, 2]$, that is quite high, meaning that the token at position 3 is important to understanding the token at position 5, ro the token at position 5 "attends" to the token at position 3.
- We do a row-wise softmax (i.e., we normalize across the dot products of a given query against all keys), which compresses the weights to have a sum of 1 and disproportionately increases the weighing of very strong dot products and decreases the weighing of small/negative dot products.

The weight matrix is the matrix multiplication of the query and key tensor matrices. For self-attention, we need to mask out the upper right of the matrix, to avoid leaking future information. We then softmax across each row (corresponding to queries) such that the weighing is normalized across all the keys responding to a given query.


#### Adding the last component of the self-attention head: the value matrix

We now have one more part that we have to include in order to complete our "self-attention head". We need to add a "value" tensor matrix.

The value tensor holds the actual information that needs to be aggregated based on the attention scores. So, insead of $qk * x$, we actually want $qk * v$, where $v = Value(x)$.

##### Why use a value matrix instead of the raw input embededings?
Let's take an example. Let's say that we have the sentence "The run outdoors was long and hot", and let's say that we're looking at the word "run".

Our embedding for the word "run" is going to be a combination of the token embedding (likely a word embedding coming from something like a BERT embedding or a Word2Vec model) and position embedding. The word embedding representation is likely an averaged out meaning of all the meanings that the word "run" can have (e.g., "run" as an action vs. "run" as a baseball term vs. "run" as a software development term).

The key and query help us learn more about the  context of "run" in the sentence (e.g., "how was the run?", "who was running?", "where was the run?"). The value helps us learn what the word "run" actually means in the sentence (here, "run" as the act of movement as opposed to any other definition). We get a more context-specific representation of what the word "run" means.

To take an overly simplified example, let's say that the word embedding for the word run has 3 components:
$$
\left[
\begin{matrix}
\text{Run = act of moving quickly}\\
\text{Run = act of managing something}\\
\text{Run = a score in baseball}\\
\end{matrix}
\right]
$$

Let's say that our raw word embedding for the word "run" represents the word embedding for the word "run" as an equally weighted tensor of the three.
$$
\left[
\begin{matrix}
\text{Run = act of moving quickly} = 0.33\\
\text{Run = act of managing something} = 0.33\\
\text{Run = a score in baseball} = 0.33\\
\end{matrix}
\right]
\rightarrow
\left[0.33, 0.33, 0.33\right]
$$

Let's say that our value matrix is a randomly initialized tensor with head size of 4. For this case, then, it would have a shape of $[3, 4]$.  Let's say that in this random initialization, the resulting value tensor becomes $[0.25, 0.25, 0.25, 0.25]$ and that all 3 definitions of "run" from the word embedding are still weighed equally.

Through backpropagation, what we want to see is that the "value" tensor that comes from passing the "run" embedding to the value matrix begins to more heavily weigh the "Run = act of moving quickly" parameter of the "run" embedding over the rest. The value matrix takes in the token and returns "the meaning of this token in THIS context, as opposed to other contexts". Put differently, the value tensor is a context-aware representation of the input ("what does the word "run" mean in *this* context?").

##### How the query, key, and value interact
The "value" tensor contains this representation of the word itself. It learns its own representation of what the word means. It's a richer representation than just the token or positional embeddings alone.

The query and key matrices, when multiplied, create the weight matrix. This weight matrix, let's call it $W$, can then be multiplied by a value vector, $v$, to modify the representation of $v$ so that some parameters are weighed more and some are weighed less. This lets us change the representation of $v$ based on the context that we learned from the sequence.

For example, let's say that we have a head size of 16, context length of 8, and embedding size of 32. Let's say that the 5th word is "run".

The sequence that "run" is in would be a $[8,32]$ tensor, where the 5th element, at index $[4]$, would be $[1,32]$. We'd pass this to the value matrix, which changes the sequence from $[8,32] \rightarrow [8,16]$, and the 5th element from $[1,32] \rightarrow [1,16]$.

The weight matrix would be $QK^{T} \in \R^{T*T}$. 

The final output is given by
$$
    A = \text{Softmax}(\frac{QK^{T}}{\text{Normalization constant}}) \\ 
$$

$$
    A \in \R^{T*T}  \\
    V \in \R^{T*16} = \text{A 16-D value tensor for each of the $T$ tokens in the context length}\\
$$

$$
    \text{Output} = AV = \R^{T*T} * \R^{T*16} \rightarrow \R^{T*16} \\
$$

For the 5th element, we would take the tensor given by $A[4,:] \in \R^{T}$, which contains the attention weights for the 5th word "run" with respect to all other words in the sequence (the last 3 words are zeroed out, meaning we don't consider them when figuring out the meaning of "run"). These weights determine how much each word in the context contributes to the final representation of "run".

#### Another concrete example of how query, key, and value matrices interact
For example, let's say we see the word "treadmill" as the second token in the sequence and we want to know the representation of "run" (e.g., "I'll treadmill for my run today indoors").

The query tensor for "run" is $Q_4$ (for the 5th word). The key tensor for "treadmill" is $K_1$. The dot product should be high, let's say, $Q_4K_1^{T} = 0.8$ Therefore, the weight matrix $A$ would have the entry $A_{4,1} = 0.8$.

**Attention weights**

Let's say that after the softmax, "treadmill" is deemed to the context of "run" as very important compared to the other words in the sequence, so that $A_{4,1} = 0.9$.

Let's say that $A[4,:] = [0.05, 0.9, 0.01, 0.01, 0.03, 0, 0, 0]$.

**Value vectors**

Let's now have some values for our value tensors:

$$
V =
\left[
\begin{matrix}
V0\\
V1\\
V2\\
V3\\
V4\\
V5\\
V6\\
V7\\
\end{matrix}
\right]
$$

$$
V = 
\left[
\begin{matrix}
[0.1, 0.2, 0.1, 0.2]\\
[0.3, 0.1, 0.4, 0.2]\\
[0.6, 0.2, 0.1, 0.1]\\
[0.4, 0.7, 0.3, 0.3]\\
[0.5, 0.2, 0.1, 0.1]\\
[0.1, 0.3, 0.5, 0.5]\\
[0.1, 0.1, 0.7, 0.3]\\
[0.2, 0.5, 0.2, 0.1]\\
\end{matrix}
\right]
$$


**Output**

Let's now calculate our output for position 5:
$$\text{Output}_4 = \Sigma_{j=0}^{j=7}A_{4,j} * V_j$$

Recall that $A[4,:] = [0.05, 0.9, 0.01, 0.01, 0.03, 0, 0, 0]$. If we do this calculation, then

$$\text{Output}_4 = (0.05 * [0.1, 0.2, 0.1, 0.2]) + (0.9 * [0.3, 0.1, 0.4, 0.2]) +
    (0.01 * [0.6, 0.2, 0.1, 0.1]) + (0.01 * [0.4, 0.7, 0.3, 0.3]) + (0.03 * [0.5, 0.2, 0.1, 0.1])
    + (0 * [0.1, 0.3, 0.5, 0.5]) + (0 * [0.1, 0.1, 0.7, 0.3]) + (0 * [0.2, 0.5, 0.2, 0.1])
$$

$$
    \text{Output}_4 = [0.3,0.115,0.372,0.197]
$$

**Interpretation**

So, we represent the output for the 5th token, our learned meaning of the word "run" in this context, as a weighted sum of our learned representations of the other words in the sequence. We're saying that here, the meaning of the 5th token is most heavily driven by the word "treadmill".

The query and key tell us which of the words (whose representations themselves are in the "value" tensors) are most important to understanding the meaning of the token that we care about.

As we can see, our output vector for the word "run", $[0.3,0.115,0.372,0.197]$, is very similar to our value tensor for the word "treadmill" $[0.3, 0.1, 0.4, 0.2]$, signifying that the word "treadmill" has an outsized importance in determining the meaning of the word "run".


### Putting it all together: the attention head

The query, key, and value matrices, to put differently, do the following conceptually:
- The query asks "I need to know X,Y,Z to understand the meaning of the word I care about"
- The key says "here is my answer to that query". If the query is "I need a word that is a noun", then if the word is a noun, then $QK^T$ should be a higher number. The magnitude of the dot product tells you how well that key answers that query.
- The attention calculation, the softmax of the normalized dot product between the queries and keys, tells you how much to "pay attention" to different words.
- When the attention tensor is multiplied to the value tensors, we get a representation of what a certain token means as an aggregated sum of the words that came before it. For the previous example, we saw that the meaning of the word "run" in the sentence "I'll treadmill for my run indoors today", is 90% explained by the word "treadmill", hence the output for the word "run" is heavily weighed by the value tensor of the word "treadmill".

This, in sum, is the attention head.

### The attention head in action

Let's see what this calculation looks like.

In [66]:
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
print(f"{x.shape=}")
k = key(x)
q = query(x)
v = value(x)
print(f"{k.shape=}")
print(f"{q.shape=}")
print(f"{v.shape=}")

wei = q @ k.transpose(-2, -1)
print(f"{wei.shape=}")

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float("-inf"))
print(f"{wei.shape=}")
print(f"B=1 of weight matrix (product of query and key matrices) before softmax: \n{wei[0]}")
wei = F.softmax(wei, dim=-1)
print(f"B=1 of weight matrix after softmax: \n{wei[0]}")

# now let's actually get the weighted sum of the values.
out = wei @ v
print(f"{out.shape=}")


x.shape=torch.Size([4, 8, 32])
k.shape=torch.Size([4, 8, 16])
q.shape=torch.Size([4, 8, 16])
v.shape=torch.Size([4, 8, 16])
wei.shape=torch.Size([4, 8, 8])
wei.shape=torch.Size([4, 8, 8])
B=1 of weight matrix (product of query and key matrices) before softmax: 
tensor([[-1.7629,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-3.3334, -1.6556,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-1.0226, -1.2606,  0.0762,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.7836, -0.8014, -0.3368, -0.8496,    -inf,    -inf,    -inf,    -inf],
        [-1.2566,  0.0187, -0.7880, -1.3204,  2.0363,    -inf,    -inf,    -inf],
        [-0.3126,  2.4152, -0.1106, -0.9931,  3.3449, -2.5229,    -inf,    -inf],
        [ 1.0876,  1.9652, -0.2621, -0.3158,  0.6091,  1.2616, -0.5484,    -inf],
        [-1.8044, -0.4126, -0.8306,  0.5898, -0.7987, -0.5856,  0.6433,  0.6303]],
       grad_fn=<SelectBackward0>)
B=1 of weight matrix after softmax: 
tensor([[1

Let's now encapsulate this in a programmatic form.

In [67]:
class AttentionHead(nn.Module):
    """One head of self-attention."""

    def __init__(self, head_size, n_embed, context_length):
        super().__init__()
        self.head_size = head_size
        self.n_embed = n_embed
        self.context_length = context_length
        self.dropout_p = 0.1
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(context_length, context_length)))
        self.dropout = nn.Dropout(self.dropout_p)
    
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x) # (B, T, head_size)
        q = self.query(x) # (B, T, head_size)
        v = self.value(x) # (B, T, head_size)

        # compute attention scores
        wei = q @ k.transpose(-2, -1)
        wei = wei.masked_fill(self.tril == 0, float("-inf"))
        wei = F.softmax(wei, dim=-1)

        # perform the weighted sum of the values
        out = wei @ v
        return out

We combine the results of the single heads into multiple heads that run in parallel.

In [68]:
class MultiHeadAttention(nn.Module):
    """Multiple heads of self-attention in parallel."""
    def __init__(self, n_heads, head_size, n_embed, context_length):
        super().__init__()
        self.n_heads = n_heads
        self.head_size = head_size
        self.n_embed = n_embed
        self.context_length = context_length
        self.attention_heads = nn.ModuleList([
            AttentionHead(head_size, n_embed, context_length) for _ in range(n_heads)
        ])
        self.linear = nn.Linear(n_heads * head_size, n_embed)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, x):
        # for each head, we get the output of the attention head.
        # we concatenate the outputs of the attention heads along the last dimension.
        # we then pass the concatenated output through a linear layer to get the final output.
        out = torch.cat([head(x) for head in self.attention_heads], dim=-1)
        out = self.linear(out)
        out = self.dropout(out)
        return out

We next need a feedforward set of linear layers:

In [69]:
class FFNN(nn.Module):
    """Feed-forward neural network."""
    def __init__(self, n_embed, hidden_dim):
        super().__init__()
        self.linear1 = nn.Linear(n_embed, hidden_dim)
        self.activation = nn.ReLU()
        self.linear2 = nn.Linear(hidden_dim, n_embed)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        x = self.dropout(x)
        return x

We then combine this all into one Transformer block.

In [70]:
class TransformerBlock(nn.Module):
    """Transformer block."""

    def __init__(self, n_heads, n_embed, context_length, hidden_dim):
        super().__init__()
        # since we concatenate the results of the different parallel heads together, we want
        # the total concatenated output to equal n_embed.
        head_size = n_embed // n_heads   
        self.attention = MultiHeadAttention(n_heads, head_size, n_embed, context_length)
        self.norm1 = nn.LayerNorm(n_embed)
        self.ffnn = FFNN(n_embed, hidden_dim)
        self.norm2 = nn.LayerNorm(n_embed)
    
    def forward(self, x):
        # we first pass the input through the multi-head attention layer.
        # we then add the input to the output of the multi-head attention layer.
        # we then pass the output through a feed-forward neural network.
        # we then add the output of the feed-forward neural network to the output of the multi-head attention layer.
        # we then normalize the output.
        x = x + self.attention(self.norm1(x))
        x = x + self.ffnn(self.norm2(x))
        return x

We then combine these blocks to create an end-to-end implementation of a GPT model.

In [71]:
class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size=65, block_size=8, n_embed=32, n_heads=2, hidden_dim=64, n_layers=2):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.positional_embedding_table = nn.Embedding(block_size, n_embed)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(n_heads, n_embed, block_size, hidden_dim) for _ in range(n_layers)
        ])
        self.lm_head = nn.Linear(n_embed, vocab_size)
        self.ln = nn.LayerNorm(n_embed) # final layer norm.

        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        """Add good weight initializations."""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        
    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_embeddings = self.token_embedding_table(idx) # (B, T, n_embed)
        position_embeddings = self.positional_embedding_table(
            torch.arange(T, device=idx.device)
        ) # (T, n_embed)
        x = token_embeddings + position_embeddings # (B, T, n_embed)
        x = self.blocks(x) # (B, T, n_embed)
        x = self.ln(x) # (B, T, n_embed)
        logits = self.lm_head(x) # (B, T, vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] # (B, C)
            probs = F.softmax(logits, dim=-1) # (B, C) of probs
            # sample from distribution
            next_token = torch.multinomial(probs, num_samples=1)
            # add new token to context
            idx = torch.cat([idx, next_token], dim=1)
        return idx