## Self Attention


Using this famous paper: https://arxiv.org/pdf/1706.03762.pdf

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

torch.manual_seed(3791)
B, T, C = 4, 8, 6 # Batch, Time, Channel
x = torch.randn(B, T, C)
x

tensor([[[-0.1155,  1.4004, -0.2636,  1.1846,  1.2524, -1.4922],
         [-0.8670,  0.4435, -0.1523,  1.5791,  0.4891, -0.6652],
         [ 0.4425,  1.0421,  2.1270,  0.2821,  1.6772,  0.9382],
         [ 2.4316, -2.7186,  0.4827, -0.3789, -1.6457,  1.4676],
         [-0.2117,  0.1744, -1.4542,  0.3945, -0.0648,  1.2230],
         [-0.4369, -0.9095, -0.3158, -0.0077, -0.3685, -0.4621],
         [-0.0730, -0.1383,  0.0828, -0.8913, -0.1860,  1.4274],
         [ 0.0764, -1.2949,  1.2127, -0.9746,  0.3595, -0.8611]],

        [[-1.0264, -0.1822, -2.1604,  1.1079, -0.2544,  2.3360],
         [-0.4107,  2.4999,  0.4165, -0.5933,  0.4560, -0.1843],
         [-1.1413,  0.6564,  0.5931,  0.2935,  0.4171, -0.4676],
         [ 1.3102,  1.8109, -0.8727,  0.5862, -1.2576, -0.5111],
         [-0.2096,  0.6701, -0.8327, -0.1827,  0.1226,  0.1414],
         [ 0.0103,  0.4948,  0.3764, -0.9829, -1.1029, -0.2477],
         [ 0.8806, -0.0726, -0.4488, -0.0483,  0.3869,  0.9365],
         [ 0.4525, -1.5

In [2]:
tril = torch.tril(torch.ones(T, T))
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [3]:
wei = torch.zeros(T, T)
wei

tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [4]:
wei_1 = F.softmax(wei, dim=1)
wei_1

tensor([[0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [5]:
wei = wei.masked_fill(tril == 0 , float('-inf'))
wei

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [6]:
out = wei_1 @ x

print(x[0])
print()
print("-"*80)
print()
print(out[0])

tensor([[-0.1155,  1.4004, -0.2636,  1.1846,  1.2524, -1.4922],
        [-0.8670,  0.4435, -0.1523,  1.5791,  0.4891, -0.6652],
        [ 0.4425,  1.0421,  2.1270,  0.2821,  1.6772,  0.9382],
        [ 2.4316, -2.7186,  0.4827, -0.3789, -1.6457,  1.4676],
        [-0.2117,  0.1744, -1.4542,  0.3945, -0.0648,  1.2230],
        [-0.4369, -0.9095, -0.3158, -0.0077, -0.3685, -0.4621],
        [-0.0730, -0.1383,  0.0828, -0.8913, -0.1860,  1.4274],
        [ 0.0764, -1.2949,  1.2127, -0.9746,  0.3595, -0.8611]])

--------------------------------------------------------------------------------

tensor([[ 0.1558, -0.2501,  0.2149,  0.1485,  0.1892,  0.1970],
        [ 0.1558, -0.2501,  0.2149,  0.1485,  0.1892,  0.1970],
        [ 0.1558, -0.2501,  0.2149,  0.1485,  0.1892,  0.1970],
        [ 0.1558, -0.2501,  0.2149,  0.1485,  0.1892,  0.1970],
        [ 0.1558, -0.2501,  0.2149,  0.1485,  0.1892,  0.1970],
        [ 0.1558, -0.2501,  0.2149,  0.1485,  0.1892,  0.1970],
        [ 0.1558, -0