In [1]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import pickle as pkl

# load tokenized dataset

In [2]:
dataset_dir_path = '../../data/processed/tokenized_data/'
with open(dataset_dir_path + 'train_data.pkl', 'rb') as f:
    tokenized_train_data = pkl.load(f)

with open(dataset_dir_path + 'valid_data.pkl', 'rb') as f:
    tokenized_valid_data = pkl.load(f)

In [3]:
print(f'Input(de) {tokenized_train_data[0][0]}')
print(f'Output(en) {tokenized_train_data[0][1]}')

Input(de) tensor([   2,   21,   85,  256,   31,   86,   22,   93,    7,   16,  114, 5645,
        3245,    3])
Output(en) tensor([   2,   19,   25,   15, 1197,  817,   17,   58,   84,  332, 1319,    3])


# load vocab

In [4]:
vocab_dir_path = '../../data/processed/vocab/'

with open(vocab_dir_path + 'token2idx_de.pkl', 'rb') as f:
    token2idx_de= pkl.load(f)
with open(vocab_dir_path + 'token2idx_en.pkl', 'rb') as f:
    token2idx_en = pkl.load(f)
with open(vocab_dir_path + 'idx2token_de.pkl', 'rb') as f:
    idx2token_de = pkl.load(f)
with open(vocab_dir_path + 'idx2token_en.pkl', 'rb') as f:
    idx2token_en = pkl.load(f)

# making the batch

In [5]:
batch_size = 128
PAD_INDEX = token2idx_de['<pad>']
START_INDEX = token2idx_en['<start>']
END_INDEX = token2idx_en['<end>']

In [6]:
def generate_batch(data_batch):
    batch_src = []
    batch_tgt = []
    for src, tgt in data_batch:
        batch_src.append(src)
        batch_tgt.append(tgt)
    
    batch_src = pad_sequence(batch_src, padding_value=PAD_INDEX)
    batch_tgt = pad_sequence(batch_tgt, padding_value=PAD_INDEX)

    return batch_src, batch_tgt

In [7]:
train_iter = DataLoader(tokenized_train_data, batch_size=batch_size, shuffle=True, collate_fn=generate_batch)
valid_iter = DataLoader(tokenized_valid_data, batch_size=batch_size, shuffle=True, collate_fn=generate_batch)

In [8]:
# show train_iter
# each column is a text
list(train_iter)[0]

(tensor([[    2,     2,     2,  ...,     2,     2,     2],
         [    5,     5,     5,  ...,    76,     5,     5],
         [   12,   229,    12,  ...,    77, 14797,    12],
         ...,
         [    1,     1,     1,  ...,     1,     1,     1],
         [    1,     1,     1,  ...,     1,     1,     1],
         [    1,     1,     1,  ...,     1,     1,     1]]),
 tensor([[   2,    2,    2,  ...,    2,    2,    2],
         [   6,    6,    6,  ...,   83, 9033,  110],
         [  12,  268,   12,  ...,  156,   43,   14],
         ...,
         [   1,    1,    1,  ...,    1,    1,    1],
         [   1,    1,    1,  ...,    1,    1,    1],
         [   1,    1,    1,  ...,    1,    1,    1]]))

In [9]:
for src, tgt in train_iter:
    print(src)
    print(tgt)
    print(f'src shape : {src.shape}')
    print(f'tgt.shape : {tgt.shape}')
    break

tensor([[   2,    2,    2,  ...,    2,    2,    2],
        [   5,   60,    5,  ...,    5,  593,    5],
        [  12,   31,   12,  ..., 3370,  169,  525],
        ...,
        [   1,    1,    1,  ...,    1,    1,    1],
        [   1,    1,    1,  ...,    1,    1,    1],
        [   1,    1,    1,  ...,    1,    1,    1]])
tensor([[   2,    2,    2,  ...,    2,    2,    2],
        [   6,   59,    6,  ...,    6,  275,    6],
        [  12,   36,   12,  ..., 1607,  129,  693],
        ...,
        [   1,    1,    1,  ...,    1,    1,    1],
        [   1,    1,    1,  ...,    1,    1,    1],
        [   1,    1,    1,  ...,    1,    1,    1]])
src shape : torch.Size([36, 128])
tgt.shape : torch.Size([39, 128])


# Model

In [10]:
import math
import torch
import torch.nn as nn
from torch import Tensor

## token embedding

In [11]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, embedding_size):
        super().__init__()
        # nn.Embedding is a simple lookup table.
        # if token index is set, it will return the corresponding embedding vector.
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.embedding_size = embedding_size
    
    def forward(self, tokens: Tensor):
        # the reason for this multiply is to align the range of the values
        # It is to make the positional encoding relatively smaller. 
        # This means the original meaning in the embedding vector won’t be lost 
        # when we add them together.
        return self.embedding(tokens.long()) * math.sqrt(self.embedding_size)

## positional encoging

In [12]:
class PositionalEncoding(nn.Module):
    
    def __init__(self, embedding_size: int, dropout: float, maxlen: int = 5000):
        super().__init__()
        
        den = torch.exp(-torch.arange(0, embedding_size, 2) * math.log(10000) / embedding_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        embedding_pos = torch.zeros((maxlen, embedding_size))
        embedding_pos[:, 0::2] = torch.sin(pos * den) # extract even element (start:stop:step)
        embedding_pos[:, 1::2] = torch.cos(pos * den) # extract odd element (start:stop:step)
        embedding_pos = embedding_pos.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('embedding_pos', embedding_pos) # positional encoding is not updated by learning

    def forward(self, token_embedding: Tensor):
        #print(f'token_embedding : {token_embedding.shape}')
        #print(f'positional_encoding : {self.embedding_pos.shape}') # self.embedding_pos is defined by self.register_fuffer.
        # self.embedding_pos is not updated by learning
        #print(f'positional_encoding : {self.embedding_pos[:token_embedding.size(0), :].shape}')
        return self.dropout(token_embedding + self.embedding_pos[:token_embedding.size(0), :])

## masking

In [13]:
def generate_square_subsequent_mask(seq_len, PAD_INDEX):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # seq_len is the length of the sentence
    # The mask is used to prevent the model from looking ahead in the sequence.
    # The mask is a square matrix of size (seq_len, seq_len)
    # The upper triangle of the matrix is filled with -inf
    # The lower triangle of the matrix is filled with 0
    mask = (torch.triu(torch.ones((seq_len, seq_len), device=device)) == 1).transpose(0, 1) # upper triangle
    mask = mask.float().masked_fill(mask == 1, float('-inf')).masked_fill(mask == PAD_INDEX, float(0.0)) # 1 -> -inf, 0 -> 0
    return mask

def create_mask(src, tgt, PAD_INDEX):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    seq_len_src = src.shape[0] # word number in one sentence in source
    seq_len_tgt = tgt.shape[0] # word number in one sentence in target

    mask_tgt = generate_square_subsequent_mask(seq_len_tgt, PAD_INDEX=PAD_INDEX)
    mask_src = torch.zeros((seq_len_src, seq_len_src), device=device).type(torch.bool)

    padding_mask_src = (src == PAD_INDEX).transpose(0, 1)
    padding_mask_tgt = (tgt == PAD_INDEX).transpose(0, 1)

    return mask_src, mask_tgt, padding_mask_src, padding_mask_tgt

In [14]:
seq_len_src = src.shape[0]
mask_src = torch.zeros((seq_len_src, seq_len_src)).type(torch.bool)

In [15]:
mask_src

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])

In [16]:
from torch.nn import TransformerEncoder, TransformerDecoder, TransformerEncoderLayer, TransformerDecoderLayer

In [17]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self, num_encoder_layers: int, num_decoder_layers: int, embedding_size: int, vocab_size_src: int, vocab_size_tgt: int, dim_feedforward: int = 512, dropout: float = 0.1, nhead:int = 8):
        super().__init__()

        self.token_embedding_src = TokenEmbedding(vocab_size_src, embedding_size)
        self.positional_encoding = PositionalEncoding(embedding_size, dropout=dropout)
        encoder_layer = TransformerEncoderLayer(d_model=embedding_size, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layer=encoder_layer, num_layers=num_encoder_layers)

        self.token_embedding_tgt = TokenEmbedding(vocab_size_tgt, embedding_size)
        decoder_layer = TransformerDecoderLayer(d_model=embedding_size, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout)
        self.transformer_decoder = TransformerDecoder(decoder_layer=decoder_layer, num_layers=num_decoder_layers)
        self.output = nn.Linear(embedding_size, vocab_size_tgt)
    
    def forward(self, src: Tensor, tgt: Tensor, mask_src: Tensor, mask_tgt: Tensor, padding_mask_src: Tensor, padding_mask_tgt: Tensor, memory_key_padding_mask: Tensor):
        # src : (seq_len_src, batch_size)
        # tgt : (seq_len_tgt, batch_size)
        # mask_src : (seq_len_src, seq_len_src)
        # mask_tgt : (seq_len_tgt, seq_len_tgt)
        # padding_mask_src : (batch_size, seq_len_src)
        # padding_mask_tgt : (batch_size, seq_len_tgt)

        embedding_src = self.positional_encoding(self.token_embedding_src(src))
        memory = self.transformer_encoder(embedding_src, mask_src, padding_mask_src)
        embedding_tgt = self.positional_encoding(self.token_embedding_tgt(tgt))
        outs = self.transformer_decoder(embedding_tgt, memory, mask_tgt, None, padding_mask_tgt, memory_key_padding_mask)
        return self.output(outs)
    
    def encode(self, src: Tensor, mask_src: Tensor):
        return self.transformer_encoder(self.positional_encoding(self.token_embedding_src(src)), mask_src)
    
    def decode(self, tgt: Tensor, memory: Tensor, mask_tgt: Tensor):
        return self.transformer_decoder(self.positional_encoding(self.token_embedding_tgt(tgt)), memory, mask_tgt)


In [18]:
from tqdm import tqdm

In [19]:
def train(model, data, optimizer, criterion, PAD_INDEX):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model.train()
    losses = 0.0
    for src, tgt in tqdm(data):
        src = src.to(device)
        tgt = tgt.to(device)
        input_tgt = tgt[:-1, :] # remove last token
        mask_src, mask_tgt, padding_mask_src, padding_mask_tgt = create_mask(src, input_tgt, PAD_INDEX)

        logits = model(src=src, tgt=input_tgt,
                        mask_src=mask_src, mask_tgt=mask_tgt,
                        padding_mask_src=padding_mask_src, padding_mask_tgt=padding_mask_tgt,
                        memory_key_padding_mask=padding_mask_src)

        optimizer.zero_grad()

        output_tgt = tgt[1:, :] # remove first token
        loss = criterion(logits.reshape(-1, logits.shape[-1]), output_tgt.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()
    
    return losses / len(data)

In [20]:
def evaluate(model, data, criterion, PAD_INDEX):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model.eval()
    losses = 0.0

    for src, tgt in data:
        src = src.to(device)
        tgt = tgt.to(device)

        input_tgt = tgt[:-1, :] # remove last token
        mask_src, mask_tgt, padding_mask_src, padding_mask_tgt = create_mask(src, input_tgt, PAD_INDEX)
        logits = model(src=src, tgt=input_tgt, 
                       mask_src=mask_src, mask_tgt=mask_tgt,
                       padding_mask_src=padding_mask_src, padding_mask_tgt=padding_mask_tgt,
                       memory_key_padding_mask=padding_mask_src)

        output_tgt = tgt[1:, :] # remove first token
        loss = criterion(logits.reshape(-1, logits.shape[-1]), output_tgt.reshape(-1))
        losses += loss.item()

    return losses / len(data)

In [21]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

vocab_size_src = len(token2idx_de)
vocab_size_tgt = len(token2idx_en)
embedding_size = 240 # smaller than original 512
nhead = 8
dim_feedforward = 100
num_encoder_layers = 3
num_decoder_layers = 3
dropout = 0.1

model = Seq2SeqTransformer(
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    embedding_size=embedding_size,
    vocab_size_src=vocab_size_src,
    vocab_size_tgt=vocab_size_tgt,
    dim_feedforward=dim_feedforward,
    dropout=dropout,
    nhead=nhead
)

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p) # Xavier initialization

model = model.to(device)

criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_INDEX)
optimizer = torch.optim.Adam(model.parameters())



In [22]:
import time

epoch = 1000
best_loss = float('inf')
best_model = None
patience = 10
counter = 0

for loop in range(epoch):
    start_time = time.time()

    loss_train = train(
        model=model, data=train_iter, optimizer=optimizer,
        criterion=criterion, PAD_INDEX=PAD_INDEX
    )

    elapsed_time = time.time() - start_time
    loss_valid = evaluate(
        model=model, data=valid_iter, criterion=criterion, PAD_INDEX=PAD_INDEX
    )

    print(f'epoch: {loop+1}, train loss: {loss_train:.4f}, valid loss: {loss_valid:.4f}, elapsed time: {elapsed_time:.4f} sec')

    if best_loss > loss_valid:
        best_loss = loss_valid
        best_model = model
        counter = 0
    
    if counter > patience:
        break

    counter += 1

100%|██████████| 227/227 [00:07<00:00, 32.20it/s]


epoch: 1, train loss: 4.0758, valid loss: 1.2186, elapsed time: 7.0531 sec


100%|██████████| 227/227 [00:06<00:00, 32.97it/s]


epoch: 2, train loss: 0.7586, valid loss: 0.3521, elapsed time: 6.8901 sec


100%|██████████| 227/227 [00:06<00:00, 32.57it/s]


epoch: 3, train loss: 0.2786, valid loss: 0.2396, elapsed time: 6.9728 sec


100%|██████████| 227/227 [00:06<00:00, 32.56it/s]


epoch: 4, train loss: 0.1370, valid loss: 0.2100, elapsed time: 6.9752 sec


100%|██████████| 227/227 [00:06<00:00, 32.54it/s]


epoch: 5, train loss: 0.0720, valid loss: 0.1964, elapsed time: 6.9805 sec


100%|██████████| 227/227 [00:06<00:00, 32.53it/s]


epoch: 6, train loss: 0.0344, valid loss: 0.1893, elapsed time: 6.9824 sec


100%|██████████| 227/227 [00:07<00:00, 32.10it/s]


epoch: 7, train loss: 0.0183, valid loss: 0.1952, elapsed time: 7.0750 sec


100%|██████████| 227/227 [00:07<00:00, 31.81it/s]


epoch: 8, train loss: 0.0138, valid loss: 0.1898, elapsed time: 7.1396 sec


100%|██████████| 227/227 [00:07<00:00, 31.68it/s]


epoch: 9, train loss: 0.0129, valid loss: 0.1876, elapsed time: 7.1683 sec


100%|██████████| 227/227 [00:07<00:00, 32.01it/s]


epoch: 10, train loss: 0.0128, valid loss: 0.1916, elapsed time: 7.0956 sec


100%|██████████| 227/227 [00:07<00:00, 31.71it/s]


epoch: 11, train loss: 0.0137, valid loss: 0.1995, elapsed time: 7.1612 sec


100%|██████████| 227/227 [00:07<00:00, 31.96it/s]


epoch: 12, train loss: 0.0134, valid loss: 0.1974, elapsed time: 7.1064 sec


100%|██████████| 227/227 [00:07<00:00, 31.19it/s]


epoch: 13, train loss: 0.0121, valid loss: 0.2042, elapsed time: 7.2815 sec


100%|██████████| 227/227 [00:07<00:00, 30.78it/s]


epoch: 14, train loss: 0.0133, valid loss: 0.2023, elapsed time: 7.3792 sec


100%|██████████| 227/227 [00:07<00:00, 30.33it/s]


epoch: 15, train loss: 0.0120, valid loss: 0.2105, elapsed time: 7.4874 sec


100%|██████████| 227/227 [00:07<00:00, 30.60it/s]


epoch: 16, train loss: 0.0111, valid loss: 0.2098, elapsed time: 7.4211 sec


100%|██████████| 227/227 [00:07<00:00, 30.17it/s]


epoch: 17, train loss: 0.0129, valid loss: 0.2133, elapsed time: 7.5283 sec


100%|██████████| 227/227 [00:07<00:00, 30.37it/s]


epoch: 18, train loss: 0.0120, valid loss: 0.2184, elapsed time: 7.4772 sec


100%|██████████| 227/227 [00:07<00:00, 30.61it/s]


epoch: 19, train loss: 0.0108, valid loss: 0.2243, elapsed time: 7.4194 sec


100%|██████████| 227/227 [00:07<00:00, 30.75it/s]


epoch: 20, train loss: 0.0125, valid loss: 0.2198, elapsed time: 7.3863 sec


# save model

In [23]:
import os
model_dir_path = '../../result/model'
torch.save(best_model.state_dict(), os.path.join(model_dir_path,  'translation_transformer.pth'))

# inference

In [24]:
def greedy_decode(model: torch.nn.Module, src: Tensor, mask_src: Tensor, seq_len_tgt: int, START_INDEX: int, END_INDEX: int, PAD_INDEX: int):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    src = src.to(device)
    mask_src = mask_src.to(device)

    memory = model.encode(src, mask_src)
    memory = model.transformer_encoder(model.positional_encoding(model.token_embedding_src(src)), mask_src)
    ys = torch.ones(1, 1).fill_(START_INDEX).type(torch.long).to(device) # start with <start> token

    for i in range(seq_len_tgt - 1):
        print(f'||| {i}th decoding step')

        memory = memory.to(device)
        memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool)
        mask_tgt = (generate_square_subsequent_mask(ys.size(0), PAD_INDEX=PAD_INDEX).type(torch.bool)).to(device)

        output = model.decode(ys, memory, mask_tgt)
        print(f'||| output shape : {output.shape}') # (seq_len_tgt, batch_size, vocab_size_tgt)
        print(f'||| output a : {output}')
        output = output.transpose(0, 1)
        output = model.output(output[:, -1]) # last column
        print(f'||| output b : {output}')

        import pickle as pkl
        with open('../../data/external/output.pkl', 'wb') as f:
            pkl.dump(output, f)

        _, next_word = torch.max(output, dim = 1)
        next_word = next_word.item()
        print(f'||| next word is {idx2token_en[next_word]}')

        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == END_INDEX:
            print(f'||| next word is <end> token')
            break

    return ys



In [25]:
# load local modules
import sys
script_abs_path = '/home/sasatake/repos/python/transformer/script'
sys.path.append(script_abs_path)

from load_dataset import Multi30k
from tokenizer import Tokenizer

In [26]:
def convert_text_to_indices(text, vocab, tokenizer, lang):
    return [vocab['<start>']] + [vocab[token] for token in tokenizer.tokenize(lang, text.strip('\n'))] + [vocab['<end>']]

In [27]:
def translate(model : torch.nn.Module, text: str, vocab_src, vocab_tgt, tokenizer_src, seq_len_tgt, START_INDEX, END_INDEX, PAD_INDEX):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model.eval()

    tokens = convert_text_to_indices(text=text, vocab=vocab_src, tokenizer=tokenizer_src, lang='de')
    print(f'||| src tokens : {tokens}')
    num_tokens = len(tokens)
    src = torch.LongTensor(tokens).reshape(num_tokens, 1)
    print(f'||| src : {src}')
    mask_src = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)

    predicts = greedy_decode(model=model, src=src, mask_src=mask_src, seq_len_tgt=seq_len_tgt, START_INDEX=START_INDEX, END_INDEX=END_INDEX, PAD_INDEX=PAD_INDEX).flatten()

    return ''.join([vocab_tgt[int(token)] for token in predicts]).replace('<end>', '').replace('<start>', '') # remove <start> and <end> token

In [28]:
#seq_len_tgt = max([len(x[1] for x in tokenized_train_data)]) # max length of target sentence
seq_len_tgt = 30

text = 'Eine Gruppe von Menschen steht vor einem Iglu .'
tokenizer = Tokenizer()

ret = translate(model=best_model, text=text, vocab_src=token2idx_de, vocab_tgt=idx2token_en, tokenizer_src=tokenizer, seq_len_tgt=seq_len_tgt, START_INDEX=START_INDEX, END_INDEX=END_INDEX, PAD_INDEX=PAD_INDEX)
print(f'result : {ret}')

||| src tokens : [2, 14, 38, 24, 48, 30, 28, 6, 6182, 3]
||| src : tensor([[   2],
        [  14],
        [  38],
        [  24],
        [  48],
        [  30],
        [  28],
        [   6],
        [6182],
        [   3]])
||| 0th decoding step
||| output shape : torch.Size([1, 1, 240])
||| output a : tensor([[[-1.1178e+00,  2.5233e+00, -8.1720e-01,  1.3011e-01,  9.6152e-01,
           3.8175e-01,  2.0252e+00,  5.6775e-01, -1.3518e+00,  1.2847e+00,
           1.5160e+00,  1.2637e+00,  1.3025e-01, -2.1975e-01, -9.9687e-01,
           6.9770e-01, -6.2064e-01,  2.7259e-01, -2.1661e+00,  5.8886e-02,
          -3.0626e+00,  2.9953e+00,  9.8934e-02, -1.3914e+00, -2.6361e-01,
           7.4671e+00, -2.4018e+00,  8.8206e-02,  2.3075e-01, -2.6926e-02,
           2.8573e+00, -3.0730e+00, -5.2096e-01, -7.5036e-01, -4.4510e+00,
           2.8844e+00, -6.5234e-02,  2.0652e+00, -1.2125e+00,  1.3450e+00,
           9.0281e-01, -4.2879e+00,  2.4239e-01,  8.8024e-01, -1.9037e+00,
          -1.6230