# Self-Attention Tutorial

---
## Table of Contents

* [Self-Attention Overview](#self_attention_overview)
    * [Process](#process)
    * [Pseudocode](#pseudocode)
* [Self-Attention Pipeline](#self_attention_pipeline)
    * [Imports](#imports)
    * [Mini-Example](#mini)
    * [Random Image](#image)
    * [Image with Patches](#patches)
    * [Model](#model)
    * [Helper Functions](#functions)
* [Attention Runs - Similar](#runs) 

# Self-Attention Overview<a class="anchor" id="self_attention_overview"></a>

Self-attention has ascended to the forefront of deep learning, demonstrating remarkable effectiveness in a multitude of domains, encompassing image processing as well. This innovative approach has revolutionized the field of computer vision, particularly with its integration into Vision Transformers (ViTs). ViTs leverage self-attention mechanisms either as a complete substitute for, or as a collaborative partner with, traditional convolutional layers. However, it's essential to acknowledge the trade-offs inherent to this technology. While self-attention boasts numerous advantages, it often necessitates a significant increase in the number of parameters and computational resources required, compared to its established counterpart, the Convolutional Neural Network (CNN).

## Process <a class="anchor" id="process"></a>

1. **Feature Extraction: The Foundation**

The process commences with feeding the image into a pre-trained convolutional neural network (CNN) akin to traditional approaches. This CNN acts as a feature extractor, meticulously dissecting the image and capturing low-level details like edges, textures, and color variations. These extracted details, represented as feature maps, provide the essential groundwork upon which self-attention builds its understanding.

2. **Unveiling the Power of Three: Queries, Keys, and Values**

Each element within the feature maps is then transformed into a trio of novel vectors, each playing a distinct role in the self-attention mechanism:

Query (Q): This vector embodies the essence of what the model seeks within the intricate tapestry of the feature maps. Imagine it as a specific question the model poses about each region of the image, aiming to unearth the most pertinent information.

Key (K): In contrast to the inquisitive nature of the query vector, the key vector acts as a repository of knowledge about each element in the feature maps. It's akin to an answer key, holding the potential answers that correspond to the queries posed by the model.

Value (V): This vector carries the actual content or information associated with each element within the feature maps. It's the treasure trove of data that the model will selectively pay attention to, based on the efficacy of the query-key matching process.

3. **The Heart of the Mechanism: Attention Scores**

The crux of self-attention lies in calculating a score for every conceivable relationship between elements in the feature maps. This score meticulously quantifies how relevant a particular element's value is in response to the current query, considering the corresponding key as the intermediary. Mathematically, the similarity between the query and key vectors is often computed using a dot product operation. Intuitively, a higher dot product signifies a stronger alignment between the query and the key, implying greater relevance of the associated value.

4. **A Dance of Weights: Assigning Importance**

Once the attention scores for all potential relationships are meticulously calculated, the model embarks on the task of assigning weights to each value vector. These weights act as a measure of significance, indicating how important the information encoded within each element is relative to the current query. Elements that generate higher attention scores naturally translate to higher weights. In essence, the model is meticulously prioritizing its focus on the most relevant features within the image for the specific task at hand.

5. **The Weighted Sum: Aggregating Knowledge**

Finally, the model takes a weighted sum of all the value vectors, leveraging the meticulously calculated attention weights. This process effectively creates a refined representation for the current element, emphasizing the features that hold the most relevance for the specific query. Imagine the model intelligently amplifying the most critical details while judiciously suppressing less important ones.

6. **A Meticulous Journey: Repetition and Refinement**

This entire intricate dance of transformations, attention score calculations, weight assignment, and weighted summation (steps 2-5) is meticulously repeated for each and every element within the feature maps. Essentially, the model progresses through the image piece by piece, attentively gleaning the most crucial information for the task at hand. It's akin to the model meticulously examining each brushstroke in a painting, meticulously piecing together the narrative it conveys.

**Beyond the Basics: Multi-Head Attention**

While the explanation above delves into the core workings of self-attention, there are further advancements to explore. One such technique is multi-head attention, which involves performing multiple self-attention operations in parallel. Each of these "heads" can attend to different aspects of the relationships within the image, effectively capturing a richer and more nuanced understanding of the intricate interplay between image elements. This allows the model to not only identify the most relevant features but also to grasp the subtle interplay between them, leading to superior performance in computer vision tasks.

By understanding the intricate workings of self-attention mechanisms, you gain valuable insights into how deep learning models are revolutionizing the field of image processing. This powerful technique allows models to not only focus on individual features but also to grasp the crucial relationships that tie an image together, leading to significant advancements in tasks like object recognition, image segmentation, and image captioning.

## Pseudocode <a class="anchor" id="pseudocode"></a>

This pseudocode provides a high-level overview of the logic behind self-attention, without focusing on specific programming language syntax. Remember, this is a simplified representation, and the actual implementation can vary depending on the chosen framework and specific application.

```
Input: data - a 2D array representing the feature map
Output: output - a 2D array representing the attention-weighted output

# Function: self_attention(data)
function self_attention(data):
  # Get dimensions of the data
  num_elements, feature_dim = data.shape

  # Initialize empty arrays for query, key, value (usually learned during training)
  query = new array(num_elements, feature_dim)
  key = new array(num_elements, feature_dim)
  value = new array(num_elements, feature_dim)

  # (Fill query, key, and value with actual values, typically learned parameters)

  # Calculate attention scores using dot product (query and transposed key)
  attention_scores = dot_product(query, transpose(key))

  # Apply softmax for normalized weights
  attention_weights = softmax(attention_scores)

  # Initialize empty array for output
  output = new array(num_elements, feature_dim)

  # Loop through each element in the data
  for i in range(num_elements):
    # Weighted sum of value vectors using attention weights
    for j in range(feature_dim):
      output[i, j] = 0
      for k in range(num_elements):
        output[i, j] += attention_weights[i, k] * value[k, j]

  # Return the attention-weighted output
  return output
```

# Self-Attention Pipeline <a class="anchor" id="self_attention_pipeline"></a>

## Imports <a class="anchor" id="imports"></a>

In [None]:
from argparse import Namespace
from dataclasses import dataclass, field
import math
import os
from pathlib import Path
import re
from typing import Dict, Union

from IPython.display import display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import models as torchvision_models
from torchvision import transforms as pth_transforms

from notebooks.transformers import vision_transformer as vits

print(f'Torch version is: {torch.__version__}')
print(f'Is CUDA available: {torch.cuda.is_available()}')

## Mini-Example <a class="anchor" id="mini"></a>

In [None]:
def softmax(x, axis=-1):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return e_x / e_x.sum(axis=axis, keepdims=True)

# Example concatenated embedding
concatenated_embedding = np.array([[1, 2, 3, 7, 8],
                                   [4, 5, 6, 9, 10]])
print(f"shape: {concatenated_embedding.shape}")
print(f"shape -1: {concatenated_embedding.shape[-1]}")

# Assuming the number of attention heads and other parameters
num_heads = 2
d_model = concatenated_embedding.shape[-1]  # Dimension of the concatenated embedding
print(f"d_model: {d_model}")
d_k = d_v = d_model // num_heads
print(f"d_k and d_v: {d_k}")
print("-----")

# Generate Query, Key, and Value matrices
Q = concatenated_embedding.dot(np.random.rand(d_model, d_k))  # Query matrix
K = concatenated_embedding.dot(np.random.rand(d_model, d_k))  # Key matrix
V = concatenated_embedding.dot(np.random.rand(d_model, d_v))  # Value matrix
print(f"Q: {Q}\nK: {K}\nV: {V}")
print("-----")

# Calculate attention scores
attention_scores = np.matmul(Q, K.T) / np.sqrt(d_k)
print(f"scores: {attention_scores}")
print("-----")

# Apply softmax to obtain attention weights
attention_weights = softmax(attention_scores, axis=-1)
print(f"weights: {attention_weights}")
print("-----")

# Compute the weighted sum using attention weights
output = np.matmul(attention_weights, V)
print(f"Output after weighted sum - self-attention: (shape: {output.shape})")
print(output)

## 64x64 Random Images and Positional Encoding <a class="anchor" id="image"></a>

In [None]:
def positional_encoding(embedding):
    """Add positional encoding to input embeddings."""
    print(f"Embedding Dimension: {embedding.shape}")
    seq_len = embedding.shape[1]
    channel_dim = embedding.shape[-1] # channels
    print("Seq len:", seq_len)
    print("Channel dim:", channel_dim)
    print(f"Np zeros shape: {np.zeros((seq_len, channel_dim)).shape}")
#     print(f"Np zeros: {np.zeros((seq_len, channel_dim))}")
    pos_enc = np.zeros((seq_len, channel_dim))
    for pos in range(seq_len):
        for i in range(0, channel_dim, 2):
            pos_enc[pos, i] = math.sin(pos / (10000 ** ((2 * i) // channel_dim)))
            if i + 1 < channel_dim:  # Ensure not to exceed the last dimension
                pos_enc[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) // channel_dim)))
    return embedding + pos_enc

image_size = 64 # High image size will not work on small GPU
num_images = 3

# Parameters
num_channels = 3  # 3 channels for RGB images
num_heads = 2
dim = image_size * image_size * num_channels  # Dimension of the input (total number of features)
print(f"dim: {dim}")
d_k = d_v = dim // num_heads
print(f"d_k and d_v: {d_k}")
print("-----")

images = np.random.rand(num_images, image_size, image_size, num_channels)
print(f"IMAGES SHAPE: {images.shape}") # 2 images, 64 by 64, 3 channels
print("-----")

# Reshape images to have a sequence length of image_size * image_size
images_flattened = images.reshape(num_images, -1, num_channels)
print(f"FLATTENED SHAPE: {images_flattened.shape}")
print(f"FLATTENED: {images_flattened}")
print("-----")

# Add positional encoding
print(f"POSITIONAL ENCODING OUTPUT (USING FLATTENED IMAGE)")
images_with_pos = positional_encoding(images_flattened)
print(f"Input images after positional encoding: (shape={images_with_pos.shape}")
print(images_with_pos)

#### With Basic np.matmul

In [None]:
# Reshape images_with_pos to align with the matrix multiplication
images_with_pos_reshaped = images_with_pos.reshape(num_images, -1, dim)

# Generate Query, Key, and Value matrices
Q = np.matmul(images_with_pos_reshaped, np.random.rand(dim, d_k))  # Query matrix
K = np.matmul(images_with_pos_reshaped, np.random.rand(dim, d_k))  # Key matrix
V = np.matmul(images_with_pos_reshaped, np.random.rand(dim, d_v))  # Value matrix
print(f"Q: {Q}\nK: {K}\nV: {V}")
print("-----")

# Calculate attention scores
scale = None or d_k**-0.5

attention_scores = np.matmul(Q, K.transpose((0, 2, 1))) / np.sqrt(d_k) # Standard
print(f"scores (shape: {attention_scores.shape}):\n{attention_scores}")
print("-----")

# Apply softmax to obtain attention weights
attention_weights = softmax(attention_scores, axis=-1)
print(f"weights (shape: {attention_weights.shape}):\n{attention_weights}")
print("-----")

# Compute the weighted sum using attention weights
output = np.matmul(attention_weights, V)
print(f"Output after weighted sum - self-attention: (shape: {output.shape})")
print(f"OUTPUT:\n{output}")

#### With Linear Projection

In [None]:
# Add positional encoding
images_with_pos = positional_encoding(images_flattened)
print("Input images after positional encoding:")
print(images_with_pos)
print(f"IMAGES WITH POS SHAPE: {images_with_pos.shape}")

# Linear transformation to generate Q, K, and V matrices
linear_qkv = nn.Linear(num_channels, d_k * 3, bias=False) # Linear(in_features=3, out_features=18432, bias=False)
qkv = linear_qkv(torch.tensor(images_with_pos, dtype=torch.float32))
print(f"QKV AFTER LINEAR: {qkv.shape}")
qkv = qkv.reshape(num_images, -1, 3, num_heads, d_k // num_heads).permute(2, 0, 3, 1, 4)
print(f"QKV AFTER RESHAPE: {qkv.shape}")

Q, K, V = qkv[0], qkv[1], qkv[2]

print(f"Q:\n{Q}\nK:\n{K}\nV:\n{V}")
print("-----")
# Calculate attention scores
scale = d_k ** -0.5
attention_scores = (torch.matmul(Q, K.transpose(-2, -1))) * scale
print(f"scores (shape: {attention_scores.shape}):\n{attention_scores}")
print("-----")
# Apply softmax to obtain attention weights
attention_weights = softmax(attention_scores.detach().numpy(), axis=-1)
print(f"weights (shape: {attention_weights.shape}):\n{attention_weights}")
print("-----")
# Compute the weighted sum using attention weights
output = np.matmul(attention_weights, V.detach().numpy())

print(f"Output after weighted sum - self-attention: (shape: {output.shape})")
print(output)

## Image with patches example <a class="anchor" id="patches"></a>

In [None]:
def positional_encoding(embedding):
    """Add positional encoding to input embeddings."""
    seq_len = embedding.shape[1]
    channel_dim = embedding.shape[-1]
    pos_enc = np.zeros((seq_len, channel_dim))
    for pos in range(seq_len):
        for i in range(0, channel_dim, 2):
            pos_enc[pos, i] = math.sin(pos / (10000 ** ((2 * i) // channel_dim)))
            if i + 1 < channel_dim:  
                pos_enc[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) // channel_dim)))
    return embedding + pos_enc

def extract_patches(images, patch_size):
    patches = []
    for image in images:
        for i in range(0, image.shape[0] - patch_size + 1, patch_size):
            for j in range(0, image.shape[1] - patch_size + 1, patch_size):
                patch = image[i:i+patch_size, j:j+patch_size, :]
                patches.append(patch.flatten())
    return np.array(patches)


In [None]:
# Example images (for demonstration, you would use your actual images here)
image_w = 500
image_h = 375

# Parameters
num_channels = 3 # Assuming RGB images
num_heads = 8
patch_size = 16 # Patch size (similar to DINO)
dim = num_channels * patch_size * patch_size  # Dimension of each patch
print(f"dim: {dim}")
d_k = d_v = dim // num_heads
print(f"d_k and d_v: {d_k}")
print("-----")

image = np.random.rand(1, image_h, image_w, num_channels)[0]
print(f"IMAGE SHAPE: {image.shape}") # 2 images, 64 by 64, 3 channels
print("-----")

# Extract patches from the input image
image_patches = extract_patches([image], patch_size)  # Pass a list containing the single image
print(f"PATCHES SHAPE: {image_patches.shape}")
print("-----")

# Apply positional encoding to each patch individually
num_patches = image_patches.shape[0]
pos_enc = positional_encoding(image_patches.reshape(num_patches, -1, num_channels))
print("Positional encoding:")
print(f"SHAPE: {pos_enc.shape}")
print("-----")

# Concat positional encoding with image_patches
image_patches_with_pos = pos_enc.reshape(1, -1, dim)
print("Image patches after positional encoding:")
print(f"SHAPE: {image_patches_with_pos.shape}")

In [None]:
# Linear transformation to generate Q, K, and V matrices
linear_qkv = nn.Linear(dim, d_k * 3, bias=False)
print(linear_qkv)

# Apply linear transformation to generate Q, K, and V matrices
linear_qkv = nn.Linear(dim, d_k * 3, bias=False)
qkv = linear_qkv(torch.tensor(image_patches_with_pos, dtype=torch.float32))
qkv = qkv.reshape(1, -1, 3, num_heads, d_k // num_heads).permute(2, 0, 3, 1, 4)
Q, K, V = qkv[0], qkv[1], qkv[2]
print(f"Q:\n{Q}\nK:\n{K}\nV:\n{V}")
print("-----")

# Calculate attention scores
scale = d_k ** -0.5
# Transpose K before performing matrix multiplication
attention_scores = (torch.matmul(Q, K.transpose(-2, -1))) * scale
print(f"Attention scores (shape: {attention_scores.shape}):\n{attention_scores}")
print("-----")

# Apply softmax to obtain attention weights
attention_weights = nn.functional.softmax(attention_scores, dim=-1)
print(f"Attention weights (shape: {attention_weights.shape}):\n{attention_weights}")
print("-----")

# Compute the weighted sum using attention weights
output = torch.matmul(attention_weights, V)
print(f"Output after self-attention: (shape: {output.shape})")
print(output)

#### Approach 2

In [None]:
image_size = (500, 375)
num_images = 1
patch_size = 16  # Define the patch size

num_channels = 3
num_heads = 2
dim = patch_size * patch_size * num_channels  # Dimension of each patch
print(f"dim: {dim}")
d_k = d_v = dim // num_heads
print(f"d_k and d_v: {d_k}")
print("-----")

images = np.random.rand(num_images, image_size[1], image_size[0], num_channels)  # Note: OpenCV style (width, height)
print(f"IMAGE SHAPE: {images.shape}") # 2 images, 64 by 64, 3 channels
print("-----")

# Extract patches from images
patches = extract_patches(images, patch_size)
print(f"PATCHES SHAPE: {patches.shape}")
print("-----")

# Add positional encoding
num_patches = patches.shape[0]
patches_with_pos = positional_encoding(patches.reshape(num_patches, -1, num_channels))
print(f"PATCHES WITH POS: {patches_with_pos.shape}")
print("-----")

# Reshape patches_with_pos to match the expected input shape for linear layer
patches_with_pos_reshaped = patches_with_pos.reshape(num_images, -1, dim)
print(f"PATCHES WITH POS RESHAPE: {patches_with_pos.shape}")
print("-----")

In [None]:
# Linear transformation to generate Q, K, and V matrices
linear_qkv = nn.Linear(dim, d_k * 3, bias=False)
qkv = linear_qkv(torch.tensor(patches_with_pos_reshaped, dtype=torch.float32))
qkv = qkv.reshape(num_images, -1, 3, num_heads, d_k // num_heads).permute(2, 0, 3, 1, 4)
Q, K, V = qkv[0], qkv[1], qkv[2]
print(f"Q:\n{Q}\nK:\n{K}\nV:\n{V}")
print("-----")

# Calculate attention scores
scale = d_k ** -0.5
attention_scores = (torch.matmul(Q, K.transpose(-2, -1))) * scale
print(f"Attention scores (shape: {attention_scores.shape}):\n{attention_scores}")
print("-----")

# Apply softmax to obtain attention weights
attention_weights = nn.functional.softmax(attention_scores, dim=-1)
print(f"Attention weights (shape: {attention_weights.shape}):\n{attention_weights}")

# Compute the weighted sum using attention weights
output = torch.matmul(attention_weights, V)
print(f"Output after weighted sum - self-attention: (shape: {output.shape})")
print(output)

## Model <a class="anchor" id="model"></a>

In [None]:
arch = "vit_small" # num_heads = 6
checkpoint_key = "teacher"
patch_size = 16

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
if arch in torchvision_models.__dict__.keys():
    model = (torchvision_models.__dict__[arch](
        num_classes=0))
    model.fc = nn.Identity()
else:
    model = vits.__dict__[arch](
        patch_size=patch_size, num_classes=0)

url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"

state_dict = torch.hub.load_state_dict_from_url(
    url="https://dl.fbaipublicfiles.com/dino/" + url)
model.load_state_dict(state_dict, strict=True)

for p in model.parameters():
    p.requires_grad = False

model.cuda()
model.eval()

## Helper Functions <a class="anchor" id="functions"></a>

### Transform

In [None]:
transform = pth_transforms.Compose([
    # pth_transforms.Resize((480, 480)),
    pth_transforms.ToTensor(),
    pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

### Get Attention

In [None]:
def get_attention(img_name, path, patch_size):
    img_path = f"{path}/{img_name}"
    img = Image.open(img_path).convert('RGB')
    print(f"Original Image Size: {img.size}")

    img = transform(img)
    print(f"Transformed Image Size: {img.shape}")
        
    # Make the image divisible by the patch size
    w, h = img.shape[1] - img.shape[1] % patch_size, img.shape[2] - img.shape[2] % patch_size
    img = img[:, :w, :h].unsqueeze(0)
    print(f"Image by Patch Size: {img.shape}")
    
    w_featmap = img.shape[-2] // patch_size
    h_featmap = img.shape[-1] // patch_size

    y, attentions, test_scores, test_weights = model.get_last_selfattention(img.to(device))

    nh = attentions.shape[1]  # Number of heads

    # Keep only the output patch attention
    attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
    
    attentions = attentions.reshape(nh, w_featmap, h_featmap)
    attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=16, mode="nearest")[0].cpu().numpy()
    
    return y, attentions, test_scores, test_weights

### Plot Attentions Function

In [None]:
def plot_img(img, img_per_col = 3, ):
    num_images = len(img)
    num_images

    nrow = num_images // img_per_col
    ncol = num_images // nrow

    fig, axes = plt.subplots(
        nrow, ncol,
        gridspec_kw=dict(wspace=0.05, hspace=0.05,
                         top=1. - 0.5 / (nrow + 1), bottom=0.5 / (nrow + 1),
                         left=0.5 / (ncol + 1), right=1 - 0.5 / (ncol + 1)),
        figsize=(ncol + 10, nrow + 10)
    )

    for i, image in enumerate(img):
        ax = axes[i // ncol, i % ncol]
        ax.imshow(image)
        ax.axis("off")

    plt.show()

# Attention Run <a class="anchor" id="runs"></a>

### Variables

In [None]:
full_path = "/usr/src/ai_tutorials/data/imagenet/val/n01440764"
img_name = "ILSVRC2012_val_00000293.JPEG"

### Display image

In [None]:
img = Image.open(full_path + '/' + img_name)
print(f"IMG SHAPE: {img.size}")
display(img)

### Run attention pipeline

In [None]:
y, img_attentions, scores, weights = get_attention(img_name=img_name, path=full_path, patch_size=patch_size)

In [None]:
type(img_attentions), img_attentions.shape

In [None]:
type(y), y.shape

In [None]:
scores.shape, weights.shape

In [None]:
# ATTENTION SCORES: torch.Size([1, 6, 714, 714])
# tensor([[ 3.1271, -1.5889, -1.9582, -1.7589, -1.5342],
#         [ 3.2784, -0.5752, -0.8337, -0.8466, -0.6552],
#         [ 3.2076, -0.8093, -0.8502, -0.7562, -0.6061],
#         [ 3.2198, -0.7407, -0.7294, -0.5392, -0.4825],
#         [ 2.9159, -0.9545, -1.0094, -0.9100, -0.4406]], device='cuda:0')

# ATTENTION WEIGHTS (SOFTMAX): torch.Size([1, 6, 714, 714])
# tensor([[0.0636, 0.0006, 0.0004, 0.0005, 0.0006],
#         [0.0644, 0.0014, 0.0011, 0.0010, 0.0013],
#         [0.0659, 0.0012, 0.0011, 0.0013, 0.0015],
#         [0.0625, 0.0012, 0.0012, 0.0015, 0.0015],
#         [0.0478, 0.0010, 0.0009, 0.0010, 0.0017]], device='cuda:0')

In [None]:
first_score = scores[0][0][:5][:5]
first_score

In [None]:
# Compute softmax along the last dimension (dim=1)
attn_weights = first_score.softmax(dim=-1)
print(attn_weights)

In [None]:
first_weight = weights[0][0][:5][:5]
first_weight

### Plot attentions

In [None]:
plot_img(img=img_attentions)