In [9]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import pickle as pkl

# load tokenized dataset

In [10]:
dataset_dir_path = '../../data/processed/tokenized_data/'
with open(dataset_dir_path + 'train_data.pkl', 'rb') as f:
    tokenized_train_data = pkl.load(f)

with open(dataset_dir_path + 'valid_data.pkl', 'rb') as f:
    tokenized_valid_data = pkl.load(f)

In [11]:
print(f'Input(de) {tokenized_train_data[0][0]}')
print(f'Output(en) {tokenized_train_data[0][1]}')

Input(de) tensor([   2,   21,   85,  256,   31,   86,   22,   93,    7,   16,  114, 5645,
        3245,    3])
Output(en) tensor([   2,   19,   25,   15, 1197,  817,   17,   58,   84,  332, 1319,    3])


# load vocab

In [12]:
vocab_dir_path = '../../data/processed/vocab/'

with open(vocab_dir_path + 'token2idx_de.pkl', 'rb') as f:
    token2idx_de= pkl.load(f)
with open(vocab_dir_path + 'token2idx_en.pkl', 'rb') as f:
    token2idx_en = pkl.load(f)
with open(vocab_dir_path + 'idx2token_de.pkl', 'rb') as f:
    idx2token_de = pkl.load(f)
with open(vocab_dir_path + 'idx2token_en.pkl', 'rb') as f:
    idx2token_en = pkl.load(f)

# making the batch

In [13]:
batch_size = 128
PAD_INDEX = token2idx_de['<pad>']
START_INDEX = token2idx_en['<start>']
END_INDEX = token2idx_en['<end>']

In [14]:
def generate_batch(data_batch):
    batch_src = []
    batch_tgt = []
    for src, tgt in data_batch:
        batch_src.append(src)
        batch_tgt.append(tgt)
    
    batch_src = pad_sequence(batch_src, padding_value=PAD_INDEX)
    batch_tgt = pad_sequence(batch_tgt, padding_value=PAD_INDEX)

    return batch_src, batch_tgt

In [15]:
train_iter = DataLoader(tokenized_train_data, batch_size=batch_size, shuffle=True, collate_fn=generate_batch)
valid_iter = DataLoader(tokenized_valid_data, batch_size=batch_size, shuffle=True, collate_fn=generate_batch)

In [16]:
# show train_iter
# each column is a text
list(train_iter)[0]

(tensor([[   2,    2,    2,  ...,    2,    2,    2],
         [4603,    5,   14,  ...,   21,   60,   14],
         [ 840,   12,   69,  ...,  120,  880,   38],
         ...,
         [  29,    1,    1,  ...,    1,    1,    1],
         [ 148,    1,    1,  ...,    1,    1,    1],
         [   3,    1,    1,  ...,    1,    1,    1]]),
 tensor([[   2,    2,    2,  ...,    2,    2,    2],
         [ 141,    6,  248,  ...,   19,   59,    6],
         [  22,  198,   72,  ...,  118,  868,   39],
         ...,
         [3456,    1,    1,  ...,    1,    1,    1],
         [   3,    1,    1,  ...,    1,    1,    1],
         [   1,    1,    1,  ...,    1,    1,    1]]))

In [17]:
for src, tgt in train_iter:
    print(src)
    print(tgt)
    print(f'src shape : {src.shape}')
    print(f'tgt.shape : {tgt.shape}')
    break

tensor([[   2,    2,    2,  ...,    2,    2,    2],
        [   5,    5,   21,  ...,    5,   14,   14],
        [3536,   12,   31,  ...,   12,   17,   17],
        ...,
        [   1,    1,    1,  ...,    1,    1,    1],
        [   1,    1,    1,  ...,    1,    1,    1],
        [   1,    1,    1,  ...,    1,    1,    1]])
tensor([[  2,   2,   2,  ...,   2,   2,   2],
        [  6,   6,  19,  ...,   6, 274,   6],
        [ 33,  12,  36,  ..., 180,  14,  16],
        ...,
        [  1,   1,   1,  ...,   1,   1,   1],
        [  1,   1,   1,  ...,   1,   1,   1],
        [  1,   1,   1,  ...,   1,   1,   1]])
src shape : torch.Size([24, 128])
tgt.shape : torch.Size([24, 128])


# Model

In [18]:
import math
import torch
import torch.nn as nn
from torch import Tensor

## token embedding

In [19]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, embedding_size):
        super().__init__()
        # nn.Embedding is a simple lookup table.
        # if token index is set, it will return the corresponding embedding vector.
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.embedding_size = embedding_size
    
    def forward(self, tokens: Tensor):
        # the reason for this multiply is to align the range of the values
        # It is to make the positional encoding relatively smaller. 
        # This means the original meaning in the embedding vector won’t be lost 
        # when we add them together.
        return self.embedding(tokens.long()) * math.sqrt(self.embedding_size)

## positional encoging

In [20]:
class PositionalEncoding(nn.Module):
    
    def __init__(self, embedding_size: int, dropout: float, maxlen: int = 5000):
        super().__init__()
        
        den = torch.exp(-torch.arange(0, embedding_size, 2) * math.log(10000) / embedding_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        embedding_pos = torch.zeros((maxlen, embedding_size))
        embedding_pos[:, 0::2] = torch.sin(pos * den) # extract even element (start:stop:step)
        embedding_pos[:, 1::2] = torch.cos(pos * den) # extract odd element (start:stop:step)
        embedding_pos = embedding_pos.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('embedding_pos', embedding_pos) # positional encoding is not updated by learning

    def forward(self, token_embedding: Tensor):
        #print(f'token_embedding : {token_embedding.shape}')
        #print(f'positional_encoding : {self.embedding_pos.shape}') # self.embedding_pos is defined by self.register_fuffer.
        # self.embedding_pos is not updated by learning
        #print(f'positional_encoding : {self.embedding_pos[:token_embedding.size(0), :].shape}')
        return self.dropout(token_embedding + self.embedding_pos[:token_embedding.size(0), :])

## masking

In [21]:
def generate_square_subsequent_mask(seq_len, PAD_INDEX):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # seq_len is the length of the sentence
    # The mask is used to prevent the model from looking ahead in the sequence.
    # The mask is a square matrix of size (seq_len, seq_len)
    # The upper triangle of the matrix is filled with -inf
    # The lower triangle of the matrix is filled with 0
    mask = (torch.triu(torch.ones((seq_len, seq_len), device=device)) == 1).transpose(0, 1) # upper triangle
    mask = mask.float().masked_fill(mask == 1, float('-inf')).masked_fill(mask == PAD_INDEX, float(0.0)) # 1 -> -inf, 0 -> 0
    return mask

def create_mask(src, tgt, PAD_INDEX):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    seq_len_src = src.shape[0] # word number in one sentence in source
    seq_len_tgt = tgt.shape[0] # word number in one sentence in target

    mask_tgt = generate_square_subsequent_mask(seq_len_tgt, PAD_INDEX=PAD_INDEX)
    mask_src = torch.zeros((seq_len_src, seq_len_src), device=device).type(torch.bool)

    padding_mask_src = (src == PAD_INDEX).transpose(0, 1)
    padding_mask_tgt = (tgt == PAD_INDEX).transpose(0, 1)

    return mask_src, mask_tgt, padding_mask_src, padding_mask_tgt

In [22]:
seq_len_src = src.shape[0]
mask_src = torch.zeros((seq_len_src, seq_len_src)).type(torch.bool)

In [23]:
mask_src

tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False],
        [False, Fals

In [24]:
from torch.nn import TransformerEncoder, TransformerDecoder, TransformerEncoderLayer, TransformerDecoderLayer

In [25]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self, num_encoder_layers: int, num_decoder_layers: int, embedding_size: int, vocab_size_src: int, vocab_size_tgt: int, dim_feedforward: int = 512, dropout: float = 0.1, nhead:int = 8):
        super().__init__()

        self.token_embedding_src = TokenEmbedding(vocab_size_src, embedding_size)
        self.positional_encoding = PositionalEncoding(embedding_size, dropout=dropout)
        encoder_layer = TransformerEncoderLayer(d_model=embedding_size, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layer=encoder_layer, num_layers=num_encoder_layers)

        self.token_embedding_tgt = TokenEmbedding(vocab_size_tgt, embedding_size)
        decoder_layer = TransformerDecoderLayer(d_model=embedding_size, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout)
        self.transformer_decoder = TransformerDecoder(decoder_layer=decoder_layer, num_layers=num_decoder_layers)
        self.output = nn.Linear(embedding_size, vocab_size_tgt)
    
    def forward(self, src: Tensor, tgt: Tensor, mask_src: Tensor, mask_tgt: Tensor, padding_mask_src: Tensor, padding_mask_tgt: Tensor, memory_key_padding_mask: Tensor):
        # src : (seq_len_src, batch_size)
        # tgt : (seq_len_tgt, batch_size)
        # mask_src : (seq_len_src, seq_len_src)
        # mask_tgt : (seq_len_tgt, seq_len_tgt)
        # padding_mask_src : (batch_size, seq_len_src)
        # padding_mask_tgt : (batch_size, seq_len_tgt)

        embedding_src = self.positional_encoding(self.token_embedding_src(src))
        memory = self.transformer_encoder(embedding_src, mask_src, padding_mask_src)
        embedding_tgt = self.positional_encoding(self.token_embedding_tgt(tgt))
        outs = self.transformer_decoder(embedding_tgt, memory, mask_tgt, None, padding_mask_tgt, memory_key_padding_mask)
        return self.output(outs)
    
    def encode(self, src: Tensor, mask_src: Tensor):
        return self.transformer_encoder(self.positional_encoding(self.token_embedding_src(src)), mask_src)
    
    def decode(self, tgt: Tensor, memory: Tensor, mask_tgt: Tensor):
        return self.transformer_decoder(self.positional_encoding(self.token_embedding_tgt(tgt), memory, mask_tgt))


In [26]:
from tqdm import tqdm

In [27]:
def train(model, data, optimizer, criterion, PAD_INDEX):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model.train()
    losses = 0.0
    for src, tgt in tqdm(data):
        src = src.to(device)
        tgt = tgt.to(device)
        input_tgt = tgt[:-1, :] # remove last token
        mask_src, mask_tgt, padding_mask_src, padding_mask_tgt = create_mask(src, input_tgt, PAD_INDEX)

        logits = model(src=src, tgt=input_tgt,
                        mask_src=mask_src, mask_tgt=mask_tgt,
                        padding_mask_src=padding_mask_src, padding_mask_tgt=padding_mask_tgt,
                        memory_key_padding_mask=padding_mask_src)

        optimizer.zero_grad()

        output_tgt = tgt[1:, :] # remove first token
        loss = criterion(logits.reshape(-1, logits.shape[-1]), output_tgt.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()
    
    return losses / len(data)

In [28]:
def evaluate(model, data, criterion, PAD_INDEX):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model.eval()
    losses = 0.0

    for src, tgt in data:
        src = src.to(device)
        tgt = tgt.to(device)

        input_tgt = tgt[:-1, :] # remove last token
        mask_src, mask_tgt, padding_mask_src, padding_mask_tgt = create_mask(src, input_tgt, PAD_INDEX)
        logits = model(src=src, tgt=input_tgt, 
                       mask_src=mask_src, mask_tgt=mask_tgt,
                       padding_mask_src=padding_mask_src, padding_mask_tgt=padding_mask_tgt,
                       memory_key_padding_mask=padding_mask_src)

        output_tgt = tgt[1:, :] # remove first token
        loss = criterion(logits.reshape(-1, logits.shape[-1]), output_tgt.reshape(-1))
        losses += loss.item()

    return losses / len(data)

In [29]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

vocab_size_src = len(token2idx_de)
vocab_size_tgt = len(token2idx_en)
embedding_size = 240 # smaller than original 512
nhead = 8
dim_feedforward = 100
num_encoder_layers = 2
num_decoder_layers = 2
dropout = 0.1

model = Seq2SeqTransformer(
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    embedding_size=embedding_size,
    vocab_size_src=vocab_size_src,
    vocab_size_tgt=vocab_size_tgt,
    dim_feedforward=dim_feedforward,
    dropout=dropout,
    nhead=nhead
)

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p) # Xavier initialization

model = model.to(device)

criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_INDEX)
optimizer = torch.optim.Adam(model.parameters())



In [30]:
import time

epoch = 100
best_loss = float('inf')
best_model = None
patience = 10
counter = 0

for loop in range(epoch):
    start_time = time.time()

    loss_train = train(
        model=model, data=train_iter, optimizer=optimizer,
        criterion=criterion, PAD_INDEX=PAD_INDEX
    )

    elapsed_time = time.time() - start_time
    loss_valid = evaluate(
        model=model, data=valid_iter, criterion=criterion, PAD_INDEX=PAD_INDEX
    )

    print(f'epoch: {loop+1}, train loss: {loss_train:.4f}, valid loss: {loss_valid:.4f}, elapsed time: {elapsed_time:.4f} sec')

    if best_loss > loss_valid:
        best_loss = loss_valid
        best_model = model
        counter = 0
    
    if counter > patience:
        break

    counter += 1

100%|██████████| 227/227 [00:05<00:00, 40.38it/s]


epoch: 1, train loss: 3.5148, valid loss: 0.9421, elapsed time: 5.6251 sec


100%|██████████| 227/227 [00:05<00:00, 41.60it/s]


epoch: 2, train loss: 0.6187, valid loss: 0.3166, elapsed time: 5.4596 sec


100%|██████████| 227/227 [00:05<00:00, 41.59it/s]


epoch: 3, train loss: 0.2425, valid loss: 0.2244, elapsed time: 5.4608 sec


100%|██████████| 227/227 [00:05<00:00, 41.46it/s]


epoch: 4, train loss: 0.1219, valid loss: 0.1996, elapsed time: 5.4788 sec


100%|██████████| 227/227 [00:05<00:00, 41.17it/s]


epoch: 5, train loss: 0.0646, valid loss: 0.1859, elapsed time: 5.5168 sec


100%|██████████| 227/227 [00:05<00:00, 40.97it/s]


epoch: 6, train loss: 0.0332, valid loss: 0.1835, elapsed time: 5.5447 sec


100%|██████████| 227/227 [00:05<00:00, 40.86it/s]


epoch: 7, train loss: 0.0203, valid loss: 0.1827, elapsed time: 5.5601 sec


100%|██████████| 227/227 [00:05<00:00, 40.79it/s]


epoch: 8, train loss: 0.0172, valid loss: 0.1846, elapsed time: 5.5679 sec


100%|██████████| 227/227 [00:05<00:00, 40.71it/s]


epoch: 9, train loss: 0.0144, valid loss: 0.1841, elapsed time: 5.5796 sec


100%|██████████| 227/227 [00:05<00:00, 40.80it/s]


epoch: 10, train loss: 0.0142, valid loss: 0.1836, elapsed time: 5.5684 sec


100%|██████████| 227/227 [00:05<00:00, 40.64it/s]


epoch: 11, train loss: 0.0145, valid loss: 0.1809, elapsed time: 5.5886 sec


100%|██████████| 227/227 [00:05<00:00, 40.49it/s]


epoch: 12, train loss: 0.0157, valid loss: 0.1913, elapsed time: 5.6095 sec


100%|██████████| 227/227 [00:05<00:00, 40.38it/s]


epoch: 13, train loss: 0.0143, valid loss: 0.1907, elapsed time: 5.6239 sec


100%|██████████| 227/227 [00:05<00:00, 40.15it/s]


epoch: 14, train loss: 0.0149, valid loss: 0.1965, elapsed time: 5.6575 sec


100%|██████████| 227/227 [00:05<00:00, 40.52it/s]


epoch: 15, train loss: 0.0138, valid loss: 0.1999, elapsed time: 5.6050 sec


100%|██████████| 227/227 [00:05<00:00, 39.98it/s]


epoch: 16, train loss: 0.0151, valid loss: 0.1982, elapsed time: 5.6801 sec


100%|██████████| 227/227 [00:05<00:00, 40.19it/s]


epoch: 17, train loss: 0.0154, valid loss: 0.2059, elapsed time: 5.6499 sec


100%|██████████| 227/227 [00:05<00:00, 39.64it/s]


epoch: 18, train loss: 0.0138, valid loss: 0.2087, elapsed time: 5.7292 sec


100%|██████████| 227/227 [00:05<00:00, 39.07it/s]


epoch: 19, train loss: 0.0142, valid loss: 0.2100, elapsed time: 5.8134 sec


100%|██████████| 227/227 [00:05<00:00, 38.54it/s]


epoch: 20, train loss: 0.0132, valid loss: 0.2107, elapsed time: 5.8925 sec


100%|██████████| 227/227 [00:05<00:00, 38.73it/s]


epoch: 21, train loss: 0.0134, valid loss: 0.2177, elapsed time: 5.8644 sec


100%|██████████| 227/227 [00:05<00:00, 38.39it/s]

epoch: 22, train loss: 0.0133, valid loss: 0.2135, elapsed time: 5.9145 sec





# save model

In [52]:
import os
model_dir_path = '../../result/model'
torch.save(best_model.state_dict(), os.path.join(model_dir_path,  'translation_transformer.pth'))

# inference

In [None]:
def greedy_decode(model: torch.nn.Module, src: Tensor, mask_src: Tensor, seq_len_tgt: int, START_INDEX: int, END_INDEX: int):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    sor = src.to(device)
    mask_src = mask_src.to(device)

    memory = model.encode(src, mask_src)
    memory = model.transformer_encoder(model.positional_encoding(model.token_embedding_src(src)), mask_src)
    ys = torch.ones(1, 1).fill_(START_INDEX).type(torch.long).to(device) # start with <start> token

    for i in range(seq_len_tgt - 1):
        memory = memory.to(device)
        memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool)
        mask_tgt = (generate_square_subsequent_mask(ys.size(0)).type(torch.bool)).to(device)

        output = model.decode(ys, memory, mask_tgt)
        output = output.transpose(0, 1)
        output = model.output(output[:, -1]) # last column
        _, next_word = torch.max(output, dim = 1)
        next_word = next_word.item()

        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == END_INDEX:
            break

    return ys

