# Attention Mechanism



In [1]:
%load_ext watermark
%watermark -v -p numpy,pandas,polars,torch,lightning --conda

Python implementation: CPython
Python version       : 3.11.8
IPython version      : 8.22.2

numpy    : 1.26.4
pandas   : 2.2.1
polars   : 0.20.18
torch    : 2.2.2
lightning: 2.2.1

conda environment: torch_p11



In [2]:
# Built-in library
from pathlib import Path
import re
import json
from typing import Any, Optional, Union
import logging
import warnings

# Standard imports
import numpy as np
import numpy.typing as npt
from pprint import pprint
import pandas as pd
import polars as pl
from rich.console import Console
from rich.theme import Theme

custom_theme = Theme(
    {
        "info": "#76FF7B",
        "warning": "#FBDDFE",
        "error": "#FF0000",
    }
)
console = Console(theme=custom_theme)

# Visualization
import matplotlib.pyplot as plt


# Pandas settings
pd.options.display.max_rows = 1_000
pd.options.display.max_columns = 1_000
pd.options.display.max_colwidth = 600

warnings.filterwarnings("ignore")


# Black code formatter (Optional)
%load_ext lab_black

# auto reload imports
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from torch import nn, Tensor
import torch.nn.functional as F

### Context Vector

- Context vector is the weighted sum of the input vectors that captures the relevent information from the entire sequence for a given position. i.e. it can be thought of as an enriched embedding vector of the inout

#### Calculate Context Vector

- Attention Score: 
  - it's calculated by finding the dot product of the token's query vector and the key vector of the other tokens in the sequence.
  - The scores are normalized using softmax to produce the attention weights.
- Multiply the embedded input tokens with their corresponding attention weights and sum the resulting vectors to get the context vector.
- This is done for each position in the sequence to get the context vector for the entire sequence.

In [4]:
# Assume that we have an input with a 3-D embeddings shown below:
inputs: Tensor = torch.tensor(
    [
        [0.43, 0.15, 0.89],  # Your (x^1)
        [0.55, 0.87, 0.66],  # journey (x^2)
        [0.57, 0.85, 0.64],  # starts (x^3)
        [0.22, 0.58, 0.33],  # with (x^4)
        [0.77, 0.25, 0.10],  # one (x^5)
        [0.05, 0.80, 0.55],  # step (x^6)
    ]
)

# Calculate the context vector for the 2nd token (x^2)
# 1: Cal the attention scores
query: Tensor = inputs[1]
attn_scores_index_1: Tensor = torch.empty(inputs.shape[0])

for idx, x_1 in enumerate(inputs):
    # Cal the dot product of the query vector and each key vector in the input
    attn_scores_index_1[idx] = torch.dot(x_1, query)

print(f"{attn_scores_index_1 = }")

attn_scores_index_1 = tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [5]:
# 2: Normalize the attention scores to obtain the attention weights
attn_scores_weights_1: Tensor = torch.softmax(attn_scores_index_1, dim=-1)
print(f"{attn_scores_weights_1 = }")
attn_scores_weights_1.sum(-1)

attn_scores_weights_1 = tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


tensor(1.)

In [6]:
inputs_shape: tuple = tuple(inputs.shape)
attn_scores_weights_1_shape: tuple = tuple(attn_scores_weights_1.shape)
print(f"{inputs_shape = } AND {attn_scores_weights_1_shape = }")

# 3: Calculate the context vector as the weighted average of the values
# Transpose the inputs so that we can perform matrix multiplication
context_vector_1: Tensor = inputs.T @ attn_scores_weights_1
context_vector_1

inputs_shape = (6, 3) AND attn_scores_weights_1_shape = (6,)


tensor([0.4419, 0.6515, 0.5683])

#### Calculate The Attention Weights Of The Sequence

In [7]:
# Step 1: Calculate the attention scores
print(f"{inputs.shape = } AND {inputs.T.shape = }")
attn_scores: Tensor = inputs @ inputs.T
print(f"\n{attn_scores.shape = }")

attn_scores

inputs.shape = torch.Size([6, 3]) AND inputs.T.shape = torch.Size([3, 6])

attn_scores.shape = torch.Size([6, 6])


tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [8]:
# Step 2: Calculate the attention weights. i.e. normalize the attention scores using softmax
attn_weights = torch.softmax(attn_scores, dim=-1)
print(f"\n{attn_weights.shape = }")
attn_weights


attn_weights.shape = torch.Size([6, 6])


tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [9]:
# Step 3: Calculate the context vector
print(f"{attn_weights.shape = } AND {inputs.shape = }")

context_vector: Tensor = attn_weights @ inputs
context_vector

attn_weights.shape = torch.Size([6, 6]) AND inputs.shape = torch.Size([6, 3])


tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])