# Learning attention block

I realized that if I wanted to note down code and math, then the best place to do it is in a notebook.

In this one, we will write down all the details of the self attention block. This is so pervasive now that I think it is worth the effort to understand it carefully.

I have written down notes on the transformer before, though I stopped at the math level and didn't get to the code level. Today the goal is to do everything fully, and write down explanations of every line to make it ultra easy to recover next time.

## Conceptual recap
Input:
- a window of length `w` of token embeddings of dim `n_embd_1`.
- a batch size `B` of independent queries
- a total of `N` embeddings in the dictionary. These embeddings are fixed an not trained. For instance for language models, we will be feeding in a window of `N` words at a time, and these words are converted to word embeddings in dim `n_embd_1`.
    - I don't actually think we need this constraint that the number of embeddings be finite, because we can always translate to the query, key and value vectors using a linear layer. I will explain this later
- So the final shape of the input will be `(B, w, n_embd_1)`

Parameters:
- For each embedding vector, we have a query vector $q_i$ of dim `n_embd_2`, key vector $k_i$ of dim `n_embd_2`, and a value vector $v_i$, of dim `n_embd_3`
    - Note: the embedding dimensions could all be the same or different, but the important point is that $q_i$ and $k_i$ have the same simension, because we would want to form $\langle q_i, k_j\rangle$ between different embedding vectors.
- The above parameters is only for a single attention head. For multi-head attentions, we would have `nh` different Q/K/V vectors for each embedding.

The idea for a single attention head is as follows: given `w` input words, for each word in the end we want to generate another output vector of the same dim `n_embd_1` which depends on the other words in this window. The output will be a linear combination of the value vectors $v_i$s of the words (or rather a projection of the linear combination of the value vectors, if `n_embd_3` does not equal `n_embd_1`).

How do we get the weights of the linear combination? For word $i$, the weights will be the softmax of $\frac{1}{\text{d2}}\langle q_i, k_j\rangle$ (d2 is `n_embd_2`), where $j$ ranges over the other `w` words in the window (idea being: you use the query vector $q_i$ to query against the keys of all the other words).

And that's it. Then you repeat the process above for multi-head attention and concatenate and project into a final output.

Final note: for text data, when we are doing next token generation, often we don't want early words to depend on the later words. So when getting the weights for word $i$, we might restrict it to using a linear combination of only $v_j$ for $j \le i$. For this we need to do some `torch.tril` business that we'll see below.

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
# params
B = 8
w = 16
n_embd_1 = 16
n_embd_2 = 32
n_embd_3 = 32  # setting equal n_embd_2 for convenience
nh = 8

class SelfAttention(nn.Module):
    def __init__(self):
        super().__init__()
        # attention block
        # nn.Linear takes in shape (*, in_features) and outputs shape (*, out_features)
        # for us, input shape will be (B, w, n_embd_1)
        # c_attn contains all the weights for Q, K, V for nh heads
        self.c_attn = nn.Linear(n_embd_1, 3 * nh * n_embd_2)

        # the output from all the heads will be nh * n_embd_2
        # need to project that down to n_embd_1
        self.c_proj = nn.Linear(nh * n_embd_2, n_embd_1)
        
        # some regularization for the attention layer and output layer
        self.attn_drop = nn.Dropout(0.1)
        self.out_drop = nn.Dropout(0.1)
        
        # Optional: a mask to ensure that token i only gets weights from j <= i
        # not a trainable param, so use register buffer
        self.register_buffer("mask", torch.tril(torch.ones(w, w)).view(1, 1, w, w))
        
    def forward(self, x):
        # x shape (B, w, n_embd_1)
        # support inference on window sizes smaller than w
        # so all the operations are done with windon size T instead of the max w
        T = x.size(1)

        # split into q, k, v that each contains nh heads
        # Tensor.chunk splits the tensor into specified number of chunks
        # could also use Tensor.split(nh * n_embd_2, dim=-1), where we specify the size of each chunk
        q, k, v = self.c_attn(x).chunk(3, dim=-1)
        q = q.view(B, T, nh, n_embd_2).transpose(1, 2)  # (B, nh, T, n_embd_2)
        k = k.view(B, T, nh, n_embd_2).transpose(1, 2)
        v = v.view(B, T, nh, n_embd_2).transpose(1, 2)
        
        # attention weights
        # @ multiplication only multiplies the last two dimensions
        # attn_weights has shape (B, nh, T, T)
        attn_weights = (q @ k.transpose(-2, -1)) * (1.0 / (n_embd_2 ** 0.5))
        
        # Optional: apply mask
        # need to restrict mask to the first T to fit the input size
        attn_weights = attn_weights.masked_fill(self.mask[:, :, :T, :T], float('-inf'))
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.attn_drop(attn_weights)
        
        # get output
        output = attn_weights @ v
        # right now the shape of output would be (B, nh, T, n_embd_2)
        # want to concat all heads together into something of dim (B, T, nh * n_embd_2) and then apply projection
        # this code is a bit suuble:
        # first we transpose the output to (B, T, nh, n_embd_2)
        # but transposing only changes the indexing and doesn't change the memory location
        # so we need to call .contiguous to make sure the memory is laid out in the right order for .view
        output = output.transpose(1, 2).contiguous().view(B, T, -1)
        
        output = self.out_drop(self.c_proj(output))
        return output
