# Self Attention Trick

In [48]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [49]:
torch.manual_seed(1337)
B,T,C = 4,8,2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [50]:
xbow = torch.zeros(B,T,C)
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # T, C
        xbow[b, t] = torch.mean(xprev, 0)

In [51]:
torch.tril(torch.ones(3, 3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [52]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, dim=1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print("a=")
print(a)
print("--")
print("b=")
print(b)
print("--")
print("c=")
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [53]:
a

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

In [54]:
# version: 2
wei = torch.tril(torch.ones(T, T))
wei = wei / torch.sum(wei, 1, keepdim=True)
xbow2 = wei @ x # (B, T, T) @ (B, T, C)
# torch.allclose(xbow, xbow2)

In [55]:
# version 3
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

True

In [58]:
# version :4
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# let's see a single head of self attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)
wei = q @ k.transpose(-2, -1) # (B, T, 16) (B, 16, T) => (B, T, T)

tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros(T, T)
wei = wei.masked_fill(tril==0, float("-inf"))
wei = F.softmax(wei, dim=1)
out = wei @ x

out.shape

torch.Size([4, 8, 32])