# Attention is all you need

I will be implementing the <a href="https://arxiv.org/pdf/1706.03762.pdf">**Attention is all you need**</a> paper following the wonderful work in <a href="http://nlp.seas.harvard.edu/2018/04/03/attention.html">**The Annotated Transformer**</a> by Harverd NLP.

This is made to clearly understand the transformer architecture.

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn
seaborn.set_context(context="talk")
%matplotlib inline

# Encoder-Decoder architecture

This is the base of many sequence to sequence models

In [2]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, source_embed, target_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.source_embed = source_embed
        self.target_embed = target_embed
        self.generator = generator
        
    def forward(self, source, target, source_mask, target_mask):
        encoded_result = self.encoder(self.source_embed(source), source_mask)
        decoded_result = self.decoder(self.target_embed(target), encoded_result, source_mask, target_mask)
        return decoded_result

In [3]:
class Generator(nn.Module):
    """
    Single linear layer to project decoder output into the vocabulary space.
    """
    def __init__(self, vocab, model_dim):
        super(Generator, self).__init__()
        self.projection = nn.Linear(model_dim, vocab)
        
    def forward(self, x):
        return F.log_softmax(self.projection(x), dim=-1)

# Encoder-Decoder stacks
## Encoder
<img src="https://lilianweng.github.io/lil-log/assets/images/transformer-encoder.png" width=500 height=600>

In [4]:
def clones(module, N):
    """
    Copies a module N times
    """
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [5]:
class Encoder(nn.Module):
    def __init__(self, module, N):
        super(Encoder, self).__init__()
        self.layers = clones(module, N)
        self.layer_norm = LayerNorm(module.size)
    
    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.layer_norm(x)  # but why do a layer_norm here?

Here I apply layer normalization:

<img src="https://miro.medium.com/max/498/1*VSYtYThHQtxq1dNkVu6leQ.png">

In [6]:
class LayerNorm(nn.Module):
    def __init__(self, module_size, epsilon=1e-6):
        super(LayerNorm, self).__init__()
        # creating two learnable parameters for layer norm
        self.a_2 = nn.Parameter(torch.ones(module_size))
        self.b_2 = nn.Parameter(torch.ones(module_size))
        self.epsilon = epsilon
        
    def forward(self, x):
        mean = x.mean(dim=-1, keepdims=True)
        std = x.std(dim=-1, keepdims=True)
        return (self.a_2 * (x - mean) / (std + self.epsilon)) + self.b_2

In [8]:
class SubLayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    In the paper is it denoted as: LayerNorm(x + sublayer(x)). Adding a dropout to the sublayer as well.
    """
    def __init__(self, size, dropout_prob):
        super(SubLayerConnection, self).__init__()
        self.layer_norm = LayerNorm(size)
        self.dropout = nn.Dropout(p=dropout_prob)
        
    def forward(self, x, sublayer):
        x = x + self.dropout(sublayer(x))
        return self.layer_norm(x)

Each encoder block has two sublayers:
- Multiheaded Attention
- Fully connected

In [9]:
class EncoderLayer(nn.Module):
    def __init__(self, multi_headed_self_attention, feed_forward, dropout_prob, size):
        super(EncoderLayer, self).__init__()
        self.multi_headed_self_attention = multi_headed_self_attention
        self.feed_forward = feed_forward
        self.sublayer = clones(SubLayerConnection(size, dropout_prob), 2)
        self.size = size
        
    def forward(self, x, mask):
        x = self.sublayer[0](x, lambda x: self.multi_headed_self_attention(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

## Decoder

<img src="https://lilianweng.github.io/lil-log/assets/images/transformer-decoder.png" width=500 height=700>

In [10]:
class Decoder(nn.Module):
    def __init__(self, module, N):
        super(Decoder, self).__init__()
        self.layers = clones(module, N)
        self.layer_norm = LayerNorm(module.size)
        
    def forward(self, x, memory, source_mask, target_mask):
        for layer in self.layers:
            x = layer(x, memory, source_mask, target_mask)
        return self.layer_norm(x)

In [11]:
class DecoderLayer(nn.Module):
    """
    multi_headed_self_attention is for the decoder input
    source_multi_headed_self_attention is for the encoder stack output
    """
    def __init__(self, multi_headed_self_attention, source_multi_headed_self_attention, 
                 feed_forward, dropout_prob, size):
        super(DecoderLayer, self).__init__()
        self.multi_headed_self_attention = multi_headed_self_attention
        self.source_multi_headed_self_attention = source_multi_headed_self_attention
        self.feed_forward = feed_forward
        self.sublayer = clones(SubLayerConnection(size, dropout_prob), 3)
        self.size = size
        
    def forward(self, x, memory, source_mask, target_mask):
        # memory is the encoder stack output
        x = self.sublayer[0](x, lambda x: self.multi_headed_self_attention(x, x, x, target_mask))
        x = self.sublayer[1](x, lambda x: self.source_multi_headed_self_attention(x, memory, memory, source_mask))
        return self.sublayer[2](x, self.feed_forward)