Refs: https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a

In [1]:
import numpy as np
#import pandas as pd
import torch
from torch.nn.functional import softmax

In [2]:
## Step 1: input
x = [[1, 0, 1, 0],[0, 2, 0, 2],[1, 1, 1, 1]]
x = torch.tensor(x, dtype=torch.float32)
print(x)

tensor([[1., 0., 1., 0.],
        [0., 2., 0., 2.],
        [1., 1., 1., 1.]])


In [3]:
## Step 2: initialise weights
w_key = [[0, 0, 1],[1, 1, 0],[0, 1, 0],[1, 1, 0]]
w_query = [[1, 0, 1],[1, 0, 0],[0, 0, 1],[0, 1, 1]]
w_value = [[0, 2, 0],[0, 3, 0],[1, 0, 3],[1, 1, 0]]

w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)

print("w_key:")
print(w_key)

print("w_query:")
print(w_query)

print("w_value:")
print(w_value)

w_key:
tensor([[0., 0., 1.],
        [1., 1., 0.],
        [0., 1., 0.],
        [1., 1., 0.]])
w_query:
tensor([[1., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 1.]])
w_value:
tensor([[0., 2., 0.],
        [0., 3., 0.],
        [1., 0., 3.],
        [1., 1., 0.]])


In [4]:
## Step 3: Derive key, query and value

keys = x @ w_key
querys = x @ w_query
values = x @ w_value

print("keys:")
print(keys)

print("querys:")
print(querys)

print("values:")
print(values)

keys:
tensor([[0., 1., 1.],
        [4., 4., 0.],
        [2., 3., 1.]])
querys:
tensor([[1., 0., 2.],
        [2., 2., 2.],
        [2., 1., 3.]])
values:
tensor([[1., 2., 3.],
        [2., 8., 0.],
        [2., 6., 3.]])


In [5]:
## Step 4: Calculate attention scores
attn_scores = querys @ keys.T

print(attn_scores)

tensor([[ 2.,  4.,  4.],
        [ 4., 16., 12.],
        [ 4., 12., 10.]])


In [6]:
## Step 5: Calculate softmax
attn_scores_softmax = softmax(attn_scores, dim=-1)
print(attn_scores_softmax)

tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],
        [6.0337e-06, 9.8201e-01, 1.7986e-02],
        [2.9539e-04, 8.8054e-01, 1.1917e-01]])


In [7]:
## Step 6: Multiply scores with values
weighted_values = values[:,None] * attn_scores_softmax.T[:,:,None]
print(weighted_values)

tensor([[[6.3379e-02, 1.2676e-01, 1.9014e-01],
         [6.0337e-06, 1.2067e-05, 1.8101e-05],
         [2.9539e-04, 5.9077e-04, 8.8616e-04]],

        [[9.3662e-01, 3.7465e+00, 0.0000e+00],
         [1.9640e+00, 7.8561e+00, 0.0000e+00],
         [1.7611e+00, 7.0443e+00, 0.0000e+00]],

        [[9.3662e-01, 2.8099e+00, 1.4049e+00],
         [3.5972e-02, 1.0792e-01, 5.3958e-02],
         [2.3834e-01, 7.1501e-01, 3.5750e-01]]])


In [8]:
## Step 7: Sum weighted values
outputs = weighted_values.sum(dim=0)
print(outputs)

#tensor([[1.9366, 6.6831, 1.5951],  ## output 1
#        [2.0000, 7.9640, 0.0540],  ## output 2
#        [1.9997, 7.7599, 0.3584]]) ## output 3

tensor([[1.9366, 6.6831, 1.5951],
        [2.0000, 7.9640, 0.0540],
        [1.9997, 7.7599, 0.3584]])
