**This tutorial is based on/ copied from http://peterbloem.nl/blog/transformers.**

# Imports

In [1]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

# Self-attention

Self-attention maps a sequence of vectors $x_1, ..., x_l$ to an output sequence of vectors $y_1, ..., y_l$ by taking weighted averages of the input:

$$y_i = \sum_j w_{ij}x_j$$

Here, $w_{ij}$ captures the interaction between inputs $x_i$ and $x_j$. For example, with the softmax over the inner products, i.e.

$$w'_{ij} = x_i^\text{T}x_j$$

$$w_{ij} = \frac{\exp(w'_{ij})}{\sum_j\exp(w'_{ij})}$$



<img src="https://raw.githubusercontent.com/leox1v/dl20/b3d5b5556d1b2bd360a4abeef4fd82f056ab0301/imgs/self-attention.svg?token=AD5WN2SV46NJT37O73GVHZS7VU2YK" width="600" valign="center"/>

In [None]:
# Our input x is a sequence of l vectors of dimension d. 
# Also, we want to process it in a batch of size b later on.
# So our dimension is [b, l, d].

# Let's start by using a random tensor for x.
b, l, d = 8, 4, 10
x = torch.rand(size=(b, l, d))
print(f'x: {x.shape}')

# To compute w', we use the batch matrix multiplication bmm.
# This results in dimension [b, l, l].
w_prime = torch.bmm(x, x.transpose(1, 2))

# By applying the softmax over the last dimension of w_prime, we obtain w.
w = F.softmax(w_prime, dim=-1)
print(f'w: {w.shape}')

# Now to obtain the sequence y (of dimension [b, l, d]), we take the weighted (by w) average of X.
y = torch.bmm(w, x)
print(f'y: {y.shape}')

x: torch.Size([8, 4, 10])
w: torch.Size([8, 4, 4])
y: torch.Size([8, 4, 10])


## Query, Key, Value
In this basic form of self-attention a single vector $x_i$ is used for three different tasks:
1. Used in the weights for its own output $y_i$. -> **query**
2. Used in the weights for the j-th output $y_j$. -> **key**
3. Used as part of the weighted sum.  -> **value**

To disentangle this 3 different 'roles' of $x_i$, we introduce a (learnable) linear transformation for each. In particular, we need 3 $d \times d$ weight matrices $W_q, W_k, W_v$:

$$q_i = W_qx_i \qquad \text(Query)$$

$$k_i = W_kx_i \qquad \text(Key)$$

$$v_i = W_vx_i \qquad \text(Value)$$

This gives the self-attention layer some controllable parameters, and allows it to modify the incoming vectors to suit the three roles they must play.

<img src="https://raw.githubusercontent.com/leox1v/dl20/b3d5b5556d1b2bd360a4abeef4fd82f056ab0301/imgs/key-query-value.svg?token=AD5WN2VWUZ4MAZY642K5OGK7VU3GQ" alt="drawing" width="500"/>

## Scaling the dot product

The softmax function can be sensitive to very large input values. These kill the gradient, and slow down learning. The average value of the dot product grows with the embedding dimension **d**, therefore, it helps to scale the dot product depending on this value:

$$w'_{ij}= \frac{q_i^\text{T}k_j}{\sqrt{d}}$$

We use $\sqrt{d}$ in the denominator because that's the euclidean length of a unit vector in $\mathbb{R}^d$.


## Multi-head attention

We can increase the representational power of the self attention by combining them. Instead of using only a single set of 3 transformation matrices $W_q, W_k, W_v$, we use many of them (indexed with $r$) $W^r_q, W^r_k, W^r_v$. These are called *attention heads*.

Using the individual attention heads, we produce multiple output vectors $y^r_i$ for a single input vector $x_i$. We can then concatenate the $y^r_i$ vectors and pass them through another linear transformation to reduce the dimension back to $d$.

Note for the implementation:
While we think about the attention heads as $h$ separate sets of three matrices (of shape $d\times d$), we implement it by 'stacking' them such that we have only a single set of three matrices of shape $d\times h*d$. This way we can compute all the concatenated queries, keys, and values in a single matrix multiplication.

## Implementation of a SelfAttention Module

In [2]:
# Let's implement a SelfAttention torch module.

class SelfAttention(nn.Module):
    """
    A SelfAttention model.
    
    Args:
        d: The embedding dimension.
        heads: The number of attention heads.
    """
    def __init__(self, d: int, heads: int=8):
        super().__init__()
        self.k, self.h = d, heads
        
        self.Wq = nn.Linear(d, d * heads, bias=False)
        self.Wk = nn.Linear(d, d * heads, bias=False)
        self.Wv = nn.Linear(d, d * heads, bias=False)
        
        # This unifies the outputs of the different heads into 
        # a single k-dimensional vector.
        self.unifyheads = nn.Linear(heads * d, d)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: The input embedding of shape [b, l, d].
            
        Returns:
            Self attention tensor of shape [b, l, d].
        """
        b, l, d = x.size()
        h = self.h
        
        # Transform the input embeddings x of shape [b, l, d] to queries, keys, values.
        # The output shape is [b, l, d*h] which we transform into [b, l, h, d]. Then,
        # we fold the heads into the batch dimenstion to arrive at [b*h, l, d]
        queries = self.Wq(x).view(b, l, h, d).transpose(1, 2).contiguous().view(b*h, l, d)
        keys = self.Wk(x).view(b, l, h, d).transpose(1, 2).contiguous().view(b*h, l, d)
        values = self.Wv(x).view(b, l, h, d).transpose(1, 2).contiguous().view(b*h, l, d)
        
        # Compute the product of queries and keys and scale with sqrt(d).
        # The tensor w' has shape (b*h, l, l) containing raw weights.
        #----------------
        # TODO
        w_prime = torch.bmm(queries, keys.transpose(1, 2)) / np.sqrt(d)
        #----------------

        # Compute w by normalizing w' over the last dimension.
        # Shape: [b*h, l, l]
        #----------------
        # TODO
        w = F.softmax(w_prime, dim=-1) 
        #----------------
        
        
        # Apply the self attention to the values.
        # Shape: [b*h, l, d]
        #----------------
        # TODO
        out = torch.bmm(w, values).view(b, h, l, d)
        #----------------
        
        
        # Swap h, l back.
        # Shape: [b, l, h*d]
        out = out.transpose(1, 2).contiguous().view(b, l, h * d)
        
        # Unify heads to arrive at shape [b, l, d].
        return self.unifyheads(out)


In [3]:
# Test it out.
b, l, d, h = 2, 4, 6, 8
sa = SelfAttention(d=d, heads=h)
x = torch.rand(size=(b, l, d))
sa(x).shape

torch.Size([2, 4, 6])

# Transformers

The transformer architecture consists of multiple transformer blocks that typically look like this: 

<img src="https://raw.githubusercontent.com/leox1v/dl20/b3d5b5556d1b2bd360a4abeef4fd82f056ab0301/imgs/transformer-block.svg?token=AD5WN2SZYWM6XGH5SXMZM7S7VU3H4" alt="drawing" width="500"/>


It combines a self attention layer, [layer normalization](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html), a feed forward layer and another layer normalization. Additionally, it uses residual connections around the self attention and feed forward layer.

In [None]:
class TransformerBlock(nn.Module):
    """
    A Transformer block consisting of self attention and ff-layer.
    
    Args:
        d (int): The embedding dimension.
        heads (int): The number of attention heads.
    """
    def __init__(self, d: int, heads: int=8, n_mlp: int=4):
        super().__init__()
        
        # The self attention layer.
        #----------------
        # TODO
        self.attention = SelfAttention(d, heads=heads)
        #----------------
        
        # The two layer norms.
        #----------------
        # TODO
        self.norm1 = nn.LayerNorm(d)
        self.norm2 = nn.LayerNorm(d)
        #----------------
        
        # The feed-forward layer.
        #----------------
        # TODO
        self.ff = nn.Sequential(
            nn.Linear(d, n_mlp*d),
            nn.ReLU(),
            nn.Linear(n_mlp*d, d)
        )
        #----------------
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: The input embedding of shape [b, l, d].
            
        Returns:
            Transformer output tensor of shape [b, l, d].
        """
        #----------------
        # TODO
        x_prime = self.attention(x)
        x = self.norm1(x_prime + x)
        
        x_prime = self.ff(x)
        return self.norm2(x_prime + x)
        #----------------
        
        