In [305]:
import torch
from torch.nn import Module, Parameter, Linear, Dropout, ModuleList, Embedding
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import sentencepiece
import os
import time

In [2]:
path = "./data/WMT16"
train_path_en = os.path.join(path, "train.tok.clean.bpe.32000.en")
train_path_de = os.path.join(path, "train.tok.clean.bpe.32000.de")
vocab_path = os.path.join(path, "vocab.bpe.32000")
test_path_en = os.path.join(path, "newstest2016.tok.bpe.32000.en")
test_path_en = os.path.join(path, "newstest2016.tok.bpe.32000.de")

SOS_id = 37005
EOS_id = 37006
UNK_id = 37007
PAD_id = 0

In [3]:
class CustomDataset(Dataset):
    def __init__(self, path_en, path_de, vocab_path, sos_id, eos_id, unk_id):
        super(CustomDataset, self).__init__()
        
        """read English data, German data and vocab"""
        with open(path_en, "r") as f:
            self.data_en = f.read().split("\n")[:-1]
        with open(path_de, "r") as f:
            self.data_de = f.read().split("\n")[:-1]
        with open(vocab_path, "r") as f:
            words = f.read().split("\n")[:-1]
            self.vocab = dict()
            for idx, word in enumerate(words):
                self.vocab[word] = idx+1
            self.vocab["<sos>"] = SOS_id
            self.vocab["<eos>"] = EOS_id
            self.vocab["<unk>"] = UNK_id
            self.vocab["<pad>"] = PAD_id
        """"""
        
    def __len__(self):
        return len(self.data_en)
    
    def __getitem__(self, idx):
        # read sentences
        sentence_en = self.data_en[idx].split(" ")
        sentence_de = self.data_de[idx].split(" ")
        
        tokens_en = []
        tokens_de = []
        
        """parse sentences to token sequences"""
        tokens_en.append(self.vocab["<sos>"])
        for word in sentence_en:
            if word in self.vocab.keys():
                tokens_en.append(self.vocab[word])
            else:
                tokens_en.append(self.vocab["<unk>"])
        tokens_en.append(self.vocab["<eos>"])
        tokens_en = torch.IntTensor(tokens_en)
        
        tokens_de.append(self.vocab["<sos>"])
        for word in sentence_de:
            if word in self.vocab.keys():
                tokens_de.append(self.vocab[word])
            else:
                tokens_de.append(self.vocab["<unk>"])
        tokens_de.append(self.vocab["<eos>"])
        tokens_de = torch.IntTensor(tokens_de)
        
        return (tokens_en, tokens_de)
        """"""

def custom_collate(batch):
    """
    custom collate function for data loader
    input: batch of (eq)
    output: (sequence_en, de_sequnece, en_seq_len, de_seq_len)
    """
    batch_size = len(batch)
    max_len_en = max([len(data[0]) for data in batch])
    max_len_de = max([len(data[1]) for data in batch])
    sequence_en = torch.zeros([batch_size, max_len_en], dtype=torch.int32)
    sequence_de = torch.zeros([batch_size, max_len_de], dtype=torch.int32)
    
    seq_len_en = []
    seq_len_de = []
    
    for idx, data in enumerate(batch):
        seq_en, seq_de = data
        
        seq_len = len(seq_en)
        seq_len_en.append(seq_len)
        sequence_en[idx][:seq_len] = seq_en
        
        seq_len = len(seq_de)
        seq_len_de.append(seq_len)
        sequence_de[idx][:seq_len] = seq_de
    
    return sequence_en, sequence_de, seq_len_en, seq_len_de

In [4]:
batch_size = 4

In [43]:
train_dataset = CustomDataset(train_path_en, train_path_de, vocab_path, SOS_id, EOS_id, UNK_id)
train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=custom_collate, shuffle=True, num_workers=4)

In [277]:
class MultiheadAttention(Module):
    def __init__(self, d_model, num_heads, dropout):
        super(MultiheadAttention, self).__init__()
        
        self.d_model = d_model
        self.d_k = int(d_model / num_heads)
        if self.d_k * num_heads != d_model:
            raise Exception("d_model cannot be divided by num_heads.")
        self.num_heads = num_heads
            
        self.query = Linear(d_model, d_model)
        self.key = Linear(d_model, d_model)
        self.value = Linear(d_model, d_model)
        
        self.dropout = Dropout(dropout)
        
        self.output = Linear(d_model, d_model)
        
    def forward(self, query, key, value, future_mask=None, pad_mask=None):
        # query, key, value: [batch, time, d_model]
        assert len(query.size()) == 3, "input is not batch"
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)
        
        # query, key, value: [batch * num_heads, time, d_k]
        query = torch.cat(torch.split(query, self.d_k, dim=2), dim=0)
        key = torch.cat(torch.split(key, self.d_k, dim=2), dim=0)
        value = torch.cat(torch.split(value, self.d_k, dim=2), dim=0)
        
        # attention_score: [batch * num_heads, time, time]
        attention_score = torch.matmul(query, key.transpose(1,2)) / np.sqrt(self.d_k)
        
        # if mask is True, fill to -inf
        if future_mask is not None:
            attention_score = attention_score.masked_fill(mask=future_mask, value=-float("inf"))
        if pad_mask is not None:
            # reshape pad_mask from [batch, 1, time] to [batch * num_heads, 1, time]
            pad_mask = torch.cat([pad_mask]*self.num_heads, dim=0)
            attention_score = attention_score.masked_fill(mask=pad_mask, value=-float("inf"))
        
        # change score to probability
        attention_score = F.softmax(attention_score, dim=2)
        attention_score = self.dropout(attention_score)
        
        # probability * value: [batch * num_heads, time, d_k]
        output = torch.matmul(attention_score, value)
        
        # reshape output: [batch, time, d_model]
        batch_size = output.size()[0] / self.num_heads
        output = torch.cat(torch.split(output, batch_size, dim=0), dim=2)
        
        # linear projection of output
        output = self.output(output)
        
        return output

In [104]:
def create_mask(sequence, pad_id):
    # sequence: [batch, time]
    batch_size = sequence.size()[0]
    seq_len = sequence.size()[1]
    
    """
    masking future positions
    [F T T]
    [F F T]
    [F F F]
    future_mask: [seq_len, seq_len]
    """
    future_mask = torch.BoolTensor(np.triu(np.ones(seq_len), k=1))
    
    """
    masking pad tokens in sequence to prevent query from attending to pad tokens in key
    pad_batch: [batch, seq_len]
    pad_mask: [batch, 1, seq_len]
    pad_mask has to be reshaped to [batch * num_heads, 1, seq_len]
    """
    pad_batch = (sequence == pad_id)
    pad_mask = pad_batch.view(batch_size, 1, seq_len)
    
    return future_mask, pad_mask

In [282]:
class LayerNorm(Module):
    def __init__(self, d_model, epsilon=1e-6):
        super(LayerNorm, self).__init__()
        
        self.d_model = d_model
        self.epsilon = epsilon
        self.gamma = Parameter(torch.ones(d_model))
        self.beta = Parameter(torch.zeros(d_model))
        
    def forward(self, inputs):
        # inputs: [batch, time, d_model]
        mean = inputs.mean(dim=2, keepdeem=True)
        var = inputs.var(dim=2, keepdeem=True)
        
        return self.gamma * (inputs - mean) / torch.sqrt(var + self.epsilon) + self.beta

In [120]:
class FeedForward(Module):
    def __init__(self, d_model, d_ff, dropout):
        super(FeedForward, self).__init__()
        
        self.d_model = d_model
        self.d_ff = d_ff
        
        self.dropout = Dropout(dropout)
        self.linear1 = Linear(d_model, d_ff)
        self.linear2 = Linear(d_ff, d_model)
        
    def forward(self, inputs):
        # inputs: [batch, time, d_model]
        output = self.linear1(inputs)
        output = F.relu(output)
        output = self.dropout(output)
        output = self.linear2(output)
        
        return output

In [276]:
class PositionalEncoding(Module):
    def __init__(self, d_model, max_len, pad_id):
        super(PositionalEncoding, self).__init__()
        
        self.d_model = d_model
        self.max_len = max_len
        self.pad_id = pad_id
        
        self.pe = torch.zeros([max_len, d_model])
        for pos in range(max_len):
            for i in range(0, d_model, 2):
                self.pe[pos, i] = np.sin(pos / 10000 ** (i / d_model))
                self.pe[pos, i+1] = np.cos(pos / 10000 ** (i / d_model))
        
    def forward(self, inputs):
        # inputs: [batch, time]
        batch_size = inputs.size()[0]
        seq_len = inputs.size()[1]
        
        # pad_mask: [batch, time, 1]
        pad_mask = (inputs == self.pad_id)
        pad_mask = pad_mask.view(batch_size, seq_len, 1)
        
        # pe: [max_len, d_model] => [batch, seq_len, d_model]
        pe = torch.stack([self.pe[:seq_len, :]]*4, dim=0)
        pe = pe.masked_fill(mask=pad_mask, value=0)
        
        return pe

In [301]:
class Encoder(Module):
    def __init__(self, d_model, d_ff, num_heads, dropout):
        super(Encoder, self).__init__()
        
        self.self_attention = MultiheadAttention(d_model, num_heads, dropout)
        self.norm1 = LayerNorm(d_model)
        self.feedforward = FeedForward(d_model, d_ff)
        self.norm2 = LayerNorm(d_model)
        self.dropout = Dropout(dropout)
        
    def forward(self, inputs, pad_mask=None):
        """
        inputs: [batch, time, d_model]
        pad_mask: [batch, 1, time]
        """
        
        """sublayer 1: self attention"""
        output = self.sefl_attention(inputs, inputs, inputs, pad_mask=pad_mask)
        output = self.dropout(output)
        output_ = self.norm1(output + inputs)
        """"""
        
        """sublayer 2: feed forward"""
        output = self.feedforward(output_)
        output = self.dropout(output)
        output = self.norm2(output + output_)
        """"""
        
        return output

In [303]:
class Decoder(Module):
    def __init__(self, d_model, d_ff, num_heads, dropout):
        super(Decoder, self).__init__()
        
        self.self_attention = MultiheadAttention(d_model, num_heads, dropout)
        self.norm1 = LayerNorm(d_model)
        self.cross_attention = MultiheadAttention(d_model, num_heads, dropout)
        self.norm2 = LayerNorm(d_model)
        self.feedforward = FeedForward(d_model, d_ff)
        self.norm3 = LayerNorm(d_model)
        self.dropout = Dropout(dropout)
        
    def forward(self, inputs, encoder_output, future_mask=None, src_pad_mask=None, tgt_pad_mask=None):
        """
        inputs: [batch, time, d_model]
        encoder_output: [batch, time, d_model]
        """
        
        """sublayer 1: self attention"""
        output = self.self_attention(intput, inputs, inputs, future_mask=future_mask, pad_mask=tgt_pad_mask)
        output = self.dropout(output)
        output_ = self.norm1(output + inputs)
        """"""
        
        """sublayer 2: encoder decoder attention"""
        output = self.cross_attention(output_, encoder_output, encoder_output, pad_mask=src_pad_mask)
        output = self.dropout(output)
        output_ = self.norm2(output + output_)
        """"""
        
        """sublayer 3: feed forward"""
        output = self.feedforward(output_)
        output = self.dropout(output)
        output = self.norm3(output + output_)
        """"""
        
        return output

In [344]:
class EncoderStack(Module):
    def __init__(self, d_model, d_ff, shared_embedding, num_heads, num_layers, max_len, dropout, pad_id):
        super(EncoderStack, self).__init__()
        
        self.layers = ModuleList([Encoder(d_model, d_ff, num_heads, dropout)] * num_layers)
        self.embedding = shared_embedding
        self.pe = PositionalEncoding(d_model, max_len, pad_id)
        self.dropout = Dropout(dropout)
        
    def forward(self, inputs, pad_mask=None):
        """
        inputs: [batch, time]
        pad_mask: [batch, 1, time]
        """
        embedding = self.embedding(inputs)
        pe = self.pe(inputs)
        
        output = self.dropout(embedding + pe)
        
        for layer in self.layers:
            output = layer(output, pad_mask)
        
        return output

In [345]:
class DecoderStack(Module):
    def __init__(self, d_model, d_ff, shared_embedding, num_heads, num_layers, max_len, dropout, pad_id):
        super(DecoderStack, self).__init__()
        
        self.layers = ModuleList([Decoder(d_model, d_ff, num_heads, dropout)] * num_layers)
        self.embedding = shared_embedding
        self.pe = PositionalEncoding(d_model, max_len, pad_id)
        self.dropout = Dropout(dropout)
        
    def forward(self, inputs, encoder_output, future_mask=None, src_pad_mask=None, tgt_pad_mask=None):
        """
        inputs: [batch, time]
        encoder_output: [batch, time, d_model]
        future_mask: [batch, time, time]
        src_pad_mask, tgt_pad_mask: [batch, 1, time]
        """
        embedding = self.embedding(inputs)
        pe = self.pe(inputs)
        
        output = self.dropout(embedding + pe)
        
        for layer in self.layers:
            output = layer(output, encoder_output, future_mask, src_pad_mask, tgt_pad_mask)
        
        # output: [batch, time, d_model] => [batch, time, vocab]
        output = torch.matmul(output, self.embedding.weight.data)
        output = F.softmax(output, dim=2)
        
        return output

In [346]:
class Transformer(Module):
    def __init__(self, d_model, d_ff, vocab_size, num_heads, num_layers, max_len, dropout, pad_id):
        super(Transformer, self).__init__()
        
        self.shared_embedding = Embedding(vocab_size, d_model)
        self.encoder = EncoderStack(d_model, d_ff, self.shared_embedding, num_heads, num_layers, max_len, dropout, pad_id)
        self.decoder = DecoderStack(d_model, d_ff, self.shared_embedding, num_heads, num_layers, max_len, dropout, pad_id)
    
    def forward(self, inputs, target, future_mask=None, src_pad_mask=None, tgt_pad_mask=None):
        """
        inputs, target: [batch, time]
        future_mask: [batch, time, time]
        src_pad_mask, tgt_pad_mask: [batch, 1, time]
        """
        encoder_output = self.encoder(inputs, src_pad_mask)
        output = self.decoder(target, encoder_output, future_mask, src_pad_mask, tgt_pad_mask)
        
        return output