In [71]:
import torch
import torch.nn.functional as F

# A mathematical trick that is used in the self attention inside a transformer
# and at the heart of an efficient implementation of self-attention

torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)
print("x:", x.shape)

# Now we would like this 8 tokens in a batch to talk to each other. But the token for example at the fifth location
# should not communicate with future tokens in a sequence (6, 7, 8, ...). It should only talk to tokens in (4, 3, 2, ...) locations.
# So information only flows from previous context to the current timestamp and we cannot get any information from the future
# because we are about to try to predict the future.
# The easiest way for tokens to communicate is to just do an average of all the preceding elements.
# For example if i am the fifth token (T) i would like to take channels (C) that make up information at my step but
# then also the channels from the four, third, second and first steps. I'd like to average those up and then that would
# become sort of like a feature vector that summarizes me in the context of my history.

# for each T inside each B, we wanna calculate the average of current T and all the previous
xbow = torch.zeros((B,T,C))
for b in range(B):
  print("---- b:", b)
  for t in range(T):
    cur_and_prev = x[b, :t+1]
    mean = torch.mean(cur_and_prev, 0)
    xbow[b,t] = mean
    print(f"-- t: {t}, cur_and_prev: {cur_and_prev.shape}, mean: {mean}")

print(xbow[0])
print(x[0])
    

x: torch.Size([4, 8, 2])
---- b: 0
-- t: 0, cur_and_prev: torch.Size([1, 2]), mean: tensor([ 0.1808, -0.0700])
-- t: 1, cur_and_prev: torch.Size([2, 2]), mean: tensor([-0.0894, -0.4926])
-- t: 2, cur_and_prev: torch.Size([3, 2]), mean: tensor([ 0.1490, -0.3199])
-- t: 3, cur_and_prev: torch.Size([4, 2]), mean: tensor([ 0.3504, -0.2238])
-- t: 4, cur_and_prev: torch.Size([5, 2]), mean: tensor([0.3525, 0.0545])
-- t: 5, cur_and_prev: torch.Size([6, 2]), mean: tensor([ 0.0688, -0.0396])
-- t: 6, cur_and_prev: torch.Size([7, 2]), mean: tensor([ 0.0927, -0.0682])
-- t: 7, cur_and_prev: torch.Size([8, 2]), mean: tensor([-0.0341,  0.1332])
---- b: 1
-- t: 0, cur_and_prev: torch.Size([1, 2]), mean: tensor([ 1.3488, -0.1396])
-- t: 1, cur_and_prev: torch.Size([2, 2]), mean: tensor([0.8173, 0.4127])
-- t: 2, cur_and_prev: torch.Size([3, 2]), mean: tensor([-0.1342,  0.4395])
-- t: 3, cur_and_prev: torch.Size([4, 2]), mean: tensor([0.2711, 0.4774])
-- t: 4, cur_and_prev: torch.Size([5, 2]), mean: 

In [78]:
# short version using: tril, softmax, matrix mul.

wei = torch.zeros((T,T))
print(wei)
wei = wei.masked_fill(tril == 0, float("-inf"))
print(wei)
wei = F.softmax(wei, dim=1) # normalization to 1
print(wei)
xbow2 = wei @ x
print(torch.allclose(xbow, xbow2))

tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000,