# Transformer Details and Implementation

A theory to practice study of the Transformer model, based on relevant parts from the textbook [Dive Into Deep Learning (D2L)](https://d2l.ai/chapter_attention-mechanisms-and-transformers/index.html). 

Note that the implementation is not for production. 

## Attention

### Definition

Denote a database of $m$ tuples of _keys_ and _values_ $D\stackrel{def}{=}\{(k_{1}, v_{1}), ..., (k_{m}, v_{m})\}$, also denote a _query_ by $q$. Then we can define the _Attention_ over D as

$$ Attention(q, D)\stackrel{def}{=}\sum_{i=1}^m \alpha(q, k_{i})v_{i} $$ 

where $\alpha(q, k_{i})$ are scalar attention weights. 

This operation pays more attention to terms where the weight is larger, hence the name _attention_.

### Prerequisites

To train a model using this function smoothly and stably, we want to ensure a number of requirements to the attention weights:
- The weights $\alpha(q, k_{i})$ are nonnegative.
- The weights $\alpha(q, k_{i})$ form a convex combination, i.e., $\sum_{i}\alpha(q,k_{i})=1$ and $\alpha(q, k_{i})\ge0$
- Exactly one of the weights is 1 and all others are 0

#### Sum to 1
To ensure the weights sum up to 1, we can normalize them:

$$ \alpha(q, k_{i}) = \frac{\alpha(q, k_{i})}{\sum_{j}\alpha(q,k_{j})} $$

#### Non-negative
To also ensure that the weights are non-negative, we can use exponentiation:

$$ \alpha(q, k_{i}) = \frac{exp(\alpha(q, k_{i}))}{\sum_{j}exp(\alpha(q, k_{j}))} $$

Now it is differentiable and its gradient never vanishes, all of which are desirable properties in a model.

This is the $softmax$ operation.

### Dot Product Attention

#### Gaussian kernel attention to dot product attention


We can derive from Gaussain kernel attention to simple dot product attention naturally by expanding the formula then consider it after Batch and Layer Normalization. (Here queries $q$ and keys $k$ are column vectors, to take their dot product we transpose $q$.) 

$$ \alpha(q, k_{i}) = q^\top k_{i} $$

(The Gaussian kernel formula

$$ \alpha(q, k_{i}) = -\frac{1}{2} \lVert q-k_{i} \rVert^2 $$

Can be expanded as

$$ = q^\top{k_{i}} - \frac{1}{2}\lVert k_{i} \rVert^2 - \frac{1}{2}\lVert q \rVert^2 $$

For the last term $ - \frac{1}{2}\lVert q \rVert^2 $, since $q$ is the same for all $(q, k_{i})$ pairs, it will disappear after normalization.

For the term $- \frac{1}{2}\lVert k_{i} \rVert^2$, since both batch and layer normalization lead to activations with well-bounded and often constant norms $\lVert k_{i} \rVert$, after normalization we can drop it without major change to the outcome as well.)

#### Scaled dot product attention weights

Assume that all the elements of the query $ q \in \mathbb{R}^d $ and the key $ k \in \mathbb{R}^d $ are independent and identically drawn random variables with zero mean and unit variance. Their dot product will have zero mean and a variance of $d$. Here variance refers to how spread out the values are. When we compute the dot products of two vectors, higher dimensionality (vector length) will lead to higher variance.

However, we want to keep the variance of the dot product to 1 regardless of the vector length $d$. The reason behind it are:
1. Stable training: keeping the variance around a constant (like 1) helps stablize the gradient updates during training, decrease the risk of exploding gradients which destabilize the training
2. Stable model: a constant variance means that the dot products don't change drastically with different vector lengths, making the model's behavior more predictable and stable.

To achieve a constant variance of 1 for the dot product, we scale the dot product attention by the square root of the dimension:

$$ \alpha(q, k_{i}) = q^\top k_{i} / \sqrt{d} $$

Finally, we still need to normalize the weights $\alpha$ to be non-negative and sum to one, so we apply the softmax and arrive at the most commonly used mechanism for attention weights - the scaled dot product attention weights.

$$ \alpha(q, k_{i}) = softmax(\alpha(q, k_{i})) $$
$$ = \frac{exp(q^{\top}k_{i}/\sqrt{d})}{\sum_{j=1}exp(q^{\top}k_{j}/\sqrt{d})} $$

#### Matrix Multplication

Calculating attention weights one query at a time is very slow. As GPU is highly optimized for matrix multiplication, we stack the query, key-value pairs into matrices to take advantage of the compute efficiency, calculating the dot product attention for multiple queries at once. 

Say we compute attention for $n$ queries on $m$ key-value pairs, each query and key has the same dimension $d$, each value has dimension $v$. So we have three matrices: 
- $Q \in \mathbb{R}^{n \times d}$, $n$ rows of $d$-dimensional query vectors
- $K \in \mathbb{R}^{m \times d}$, $m$ rows of $d$-dimensional key vectors
- $V \in \mathbb{R}^{m \times v}$, $m$ rows of $v$-dimensional value vectors

So the scaled dot product attention of matrices $Q$, $K$, $V$ can be written as

$$ softmax(\frac{QK^\top}{\sqrt{d}})V \in \mathbb{R}^{n \times v}$$

The dimension of the calculated attention value matrix is $n \times v$,  namely, one value vector row for each query row. 

In practice, as each GPU has a limited memory, we normally can't calculate all queries at once, either. We will divide the query, key, value matricis into mini-batches. More about this will be elaborated later in the "Batch Matrix Multiplication" section of the implementation part.

### PyTorch implementation

Here we look at a sample implementation of dot product attention in PyTorch. To understand the implementation details demanded by the training practicalities, we'll first look into a function `masked_softmax()`, which applies masked `softmax` to the calculated attention value, then look at the `DotProductAttention` class.

Before anything, import the dependencies:

In [1]:
import math
import torch
from torch import nn
from d2l import torch as d2l

#### The masked_softmax function


Define a convenient softmax function which can optionally mask each sequence at the positions that go beyond its valid length. This allows us to handle input sequences with different lengths (when they end up in the same batch, shorter sequences are typically padded with dummy tokens, to which we don't want our model to pay attention). 

In [None]:
def masked_softmax(X, valid_lens = None):
    """
    Perform softmax while optionally applying a sequence mask on the input X's last axis.
    
    Parameters
    --------
    X : 3D tensor (num_batches, batch_size, max_seq_len)
        The input batches of sequences
    valid_lens : 1D or 2D tensor
        The valid length of each sequence in the input batches. Using a 1D tensor assumes the valid sequence length is the same in each batch. Using a 2D tensor allows the function to work with different sequence lengths in the same batch.
    """
    
    # helper function to apply mask to a sequnce
    def _sequence_mask(X, valid_len, mask_value=0):
        """
        Returns a masked instance of the input sequences X. The mask_value will be applied to each sequence's positions which are not within that sequence's valid length as denoted in valid_len.
        
        Parameters
        --------
        X : 2D tesnsor of dimension (num_sequences, max_seq_len)
            a matrix including all input sequences as rows
        valid_len : 1D tensor of dimension (1, num_sequences)
            the valid length of each input sequence
        """
        max_seq_len = X.size(1) 
        mask = torch.arange(max_seq_len, dtype=torch.float32, \
            device=X.device)[None, :] < valid_len[:, None]
        X[~mask] = mask_value
        return X
        
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            # if valid_lens is a 1D tensor, we want to make it 2D to work with our helper function by repeating it batch_size times 
            ## Intuition
            ## Suppose X has 2 batches each containing 3 sequences and the max length of all sequences is 4, i.e., X.shape is (2 x 3 x 4) 
            ## Suppose valid_lens is a 1D tensor [2, 4] denoting that the first batch has sequence max_len 2, second batch 4
            ## To convert it to the 2D tensor, since the batch size is 3 (shape[1]), we want our valid_lens tensor to be [[2, 2, 2], [4, 4, 4]]
            # which is practically repeat each element of the original 1D tensor 3(shape[1]) times
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            # otherwise, flatten the 2D tensor to 1D
            valid_lens = valid_lens.reshape(-1)
        
        # Apply mask to replace all invalid positions (positions that exceed the valid length of that sequence) in all input sequences with a very small number -1e6
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, mask_value = -1e6)
        
        return nn.functional.softmax(X.reshape(shape), dim=-1)
        

Let's take a closer look at this implementation. 

First the helper function `_sequence_mask(X, valid_len, mask_value=0)`:
1. The input arguments is preprocessed:
   1.  `X` is not directly the outer function's input argument X (which is a 3D tensor X of dimension (num_batches, batch_size, max_seq_len)), but its reshaped version of a 2D tensor, with dimensions `(num_sequences (aka num_batches * batch_size), max_seq_len)`.
   2.  `valid_len` is not directly the outer function's input argument `valid_lens`, either. While the outer function accepts either a 1D or a 2D tensor, the inner function accepts only its sanitized and by necessity reshaped form, which is a 1D tensor of each sequence's valid length across all sequences in the input. Thus its dimension is (1, num_sequences)
2. How the boolean mask is created:
   1.  Get the max sequence length `maxlen`, which is the 2nd dimension of input `X`
   2.  Use `torch.arange` and the max sequence length as its upper bound to create a vector of all position ids in a longest sequence
   3.  Directly in the same line of code, there were three steps to create the boolean mask for each sequence's each position
      1.  expand the sequence ids 1D tensor to a 2D tensor by prepending a dimension at the front `[None, :]`, so that it's of dimension `(1, max_seq_len)`
      2. expand the `valid_len` 1D tensor to a 2D tensor by appending a dimension at the back `[:, None]`, so that it's of dimension `(batch_size * num_batches, 1)`
      3. create the boolean mask by element-wise comparing if a seq id is smaller than the corresponding position's `valid_len`, which renders `True` if the position is within that sequence's valid length, `False` otherwise
         1. Here we see a clever use of broadcasting: element-wise comparing two tensors of different dimensions will cause them to be both broadcasted. So both seqence ids tensor and `valid_len` are broadcasted to be of dimension `(batch_size * num_batches, max_seq_len)`, essenially giving us a mask of the same dimension as this function's input X `(num_sequences, max_seq_len)`.
3. How the boolean mask is applied
   1. The boolean mask we created in step 2 has `True` in all positions that are within the valid length. Taking its complement `~mask` flips the `True` to `False`, now all the invalid positions will be `True` and we can use it to filter the invalid positions in `X`
   2. `X[~mask]` selects all the invalid positions in X. Now we apply the mask value to these positions: `X[~mask] = mask_value`
After all these, we return the masked X matrix.

Then let's see what was happening outside of the helper function, especially how the valid_lens tensor, if given, is normalized to a 1D tensor denoting the valid length of each sequence in all input sequence batches.
1. If no valid_lens is given, simply do a softmax to the input matrix and return it
2. If a valid_lens tensor is given, we need to check if it's 1D or 2D `shape = X.shape` Per our function definition, this shape should be `(num_batches, batch_size, max_seq_len)`
   1. If it's a 1D tensor, it means each batch has a consistent valid length and it is denoted as each item in this 1D tensor. However, our helper function expects a 1D tensor denoting the valid length of each sequence, not each batch. To satisfy this, we can just copy each item of the input 1D tensor `batch_size` (the second dimension of the shape of X) times with `torch.repeat_interleave(valid_lens, shape[1])`
   2. Else, it would be a 2D tensor, denoting the valid length of each sequence inside each batch. To get the 1D valid_lens tensor we need from this 2D tensor, we simply need to flatten it with `valid_lens.reshape(-1)` 
Thus the `valid_lens` is well prepared to be given to the helper function and get us a masked version of X. When calling the helper function, we specify a really small mask value `-1e6` to decrease the weight as much as possible, so that the model won't pay any attention to them.

> Note: this implementation is directly from the textbook. From software engineering's perspective, the interface design is brittle without error handling, and the code lacks readability. I'll look into other implementations (e.g. [The Annotated Transformer](https://nlp.seas.harvard.edu/2018/04/03/attention.html#batches-and-masking)) later.

#### The `DotProductAttention` Class

Let's return to the scaled dot product attention for baches. Say we compute attention for 
- $n$ queries of dimension $d$: $Q \in \mathbb{R}^{n \times d}$, on
- $m$ key-value pairs, where keys has the same dimension as queries: $K \in \mathbb{R}^{m \times d}$, and
- values in each of the $m$ pairs are all of length $v$: $V \in \mathbb{R}^{m \times v}$

$$ softmax(\frac{QK^\top}{\sqrt{d}})V \in \mathbb{R}^{n \times v} $$ 

Define the `DotProductAttention` class as a subclass of [PyTorch's `nn.Module`](https://pytorch.org/docs/stable/notes/modules.html). `Dropout` is used for model regularization.

In [2]:
class DotProductAttention(nn.Module):
    """Scaled dot product attention
    Parameters
    ========
    dropout: float between 0 and 1, optional
        The dropout rate (probability of an element being set to zero) to apply to the attention weights before multiplying with the values. Default is 0.5
    """
    
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

Let's look into the `forward` function. It's basically implementing the formula, with two twists:
- apply dropout to the calculated weight scores before multiplying with values
- use batch matrix multiplication. 

#### Batch Matrix Multiplication

First, a detour to batch matrix multiplication (BMM). 

For performance we divide the whole input matrix into $n$ minibatches. Matrix multiplication then will be performed elementwise on these sub matrices. Suppose we divided input matrices $A$ and $B$ into $n$ batches:
- $A = [A_1, A_2, ..., A_n] \in \mathbb{R}^{n \times a \times b}$
- $B = [B_1, B_2, ..., B_n] \in \mathbb{R}^{n \times b \times c}$

Then the batch matrix multiplication (BMM) computes the elementwise product of them
$$ BMM(A, B) = [A_1B_1, A_2B_2, ..., A_nB_n] \in \mathbb{R}^{n \times a \times c} $$

As for our dot product attention, we also use batches, thus the dimensions of $Q$ `queries`, $K$ `keys`, and $V$ `values` are not 2D anymore, but 3D, with the first dimension being the `batch_size`, the other dimensions are as discussed earlier.
- $Q$ `queries`: (batch_size, n, d)
- $K$ `keys`: (batch_size, m, d)
- $V$ `values`: (batch_size, m, v)

To calculate the weight scores $QK^\top$, we transpose the 2nd and 3rd dimensions of $K$ `keys.transpose(1, 2)`, then do batch matrix multiplication of $Q$ and transposed $K$ `torch.bmm(queries, keys.transpose(1,2))`.

Later, after applying dropout to the calculated weights, BMM is also used to compute the attention values `torch.bmm(self.dropout(self.attention_weights), values)`.

### Additive Attention (Bahadanau Attention)

A mechanism in sequence-to-sequence model to improve the performance of neural machine translation systems. It addresses the limitation of encoding the entire input sequence into a fixed-size vector like in RNN, aiming to allow the decoder to focus on different parts of the input sequence during each step of the output generation.

Its core attention weight calculation is a feed-forward neural network with a single hidden layer. $w_v$, $W_q$, $W_k$ are learnable parameters. It is additive with non-linearity (tanh activation), instead of simple dot product attention. 

$$ \alpha(q,k) = w_v^\top tanh(W_qq + W_kk) \in{\mathbb{R}} $$

In a sequence-to-sequence model, the context variable $c$ will be dynamically updated as a function of both the original text (encoder hidden states $h_t$) and the already generated text (decoder hidden states $s_{t'-1}$).

$$ c_{t'} = \sum_{t=1}^T \alpha(s_{t'-1}, h_t)h_t $$

We use decoder's previous hidden state $s_{t'-1}$ as the query, encoder's hidden states $h_t$ as both the key and the value, and calculate additive attention weight $\alpha$. Later the model was modified to use the already generated tokens in decoder also as context, i.e., the attention sum does not stop at $T$, but rather $t'-1$.

## Multi-Head Attention

## Self-Attention and Positional Encoding

## The Transformer Architecture