# Attention Mechanism



In [None]:
%load_ext watermark
%watermark -v -p numpy,pandas,polars,torch,lightning --conda

In [None]:
# Built-in library
from pathlib import Path
import re
import json
from typing import Any, Optional, Union
import logging
import warnings

# Standard imports
import numpy as np
import numpy.typing as npt
from pprint import pprint
import pandas as pd
import polars as pl
from rich.console import Console
from rich.theme import Theme

custom_theme = Theme(
    {
        "info": "#76FF7B",
        "warning": "#FBDDFE",
        "error": "#FF0000",
    }
)
console = Console(theme=custom_theme)

# Visualization
import matplotlib.pyplot as plt


# Pandas settings
pd.options.display.max_rows = 1_000
pd.options.display.max_columns = 1_000
pd.options.display.max_colwidth = 600

warnings.filterwarnings("ignore")


# Black code formatter (Optional)
%load_ext lab_black

# auto reload imports
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torch import nn, Tensor
import torch.nn.functional as F

### Context Vector

- Context vector is the weighted sum of the input vectors that captures the relevent information from the entire sequence for a given position. i.e. it can be thought of as an enriched embedding vector of the inout

#### Calculate Context Vector

- Attention Score: 
  - it's calculated by finding the dot product of the token's query vector and the key vector of the other tokens in the sequence.
  - The scores are normalized using softmax to produce the attention weights.
- Multiply the embedded input tokens with their corresponding attention weights and sum the resulting vectors to get the context vector.
- This is done for each position in the sequence to get the context vector for the entire sequence.

## Without Trainable Parameters (Simplified Version)

In [None]:
seed: int = 5

# Assume that we have an input with a 3-D embeddings shown below:
inputs: Tensor = torch.tensor(
    [
        [0.43, 0.15, 0.89],  # Your (x^1)
        [0.55, 0.87, 0.66],  # journey (x^2)
        [0.57, 0.85, 0.64],  # starts (x^3)
        [0.22, 0.58, 0.33],  # with (x^4)
        [0.77, 0.25, 0.10],  # one (x^5)
        [0.05, 0.80, 0.55],  # step (x^6)
    ]
)

# Calculate the context vector for the 2nd token (x^2)
# 1: Cal the attention scores
query: Tensor = inputs[1]
attn_scores_index_1: Tensor = torch.empty(inputs.shape[0])

for idx, x_1 in enumerate(inputs):
    # Cal the dot product of the query vector and each key vector in the input
    attn_scores_index_1[idx] = torch.dot(x_1, query)

print(f"{attn_scores_index_1 = }")

In [None]:
# 2: Normalize the attention scores to obtain the attention weights
attn_scores_weights_1: Tensor = torch.softmax(attn_scores_index_1, dim=-1)
print(f"{attn_scores_weights_1 = }")
attn_scores_weights_1.sum(-1)

In [None]:
inputs_shape: tuple = tuple(inputs.shape)
attn_scores_weights_1_shape: tuple = tuple(attn_scores_weights_1.shape)
print(f"{attn_scores_weights_1_shape = } AND {inputs_shape = }")

# 3: Calculate the context vector as the weighted average of the values
# Transpose the inputs so that we can perform matrix multiplication
context_vector_1: Tensor = attn_scores_weights_1 @ inputs
context_vector_1

#### Calculate The Attention Weights Of The Sequence

In [None]:
# Step 1: Calculate the attention scores
print(f"{inputs.shape = } AND {inputs.T.shape = }")
attn_scores: Tensor = inputs @ inputs.T
print(f"\n{attn_scores.shape = }")

attn_scores

In [None]:
# Step 2: Calculate the attention weights. i.e. normalize the attention scores using softmax
attn_weights = torch.softmax(attn_scores, dim=-1)
print(f"\n{attn_weights.shape = }")
attn_weights

In [None]:
# Step 3: Calculate the context vector
print(f"{attn_weights.shape = } AND {inputs.shape = }")

context_vector: Tensor = attn_weights @ inputs
context_vector

<br><hr>

## Implement Self-Attention With Trainable Parameters

- AKA **Scaled Dot-Product Attention**
- Add weight matrices that are updated during training.

In [None]:
# Calculate the trainable attention weights for a given token in the input
x_1: Tensor = inputs[1]
print(f"{x_1.shape = }")
# Embedding dimension
d_in: int = x_1.shape[-1]
# Output embedding size
d_out: int = 2

In [None]:
torch.manual_seed(seed)

# Trainable parameters: requires_grad=False (to reduce the clutter and keep things simple)
W_query: Tensor = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_key: Tensor = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_value: Tensor = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)

# Compute the query, key, and value tensors for the given index
query_1: Tensor = torch.matmul(x_1, W_query)
key_1: Tensor = torch.matmul(x_1, W_key)
value_1: Tensor = torch.matmul(x_1, W_value)

# Compute the key and value tensors for ALL the input
query: Tensor = torch.matmul(inputs, W_query)
key: Tensor = torch.matmul(inputs, W_key)
value: Tensor = torch.matmul(inputs, W_value)

print(f"{query_1.shape = }")
query_1

In [None]:
query

In [None]:
print(f"{query_1.shape = } | {key_1.shape = } | {value_1.shape = }")
print(f"{query.shape = } | {key.shape = } | {value.shape = }")
print()

# Calculate the attention scores
# For a single token in the query
attn_score_1: Tensor = torch.matmul(query_1, key.T)  # query_1 @ key_1

# For all the tokens in the query
attn_scores: Tensor = torch.matmul(query, key.T)

print(f"{attn_score_1.shape =} | {attn_score_1 = }")
print()
print(f"{attn_scores.shape = } | {attn_scores = }")

In [None]:
# Calculate the scaled attention weights. It's scaled by the square root of
# the dimension size to improve the training performance and avoid small gradients.
attn_weights_1: Tensor = torch.softmax(attn_score_1 / (d_out**0.5), dim=-1)
attn_weights: Tensor = torch.softmax(attn_scores / (d_out**0.5), dim=-1)


print(f"{attn_weights_1.shape =} | {attn_weights_1 = }")
print()
print(f"{attn_weights.shape = } | {attn_weights = }")

In [None]:
inputs.shape

In [None]:
print(f"{attn_weights_1.shape = } | {value.shape = }")
print(f"{attn_weights.shape = } | {value.shape = }")

context_vector_1: Tensor = attn_weights_1 @ value
context_vector: Tensor = attn_weights @ value
print()
print(f"{context_vector_1 = }\n\n")

print(f"{context_vector = }\n\n")

### Query, Key and Value

- **Query** : The query is analogous to a `search` in a `database`. It represents the current item/token the model focuses on.
- **Key** : The key is analogous to the `index` in a `database`. It represents the item/token that the model compares the query to.
- **Value** : The value is analogous to the `value` in a `key-value` pair. It represents the actual content or representation of the item/token.

In [None]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in: int, d_out: int) -> None:
        super().__init__()

        self.d_out = d_out

        self.W_q = nn.Parameter(torch.randn(d_in, d_out))
        self.W_k = nn.Parameter(torch.randn(d_in, d_out))
        self.W_v = nn.Parameter(torch.randn(d_in, d_out))

    def forward(self, x: Tensor) -> Tensor:
        queries: Tensor = torch.matmul(x, W_query)
        keys: Tensor = torch.matmul(x, W_key)
        values: Tensor = torch.matmul(x, W_value)
        attn_scores: Tensor = queries @ keys.T
        attn_weights: Tensor = torch.softmax(attn_scores / self.d_out**0.5, dim=-1)
        context_vector: Tensor = torch.matmul(attn_weights, values)
        return context_vector

In [None]:
torch.manual_seed(seed)
self_attn_v1 = SelfAttention_v1(d_in=d_in, d_out=d_out)
print(f"{self_attn_v1 = }")
print(self_attn_v1(inputs))

### Update

- Improve the `SelfAttention_v1` implementation using PyTorch's `nn.Linear` layers instead of `nn.Parameter` layers.

- This is because:
  - `nn.Linear` performs effective matrix multiplication when the bias units are disabled.
  - `nn.Linear` has a an optimized weight initialization scheme.

In [None]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in: int, d_out: int) -> None:
        super().__init__()

        self.d_out = d_out

        self.W_q = nn.Linear(d_in, d_out)
        self.W_k = nn.Linear(d_in, d_out)
        self.W_v = nn.Linear(d_in, d_out)

    def forward(self, x: Tensor) -> Tensor:
        queries: Tensor = torch.matmul(x, W_query)
        keys: Tensor = torch.matmul(x, W_key)
        values: Tensor = torch.matmul(x, W_value)
        attn_scores: Tensor = queries @ keys.T
        attn_weights: Tensor = torch.softmax(attn_scores / self.d_out**0.5, dim=-1)
        context_vector: Tensor = torch.matmul(attn_weights, values)

        return context_vector

In [None]:
torch.manual_seed(seed)
self_attn_v1 = SelfAttention_v2(d_in=d_in, d_out=d_out)
print(f"{self_attn_v1 = }")
print(self_attn_v1(inputs))