In [1]:
#we will be looking at self-attention
#consider the following example 

import torch
torch.manual_seed(1337)
B,T,C = 4,8,2 #batch,time,channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [2]:
#we want x[b, t] = mean_{i<=t} x[b,i] -> The word will be calculating using average of predecessors
xbow = torch.zeros((B,T,C)) #bow is 'bag of words'
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] #(t,C)
        xbow[b,t] = torch.mean(xprev, 0)

In [3]:
#we can achieve the same result as above using Matrix multiplication, which is computationally less expensive
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3)) #produces a lower right triangle of ones
a = a / torch.sum(a, 1, keepdim=True) #make each row sum to one
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print(a)
print(b)
print(c)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [4]:
#using the idea above, we find the mean
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
#although wei is (T, T), matrix multiplication with x (B, T, C) will create a batch dimension for wei and all batch elements will be multiplied in parallel
xbow2 = wei @ x #return (B, T, C)

In [5]:
#we can also solve this problem using softmax
from torch.nn import functional as F
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0 , float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x

In [6]:
#we will now try incorporate self-attention into this model
import torch.nn as nn
torch.manual_seed(1337)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

#let's see a single head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B,T,16)
q = query(x) # (B,T,16)

wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) --> (B, T, T)

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

#note that we also scale the attention so that wei is divided by sqrt(head_size). This is to ensure that Softmax does not saturate too much (tunnel vision on the peak) and stay diffuse

v = value(x)
out = wei @ v 

out.shape

torch.Size([4, 8, 16])

In [9]:
linear_layer = nn.Linear(in_features=3,out_features=1)
print(linear_layer.weight)

Parameter containing:
tensor([[-0.2744,  0.4940,  0.0129]], requires_grad=True)
