In [1]:
import torch

# Scaled Dot Product Attention

In [2]:
input_sentence = 'Life is short, eat dessert last'
input_sentence_list = input_sentence.replace(",", "").split()
sentence_dict = {s : i for i, s in enumerate(sorted(input_sentence_list))}
print(sentence_dict)

{'Life': 0, 'dessert': 1, 'eat': 2, 'is': 3, 'last': 4, 'short': 5}


In [3]:
sentence_int = torch.tensor([sentence_dict[s] for s in input_sentence_list])
print(sentence_int)

tensor([0, 3, 5, 2, 1, 4])


### Create the embedding
We use a 16 dimensional embedding for each vector and since we have 6 words this means we get embeddings of shape 6x16.

In [4]:
torch.manual_seed(123)
embed = torch.nn.Embedding(6, 16)
embed_sentence = embed(sentence_int).detach()

In [5]:
print(embed_sentence)
print(embed_sentence.shape)

tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
          2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293],
        [ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,
         -1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0.6230, -1.3729],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,
          0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],
        [ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,
          0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465]])
torch.Size([6, 16])


In [6]:
d = embed_sentence.shape[1]

d_q, d_k, d_v = 24, 24, 28

W_q = torch.randn(d_q, d)
W_k = torch.randn(d_k, d)
W_v = torch.randn(d_v, d)

print(W_q.shape, W_k.shape, W_v.shape)

torch.Size([24, 16]) torch.Size([24, 16]) torch.Size([28, 16])


## Step 1: Compute unnormalized attention weights

In [7]:
x_2 = embed_sentence[1]
print(x_2)

tensor([ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
         2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293])


In [8]:
q_2 = W_q.matmul(x_2)
k_2 = W_k.matmul(x_2)
v_2 = W_v.matmul(x_2)

print(q_2.shape)
print(k_2.shape)
print(v_2.shape)

torch.Size([24])
torch.Size([24])
torch.Size([28])


In [9]:
keys = W_k.matmul(embed_sentence.T).T
values = W_v.matmul(embed_sentence.T).T
print(keys.shape, values.shape)

torch.Size([6, 24]) torch.Size([6, 28])


In [10]:
omega_24 = q_2.dot(keys[4])
print(omega_24)

tensor(58.9875)


In [11]:
omega_2 = q_2.matmul(keys.T)
print(omega_2)

tensor([ -57.1016,  -85.4889,  160.1854, -144.2133,   58.9875,  -80.1705])


## Step 2: Normalization

In [12]:
attention_2 = torch.nn.functional.softmax(omega_2/(d_k**0.5), dim=0)
print(attention_2)

tensor([5.4640e-20, 1.6633e-22, 1.0000e+00, 1.0353e-27, 1.0686e-09, 4.9255e-22])


## Step 3: Compute the Context Vector

In [13]:
context_vector_2 = attention_2.matmul(values)
print(context_vector_2)
print(context_vector_2.shape)

tensor([-4.6812,  4.3038, -5.0492, -2.6208, -2.4619, -0.3670, -1.0982,  3.0041,
        -2.2975,  3.9133, -3.7064, -1.8859,  3.9662, -4.3787, -1.7991,  4.1266,
        -2.3905,  2.7373,  2.9809,  6.5839,  0.3691, -6.0942,  3.2605, -3.9929,
         6.6571,  1.6524, -4.1800,  2.8630])
torch.Size([28])


# Multi Head Attention

In [14]:
x_2 = embed_sentence[1]

d = len(x_2)
d_q, d_k, d_v = 24, 24, 28
h = 3
W_q = torch.randn(h, d_q, d)
W_k = torch.randn(h, d_k, d)
W_v = torch.randn(h, d_v, d)

In [15]:
print(W_q.shape, W_k.shape, W_v.shape)

torch.Size([3, 24, 16]) torch.Size([3, 24, 16]) torch.Size([3, 28, 16])


In [16]:
q_2 = W_q.matmul(x_2)
print(q_2.shape)

torch.Size([3, 24])


In [17]:
# Stack the inputs for multi head
input_stack = embed_sentence.T.repeat(3, 1, 1)
print(input_stack.shape)

torch.Size([3, 16, 6])


In [18]:
keys = torch.bmm(W_k, input_stack)
values = torch.bmm(W_v, input_stack)
print(keys.shape, values.shape)

torch.Size([3, 24, 6]) torch.Size([3, 28, 6])


In [19]:
q_2 = q_2.unsqueeze(-1).swapaxes(-1, -2)
print(q_2.shape)

torch.Size([3, 1, 24])


In [20]:
omega_2 = torch.bmm(q_2, keys).squeeze()
print(omega_2.shape)

torch.Size([3, 6])


In [21]:
# Normalize the dot product
attention_2 = torch.nn.functional.softmax(omega_2/(d_k**0.5), dim=1)
print(attention_2.shape)

torch.Size([3, 6])


In [22]:
# Compute the context vector
context_vector_2 = torch.bmm(values, attention_2.unsqueeze(-1)).squeeze()
print(context_vector_2.shape)

torch.Size([3, 28])


### References
1. [Understanding and Coding the Self-Attention Mechanism of Large Language Models From Scratch](https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html) by Sebastian Raschka, PhD.

2. [Attention is all you need](https://arxiv.org/pdf/1706.03762.pdf) by Vaswani Et.al