# Mathematical Trick for Self-Attention

In [1]:
import torch

In [2]:
torch.manual_seed(42)

<torch._C.Generator at 0x105d63510>

In [16]:
B, T, C = 4, 8, 2 # batch, time, channels (i.e. # of tokens in vocabulary)
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

## Bag Of Words (bow)

== averaging

### 1. For Loop
Using for loop => not efficient

In [19]:
x_bow = torch.zeros((B, T, C))

for batch in range(B):
    for time in range(T):
        x_prev = x[batch, :time+1]  # all previous tokens (up to time t) in this batch and sample
        x_bow[batch, time] = torch.mean(x_prev, dim=0)

x_bow[0]

tensor([[ 1.7744, -0.9216],
        [ 1.3684, -0.6293],
        [ 0.5205, -0.3002],
        [ 0.5101,  0.1133],
        [ 0.5133,  0.5130],
        [ 0.3409,  0.2722],
        [ 0.3187,  0.3860],
        [ 0.4422,  0.3952]])

### 2. Vectorization

In [11]:
weights = torch.ones(T, T)
weights

tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [13]:
# the future is not relevant for predictions, only look at the past
weights = torch.tril(weights)
weights

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [14]:
weights = weights / weights.sum(axis=1, keepdim=True)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [18]:
x_bow2 = weights @ x
x_bow2[0]

tensor([[ 1.7744, -0.9216],
        [ 1.3684, -0.6293],
        [ 0.5205, -0.3002],
        [ 0.5101,  0.1133],
        [ 0.5133,  0.5130],
        [ 0.3409,  0.2722],
        [ 0.3187,  0.3860],
        [ 0.4422,  0.3952]])

In [20]:
# check that both are similar
torch.allclose(x_bow, x_bow2)

True

### 3. Softmax

In [23]:
tril = torch.tril(torch.ones((T, T)))
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [24]:
weights = torch.zeros((T, T))
weights = weights.masked_fill(tril == 0, float("-inf"))
weights

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [26]:
weights = torch.nn.functional.softmax(weights, dim=-1)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [28]:
x_bow3 = weights @ x
x_bow3[0]

tensor([[ 1.7744, -0.9216],
        [ 1.3684, -0.6293],
        [ 0.5205, -0.3002],
        [ 0.5101,  0.1133],
        [ 0.5133,  0.5130],
        [ 0.3409,  0.2722],
        [ 0.3187,  0.3860],
        [ 0.4422,  0.3952]])

In [29]:
torch.allclose(x_bow, x_bow3)

True