In [1]:
import math
from dataclasses import dataclass

import numpy as np
import sacrebleu
import sentencepiece as spm
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchtext.datasets import Multi30k
from tqdm import tqdm

from transformer import Transformer
from utils import *

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
SRC, TRG = "de", "en"

train_iter = Multi30k(split='train', language_pair=(SRC, TRG))
f_de = open("Multi30k_de_text.txt", "w")
f_en = open("Multi30k_en_text.txt", "w")
for pair in train_iter:
    f_de.write(pair[0]+'\n')
    f_en.write(pair[1]+'\n')
f_de.close()
f_en.close()

In [3]:
en_vocab_size = 8200
de_vocab_size = 10000
vocab_sizes = {"en": en_vocab_size, "de": de_vocab_size}

In [4]:
# train sentencepiece models to get tokenizers
spm.SentencePieceTrainer.train\
(f'--input=Multi30k_de_text.txt --model_prefix=Multi30k_de --user_defined_symbols= --vocab_size={de_vocab_size}')
spm.SentencePieceTrainer.train\
(f'--input=Multi30k_en_text.txt --model_prefix=Multi30k_en --user_defined_symbols= --vocab_size={en_vocab_size}')

# make SentencePieceProcessor instances and load the model files
de_sp = spm.SentencePieceProcessor()
de_sp.load('Multi30k_de.model')
en_sp = spm.SentencePieceProcessor()
en_sp.load('Multi30k_en.model')

tokenizers = {"en": en_sp.encode_as_ids, "de": de_sp.encode_as_ids}
detokenizers = {"en":en_sp.decode_ids, "de":de_sp.decode_ids}

sentencepiece_trainer.cc(177) LOG(INFO) Running command: --input=Multi30k_de_text.txt --model_prefix=Multi30k_de --user_defined_symbols= --vocab_size=10000
sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : 
trainer_spec {
  input: Multi30k_de_text.txt
  input_format: 
  model_prefix: Multi30k_de
  model_type: UNIGRAM
  vocab_size: 10000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 0
  bos_id: 1
  eos_id: 2
  pad_id: -1
  unk_piece: <unk>
  bo

In [5]:
# indexes of special symbols
UNK, BOS, EOS, PAD = 0, 1, 2, 3

train_iter = Multi30k(split='train', language_pair=(SRC, TRG))
valid_iter = Multi30k(split='valid', language_pair=(SRC, TRG))
test_iter  = Multi30k(split='test',  language_pair=(SRC, TRG))

train_set = [(x.rstrip('\n'), y.rstrip('\n')) for x, y in train_iter if x!='']
valid_set = [(x.rstrip('\n'), y.rstrip('\n')) for x, y in valid_iter if x!='']
# test_set  = [(x.rstrip('\n'), y.rstrip('\n')) for x, y in test_iter if x!='']
print(len(train_set), len(valid_set))
for i in range(10):
   print(train_set[i])

29000 1014
('Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.', 'Two young, White males are outside near many bushes.')
('Mehrere Männer mit Schutzhelmen bedienen ein Antriebsradsystem.', 'Several men in hard hats are operating a giant pulley system.')
('Ein kleines Mädchen klettert in ein Spielhaus aus Holz.', 'A little girl climbing into a wooden playhouse.')
('Ein Mann in einem blauen Hemd steht auf einer Leiter und putzt ein Fenster.', 'A man in a blue shirt is standing on a ladder cleaning a window.')
('Zwei Männer stehen am Herd und bereiten Essen zu.', 'Two men are at the stove preparing food.')
('Ein Mann in grün hält eine Gitarre, während der andere Mann sein Hemd ansieht.', 'A man in green holds a guitar while the other man observes his shirt.')
('Ein Mann lächelt einen ausgestopften Löwen an.', 'A man is smiling at a stuffed lion')
('Ein schickes Mädchen spricht mit dem Handy während sie langsam die Straße entlangschwebt.', 'A trendy girl talking on her cell

In [6]:
max_seq_len = 50
def tokenize_dataset(dataset):
    'tokenize a dataset and add [BOS] and [EOS] to the beginning and end of the sentences'
    return [(torch.tensor([BOS]+tokenizers[SRC](src_text)[0:max_seq_len-2]+[EOS]),
             torch.tensor([BOS]+tokenizers[TRG](trg_text)[0:max_seq_len-2]+[EOS]))
            for src_text, trg_text in dataset]
          
train_tokenized = tokenize_dataset(train_set)
valid_tokenized = tokenize_dataset(valid_set)
# test_tokenized  = tokenize_dataset(test_set)

In [7]:
class TranslationDataset(Dataset):
    'create a dataset for torch.utils.data.DataLoader() '
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def pad_sequence(batch):
    'collate function for padding sentences such that all \
    the sentences in the batch have the same length'
    src_seqs  = [src for src, trg in batch]
    trg_seqs  = [trg for src, trg in batch]
    src_padded = torch.nn.utils.rnn.pad_sequence(src_seqs,
                                batch_first=True, padding_value = PAD)
    trg_padded = torch.nn.utils.rnn.pad_sequence(trg_seqs,
                                batch_first=True, padding_value = PAD)
    return src_padded, trg_padded

In [8]:
batch_size = 128

class Dataloaders:
    'Dataloaders contains train_loader, test_loader and valid_loader for training and evaluation '
    def __init__(self):
        self.train_dataset = TranslationDataset(train_tokenized)
        self.valid_dataset = TranslationDataset(valid_tokenized)
        # self.test_dataset  = TranslationDataset(test_tokenized)
        
        # each batch returned by dataloader will be padded such that all the texts in
        # that batch have the same length as the longest text in that batch
        self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=batch_size,
                                                shuffle=True, collate_fn = pad_sequence)
        
        
        self.valid_loader = torch.utils.data.DataLoader(self.valid_dataset, batch_size=batch_size,
                                                shuffle=True, collate_fn=pad_sequence)

In [9]:
def make_batch_input(x, y):
        src = x.to(device)
        trg_in = y[:, :-1].to(device)
        trg_out = y[:, 1:].contiguous().view(-1).to(device)
        src_pad_mask = (src == PAD).view(src.size(0), 1, 1, src.size(-1))
        trg_pad_mask = (trg_in == PAD).view(trg_in.size(0), 1, 1, trg_in.size(-1))
        return src, trg_in, trg_out, src_pad_mask, trg_pad_mask

In [10]:
def make_model():
    model = Transformer(num_encoder_layers=6,
                        num_decoder_layers=6,
                        d_model=512,
                        num_heads=8,
                        dff=2048,
                        input_vocab_size=vocab_sizes[SRC],
                        target_vocab_size=vocab_sizes[TRG],
                        max_seq_len=max_seq_len,
                        dropout_rate=0.1).to(device)

    # initialize model parameters
    # it seems that this initialization is very important!
    for p in model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    return model

In [11]:
data_loaders = Dataloaders()

model = make_model()

warmup_steps = 3 * len(data_loaders.train_loader)
# lr first increases in the warmup steps, and then descreases
lr_fn = lambda step: 512 **(-0.5) * min([(step+1)**(-0.5), (step+1) * warmup_steps ** (-1.5)])

optimizer = torch.optim.Adam(model.parameters(), lr=0.5, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_fn)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD)
early_stop_count = 2


In [12]:
def train_epoch(model, dataloaders):
    model.train()
    grad_norm_clip = 1.0
    losses, acc, count = [], 0, 0
    num_batches = len(dataloaders.train_loader)
    pbar = tqdm(enumerate(dataloaders.train_loader), total=num_batches)
    for idx, (x, y)  in  pbar:
        optimizer.zero_grad()
        src, trg_in, trg_out, src_pad_mask, trg_pad_mask = make_batch_input(x, y)
        pred, _ = model(src, trg_in)
        pred = pred.view(-1, pred.size(-1))
        loss = loss_fn(pred, trg_out).to(device)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm_clip)
        optimizer.step()
        scheduler.step()
        losses.append(loss.item())
        # report progress
        if idx>0 and idx%50 == 0:
            pbar.set_description(f'train loss={loss.item():.3f}, lr={scheduler.get_last_lr()[0]:.5f}')
    return np.mean(losses)


def train(model, dataloaders, epochs):
    global early_stop_count
    best_valid_loss = float('inf')
    train_size = len(dataloaders.train_loader)*batch_size
    for ep in range(epochs):
        train_loss = train_epoch(model, dataloaders)
        valid_loss = validate(model, dataloaders.valid_loader)
        
        print(f'ep: {ep}: train_loss={train_loss:.5f}, valid_loss={valid_loss:.5f}')
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
        else:
            if scheduler.last_epoch>2*warmup_steps:
                early_stop_count -= 1
                if early_stop_count<=0:   
                    return train_loss, valid_loss
    return train_loss, valid_loss
      
               
def validate(model, dataloder):
    'compute the validation loss'
    model.eval()
    losses = []
    with torch.no_grad():
        for i, (x, y) in enumerate(dataloder):
            src, trg_in, trg_out, src_pad_mask, trg_pad_mask = make_batch_input(x,y)
            pred, _ = model(src, trg_in)
            pred = pred.view(-1, pred.size(-1))
            losses.append(loss_fn(pred, trg_out).item())
    return np.mean(losses)

In [13]:
train_loss, valid_loss = train(model, data_loaders, epochs=10)

train loss=4.345, lr=0.00025: 100%|██████████| 227/227 [00:28<00:00,  7.87it/s]


ep: 0: train_loss=5.77885, valid_loss=4.05680


train loss=3.101, lr=0.00053: 100%|██████████| 227/227 [00:28<00:00,  7.97it/s]


ep: 1: train_loss=3.42724, valid_loss=2.97579


train loss=2.543, lr=0.00082: 100%|██████████| 227/227 [00:28<00:00,  7.94it/s]


ep: 2: train_loss=2.64849, valid_loss=2.45689


train loss=2.119, lr=0.00074: 100%|██████████| 227/227 [00:28<00:00,  7.91it/s]


ep: 3: train_loss=2.13015, valid_loss=2.09979


train loss=1.583, lr=0.00066: 100%|██████████| 227/227 [00:28<00:00,  7.88it/s]


ep: 4: train_loss=1.68488, valid_loss=1.81377


train loss=1.259, lr=0.00060: 100%|██████████| 227/227 [00:28<00:00,  7.85it/s]


ep: 5: train_loss=1.36163, valid_loss=1.63384


train loss=1.092, lr=0.00056: 100%|██████████| 227/227 [00:28<00:00,  7.85it/s]


ep: 6: train_loss=1.13870, valid_loss=1.56924


train loss=0.901, lr=0.00052: 100%|██████████| 227/227 [00:28<00:00,  7.90it/s]


ep: 7: train_loss=0.97735, valid_loss=1.53548


train loss=0.855, lr=0.00049: 100%|██████████| 227/227 [00:28<00:00,  7.87it/s]


ep: 8: train_loss=0.84089, valid_loss=1.47775


train loss=0.704, lr=0.00047: 100%|██████████| 227/227 [00:28<00:00,  7.87it/s]


ep: 9: train_loss=0.71335, valid_loss=1.48312


In [14]:
# def translate(model, x):
#     'translate source sentences into the target language, without looking at the answer'
#     with torch.no_grad():
#         dB = x.size(0)
#         y = torch.tensor([[BOS]*dB]).view(dB, 1).to(device)
#         memory = model.encoder(x)
#         for i in range(max_seq_len):
#             logits, _ = model.decoder(y, memory)
#             logits = nn.Softmax(1)(logits)
#             last_output = logits.argmax(-1)[:, -1]
#             last_output = last_output.view(dB, 1)
#             y = torch.cat((y, last_output), 1).to(device)
#     return y
     
# def remove_pad(sent):
#     '''truncate the sentence if BOS is in it,
#      otherwise simply remove the padding tokens at the end'''
#     if sent.count(EOS)>0:
#       sent = sent[0:sent.index(EOS)+1]
#     while sent and sent[-1] == PAD:
#             sent = sent[:-1]
#     return sent

# def decode_sentence(detokenizer, sentence_ids):
#     'convert a tokenized sentence (a list of numbers) to a literal string'
#     if not isinstance(sentence_ids, list):
#         sentence_ids = sentence_ids.tolist()
#     sentence_ids = remove_pad(sentence_ids)
#     return detokenizer(sentence_ids).replace("", "")\
#            .replace("", "").strip().replace(" .", ".")

In [15]:
# def translate_this_sentence(text: str):
#     'translate the source sentence in string formate into target language'
#     input = torch.tensor([[BOS] + tokenizers[SRC](text) + [EOS]]).to(device)
#     output = translate(model, input)
#     return decode_sentence(detokenizers[TRG], output[0])

# translate_this_sentence("Eine Gruppe von Menschen steht vor einem Iglu.")