# **Attention Mechanisms**

In [1]:
import torch

## The Need for Attention - In a Nutshell

As this notebook is being written (mid 2025), the seminal paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762) is nearly 8 years old. The abstract states:

>The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. 
>
>Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data. 

Eschewing encoder-decoder RNNS, which were commonly used in translation tasks, in favour of an architecture which could selectively access all input tokens of a given sequence and assign signifance on a comparative basis helped to usher in the age of LLMs as we know them.

Transformers built using attention mechanisms addressed the limitations of RNNS i.e.:

1. **Short range dependency** i.e. the failure to grasp connections between distant words / sequences.
2. **Limited Parallelism** i.e. slow processing of information due to sequential design.
3. **Focus on Local Context** i.e. primarily considering immediate neighbours and potentially missing critical information from other parts of the input sequence.

## 1. Self Attention

The primary function of self attention is to generate "context-aware" vectors from the input sequence.

![title](.//images/self-attention.webp)

_figure from Deep Learning with Python by Francios Chollet_


> The "self" refers to the mechanisms ability to compute attention weights by relating different positions within a single input sequence. It assesses and learns the relationships and dependencies between various parts of the input itself, such as words in a sentence or pixels in an image. - S. Raschka

### 1.1 Simple Self-Attention Without Trainable Weights

The goal is to compute a context vector for each input element that combines information from all other input elements.

In [2]:
# Input tensor representing a sequence which has already by embedded into 3 dimensional vectors
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your    (x^1)
     [0.55, 0.87, 0.66], # journey (x^2)
     [0.57, 0.85, 0.64], # starts  (x^3)
     [0.22, 0.58, 0.33], # with    (x^4)
     [0.77, 0.25, 0.10], # one     (x^5)
     [0.05, 0.80, 0.55]] # step    (x^6)
)

In [None]:
# Calculate intermediate attention scores by taking dot product of the query with every other input token
query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])

for i, x_i in enumerate(inputs):
    # dot product without transpose since these are 1 dim vectors
    attn_scores_2[i] = torch.dot(x_i, query)

print(attn_scores_2)


tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [4]:
# Illustration of dot products at work
res = 0.

for idx, element in enumerate(inputs[0]):
    res += inputs[0][idx] * query[idx]

print(res)
print(torch.dot(inputs[0], query))

tensor(0.9544)
tensor(0.9544)


> Beyond viewing the dot product operations as a mathematical tool...(it) is a measure of similarity because it quantifies how closely two vectors are aligned... In the context of self attention mechanisms, the dot product determines the extent to which each element in a sequence focuses on, or 'attends to' any other element.

In [5]:
# Simple normalization of the attention scores, they should sum upto ~1.0
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()

print("Attention weights: ", attn_weights_2_tmp)
print("Sum: ", attn_weights_2_tmp.sum())

Attention weights:  tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum:  tensor(1.0000)


In [6]:
# In practice softmax is used since it is better at handling extreme values and always results in positive values.
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)

print("Attention weights: ", attn_weights_2_naive)
print("Sum: ", attn_weights_2_naive.sum())

Attention weights:  tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum:  tensor(1.)


In [7]:
# To correct for numerical instability i.e. over and underflow...stick with the pytorch version which is extensively optimized
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)

print("Attention weights: ", attn_weights_2)
print("Sum: ", attn_weights_2.sum())

Attention weights:  tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum:  tensor(1.)


In [8]:
# The final step focuses on creating the context vector
# Multiply the embedded input tokens with the corresponding attention weights
# and then summing.
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i]*x_i
    
print(context_vec_2)


tensor([0.4419, 0.6515, 0.5683])


### 1.2 Computing Attention Weights for all Input Tokens

Let's extend the computation to calculate attention weights and context vectors for all inputs.

In [21]:
# Apply the previous step to all pairwise elements to compute the un-normalized attention score matrix
attn_scores = torch.empty(6, 6)

for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs): # additional loop to compute the dot products for all pairs of inputs
        attn_scores[i, j] = torch.dot(x_i, x_j)
        
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [20]:
# Replace the for loops with matrix multiplication
attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [22]:
# Normalize each row so values sum up to 1
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [30]:
# Quick check
row_3_sum = attn_weights[2].sum(dim=0)
all_row_sum = attn_weights.sum(dim=-1)

print(f"Row 3 sum: {row_3_sum}")
print(f"All row sum: {all_row_sum}")

Row 3 sum: 1.0000001192092896
All row sum: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


In [31]:
# Using the attention weights to compute all context vectors via matmul
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


## 3. Implementing Self Attention With Trainable Weights