<a href="https://colab.research.google.com/github/mtwenzel/image-video-understanding/blob/master/Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Self-Attention by Example

Partially taken from https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a#8481

In [None]:
#@title Imports
import math

import torch
from torch.nn.functional import softmax

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

## Create a tensor with input data

In a real setting, this tensor would be the result of some data encoding step (when considering the first input to the first attention layer), or the result of the previous attention layer.

If for example your input is text, and your word embeddings have 512 dimensions, each row in the tensor $x$ would have 512 entries, and the tensor would have as many rows as your sentence has words.

For images (here described along the lines of the ViT), the number of rows in $x$ is the number of tiles the images get subdivided into ($16 \times 16$ in ViT). The length of each line is the length of the embedding vector after transforming each tile with the patch encoder. [See the ViT publication for details](https://arxiv.org/abs/2010.11929)

Note that there is the embedding size of the original data, but also the internal encoding size that will be determined by the shape of the K, Q, and V matrices.

In [None]:
#@title Define data matrix {run:"auto"}
#@markdown We create random(!) data of a given size. Set the number of tokens and the encoding/embedding length here.
num_tokens = 6 #@param {type:"slider", min:"1", max:"16"}
num_embedding_features = 9 #@param {type:"slider", min:"1", max:"32"}

#@markdown Note that this enables us to show the process in the following, but that this is not describing a real task.

x = torch.rand([num_tokens,num_embedding_features])
x

In [None]:
#@title Define second data matrix for cross attention {run:"auto"}
#@markdown To demonstrate cross attention with smaller attention matrix analogous to Perceiver or DETR, create a "learned queries" matrix $x_2$
#@markdown For consistency, you can only adjust the number of tokens. The embedding dimension is kept from above.
num_ca_tokens = 3 #@param {type:"slider", min:"1", max:"10"}

x2 = torch.rand([num_ca_tokens, num_embedding_features])
x2

## Create a set of weight tensors. 
We are looking at single-head attention only for the moment. For multi-head attention, each weight matrix would be replicated (with independent weights) for each head. You will see this in the second half of the notebook.

Each weight tensor has to have as many rows as the tokens have dimensions. Our input vectors have ```num_embedding_features``` dimensions. Let's create random weight matrices of the according size. You are free to select the other dimension, which will then be the internal embedding dimension.

Observe how the size of these matrices does not depend on the number of tokens anymore.

Note that this will result in a matrix output after the attention mechanism, instead of a single token.

In [None]:
#@title Get K, Q, V transform matrices {run:"auto"}

internal_embedding_dimensions = 7 #@param {type:"slider", min:"1", max:"32"}

w_key = torch.rand([num_embedding_features,internal_embedding_dimensions])
w_query = torch.rand([num_embedding_features,internal_embedding_dimensions])
w_value = torch.rand([num_embedding_features,internal_embedding_dimensions])

print(f'Initialized random tensors w_key {tuple(w_key.shape)}, w_query {tuple(w_query.shape)}, w_value {tuple(w_value.shape)}.')

## K, Q, and V

The actual keys, querys and values are the result of the multiplication of input tensor with weight tensors.

Their dimension is:
* each row has as many entries as the weight tensors (three in our setup)
* the number of rows equals the number of input tokens (five in our setup)

In [None]:
# The '@' operator performs matrix multiplication in python,
# (in pytorch that is equivalent to `torch.matmul()`)
keys = x @ w_key
querys = x @ w_query
values = x @ w_value

# This would be the cross attention with a potentially different number of tokens
xattn_q = x2 @ w_query


print("Keys:",keys)
print("Queries:",querys)
print("Cross-attention Queries:",xattn_q)
print("Values:",values)

## Softmax Attention

The size of the square attention matrix equals the number of input tokens in both dimensions. 

In [None]:
# keys.T transposes the keys matrix
attn_scores = querys @ keys.T
attn_scores_softmax = softmax(attn_scores, dim=-1)

# For readability, round the scores to a definable number of decimal places
print(attn_scores_softmax.round(decimals = 2))

# Plot self attention matrix
sns.heatmap(attn_scores_softmax.numpy())

### Do the same for "cross attention"

In [None]:
xattn_scores = xattn_q @ keys.T
xattn_scores_softmax = softmax(xattn_scores, dim=-1)

# For readability, round the scores to a definable number of decimal places
print(xattn_scores_softmax.round(decimals = 2))

# Plot self attention matrix
sns.heatmap(xattn_scores_softmax.numpy(), square = True)

## Multiply softmax attention with V to obtain the result

In a transformer, dense layers would follow that can 
* reduce a multi-head attention result
* enforce correct dimensionality to use output in next input.

We will see this after the following experiment with multi-head attention.

In [None]:
weighted_values = values[:,None] * attn_scores_softmax.T[:,:,None]
outputs = weighted_values.sum(dim=0)
print(outputs)

# Multi-Head Attention

In [None]:
#@title Define number of heads, and create according random weight matrices K, Q, V {run:"auto"}
#@markdown The only required change is to stack multiple K, Q, V. 

num_heads = 4 #@param {type:"slider", min:"2", max:"10"}

w_key = torch.rand([num_heads, num_embedding_features,internal_embedding_dimensions])
w_query = torch.rand([num_heads,num_embedding_features,internal_embedding_dimensions])
w_value = torch.rand([num_heads,num_embedding_features,internal_embedding_dimensions])
print(f'Initialized new random tensors w_key {tuple(w_key.shape)}, w_query {tuple(w_query.shape)}, w_value {tuple(w_value.shape)}.')

In [None]:
keys = x @ w_key
querys = x @ w_query

xattn_q = x2 @ w_query

values = x @ w_value

print(keys)
#print(querys)
#print(xattn_q)
#print(values)

In [None]:
# .mT transposes the last two dimensions
# (transposing each keys matrix, independently for each head)
attn_scores = querys @ keys.mT 
attn_scores_softmax = softmax(attn_scores, dim=-1)

def show_multihead_attention(attn_scores_softmax, column_count = 2):
  '''Plot self attention matrices for different heads'''
  row_count = num_heads // column_count
  f, axx = plt.subplots(row_count, column_count,
                        sharex = True, sharey = True,
                        figsize = (column_count * 2, row_count * 2))
  for ax, t in zip(axx.ravel(), attn_scores_softmax.numpy()):
    sns.heatmap(t, ax = ax, square = True)

show_multihead_attention(attn_scores_softmax)

### Again, do the same for "cross attention"

In [None]:
xattn_scores = xattn_q @ keys.mT
xattn_scores_softmax = softmax(xattn_scores, dim=-1)

show_multihead_attention(xattn_scores_softmax)

In [None]:
weighted_values = values[:,:,None] * attn_scores_softmax.mT[:,:,:,None]
outputs = weighted_values.sum(dim=1)
print(outputs)
outputs.shape

## Convert into expected shape for next layer

A MLP is employed to "fix the dimensions". We require the output shape to match the original $x$ input shape. This can be achieved by creating a MLP weight matrix of the appropriate shape. 

It has one dimension given by the number of tokens times the internal embedding dimension, the other by the number of original embedding features.

Notice that the MLP will require a 2D tensor input -- therefore the heads' outputs need to be flattened. The MLP can then mix the results from the different heads (per token).

Consequentially, the MLP will have a rather large weight matrix.

# Create MLP weights
The size of the MLP output needs to match the original $x$ input data matrix.
This makes it possible to stack Attention blocks. Therefore, it has no free parameters.

In [None]:
combined_heads = outputs.permute(1,0,2).flatten(1) # The heads are combined. The MLP will mix their results.

mlp_weights = torch.rand([num_heads*internal_embedding_dimensions, num_embedding_features])

print("MLP size:", mlp_weights.shape)
result = combined_heads @ mlp_weights
print("Resulting shape: ", result.shape)
print("Resulting new input to next attention block:", result)

assert result.shape==x.shape