In this article Transformer Network is implemented in Pytorch based on "Attention is all you need" paper. 
The first motivation of developing such network that is the sequential analysis is a bottleneck in sequential datas. What a transformer network is trying to do, is combining CNN and RNN specially attention mechanism, create a network that solves this bottelneck.
The network consist of some parts.
1. self attention:
    an attention mechanism relating to diffrent parts of a sequence.
2. Encoder
    the encoder block inputs (x1,...,xn) and outputs (z1,...,zn) and the decoder generates (y1,...,yn) based on input Z. Also encoder block consists of N=6 identical layers.
    the encoder consists of two parts. a multi-head attention and a feed forward network. and some residual connections between them. The residual is a normalization layer over output of each block plus output of the privious block.
    note that input dimentions (embedding dimention) of encoder is d=512 and to make use of residual connections we have to set output of each block to have same dimention of input so dimention of output of each block is d=512.
3. Decoder:
    The decoder is like encoder consists except it has a new multi-head layer to relate input of decoder to encoder features.
    

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


First of all we need a scaled dot product. Typically what this layer does is that it computes softmax(Q * K.T / sqrt(d_k)) * V. This computes the value of output in regard to each query and key (question and answers which are created during training). multiplying with value.
The shapes are like this. Q, K:(batch_size, sequence_number, dk) and V:(batch_size, sequence_number, dv).

In [3]:
class ScaledDotProduct(nn.Module):
    def __init__(self):
        super(ScaledDotProduct, self).__init__()
        self.activation = nn.Softmax(dim = -1)
        
    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
        x = torch.bmm(Q, K.transpose(-1, -2))
        dk = torch.tensor(K.size(-1))
        x = x.div(torch.sqrt(dk))
        x = self.activation(x)
        x = torch.bmm(x, V)
        return x

in order to test This class I create this paramters. Work on paper to see how it works.

In [43]:
def test_scaled_dot_product():
    batch_size = 2
    sequence_number = 5
    d_k = 4
    d_v = 6
    Q = torch.full((batch_size, sequence_number, d_k), 1, dtype= torch.float)
    K = torch.full((batch_size, sequence_number, d_k), 2, dtype= torch.float)
    V = torch.full((batch_size, sequence_number, d_v), 3, dtype= torch.float)
    scaled_dot_product = ScaledDotProduct()
    product = scaled_dot_product(Q, K, V)
#     torch_versio = torch.nn.functional.scaled_dot_product_attention(Q, K, V)
    print(product.size())
#     print(torch_versio)
test_scaled_dot_product()

torch.Size([2, 5, 6])


ok, we implemented scaled dot product. Now we need to implement multi-head attention using this.

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_head, d_model):
        super(MultiHeadAttention, self).__init__()
        self.dk = 10
        self.dv = 12
        self.num_head = num_head
        self.WQ = nn.Parameter(torch.randn(self.num_head, d_model, self.dk))
        self.WK = nn.Parameter(torch.randn(self.num_head, d_model, self.dk))
        self.WV = nn.Parameter(torch.randn(self.num_head, d_model, self.dv))
        self.WO = nn.Parameter(torch.randn(self.num_head * self.dv, d_model))
        self.reset_parameters()
        self.scaled_dot_product = ScaledDotProduct()
        
    
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.WQ)
        nn.init.xavier_uniform_(self.WK)
        nn.init.xavier_uniform_(self.WV)
        
        
    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
        heads = list()
        for i in range(self.num_head):
            WQi, WKi, WVi = self.WQ[i, :, :], self.WK[i, :, :], self.WV[i, :, :]
            q = torch.bmm(Q, WQi.unsqueeze(0).repeat(Q.size(0), 1, 1))
            k = torch.bmm(K, WKi.unsqueeze(0).repeat(Q.size(0), 1, 1))
            v = torch.bmm(V, WVi.unsqueeze(0).repeat(Q.size(0), 1, 1))
            heads.append(self.scaled_dot_product(q,k,v))
        out = torch.cat(heads, dim=-1)
        out = torch.bmm(out, self.WO.unsqueeze(0).repeat(Q.size(0), 1, 1))
        return out

In order to test the code above.

In [44]:
def test_multi_head_attention():
    num_head = 8
    d_model = 128
    batch_size = 2
    sequence_number = 5
    multi_head_attention = MultiHeadAttention(num_head, d_model)
    Q = torch.full((batch_size, sequence_number, d_model), 1, dtype= torch.float)
    K = torch.full((batch_size, sequence_number, d_model), 2, dtype= torch.float)
    V = torch.full((batch_size, sequence_number, d_model), 3, dtype= torch.float)
    out = multi_head_attention(Q, K, V)
    print(out.size())
    # use torch version to see if we were correct
#     torch_multi_head_attention = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_head, batch_first=True)
#     torch_out = torch_multi_head_attention(Q, K, V, need_weights=False)
#     print(torch_out[0])
#     print(torch.eq(out, torch_out[0]).all())
test_multi_head_attention()

torch.Size([2, 5, 128])


The next block is a simple feed forward network.

In [14]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=2048):
        super(FeedForward, self).__init__()
        self.d_model = d_model
        self.layer1 = nn.Linear(d_model, d_ff)
        self.layer2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        size = x.size()
        x = x.view(-1, self.d_model)
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        x = x.view(size)
        return x

In [17]:
def test_feed_forward():
    num_head = 8
    d_model = 512
    d_ff = 2045
    batch_size = 2
    sequence_number = 5
    x = torch.full((batch_size, num_head, sequence_number, d_model), 1, dtype= torch.float)
    feed_forward = FeedForward(d_model, d_ff)
    out = feed_forward(x)
    print(out.size())
test_feed_forward()

torch.Size([2, 8, 5, 512])


Now we have all bulding blocks to create encoder and decoder. The residual connections are applied too.

In [47]:
class EncoderBlock(nn.Module):
    def __init__(self, num_head, d_model, d_ff):
        super(EncoderBlock, self).__init__()
        self.num_head, self.d_model, self.d_ff = num_head, d_model, d_ff
        self.multi_head_attention1 = MultiHeadAttention(num_head, d_model)
        self.feed_forward1 = FeedForward(d_model, d_ff)
        self.batch_norm1 = nn.BatchNorm1d(d_model)
        self.batch_norm2 = nn.BatchNorm1d(d_model)
        
    def batch_norm(self, x):
        x_temp = x.view(-1, self.d_model)
        x_temp = self.batch_norm1d(x_temp)
        x = x_temp.view(x.size())
        return x
        
    def forward(self, x):
        x1 = self.multi_head_attention1(x, x, x)
        x1 = torch.add(x, x1)
        
        x_temp = x1.view(-1, self.d_model)
        x_temp = self.batch_norm1(x_temp)
        x1 = x_temp.view(x1.size())
        
        x2 = self.feed_forward1(x1)
        x2 = torch.add(x2, x1)
        
        x_temp = x2.view(-1, self.d_model)
        x_temp = self.batch_norm2(x_temp)
        x2 = x_temp.view(x2.size())
        
        return x2

and in order to test this:

In [48]:
def test_encoder_block():
    num_head = 8
    d_model = 512
    d_ff = 2045
    batch_size = 2
    sequence_number = 5
    x = torch.full((batch_size, sequence_number, d_model), 1, dtype= torch.float)
    encoder_block = EncoderBlock(num_head, d_model, d_ff)
    out = encoder_block(x)
    print(out.size())
test_encoder_block()

torch.Size([2, 5, 512])


In the next block, we implement the stack of encoder blocks. N number of blocks are stacked on the top of each other to create Encoder.

In [49]:
class Encoder(nn.Module):
    def __init__(self, N, num_head, d_model, d_ff):
        super(Encoder, self).__init__()
        self.block_list = nn.ModuleList([EncoderBlock(num_head, d_model, d_ff) for _ in range(N)])
        
    def forward(self, x):
        for block in self.block_list:
            x = block(x)
        return x

In [50]:
def test_encoder():
    num_head = 8
    d_model = 512
    d_ff = 2045
    batch_size = 2
    sequence_number = 5
    N = 6
    x = torch.full((batch_size, sequence_number, d_model), 1, dtype= torch.float)
    encoder = Encoder(N, num_head, d_model, d_ff)
    out = encoder(x)
    print(out.size())
test_encoder()

torch.Size([2, 5, 512])


The exact same process is applied on decoder. Note that decoder inputs are the input sentence and the output of encoder block.

In [51]:
class DecoderBlock(nn.Module):
    def __init__(self, num_head, d_model, d_ff):
        super(DecoderBlock, self).__init__()
        self.num_head, self.d_model, self.d_ff = num_head, d_model, d_ff
        self.multi_head_attention1 = MultiHeadAttention(num_head, d_model)
        self.multi_head_attention2 = MultiHeadAttention(num_head, d_model)
        self.feed_forward1 = FeedForward(d_model, d_ff)
        self.batch_norm1 = nn.BatchNorm1d(d_model)
        self.batch_norm2 = nn.BatchNorm1d(d_model)
        self.batch_norm3 = nn.BatchNorm1d(d_model)
        
        
    def forward(self, x, encoder_out):
        x1 = self.multi_head_attention1(x, x, x)
        x1 = torch.add(x, x1)
        
        x_temp = x1.view(-1, self.d_model)
        x_temp = self.batch_norm1(x_temp)
        x1 = x_temp.view(x1.size())
        
        x2 = self.multi_head_attention2(x1, encoder_out, encoder_out)
        x2 = torch.add(x1, x2)
        
        x_temp = x2.view(-1, self.d_model)
        x_temp = self.batch_norm2(x_temp)
        x2 = x_temp.view(x2.size())
        
        x3 = self.feed_forward1(x2)
        x3 = torch.add(x2, x3)
        
        x_temp = x3.view(-1, self.d_model)
        x_temp = self.batch_norm3(x_temp)
        x3 = x_temp.view(x3.size())
        
        return x3

In [54]:
def test_encoder_block():
    num_head = 8
    d_model = 512
    d_ff = 2045
    batch_size = 2
    sequence_number = 5
    x = torch.full((batch_size, sequence_number, d_model), 1, dtype= torch.float)
    encode_out = torch.full((batch_size, sequence_number, d_model), 2, dtype= torch.float)
    decoder_block = DecoderBlock(num_head, d_model, d_ff)
    out = decoder_block(x, encode_out)
    print(out.size())
test_encoder_block()

torch.Size([2, 5, 512])


In [59]:
class Decoder(nn.Module):
    def __init__(self, N, num_head, d_model, d_ff):
        super(Decoder, self).__init__()
        self.block_list = nn.ModuleList([DecoderBlock(num_head, d_model, d_ff) for _ in range(N)])
        
    def forward(self, x, encoder_out):
        for block in self.block_list:
            x = block(x, encoder_out)
        return x

In [61]:
def test_decoder():
    num_head = 8
    d_model = 512
    d_ff = 2045
    batch_size = 2
    sequence_number = 5
    N = 6
    x = torch.full((batch_size, sequence_number, d_model), 1, dtype= torch.float)
    encoder_out = torch.full((batch_size, sequence_number, d_model), 2, dtype= torch.float)
    decoder = Decoder(N, num_head, d_model, d_ff)
    out = decoder(x, encoder_out)
    print(out.size())
test_decoder()

torch.Size([2, 5, 512])


Now its time to put everything together and create the whole model

In [62]:
class Transformer(nn.Module):
    def __init__(self, N, num_head, d_model, d_ff):
        super(Transformer, self).__init__()
        self.encoder = Encoder(N, num_head, d_model, d_ff)
        self.decoder = Decoder(N, num_head, d_model, d_ff)
        
    def forward(self, x):
        encoder_out = self.encoder(x)
        out = self.decoder(x, encoder_out)
        return out

In [66]:
def test_transformer():
    num_head = 8
    d_model = 512
    d_ff = 2045
    batch_size = 2
    sequence_number = 5
    N = 6
    
    x = torch.full((batch_size, sequence_number, d_model), 1, dtype= torch.float)
    transformer = Transformer(N, num_head, d_model, d_ff)
    out = transformer(x)
    
    print(out.size())
test_transformer()

torch.Size([2, 5, 512])
