# Transformer in Formula and Code

Based on [Dive Into Deep Learning (D2L)](https://d2l.ai/chapter_attention-mechanisms-and-transformers/index.html)

## Attention

### Basics
Denote a database of $m$ tuples of _keys_ and _values_ $D\stackrel{def}{=}\{(k_{1}, v_{1}), ..., (k_{m}, v_{m})\}$, also denote a _query_ by $q$. Then we can define the _Attention_ over D as

$$ Attention(q, D)\stackrel{def}{=}\sum_{i=1}^m \alpha(q, k_{i})v_{i} $$ 

where $\alpha(q, k_{i})$ are scalar attention weights. 

This operation pays more attention to terms where the weight is larger, hence the name _attention_.

### Requirements
To train a model using this function smoothly and stably, we want to ensure a number of requirements to the attention weights:
- The weights $\alpha(q, k_{i})$ are nonnegative.
- The weights $\alpha(q, k_{i})$ form a convex combination, i.e., $\sum_{i}\alpha(q,k_{i})=1$ and $\alpha(q, k_{i})\ge0$
- Exactly one of the weights is 1 and all others are 0

#### Sum to 1
To ensure the weights sum up to 1, we can normalize them:

$$ \alpha(q, k_{i}) = \frac{\alpha(q, k_{i})}{\sum_{j}\alpha(q,k_{j})} $$

#### Non-negative
To also ensure that the weights are non-negative, we can use exponentiation:

$$ \alpha(q, k_{i}) = \frac{exp(\alpha(q, k_{i}))}{\sum_{j}exp(\alpha(q, k_{j}))} $$

Now it is differentiable and its gradient never vanishes, all of which are desirable properties in a model.

This is the $softmax$ operation.

### Dot Product Attention Weights

#### Gaussian kernel attention to dot product attention
We can derive from Gaussain kernel attention to simple dot product attention naturally by expanding the formula then consider it after Batch and Layer Normalization. (Here queries $q$ and keys $k$ are column vectors, to take their dot product we transpose $q$.) 

$$ \alpha(q, k_{i}) = q^\top k_{i} $$

(The Gaussian kernel formula

$$ \alpha(q, k_{i}) = -\frac{1}{2} \lVert q-k_{i} \rVert^2 $$

Can be expanded as

$$ = q^\top{k_{i}} - \frac{1}{2}\lVert k_{i} \rVert^2 - \frac{1}{2}\lVert q \rVert^2 $$

For the last term $ - \frac{1}{2}\lVert q \rVert^2 $, since $q$ is the same for all $(q, k_{i})$ pairs, it will disappear after normalization.

For the term $- \frac{1}{2}\lVert k_{i} \rVert^2$, since both batch and layer normalization lead to activations with well-bounded and often constant norms $\lVert k_{i} \rVert$, after normalization we can drop it without major change to the outcome as well.)

#### Scaled dot product attention weights

Assume that all the elements of the query $ q \in \mathbb{R}^d $ and the key $ k \in \mathbb{R}^d $ are independent and identically drawn random variables with zero mean and unit variance. Their dot product will have zero mean and a variance of $d$. 

To keep the variance of the dot product to 1 regardless of the vector length $d$, we need to scale the dot product attention weights:

$$ \alpha(q, k_{i}) = q^\top k_{i} / \sqrt{d} $$

Finally we still need to normalize the weights $\alpha$ to be non-negative and sum to one, so we apply the softmax and arrive at the most commonly used mechanism for attention weights - the scaled dot product attention weights.

$$ \alpha(q, k_{i}) = softmax(\alpha(q, k_{i})) $$

$$ = \frac{exp(q^{\top}k_{i}/\sqrt{d})}{\sum_{j=1}exp(q^{\top}k_{j}/\sqrt{d})} $$

#### Batch Matrix Multplication

In practice, we use minibatches for computing efficiency, thus calculating the dot product attention for multiple queries at once. Say we compute attention for $n$ queries on $m$ key-value pairs, each query and key has the same dimension $d$, each value has dimension $v$. So we have three matrices: 
- $Q \in \mathbb{R}^{n \times d}$, $n$ rows of $d$-dimensional query vectors
- $K \in \mathbb{R}^{m \times d}$, $m$ rows of $d$-dimensional key vectors
- $V \in \mathbb{R}^{m \times v}$, $m$ rows of $v$-dimensional value vectors

So the scaled dot product attention of matrices $Q$, $K$, $V$ can be written as

$$ softmax(\frac{QK^\top}{\sqrt{d}})V \in \mathbb{R}^{n \times v}$$

The dimension of the calculated attention value matrix is $n \times v$,  namely, one value vector row for each query row. 

### PyTorch implementation

Before anything, import the dependencies:

In [1]:
import math
import torch
from torch import nn
from d2l import torch as d2l

Define a convenient softmax function optionally taking in a mask, so we can handle input sequences with different lengths (when they end up in the same minibatch, shorter sequences are typically padded with dummy tokens, to which we don't want our model to pay attention). 

In [None]:
def masked_softmax(X, valid_lens = None):
    """
    Perform softmax while optionally applying a sequence mask on the input X's last axis.
    
    Parameters
    --------
    X : 3D tensor (num_of_batches, batch_size, max_seq_len)
        The input tensor with dimensions 
    valid_lens : 1D or 2D tensor
        The valid length(s) of the sequence. Using a 1D tensor assumes the valid sequence length is the same in each batch. Using a 2D tensor allows the function to work with different sequence lengths in the same batch.
    """
    
    # helper function to apply mask to a sequnce
    def _sequence_mask(X, valid_len, mask_value=0):
        """
        Create a mask according to valid_len and return the a new tensor of X being applied the mask with mask_value
        
        Parameters
        --------
        X : 2D tesnsor of dimension (num_sequences, max_seq_len)
        valid_len : 1D tensor of dimension (1, num_sequences)
        """
        maxlen = X.size(1) 
        seq_ids = torch.arange(maxlen, dtype=torch.float32, device=X.device)
        mask = seq_ids[None, :] < valid_len[:, None]        
        X[~mask] = mask_value
        return X
        
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape # (num_of_batches x batch_size x padded_sequence_len)
        if valid_lens.dim() == 1:
            # if valid_lens is a 1D tensor, we want to make it 2D to work with our helper function by repeating it batch_size times 
            ## Intuition
            ## Suppose X has 2 batches each containing 3 sequences and the max length of all sequences is 4, i.e., X.shape is (2 x 3 x 4) 
            ## Suppose valid_lens is a 1D tensor [2, 4] denoting that the first batch has sequence max_len 2, second batch 4
            ## To convert it to the 2D tensor, since the batch size is 3 (shape[1]), we want our valid_lens tensor to be [[2, 2, 2], [4, 4, 4]]
            # which is practically repeat each element of the original 1D tensor 3(shape[1]) times
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])

Let's take a closer look at this implementation. 

First the helper function `_sequence_mask(X, valid_len, mask_value=0)`:
1. The input arguments:
   1.  `X` is not directly the outer function's input argument X (which is a 3D tensor X of dimension (num_batches, batch_size, max_seq_len)), but its reshaped version of a 2D tensor, with dimensions `(num_sequences (aka num_batches * batch_size), max_seq_len)`.
   2.  `valid_len` is not directly the outer function's input argument `valid_lens`, either. While the outer function accepts either a 1D or a 2D tensor, the inner function accepts only its sanitized and by necessity reshaped form, which is a 1D tensor of each sequence's valid length across all sequences in the input. Thus its dimension is (1, num_sequences)
2. How the boolean mask is created:
   1.  Get the max sequence length, which is the 2nd dimension of input `X`
   2.  Use `torch.arange` and the max sequence length as its upper bound to create a vector of all position indices in a max length sequence, `seq_ids`
   3.  Three steps to create the boolean mask for each sequence's each position
      1.  expand the `seq_ids` 1D tensor to a 2D tensor by prepending a dimension, so that it's of dimension `(1, max_seq_len)`
      2. expand the `valid_len` 1D tensor to a 2D tensor by appending a dimension, so that it's of dimension `(batch_size * num_batches, 1)`
      3. create the boolean mask by element-wise comparing if a `seq_id` is smaller than the corresponding position's `valid_len`, which renders `True` if that position is within that sequence's valid length, `False` otherwise
         1. Here we see a clever use of broadcasting: element-wise comparing two tensors of different dimensions will cause them to be both broadcasted. So both `seq_ids` and `valid_len` are broadcasted to be of dimension `(batch_size * num_batches, max_seq_len)`, essenially giving us a mask of the same dimension as this function's input X `(num_sequences, max_seq_len)`.
3. How the boolean mask is applied
   1. The boolean mask we created in step 2 has `True` in all positions that are within the valid length. Taking its complement `~mask` flips the `True` to `False`, now all the invalid positions will be `True` and we can use it to filter the invalid positions in `X`
   2. `X[~mask]` selects all the invalid positions in X. Now we apply the mask value to these positions: `X[~mask] = mask_value`
After all these, we return the masked X matrix.

Define the `DotProductAttention` class as a subclass of [PyTorch's `nn.Module`](https://pytorch.org/docs/stable/notes/modules.html)