## Mathematical trick in self-attention

Below is a mathematical trick that's at the heart of efficient self-attention implementation in Transformers. We will use the below toy example toy example to understand this operation.

Currently, the 8 tokens in each batch don't communicate with each other. We want them to **"talk"** to each other, but with a specific constraint: each token should only communicate with previous tokens, not future ones. For example, the token at position 5 should only access information from positions 1, 2, 3, 4, and 5 - never from positions 6, 7, or 8, since those represent future information we're trying to predict.

In [None]:
import torch

In [None]:
# toy example
torch.manual_seed(42)
B,T,C = 4,8,2 # batch, seq_length, vocab_size
x = torch.randn(B,T,C)
x.shape

**The simplest way for tokens to communicate is through averaging.** If I'm the 5th token, I can take my channels and average them with the channels from all previous positions (1st through 4th). This creates a feature vector that summarizes me in the context of my history.

While averaging is a weak form of interaction that loses spatial arrangement information, it's a good starting point. We'll see how to add that information back later.

In [None]:
xbow = torch.zeros(x.shape)
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1]
        xbow[b,t] = torch.mean(xprev, 0)
x[0], xbow[0]

We can make this process highly efficient using matrix multiplication. Let demonstrate it with a toy example.

In [None]:
torch.manual_seed(42)
a = torch.ones((3,3))
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('---')
print('b=')
print(b)
print('---')
print('c=')
print(c)

**key trick**: Instead of using a boring matrix of all ones, PyTorch provides a function called `tril()` wrapped in `torch.ones()`. This returns only the lower triangular portion of the matrix, zeroing out the upper elements. 

This creates an incremental aggregation pattern where each position accumulates information from all previous positions. 

In [None]:
torch.manual_seed(42)
# create lower triangular matrix
a = torch.tril((torch.ones((3,3))))
# avg it along 1st dimension
a = a / a.sum(1,keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('---')
print('b=')
print(b)
print('---')
print('c=')
print(c)

This can be used to achieve averaging along the sequence_length efficiently.

In [None]:
# version 2
w = torch.tril(torch.ones(T,T))
w = w / w.sum(1, keepdim=True)
xbow2 = w @ x # (T,T) @ (B,T,C) -> (B,T,C)

In [None]:
torch.allclose(xbow2, xbow)

Now We will rewrite this in a third way using `softmax`. Softmax is a normalization operation. When we exponentiate each element:
- The finite values (where `tril` was 1) become e^0 = 1
- The negative infinity values (where `tril` was 0) become e^(-inf) = 0

Note: Currently these affinities are just set to zero by us, but in self-attention, these affinities won't be constant. They'll be data-dependent. Tokens will look at each other and find some tokens more or less interesting based on their content. The affinities will vary depending on how much tokens "want" to attend to each other.



In [None]:
# version 3: use softmax
from torch.nn import functional as F
tril = torch.tril(torch.ones(T,T))
w = torch.zeros((T,T))
w = w.masked_fill(tril == 0, float('-inf'))
w = F.softmax(w, dim=-1)
xbow3 = w @ x
torch.allclose(xbow3, xbow)

This is the preview for self-attention. The key takeaway from this entire section is that you can perform wghted aggregations of past elements using matrix multiplication in a lower triangular fashion. The elements in the lower triangular part determine how much each past element contributes to the current position.

**Crux of self-attention**

- It is probably the most important part of transformers to understand. _**So, please spend as much time as possible to understand this.**_
- The above version does a simple average of all the past tokens and the current token. So the previous information and current information is just being mixed together in an average.

**Making Attention Data-Dependent**
- Now we don't actually want this to be all uniform because different tokens will find different other tokens more or less interesting, and we want that to be data-dependent.
- For example, if I'm a vowel, then maybe I'm looking for consonants in my past, and maybe I want to know what those consonants are and I want that information to flow to me. So I want to now gather information from the past, but I want to do it in a data-dependent way. **This is the problem that self-attention solves**.

**How self-attention makes makes interactions/affinities data-dependent?**
- The way self-attention solves this is the following: every single node/token at each position will emit two vectors - it will emit a **query** and it will emit a **key**.
    - The **query** vector roughly speaking is "what am I looking for?"
    - The **key** vector roughly speaking is "what do I contain?"
- Then the way we get affinities between these tokens in a sequence is we basically just do a dot product between the keys and the queries.
- So my query dot products with all the keys of all the other tokens, and that dot product now becomes `w`. So if the key and the query are sort of aligned, they will interact to a very high amount (and vica-versa), and then I will get to learn more about that specific token as opposed to any other token in the sequence.


In [None]:
# version 4: self-attention (for single head)
from torch import nn
from einops import einsum
torch.manual_seed(42)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

# single head
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
k = key(x) # (B,T,head_size)
q = query(x)# (B,T,head_size)
# w = q @ k.transpose(-2,-1)
w = einsum(q,k, 'B T1 C, B T2 C -> B T1 T2') # (B,T,C) @ (B,C,T) -> (B,T,T)

tril = torch.tril(torch.ones(T,T))
w = w.masked_fill(tril == 0, float('-inf'))
w = F.softmax(w, dim=-1)
out = w @ x

out.shape

In [None]:
# w is not uniform anymore
w[0]

**Adding value**

- Now there's one more part to a single self-attention head, and that is that when we do the aggregation, we don't actually aggregate the tokens(`x`) exactly. We produce one more value here and we call that the **value**.
- So `v` is the elements that we aggregate instead of the raw `x`. You can think of `x` as kind of like private information to this token. So X is kind of private to this token - I'm the fifth token at some position and I have some identity, and my information is kept in vector `x`. Now for the purposes of the single head: here's what I'm interested in, here's what I have, and if you find me interesting, here's **what I will communicate to you** - and that's stored in `v`. `v` is the thing that gets aggregated for the purposes of this single head between the different nodes.
- This is basically the self-attention mechanism.

In [None]:
from torch import nn
from einops import einsum
torch.manual_seed(42)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

# single head
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B,T,head_size)
q = query(x)# (B,T,head_size)
v = value(x) # (B,T,head_size)
# w = q @ k.transpose(-2,-1)
# w = einsum(q,k, 'B T1 C, B T2 C -> B T1 T2') # (B,T,C) @ (B,C,T) -> (B,T,T)
w = torch.zeros((T,T))

tril = torch.tril(torch.ones(T,T))
w = w.masked_fill(tril == 0, float('-inf'))
w = F.softmax(w, dim=-1)
out = w @ v

out.shape

![directed-graph](images/directed-graph.webp)
- Attention is a **communication mechanism**
  - You can really think about it as a communication mechanism where you have a number of nodes in a directed graph where basically you have edges pointed between nodes like this. What happens is every node has some vector of information, and it gets to aggregate information via a weighted sum from all of the nodes that point to it.
- **There is no notion of space**.
  - Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
  - Notice that there's no inherent notion of space in attention. It simply acts over a set of vectors, and by default, these nodes have no idea about their positions. That's why we need positional encodings - to give the nodes information about where they are located in the sequence. This is different from convolutions, which have a very specific spatial layout and operate directly in that space.
- **Each example across batch dimension is of course processed completely independently** and never "talk" to each other
  - Elements across the batch dimension never communicate with each other - they're processed completely independently. This is achieved through batched matrix multiplication that applies the same operation in parallel across the batch dimension. So in our analogy of a directed graph, with a batch size of 4, we actually have four separate pools of 8 nodes each. The 8 nodes within each pool can communicate, but the different pools never interact.
- **Encoder vs Decoder Block**
  - This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
  - However, in many other applications like sentiment analysis, you might want all tokens to communicate with each other fully. In those cases, you'd use an "encoder block" - which is simply attention without the triangular mask (delete the single line that does masking with `tril`), allowing all nodes to talk to each other completely.
- **Self-attention vs Cross-attention**
  - self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- **Scaled attention** additional divides `w` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, w will be unit variance too.
  - This scaling is crucial because the attention weights feed into a softmax function. If the attention weights become too large (high variance), the softmax will become very "peaky" - it will converge toward one-hot vectors where each token only attends to a single other token. This is especially problematic at initialization, where we want the attention to be fairly diffuse so tokens can learn to attend to multiple relevant positions.




In [None]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
w = q @ k.transpose(-2, -1) * head_size**-0.5

print("k.var()", k.var())
print("q.var()", q.var())
print("w.var()", w.var())

In [None]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)

In [None]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1) # gets too peaky, converges to one-hot