# **Introducing The Attention Component of Transformers** 

According to Perplexity **Attention**, which made a huge splash in this [seminal paper](https://arxiv.org/abs/1706.03762), can briefly be defined as: 

>The attention mechanism mimics human cognitive processes by allowing a model to prioritize certain inputs over others based on their relevance to the task at hand. This is particularly useful in scenarios where the input data is large and complex, enabling the model to selectively concentrate on the most pertinent elements while ignoring less relevant information.
>
> ### Key Concepts
>
>    **Encoder-Decoder Architecture**: The attention mechanism is often employed within an encoder-decoder framework. The encoder processes the input sequence and generates a set of hidden states, while the decoder uses these states to produce the output sequence. Traditional models would pass only the final hidden state from the encoder to the decoder, but attention allows for all hidden states to be considered1
>    
>    **Attention Weights**: The mechanism assigns weights to different parts of the input, indicating their relative importance. These weights are dynamically calculated during model training and are used to create a context vector that emphasizes significant input elements.
>
> ### How Attention Works
>
>    **Calculating Attention Scores**: For each element in the output sequence, attention scores are computed by comparing it with all elements in the input sequence. This can be done using methods such as dot-product or additive attention, where each score reflects how relevant an input element is to a particular output element.
>
>    **Creating Context Vectors**: The attention scores are normalized using a softmax function to produce a probability distribution. This distribution is then used to compute a weighted sum of the input elements, resulting in a context vector that highlights important features.
>
>    **Decoding Process**: The context vector is fed into the decoder alongside its current hidden state. This allows the decoder to generate output tokens based on both the immediate context and relevant parts of the input.

It is important to note that Stable Diffusion's implementation of Attention is quite sub-optimal. We may consider moving to better approaches in later NBs.

Initially, we will focus on **1d-Attention**, which was predominantly used for NLP. For Stable Diffusion, we will flatten all pixel rows into single vectors for each channel.

In [1]:
import math, torch
from torch import nn
from miniai.activations import *
import matplotlib.pyplot as plt
from diffusers.models.attention import Attention # AttentionBlock has been deprecated

In [2]:
set_seed(42)
# Creating a tensor to represent a 16x16 image, with 32 channels and a batch size of 64 (NCHW)
# NLP implementations call HxW (16x16) a sequence. Sequence mostly preceeds dimension / channel
x = torch.randn(64, 32, 16, 16)

In [3]:
# TO replicate 1d-attention, we first need to flatten out the input tensor
# in view(), -1 stands for 'everything else'. Transposing will give us NLP's equivalent of 
# NSD (BatchxChannelxDimension)
t = x.view(*x.shape[:2], -1).transpose(1, 2)
t.shape

torch.Size([64, 256, 32])

>Self-attention is a crucial mechanism in modern neural networks, especially within the context of natural language processing (NLP) and transformer architectures. It allows models to weigh the significance of different parts of an input sequence relative to each other, enabling the capture of complex dependencies and contextual relationships.
>
> ### Key Components of Self-Attention
>
> Self-attention operates using three main components derived from the input sequence:
>
>    **Query (Q)**: This vector represents the current focus or context for a specific word. It is generated through a linear transformation of the input embedding.
>
>    **Key (K)**: Each word in the input sequence has an associated key vector, which serves as a reference point. The model compares the query vector with all key vectors to determine relevance.
>
>    **Value (V)**: The value vectors hold the actual information content associated with each word. After calculating attention scores based on the similarity between queries and keys, these value vectors are weighted accordingly to produce the output.
>
> ### Process of Self-Attention
>
>    **Linear Transformations**: The input embeddings are transformed into three separate matrices—Q, K, and V—using learned weight matrices.
>
>    **Attention Scores Calculation**: The attention score for each pair of words is computed by taking the dot product of the query vector with all key vectors. This score indicates how much focus should be placed on each word when processing a particular word.
>
>    **Softmax Normalization**: The scores are normalized using a softmax function to create a probability distribution, ensuring that they sum to one.
>
>    **Weighted Sum**: The output for each word is obtained by calculating a weighted sum of the value vectors, where weights correspond to the normalized attention scores.
>
>    **Final Output**: The resulting context-aware representations are then passed through additional layers, typically including feed-forward neural networks, to produce the final output.

In [4]:
# Number of input channels
ni = 32

In [5]:
# We now need 3 projections for 32 in channels to 32 out channels
# Creating simple linear layers (matmul plus a bias). Randomly initializing.
sk = nn.Linear(ni, ni)
sq = nn.Linear(ni, ni)
sv = nn.Linear(ni, ni)

In [6]:
# For self attention, the technical parlance refers to these projections as keys, queries and values
k = sk(t)
q = sq(t)
v = sv(t)

k.shape, q.shape, v.shape

(torch.Size([64, 256, 32]),
 torch.Size([64, 256, 32]),
 torch.Size([64, 256, 32]))

In [7]:
# Matmul with the transpose. For every 64 items in the batch and for 256 pixels, we now have 256 weights
(q@k.transpose(1,2)).shape

torch.Size([64, 256, 256])

Time to put the last few cell blocks into a `SelfAttention()` class.

In [8]:
# Setting up a class for self attention. Note that this self-attention approach is more geared
# towards resnets.
class SelfAttention(nn.Module):
    def __init__(self, ni):
        super().__init__()
        self.scale = math.sqrt(ni)
        self.norm  = nn.GroupNorm(1, ni) # Basically, BatchNorm for sets of channels
        self.q     = nn.Linear(ni, ni)
        self.k     = nn.Linear(ni, ni)
        self.v     = nn.Linear(ni, ni)
        self.proj  = nn.Linear(ni, ni) # final projection to map items to different scales

    def forward(self, x):
        inp = x
        n, c, h, w = x.shape
        x = self.norm(x)
        x = x.view(n, c, -1).transpose(1, 2)
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        # Matmul changes the scale of weights, normalizing to the original scale
        s = (q@k.transpose(1, 2)) / self.scale
        x = s.softmax(dim=-1) @ v
        x = self.proj(x) # Secondary projection
        x = x.transpose(1, 2).reshape(n,c,h,w) # reshaping back to the original
        return x + inp # adding outputs to the original. Diffusers does the same.

In [9]:
sa = SelfAttention(32) # self attention layer

In [10]:
# Calling the self attention layer on the randomly generated numbers. Transpose ops above
# ensure that the shape isn't changed
ra = sa(x)
ra.shape

torch.Size([64, 32, 16, 16])

In [11]:
ra[0, 0, 0]

tensor([ 1.9104,  1.4186,  0.8385, -2.1584,  0.6318, -1.2443, -0.0789, -1.6844,
        -0.7939,  1.6117, -0.3852, -1.4307, -0.7494, -0.6010, -0.8335,  0.7477],
       grad_fn=<SelectBackward0>)

We need to be sure that our outputs align with Diffusers' `Attention` outputs.

In [12]:
def cp_params(a, b):
    # Copy weights and biases from b to a
    b.weight = a.weight
    b.bias = a.bias

In [13]:
# Diffuser attention, updated since AttentionBlock ws deprecated
at = Attention(32, dim_head=32, out_dim=32, norm_num_groups=1, residual_connection=1)
# Comparing out q,k,v values to the ones from `at`
src = sa.q, sa.k, sa.v, sa.proj, sa.norm
dst = at.to_q, at.to_k, at.to_v, at.to_out[0], at.group_norm
#  Pairwise zipping 
for s, d in zip(src, dst): cp_params(s, d)

In [14]:
rb = at(x)
rb[0, 0, 0]

tensor([ 1.9104,  1.4186,  0.8385, -2.1584,  0.6318, -1.2443, -0.0789, -1.6844,
        -0.7939,  1.6117, -0.3852, -1.4307, -0.7494, -0.6010, -0.8335,  0.7477],
       grad_fn=<SelectBackward0>)

The similarity of results means that our attention block is now similar to the diffusers attention block. Alternatively, we can also run the following code for our projection calculations.

In [15]:
# Instead of three separate projections, we can create a single one. However the final dimension size will be larger
# given the three sets of multiplications here.
sqkv = nn.Linear(ni, ni*3)
st = sqkv(t)
st.shape

torch.Size([64, 256, 96])

In [16]:
# Chunking allows us to split along the last dimension to get q, k, v
q, k, v = torch.chunk(st, 3, dim=-1)
q.shape

torch.Size([64, 256, 32])

In [19]:
# Based on the above, this is an alternate - and more concise - version of SelfAttention which has 
# a single projection for q,k,v.

# This approach should also reduce computational overheads if we're working with standard PyTorch based frameworks.
class SelfAttention(nn.Module):
    def __init__(self, ni):
        super().__init__()
        self.scale = math.sqrt(ni)
        self.norm  = nn.BatchNorm2d(ni)
        self.qkv   = nn.Linear(ni, ni*3)
        self.proj  = nn.Linear(ni, ni)

    def forward(self, inp):
        n, c, h, w = inp.shape
        x = self.norm(inp).view(n, c, -1).transpose(1, 2)
        q, k, v = torch.chunk(self.qkv(x), 3, dim=-1) # Applying chunking to split along the last dim
        s = (q@k.transpose(1, 2)) / self.scale
        x = s.softmax(dim=-1) @ v
        x = self.proj(x).transpose(1, 2).reshape(n, c, h, w)
        return x + inp

In [20]:
sa = SelfAttention(32)
sa(x).shape

torch.Size([64, 32, 16, 16])

In [21]:
sa(x).std()

tensor(1.0094, grad_fn=<StdBackward0>)

Now that we've figured out how to apply self-attention, it is time to chuck it out the window **because this approach is never used in Stable Diffusion**.

Instead, we will use **Multi-Headed Attention**. Based on Perplexity...

> The choice of using multi-headed attention over self-attention in training stable diffusion models is primarily driven by the need for enhanced feature representation and effective integration of textual and visual information.
> 
> ### Importance of Multi-Headed Attention
> 
> #### 1. Diverse Representation:
> 
> Multi-headed attention allows the model to capture different aspects of the input data simultaneously. Each attention head can focus on various parts of the input sequence, enabling the model to learn multiple relationships and features from the data. This diversity is crucial for tasks like text-to-image generation, where different textual prompts may relate to various visual features in an image.
> 
> #### 2. Cross-Attention Mechanism:
> 
> In stable diffusion models, cross-attention plays a vital role in merging information from text prompts with image features. The cross-attention layers enable the model to align and integrate the textual description with specific regions of the image, facilitating coherent image generation based on the provided prompt.
>
> This mechanism is essential for ensuring that generated images are consistent with their corresponding textual descriptions.
> #### 3. Preservation of Spatial Details:
>
> While self-attention is useful for understanding relationships within a single input (like an image), it does not effectively manage the integration of external information (like text). In contrast, multi-headed cross-attention allows for the preservation of geometric and spatial details during transformation processes, which is critical in maintaining the integrity of the original image while incorporating new elements from the text.
>
> ### Limitations of Self-Attention in This Context
> 
> Self-attention, while beneficial for capturing internal dependencies within an input sequence, does not provide the same level of flexibility when it comes to integrating external information. In stable diffusion models, relying solely on self-attention could lead to challenges in maintaining coherence between generated images and their corresponding textual prompts. The self-attention mechanism tends to focus more on preserving shape and structure rather than effectively merging different types of information.

It is important to note that softmax tends to highlight certain weights more than others due to the way it scales outputs. So, in the case of single-headed attention, it would pick a single pixel almost exclusively.

In [22]:
# This comes from the diffusers code and is the traditional way to approach the problem.
def heads_to_batch(x, heads):
    # batch, pixels, channels
    n, sl, d = x.shape
    x = x.reshape(n, sl, heads, -1) # reshaping so that we have 64 images x 256 pixels x 4 heads x 32/8 channels
    return x.transpose(2, 1).reshape(n*heads, sl, -1) # nx4 -->reshape to combine

def batch_to_heads(x, heads):
    n, sl, d = x.shape 
    x = x.reshape(-1, heads, sl, d)
    return x.transpose(2, 1).reshape(-1, sl, d*heads)

`einops` allows us to use `rearrange` which is a cool rethinking of Einstein summation notation to enable tensor re-arranging operations.

In [23]:
from einops import rearrange

In [24]:
rearrange?

[0;31mSignature:[0m
[0mrearrange[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mtensor[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0;34m~[0m[0mTensor[0m[0;34m,[0m [0mList[0m[0;34m[[0m[0;34m~[0m[0mTensor[0m[0;34m][0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpattern[0m[0;34m:[0m [0mstr[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0;34m**[0m[0maxes_lengths[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0;34m~[0m[0mTensor[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
einops.rearrange is a reader-friendly smart element reordering for multidimensional tensors.
This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
stack, concatenate and other operations.

Examples for rearrange operation:

```python
# suppose we have a set of 32 images in "h w c" format (height-width-channel)
>>> images = [np.random.randn(30, 40, 3) for _ in range(32)]

# stack along first (batch) axis, output is

In [25]:
# Tensor rearrangement
# Take our rank 3 tensor, containing the first dim of length n, the second dim of length s and a third dim
# (hxd) with h=8.
# This rearrangement results in each batch which is now (nxh), with the same sequence length, and the number of channels d
# has been reduced by a factor of 8.
t2 = rearrange(t, 'n s (h d) -> (n h) s d', h=8)
t.shape, t2.shape

(torch.Size([64, 256, 32]), torch.Size([512, 256, 4]))

In [26]:
# We can reverse the operations just as easily
t3 = rearrange(t2, '(n h) s d -> n s (h d)', h=8)
t2.shape, t3.shape

(torch.Size([512, 256, 4]), torch.Size([64, 256, 32]))

In [28]:
# Confirming that rearrangements return the same results.
(t==t3).all()

tensor(True)

In [29]:
class SelfAttentionMultiHead(nn.Module):
    def __init__(self, ni, nheads): # Adding an additional parameter nheads
        super().__init__()
        self.nheads = nheads
        self.scale  = math.sqrt(ni / nheads)
        self.norm   = nn.BatchNorm2d(ni)
        self.qkv    = nn.Linear(ni, ni*3)
        self.proj   = nn.Linear(ni, ni)

    def forward(self, inp): 
        n, c, h, w = inp.shape
        x = self.norm(inp).view(n, c, -1).transpose(1, 2)
        x = self.qkv(x)
        # Take the number of heads (for demo purposes, 32 channels split over 4 heads i.e. 8 per head.)
        # Each batch becomes 4 times larger due to the additional heads.
        x = rearrange(x, 'n s (h d) -> (n h) s d', h=self.nheads) 
        q, k, v = torch.chunk(x, 3, dim=-1)
        s = (q@k.transpose(1, 2)) / self.scale
        x = s.softmax(dim=-1) @ v
        x = rearrange(x, '(n h) s d -> n s (h d)', h=self.nheads)
        x = self.proj(x).transpose(1, 2).reshape(n, c, h, w)
        return x + inp

In [30]:
sa = SelfAttentionMultiHead(32, 4) # MH attention with 32 channels and 4 heads
sx = sa(x)
sx.shape

torch.Size([64, 32, 16, 16])

In [31]:
sx.mean(), sx.std()

(tensor(-0.0146, grad_fn=<MeanBackward0>),
 tensor(1.0098, grad_fn=<StdBackward0>))

In [32]:
# PyTorch already has this all built in with nn.MultiheadAttention.
# Using batch_first=True, ensures that the first dimension passed is the batch so that everything is 
# aligned closely with Diffusers.
nm = nn.MultiheadAttention(32, num_heads=8, batch_first=True)
nmx, nmw = nm(t, t, t) # Q, K, V - Passing other projections enables cross attention
nmx = nmx + t

In [33]:
nmx.mean(), nmx.std()

(tensor(-0.0007, grad_fn=<MeanBackward0>),
 tensor(1.0011, grad_fn=<StdBackward0>))