<center><h1>Transformer</h1> </center>

<center><p><a href="https://arxiv.org/abs/1706.03762">Attention is All You Need
</a></p></center>

<img src="https://production-media.paperswithcode.com/methods/new_ModalNet-21.jpg" width="600"/>

[Code: https://github.com/harvardnlp/annotated-transformer/](https://github.com/harvardnlp/annotated-transformer/)

[Code with annotation: http://nlp.seas.harvard.edu/annotated-transformer/](http://nlp.seas.harvard.edu/annotated-transformer/)


In [1]:
import copy
import math

import torch
from torch import nn

# Model Architecture

## Embeddings

### Input/Output Embedding

In [2]:
class Embedding(nn.Module):
    def __init__(self, vocab, d_model=512):
        """
        Embeddings and Softmax.
        :param vocab: dictionary size of the source vocabulary.
        :param d_model: the size of each embedding vector. Default: 512
        """
        super(Embedding, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        """
        Embeddings and Softmax.
        :param x: (batch_size, seq_length)
        :return: (batch_size, seq_length, d_model)
        """
        return self.lut(x) * math.sqrt(self.d_model)

### Positional Encoding

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model=512, dropout=0.1, max_len=5000):
        """
        Positional Encoding.
        :param d_model: the size of each embedding vector. Default: 512
        :param dropout: probability of an element to be zeroed. Default: 0.1
        :param max_len: max length of the sequence. Default: 5000
        """
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float()
            * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        """
        Positional Encoding.
        :param x: (batch_size, seq_length, d_model)
        :return: (batch_size, seq_length, d_model)
        """
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

## Sublayers

### Multi-Head Attention

In [4]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, h=8, d_model=512, dropout=0.1):
        """
        Multi-Head Self-Attention Mechanism.
        :param h: the number of heads. Default: 8
        :param d_model: the size of each embedding vector. Default: 512
        :param dropout: probability of an element to be zeroed. Default: 0.1
        """
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        """
        Multi-Head Self-Attention Mechanism.
        :param query: (batch_size, src/tgt/tgt_seq_length, d_model)
        :param key: (batch_size, src/tgt/src_seq_length, d_model)
        :param value: (batch_size, src/tgt/src_seq_length, d_model)
        :param mask: (batch_size, 1/tgt/1_seq_length, src/tgt/src_seq_length)
        :return: (batch_size, src/tgt/tgt_seq_length, d_model)
        """
        if mask is not None:
            mask = mask.unsqueeze(1)
        batch_size = query.size(0)

        query, key, value = [
            lin(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
            for lin, x in zip(self.linears, (query, key, value))
        ]

        x, self.attn = attention(
            query, key, value, mask=mask, dropout=self.dropout
        )

        x = (
            x.transpose(1, 2)
            .contiguous()
            .view(batch_size, -1, self.h * self.d_k)
        )

        del query
        del key
        del value
        return self.linears[-1](x)

### Feed Forward

In [5]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model=512, d_ff=2048, dropout=0.1):
        """
        Position-wise Fully Connected Feed-Forward Network.
        :param d_model: the size of each embedding vector. Default: 512
        :param d_ff: dimension of the inner layer. Default: 2048
        :param dropout: probability of an element to be zeroed. Default: 0.1
        """
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Position-wise Fully Connected Feed-Forward Network.
        :param x: (batch_size, seq_length, d_model)
        :return: (batch_size, seq_length, d_model)
        """
        return self.w_2(self.dropout(self.w_1(x).relu()))

### Sublayer Connection

In [6]:
class SublayerConnection(nn.Module):
    def __init__(self, d_model=512, dropout=0.1):
        """
        We employ a residual connection around each of the two sub-layers,
        followed by layer normalization (Actually we use pre-layer norm).
        :param d_model: the size of each embedding vector. Default: 512
        :param dropout: probability of an element to be zeroed. Default: 0.1
        """
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        """
        We employ a residual connection around each of the two sub-layers,
        followed by layer normalization (Actually we use pre-layer norm).
        :param x: (batch_size, seq_length, d_model)
        :param sublayer: attention or feed-forward network.
        :return: (batch_size, seq_length, d_model)
        """
        return x + self.dropout(sublayer(self.norm(x)))

## Layers

### Encoder Layer

In [7]:
class EncoderLayer(nn.Module):
    def __init__(
            self,
            self_attn: MultiHeadedAttention,
            feed_forward: PositionwiseFeedForward,
            d_model=512,
            dropout=0.1,
    ):
        """
        Identical Encoder Layer.
        :param self_attn: multi-head self-attention.
        :param feed_forward: position-wise fully connected feed-forward network.
        :param d_model: the size of each embedding vector. Default: 512
        :param dropout: probability of an element to be zeroed. Default: 0.1
        """
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayers = clones(SublayerConnection(d_model, dropout), 2)
        self.d_model = d_model

    def forward(self, x, mask):
        """
        Identical Encoder Layer.
        :param x: (batch_size, src_seq_length, d_model)
        :param mask: (batch_size, 1, src_seq_length)
        :return: (batch_size, src_seq_length, d_model)
        """
        x = self.sublayers[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayers[1](x, self.feed_forward)

### Decoder Layer

In [8]:
class DecoderLayer(nn.Module):
    def __init__(
            self,
            self_attn: MultiHeadedAttention,
            src_attn: MultiHeadedAttention,
            feed_forward: PositionwiseFeedForward,
            d_model=512,
            dropout=0.1
    ):
        """
        Identical Decoder Layer.
        :param self_attn: multi-head self-attention.
        :param src_attn: encoder-decoder attention.
        :param feed_forward: position-wise fully connected feed-forward network.
        :param d_model: the size of each embedding vector. Default: 512
        :param dropout: probability of an element to be zeroed. Default: 0.1
        """
        super(DecoderLayer, self).__init__()
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayers = clones(SublayerConnection(d_model, dropout), 3)
        self.d_model = d_model

    def forward(self, x, memory, src_mask, tgt_mask):
        """
        Identical Decoder Layer.
        :param x: (batch_size, tgt_seq_length, d_model)
        :param memory: (batch_size, src_seq_length, d_model)
        :param src_mask: (batch_size, 1, src_seq_length)
        :param tgt_mask: (batch_size, tgt_seq_length, tgt_seq_length)
        :return: (batch_size, tgt_seq_length, d_model)
        """
        m = memory
        x = self.sublayers[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayers[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayers[2](x, self.feed_forward)

## Encoder

In [9]:
class Encoder(nn.Module):
    def __init__(self, layer: EncoderLayer, n_layers=6):
        """
        Encoder Stack.
        :param layer: identical encoder layer.
        :param n_layers: the number of encoder layers. Default: 6
        """
        super(Encoder, self).__init__()
        self.layers = clones(layer, n_layers)
        self.norm = LayerNorm(layer.d_model)

    def forward(self, x, mask):
        """
        Encoder Stack.
        :param x: (batch_size, src_seq_length, d_model)
        :param mask: (batch_size, 1, src_seq_length)
        :return: (batch_size, src_seq_length, d_model)
        """
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

## Decoder

In [10]:
class Decoder(nn.Module):
    def __init__(self, layer: DecoderLayer, n_layers=6):
        """
        Decoder Stack.
        :param layer: identical decoder layer.
        :param n_layers: the number of decoder layers. Default: 6
        """
        super(Decoder, self).__init__()
        self.layers = clones(layer, n_layers)
        self.norm = LayerNorm(layer.d_model)

    def forward(self, x, memory, src_mask, tgt_mask):
        """
        Decoder Stack.
        :param x: (batch_size, tgt_seq_length, d_model)
        :param memory: (batch_size, src_seq_length, d_model)
        :param src_mask: (batch_size, 1, src_seq_length)
        :param tgt_mask: (batch_size, tgt_seq_length, tgt_seq_length)
        :return: (batch_size, tgt_seq_length, d_model)
        """
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

## Generator

In [11]:
class Generator(nn.Module):
    def __init__(self, vocab, d_model=512):
        """
        Generator of Log Probabilities.
        :param vocab: dictionary size of the target vocabulary.
        :param d_model: the size of each embedding vector. Default: 512
        """
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        """
        Generator of Log Probabilities.
        :param x: (batch_size, tgt_seq_length, d_model)
        :return: (batch_size, tgt_seq_length, vocab)
        """
        return self.softmax(self.proj(x))

## Transformer

### Encoder-Decoder

In [12]:
class EncoderDecoder(nn.Module):
    def __init__(
            self,
            encoder: Encoder,
            decoder: Decoder,
            src_embed: nn.Sequential,
            tgt_embed: nn.Sequential,
            generator: Generator,
    ):
        """
        Encoder-Decoder Architecture.
        :param encoder: encoder stack.
        :param decoder: decoder stack.
        :param src_embed: sequence of embedding and positional encoding.
        :param tgt_embed: sequence of embedding and positional encoding.
        :param generator: generator of log probabilities.
        """
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator

    def forward(self, src, tgt, src_mask, tgt_mask):
        """
        Encoder-Decoder Architecture.
        :param src: (batch_size, src_seq_length)
        :param tgt: (batch_size, tgt_seq_length)
        :param src_mask: (batch_size, 1, src_seq_length)
        :param tgt_mask: (batch_size, tgt_seq_length, tgt_seq_length)
        :return: (batch_size, tgt_seq_length)
        """
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)

    def encode(self, src, src_mask):
        """
        Encoding Process.
        :param src: (batch_size, src_seq_length)
        :param src_mask: (batch_size, 1, src_seq_length)
        :return: (batch_size, src_seq_length, d_model)
        """
        return self.encoder(self.src_embed(src), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask):
        """
        Encoding Process.
        :param memory: (batch_size, src_seq_length, d_model)
        :param src_mask: (batch_size, 1, src_seq_length)
        :param tgt: (batch_size, tgt_seq_length)
        :param tgt_mask: (batch_size, tgt_seq_length, tgt_seq_length)
        :return: (batch_size, tgt_seq_length, d_model)
        """
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

### Transformer

In [13]:
def transformer(
        src_vocab,
        tgt_vocab,
        n_layers=6,
        d_model=512,
        d_ff=2048,
        n_heads=8,
        dropout=0.1,
        max_len=5000,
):
    """
    Create a Transformer Block.
    :param src_vocab: dictionary size of the source vocabulary.
    :param tgt_vocab: dictionary size of the target vocabulary.
    :param n_layers: the number of decoder layers. Default: 6
    :param d_model: the size of each embedding vector. Default: 512
    :param d_ff: dimension of the inner layer. Default: 2048
    :param n_heads: the number of heads. Default: 8
    :param dropout: probability of an element to be zeroed. Default: 0.1
    :param max_len: max length of the sequence. Default: 5000
    :return: transformer encoder-decoder architecture.
    """
    c = copy.deepcopy
    attn = MultiHeadedAttention(n_heads, d_model, dropout)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout, max_len)
    model = EncoderDecoder(
        Encoder(EncoderLayer(c(attn), c(ff), d_model, dropout), n_layers),
        Decoder(DecoderLayer(c(attn), c(attn), c(ff), d_model, dropout), n_layers),
        nn.Sequential(Embedding(src_vocab, d_model), c(position)),
        nn.Sequential(Embedding(tgt_vocab, d_model), c(position)),
        Generator(tgt_vocab, d_model),
    )
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return model

# Utils

## Attention

In [14]:
def attention(query, key, value, mask=None, dropout=None):
    """
    Scaled Dot-Product Attention.
    :param query: (batch_size, head_num, src/tgt/tgt_seq_length, d_k)
    :param key: (batch_size, head_num, src/tgt/src_seq_length, d_k)
    :param value: (batch_size, head_num, src/tgt/src_seq_length, d_k)
    :param mask: (batch_size, 1, 1/tgt/1_seq_length, src/tgt/src_seq_length)
    :param dropout: probability of an element to be zeroed. Default: None
    :return: (batch_size, head_num, src/tgt/tgt_seq_length, d_k), (batch_size, head_num, src/tgt/tgt_seq_length, src/tgt/src_seq_length)
    """
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = scores.softmax(dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

## Layer Normalization

In [15]:
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

## Clone

In [16]:
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

# Summary

## Mask

In [17]:
def subsequent_mask(size):
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
        torch.uint8
    )
    return subsequent_mask == 0

## Data

In [18]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

vocab = 25000
batch_size = 1
seq_length = 500

data = torch.randint(1, vocab, size=(batch_size, seq_length)).to(device)
data[:, 0] = 1
src = data.clone()
tgt = data.clone()
src_mask = (src != 0).unsqueeze(-2)
tgt = tgt[:, :-1]
tgt_mask = (tgt != 0).unsqueeze(-2)
tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)

## Transformer (base model)

In [19]:
from torchkeras import summary

net = transformer(
    src_vocab=vocab,
    tgt_vocab=vocab,
    n_layers=6,
    d_model=512,
    d_ff=2048,
    n_heads=8,
    dropout=0.1,
    max_len=5000,
).to(device)

summary(net, input_data_args=[src, tgt, src_mask, tgt_mask]);

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
Embedding-1                           [-1, 500, 512]           12,800,000
Dropout-2                             [-1, 500, 512]                    0
LayerNorm-3                           [-1, 500, 512]                1,024
Linear-4                              [-1, 500, 512]              262,656
Linear-5                              [-1, 500, 512]              262,656
Linear-6                              [-1, 500, 512]              262,656
Dropout-7                          [-1, 8, 500, 500]                    0
Linear-8                              [-1, 500, 512]              262,656
Dropout-9                             [-1, 500, 512]                    0
LayerNorm-10                          [-1, 500, 512]                1,024
Linear-11                            [-1, 500, 2048]            1,050,624
Dropout-12                           

## Transformer (big)

In [20]:
net = transformer(
    src_vocab=vocab,
    tgt_vocab=vocab,
    n_layers=6,
    d_model=1024,
    d_ff=4096,
    n_heads=16,
    dropout=0.3,
    max_len=5000,
).to(device)

summary(net, input_data_args=[src, tgt, src_mask, tgt_mask]);

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
Embedding-1                          [-1, 500, 1024]           25,600,000
Dropout-2                            [-1, 500, 1024]                    0
LayerNorm-3                          [-1, 500, 1024]                2,048
Linear-4                             [-1, 500, 1024]            1,049,600
Linear-5                             [-1, 500, 1024]            1,049,600
Linear-6                             [-1, 500, 1024]            1,049,600
Dropout-7                         [-1, 16, 500, 500]                    0
Linear-8                             [-1, 500, 1024]            1,049,600
Dropout-9                            [-1, 500, 1024]                    0
LayerNorm-10                         [-1, 500, 1024]                2,048
Linear-11                            [-1, 500, 4096]            4,198,400
Dropout-12                           