# package

In [34]:
import copy

import torch

import torch.nn as nn

# function

In [None]:
def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

In [2]:
def init_weight(layer):
    """
    it is necessary to init weight using some distribution, 
    since the default initialized weights from `torch.tensor` can be in large range,
    which may cause the model hard to converge.
    """
    nn.init.xavier_uniform_(layer.weight)
    if layer.bias is not None:
        nn.init.constant_(layer.bias, 0)

In [3]:
def attention(q, k, v, dk, mask, dropout):
    # scale inplace
    q.mul_(dk ** -0.5)
    
    # calculate similarity score and softmax
    score = torch.matmul(q, k.transpose(-2, -1)) # q:(batch_size, num_head, seqlen, dk) k.T: (batch_size, num_head, dk, seqlen) -> score:(batch_size, num_head, seqlen, seqlen)
    score = torch.softmax(score, dim=-1) # calculate probability on the last dimension

    # apply mask inplace
    if mask is not None:
        score.masked_fill_(mask.unsqueeze(1), -1e9)
    # apply dropout
    if dropout is not None:
        score = dropout(score)
    # apply similarity score on value
    result = torch.matmul(score, v) # score:(batch_size, num_head, seqlen, seqlen), v:(batch_size, num_head, seqlen, dk) -> result:(batch_size, num_head, seqlen, dk)
    return result

# sub-layer

In [4]:
class FeedForward(nn.Module):
    """
    implement point-wise Feed Forward sub-layer in transformer

    point-wise == apply same fc for each word

    map the extracted feature into the desired semantic space

    """
    def __init__(self, inputs_dim=512, hidden_dim=2048, dropout=0.1):
        super().__init__()
        
        self.fc1 = nn.Linear(inputs_dim, hidden_dim,)

        self.activation = nn.ReLU()

        self.fc2 = nn.Linear(hidden_dim, inputs_dim,)

        self.dropout = nn.Dropout(dropout)

        init_weight(self.fc1)
        init_weight(self.fc2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [5]:
class MultiHeadAttention(nn.Module):
    """
    implement Multi-head attention sub-layer in transformer

    purpose: extract sequential feature that is needed from input sequence

    speed: comparing to RNN that sequentially pass the feature, attention globally extract feature 

    performance: like multi-channel CNN, multi-head attention extract different feature pattern 
    """
    def __init__(self, embedding_dim, dropout=0.1, num_head=8):
        super().__init__()

        self.dk = embedding_dim // num_head #embedding_dim need to be greater than num_head
        self.num_head = num_head
        
        self.query = nn.Linear(embedding_dim, self.dk * num_head, bias=False)
        self.key = nn.Linear(embedding_dim, self.dk * num_head, bias=False)
        self.value = nn.Linear(embedding_dim, self.dk * num_head, bias=False)

        init_weight(self.query)
        init_weight(self.key)
        init_weight(self.value)

        self.dropout = nn.Dropout(dropout)

        self.output_linear = nn.Linear(self.dk * num_head, embedding_dim, bias=False)

        init_weight(self.output_linear)

    def forward(self, q, k, v, mask=None):

        batch_size = q.size(0)
        
        # perform linear operation and split into N heads
        q = self.query(q).view(batch_size, -1, self.num_head, self.dk)
        k = self.key(k).view(batch_size, -1, self.num_head, self.dk)
        v = self.value(v).view(batch_size, -1, self.num_head, self.dk)

        # transpose from (batch, seqlen, num_head, dk) -> (batch, num_head, seqlen, dk)
        k = k.transpose(1,2)
        q = q.transpose(1,2)
        v = v.transpose(1,2)

        #calculate attention
        att_res = attention(q, k, v, self.dk, mask, self.dropout)
        
        # concatenate heads 
        # (batch, num_head, seqlen, dk) -> (batch, seqlen, num_head, dk) -> (batch, seqlen, dk*num_head)
        att_res = att_res.transpose(1, 2).contiguous().view(batch_size, -1, self.dk*self.num_head)
        
        # pass though output linear layer
        att_res = self.output_linear(att_res)

        return att_res

# blocks

In [26]:
class encoder_block(nn.Module):
    def __init__(self, embedding_dim, dropout=0.1, num_head=8, hidden_dim=2048):
        super().__init__()

        self.att = MultiHeadAttention(embedding_dim, dropout=dropout, num_head=num_head)
        self.norm1 = nn.LayerNorm(embedding_dim)
        self.dropout1 = nn.Dropout(dropout)

        self.fc = FeedForward(inputs_dim=embedding_dim,hidden_dim=hidden_dim, dropout=dropout)
        self.norm2 = nn.LayerNorm(embedding_dim)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask):
        x1 = self.att(x, x, x, mask)
        x1 = self.dropout1(x1) + x
        x1 = self.norm1(x1)
        
        x2 = self.fc(x1)
        x2 = self.dropout2(x2) + x1
        x2 = self.norm2(x2)

        return x2

In [27]:
class decoder_block(nn.Module):
    def __init__(self, embedding_dim, dropout=0.1, num_head=8, hidden_dim=2048):
        super().__init__()

        self.att1 = MultiHeadAttention(embedding_dim, dropout=dropout, num_head=num_head)
        self.norm1 = nn.LayerNorm(embedding_dim)
        self.dropout1 = nn.Dropout(dropout)

        self.att2 = MultiHeadAttention(embedding_dim, dropout=dropout, num_head=num_head)
        self.norm2 = nn.LayerNorm(embedding_dim)
        self.dropout2 = nn.Dropout(dropout)

        self.fc = FeedForward(inputs_dim=embedding_dim,hidden_dim=hidden_dim, dropout=dropout)
        self.norm3 = nn.LayerNorm(embedding_dim)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, encode_output, x_mask, output_mask):
        x1 = self.att1(x, x, x, x_mask)
        x1 = self.dropout1(x1) + x
        x1 = self.norm1(x1)

        x2 = self.att2(k=x1, q=encode_output, v=encode_output, mask=output_mask)
        x2 = self.dropout2(x2) + x1
        x2 = self.norm2(x2)
        
        x3 = self.fc(x2)
        x3 = self.dropout2(x3) + x2
        x3 = self.norm3(x3)

        return x3

# layer

In [36]:
class encoder(nn.Module):
    def __init__(self, block, num_block, norm=None):
        super().__init__()
        self.blocks = _get_clones(block, num_block)
        self.N = num_block
        self.norm = norm
    
    def forward(self, x, mask):
        output = x
        for block in self.blocks:
            output = block(output, mask)
        
        if self.norm is not None:
            output = self.norm(output)

        return output

In [None]:
class decoder(nn.Module):
    def __init__(self, block, num_block, norm=None):
        super().__init__()
        self.blocks = _get_clones(block, num_block)
        self.N = num_block
        self.norm = norm
    
    def forward(self, x, mask):
        output = x
        for block in self.blocks:
            output = block(output, mask)
        
        if self.norm is not None:
            output = self.norm(output)

        return output

# model

In [None]:
class Transformer(nn.Module):
    def __init__(self, N):
        """
        N: number of encoder/decoder blocks
        """
        super().__init__()


# test

In [31]:
# batch, num_head, sentence_length, embedding_dim = 1, 8, 5, 3
# q = v = k = torch.randn(batch, num_head, sentence_length, embedding_dim)

batch, sentence_length, embedding_dim = 10, 6, 100
q = torch.randn(batch, sentence_length, embedding_dim)

model = decoder_block(embedding_dim)

res = model(q, q, None, None)