In [1]:
# https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html

In [1]:
sentence = 'Life is short, eat dessert first'

dc = {word:i for i, word in enumerate(sorted(sentence.replace(',', '').split()))}

In [2]:
dc

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}

In [3]:
import torch

In [4]:
sent_int = torch.tensor([dc[w] for w in sentence.replace(',', '').split()])

In [5]:
sent_int

tensor([0, 4, 5, 2, 1, 3])

In [6]:
torch.manual_seed(123)

<torch._C.Generator at 0x11297e3f0>

In [7]:
embed = torch.nn.Embedding(6, 16)

In [8]:
embedded_sentence = embed(sent_int)

In [9]:
embedded_sentence = embedded_sentence.detach()

In [10]:
embedded_sentence

tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,
          0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465],
        [ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,
         -1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0.6230, -1.3729],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,
          0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],
        [ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
          2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293]])

In [11]:
embedded_sentence.shape

torch.Size([6, 16])

In [12]:
d = embedded_sentence.shape[1]
d

16

In [13]:
d_q, d_k, d_v = 24, 24, 28

Defining the Weight Matrices #
Now, let’s discuss the widely utilized self-attention mechanism known as the scaled dot-product attention, which is integrated into the transformer architecture.

Self-attention utilizes three weight matrices, referred to as Wq
, Wk
, and Wv
, which are adjusted as model parameters during training. These matrices serve to project the inputs into query, key, and value components of the sequence, respectively.

The respective query, key and value sequences are obtained via matrix multiplication between the weight matrices W
 and the embedded inputs x
:

Query sequence: q(i)=Wqx(i)
 for i∈[1,T]

Key sequence: k(i)=Wkx(i)
 for i∈[1,T]

Value sequence: v(i)=Wvx(i)
 for i∈[1,T]

The index i refers to the token index position in the input sequence, which has length T.

![](https://sebastianraschka.com/images/blog/2023/self-attention-from-scratch/attention-matrices.png)


Here, both q(i)
 and k(i)
 are vectors of dimension dk
. The projection matrices Wq
 and Wk
 have a shape of dk×d
, while Wv
 has the shape dv×d
.

(It’s important to note that d
 represents the size of each word vector, x
.)

Since we are computing the dot-product between the query and key vectors, these two vectors have to contain the same number of elements (dq=dk
). However, the number of elements in the value vector v(i)
, which determines the size of the resulting context vector, is arbitrary.

In [14]:
W_query = torch.nn.Parameter(torch.rand(d_q, d))

W_key = torch.nn.Parameter(torch.rand(d_k, d))

W_value = torch.nn.Parameter(torch.rand(d_v, d))

### Computing the Unnormalized Attention Weights

![](https://sebastianraschka.com/images/blog/2023/self-attention-from-scratch/query.png)

In [15]:
x_2 = embedded_sentence[1]  # focusing on second input for getting query and subsequent result

query_2 = W_query.matmul(x_2)

key_2 = W_key.matmul(x_2)

value_2 = W_value.matmul(x_2)

print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

torch.Size([24])
torch.Size([24])
torch.Size([28])


In [16]:
embedded_sentence.shape

torch.Size([6, 16])

In [17]:
W_key.shape

torch.Size([24, 16])

In [18]:
keys = W_key.matmul(embedded_sentence.T).T
values = W_value.matmul(embedded_sentence.T).T

In [19]:
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 24])
values.shape: torch.Size([6, 28])


Now that we have all the required keys and values, we can proceed to the next step and compute the unnormalized attention weights ω
 , which are illustrated in the figure below:

![](https://sebastianraschka.com/images/blog/2023/self-attention-from-scratch/omega.png)

In [20]:
query_2.shape

torch.Size([24])

In [21]:
keys.shape

torch.Size([6, 24])

In [22]:
omega_2 = query_2.matmul(keys.T)

In [23]:
omega_2

tensor([ -7.0847,  -4.5398,   3.9887,  10.2379,   2.3206, -10.5434],
       grad_fn=<SqueezeBackward4>)

Computing the Attention Scores

The subsequent step in self-attention is to normalize the unnormalized attention weights, ω
, to obtain the normalized attention weights, α
, by applying the softmax function. Additionally, 1/dk‾‾√
 is used to scale ω
 before normalizing it through the softmax function, as shown below:


 ![](https://sebastianraschka.com/images/blog/2023/self-attention-from-scratch/attention-scores.png)

The scaling by dk ensures that the Euclidean length of the weight vectors will be approximately in the same magnitude. This helps prevent the attention weights from becoming too small or too large, which could lead to numerical instability or affect the model’s ability to converge during training.

In [24]:
import torch.nn.functional as F

In [25]:
attention_weights_2 = F.softmax(omega_2/d_k**0.5, dim=0)

In [26]:
attention_weights_2

tensor([0.0185, 0.0312, 0.1778, 0.6368, 0.1265, 0.0092],
       grad_fn=<SoftmaxBackward0>)

Finally, the last step is to compute the context vector z(2)
, which is an attention-weighted version of our original query input x(2)
, including all the other input elements as its context via the attention weights:

![](https://sebastianraschka.com/images/blog/2023/self-attention-from-scratch/context-vector.png)

In [27]:
values.shape

torch.Size([6, 28])

In [28]:
context_vector_2 = attention_weights_2.matmul(values)

In [29]:
context_vector_2

tensor([-0.9495, -1.4345, -2.0504, -0.3737, -1.5098, -0.5921, -0.4289, -1.9790,
        -1.7937, -0.7146, -0.9926, -2.0061, -2.1961, -1.7174, -1.0732, -0.7900,
        -1.7367, -2.2095, -0.9344, -1.5299, -0.2828, -0.5350, -1.7285, -1.5485,
        -0.2043, -0.7109, -1.5165, -1.5167], grad_fn=<SqueezeBackward4>)

In [30]:
context_vector_2.shape

torch.Size([28])

Note that this output vector has more dimensions (dv=28) than the original input vector (d=16) since we specified **dv > d** earlier; however, the embedding size choice is arbitrary.

## Multi-Head Attention:

How does that relate to the self-attention mechanism (scaled-dot product attention) we walked through above?

In the scaled dot-product attention, the input sequence was transformed using three matrices representing the query, key, and value. These three matrices can be considered as a single attention head in the context of multi-head attention. The figure below summarizes this single attention head we covered previously:


![](https://sebastianraschka.com/images/blog/2023/self-attention-from-scratch/single-head.png)


As its name implies, multi-head attention involves multiple such heads, each consisting of query, key, and value matrices. This concept is similar to the use of multiple kernels in convolutional neural networks.

![](https://sebastianraschka.com/images/blog/2023/self-attention-from-scratch/multi-head.png)

To illustrate this in code, suppose we have **3 attention heads**, so we now extend the d′×d
 dimensional weight matrices so 3×d′×d
:


In [33]:
h = 3

multihead_W_query = torch.nn.Parameter(torch.rand(h, d_q, d))

multihead_W_key = torch.nn.Parameter(torch.rand(h, d_k, d))

multihead_W_value = torch.nn.Parameter(torch.rand(h, d_v, d))

In [34]:
multihead_W_query.shape

torch.Size([3, 24, 16])

Consequently, each query element is now 3×dq
 dimensional, where dq=24
 (here, let’s keep the focus on the 3rd element corresponding to index position 2):

In [35]:
multihead_query_2 = multihead_W_query.matmul(x_2)

In [36]:
multihead_query_2.shape

torch.Size([3, 24])