## Coding attention mechanisms


## Implement four different variants of the attention mechanisms:
* Simlified self-attention
* Self-attention
* Casual attention
* Multi-head attention

## A simple self-attention mechanism without trainable weights

The goal of self-attention is to compute a context vector for each input element that combines information from all other input elements. A context vector can be interpreted as an enriched embedding vector, which is created by incorporating information from all other elements in the sequense.

Consider following input sentance, which has already been embedded into three-dimentional vectors:


In [None]:
%pip install torch numpy

In [4]:
import torch

inputs = torch.tensor(
    [[0.43, 0.15, 0.89],  # your     (x1)
     [0.55, 0.87, 0.66],  # journey  (x2)
     [0.57, 0.85, 0.64],  # starts   (x3)
     [0.22, 0.58, 0.33],  # with     (x4)
     [0.77, 0.25, 0.10],  # one      (x5)
     [0.05, 0.80, 0.55]]  # step     (x6)
)

**Step1**. In order to calculate the intermediate attention scores for second input element we compute the dot product of x2 with every other input token:

In [4]:
# the second input token serves as the query
query = inputs[1]

attn_scores_for_2 = torch.empty(inputs.shape[0])

for i, x_i in enumerate(inputs):
    attn_scores_for_2 [i] = torch.dot(x_i, query)

print(attn_scores_for_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


The dot product is a measure of similarity because it quantifies how closely two vectors are aligned: a higher dot procuct indicates a greater degree of alignment or similarity between the vectors. In the context of the self-attention mechanisms, the dot product determines the extent to which each element in the sequnce focuses on, or **"attends to"** any other element: the higher the dot product, the higher the similarity and attention score between two elements.

**Step2**. Normalize each of the attention scores:


In [5]:
attn_weights_2_tmp = attn_scores_for_2 / attn_scores_for_2.sum()

print("Attention weights: ", attn_weights_2_tmp)
print("Sum: ", attn_weights_2_tmp.sum())

Attention weights:  tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum:  tensor(1.0000)


**In practice, it's more common to use softmax function for normalization**

In [9]:
attn_weights_2_softmax = torch.softmax(attn_scores_for_2, dim=0)

print("Attention weights: ", attn_weights_2_softmax)
print("Sum: ", attn_weights_2_softmax.sum())

Attention weights:  tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum:  tensor(1.)


  **Step3**. Calculate the context vector by multiplying the embedded input tokens, with the corresponding attention weights and then summing the resulting vectors. Context vector is the weighted sum of all input vectors, obtained by multiplying each input vector by its corresponding attention weight.

In [19]:
query = inputs[1] # x2

print(query.shape)
context_vec_2 = torch.zeros(query.shape)

for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2_softmax[i] * x_i

print(context_vec_2)

torch.Size([3])
tensor([0.4419, 0.6515, 0.5683])


So again all the calculations done manually:
```
attn_scores_for_2 = input * x2 =
    [0.43, 0.15, 0.89]
    [0.55, 0.87, 0.66] (x2)
    [0.57, 0.85, 0.64]
    [0.22, 0.58, 0.33]
    [0.77, 0.25, 0.10]
    [0.05, 0.80, 0.55]         *   [0.55, 0.87, 0.66] (x2)  =
                                   [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865]

attn_weights_2_softmax = softmax([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865]) =
                                 [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581]

context_vec_2 = sum(input[i] * attn_weights_2_softmax[i]) =
              = sum(
                    [0.43, 0.15, 0.89] * 0.1385,
                    [0.55, 0.87, 0.66] * 0.2379
                    [0.57, 0.85, 0.64] * 0.2333
                    [0.22, 0.58, 0.33] * 0.1240
                    [0.77, 0.25, 0.10] * 0.1082
                    [0.05, 0.80, 0.55] * 0.1581)
              = sum(
                    [0.0596, 0.0208, 0.1233]
                    [0.1308, 0.2070, 0.1570]
                    [0.1330, 0.1983, 0.1493]
                    [0.0273, 0.0719, 0.0409]
                    [0.0833, 0.0270, 0.0108]
                    [0.0079, 0.1265, 0.0870])   =   [0.4419, 0.6515, 0.5683]
```


That was the attention weight for input x2.

### Computing attention weightsd for all input tokens

In [9]:
attn_scores = torch.empty(6, 6)

# for i, x_i in enumerate(inputs):
#    for j, x_j in enumerate(inputs):
#        attn_scores[i, j] = torch.dot(x_i, x_j)

# compute attention score tensor (equivalent to the for loop above)
attn_scores = inputs @ inputs.T

# normalize
attn_weights = torch.softmax(attn_scores, dim=-1)

print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


setting `dim=-1` we apply normalization along the last dimension of the `attn_scores` tensor. If `attn_scores` is a two-dimentional tensor ([rows, columns]) it will normalize across the colums so that the values in each row sum up to 1.

Now use attention weights to compute all context vectors via matrix mul:

In [10]:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


## Self-attention with trainable weights (scaled dot-product attention)

Let's introduce three trainable weight matrices - `Wq`, `Wk`, `Wv`.
Used to project the embedded input tokens x(i) into query, key and value vectors.

Let's calculate all the context vectors for input x_2:

In [18]:
x_2 = inputs[1]         # second input element
d_in = inputs.shape[1]  # the input embedding size = 3
d_out = 2               # the output embedding size
# in GPT-like models the input and output dim are usually the same

# init three weight matrices
torch.manual_seed(123)

W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)  # setting false to reduce clutter in the outputs
# TODO: set to True to update these matrices during model training

# compute the vectors
query_2 = x_2 @ W_query
key_2   = x_2 @ W_key
value_2 = x_2 @ W_value

print(query_2)
print(key_2)
print(value_2)

tensor([0.4306, 1.4551])
tensor([0.4433, 1.1419])
tensor([0.3951, 1.0037])


Even though we're computing context vector for x_2, we still need the key and value vectors for all the input elements as they are involved in computing the attention weights with respect to the query q_2

In [29]:
keys = inputs @ W_key
values = inputs @ W_value

Let's compute the attention score

In [30]:
attn_score_22 = query_2 @ key_2
print(attn_score_22)

tensor(1.8524)


or compute for all attention scores:

In [31]:
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


Now compute attention weights from attention scores, by scaling the attention scores and using softmax. However, now we scale by dividing by square root of the embedding dimentions of the keys:

In [32]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


**The reason for the normalization by the embedding dimention size is to improve the training performance by avoiding small gradients. This is why this self-attention mechanism is called scaled-dot product attention**

The last step is multiplying each value vector with its respective attention weight and them summing them to obtain the context vector.

In [33]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3061, 0.8210])


**A query** is analogous to a search query in a database. It represents the current item
(e.g., a word or token in a sentence) the model focuses on or tries to understand.
The query is used to probe the other parts of the input sequence to determine how
much attention to pay to them.

**The key** is like a database key used for indexing and searching. In the attention mech-
anism, each item in the input sequence (e.g., each word in a sentence) has an associated key. These keys are used to match the query.

**The value** in this context is similar to the value in a key-value pair in a database. It
represents the actual content or representation of the input items. Once the model
determines which keys (and thus which parts of the input) are most relevant to the
query (the current focus item), it retrieves the corresponding values.

### Self-attention class

In [34]:
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values

        return context_vec

In [37]:
# test
torch.manual_seed(123)
sa = SelfAttention(d_in, d_out)
print(sa(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)
