In [1]:
import torch
from torch.nn import functional

# Understanding what @ does

The symbol @ is a replacemnt for matmul, which does batch matrix-matrix multiplcation. It basically does a mat-mat multiplication on the last 2 dimensions, and broadcast everything else if it can.

In [2]:
B = torch.randn(3, 1, 2, 2)
C = torch.randn(5, 2, 3)
D = B @ C
D.shape

torch.Size([3, 5, 2, 3])

In [3]:
torch.matmul(B[1, 0, :], C[2, :])

tensor([[ 0.6361, -0.3220,  1.7799],
        [ 2.2120, -1.2736,  0.7472]])

In [4]:
D[1, 2, :]

tensor([[ 0.6361, -0.3220,  1.7799],
        [ 2.2120, -1.2736,  0.7472]])

In [6]:
# Reverse the order
C = torch.randn(6, 4, 2)
D = C @ B
D.shape

torch.Size([3, 6, 4, 2])

# Efficient bag of words calculation

In [36]:
B = torch.tril(torch.ones(8, 8))
B.sum(1, keepdim=True)

tensor([[1.],
        [2.],
        [3.],
        [4.],
        [5.],
        [6.],
        [7.],
        [8.]])

In [38]:
weights = B / B.sum(1, keepdim=True)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [40]:
C = torch.randn(4, 8, 2)
D = B @ C
D.shape

torch.Size([4, 8, 2])

# Beginning of self-attention

In practice, we don't want a simple average, we want a weighted average and we want to learn the weights. That's why you can rewrite this weighting matrix using a masked_fill and Softmax

In [46]:
mask = torch.tril(torch.ones(8, 8))
mask

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [49]:
weights = torch.zeros(8, 8)
weights.masked_fill_(mask==0, float("-inf"))
weights

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [52]:
functional.softmax(weights, dim=-1) # softmax along the last dimension

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

## Example with non-constant initial weights

In [53]:
weights = torch.randn(8, 8)
weights.masked_fill_(mask==0, float("-inf"))
weights

tensor([[-0.1661,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.7549, -0.6352,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.5888,  0.5921, -0.5735,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-1.4654, -0.8489,  1.1032, -1.4104,    -inf,    -inf,    -inf,    -inf],
        [ 1.3734,  0.6371,  2.6707, -0.2718,  0.9070,    -inf,    -inf,    -inf],
        [-2.8097, -0.5699,  1.3901, -1.0474, -0.7544, -0.5896,    -inf,    -inf],
        [-0.7703, -0.4959,  0.7730,  0.2219,  0.2030,  1.0977,  1.0572,    -inf],
        [-0.9147, -0.9655, -0.5338,  0.6499, -0.9065,  0.8309,  0.4588, -1.3579]])

In [54]:
functional.softmax(weights, dim=-1)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.8006, 0.1994, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1897, 0.6178, 0.1926, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0590, 0.1092, 0.7695, 0.0623, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1678, 0.0804, 0.6141, 0.0324, 0.1053, 0.0000, 0.0000, 0.0000],
        [0.0100, 0.0940, 0.6673, 0.0583, 0.0782, 0.0922, 0.0000, 0.0000],
        [0.0399, 0.0526, 0.1870, 0.1077, 0.1057, 0.2587, 0.2484, 0.0000],
        [0.0512, 0.0487, 0.0750, 0.2449, 0.0516, 0.2935, 0.2023, 0.0329]])

In [55]:
functional.softmax(weights, dim=-1).sum(dim=-1, keepdim=True)

tensor([[1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000]])