In this article Transformer Network is implemented in Pytorch based on "Attention is all you need" paper. 
The first motivation of developing such network that is the sequential analysis is a bottleneck in sequential datas. What a transformer network is trying to do, is combining CNN and RNN specially attention mechanism, create a network that solves this bottelneck.
The network consist of some parts.
1. self attention:
    an attention mechanism relating to diffrent parts of a sequence.
2. Encoder
    the encoder block inputs (x1,...,xn) and outputs (z1,...,zn) and the decoder generates (y1,...,yn) based on input Z. Also encoder block consists of N=6 identical layers.
    the encoder consists of two parts. a multi-head attention and a feed forward network. and some residual connections between them. The residual is a normalization layer over output of each block plus output of the privious block.
    note that input dimentions (embedding dimention) of encoder is d=512 and to make use of residual connections we have to set output of each block to have same dimention of input so dimention of output of each block is d=512.
3. Decoder:
    The decoder is like encoder consists except it has a new multi-head layer to relate input of decoder to encoder features.
    

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


First of all we need a scaled dot product. Typically what this layer does is that it computes softmax(Q * K.T / sqrt(d_k)) * V. This computes the value of output in regard to each query and key (question and answers which are created during training). multiplying with value.
The shapes are like this. Q, K:(batch_size, sequence_number, dk) and V:(batch_size, sequence_number, dv).

In [2]:
class ScaledDotProduct(nn.Module):
    def __init__(self):
        super(ScaledDotProduct, self).__init__()
        self.activation = nn.Softmax(dim = -1)
        
    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
        x = torch.bmm(Q, K.transpose(-1, -2))
        dk = torch.tensor(K.size(-1))
        x = x.div(torch.sqrt(dk))
        x = self.activation(x)
        x = torch.bmm(x, V)
        return x

in order to test This class I create this paramters. Work on paper to see how it works.

In [3]:
def test_scaled_dot_product():
    batch_size = 2
    sequence_number = 5
    d_k = 4
    d_v = 6
    Q = torch.full((batch_size, sequence_number, d_k), 1, dtype= torch.float)
    K = torch.full((batch_size, sequence_number, d_k), 2, dtype= torch.float)
    V = torch.full((batch_size, sequence_number, d_v), 3, dtype= torch.float)
    scaled_dot_product = ScaledDotProduct()
    product = scaled_dot_product(Q, K, V)
#     torch_versio = torch.nn.functional.scaled_dot_product_attention(Q, K, V)
    print(product)
#     print(torch_versio)
test_scaled_dot_product()

tensor([[[3., 3., 3., 3., 3., 3.],
         [3., 3., 3., 3., 3., 3.],
         [3., 3., 3., 3., 3., 3.],
         [3., 3., 3., 3., 3., 3.],
         [3., 3., 3., 3., 3., 3.]],

        [[3., 3., 3., 3., 3., 3.],
         [3., 3., 3., 3., 3., 3.],
         [3., 3., 3., 3., 3., 3.],
         [3., 3., 3., 3., 3., 3.],
         [3., 3., 3., 3., 3., 3.]]])


ok, we implemented scaled dot product. Now we need to implement multi-head attention using this.

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_head, d_model):
        super(MultiHeadAttention, self).__init__()
        self.dk = 10
        self.dv = 12
        self.num_head = num_head
        self.WQ = nn.Parameter(torch.randn(self.num_head, d_model, self.dk))
        self.WK = nn.Parameter(torch.randn(self.num_head, d_model, self.dk))
        self.WV = nn.Parameter(torch.randn(self.num_head, d_model, self.dv))
        self.WO = nn.Parameter(torch.randn(self.num_head * self.dv, d_model))
        self.reset_parameters()
        self.scaled_dot_product = ScaledDotProduct()
        
    
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.WQ)
        nn.init.xavier_uniform_(self.WK)
        nn.init.xavier_uniform_(self.WV)
        
        
    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
        heads = list()
        for i in range(self.num_head):
            WQi, WKi, WVi = self.WQ[i, :, :], self.WK[i, :, :], self.WV[i, :, :]
            q = torch.bmm(Q, WQi.unsqueeze(0).repeat(Q.size(0), 1, 1))
            k = torch.bmm(K, WKi.unsqueeze(0).repeat(Q.size(0), 1, 1))
            v = torch.bmm(V, WVi.unsqueeze(0).repeat(Q.size(0), 1, 1))
            heads.append(self.scaled_dot_product(q,k,v))
        out = torch.cat(heads, dim=-1)
        out = torch.bmm(out, self.WO.unsqueeze(0).repeat(Q.size(0), 1, 1))
        return out

In order to test the code above.

In [7]:
def test_multi_head_attention():
    num_head = 8
    d_model = 128
    batch_size = 2
    sequence_number = 5
    multi_head_attention = MultiHeadAttention(num_head, d_model)
    Q = torch.full((batch_size, sequence_number, d_model), 1, dtype= torch.float)
    K = torch.full((batch_size, sequence_number, d_model), 2, dtype= torch.float)
    V = torch.full((batch_size, sequence_number, d_model), 3, dtype= torch.float)
    out = multi_head_attention(Q, K, V)
    print(out)
    # use torch version to see if we were correct
#     torch_multi_head_attention = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_head, batch_first=True)
#     torch_out = torch_multi_head_attention(Q, K, V, need_weights=False)
#     print(torch_out[0])
#     print(torch.eq(out, torch_out[0]).all())
test_multi_head_attention()

tensor([[[-18.4419,  -6.6514,   9.7013,  ..., -13.9123, -19.1616,  14.5940],
         [-18.4419,  -6.6514,   9.7013,  ..., -13.9123, -19.1616,  14.5940],
         [-18.4419,  -6.6514,   9.7013,  ..., -13.9123, -19.1616,  14.5940],
         [-18.4419,  -6.6514,   9.7013,  ..., -13.9123, -19.1616,  14.5940],
         [-18.4419,  -6.6514,   9.7013,  ..., -13.9123, -19.1616,  14.5940]],

        [[-18.4419,  -6.6514,   9.7013,  ..., -13.9123, -19.1616,  14.5940],
         [-18.4419,  -6.6514,   9.7013,  ..., -13.9123, -19.1616,  14.5940],
         [-18.4419,  -6.6514,   9.7013,  ..., -13.9123, -19.1616,  14.5940],
         [-18.4419,  -6.6514,   9.7013,  ..., -13.9123, -19.1616,  14.5940],
         [-18.4419,  -6.6514,   9.7013,  ..., -13.9123, -19.1616,  14.5940]]],
       grad_fn=<BmmBackward0>)


In [8]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(FeedForward, self).__init__()
        self.d_model = d_model
        self.layer1 = nn.Linear(d_model, d_ff)
        self.layer2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        batch_size, sequence_num, _ = x.size()
        x = x.view(-1, self.d_model)
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        x = x.view(batch_size, sequence_num, self.d_model)