# Attention Basics
In this lab we will explore the Attention mechanism.

**Remember** to enable GPU.

## Using dot-product attention to look up and retrieve information from a database
Recall that dot-product attention is defined as

> $\text{softmax}(QK^T)V$

where $Q$ holds the queries (or query if there is only one), $K$ holds the keys, and $V$ contains the values.

In the database analogy, the attention weights

> $\text{softmax}(QK^T)$

are used to perform the **lookup**; the attention weights tell us how much we want of each sample in the database.

Multiplying $\text{softmax}(QK^T)$ with $V$ is the **retrieval** step, where we retrieve a weighted sum of the elements in $V$.

In the following tasks we will learn how to actually apply this principle to look up and retrieve data in a small database of images.

## 1. Data download
We will be using the MNIST dataset.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch import optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision
import torch.utils.data.dataloader as dataloader

import torch.nn.functional as F

from tqdm.notebook import trange, tqdm

In [None]:
# Define the root directory of the dataset
data_set_root = "."

# Define transformations to be applied to the dataset
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize([0.5], [0.5])  # Normalize image data with mean 0.5 and standard deviation 0.5
])

# Load the MNIST dataset
dataset = datasets.MNIST(data_set_root, train=True, download=True, transform=transform);

## 2. Create database of images
Here, we build a small "database" of 100 randomly selected images. After that we will use attention to look up and retrieve samples from the database.

In [None]:
# Make it deterministic
torch.manual_seed(42)

# Specify the number of examples to select randomly
num_of_examples = 100

# Randomly select indices from the dataset
rand_perm = torch.randperm(dataset.data.shape[0])[:num_of_examples]

# Extract and concatenate the images of randomly selected examples into a tensor
# These are the values from which we want to retrieve.
image_database = torch.cat([dataset.__getitem__(i)[0].reshape(1, -1) for i in rand_perm])

print("Shape of image_database", image_database.detach().numpy().shape)

The `image_database` has shape $100 \times 784$ and holds the 100 selected images, flattened into 784-length vectors.

Let's visualise the images stored in it:

In [None]:
out = torchvision.utils.make_grid(image_database.reshape(-1, 1, 28, 28), 10, normalize=True, pad_value=0.5)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

## 3. Hard attention
The simplest form of attention is **hard attention**, where we always select exactly one element from the database. One way to implement hard attention (using dot-product attention) is to ensure all queries and keys are unique one-hot vectors (why?).

How do we implement this in practice? One way would be as follows:
- Assign a unique one-hot vector to each sample in the database. These will constitute the rows of the key matrix $K$.
- The query $Q$ is also going to be a one-hot vector (note that in the example below we have a single query, meaning that $Q$ is a vector. In general, you can have multiple queries, and in that case $Q$ is a matrix).
- Given the query $Q$, we can match it against all keys by calculating the matrix product $QK^T$.
- Since $K$ consists of unique one-hot vectors, this will result in *exactly* one match (i.e., there is exactly one row $K_i$ for which $Q \cdot K_i=1$; for all other rows $j \neq i$ we will have $Q \cdot K_j=0$).
- Hence, the output of $QK^T$ is another one-hot vector where the 1-element is at the index corresponding to $Q$'s location in the database.
- The product $QK^TV$ will retrieve that element from the database.
- In order to retrieve the images from the database, we set $V$ = `image_database`.

**Note:** we don't need to apply the softmax here, because $QK^T$ is already normalized (because it is a one-hot vector, the elements already sum to 1).

Let's implement this:

In [None]:
# We arbitrarily select the 10'th sample in the database as our query
# and use the corresponding one-hot encoding as the query (Q)
q_index = 10
Q = F.one_hot(torch.tensor([q_index]), num_of_examples)
print("Shape of query vector (Q):", Q.detach().numpy().shape)

# Lets visualise the image at this index
plt.figure(figsize = (3,3))
_ = plt.imshow(image_database[q_index].reshape(28, 28).numpy(), cmap="gray")

In [None]:
# The key matrix (K) will consist of unique one-hot vectors, one for every image
# in our dataset
K = F.one_hot(torch.arange(num_of_examples), num_of_examples)

# We already know the the query is at index 10 (q_index).
# Let's make it a little bit harder by randomly shuffling the keys.
rand_perm = torch.randperm(num_of_examples)
K = K[rand_perm]

# The keys and values must match, so we need to shuffle the values as well.
# (Recall that the values correspond to the image_database)
V = image_database[rand_perm]

print("Shape of key matrix (K)", K.detach().numpy().shape)
print("Shape of value matrix (V)", V.detach().numpy().shape)

### 3.1 Question
- Why do $K$ and $V$ have the shapes that they have?

Due to the shuffling we no longer know at which index the query is located in the database.

Let's perform the **lookup** (i.e, $QK^T$):

In [None]:
# Multiply our query with the keys
alignment_scores = torch.mm(Q, K.t()).float()

# Print to confirm that the result is a one-hot vector
print(alignment_scores)

### 3.2 Question
- You should see that `alignment_scores` is a one-hot vector with a 1 at the index at which the query is located in the database. Why?

We can now perform the **retrieval** be multiplying $QK^T$ with $V$.

In [None]:
# Perform matrix multiplication between the resulting index map and the randomly shuffled dataset
output = torch.mm(alignment_scores, V)

# Visualize the image at the specified index
plt.figure(figsize=(3, 3))
_ = plt.imshow(output.reshape(28, 28).numpy(), cmap="gray")

In summary, we used hard attention to find and retrieve the query image after shuffling the database.

## 4. Soft attention: Quering the database using random vectors
What happens if we replace the one-hot vectors with random vectors?

In that case, the lookup

> $\text{softmax}(QK^T)$

will result in a set of non-zero weights. In practise, this means that when we do the retrieval

> $\text{softmax}(QK^T)V$

we are going to get some weighted average of **all** the elements in $V$.

Let's verify this.

In [None]:
# Define the embedding size for each of the vectors
vec_size = 4

# Create a random query vector
Q = torch.randn(1, vec_size)

# Create a random key vectors (one for each sample in the database)
K = torch.randn(num_of_examples, vec_size)

V = image_database

# Lookup
alignment_scores = torch.mm(Q, K.t()).float()
attention_weights = F.softmax(alignment_scores, 1)

# Retrieval
output = torch.mm(attention_weights, V)

Let's visualize the result.

In [None]:
plt.figure(figsize = (3,3))
_ = plt.imshow(output.reshape(28, 28).numpy(), cmap="gray")

### 4.1 Question
- You should see that the retrieved image is mostly blurry (otherwise run the code block again). What is the reason for this (hint: think weighted average).

To see how much is extracted from each of the 100 samples in the database, you can plot the attention weights:

In [None]:
_ = plt.plot(attention_weights.squeeze().detach().numpy())

## 5. Multiple Queries
Until now we have queried the database with a single vector. We can easily perform multiple queries at the same time by putting the queries into matrix $Q$. In the example below, there are 8 queries, which should result in 8 retrievals from the database.

In [None]:
# Define the size for each of the vectors
vec_size = 16

# Number of Queries
num_q = 8

# Create random query vectors
Q = torch.randn(num_q, vec_size)

# Create a random key vector for each image in the dataset
K = torch.randn(num_of_examples, vec_size)

V = image_database

# Lookup
alignment_scores = torch.mm(Q, K.transpose(0, 1)).float()
attention_weights = F.softmax(alignment_scores, -1)

# Retrieval
output = torch.mm(attention_weights, V)

print("Size of retrieved data:", output.detach().numpy().shape)

In [None]:
# Lets visualise an entire batch of images!
plt.figure(figsize = (20,10))
out = torchvision.utils.make_grid(output.reshape(num_q, 1, 28, 28), 8, normalize=True)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

**Note:** Because we used random vectors for the queries and keys, the lookup and retrieval doesn't really make a whole lot af sense. Next up, you will learn how to train an attention layer to actually do the lookup and retrieval the way it is supposed to work.

## 6. Pytorch Multi-Head Attention
Of course Pytorch has it's own implementation of scaled dot-product attention.

[Pytorch MultiheadAttention](https://pytorch.org/docs/2.1/generated/torch.nn.MultiheadAttention.html#torch.nn.MultiheadAttention)

It supports multiple attention heads, as in the Transformer paper, but we will only use
one head for simplicity.

Recall from lecture 12 that in the transformer model the queries, keys, and values are actually calculated by multiplying the vector embeddings of the inputs and/or outputs (denoted $X$ and $Y$ in the slides) with *learnable* matrices $W_q$, $W_k$, and $W_v$. For instance, for cross attention we would have

> $Q=YW_q$

> $K=XW_k$

> $V=XW_v$

In the general case, if we denote the embeddings $X_q$, $X_k$, and $X_v$, the queries, keys, and values are

> $Q=X_qW_q$

> $K=X_kW_k$

> $V=X_vW_v$



### 6.1 Question
- If we denote the input data $X$, how do you calculate $Q$, $K$, and $V$ for *self-attention* (easy question!).

The code block below shows how to use Pytorch's MultiheadAttention layer. Note that it calculates $Q$, $K$, and $V$ internally using the equations above. So we just need to provide $X_q$, $X_k$, and $X_v$ as input:

In [None]:
# Define the size for each of the vectors
vec_size = 16

# Number of attention heads
num_heads = 1

# Batch Size
batch_size = 4

# Create a batch of a single random query vector
X_q = torch.randn(batch_size, 1, vec_size)

# Create random key and value vectors for each image in the dataset
X_k = torch.randn(batch_size, num_of_examples, vec_size)
X_v = torch.randn(batch_size, num_of_examples, vec_size)

# Initialize a MultiheadAttention module with specified parameters
multihead_attn = nn.MultiheadAttention(vec_size, num_heads, batch_first=True)

# Perform a forward pass through the Multi-Head Attention module
# Returns the attention output and the attention weights
attn_output, attn_output_weights = multihead_attn(X_q, X_k, X_k, average_attn_weights=False)

print("X_q:", X_q.detach().numpy().shape)
print("X_k:", X_k.detach().numpy().shape)
print("X_v:", X_v.detach().numpy().shape)

# Print the shapes of the output of the forward pass from Multi-Head Attention module

# Softmaxed "attention mask" shape
print("Attention weights:", attn_output_weights.detach().numpy().shape)

# Attention output shape
print("Attention output (retrieval):", attn_output.detach().numpy().shape)

### 6.2 Questions
- What does the first dimension (size 4) of these arrays represent?
- What does `attn_output` represent?
- What does `attn_output_weights` represent?

### 6.3. Train a Multi-Head Attention
Lets train a model with attention that when given an image will try to find the best matching image in `image_database`.

Specifically, our goal is to train a model that, given any of the 60,000 images in the MNIST dataset as the query, will match the query against `image_database` and retrieve the best matching image.

Note that in this particular example, we enforce that the key and value embeddings are the same, i.e., we set $X_k=X_v$. To calculate the embeddings we use a one-layer MLP.

In [None]:
class AttentionTest(nn.Module):
    def __init__(self, num_of_examples=100, embed_dim=32, num_heads=1):
        super(AttentionTest, self).__init__()

        # Simple one-layer MLP use to embed images
        # Note the embedding is part of the model (learnable)
        self.img_mlp = nn.Sequential(
            nn.Linear(784, embed_dim),   # Linear layer to embed image data into a lower-dimensional space
        )

        # Define the Multi-Head Attention mechanism
        self.mha = nn.MultiheadAttention(
            embed_dim=embed_dim,     # Dimensionality of the embedding space
            num_heads=num_heads,     # Number of attention heads
            batch_first=True         # Whether the input is batch-first or sequence-first
        )

    def forward(self, img, image_database):
        # Embed the query image (img)
        X_q = self.img_mlp(img)

        # Embed the images in the database
        X_v = self.img_mlp(image_database)
        X_k = X_v # Key and Value embeddings are the same in this example

        # Lookup
        attn_output, attn_output_weights = self.mha(X_q, X_k, X_v)

        # Retrieval
        # (by design our model will learn to retrieve images from image_database)
        output = torch.bmm(attn_output_weights, image_database)

        return output, attn_output_weights

### 6.4 Question
- See if you can understand what the `AttentionTest` module does.

### 6.5 Set up training

In [None]:
# Set the device to GPU if available, otherwise use CPU
device = torch.device(0 if torch.cuda.is_available() else 'cpu')

# Define the dimensionality of the embedding space
embed_dim = 32

# Define the number of attention heads
num_heads = 1

# Define the batch size
batch_size = 64

# Duplicate the data value tensor for each batch element and move it to the specified device
image_database_batches = image_database.unsqueeze(0).expand(batch_size, num_of_examples, -1).to(device)

In [None]:
# Create a DataLoader for training the model
train_loader = dataloader.DataLoader(
    dataset,                   # Dataset to load
    shuffle=True,              # Shuffle the data for each epoch
    batch_size=batch_size,     # Batch size for training
    num_workers=4,             # Number of processes to use for data loading
    drop_last=True             # Drop the last incomplete batch if it's smaller than the batch size
)

### 6.6 Initialize Model and Optimizer

In [None]:
# Create an instance of the AttentionTest model
mha_model = AttentionTest(
    num_of_examples=num_of_examples,   # Number of examples in the dataset
    embed_dim=embed_dim,               # Dimensionality of the embedding space
    num_heads=num_heads                # Number of attention heads
).to(device)                           # Move the model to the specified device

# Define the Adam optimizer for training the model
optimizer = optim.Adam(
    mha_model.parameters(),  # Parameters to optimize
    lr=1e-4                   # Learning rate
)

# List to store the training loss for each epoch
loss_logger = []

### 6.7 Training

In [None]:
# Set the model to training mode
mha_model.train()

# Loop through 10 epochs
for _ in trange(10, leave=False):
    # Iterate over the training data loader
    for queries, _ in tqdm(train_loader, leave=False):

        # Reshape the input queries and move it to the specified device
        X_q = queries.reshape(queries.shape[0], 1, -1).to(device)

        # Perform forward pass through the Multi-Head Attention model
        attn_output, attn_output_weights = mha_model(X_q, image_database_batches)

        # Calculate the mean squared error loss between the output and input images
        loss = (attn_output - X_q).pow(2).mean()

        # Zero the gradients, perform backward pass, and update model parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Append the current loss value to the loss logger
        loss_logger.append(loss.item())

Plot the loss:

In [None]:
_ = plt.plot(loss_logger[100:])
print("Minimum MSE loss %.4f" % np.min(loss_logger))

### 6.8 Using the Model
Let's see how to use the model to look up and retrieve images from the database.

In [None]:
# Set the model to evaluation mode
mha_model.eval()

# Perform forward pass without gradient computation
with torch.no_grad():
    queries, _ = next(iter(train_loader))

    # Reshape input data and move it to the specified device
    X_q = queries.reshape(queries.shape[0], 1, -1).to(device)

    # Perform forward pass through the Multi-Head Attention model
    attn_output, attn_output_weights = mha_model(X_q, image_database_batches)

In [None]:
# Select index of query image
q_index = 10

In [None]:
# Show the query image
plt.figure(figsize=(3, 3))
out = torchvision.utils.make_grid(X_q[q_index].cpu().reshape(-1, 1, 28, 28), 1,
                                  normalize=True, pad_value=0.5)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

Show the retrival (i.e., sum of all samples in `image_database` weighted by the attention weights).

In [None]:
plt.figure(figsize=(3, 3))
out = torchvision.utils.make_grid(attn_output[q_index].cpu().reshape(-1, 1, 28, 28), 1,
                                  normalize=True, pad_value=0.5)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

Show attention weights (how much is retrieved from each of the 100 samples in `image_database`).

In [None]:
# Plot the attention weights for the given input
_ = plt.plot(attn_output_weights[q_index, 0].cpu().numpy().flatten())

Use the attention weights to find the 10 "closest" matches.

In [None]:
top10 = attn_output_weights[q_index, 0].argsort(descending=True)[:10]
top10_data = image_database_batches[q_index, top10].cpu()

plt.figure(figsize=(10, 4))
out = torchvision.utils.make_grid(top10_data.reshape(-1, 1, 28, 28), 10, normalize=True, pad_value=0.5)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

Show query image and the retrieved image (weighted sum of all images in `image_database`) for all images in the batch.

In [None]:
# Reshape the target and returned images
target_img = X_q.reshape(batch_size, 1, 28, 28)
indexed_img = attn_output.reshape(batch_size, 1, 28, 28)

# Stack the images with the returned image on top
img_pair = torch.cat((indexed_img, target_img), 2).cpu()

# Let's visualize the pairs of images, with the returned image on top and the target on bottom
plt.figure(figsize=(10, 10))
out = torchvision.utils.make_grid(img_pair, 8, normalize=True, pad_value=0.5)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))