# Preamble

In [1]:
import sys
sys.path.append("../")

In [8]:
import torch
import torch.nn as nn
from torch.nn import functional

from src.self_attention import ScaledDotProductSelfAttentionHead

# Understanding what @ does

The symbol @ is a replacemnt for matmul, which does batch matrix-matrix multiplcation. It basically does a mat-mat multiplication on the last 2 dimensions, and broadcast everything else if it can.

In [2]:
B = torch.randn(3, 1, 2, 2)
C = torch.randn(5, 2, 3)
D = B @ C
D.shape

torch.Size([3, 5, 2, 3])

In [3]:
torch.matmul(B[1, 0, :], C[2, :])

tensor([[-3.2130, -0.6345, -3.4687],
        [ 0.4709,  0.0288,  0.3812]])

In [4]:
D[1, 2, :]

tensor([[-3.2130, -0.6345, -3.4687],
        [ 0.4709,  0.0288,  0.3812]])

In [5]:
# Reverse the order
C = torch.randn(6, 4, 2)
D = C @ B
D.shape

torch.Size([3, 6, 4, 2])

# Efficient bag of words calculation

In [6]:
B = torch.tril(torch.ones(8, 8))
B.sum(1, keepdim=True)

tensor([[1.],
        [2.],
        [3.],
        [4.],
        [5.],
        [6.],
        [7.],
        [8.]])

In [7]:
weights = B / B.sum(1, keepdim=True)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [8]:
C = torch.randn(4, 8, 2)
D = weights @ C
D.shape

torch.Size([4, 8, 2])

# Beginning of self-attention

In practice, we don't want a simple average, we want a weighted average and we want to learn the weights. That's why you can rewrite this weighting matrix using a masked_fill and Softmax

In [4]:
mask = torch.tril(torch.ones(8, 8))
mask

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [5]:
weights = torch.zeros(8, 8)
weights.masked_fill_(mask==0, float("-inf"))
weights

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [11]:
functional.softmax(weights, dim=-1) # softmax along the last dimension

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

## Example with non-constant initial weights

In [12]:
weights = torch.randn(8, 8)
weights.masked_fill_(mask==0, float("-inf"))
weights

tensor([[-0.6851,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.5143,  1.5780,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.1154,  1.4088,  1.1345,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.2455,  0.0146,  0.7242,  1.5113,    -inf,    -inf,    -inf,    -inf],
        [ 0.0629, -0.4085, -1.8356,  1.8961, -0.4704,    -inf,    -inf,    -inf],
        [ 0.1704,  0.2344, -1.1358, -1.4638, -0.6538,  1.9615,    -inf,    -inf],
        [ 0.7947, -0.3530,  0.2352,  0.9171, -0.4502, -0.8813,  1.2977,    -inf],
        [ 1.0420,  2.1455,  0.3926,  1.0247, -1.7611,  0.9412, -0.8895,  0.9052]])

In [13]:
functional.softmax(weights, dim=-1)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1098, 0.8902, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1348, 0.4915, 0.3736, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1438, 0.1142, 0.2321, 0.5099, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1161, 0.0724, 0.0174, 0.7260, 0.0681, 0.0000, 0.0000, 0.0000],
        [0.1115, 0.1189, 0.0302, 0.0218, 0.0489, 0.6687, 0.0000, 0.0000],
        [0.1943, 0.0616, 0.1110, 0.2195, 0.0559, 0.0364, 0.3212, 0.0000],
        [0.1333, 0.4019, 0.0696, 0.1310, 0.0081, 0.1205, 0.0193, 0.1163]])

In [14]:
functional.softmax(weights, dim=-1).sum(dim=-1, keepdim=True)

tensor([[1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000]])

# Self Attention

In [3]:
self_attention = ScaledDotProductSelfAttentionHead(512, 64, 12)

In [4]:
self_attention.mask

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])

In [12]:
x = torch.randn(32, 12, 512)
self_attention(x).shape

torch.Size([32, 12, 64])

In [13]:
x = torch.randn(32, 8, 512)
self_attention(x).shape

torch.Size([32, 8, 64])

In [34]:
drop = nn.Dropout(0.2)

In [35]:
drop(x).shape, x.shape

(torch.Size([32, 8, 512]), torch.Size([32, 8, 512]))

In [36]:
((x - drop(x))/x).min()

tensor(-0.2500)

In [40]:
x[0,0,0]

tensor(-0.2632)

In [41]:
drop(x)[0,0,0]

tensor(-0.3290)