In [10]:
import torch
import torch.nn.functional as F

In [11]:
words = ["festival", "of", "lights", "diwali"]
embedding_dim = 4

torch.manual_seed(0)

embeddings = torch.randn(len(words), embedding_dim)

for w, e in zip(words, embeddings):
    print(w, e)

festival tensor([-1.1258, -1.1524, -0.2506, -0.4339])
of tensor([ 0.8487,  0.6920, -0.3160, -2.1152])
lights tensor([ 0.3223, -1.2633,  0.3500,  0.3081])
diwali tensor([ 0.1198,  1.2377,  1.1168, -0.2473])


In [12]:
Wq = torch.randn(embedding_dim, embedding_dim)
Wk = torch.randn(embedding_dim, embedding_dim)
Wv = torch.randn(embedding_dim, embedding_dim)

Q = embeddings @ Wq
K = embeddings @ Wk
V = embeddings @ Wv

print("Query", Q.shape)
print("Key", K.shape)
print("Value", V.shape)

Query torch.Size([4, 4])
Key torch.Size([4, 4])
Value torch.Size([4, 4])


In [13]:
scores = Q @ K.T
print("Raw attention scores")
print(scores)

Raw attention scores
tensor([[ -4.3225, -28.3179,   6.1734,  -1.7958],
        [  6.5034,  42.2719,  -3.3050,  -1.0317],
        [ -1.6087, -13.9601,   2.6460,  -1.0218],
        [  3.5789,  22.8167,  -4.8524,   1.0598]])


In [14]:
scale = embedding_dim ** 0.5
scaled_scores = scores / scale

attention_weights = F.softmax(scaled_scores, dim=1)

print("Attention weights")
print(attention_weights)

Attention weights
tensor([[5.1356e-03, 3.1627e-08, 9.7670e-01, 1.8167e-02],
        [1.7099e-08, 1.0000e+00, 1.2679e-10, 3.9513e-10],
        [9.3148e-02, 1.9368e-04, 7.8175e-01, 1.2491e-01],
        [6.6453e-05, 9.9991e-01, 9.8103e-07, 1.8859e-05]])


In [15]:
context = attention_weights @ V

print("Context vectors")
print(context)

Context vectors
tensor([[-0.5406,  0.6538,  0.7513,  1.6047],
        [ 2.6491,  2.8450, -3.0581, -0.0337],
        [-0.3255,  0.5911,  0.5609,  1.6361],
        [ 2.6490,  2.8448, -3.0580, -0.0336]])


In [16]:
festival_index = words.index("festival")

for i, w in enumerate(words):
    print(
        "festival attends to",
        w,
        "with weight",
        attention_weights[festival_index][i].item()
    )

festival attends to festival with weight 0.0051356288604438305
festival attends to of with weight 3.162701744940932e-08
festival attends to lights with weight 0.9766977429389954
festival attends to diwali with weight 0.018166586756706238


In [17]:
import pandas as pd

df = pd.DataFrame(
    attention_weights.numpy(),
    columns=words,
    index=words
)

df

Unnamed: 0,festival,of,lights,diwali
festival,0.005135629,3.162702e-08,0.9766977,0.01816659
of,1.709879e-08,1.0,1.267937e-10,3.951302e-10
lights,0.09314755,0.0001936833,0.7817459,0.124913
diwali,6.645304e-05,0.9999138,9.810271e-07,1.885862e-05
