# ViT Self-Attention Tutorial

---
## Table of Contents

* [Self-Attention Overview](#self_attention_overview)
    * [Process](#process)
    * [Pseudocode](#pseudocode)
* [Self-Attention Pipeline](#self_attention_pipeline)
    * [Imports](#imports)
    * [Mini-Example](#mini)
    * [Random Image](#image)
    * [Image with Patches](#patches)
    * [Model](#model)
    * [Helper Functions](#functions)
* [Attention Runs - Similar](#runs) 

# Self-Attention Overview<a class="anchor" id="self_attention_overview"></a>

Self-attention has ascended to the forefront of deep learning, demonstrating remarkable effectiveness in a multitude of domains, encompassing image processing as well. This innovative approach has revolutionized the field of computer vision, particularly with its integration into Vision Transformers (ViTs). ViTs leverage self-attention mechanisms either as a complete substitute for, or as a collaborative partner with, traditional convolutional layers. However, it's essential to acknowledge the trade-offs inherent to this technology. While self-attention boasts numerous advantages, it often necessitates a significant increase in the number of parameters and computational resources required, compared to its established counterpart, the Convolutional Neural Network (CNN).

## Process <a class="anchor" id="process"></a>

1. **Define parameters and initialize data:**

- **Image size, number of images, and patch size:** These define the input image and how it's divided into smaller patches.
- **Number of channels, heads, and dimensions:** These determine the complexity of the attention mechanism and the size of internal representations.

**2. Preprocess data:**

- **Extract patches:** The image is divided into overlapping or non-overlapping patches based on the defined patch size.
- **Add positional encoding:** This step incorporates the position information of each patch into its representation, important for capturing long-range dependencies in sequences.

**3. Reshape and linear transformation:**

- **Reshape patches:** The patches are reshaped to a format suitable for further processing.
- **Linear transformation:** A linear layer transforms the patch representations into three sets of vectors: query (Q), key (K), and value (V). These vectors will be used to calculate attention scores.

**4. Calculate attention scores:**

- **Scaled dot product:** The dot product between Q and the transpose of K is calculated and scaled by the square root of dimension for stability.
- **Interpretation:** Higher scores indicate a stronger relationship between a specific patch (represented by Q) and another patch (represented by the corresponding row in K).

**5. Apply softmax and weighted sum:**

- **Softmax:** The attention scores are normalized using softmax, resulting in attention weights between 0 and 1.
- **Weighted sum:** The value vectors (V) are weighted by the attention weights, effectively emphasizing patches relevant to the current focus based on the query patch.
- **Interpretation:** The final output is a weighted combination of the value vectors, capturing the most relevant information from surrounding patches based on the query patch.

## Pseudocode <a class="anchor" id="pseudocode"></a>

This pseudocode provides a high-level overview of the logic behind self-attention, without focusing on specific programming language syntax. Remember, this is a simplified representation, and the actual implementation can vary depending on the chosen framework and specific application.

```
1. Define parameters and functions 

FUNCTION extract_patches(image, patch_size):
  # Divides image into patches and returns a tensor of patches

FUNCTION positional_encoding(patches):
  # Adds positional information to each patch and returns encoded patches

2. Preprocess data

patches = extract_patches(image, patch_size)
patches_with_pos = positional_encoding(patches.flatten())

3. Reshape and linear transformation

num_patches = patches_with_pos.shape[0]
patches_with_pos_reshaped = patches_with_pos.reshape(num_patches, -1)  # Flatten for linear layer

FUNCTION linear_qkv(x):
  # Linear transformation to generate Q, K, V vectors
  return Q, K, V

Q, K, V = linear_qkv(patches_with_pos_reshaped)

4. Calculate attention scores

d_k = dim // num_heads  # Dimension of Q, K, and V after splitting
scale = 1 / math.sqrt(d_k)  # Scaling factor

attention_scores = Q.dot(K.transpose()) * scale

5. Apply softmax and weighted sum

attention_weights = softmax(attention_scores, axis=-1)  # Softmax along the last dimension
output = attention_weights.dot(V)
```

# Self-Attention Pipeline <a class="anchor" id="self_attention_pipeline"></a>

## Imports <a class="anchor" id="imports"></a>

In [1]:
from argparse import Namespace
from dataclasses import dataclass, field
import math
import os
from pathlib import Path
import re
from typing import Dict, Union

from IPython.display import display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import models as torchvision_models
from torchvision import transforms as pth_transforms

from ai_tutorials import vision_transformer as vits

print(f'Torch version is: {torch.__version__}')
print(f'Is CUDA available: {torch.cuda.is_available()}')

Torch version is: 2.2.0a0+6a974be
Is CUDA available: True


## Mini-Example <a class="anchor" id="mini"></a>

Briefly introduce the concept of self-attention and its role in tasks like machine translation or sentiment analysis.
Mention that this code snippet demonstrates a simplified example of a single step within a self-attention mechanism.

In [3]:
# The softmax function is used to convert scores into probabilities, 
# ensuring they sum to 1. I recommend using nn.functional.
def softmax(x, axis=-1):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return e_x / e_x.sum(axis=axis, keepdims=True)

# Example concatenated embedding
concatenated_embedding = np.array([[1, 2, 3, 7, 8],
                                   [4, 5, 6, 9, 10]])

print("Example Concatenated Embedding:")
print(concatenated_embedding)
print(f"Shape of concatenated embedding: {concatenated_embedding.shape}")
print(f"Embedding dimension: {concatenated_embedding.shape[-1]}")

# Assuming the number of attention heads and other parameters
num_heads = 2
d_model = concatenated_embedding.shape[-1]  # Dimension of the concatenated embedding
print(f"Total embedding dimension (d_model): {d_model}")
d_k = d_v = d_model // num_heads
print(f"Dimension of keys and values (d_k and d_v): {d_k}")
print("-----")

# Generate Query, Key, and Value matrices
Q = concatenated_embedding.dot(np.random.rand(d_model, d_k))  # Query matrix
K = concatenated_embedding.dot(np.random.rand(d_model, d_k))  # Key matrix
V = concatenated_embedding.dot(np.random.rand(d_model, d_v))  # Value matrix
print("Query, Key, and Value Matrices:")
print(f"Q: {Q}\nK: {K}\nV: {V}")
print("-----")

# Calculate attention scores
attention_scores = np.matmul(Q, K.T) / np.sqrt(d_k)
print("Attention Scores:")
print(f"Scores: {attention_scores}")
print("-----")

# Apply softmax to obtain attention weights
attention_weights = softmax(attention_scores, axis=-1)
print("Attention Weights:")
print(f"Weights: {attention_weights}")
print("-----")

# Compute the weighted sum using attention weights
output = np.matmul(attention_weights, V)
print("Output after Weighted Sum - Self-Attention:")
print(f"Shape of output: {output.shape}")
print(output)

Example Concatenated Embedding:
[[ 1  2  3  7  8]
 [ 4  5  6  9 10]]
Shape of concatenated embedding: (2, 5)
Embedding dimension: 5
Total embedding dimension (d_model): 5
Dimension of keys and values (d_k and d_v): 2
-----
Query, Key, and Value Matrices:
Q: [[18.3785915  12.9155287 ]
 [29.45144842 19.88750162]]
K: [[5.01985598 3.13856459]
 [9.50817132 6.72543567]]
V: [[ 5.59360392  9.77897866]
 [10.96456081 15.34470326]]
-----
Attention Scores:
Scores: [[ 93.89961106 184.98574822]
 [148.67643998 292.58772576]]
-----
Attention Weights:
Weights: [[2.76562476e-40 1.00000000e+00]
 [3.16317123e-63 1.00000000e+00]]
-----
Output after Weighted Sum - Self-Attention:
Shape of output: (2, 2)
[[10.96456081 15.34470326]
 [10.96456081 15.34470326]]


## 64x64 Random Images and Positional Encoding <a class="anchor" id="image"></a>

In [None]:
def positional_encoding(embedding):
    """Add positional encoding to input embeddings."""
    print(f"Embedding Dimension: {embedding.shape}")
    seq_len = embedding.shape[1]
    channel_dim = embedding.shape[-1] # channels
    print("Seq len:", seq_len)
    print("Channel dim:", channel_dim)
    print(f"Np zeros shape: {np.zeros((seq_len, channel_dim)).shape}")
#     print(f"Np zeros: {np.zeros((seq_len, channel_dim))}")
    pos_enc = np.zeros((seq_len, channel_dim))
    for pos in range(seq_len):
        for i in range(0, channel_dim, 2):
            pos_enc[pos, i] = math.sin(pos / (10000 ** ((2 * i) // channel_dim)))
            if i + 1 < channel_dim:  # Ensure not to exceed the last dimension
                pos_enc[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) // channel_dim)))
    return embedding + pos_enc

image_size = 64 # High image size will not work on small GPU
num_images = 3

# Parameters
num_channels = 3  # 3 channels for RGB images
num_heads = 2
dim = image_size * image_size * num_channels  # Dimension of the input (total number of features)
print(f"dim: {dim}")
d_k = d_v = dim // num_heads
print(f"d_k and d_v: {d_k}")
print("-----")

images = np.random.rand(num_images, image_size, image_size, num_channels)
print(f"IMAGES SHAPE: {images.shape}") # 2 images, 64 by 64, 3 channels
print("-----")

# Reshape images to have a sequence length of image_size * image_size
images_flattened = images.reshape(num_images, -1, num_channels)
print(f"FLATTENED SHAPE: {images_flattened.shape}")
print(f"FLATTENED: {images_flattened}")
print("-----")

# Add positional encoding
print(f"POSITIONAL ENCODING OUTPUT (USING FLATTENED IMAGE)")
images_with_pos = positional_encoding(images_flattened)
print(f"Input images after positional encoding: (shape={images_with_pos.shape}")
print(images_with_pos)

#### With Basic np.matmul

In [None]:
# Reshape images_with_pos to align with the matrix multiplication
images_with_pos_reshaped = images_with_pos.reshape(num_images, -1, dim)

# Generate Query, Key, and Value matrices
Q = np.matmul(images_with_pos_reshaped, np.random.rand(dim, d_k))  # Query matrix
K = np.matmul(images_with_pos_reshaped, np.random.rand(dim, d_k))  # Key matrix
V = np.matmul(images_with_pos_reshaped, np.random.rand(dim, d_v))  # Value matrix
print(f"Q: {Q}\nK: {K}\nV: {V}")
print("-----")

# Calculate attention scores
scale = None or d_k**-0.5

attention_scores = np.matmul(Q, K.transpose((0, 2, 1))) / np.sqrt(d_k) # Standard
print(f"scores (shape: {attention_scores.shape}):\n{attention_scores}")
print("-----")

# Apply softmax to obtain attention weights
attention_weights = nn.functional.softmax(attention_scores, axis=-1)
print(f"weights (shape: {attention_weights.shape}):\n{attention_weights}")
print("-----")

# Compute the weighted sum using attention weights
output = np.matmul(attention_weights, V)
print(f"Output after weighted sum - self-attention: (shape: {output.shape})")
print(f"OUTPUT:\n{output}")

#### With Linear Projection

In [None]:
# Add positional encoding
images_with_pos = positional_encoding(images_flattened)
print("Input images after positional encoding:")
print(images_with_pos)
print(f"IMAGES WITH POS SHAPE: {images_with_pos.shape}")

# Linear transformation to generate Q, K, and V matrices
linear_qkv = nn.Linear(num_channels, d_k * 3, bias=False) # Linear(in_features=3, out_features=18432, bias=False)
qkv = linear_qkv(torch.tensor(images_with_pos, dtype=torch.float32))
print(f"QKV AFTER LINEAR: {qkv.shape}")
qkv = qkv.reshape(num_images, -1, 3, num_heads, d_k // num_heads).permute(2, 0, 3, 1, 4)
print(f"QKV AFTER RESHAPE: {qkv.shape}")

Q, K, V = qkv[0], qkv[1], qkv[2]

print(f"Q:\n{Q}\nK:\n{K}\nV:\n{V}")
print("-----")
# Calculate attention scores
scale = d_k ** -0.5
attention_scores = (torch.matmul(Q, K.transpose(-2, -1))) * scale
print(f"scores (shape: {attention_scores.shape}):\n{attention_scores}")
print("-----")
# Apply softmax to obtain attention weights
attention_weights = nn.functional.softmax(attention_scores.detach().numpy(), axis=-1)
print(f"weights (shape: {attention_weights.shape}):\n{attention_weights}")
print("-----")
# Compute the weighted sum using attention weights
output = np.matmul(attention_weights, V.detach().numpy())

print(f"Output after weighted sum - self-attention: (shape: {output.shape})")
print(output)

## Image with patches example <a class="anchor" id="patches"></a>

In [None]:
def positional_encoding(embedding):
    """Add positional encoding to input embeddings."""
    seq_len = embedding.shape[1]
    channel_dim = embedding.shape[-1]
    pos_enc = np.zeros((seq_len, channel_dim))
    for pos in range(seq_len):
        for i in range(0, channel_dim, 2):
            pos_enc[pos, i] = math.sin(pos / (10000 ** ((2 * i) // channel_dim)))
            if i + 1 < channel_dim:  
                pos_enc[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) // channel_dim)))
    return embedding + pos_enc

def extract_patches(images, patch_size):
    patches = []
    for image in images:
        for i in range(0, image.shape[0] - patch_size + 1, patch_size):
            for j in range(0, image.shape[1] - patch_size + 1, patch_size):
                patch = image[i:i+patch_size, j:j+patch_size, :]
                patches.append(patch.flatten())
    return np.array(patches)


In [None]:
# Example images (for demonstration, you would use your actual images here)
image_w = 500
image_h = 375

# Parameters
num_channels = 3 # Assuming RGB images
num_heads = 8
patch_size = 16 # Patch size (similar to DINO)
dim = num_channels * patch_size * patch_size  # Dimension of each patch
print(f"dim: {dim}")
d_k = d_v = dim // num_heads
print(f"d_k and d_v: {d_k}")
print("-----")

image = np.random.rand(1, image_h, image_w, num_channels)[0]
print(f"IMAGE SHAPE: {image.shape}") # 2 images, 64 by 64, 3 channels
print("-----")

# Extract patches from the input image
image_patches = extract_patches([image], patch_size)  # Pass a list containing the single image
print(f"PATCHES SHAPE: {image_patches.shape}")
print("-----")

# Apply positional encoding to each patch individually
num_patches = image_patches.shape[0]
pos_enc = positional_encoding(image_patches.reshape(num_patches, -1, num_channels))
print("Positional encoding:")
print(f"SHAPE: {pos_enc.shape}")
print("-----")

# Concat positional encoding with image_patches
image_patches_with_pos = pos_enc.reshape(1, -1, dim)
print("Image patches after positional encoding:")
print(f"SHAPE: {image_patches_with_pos.shape}")

In [None]:
# Linear transformation to generate Q, K, and V matrices
linear_qkv = nn.Linear(dim, d_k * 3, bias=False)
print(linear_qkv)

# Apply linear transformation to generate Q, K, and V matrices
linear_qkv = nn.Linear(dim, d_k * 3, bias=False)
qkv = linear_qkv(torch.tensor(image_patches_with_pos, dtype=torch.float32))
qkv = qkv.reshape(1, -1, 3, num_heads, d_k // num_heads).permute(2, 0, 3, 1, 4)
Q, K, V = qkv[0], qkv[1], qkv[2]
print(f"Q:\n{Q}\nK:\n{K}\nV:\n{V}")
print("-----")

# Calculate attention scores
scale = d_k ** -0.5
# Transpose K before performing matrix multiplication
attention_scores = (torch.matmul(Q, K.transpose(-2, -1))) * scale
print(f"Attention scores (shape: {attention_scores.shape}):\n{attention_scores}")
print("-----")

# Apply softmax to obtain attention weights
attention_weights = nn.functional.softmax(attention_scores, dim=-1)
print(f"Attention weights (shape: {attention_weights.shape}):\n{attention_weights}")
print("-----")

# Compute the weighted sum using attention weights
output = torch.matmul(attention_weights, V)
print(f"Output after self-attention: (shape: {output.shape})")
print(output)

#### Approach 2

In [None]:
image_size = (500, 375)
num_images = 1
patch_size = 16  # Define the patch size

num_channels = 3
num_heads = 2
dim = patch_size * patch_size * num_channels  # Dimension of each patch
print(f"dim: {dim}")
d_k = d_v = dim // num_heads
print(f"d_k and d_v: {d_k}")
print("-----")

images = np.random.rand(num_images, image_size[1], image_size[0], num_channels)  # Note: OpenCV style (width, height)
print(f"IMAGE SHAPE: {images.shape}") # 2 images, 64 by 64, 3 channels
print("-----")

# Extract patches from images
patches = extract_patches(images, patch_size)
print(f"PATCHES SHAPE: {patches.shape}")
print("-----")

# Add positional encoding
num_patches = patches.shape[0]
patches_with_pos = positional_encoding(patches.reshape(num_patches, -1, num_channels))
print(f"PATCHES WITH POS: {patches_with_pos.shape}")
print("-----")

# Reshape patches_with_pos to match the expected input shape for linear layer
patches_with_pos_reshaped = patches_with_pos.reshape(num_images, -1, dim)
print(f"PATCHES WITH POS RESHAPE: {patches_with_pos.shape}")
print("-----")

In [None]:
# Linear transformation to generate Q, K, and V matrices
linear_qkv = nn.Linear(dim, d_k * 3, bias=False)
qkv = linear_qkv(torch.tensor(patches_with_pos_reshaped, dtype=torch.float32))
qkv = qkv.reshape(num_images, -1, 3, num_heads, d_k // num_heads).permute(2, 0, 3, 1, 4)
Q, K, V = qkv[0], qkv[1], qkv[2]
print(f"Q:\n{Q}\nK:\n{K}\nV:\n{V}")
print("-----")

# Calculate attention scores
scale = d_k ** -0.5
attention_scores = (torch.matmul(Q, K.transpose(-2, -1))) * scale
print(f"Attention scores (shape: {attention_scores.shape}):\n{attention_scores}")
print("-----")

# Apply softmax to obtain attention weights
attention_weights = nn.functional.softmax(attention_scores, dim=-1)
print(f"Attention weights (shape: {attention_weights.shape}):\n{attention_weights}")

# Compute the weighted sum using attention weights
output = torch.matmul(attention_weights, V)
print(f"Output after weighted sum - self-attention: (shape: {output.shape})")
print(output)

## Model <a class="anchor" id="model"></a>

In [None]:
arch = "vit_small" # num_heads = 6
checkpoint_key = "teacher"
patch_size = 16

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
if arch in torchvision_models.__dict__.keys():
    model = (torchvision_models.__dict__[arch](
        num_classes=0))
    model.fc = nn.Identity()
else:
    model = vits.__dict__[arch](
        patch_size=patch_size, num_classes=0)

url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"

state_dict = torch.hub.load_state_dict_from_url(
    url="https://dl.fbaipublicfiles.com/dino/" + url)
model.load_state_dict(state_dict, strict=True)

for p in model.parameters():
    p.requires_grad = False

model.cuda()
model.eval()

## Helper Functions <a class="anchor" id="functions"></a>

### Transform

In [None]:
transform = pth_transforms.Compose([
    # pth_transforms.Resize((480, 480)),
    pth_transforms.ToTensor(),
    pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

### Get Attention

In [None]:
def get_attention(img_name, path, patch_size):
    img_path = f"{path}/{img_name}"
    img = Image.open(img_path).convert('RGB')
    print(f"Original Image Size: {img.size}")

    img = transform(img)
    print(f"Transformed Image Size: {img.shape}")
        
    # Make the image divisible by the patch size
    w, h = img.shape[1] - img.shape[1] % patch_size, img.shape[2] - img.shape[2] % patch_size
    img = img[:, :w, :h].unsqueeze(0)
    print(f"Image by Patch Size: {img.shape}")
    
    w_featmap = img.shape[-2] // patch_size
    h_featmap = img.shape[-1] // patch_size

    y, attentions, test_scores, test_weights = model.get_last_selfattention(img.to(device))

    nh = attentions.shape[1]  # Number of heads

    # Keep only the output patch attention
    attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
    
    attentions = attentions.reshape(nh, w_featmap, h_featmap)
    attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=16, mode="nearest")[0].cpu().numpy()
    
    return y, attentions, test_scores, test_weights

### Plot Attentions Function

In [None]:
def plot_img(img, img_per_col = 3, ):
    num_images = len(img)
    num_images

    nrow = num_images // img_per_col
    ncol = num_images // nrow

    fig, axes = plt.subplots(
        nrow, ncol,
        gridspec_kw=dict(wspace=0.05, hspace=0.05,
                         top=1. - 0.5 / (nrow + 1), bottom=0.5 / (nrow + 1),
                         left=0.5 / (ncol + 1), right=1 - 0.5 / (ncol + 1)),
        figsize=(ncol + 10, nrow + 10)
    )

    for i, image in enumerate(img):
        ax = axes[i // ncol, i % ncol]
        ax.imshow(image)
        ax.axis("off")

    plt.show()

# Attention Run <a class="anchor" id="runs"></a>

### Variables

In [None]:
full_path = "/usr/src/ai_tutorials/data/coco_2017/val2017"
img_name = "000000000139.jpg"

### Display image

In [None]:
img = Image.open(full_path + '/' + img_name)
print(f"IMG SHAPE: {img.size}")
display(img)

### Run attention pipeline

In [None]:
y, img_attentions, scores, weights = get_attention(img_name=img_name, path=full_path, patch_size=patch_size)

In [None]:
type(img_attentions), img_attentions.shape

In [None]:
type(y), y.shape

In [None]:
scores.shape, weights.shape

In [None]:
# ATTENTION SCORES: torch.Size([1, 6, 1041, 1041])
# tensor([[ 3.4057, -3.8283, -3.6789, -3.8322, -4.2521],
#         [ 2.5470, -2.0419, -1.8337, -2.0044, -2.2881],
#         [ 2.5867, -1.9636, -1.7265, -1.9115, -2.1393],
#         [ 2.4455, -1.8812, -1.6428, -1.8065, -2.0402],
#         [ 2.4473, -1.8908, -1.6959, -1.8473, -2.0370]], device='cuda:0')

# ATTENTION WEIGHTS (SOFTMAX): torch.Size([1, 6, 1041, 1041])
# tensor([[5.1433e-02, 3.7115e-05, 4.3097e-05, 3.6970e-05, 2.4294e-05],
#         [2.3387e-02, 2.3772e-04, 2.9273e-04, 2.4679e-04, 1.8583e-04],
#         [2.3418e-02, 2.4740e-04, 3.1357e-04, 2.6064e-04, 2.0753e-04],
#         [2.0375e-02, 2.6918e-04, 3.4163e-04, 2.9003e-04, 2.2961e-04],
#         [2.0554e-02, 2.6846e-04, 3.2624e-04, 2.8040e-04, 2.3195e-04]],
#        device='cuda:0')

In [None]:
first_score = scores[0][0][:5][:5]
first_score

In [None]:
# Compute softmax along the last dimension (dim=1)
attn_weights = first_score.softmax(dim=-1)
print(attn_weights)

In [None]:
first_weight = weights[0][0][:5][:5]
first_weight

### Plot attentions

In [None]:
plot_img(img=img_attentions)