## 2. Attention

The purpose of this notebook is twofold: (1) gain understanding of the attention mechanism, and (2) ensure that our implementation is correct by cross-checking with a built-in attention function.

We will use the following progression:

* Taking an average
* Vectorizing our average
* Learning weights for the wectorized weighted average
* Cross checking with a built-in attention function
* Expanding to mult-headed attention

As reminder, the training goal is to predict the next token. In data, it looks like this:

In [1]:
import torch

torch.set_printoptions(precision=4, sci_mode=False)
torch.manual_seed(538)

batch_size = 2
context_size = 3

data_batch = torch.randint(high=10, size=(batch_size, context_size + 1), dtype=torch.float32)
x_batch = data_batch[:, :context_size]
y_batch = data_batch[:, 1:context_size+1]

print(f"data:\n{data_batch}")

print("-" * 38)
print(f"| {'Batch': <6} | {'Context':<15} | {'Target'} |")
print("-" * 38)
for b in range(batch_size):
    for t in range(context_size):
        context = x_batch[b, : t + 1]
        target = y_batch[b, t]
        print(f"| {b: <6} | {str(context.tolist()):<15} | {target:<6} |")

data:
tensor([[7., 9., 4., 7.],
        [4., 9., 3., 0.]])
--------------------------------------
| Batch  | Context         | Target |
--------------------------------------
| 0      | [7.0]           | 9.0    |
| 0      | [7.0, 9.0]      | 4.0    |
| 0      | [7.0, 9.0, 4.0] | 7.0    |
| 1      | [4.0]           | 9.0    |
| 1      | [4.0, 9.0]      | 3.0    |
| 1      | [4.0, 9.0, 3.0] | 0.0    |


### 2.1 Taking an Average

A naive method for predicting the next word would be to take an average of the word features that come before it. Let's look at the following example. 

For continuity with our actual attention implementation, we will add in the embedding dimension to our data.

In [2]:
batch_size = 2
context_size = 3
n_embd = 4

x = torch.randint(high=10, size=(batch_size, context_size, n_embd), dtype=torch.float32)
x

tensor([[[4., 9., 0., 0.],
         [7., 0., 5., 3.],
         [2., 1., 4., 9.]],

        [[0., 7., 5., 4.],
         [5., 1., 1., 4.],
         [1., 5., 6., 5.]]])

We are going to therefore generate predictions for each position, so we are effectively getting multiple batches of examples from each sequence (as well as multiple independent sequences).

In [3]:
y_output = torch.zeros((batch_size, context_size, n_embd))
for b in range(batch_size):
    for t in range(context_size):
        x_prev = x[b, : t + 1, :]
        y_output[b, t, :] = x_prev.mean(dim=0)
print(f"Output ({batch_size=}, {context_size=}):\n {y_output}")

Output (batch_size=2, context_size=3):
 tensor([[[4.0000, 9.0000, 0.0000, 0.0000],
         [5.5000, 4.5000, 2.5000, 1.5000],
         [4.3333, 3.3333, 3.0000, 4.0000]],

        [[0.0000, 7.0000, 5.0000, 4.0000],
         [2.5000, 4.0000, 3.0000, 4.0000],
         [2.0000, 4.3333, 4.0000, 4.3333]]])


As a reminder: when you slice you don't get the value correspnding to the right integer e.g. `[:4]` it takes all elements up to but not including `4`.

### 2.2 Vectorizing our moving average

We can remove the for loop and vectorize the previous operation by the following technique: creating a square matrix of dimension `context_size`, then multiplying that matrix by our input matrix. 

And this is actually numerically equivalent to what we had before.

In [4]:
mask = torch.tril(torch.ones(context_size, context_size))
weights = mask / mask.sum(dim=1, keepdim=True)
weights

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

What we have is a square matrix: `(context_size, context_size)` multiplied by a `(context_size, n_embd)` matrix. 

And torch is going to broadcast this matrix operation across the batch dimension to give us ultimately what we want.

Which is a `(batch_size, context_size, n_embd)` matrix that is equivalent to what we had before.

In [5]:
y_output_2 = weights @ x
print(f"My outputs are equivalent: {torch.allclose(y_output, y_output_2)}")
y_output_2

My outputs are equivalent: True


tensor([[[4.0000, 9.0000, 0.0000, 0.0000],
         [5.5000, 4.5000, 2.5000, 1.5000],
         [4.3333, 3.3333, 3.0000, 4.0000]],

        [[0.0000, 7.0000, 5.0000, 4.0000],
         [2.5000, 4.0000, 3.0000, 4.0000],
         [2.0000, 4.3333, 4.0000, 4.3333]]])

Why does this work? 
* The rows in the triangle matrix are being multiplied by the columns (which corresponds to the context/sequence length) in the input matrix. 
* The first row in the triangle matix zero's out the all but the very first element of the input matrix columns. Likewise, the second row in the triangle matrix zeros out all by the very first two elements of the input matrix columns. And so on and so forth.

What is happening here? 
* Well, this is a moving average with uniformly weighted across all prior positions of the input. 
* The first row has a 1, the second row has a 1/2, the third row has a 1/3, etc...
* But what if we wanted to take a non-uniform weighted moving average, and ulimately learn the weights to use in the average?
* This is what the attention mechanism does.

### 2.3 Self-Attention Mechanism

We are going to learn the weights to use by comparing the query and the key. Here we are using the terminology from the world of databases and hash tables. Your query is like your request. The key is the index of the data to be returned. And the values are what you ultimately are caring about. It doesn't map 1:1 in the world of LLMs, but this is the terminology used.
* Change 1: we are learning weights associated with projections of the input sequence.
* Change 2: we are creating a weight matrix that corresponds to the similarity of the query and the key. (Previously this was a uniform)
* Change 3: we are replacing the zeros with -infinity, because since e^-infinity = 0

In [31]:
batch_size = 2
context_size = 3
n_embd = 4

q = torch.randint(high=6, size=(batch_size, context_size, n_embd), dtype=torch.float32)
k = torch.randint(high=6, size=(batch_size, context_size, n_embd), dtype=torch.float32)
print(q)
print(k)

weights = q @ k.transpose(-2, -1) * 1 / math.sqrt(k.shape[-1])
weights = weights.masked_fill(torch.tril(torch.ones(context_size, context_size)) == 0, float("-inf"))
weights = F.softmax(weights, dim=-1)
print(weights)
# output = weights @ v
# print(q @ k.transpose(-2, -1))

tensor([[[2., 2., 4., 5.],
         [2., 2., 0., 4.],
         [1., 5., 1., 2.]],

        [[1., 5., 1., 3.],
         [4., 5., 4., 5.],
         [3., 3., 4., 5.]]])
tensor([[[5., 4., 4., 3.],
         [0., 5., 2., 2.],
         [4., 4., 1., 5.]],

        [[5., 1., 2., 4.],
         [5., 5., 1., 1.],
         [0., 4., 2., 1.]]])
tensor([[[    1.0000,     0.0000,     0.0000],
         [    0.9975,     0.0025,     0.0000],
         [    0.4683,     0.0634,     0.4683]],

        [[    1.0000,     0.0000,     0.0000],
         [    0.3775,     0.6225,     0.0000],
         [    0.9707,     0.0293,     0.0000]]])


In [32]:
x

tensor([[[8., 1., 2., 5., 0., 1., 6., 3., 0., 6., 4., 2.],
         [2., 4., 5., 5., 7., 5., 5., 2., 0., 9., 8., 0.],
         [4., 4., 2., 6., 9., 8., 8., 6., 5., 9., 9., 7.]],

        [[8., 1., 6., 2., 7., 5., 9., 5., 7., 3., 7., 8.],
         [1., 0., 2., 9., 6., 0., 5., 4., 9., 4., 3., 4.],
         [6., 0., 6., 8., 6., 7., 8., 4., 7., 0., 0., 9.]]])

In [None]:
weights @ 

In [6]:
import math
import torch.nn as nn
import torch.nn.functional as F

key_layer = nn.Linear(n_embd, n_embd)
query_layer = nn.Linear(n_embd, n_embd)
value_layer = nn.Linear(n_embd, n_embd)

q = key_layer(x)
k = query_layer(x)
v = value_layer(x)

weights = q @ k.transpose(-2, -1) * 1 / math.sqrt(k.shape[-1])
weights = weights.masked_fill(torch.tril(torch.ones(context_size, context_size)) == 0, float("-inf"))
weights = F.softmax(weights, dim=-1)
output = weights @ v
print(f"Input ({batch_size=}, {context_size=}, {n_embd=}):\n{v}")
print(f"Weights ({batch_size=}, {context_size=},{context_size=}) :\n{weights}")
print(f"Output ({batch_size=}, {context_size=}, {n_embd=}):\n{output}")

Input (batch_size=2, context_size=3, n_embd=4):
tensor([[[-5.6161, -2.5599, -2.4308, -3.7171],
         [-3.9803,  0.2815,  0.6249, -0.4333],
         [ 0.9179,  2.0273, -3.0268,  2.1266]],

        [[-2.8479, -1.6375, -4.8781, -1.9841],
         [-1.6361,  0.9025, -0.0676,  0.4190],
         [-2.4767, -0.8501, -4.1768, -1.1245]]], grad_fn=<ViewBackward0>)
Weights (batch_size=2, context_size=3,context_size=3) :
tensor([[[    1.0000,     0.0000,     0.0000],
         [    1.0000,     0.0000,     0.0000],
         [    0.9943,     0.0000,     0.0057]],

        [[    1.0000,     0.0000,     0.0000],
         [    0.9991,     0.0009,     0.0000],
         [    0.9887,     0.0001,     0.0112]]], grad_fn=<SoftmaxBackward0>)
Output (batch_size=2, context_size=3, n_embd=4):
tensor([[[-5.6161, -2.5599, -2.4308, -3.7171],
         [-5.6161, -2.5599, -2.4308, -3.7171],
         [-5.5789, -2.5338, -2.4342, -3.6839]],

        [[-2.8479, -1.6375, -4.8781, -1.9841],
         [-2.8469, -1.6353, -4.8

* Change 4: We are applying a scaling factor of 1 / sqrt(d_k). This allows the softmax output to be more spread out when the values are large (the attention logit can become very large). 
* Here is an example.

In [7]:
print(f"{'no scaling:':<15} {F.softmax(torch.tensor([0, 1, 2, 3, 10, -1e4]).float(), dim=0)}")
print(f"{'with scaling:':<15} {F.softmax(torch.tensor([0, 1, 2, 3, 10, -1e4]).float() / n_embd, dim=0)}")

no scaling:     tensor([    0.0000,     0.0001,     0.0003,     0.0009,     0.9986,     0.0000])
with scaling:   tensor([0.0548, 0.0704, 0.0904, 0.1161, 0.6682, 0.0000])


### 2.4 Checking our work

Let's to a cross-check using the pytorch attention function to make sure we're doing this correctly. This is a function, so it expects you to pass in the key, query, value data to it.

And we can see that we're getting the exact same results here.

In [8]:
expected_output = F.scaled_dot_product_attention(q, k, v, is_causal=True)

print(f"Output ({batch_size=}, {context_size=}, {n_embd=}):\n{expected_output}")
print(f"Your implementation is correct: {torch.allclose(output, expected_output)}")

Output (batch_size=2, context_size=3, n_embd=4):
tensor([[[-5.6161, -2.5599, -2.4308, -3.7171],
         [-5.6161, -2.5599, -2.4308, -3.7171],
         [-5.5789, -2.5338, -2.4342, -3.6839]],

        [[-2.8479, -1.6375, -4.8781, -1.9841],
         [-2.8469, -1.6353, -4.8739, -1.9820],
         [-2.8436, -1.6283, -4.8696, -1.9741]]], grad_fn=<UnsafeViewBackward0>)
Your implementation is correct: True


### 2.5 Multi-headed Attention

Now we can expand this to the setting of a multi-headed attention. The idea here is to have a bunch of these attention heads processing completely independently. We could create these independently, and then loop through them, and then concatenate them together. However, there's a slightly more efficient way of doing this. Let's make our input data slightly bigger to illustrate and use an example of 4 heads.

* It's not obvious until you draw it out, but we create our keys/queries/values together.

In [9]:
batch_size = 2
context_size = 3
n_embd = 12
n_heads = 4

x = torch.randint(high=10, size=(batch_size, context_size, n_embd), dtype=torch.float32)
print(f"Input ({batch_size=}, {context_size=}, {n_embd=}):\n {x}")

key_layer = nn.Linear(n_embd, n_embd)
query_layer = nn.Linear(n_embd, n_embd)
value_layer = nn.Linear(n_embd, n_embd)

q = key_layer(x)
k = query_layer(x)
v = value_layer(x)


Input (batch_size=2, context_size=3, n_embd=12):
 tensor([[[8., 1., 2., 5., 0., 1., 6., 3., 0., 6., 4., 2.],
         [2., 4., 5., 5., 7., 5., 5., 2., 0., 9., 8., 0.],
         [4., 4., 2., 6., 9., 8., 8., 6., 5., 9., 9., 7.]],

        [[8., 1., 6., 2., 7., 5., 9., 5., 7., 3., 7., 8.],
         [1., 0., 2., 9., 6., 0., 5., 4., 9., 4., 3., 4.],
         [6., 0., 6., 8., 6., 7., 8., 4., 7., 0., 0., 9.]]])


Then we can convert the head into a batch dimension, so now we have two independent batch dimensions.

The first corresponding to the batch of sequences, the second corresponding to the attention head.

So our final shape going into the attention calculation is going to be `(batch_size, n_heads, context_size, n_embd // n_heads)`

In [10]:
q = q.view(batch_size, context_size, n_heads, n_embd // n_heads).transpose(1, 2)
k = k.view(batch_size, context_size, n_heads, n_embd // n_heads).transpose(1, 2)
v = v.view(batch_size, context_size, n_heads, n_embd // n_heads).transpose(1, 2)

Then we can do the exact same thing as before.

In [11]:
key_layer = nn.Linear(n_embd, n_embd)
query_layer = nn.Linear(n_embd, n_embd)
value_layer = nn.Linear(n_embd, n_embd)

weights = q @ k.transpose(-2, -1) * 1 / math.sqrt(k.shape[-1])
weights = weights.masked_fill(torch.tril(torch.ones(context_size, context_size)) == 0, float("-inf"))
weights = F.softmax(weights, dim=-1)
output = weights @ v
print(f"Input ({v.shape}):\n{v}")
print(f"Weights ({weights.shape}) :\n{weights}")
print(f"Output ({output.shape=}):\n{output}")

Input (torch.Size([2, 4, 3, 3])):
tensor([[[[ 0.5700,  0.3222, -2.3747],
          [ 1.5730, -2.4145, -0.8959],
          [ 2.6774, -0.4923, -2.4715]],

         [[-0.4135, -3.4527, -1.9973],
          [-5.2061,  0.1708,  1.9398],
          [-4.1394, -3.6896, -1.5541]],

         [[-1.8024,  0.6671,  1.5068],
          [ 0.4360,  2.8751, -2.0109],
          [-1.1852,  3.0334, -2.8473]],

         [[ 2.2916, -0.5117,  0.0289],
          [ 1.1969, -1.2733, -0.3100],
          [ 2.9144, -0.3376,  0.8424]]],


        [[[ 5.0462, -1.4799, -5.0024],
          [ 0.8774, -0.2471, -2.7377],
          [ 3.0366, -0.0454, -4.3973]],

         [[-2.6934, -2.8800, -2.7414],
          [-3.6354, -2.3350, -0.2569],
          [-1.9497, -5.5444, -2.4319]],

         [[-1.5781,  2.0507,  0.2764],
          [ 0.1471,  0.4645, -0.8665],
          [-1.0712,  3.6178,  0.2539]],

         [[ 0.8271, -0.1410,  2.5384],
          [ 2.2272,  1.3619,  2.4220],
          [ 2.5946,  1.0372,  5.2898]]]], grad_fn=<Tr

Just as before, we can perform our sanity check to ensure we did it right.

In [12]:
expected_output = F.scaled_dot_product_attention(q, k, v, is_causal=True)

print(f"Output ({batch_size=}, {n_heads=}, {context_size=}, {n_embd=}): {expected_output.shape}")
print(f"Your implementation is correct: {torch.allclose(output, expected_output)}")

Output (batch_size=2, n_heads=4, context_size=3, n_embd=12): torch.Size([2, 4, 3, 3])
Your implementation is correct: True


After computing the attition output, we can reassemble the attention heads side by side along the third dimension and return a 3 dimensional tensor.

* Original shape: `(batch_size, n_heads, context_size, n_embd // n_heads)`
* Final shape: `(batch_size, context_size, n_embd)`

In [13]:
output = output.transpose(1, 2).contiguous().view(batch_size, context_size, n_embd)