In [1]:
import torch
torch.manual_seed(1337)

<torch._C.Generator at 0x1fd22b9f690>

In [2]:
# toy example

# we have a matrix
B, T, C = 4, 8, 2 # batch size, time steps, channels    
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [3]:
# each example in a batch is a sequence of length T(8) and each element is expressed in a C-Dim space (2)
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [4]:
# now each token doesn't communicate with the other tokens in the sequence
# but what we want is to make each token aware of the previous tokens in the sequence
# so for example token 5 should be aware of token 4, 3, 2, 1 but not 6, 7, 8 
# because we cannot use information from the future when we try to predict the future

# what the EASIEST way for tokens to communicate?
# if a certain token wants to know some past information we can simply calculate the average of all the previous tokens + current token
# so each token will be more like a feature vector that summarizes current token in a context of its history
# disadventage of this approach is that we lose the order of the tokens in the sequence
# but for now let's try this approach

In [5]:
# so in a context of the first element in a batch:
# first token is gonna be the way it is:
# [ 0.1808, -0.0700]

# the second is gonna be avegage of the first and the second token:
# [ 0.1808, -0.0700],
# [-0.3596, -0.9152] =========> [-0.0894, -0.4926]

# the third:
# [ 0.1808, -0.0700],
# [-0.3596, -0.9152],
# [ 0.6258,  0.0255], =========> [ 0.1490, -0.3199]
# .....

## Brute force: a simple for loop

In [6]:
# we can do it with a simple for loop
x_bow = torch.zeros_like(x)
for b in range(B):
    for t in range(T):
        x_bow[b, t] = x[b, :t+1, ...].mean(dim = 0) #x[b, :t+1, ...] => (t, C)
x_bow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

## Matrix multiplication

In [7]:
# how can we do the same thing using matrix multiplication?
# lets start with a toy example
torch.manual_seed(42)
a = torch.ones(3, 3)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b

print(f'{a=}\n')
print(f'{b=}\n')
print(f'{c=}')

a=tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

b=tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])

c=tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


In [8]:
# now we sum up all elements in each column
# how to sum up only previous + current token?
# instead of using ones we can use a lower triangular matrix
a = torch.tril(a)
c = a @ b

print(f'{a=}\n')
print(f'{b=}\n')
print(f'{c=}')

a=tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

b=tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])

c=tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


In [9]:
# the only thing left is to divide each element 
# by the number of elements in each row of matrix a
# (we simply normilize each row of a)
a = a / a.sum(dim=1, keepdim=True)
c = a @ b

print(f'{a=}\n')
print(f'{b=}\n')
print(f'{c=}')

a=tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

b=tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])

c=tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [10]:
# so apply the same logic to our example

wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(dim=1, keepdim=True) # (T, T) / (T, 1)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [20]:
x_bow2 = wei @ x # (T, T) @ (B, T, C) broadcasting=>  (B, T, T) @ (B, T, C) => x_avg = (B, T, C)
torch.allclose(x_bow, x_bow2, atol=1e-6)

True

## Softmax

In [26]:
import torch.nn.functional as F

# why is exactly this approach using in self-attention?
# now we manually set all weights to 0 (here: wei = torch.zeros((T, T)))
# and after normilization we get uniform weights for each token
# but soon all this weights are gonna be more data dependent
# and softmax is a good way to normilaze them

tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf')) # it is not allowed to communicate with future tokens!
wei

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [27]:
wei = F.softmax(wei, dim=-1)
wei

# we get the same matrix as before

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [28]:
x_bow3 = wei @ x
torch.allclose(x_bow, x_bow3, atol=1e-6)

True