# Building a Transformer Encoder from Scratch
Charlie Ardern <br>
<hr>

Below I've implemented the encoder block from the paper [Attention is all you need, Vaswani et al., 2017](https://arxiv.org/abs/1706.03762). The code was inspired and guided by [this book](https://www.bishopbook.com/) and the free online materials from the deep learning course by the University of Amsterdam. Clearly I'm lacking some annotations here but this is more about just implementing the code. Readers are assumed to already understand how transformers work.

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

<hr>

## Positional Encoding
Before feeding data to the encoder, it is often beneficial to encode the positions. This is because the encoder is permutation equivariant so it won't be able to learn from the token positions. The [original paper](https://arxiv.org/abs/1706.03762) makes use of these:
$$\mathbf{X} := \mathbf{X} + \mathbf{P},$$
where:
$$ P_{ij} = \begin{cases} \sin \left[\frac{i}{10000^{j/d}}\right] \text{ for even } j\\ \cos \left[\frac{i}{10000^{(j-1)/d}}\right] \text{ for odd } j\end{cases}$$
Now let's implement this in code:

In [14]:
class PositionalEncoder(nn.Module):
    
    def __init__(self, n_max, d):
        super().__init__()
        self.n_max = n_max
        self.d = d
        P = torch.zeros(n_max, d)
        i = torch.arange(0,n_max, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d, 2).float() * (-math.log(10000.0) / d))
        P[:, 0::2] = torch.sin(pos * div)
        P[:, 1::2] = torch.cos(pos * div)
        P = P.unsqueeze(0)
        
        self.register_buffer('P', P, persistent=False) #Make it part of module's state
        
        def forward(self, X):
            X = X + self.P[:, X.size(1)]
            return X

<hr>

# Architecture of the Encoder
The structure of the encoder block can be seen in the figure below. (credit [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762))
<br>
<img src="transformer_architecture.svg">

**The transformer encoder will consist of these key stages:**
1. Multi-head attention
2. Add + layer norm
3. MLP
4. Add + layer norm
<br>

**Notation:**
- $b$ is the batch size
- $n$ is the sequence length
- $d$ is the dimension of the embeddings
<br>

Note that each stage maps $\mathbb{R}^{b\times n\times d} \to \mathbb{R}^{b\times n\times d}$ so the input and output dimensions are the same.

<hr>

## Multi-Head Attention Class
The computation steps and dimensions are shown below:<br>
<img src="MHA_steps.png" width="300px">
<br>
The MHA block is evidently rather involved so we'll build a class for it:

In [15]:
class MHA(nn.Module):
    
    def __init__(self, d, h):
        super().__init__()
        assert d % h == 0, "Dimension of embeddings must be divisible by number of heads"
        self.d = d
        self.h = h    
        self.X_to_QKV = nn.Linear(d, 3*d)
        self.stacked_heads_to_X = nn.Linear(d, d)
        self._reset_parameters()
        
    def _reset_parameters(self):
        nn.init.xavier_uniform_(self.X_to_QKV.weight)
        nn.init.xavier_uniform_(self.stacked_heads_to_X.weight)
        self.X_to_QKV.bias.data.fill_(0)
        self.stacked_heads_to_X.bias.data.fill_(0)
        
    def forward(self, X, mask=None):
        b, n, d = X.size()
        QKV = self.X_to_QKV(X)
        QKV = QKV.reshape(b, n, h, 3*d/h)
        QKV = QKV.permute(0, 2, 1, 3)
        Q, K, V = QKV.chunk(3, dim=-1)
        
        # Scaled dot product attention:
        logits = torch.matmul(Q,K.transpose(-2,-1))/torch.sqrt(d/h)
        
        #Swaps zero values for 
        if mask is not None:
            logits = attn_logits.masked_fill(mask == 0, -9e15)
        
        attention = F.softmax(logits, dim=-1)
        head_vals = torch.matmul(attention, V)
        
        head_vals = head_vals.permute(0, 2, 1, 3)
        stacked_heads = heav_vals.reshape(b, n, d)
        X = self.stacked_heads_to_X(stacked_heads)
        return X

<hr>

# Constructing the Full Encoder

In [16]:
class EncoderBlock(nn.Module):
    
    def __init__(self, d, h, mlp_dim, dropout=0.0):
        super().__init__()
        
        # Multi-Head Attention Block
        self.mha = MHA(d, h)
        
        # Multi-Layer Perceptron Block
        self.mlp = nn.Sequential(
            nn.Linear(d, mlp_dim),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(mlp_dim, d)
        )
        
        # Layer norms
        self.Lnorm1 = nn.LayerNorm(d)
        self.Lnorm2 = nn.LayerNorm(d)
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, X, mask=None):
        # 1. Attention step
        X = self.mha(X, mask=mask)
        
        # 2. Add + Layer Norm
        X = X + self.dropout(X)
        X = self.Lnorm1(X)
        
        # 3. MLP
        X = self.mlp(X)
        
        # 4. Add + Layer Norm
        X = X + self.dropout(X)
        X = self.Lnorm2(X)
        
        return X