# Mathematical Trick for Self-Attention

In [30]:
import torch

In [31]:
torch.manual_seed(42)

<torch._C.Generator at 0x105d63510>

In [32]:
B, T, C = 4, 8, 2 # batch, time, channels (i.e. # of tokens in vocabulary)
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

## Bag Of Words (bow)

== averaging

### 1. For Loop
Using for loop => not efficient

In [33]:
x_bow = torch.zeros((B, T, C))

for batch in range(B):
    for time in range(T):
        x_prev = x[batch, :time+1]  # all previous tokens (up to time t) in this batch and sample
        x_bow[batch, time] = torch.mean(x_prev, dim=0)

x_bow[0]

tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176],
        [ 0.8657, -0.8644],
        [ 0.5422, -0.3617],
        [ 0.3864, -0.5354],
        [ 0.2272, -0.5388],
        [ 0.1027, -0.3762]])

### 2. Vectorization

In [34]:
weights = torch.ones(T, T)
weights

tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [35]:
# the future is not relevant for predictions, only look at the past
weights = torch.tril(weights)
weights

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [36]:
weights = weights / weights.sum(axis=1, keepdim=True)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [37]:
x_bow2 = weights @ x
x_bow2[0]

tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176],
        [ 0.8657, -0.8644],
        [ 0.5422, -0.3617],
        [ 0.3864, -0.5354],
        [ 0.2272, -0.5388],
        [ 0.1027, -0.3762]])

In [38]:
# check that both are similar
torch.allclose(x_bow, x_bow2)

True

### 3. Softmax

In [39]:
tril = torch.tril(torch.ones((T, T)))
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [40]:
weights = torch.zeros((T, T))
weights = weights.masked_fill(tril == 0, float("-inf"))
weights

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [41]:
weights = torch.nn.functional.softmax(weights, dim=-1)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [42]:
x_bow3 = weights @ x
x_bow3[0]

tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176],
        [ 0.8657, -0.8644],
        [ 0.5422, -0.3617],
        [ 0.3864, -0.5354],
        [ 0.2272, -0.5388],
        [ 0.1027, -0.3762]])

In [43]:
torch.allclose(x_bow, x_bow3)

True

## Self-Attention

In [67]:
from torch import nn

B, T, C = 4, 8, 32 # batch, time, channels (i.e. # of tokens in vocabulary)
x = torch.randn(B, T, C)
x.shape

# single head self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

# ---> here is what I have / what do I contain? / the evidence
# get words that most closerly attent to current word
k = key(x) # (B, T, head_size) => (4, 8, 16)
# print(k.shape)

# ---> here is what I'm interested in / what am I looking for / questions that I have /
# in predicting next word, are we summarizing, translating or creating?
q = query(x) # (B, T, head_size) => (4, 8, 16)
# print(q.transpose(-2, -1).shape)  #  => (4, 16, 8)


weights = q @ k.transpose(-2, -1)  # => (B, T, 16) @ (B, 16, T) => (B, T, T)  / dot product
weights *= head_size ** -0.5  # scaling, otherwise the softmax will become too peaky => high initial values will get a high value after softmax

tril = torch.tril(torch.ones((T, T)))
# weights = torch.zeros((T, T))
weights = weights.masked_fill(tril == 0, float("-inf"))  # only in decoder blocks; in encoder blocks, we allow each token to talk to each other
weights = torch.nn.functional.softmax(weights, dim=-1)

# out = weights @ x
# ---> here is what I will communicate to you if you find me interesting / relevance of the evidence to solve the case
# relevance of this pair to correct prediction
v = value(x)
out = weights @ v  # instead of the raw x values

out.shape

torch.Size([4, 8, 16])

In [68]:
weights

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.8990, 0.1010, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3339, 0.4727, 0.1934, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4491, 0.0536, 0.3944, 0.1029, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0218, 0.5909, 0.1793, 0.1461, 0.0619, 0.0000, 0.0000, 0.0000],
         [0.4292, 0.1283, 0.0066, 0.0170, 0.0012, 0.4177, 0.0000, 0.0000],
         [0.1479, 0.0274, 0.1908, 0.0884, 0.1558, 0.0718, 0.3179, 0.0000],
         [0.4090, 0.0620, 0.0469, 0.0045, 0.0068, 0.1389, 0.2396, 0.0923]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2627, 0.7373, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0712, 0.2229, 0.7059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4681, 0.1490, 0.3433, 0.0396, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0136, 0.1054, 0.0638, 0.4138, 0.4035, 0.0000, 0.0000, 0.0000],
         [0.0189, 0.222

# Notes

- `Encoder`: no tril, allows all tokens to communicate.
- `Decoder`: triangular masking, allows only to communicate with the past tokens.
- `Self-Atttention`: `q`, `k`, `v` are produced from the same `x`.
- `Cross-Attention`: `q` produced from `x` - `k`, `v` produced from other, external source (e.g. encoder module)
- `Scaled Attention`: attention devided by `1/sqrt(head_size)`

### Softmax Peakiness

softmax will emphasize high values => sharpen towards highest values

In [78]:
t = torch.tensor([0.1, -0.2, -0.3, 0.5])
s1 = torch.softmax(t, dim=0)
s1

tensor([0.2562, 0.1898, 0.1717, 0.3822])

In [80]:
s1.var()

tensor(0.0091)

In [82]:
s2 = torch.softmax(t * 8, dim=0)
s2

tensor([0.0390, 0.0035, 0.0016, 0.9559])

In [83]:
s2.var()

tensor(0.2218)