### Build dataset

In [None]:
!pip install -q datasets
!pip install sacrebleu

In [None]:
import os
from tokenizers import Tokenizer, pre_tokenizers, trainers, models
from datasets import load_dataset

In [None]:
ds = load_dataset('thainq107/iwslt2015-en-vi')

In [None]:
ds

In [None]:
# word-based
tokenizer_en = Tokenizer(models.WordLevel(unk_token='<unk>'))
tokenizer_vi = Tokenizer(models.WordLevel(unk_token='<unk>'))

tokenizer_en.pre_tokenizer = pre_tokenizers.Whitespace()
tokenizer_vi.pre_tokenizer = pre_tokenizers.Whitespace()

trainer = trainers.WordLevelTrainer(
    vocab_size=15000,
    min_frequency=2,
    special_tokens=['<pad>', '<unk>', '<bos>', '<eos>']
)

# train tokenizer
tokenizer_en.train_from_iterator(ds['train']['en'], trainer)
tokenizer_vi.train_from_iterator(ds['train']['vi'], trainer)

# tokenizer
tokenizer_en.save('tokenizer_en.json')
tokenizer_vi.save('tokenizer_vi.json')

### Encoding

In [None]:
from transformers import PreTrainedTokenizerFast

MAX_LEN = 75

# Load tokenizer
tokenizer_en = PreTrainedTokenizerFast(
    tokenizer_file="tokenizer_en.json",
    unk_token="<unk>", pad_token="<pad>", bos_token="<bos>", eos_token="<eos>"
)
tokenizer_vi = PreTrainedTokenizerFast(
tokenizer_file="tokenizer_vi.json",
unk_token="<unk>", pad_token="<pad>", bos_token="<bos>", eos_token="<eos>"
)

def preprocess_function(examples):
    src_texts = examples["en"]
    tgt_texts = ["<bos> " + sent + "<eos>" for sent in examples["vi"]]

    src_encodings = tokenizer_en(
        src_texts, padding="max_length", truncation=True, max_length=MAX_LEN
    )
    tgt_encodings = tokenizer_vi(
        tgt_texts, padding="max_length", truncation=True, max_length=MAX_LEN
    )

    return {
        "input_ids": src_encodings["input_ids"],
        "labels": tgt_encodings["input_ids"],
    }

preprocessed_ds = ds.map(preprocess_function, batched=True)

### Modeling

#### RNNs

#### Transformer

In [None]:
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig

def generate_square_subsequent_mask(sz, device):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]
    device = src.device

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(torch.bool)
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)
    src_padding_mask = (src == 0)
    tgt_padding_mask = (tgt == 0)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask


class Seq2SeqTransformerConfig(PretrainedConfig):
    def __init__(self, vocab_size_src=10000, vocab_size_tgt=10000, max_seq_length=50,
                 d_model=256, num_heads=8, num_layers=2, dropout=0.1, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size_src = vocab_size_src
        self.vocab_size_tgt = vocab_size_tgt
        self.max_seq_length = max_seq_length
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.dropout = dropout

class Seq2SeqTransformerModel(PreTrainedModel):
    config_class = Seq2SeqTransformerConfig

    def __init__(self, config):
        super().__init__(config)

        self.embedding_src = nn.Embedding(
            config.vocab_size_src, config.d_model
        )
        self.embedding_tgt = nn.Embedding(
            config.vocab_size_tgt, config.d_model
        )

        self.position_embedding_src = nn.Embedding(
            config.max_seq_length, config.d_model
        )
        self.position_embedding_tgt = nn.Embedding(
            config.max_seq_length, config.d_model
        )

        self.transformer = nn.Transformer(
            d_model=config.d_model,
            nhead=config.num_heads,
            num_encoder_layers=config.num_layers,
            num_decoder_layers=config.num_layers,
            dropout=config.dropout,
            batch_first=True
        )

        self.generator = nn.Linear(
            config.d_model, config.vocab_size_tgt
        )
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=0) # Ignore PAD token

    def forward(self, input_ids, labels):
        batch_size, seq_len_src = input_ids.shape
        _, seq_len_tgt = labels.shape

        src_positions = torch.arange(seq_len_src, device=input_ids.device).unsqueeze(0)
        tgt_positions = torch.arange(seq_len_tgt, device=labels.device).unsqueeze(0)

        src_embedded = self.embedding_src(input_ids) + self.position_embedding_src(src_positions)
        tgt_embedded = self.embedding_tgt(labels) + self.position_embedding_tgt(tgt_positions)

        src_mask, tgt_mask, src_key_padding_mask, tgt_key_padding_mask = create_mask(input_ids, labels)

        outputs = self.transformer(
            src_embedded, tgt_embedded, src_mask, tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )
        logits = self.generator(outputs)
        loss = self.loss_fn(logits.permute(0, 2, 1), labels)

        return {'loss': loss, 'logits': logits}

    def encode(self, src, src_mask):
        _, seq_len_src = src.shape
        src_positions = torch.arange(
            seq_len_src, device=src.device
        ).unsqueeze(0)
        src_embedded = self.embedding_src(src) + self.position_embedding_src(
            src_positions
        )
        return self.transformer.encoder(src_embedded, src_mask)

    def decode(self, tgt, encoder_output, tgt_mask):
        _, seq_len_tgt = tgt.shape
        tgt_positions = torch.arange(
            seq_len_tgt, device=tgt.device
        ).unsqueeze(0)
        tgt_embedded = self.embedding_tgt(tgt) + self.position_embedding_tgt(tgt_positions)
        return self.transformer.decoder(
            tgt_embedded, encoder_output, tgt_mask
        )

#### Test model

In [None]:
config = Seq2SeqTransformerConfig(
vocab_size_src=len(tokenizer_en), vocab_size_tgt=len(tokenizer_vi), max_seq_length=75)

model = Seq2SeqTransformerModel(config)

In [None]:
input_ids = torch.tensor([preprocessed_ds['train'][0]['input_ids']])
labels = torch.tensor([preprocessed_ds['train'][0]['labels']])

pred = model(input_ids, labels)

#### Trainer

In [None]:
!pip install wandb

In [None]:
# Disable wandb
import os
os.environ['WANDB_DISABLED'] = 'true'

from transformers import Trainer, TrainingArguments

# Training
training_args = TrainingArguments(
    output_dir='./en-vi-machine-translation',
    logging_dir='logs',
    eval_strategy='epoch',
    save_strategy='epoch',
    logging_strategy='epoch',
    per_device_eval_batch_size=512,
    num_train_epochs=25,
    learning_rate=2e-5,
    save_total_limit=1,
    # report_to='wandb'
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=preprocessed_ds['train'],
    eval_dataset=preprocessed_ds['validation']
)

trainer.train()

### Evaluate

In [None]:
def greedy_decode(model, src, src_mask, max_len, start_symbol, device='cpu'):
    src = src.to(device)
    src_mask = src_mask.to(device)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
    for i in range(max_len-1):
        memory = memory.to(device)
        tgt_mask = (generate_square_subsequent_mask(ys.size(1), device).type(torch.bool)).to(device)
        out = model.decode(ys, memory, tgt_mask)
        prob = model.generator(out[:, -1, :])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word[-1].item() # index

        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
        if next_word == 3: # EOS: 3
            break
    return ys

def translate(model, src_sentence, device):
    model.eval()
    input_ids = tokenizer_en([src_sentence], return_tensors='pt')['input_ids'].to(device)
    num_tokens = input_ids.shape[1]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool).to(device)
    tgt_tokens = greedy_decode(
        model, input_ids, src_mask, max_len=num_tokens+5, start_symbol=2, device=device
    )
    return tokenizer_vi.decode(tgt_tokens.detach().cpu()[0])
translate(model, 'i go to school', model.device)

In [None]:
# evaluate on test set
from tqdm import tqdm
import sacrebleu

pred_sentences, tgt_sentences = [], []
for sample in tqdm(ds['test']):
    src_sentence = sample['en']
    tgt_sentence = sample['vi']

    pred_sentence = translate(model, src_sentence, device='cuda:0')
    pred_sentences.append(pred_sentence)
    tgt_sentences.append(tgt_sentence)

bleu_score = sacrebleu.corpus_bleu(pred_sentences, [tgt_sentences], force=True)
blue_score