# Attention Scoring Functions

Attention mechanisms are vital in deep learning, especially in tasks like NLP and computer vision. They enable models to focus on relevant parts of input data.  
This notebook explores key attention scoring functions:  
- Dot product  
- Scaled dot product  
- Additive attention  

This will help to gain insights into the foundation of models like transformers.[Reference](https://arxiv.org/abs/1706.03762)

### Libraries/packages used  

The following libraries are used in this notebook:  

* [<code style="color:blue;">numpy:</code>](https://numpy.org/doc/stable/) A library for numerical computations, enabling array manipulation and mathematical operations  
* [<code style="color:blue;">matplotlib:</code>](https://matplotlib.org/stable/) A plotting library used for creating visualizations of attention scores and patterns  
* [<code style="color:blue;">pytorch:</code>](https://pytorch.org/docs/stable/index.html) A popular deep learning framework for building neural networks
* [<code style="color:blue;">d2l (dive into deep learning):</code>](https://d2l.ai/)  A library that simplifies implementation of deep learning concepts

These packages are essential for implementing and visualizing attention scoring functions. To run this code on your local machine, please install the necessary libraries using the following commands:

* `pip install torch`
* `pip install d2l==1.0.3`
* `pip install numpy ` 
* `pip install matplotlib ` 


In [None]:
!pip install d2l==1.0.3

<a id='TOC'></a>
## Table of contents

1. <a href="#intro">Introduction</a><br>  
2. <a href="#dot-product">Dot product attention</a>  
3. <a href="#convenience">Convenience functions</a>  
   3.1. <a href="#masked-softmax">Masked softmax operation</a>  
   3.2. <a href="#batch-matrix">Batch matrix multiplication</a><br>  
4. <a href="#dotproductattention-class">Scaled dot product attention</a><br>
5. <a href="#Additive">Additive attention</a><br>

<a id='intro'></a>
## 1. Introduction 
[Back to table of contents](#TOC)

In attention pooling, various distance-based kernels, including Gaussian kernels, model interactions between queries and keys. While effective, distance functions are slightly more expensive to compute compared to dot products. To simplify computation, *attention scoring functions* are widely used. These functions are central to determining attention weights using softmax operations.

![Computing the output of attention pooling as a weighted average of values, where weights are computed with the attention scoring function $\mathit{a}$ and the softmax operation.](img/attention-output.svg)


<a id='dot-product'></a>
## 2. Dot product attention  
[Back to table of contents](#TOC)

The attention function (without exponentiation) based on the Gaussian kernel is given as:

$$
a(\mathbf{q}, \mathbf{k}_i) = -\frac{1}{2} \|\mathbf{q} - \mathbf{k}_i\|^2  = \mathbf{q}^\top \mathbf{k}_i -\frac{1}{2} \|\mathbf{k}_i\|^2  -\frac{1}{2} \|\mathbf{q}\|^2.
$$

First, note that the final term depends on $\mathbf{q}$ only. As such it is identical for all $(\mathbf{q}, \mathbf{k}_i)$ pairs. Normalizing the attention weights to $1$, ensures that this term disappears entirely. 

Second, note that both batch and layer normalization (to be discussed later) lead to activations that have well-bounded, and often constant, norms $\|\mathbf{k}_i\|$. This is the case, for instance, whenever the keys $\mathbf{k}_i$ were generated by a layer norm. As such, we can drop it from the definition of $a$ without any major change in the outcome. 

Last, we need to keep the order of magnitude of the arguments in the exponential function under control. Assume that all the elements of the query $\mathbf{q} \in \mathbb{R}^d$ and the key $\mathbf{k}_i \in \mathbb{R}^d$ are independent and identically drawn random variables with zero mean and unit variance. The dot product between both vectors has zero mean and a variance of $d$. To ensure that the variance of the dot product still remains $1$ regardless of vector length, we use the `scaled dot product attention` scoring function. That is, we rescale the dot product by $1/\sqrt{d}$. We thus arrive at the first commonly used attention function that is used, e.g., in Transformers.

$$ a(\mathbf{q}, \mathbf{k}_i) = \mathbf{q}^\top \mathbf{k}_i / \sqrt{d}.$$


Note that attention weights 𝛼α still need normalizing. We can simplify this further by using the softmax operation: 

$$\alpha(\mathbf{q}, \mathbf{k}_i) = \mathrm{softmax}(a(\mathbf{q}, \mathbf{k}_i)) = \frac{\exp(\mathbf{q}^\top \mathbf{k}_i / \sqrt{d})}{\sum_{j=1} \exp(\mathbf{q}^\top \mathbf{k}_j / \sqrt{d})}.$$

<a id='convenience'></a>
## 3. Convenience functions  
[Back to table of contents](#TOC)

There are a few utility functions needed to make the attention mechanism efficient to deploy. This includes tools for dealing with strings of variable lengths (common for natural language processing) and tools for efficient evaluation on minibatches (batch matrix multiplication). 

<a id='masked-softmax'></a>
### 3.1 Masked softmax operation  
[Back to table of contents](#TOC)

One of the most popular applications of the attention mechanism is to sequence models. Hence we need to be able to deal with sequences of different lengths. In some cases, such sequences may end up in the same minibatch, necessitating padding with dummy tokens for shorter sequences. These special tokens do not carry meaning. For instance, assume that we have the following three sentences:

```
Dive  into  Deep    Learning 
Learn to    code    <blank>
Hello world <blank> <blank>
```

Since we do not want blanks in our attention model we simply need to limit $\sum_{i=1}^n \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i$ to $\sum_{i=1}^l \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i$ for however long, $l \leq n$, the actual sentence is. Since it is such a common problem, it has a name: the `masked softmax operation`. 

The implementation cheats ever so slightly by setting the values of $\mathbf{v}_i$, for $i > l$, to zero. Moreover, it sets the attention weights to a large negative number, such as $-10^{6}$, in order to make their contribution to gradients and values vanish in practice. This is done since linear algebra kernels and operators are heavily optimized for GPUs and it is faster to be slightly wasteful in computation rather than to have code with conditional (if then else) statements.


In [None]:
#import libraries
import math
import torch
from torch import nn
from d2l import torch as d2l

In [None]:
def masked_softmax(X, valid_lens):  #@save
    """Perform softmax operation by masking elements on the last axis."""
    # X: 3D tensor, valid_lens: 1D or 2D tensor
    def _sequence_mask(X, valid_len, value=0):
        maxlen = X.size(1)
        mask = torch.arange((maxlen), dtype=torch.float32,
                            device=X.device)[None, :] < valid_len[:, None]
        X[~mask] = value
        return X

    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

#### Example usage:
To demonstrate how this function operates, imagine a minibatch with two examples of size $2 \times 4$, where their valid lengths are $2$ and $3$, respectively. Due to the masked softmax operation, any values beyond the valid lengths for each vector pair are masked to zero.

In [None]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))

If we need more fine-grained control to specify the valid length for each of the two vectors of every example, we simply use a two-dimensional tensor of valid lengths. This yields:


In [None]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))

<a id='batch-matrix'></a>
### 3.2 Batch matrix multiplication  
[Back to table of contents](#TOC)

Efficient computation of batch matrix multiplication (BMM) is critical for minibatches of queries, keys, and values.

Another commonly used operation is to multiply batches of matrices by one another. This comes in handy when we have minibatches of queries, keys, and values. More specifically, assume that 

$$\mathbf{Q} = [\mathbf{Q}_1, \mathbf{Q}_2, \ldots, \mathbf{Q}_n]  \in \mathbb{R}^{n \times a \times b}, \\
    \mathbf{K} = [\mathbf{K}_1, \mathbf{K}_2, \ldots, \mathbf{K}_n]  \in \mathbb{R}^{n \times b \times c}.
$$

Then the batch matrix multiplication (BMM) computes the elementwise product

$$\textrm{BMM}(\mathbf{Q}, \mathbf{K}) = [\mathbf{Q}_1 \mathbf{K}_1, \mathbf{Q}_2 \mathbf{K}_2, \ldots, \mathbf{Q}_n \mathbf{K}_n] \in \mathbb{R}^{n \times a \times c}.$$
#### Example usage:

In [None]:
Q = torch.ones((2, 3, 4))
K = torch.ones((2, 4, 6))
d2l.check_shape(torch.bmm(Q, K), (2, 3, 6))

<a id='dotproductattention-class'></a>
## 4. Scaled dot product attention 
[Back to table of contents](#TOC)

In general, it requires that both the query and the key
have the same vector length, say $d$, even though this can be addressed easily by replacing 
$\mathbf{q}^\top \mathbf{k}$ with $\mathbf{q}^\top \mathbf{M} \mathbf{k}$ where $\mathbf{M}$ is a matrix suitably chosen for translating between both spaces. For now assume that the dimensions match. 

In practice, we often think of minibatches for efficiency,
such as computing attention for $n$ queries and $m$ key-value pairs,
where queries and keys are of length $d$
and values are of length $v$. The scaled dot product attention 
of queries $\mathbf Q\in\mathbb R^{n\times d}$,
keys $\mathbf K\in\mathbb R^{m\times d}$,
and values $\mathbf V\in\mathbb R^{m\times v}$
thus can be written as 
$$
\mathrm{softmax}\left(\frac{\mathbf{Q} \mathbf{K}^\top}{\sqrt{d}}\right) \mathbf{V} \in \mathbb{R}^{n \times v}.
$$

Batch matrix multiplication ensures efficiency in this process. Additionally, dropout is used for regularization, helping to reduce overfitting. This approach enables scalable and stable attention computation for large datasets.


In [None]:
class DotProductAttention(nn.Module):  #@save
    """Scaled dot product attention."""
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    # Shape of queries: (batch_size, no. of queries, d)
    # Shape of keys: (batch_size, no. of key-value pairs, d)
    # Shape of values: (batch_size, no. of key-value pairs, value dimension)
    # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # Swap the last two dimensions of keys with keys.transpose(1, 2)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

To demonstrate how the `DotProductAttention` class functions, we will use the same keys, values, and valid lengths from the previous toy example for additive attention. In this case, we assume a minibatch size of $2$, with a total of $10$ keys and values, each having a dimensionality of $4$. Additionally, the valid lengths for each observation are $2$ and $6$, respectively. Based on this setup, we expect the output to be a $2 \times 1 \times 4$ tensor, meaning one row per example in the minibatch.

In [None]:
queries = torch.normal(0, 1, (2, 1, 2))
keys = torch.normal(0, 1, (2, 10, 2))
values = torch.normal(0, 1, (2, 10, 4))
valid_lens = torch.tensor([2, 6])

attention = DotProductAttention(dropout=0.5)
attention.eval()
d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))

Let's check whether the attention weights actually vanish for anything beyond the second and sixth column respectively (because of setting the valid length to $2$ and $6$).


In [None]:
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries')

<a id='Additive'></a>
## 5. Additive attention  
[Back to table of contents](#TOC)

When the queries $\mathbf{q}$ and keys $\mathbf{k}$ have different dimensions, we can either use a matrix $\mathbf{M}$ to match their dimensions via $\mathbf{q}^\top \mathbf{M} \mathbf{k}$ or use additive attention as the scoring function. Additive attention is computationally more efficient due to its additive nature. Given a query $\mathbf{q} \in \mathbb{R}^q$ and a key $\mathbf{k} \in \mathbb{R}^k$, the additive attention scoring function is:

$$
a(\mathbf{q}, \mathbf{k}) = \mathbf{w}_v^\top \textrm{tanh}(\mathbf{W}_q \mathbf{q} + \mathbf{W}_k \mathbf{k}) \in \mathbb{R},
$$
where $\mathbf{W}_q \in \mathbb{R}^{h \times q}$, $\mathbf{W}_k \in \mathbb{R}^{h \times k}$, and $\mathbf{w}_v \in \mathbb{R}^h$ are learnable parameters. This result is passed through a softmax for normalization and nonnegativity. An alternative interpretation is that the query and key are concatenated and passed through a multi-layer perceptron (MLP) with a hidden layer, using $\tanh$ as the activation function and no bias terms.

In [None]:
class AdditiveAttention(nn.Module):  #@save
    """Additive attention."""
    def __init__(self, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.LazyLinear(num_hiddens, bias=False)
        self.W_q = nn.LazyLinear(num_hiddens, bias=False)
        self.w_v = nn.LazyLinear(1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        # After dimension expansion, shape of queries: (batch_size, no. of
        # queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of
        # key-value pairs, num_hiddens). Sum them up with broadcasting
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        features = torch.tanh(features)
        # There is only one output of self.w_v, so we remove the last
        # one-dimensional entry from the shape. Shape of scores: (batch_size,
        # no. of queries, no. of key-value pairs)
        scores = self.w_v(features).squeeze(-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # Shape of values: (batch_size, no. of key-value pairs, value
        # dimension)
        return torch.bmm(self.dropout(self.attention_weights), values)

Let's see how `AdditiveAttention` works, in our toy example we pick queries, keys and values of size 
$(2, 1, 20)$, $(2, 10, 2)$ and $(2, 10, 4)$, respectively. This is identical to our choice for `DotProductAttention`, except that now the queries are $20$-dimensional. Likewise, we pick $(2, 6)$ as the valid lengths for the sequences in the minibatch.


In [None]:
queries = torch.normal(0, 1, (2, 1, 20))

attention = AdditiveAttention(num_hiddens=8, dropout=0.1)
attention.eval()
d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))

When reviewing the attention function we see a behavior that is qualitatively quite similar to that of `DotProductAttention`. That is, only terms within the chosen valid length $(2, 6)$ are nonzero.


In [None]:
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries')

**Summary:** In this section, we introduced the two primary attention scoring functions: dot product and additive attention, both of which are effective for aggregating sequences of varying lengths. Specifically, dot product attention is a core component of modern Transformer architectures. When queries and keys have different lengths, we can use the additive attention scoring function instead. Optimizing these layers has been a significant area of progress in recent years. For example, [NVIDIA's Transformer Library](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html) and [Megatron](https://arxiv.org/abs/1909.08053) rely on efficient variants of the attention mechanism.