## Understanding and Coding the Self-Attention Mechanism of Large Language Models From Scrath

According to https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html

In [30]:
import torch
import torch.nn.functional as F

In [5]:
input_sentence = "Life is short, eat dessert first"
word_list = input_sentence.replace(",", "").split()
word_list

['Life', 'is', 'short', 'eat', 'dessert', 'first']

In [16]:
stoi = {s:i for i,s in enumerate(sorted(word_list))}
vocab_size = len(stoi)
print(f"vocab size: {vocab_size}")
print(stoi)

vocab size: 6
{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}


In [17]:
input_sentence_int = torch.tensor([stoi[s] for s in word_list])
input_sentence_int

tensor([0, 4, 5, 2, 1, 3])

In [18]:
torch.manual_seed(123)
embedding_size = 16
embed = torch.nn.Embedding(vocab_size, embedding_size)
embed.weight.shape

torch.Size([6, 16])

In [19]:
embedded_input_sentence = embed(input_sentence_int).detach()
print(embedded_input_sentence.shape)
print(embedded_input_sentence[0])

torch.Size([6, 16])
tensor([ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
         0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692])


In [21]:
d = embedding_size
d_q = d_k = 24
d_v = 28

In [23]:
W_query = torch.nn.Parameter(torch.rand(d_q, d))
W_key = torch.nn.Parameter(torch.rand(d_k, d))
W_value = torch.nn.Parameter(torch.rand(d_v, d))
print(W_query.shape, W_key.shape, W_value.shape)

torch.Size([24, 16]) torch.Size([24, 16]) torch.Size([28, 16])


In [24]:
x_2 = embedded_input_sentence[1]
query_2 = W_query @ x_2
key_2 = W_key @ x_2
value_2 = W_value @ x_2

print(query_2.shape, key_2.shape, value_2.shape)

torch.Size([24]) torch.Size([24]) torch.Size([28])


In [25]:
keys = (W_key @ embedded_input_sentence.T).T
values = (W_value @ embedded_input_sentence.T).T

print(keys.shape, values.shape)

torch.Size([6, 24]) torch.Size([6, 28])


In [27]:
omega_24 = query_2.dot(keys[4])
omega_24

tensor(13.1208, grad_fn=<DotBackward0>)

In [29]:
omega_2 = query_2 @ keys.T
print(omega_2)

tensor([  2.1938,  -3.8873,  14.7953,   2.0779,  13.1208, -13.5755],
       grad_fn=<SqueezeBackward4>)


In [35]:
attention_weights_2 = F.softmax(omega_2 * d_k**-0.5, dim=0)
print(attention_weights_2)


tensor([0.0405, 0.0117, 0.5301, 0.0395, 0.3766, 0.0016],
       grad_fn=<SoftmaxBackward0>)


In [37]:
attention_weights_2.shape, values.shape

(torch.Size([6]), torch.Size([6, 28]))

In [38]:
context_vector_2 = attention_weights_2 @ values
print(context_vector_2.shape)
print(context_vector_2)

torch.Size([28])
tensor([-1.5785, -2.1954, -2.2932, -2.2415, -1.2018, -1.1000, -1.4869, -1.8232,
        -2.7236, -1.0135, -2.6909, -1.2424, -2.0749, -2.5979, -1.8763, -0.2468,
        -1.7071, -1.6958, -1.7728, -2.4239, -3.8261, -2.3653, -2.9899, -1.8056,
        -0.1843, -0.8150, -1.0318, -3.1751], grad_fn=<SqueezeBackward4>)


### Multi-Head Attention

In [40]:
h = 3
multihead_W_query = torch.nn.Parameter(torch.rand(h, d_q, d))
multihead_W_key = torch.nn.Parameter(torch.rand(h, d_k, d))
multihead_W_value = torch.nn.Parameter(torch.rand(h, d_v, d))

In [41]:
multihead_W_query.shape, x_2.shape

(torch.Size([3, 24, 16]), torch.Size([16]))

In [43]:
multihead_query_2 = multihead_W_query @ x_2
print(multihead_query_2.shape)

torch.Size([3, 24])


In [44]:
multihead_key_2 = multihead_W_key @ x_2
multihead_value_2 = multihead_W_value @ x_2
multihead_key_2.shape, multihead_value_2.shape


(torch.Size([3, 24]), torch.Size([3, 28]))

In [45]:
multihead_W_key.shape, embedded_input_sentence.shape

(torch.Size([3, 24, 16]), torch.Size([6, 16]))

In [46]:
(multihead_W_key @ embedded_input_sentence.T).shape

torch.Size([3, 24, 6])

In [47]:
multihead_keys = multihead_W_key @ embedded_input_sentence.T
multihead_values = multihead_W_value @ embedded_input_sentence.T
print(multihead_keys.shape, multihead_values.shape)

torch.Size([3, 24, 6]) torch.Size([3, 28, 6])


### Cross Attention

In [48]:
# Note: within the Transformer model, the second sentence comes from the encoder, while the sentence for the queries comes from the decoder
embedded_input_sentence_2 = torch.rand(8, 16) # 2nd input sequence

keys = (W_key @ embedded_input_sentence_2.T).T
values = (W_value @ embedded_input_sentence_2.T).T

In [55]:
omega_2 = query_2 @ keys.T
print(query_2.shape, keys.shape, omega_2.shape)


torch.Size([24]) torch.Size([8, 24]) torch.Size([8])


In [50]:
attention_weights_2 = F.softmax(omega_2 * d_k**-0.5, dim=0)


In [53]:
context_vector_2 = attention_weights_2 @ values
print(context_vector_2.shape)
print(context_vector_2)

torch.Size([28])
tensor([4.3351, 3.6771, 4.2762, 3.9791, 4.1577, 3.2921, 3.7612, 4.4088, 3.5626,
        3.6418, 4.6238, 5.4345, 3.6379, 4.6668, 4.2885, 3.2560, 5.3845, 4.0421,
        3.4744, 3.7277, 4.6168, 3.8642, 4.3314, 4.7715, 3.5523, 3.3344, 5.0613,
        3.8946], grad_fn=<SqueezeBackward4>)
