## 👀 Self attention

#### 📚 Libraries

In [1]:
import torch
from torch.nn import functional as F

#### 🧶 For loop implementation

In [2]:
torch.manual_seed(42)

b, t, c = 4, 8, 2  # batch, time, channels
x = torch.randn(b, t, c)
x.shape

torch.Size([4, 8, 2])

**Cumulative average**

In [3]:
# We want x[b, t] = mean{i<=t} x[b, i]
xbow = torch.zeros((b, t, c))
for batch in range(b):
    for time in range(t):
        xprev = x[batch, 0 : time + 1]  # (time, c)
        xbow[batch, time] = torch.mean(xprev, dim=0)

In [4]:
batch = 0

join = "└─"
join_char = "──"
line = "│ "

print("{:^17}{:^18}{:^17}".format("x", "", "xbow"))
for time in range(t):
    x_time = str(x[batch, time].numpy().round(3))
    xbow_time = str(xbow[batch, time].numpy().round(3))

    n_joins = time + 1
    n_lines = t - n_joins
    line_str = "".join([line * n_lines])
    join_line_str = "".join([join, join_char * (n_joins - 1)])
    whole_string = line_str + join_line_str

    print("{:^17}{:^18}{:^17}".format(x_time, whole_string, xbow_time))

        x                                xbow       
  [1.927 1.487]   │ │ │ │ │ │ │ └─   [1.927 1.487]  
 [ 0.901 -2.106]  │ │ │ │ │ │ └───  [ 1.414 -0.309] 
 [ 0.678 -1.235]  │ │ │ │ │ └─────  [ 1.169 -0.618] 
 [-0.043 -1.605]  │ │ │ │ └───────  [ 0.866 -0.864] 
 [-0.752  1.649]  │ │ │ └─────────  [ 0.542 -0.362] 
 [-0.392 -1.404]  │ │ └───────────  [ 0.386 -0.535] 
 [-0.728 -0.559]  │ └─────────────  [ 0.227 -0.539] 
 [-0.769  0.762]  └───────────────  [ 0.103 -0.376] 


#### 🦐 Matrix multiplication implementation

In [5]:
torch.manual_seed(42)
d = torch.tril(torch.ones(3, 3))
d = d / torch.sum(d, dim=1, keepdim=True)
e = torch.randint(0, 10, (3, 2)).float()
f = d @ e
print("d =")
print(d)
print()
print("e =")
print(e)
print()
print("f =")
print(f)

d =
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

e =
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])

f =
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [6]:
wei = torch.tril(torch.ones(t, t))
wei = wei / torch.sum(wei, dim=1, keepdim=True)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [7]:
xbow2 = wei @ x  # (b, t, t) @ (b, t, c) = (b, t, c

In [8]:
print("{:^17}{:^18}{:^17}".format("xbow", "", "xbow2"))
for time in range(t):
    xbow_time = str(xbow[batch, time].numpy().round(3))
    xbow2_time = str(xbow2[batch, time].numpy().round(3))
    print("{:^17}{:^18}{:^17}".format(xbow_time, "<--->", xbow2_time))

print()
print("xbow == xbow2:", torch.allclose(xbow, xbow2))

      xbow                               xbow2      
  [1.927 1.487]        <--->         [1.927 1.487]  
 [ 1.414 -0.309]       <--->        [ 1.414 -0.309] 
 [ 1.169 -0.618]       <--->        [ 1.169 -0.618] 
 [ 0.866 -0.864]       <--->        [ 0.866 -0.864] 
 [ 0.542 -0.362]       <--->        [ 0.542 -0.362] 
 [ 0.386 -0.535]       <--->        [ 0.386 -0.535] 
 [ 0.227 -0.539]       <--->        [ 0.227 -0.539] 
 [ 0.103 -0.376]       <--->        [ 0.103 -0.376] 

xbow == xbow2: True


#### 🚚 Softmax implementation

In [9]:
tril = torch.tril(torch.ones(t, t))
wei = torch.zeros((t, t))
wei = wei.masked_fill(tril == 0, float("-inf"))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x

In [10]:
print("{:^17}{:^18}{:^17}".format("xbow", "", "xbow3"))
for time in range(t):
    xbow_time = str(xbow[batch, time].numpy().round(3))
    xbow3_time = str(xbow3[batch, time].numpy().round(3))
    print("{:^17}{:^18}{:^17}".format(xbow_time, "<--->", xbow3_time))

print()
print("xbow == xbow3:", torch.allclose(xbow, xbow3))

      xbow                               xbow3      
  [1.927 1.487]        <--->         [1.927 1.487]  
 [ 1.414 -0.309]       <--->        [ 1.414 -0.309] 
 [ 1.169 -0.618]       <--->        [ 1.169 -0.618] 
 [ 0.866 -0.864]       <--->        [ 0.866 -0.864] 
 [ 0.542 -0.362]       <--->        [ 0.542 -0.362] 
 [ 0.386 -0.535]       <--->        [ 0.386 -0.535] 
 [ 0.227 -0.539]       <--->        [ 0.227 -0.539] 
 [ 0.103 -0.376]       <--->        [ 0.103 -0.376] 

xbow == xbow3: True
