In [3]:
import torch
from torch import nn as nn
import copy

In [4]:
class Transformer(nn.Module):
    def __init__(self, encoder, decoder):
        super(Transformer, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, x, z, mask):
        c = self.encoder(x, mask)
        y = self.decoder(z, c)
        return y


In [5]:
class Encoder(nn.Module):
    def __init__(self, encoder_layer, n_layer):
        super(Encoder,self).__init__()
        self.layers = []
        for i in range(n_layer):
            self.layer.append(copy.deepcopy(encoder_layer))
        
    def forward(self, x, mask):
        out = x
        for layer in self.layers:
            out = layer(out, mask)
        return out

In [13]:
class EncoderLayer(nn.Module):

    def __init__(self, multi_head_attention_layer, position_wise_feed_forward_layer, norm_layer):
        super(EncoderLayer, self).__init__()
        self.multi_head_attention_layer = multi_head_attention_layer
        self.position_wise_feed_forward_layer = position_wise_feed_forward_layer
        self.residual_connection_layers = [ResidualConnectionLayer(copy.deepcopy(norm_layer)) for i in range(2)]
    
    def forward(self, x, mask):
        out = self.residual_connection_layers[0](x, lambda x: self.multi_head_attention_layer(query = x, key = x, value = x, mask =mask))
        out = self.residual_connection_layers[1](x, lambda x : self.position_wise_feed_forward_layer(out))

        return out


In [7]:
import numpy as np
from torch.nn import functional as F
import math
import pandas as pd

In [8]:
def calculate_attention(self, query, key, value, mask):
    d_k = key.size(-1)
    attention_score = torch.matmul(query, key.transpose(-2,-1))
    attention_score = attention_score / math.sqrt(d_k)
    if mask is not None:
        attention_score = attention_score.masked_fill(mask==0, -1e9)
    attention_prob = F.softmax(attention_score, dim=1)
    out = torch.matmul(attention_prob, value)
    return out

In [9]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, d_model, h, qkv_fc_layer, fc_layer):
        super(MultiHeadAttentionLayer, self).__init__()
        self.d_model = d_model
        self.h = h
        self.query_fc_layer = copy.deepcopy(qkv_fc_layer)
        self.key_fc_layer = copy.deepcopy(qkv_fc_layer)
        self.value_fc_layer = copy.deepcopy(qkv_fc_layer)
        self.fc_layer = fc_layer

    def forward(self, query, key, value, mask=None):
        n_batch = query.shape[0]

        def transform(x, fc_layer):
            out = fc_layer(x)
            out = out.view(n_batch, -1, self.h, self.d_model//self.h)
            out = out.transpose(1,2)
            return out
    
        query = transform(query, self.query_fc_layer)
        key = transform(key, self.key_fc_layer)
        value = transform(value, self.value_fc_layer)

        if mask is not None:
            mask = mask.unsqueeze(1)

        out = self.calculate_attention(query, key, value, mask)
        out = out.transpose(1,2)
        out = out.contigous().view(n_batch, -1, self.d_model)
        out = self.fc_layer(out)
        return out
            

In [10]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, d_model, h, qkv_fc_layer, fc_layer):
        super(MultiHeadAttentionLayer, self).__init__()
        self.d_model = d_model
        self.h = h
        self.query_fc_layer = copy.deepcopy(qkv_fc_layer)
        self.key_fc_layer = copy.deepcopy(qkv_fc_layer)
        self.value_fc_layer = copy.deepcopy(qkv_fc_layer)
        self.fc_layer = fc_layer

    def forward(self, query, key, value, mask=None):
        n_batch = query.shape[0]

        def transform(x, fc_layer):
            out = fc_layer(x)
            out = out.view(n_batch, -1, self.h, self.d_model//self.h)
            out = out.transpose(1,2)
            return out

        query = transform(query, self.query_fc_layer)
        key = transform(key, self.key_fc_layer)
        value = transform(value, self.value_fc_layer)

        if mask is not None:
            mask = mask.unsqueeze(1)
        
        out = self.calculate_attention(query, key, value, mask)
        out = out.transpose(1,2)
        out = out.contiguous().view(n_batch, -1, self.d_model)
        out = self.fc_layer(out)
        return out



In [11]:
class PositionWiseFeedForwadLayer(nn.Module):
    def __init__(self, first_fc_layer, second_fc_layer):
        super(PositionWiseFeedForwadLayer, self).__init__()
        self.first_fc_layer = first_fc_layer
        self.second_fc_layer = second_fc_layer

    def forward(self, x):
        out = self.first_fc_layer(x)
        out = F.relu(out)
        out = self.second_fc_layer(out)
        out = F.relu(out)

        return out

In [12]:
class ResidualConnectionLayer(nn.Module):
    def __init__(self, norm_layer):
        super(ResidualConnectionLayer, self).__init__()
        self.norm_layer = norm_layer
    
    def forward(self, x, sub_layer):
        out = sub_layer(x) + x
        out = self.norm_layer(out)
        return out

In [15]:
from torch.autograd import Variable

In [16]:
def subsequent_mask(size):
    atten_shape = (1, size, size)
    mask = np.triu(np.ones(atten_shape), k=1).astype('unit8')
    return torch.from_numpy(mask) == 0

def make_std_mask(tgt, pad):
    tgt_mask = (tgt != pad)
    tgt_mask = tgt_mask.unsqueeze(-2)
    tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
    return tgt_mask