# Self-Attention 

Ref: https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html

We can think of self-attention as a mechanism that enhances the information content of an input embedding by including information about the input’s context. In other words, the self-attention mechanism enables the model to weigh the importance of different elements in an input sequence and dynamically adjust their influence on the output.

### Embedding an Input Sentence

In [1]:
sentence = 'Life is short, eat dessert first'

dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}


use this dictionary to assign an integer index to each word

In [2]:
import torch

sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
print(sentence_int)

tensor([0, 4, 5, 2, 1, 3])


Now, using the integer-vector representation of the input sentence, we can use an embedding layer to encode the inputs into a real-vector embedding. Here, we will use a 16-dimensional embedding such that each input word is represented by a 16-dimensional vector. Since the sentence consists of 6 words, this will result in a **6×16**
-dimensional embedding

In [5]:
torch.manual_seed(123)

d = 16 # size of word vector embedding
embed = torch.nn.Embedding(6, 16)
embedded_sentence = embed(sentence_int).detach()

print(embedded_sentence)
print(embedded_sentence.shape)

tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,
          0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465],
        [ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,
         -1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0.6230, -1.3729],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,
          0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],
        [ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
          2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293]])
torch.Size([6, 16])


## Scaled dot-product self attention mechanism

Self-attention utilizes three weight matrices, referred as $\mathbf{W_{q}}$, $\mathbf{W_{k}}$ and $\mathbf{W_{V}}$, which are adjusted as model parameters during the model training. These matrices serve to project the inputs into query, key, and value components of the sequence, respectively. 

The respective query, key and value sequences are obtained via matrix multiplication between the weight matrices $\mathbf{W}$ and the embedded inputs $\mathbf{x}$:

- Query sequence : $\mathbf{q^{(i)} = W_{q}x^{i}} $ for $i \in [1,T]$
- Key sequence : $\mathbf{k^{(i)} = W_{k}x^{i}} $ for $i \in [1,T]$
- Value sequence : $\mathbf{v^{(i)} = W_{v}x^{i}} $ for $i \in [1,T]$

Where $i$ refers to the token index position in the input sequence, which has lenght $T$.

Let ${d}$ be the size of each word vector $\mathbf{x}$. Therefore, both $q^{(i)}$ and $k^{(i)}$ are vectors of dimension $d_{k}$. The projection matrices $\mathbf{W_{q}}$ and $\mathbf{W_{k}}$ have shape $d_{k} \times d$, while $\mathbf{W_{v}}$ has shape $d_{v} \times d$.

<img src="./img/attention-matrices.png" alt="alt text"  style="height: 50%;   
    display: block;
    margin-left: auto;
    margin-right: auto;
    width: 50%;"/>

Since we are computing the dot-product between the query and key vectors, these two vector have to contain the same number of elements ($d_{q} = d_{k}$). However, the number of elements in the value vector $v^{i}$, which determines the size of the resulting context vector, is arbitrary. 

In [6]:
torch.manual_seed(123)

d = embedded_sentence.shape[1]

d_q, d_k, d_v = 24, 24, 28

W_query = torch.nn.Parameter(torch.rand(d_q, d))
W_key = torch.nn.Parameter(torch.rand(d_k, d))
W_value = torch.nn.Parameter(torch.rand(d_v, d))

Computing the Unnormalized Attention Weights

Now, let’s suppose we are interested in computing the attention-vector for the second input element – the second input element acts as the query here:

<img src="./img/query.png" alt="alt text" 
  style="height: 50%;   
    display: block;
    margin-left: auto;
    margin-right: auto;
    width: 50%;"/>

In [47]:
x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)

print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

torch.Size([24])
torch.Size([24])
torch.Size([28])


We can then generalize this to compute $th$ remaining key, and value elements for all inputs as well, since we will need them in the next step when we compute the unnormalized attention weights $\omega$ :
:

In [48]:
keys = W_key.matmul(embedded_sentence.T).T
values = W_value.matmul(embedded_sentence.T).T

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 24])
values.shape: torch.Size([6, 28])


### Unnormalized attention weights $\omega$

<img src="./img/omega.png" alt="alt text" 
  style="height: 50%;   
    display: block;
    margin-left: auto;
    margin-right: auto;
    width: 50%;"/>

For which, $w_{i,j}$ is the dot product between the query and the key sequences, $w_{i,j} = \mathbf{q^{(i)^{T}}k^{j}}$.

The unnormalized attention weight for the query and the 5th input element will be:

In [49]:
omega_24 = query_2.dot(keys[4])
print(omega_24)

tensor(11.1466, grad_fn=<DotBackward0>)


Let's compute the $omega$ values for all input tokens as illustrated before:

In [50]:
omega_2 = query_2.matmul(keys.T)
print(omega_2)

tensor([ 8.5808, -7.6597,  3.2558,  1.0395, 11.1466, -0.4800],
       grad_fn=<SqueezeBackward4>)


## Computing the Attention Scores

The subsequent step in self-attention is to normalize the unnormalized attention weights, $ω$
, to obtain the normalized attention weights, $α$
, by applying the softmax function. Additionally, $\frac{1}{\sqrt{d_{k}}}$
 is used to scale ω
 before normalizing it through the softmax function, as shown below:


<img src="./img/attention-scores.png" alt="alt text" 
  style="height: 80%;   
    display: block;
    margin-left: auto;
    margin-right: auto;
    width: 80%;"/>

The scaling by $d_{k}$ ensures that the Euclidean length of the weight vectors will be approximately in the same magnitude. This helps prevent the attention weights from becoming too small or too large, which could lead to numerical instability or affect the model’s ability to converge during training.

In [51]:
import torch.nn.functional as F

attention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=0)
print(attention_weights_2)

tensor([0.2912, 0.0106, 0.0982, 0.0625, 0.4917, 0.0458],
       grad_fn=<SoftmaxBackward0>)


### Attention-weighted version of orignal query $\mathbf{x}$

Finally, the last step is to compute the context vector $\mathbf{z^{(2)}}$, which is an attention-weighted version of our original query input $\mathbf{x^{(2)}}$
, including all the other input elements as its context via the attention weights:

<img src="./img/context-vector.png" alt="alt text" 
  style="height: 80%;   
    display: block;
    margin-left: auto;
    margin-right: auto;
    width: 80%;"/>

In [52]:
context_vector_2 = attention_weights_2.matmul(values)

print(context_vector_2.shape)
print(context_vector_2)

torch.Size([28])
tensor([-1.5993,  0.0156,  1.2670,  0.0032, -0.6460, -1.1407, -0.4908, -1.4632,
         0.4747,  1.1926,  0.4506, -0.7110,  0.0602,  0.7125, -0.1628, -2.0184,
         0.3838, -2.1188, -0.8136, -1.5694,  0.7934, -0.2911, -1.3640, -0.2366,
        -0.9564, -0.5265,  0.0624,  1.7084], grad_fn=<SqueezeBackward4>)


Note that this output vector has more dimensions $(d_{v}=28
)$ than the original input vector $(d=16
)$ since we specified $d_{v}>d$
 earlier; however, the embedding size choice is arbitrary

# Multi-head attention

In the scaled dot-product attention, the input sequence was transformed using three matrices representing the query, key, and value. These three matrices can be considered as a single attention head in the context of multi-head attention. The figure below summarizes this single attention head we covered previously:

<img src="./img/single-head.png" alt="alt text" 
  style="height: 80%;   
    display: block;
    margin-left: auto;
    margin-right: auto;
    width: 80%;"/>

As its name implies, multi-head attention involves multiple such heads, each consisting of query, key, and value matrices. This concept is similar to the use of multiple kernels in convolutional neural networks.

<img src="./img/multi-head.png" alt="alt text" 
  style="height: 80%;   
    display: block;
    margin-left: auto;
    margin-right: auto;
    width: 80%;"/>

To illustrate this in code, suppose we have 3 attention heads, so we now extend the $d^{´} \times d$ dimensional weight matrices so $3 \times d^{'}\times d$:

In [53]:
h = 3
multihead_W_query = torch.nn.Parameter(torch.rand(h, d_q, d))
multihead_W_key = torch.nn.Parameter(torch.rand(h, d_k, d))
multihead_W_value = torch.nn.Parameter(torch.rand(h, d_v, d))

Consequently, each query element is now $3 \times d_{q}$ dimensional, where $d_{q}=24$
 (here, let’s keep the focus on the 3rd element corresponding to index position 2):

In [77]:
x_2.shape, multihead_W_query.shape

(torch.Size([16]), torch.Size([3, 24, 16]))

In [54]:
multihead_query_2 = multihead_W_query.matmul(x_2)
print(multihead_query_2.shape)

torch.Size([3, 24])


In [79]:
multihead_W_key.shape

torch.Size([3, 24, 16])

In [55]:
# Obtain key and values in a similar way 
multihead_key_2 = multihead_W_key.matmul(x_2)
multihead_value_2 = multihead_W_value.matmul(x_2)

Now, these key and value elements are specific to the query element. But, similar to earlier, we will also need the value and keys for the other sequence elements in order to compute the attention scores for the query. We can do this is by expanding the input sequence embeddings to size 3, i.e., the number of attention heads:

In [56]:
stacked_inputs = embedded_sentence.T.repeat(3, 1, 1)
print(stacked_inputs.shape)

torch.Size([3, 16, 6])


In [57]:
# Now the input is three dimensional
stacked_inputs[:,2]

tensor([[-0.3035, -0.2587,  1.0042, -2.1338, -0.1690, -1.4779],
        [-0.3035, -0.2587,  1.0042, -2.1338, -0.1690, -1.4779],
        [-0.3035, -0.2587,  1.0042, -2.1338, -0.1690, -1.4779]])

Now, we can compute all the keys and values using torch.bmm() ( batch matrix multiplication):

In [70]:
multihead_W_key.shape, stacked_inputs.shape

(torch.Size([3, 24, 16]), torch.Size([3, 16, 6]))

In [82]:
multihead_keys = torch.bmm(multihead_W_key, stacked_inputs)
multihead_values = torch.bmm(multihead_W_value, stacked_inputs)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)

multihead_keys.shape: torch.Size([3, 24, 6])
multihead_values.shape: torch.Size([3, 28, 6])


We now have tensors that represent the three attention heads in their first dimension. The third and second dimensions refer to the number of words and the embedding size, respectively. To make the values and keys more intuitive to interpret, we will swap the second and third dimensions, resulting in tensors with the same dimensional structure as the original input sequence, embedded_sentence:

In [59]:
# multihead_keys = multihead_keys.permute(0, 2, 1)
# multihead_values = multihead_values.permute(0, 2, 1)
# print("multihead_keys.shape:", multihead_keys.shape)
# print("multihead_values.shape:", multihead_values.shape)

multihead_keys.shape: torch.Size([3, 6, 24])
multihead_values.shape: torch.Size([3, 6, 28])


Then, we follow the same steps as previously to compute the unscaled attention weights $ω$ and attention weights $α$, followed by the scaled-softmax computation to obtain an $h \times d_{v}$ (here: $3 \times d_{v}$) dimensional context vector $\mathbf{z}$ for the input element $x^{(2)}$.

In [83]:
query_2.shape, keys.shape

(torch.Size([24]), torch.Size([6, 24]))

In [98]:
multihead_query_2.shape, multihead_keys.shape, multihead_keys.T.shape

(torch.Size([3, 24]), torch.Size([3, 24, 6]), torch.Size([6, 24, 3]))

In [100]:
multihead_omega_2 = multihead_query_2.matmul(multihead_keys)
print(multihead_omega_2.shape)

torch.Size([3, 3, 6])


In [101]:
multihead_omega_2

tensor([[[ -4.6998, -15.8777, -20.5870,  -9.2172,  -9.4817,  15.4494],
         [ -4.6211,  -0.8574, -24.4662,  -6.9085,  -4.5868,  23.2441],
         [ -6.3106,   8.2177,  12.2194,   8.0759,   2.1534,  -5.5698]],

        [[ 10.0815,   3.7914, -20.2564, -14.0807,  -4.5015,  15.9630],
         [  4.6162,   8.4063, -23.4371, -14.6130, -13.7417,  13.2370],
         [  0.9246, -14.5083,  13.1240,  -3.5708,   4.6327, -11.3191]],

        [[  9.5033,  11.6041, -15.0547, -13.7880, -12.2019,  15.5950],
         [  5.5354,   7.9537, -19.8211, -16.5260,  -4.5941,  11.4768],
         [  0.4722,  14.0731,   3.7612,   1.9591,   1.2152, -21.3313]]],
       grad_fn=<CloneBackward0>)

In [102]:
multihead_weights_2 = F.softmax(multihead_omega_2 / d_k**0.5, dim=0)
print(multihead_weights_2)

tensor([[[0.0253, 0.0030, 0.1937, 0.5669, 0.2305, 0.3184],
         [0.0643, 0.0732, 0.2077, 0.7419, 0.4645, 0.8195],
         [0.1067, 0.2318, 0.4200, 0.7248, 0.2870, 0.7411]],

        [[0.5161, 0.1682, 0.2072, 0.2101, 0.6372, 0.3536],
         [0.4241, 0.4848, 0.2562, 0.1539, 0.0717, 0.1063],
         [0.4673, 0.0022, 0.5052, 0.0673, 0.4760, 0.2292]],

        [[0.4586, 0.8288, 0.5991, 0.2230, 0.1323, 0.3280],
         [0.5116, 0.4420, 0.5361, 0.1042, 0.4638, 0.0742],
         [0.4260, 0.7660, 0.0747, 0.2080, 0.2370, 0.0297]]],
       grad_fn=<SoftmaxBackward0>)


In [103]:
context_vector_2.shape, values.shape

(torch.Size([28]), torch.Size([6, 28]))

In [106]:
multihead_weights_2.shape, multihead_values.shape

(torch.Size([3, 3, 6]), torch.Size([3, 28, 6]))

In [105]:
multihead_context_vector_2 = multihead_weights_2.matmul(multihead_values)

print(multihead_context_vector_2.shape)
print(multihead_context_vector_2)

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [3, 6] but got: [3, 28].

### Summary self-attention mechanism

<img src="./img/summary_self_attention.png" alt="alt text" 
  style="height: 80%;   
    display: block;
    margin-left: auto;
    margin-right: auto;
    width: 80%;"/>

# Cross-Attention

In self-attention, we work with the same input sequence. In cross-attention, we mix or combine two different input sequences. In the case of the original transformer architecture above, that’s the sequence returned by the encoder module on the left and the input sequence being processed by the decoder part on the right.

Note that in cross-attention, the two input sequences $\mathbf{x_{1}}$ and $\mathbf{x_{2}}$ can have different numbers of elements. However, their embedding dimensions must match.

The figure below illustrates the concept of cross-attention. If we set $\mathbf{x_{1}}$ = $\mathbf{x_{2}}$, this is equivalent to self-attention.

<img src="./img/cross-attention.png" alt="alt text" 
  style="height: 80%;   
    display: block;
    margin-left: auto;
    margin-right: auto;
    width: 80%;"/>

<img src="./img/cross-attention-summary.png" alt="alt text" 
  style="height: 80%;   
    display: block;
    margin-left: auto;
    margin-right: auto;
    width: 80%;"/>

In code it will look like this:

In [109]:
torch.manual_seed(123)

d = embedded_sentence.shape[1]
print("embedded_sentence.shape:", embedded_sentence.shape)

d_q, d_k, d_v = 24, 24, 28

W_query = torch.rand(d_q, d)
W_key = torch.rand(d_k, d)
W_value = torch.rand(d_v, d)

x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)
print("query.shape", query_2.shape)

keys = W_key.matmul(embedded_sentence.T).T
values = W_value.matmul(embedded_sentence.T).T

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

embedded_sentence.shape: torch.Size([6, 16])
query.shape torch.Size([24])
keys.shape: torch.Size([6, 24])
values.shape: torch.Size([6, 28])


The only part that changes in cross attention is that we now have a second input sequence, for example, a second sentence with 8 instead of 6 input elements. Here, suppose this is a sentence with 8 tokens.

In [110]:
embedded_sentence_2 = torch.rand(8, 16) # 2nd input sequence

keys = W_key.matmul(embedded_sentence_2.T).T
values = W_value.matmul(embedded_sentence_2.T).T

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([8, 24])
values.shape: torch.Size([8, 28])


Notice that compared to self-attention, the keys and values now have 8 instead of 6 rows. Everything else stays the same.