# From scratch implementation of the Transformer model in PyTorch.

Further reading: [Attention is all you need](https://arxiv.org/abs/1706.03762) and [Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/)

<img src="https://d2l.ai/_images/transformer.svg" alt="The transformer architecture" >

In [51]:
import math
import copy
import numpy as np
from torch import nn
import torch.nn.functional as F

## Embeddings

The embedding layer is the first layer of the Transformer model. The embedding layer converts the input tokens and target tokens into vectors of size $d_{\text{model}}$.

In [52]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.embedding = nn.Embedding(vocab, d_model)

    def forward(self, x):
        return self.embedding(x)

## Positional Encoding

Because Transformers models are order-invariant, they rely on the positional encoding to capture the position of the word in the sentence.
The positional encoding is a function of the position of the word in the sentence. The positional encoding is added to the embedding of the input. The positional encoding is a vector of size $d_{\text{model}}$ and is added to the embedding of the input.
Positional encodings can either be learned or fixed. In this implementation, we use a fixed positional encoding.
The fixed positional encoding add sine and cosine functions of different frequencies to the embedding of the input. The sine and cosine functions of different frequencies are used to capture the position of the word in the sentence.
The transformer model will then learn to extract the position of the word from the positional encoding.

In [53]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        # even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        # odd indices
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        # register_buffer is a special kind of attribute that is not
        # considered a model parameter, so that it will not be returned
        # by model.parameters()
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

## Multi-Head Attention

The multi-head attention layer is a key component of the Transformer model, composed of $h$ scaled dot-product attention sublayers. It takes query, key, and value vectors of size $d_{\text{model}}$ as inputs, and projects them into $h$ smaller vectors before computing the scaled dot-product attention on each. The outputs of the $h$ attention sublayers are concatenated and transformed back to a vector of size $d_{\text{model}}$. The multi-head attention layer is followed by residual connection and layer normalization. The recent implementation of pre-normalized multi-head attention places the layer normalization at the beginning of the sublayers (called "Pre-Norm").

In [54]:
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.heads = heads
        self.d_k = d_model // heads
        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        bs = q.size(0)
        # perform linear operation and split into h heads
        k = self.k_linear(k).view(bs, -1, self.heads, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.heads, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.heads, self.d_k)
        # transpose to get dimensions bs * heads * sl * d_model
        k = k.transpose(1,2)
        q = q.transpose(1,2)
        v = v.transpose(1,2)
        # calculate dot product attention
        scores = attention(q, k, v, self.d_k, mask, self.dropout)
        # concatenate heads and put through final linear layer
        concat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model)
        output = self.out(concat)
        return output

def attention(q, k, v, d_k, mask=None, dropout=None):
    scores = torch.matmul(q, k.transpose(-2, -1)) /  np.sqrt(d_k)
    if mask is not None:
        mask = mask.unsqueeze(1)
        scores = scores.masked_fill(mask == 0, -1e9)
    scores = torch.softmax(scores, dim=-1)
    if dropout is not None:
        scores = dropout(scores)
    output = torch.matmul(scores, v)
    return output

## Position-wise Feed-Forward Networks

The position-wise feed-forward network is a two-layer feed-forward network with a non-linear activation function in between (originally ReLU, recent application use Mish/GeLU). The position-wise feed-forward network is applied to each position separately and identically. The position-wise feed-forward network is followed by residual connection and layer normalization.

In [55]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.gelu(self.w_1(x))))

## Encoder and Decoder Stacks

The encoder and decoder stacks are composed of $N$ identical layers. Each layer is composed of a multi-head attention layer, a position-wise feed-forward network, and residual connection and layer normalization.

In [56]:
class ResidualConnection(nn.Module):
    def __init__(self, size, dropout):
        super(ResidualConnection, self).__init__()
        self.norm = nn.LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def __call__(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))


class EncoderLayer(nn.Module):
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = nn.Sequential(
            ResidualConnection(size, dropout),
            ResidualConnection(size, dropout)
        )

    def forward(self, x, mask):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

## Decoder Layer

The decoder is composed of a stack of $N$ identical layers. Each layer is composed of a multi-head attention layer, a position-wise feed-forward network, and residual connection and layer normalization. The multi-head attention layer is applied to the output of the decoder stack, and the encoder output. The encoder output is used as the key and value vectors, while the output of the decoder stack is used as the query vector. The multi-head attention layer is followed by residual connection and layer normalization. The output of the decoder stack is then used as the query vector, and the output of the decoder stack is used as the key and value vectors. The multi-head attention layer is followed by residual connection and layer normalization. The output of the decoder stack is then used as the input to the position-wise feed-forward network. The position-wise feed-forward network is followed by residual connection and layer normalization.

In [57]:
class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = nn.Sequential(
            ResidualConnection(size, dropout),
            ResidualConnection(size, dropout),
            ResidualConnection(size, dropout)
        )

    def forward(self, x, memory, src_mask, tgt_mask):
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)

## Encoder and Decoder Stacks

The encoder and decoder stacks are composed of $N$ identical layers. Each layer is composed of a multi-head attention layer, a position-wise feed-forward network, and residual connection and layer normalization.

In [58]:
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = nn.LayerNorm(layer.size)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class Decoder(nn.Module):
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = nn.LayerNorm(layer.size)

    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

## Generator

The final output layer of the model is a linear layer and a softmax function.

In [59]:
class Generator(nn.Module):
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)

    def forward(self, x):
        return F.log_softmax(self.proj(x), dim=-1)

## Full Model

In [60]:
class Transformer(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(Transformer, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator

    def forward(self, src, tgt, src_mask, tgt_mask):
        "Take in and process masked src and target sequences."
        return self.decode(self.encode(src, src_mask), src_mask,
                           tgt, tgt_mask)

    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

Sanity check: The output of the encoder should be the same size as the input.

In [61]:
model = Transformer(
    Encoder(EncoderLayer(512, MultiHeadAttention(8, 512), PositionwiseFeedForward(512, 2048), 0.1), 6),
    Decoder(DecoderLayer(512, MultiHeadAttention(8, 512), MultiHeadAttention(8, 512), PositionwiseFeedForward(512, 2048), 0.1), 6),
    nn.Sequential(Embeddings(512, 10000), PositionalEncoding(512, 0)),
    nn.Sequential(Embeddings(512, 10000), PositionalEncoding(512, 0)),
    Generator(512, 10000)
)

In [62]:
long_input_data = torch.LongTensor(10, 32).random_(0, 10000)
output = model.encode(long_input_data, None)
assert output.size() == (10, 32, 512)