In [None]:
!pip install bert_score --quiet
!pip install evaluate --quiet
!pip install rouge_score --quiet
!pip install git+https://github.com/google-research/bleurt.git --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m37.9 kB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m472.7/472.7 kB[0m [31m29.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for BLEURT (setup.py) ... [?25l[?25hdone


In [None]:
from google.colab import drive

drive.mount("/content/drive")

%cd "/content/drive/MyDrive/Implement Classic Papers/Attention Is All You Need"

Mounted at /content/drive
/content/drive/MyDrive/Implement Classic Papers/Attention Is All You Need


In [None]:
import evaluate
import matplotlib.pyplot as plt
import seaborn as sns
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, Dataset
from typing import List, Callable, Optional, Literal
from copy import deepcopy as dc
from transformers import AutoTokenizer
from tqdm import tqdm

torch.cuda.is_available()

MAIN_DIR = os.getcwd()
DATA_DIR = os.path.join(MAIN_DIR, "eng_vie")

with open(os.path.join(DATA_DIR, "vi_sents"), "r", encoding="utf-8") as f:
    vie_sentences = f.readlines()

with open(os.path.join(DATA_DIR, "en_sents"), "r", encoding="utf-8") as f:
    eng_sentences = f.readlines()

# Model Architecture

In [None]:
def clones(module: nn.Module, N: List):
    return nn.ModuleList([dc(module) for _ in range(N)])

def get_masked_attention_mask(seq_len: str):
    mask = torch.ones(size=(1, seq_len, seq_len), dtype=torch.uint8)
    mask = mask.triu(diagonal=1)
    return mask == 0

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(
        self,
        d_model: int = 512,
        h: int = 8,
        d_k: Optional[int] = None,
        d_v: Optional[int] = None,
        dropout: float=0.1
    ):
        super(MultiHeadedAttention, self).__init__()
        self.h = h
        self.d_model = d_model

        if not d_k:
            assert self.d_model % h == 0
            self.d_k = self.d_model // h
        else:
            self.d_k = d_k

        self.d_q = self.d_k
        self.d_v = d_v or self.d_k

        self.Q = nn.Linear(self.d_model, self.d_q * h)
        self.K = nn.Linear(self.d_model, self.d_k * h)
        self.V = nn.Linear(self.d_model, self.d_v * h)
        self.O = nn.Linear(self.d_v * h, self.d_model)

        self.dropout = nn.Dropout(p=dropout)
        self.attn = None # For visualization of attention mechanism

    def self_attention(
        self, query: torch.tensor, key: torch.tensor, value: torch.tensor, mask=None, dropout=None
    ):
        # Query & Key: (batch_size, n_head, seq_len, d_k) -> Key Transpose = (batch_size, n_head, d_k, seq_len)
        # Value: (batch_size, n_head, seq_len, d_v)
        # Mask: (batch_size, 1, seq_len, seq_len)
        d_k = query.size(-1)
        self_attention = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # (batch_size, n_head, seq_len, seq_len)
        if mask is not None:
            self_attention = self_attention.masked_fill(mask==0, value=-1e9)
        attention_weights = self_attention.softmax(dim=-1) # (batch_size, n_head, seq_len, seq_len)
        if dropout:
            attention_weights = dropout(attention_weights)
        return torch.matmul(attention_weights, value), attention_weights # (batch_size, n_head, seq_len, d_v)

    def forward(
        self, query, key, value, mask=None
    ):
        # Encoder & Decoder attention unit: Query = Key = Value = (batch_size, seq_len, emb_size)
        # Decoder masked attention unit: Query = (batch_size, tgt_len, emb_size) , K = V = (batch_size, src_len, emb_size)
        # Mask: For masked attention = (batch_size, padded_tgt_len, padded_tgt_len)
        # Mask: For self attention = (1, 1 tgt_len/src_len)

        if mask is not None:
            mask = mask.unsqueeze(1) # (batch_size, 1, tgt_len, tgt_len)

        batch_size = query.size(0)

        query = self.Q(query) # (batch_size, seq_len, d_q * h)
        query = query.view(batch_size, -1, self.h, self.d_q) # (batch_size, seq_len, h, d_q)
        query = query.transpose(1, 2) # (batch_size, h, seq_len, d_q)

        key = self.K(key) # (batch_size, seq_len, d_k * h)
        key = key.view(batch_size, -1, self.h, self.d_k)
        key = key.transpose(1, 2)

        value = self.V(value) # (batch_size, seq-len, d_v * h)
        value = value.view(batch_size, -1, self.h, self.d_v)
        value = value.transpose(1, 2) # (batch_size, h, value_seq_len, )

        x, self.attn = self.self_attention(query, key, value, mask, self.dropout) # x = (batch_size, n_head, seq_len, d_v)
        x = x.transpose(1,2).contiguous().view(batch_size, -1, self.h * self.d_v)
        output = self.O(x) # (batch_size, seq_len, d_model)

        del query
        del key
        del value

        return output

class PointWiseFeedForward(nn.Module):
    def __init__(
        self, d_model: int=512, d_ff: Optional[int] = None, dropout: float=0.1
    ):
        super(PointWiseFeedForward, self).__init__()
        self.d_model=d_model
        self.d_ff = d_ff or 4 * d_model
        self.linear1 = nn.Linear(self.d_model, self.d_ff)
        self.linear2 = nn.Linear(self.d_ff, self.d_model)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        # Input batch: (batch_size, seq-len, d_model)
        x = self.linear1(x)
        x = F.gelu(x) # Original = relu
        x = self.dropout(x)
        x = self.linear2(x)
        return x

class LayerNorm(nn.Module):
    def __init__(
        self, d_model: int, eps: float=1e-8
        ):
        super(LayerNorm, self).__init__()
        self.scaling_factor = nn.Parameter(torch.ones(d_model))
        self.bias_factor = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        # (batch_size, seq_len, d_model)
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        activations = (x - mean) / (std + self.eps)
        return self.scaling_factor * activations + self.bias_factor # Point-wise additional and multiplication

class SubLayerConnection(nn.Module):
    def __init__(
        self, d_model: int, eps: float=1e-6, dropout: float=0.1
    ):
        super(SubLayerConnection, self).__init__()
        self.d_model = d_model
        self.layer_norm = LayerNorm(d_model=d_model, eps=eps)
        self.dropout = nn.Dropout(p=dropout)

    def forward(
        self, x: torch.Tensor, sub_layer: Callable
    ):
        return self.layer_norm(x + self.dropout(sub_layer(x)))

class EncoderLayer(nn.Module):
    def __init__(
        self, d_model: int=512, h: int=8
    ):
        super(EncoderLayer, self).__init__()
        self.d_model = d_model
        self.multihead_attention = MultiHeadedAttention(d_model=d_model, h=h)
        self.feed_forward = PointWiseFeedForward(d_model=d_model)
        self.att_sublayer = SubLayerConnection(d_model=d_model)
        self.ff_sublayer = SubLayerConnection(d_model=d_model)

    def forward(
        self, x: torch.Tensor, mask: Optional[torch.Tensor]=None
    ):
        x = self.att_sublayer(x, lambda x: self.multihead_attention(query=x, key=x, value=x, mask=mask))
        x = self.ff_sublayer(x, self.feed_forward)
        return x

class Encoder(nn.Module):
    def __init__(
        self, n: int=6, h: int=8, d_model: int=512
    ):
        super(Encoder, self).__init__()
        self.N, self.H, self.d_model = n, h, d_model
        self.layers = clones(EncoderLayer(d_model=d_model, h=self.H), self.N)

    def forward(
        self, x: torch.Tensor, mask: Optional[torch.Tensor]=None
    ):
        # Input = (batch_size, seq_len, d_model)
        # Ouput = (batch_size, seq_len, d_model)
        for layer in self.layers:
            x = layer(x, mask)
        return x

class DecoderLayer(nn.Module):
    def __init__(
        self, d_model: int=512, h: int=8
    ):
        super(DecoderLayer, self).__init__()
        self.d_model = d_model
        self.masked_multihead_attention = MultiHeadedAttention(d_model=d_model, h=h)
        self.multihead_attention = MultiHeadedAttention(d_model=d_model, h=h)
        self.feed_forward = PointWiseFeedForward(d_model=d_model)
        self.masked_att_sublayer = SubLayerConnection(d_model=d_model)
        self.att_sublayer = SubLayerConnection(d_model=d_model)
        self.ff_sublayer = SubLayerConnection(d_model=d_model)

    def forward(
        self,
        decoder_input: torch.Tensor,
        encoder_output: torch.Tensor,
        src_mask: Optional[torch.Tensor]=None,
        tgt_mask: Optional[torch.Tensor]=None
    ):
        # Decoder input: (bs, tgt_seq_len, d_model)
        # Encoder ouput: (bs, src_seq_len, d_model)
        x = self.masked_att_sublayer(
            decoder_input,
            lambda x: self.masked_multihead_attention(query=x, key=x, value=x, mask=tgt_mask)
            )
        x = self.att_sublayer(
            x,
            lambda x: self.multihead_attention(query=x, key=encoder_output, value=encoder_output, mask=src_mask)
            )
        x = self.ff_sublayer(x, self.feed_forward)
        return x

class Decoder(nn.Module):
    def __init__(
        self, n: int=6, h: int=8, d_model: int=512
    ):
        super(Decoder, self).__init__()
        self.N, self.H, self.d_model = n, h, d_model
        self.layers = clones(DecoderLayer(d_model=d_model, h=self.H), self.N)

    def forward(
        self,
        x: torch.Tensor,
        encoder_output: torch.Tensor,
        src_mask: Optional[torch.Tensor]=None,
        tgt_mask: Optional[torch.Tensor]=None
    ):
        # x: (bs, tgt_seq_len, d_model)
        # encoder_output: (bs, src_seq_len, d_model)
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return x

class TokenEmbeddings(nn.Module):
    def __init__(
        self, d_model: int, vocab_size: int
    ):
        super(TokenEmbeddings, self).__init__()
        self.d_model, self.vocab_size = d_model, vocab_size
        self.embeddings = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=d_model, padding_idx=0
        )

    def forward(
        self, inputs: torch.tensor
    ):
        return self.embeddings(inputs) * math.sqrt(self.d_model)

class Generator(nn.Module):
    def __init__(self, d_model, vocab_size):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        return F.log_softmax(self.proj(x), dim=-1)

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(
        self, d_model: int, dropout: float=0.1, max_len: int=5000
    ):
        super(SinusoidalPositionalEncoding, self).__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(size=(max_len, d_model)) # (max_len, d_model)
        const_term = math.log(10000) / d_model
        div_terms = torch.exp(-torch.arange(0, d_model, 2) * const_term) # (d_model//2)
        positions = torch.arange(0, max_len).unsqueeze(1) # (max_len, 1)
        pe[:, ::2] = torch.sin(positions*div_terms)  # sin(pos * div_term)
        pe[:, 1::2] = torch.cos(positions*div_terms)  # sin(pos * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(
        self, x: torch.tensor
    ):
        assert x.size(2) == self.d_model
        # Input sequence: (batch_size, seq_length)
        x = x + self.pe[:, : x.size(1), :].requires_grad_(False)
        return self.dropout(x)

class LearnablePositionalEncoding(nn.Module):
    def __init__(
        self, d_model: int, dropout: float=0.1, max_len: int=5000
    ):
        super(LearnablePositionalEncoding, self).__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(p=dropout)
        self.embeddings = nn.Embedding(max_len, d_model)

    def forward(
        self, x
    ):
        assert x.size(2) == self.d_model
        embs = self.embeddings(torch.arange(0, x.size(1), 1))
        x = x + embs
        return self.dropout(x)

# Need to implement separate encode and decode functions as we only want to encode src sequence once for each new tgt token.
# Need to implement forward pass end-to-end so that we can backpropagate end-to-end from decoder to encoder.
class TransformersSeqToSeq(nn.Module):
    def __init__(
        self,
        n: int=6,
        h: int=8,
        d_model: int=512,
        src_vocab_size: int=30000,
        tgt_vocab_size: int=30000,
        share_embeddings: bool=False,
        max_tokens: int=4096
    ):
        super(TransformersSeqToSeq, self).__init__()
        self.n, self.h, self.d_model = n, h, d_model
        self.encoder = Encoder(n=n, h=h, d_model=d_model)
        self.decoder = Decoder(n=n, h=h, d_model=d_model)
        self.src_embeddings = TokenEmbeddings(d_model=d_model, vocab_size=src_vocab_size)
        if share_embeddings:
            self.tgt_embeddings = self.src_embeddings
            self.generator = Generator(d_model=d_model, vocab_size=src_vocab_size)
        else:
            self.tgt_embeddings = TokenEmbeddings(d_model=d_model, vocab_size=tgt_vocab_size)
            self.generator = Generator(d_model=d_model, vocab_size=tgt_vocab_size)

        self.positional_encoding = SinusoidalPositionalEncoding(d_model=d_model, max_len=5000)
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        self.max_tokens=max_tokens

    def encode(
        self, src: torch.tensor, src_mask: Optional[torch.tensor] = None
    ):
        # src = (batch_size, src_seq_len)
        src_token_embeddings = self.src_embeddings(src)
        src_embeddings = self.positional_encoding(src_token_embeddings)
        return self.encoder(x=src_embeddings, mask=src_mask)

    def decode(
        self, tgt, memory, src_mask, tgt_mask
    ):
        # tgt = (batch_size, tgt_seq_len)
        # memory = (batch_size, src_seq_len, d_model)
        tgt_token_embeddings = self.tgt_embeddings(tgt)
        tgt_embeddings = self.positional_encoding(tgt_token_embeddings)
        decoder_output = self.decoder(x=tgt_embeddings, encoder_output=memory, src_mask=src_mask, tgt_mask=tgt_mask) # (batch_size, seq_len, d_model)
        probs = self.generator(decoder_output)
        return probs

    def forward(
        self, src: torch.tensor, tgt: torch.tensor, src_mask: Optional[torch.tensor] = None, tgt_mask: Optional[torch.tensor] = None
    ):
        # src = (batch_size, src_seq_len)
        # tgt = (batch_size, tgt_seq_len)
        memory = self.encode(src, src_mask)
        probs = self.decode(tgt=tgt, memory=memory, src_mask=src_mask, tgt_mask=tgt_mask)
        return probs

    def greedy_decode(
        self,
        src: torch.tensor,
        src_mask: Optional[torch.tensor] = None,
        max_tokens: Optional[int] = None,
        cls_token_id: int = 101,
        eos_token_id: int = 102,
    ):
        device = next(self.parameters()).device
        self.eval()
        max_tokens = max_tokens or self.max_tokens

        if src.ndim == 1:
            src = src.unsqueeze(0)

        batch_size = src.size(0)
        encoder_output = self.encode(src, src_mask)

        tgt = torch.ones(size=(batch_size, 1), dtype=torch.long, device=device) * cls_token_id

        current_token = 1
        is_finished = torch.zeros(batch_size, dtype=torch.bool, device=device)

        while current_token <= max_tokens and (not is_finished.all()):
            tgt_mask = get_masked_attention_mask(tgt.size(-1)).to(device)
            last_probs = self.decode(
                tgt=tgt,
                memory=encoder_output,
                src_mask=src_mask,
                tgt_mask=tgt_mask)[:, -1, :]

            next_token_id = last_probs.argmax(dim=-1)
            tgt = torch.cat([tgt, next_token_id.unsqueeze(-1)], dim=-1)
            is_finished = is_finished | (next_token_id == eos_token_id)
            current_token += 1

        preds = []
        for pred in tgt:
            token_seq = []
            for token_id in pred:
                token_seq.append(token_id.item())
                if token_id.item() == eos_token_id:
                    break
            preds.append(token_seq)
        return preds

    def beam_search_decode(
        self,
        src: torch.tensor,
        src_mask: Optional[torch.tensor] = None,
        max_tokens: Optional[int] = None,
        cls_token_id: int = 101, eos_token_id: int = 102,
        beam_width: int = 4
    ):
        device = next(self.parameters()).device
        self.eval()
        max_tokens = max_tokens or self.max_tokens

        if src.ndim == 1:
            src = src.unsqueeze(0)

        encoder_output = self.encode(src, src_mask)

        beams = [(torch.tensor([cls_token_id], device=device), 0)]  # List of tuples (sequence, score)

        for _ in range(max_tokens):
            all_candidates = []
            for seq, score in beams:
                tgt = seq.unsqueeze(0)  # Add batch dimension
                tgt_mask = get_masked_attention_mask(tgt.size(-1)).to(device)
                # Decode the current sequence
                out = self.decode(
                    tgt,
                    encoder_output,
                    src_mask=src_mask,
                    tgt_mask=tgt_mask
                    )
                last_probs = out[:, -1, :].squeeze(0)

                # Expand each beam with all possible next tokens
                for i in range(last_probs.size(-1)):
                    candidate = (torch.cat([seq, torch.tensor([i], device=device)]), score + last_probs[i].item())
                    all_candidates.append(candidate)

            # Select the top beam_width candidates
            all_candidates = sorted(all_candidates, key=lambda x: x[1], reverse=True)
            beams = all_candidates[:beam_width]

            # Check if all beams have reached the end token
            if all(seq[-1] == eos_token_id for seq, _ in beams):
                break

        # Select the best beam
        best_sequence = sorted(beams, key=lambda x: x[1], reverse=True)[0][0]
        return best_sequence

class CustomCrossEntropyLoss(nn.Module):
    def __init__(
        self,
        label_smoothing: float=0.1,
        ignore_index: int=0,
        use_kl_divergence: bool=False
    ):
        super(CustomCrossEntropyLoss, self).__init__()
        self.use_kl_divergence = use_kl_divergence
        if use_kl_divergence:
            self.criterion = nn.KLDivLoss(reduction="sum")
            self.ignore_index = ignore_index
            self.confidence = 1 - label_smoothing
            self.smoothing = label_smoothing
            self.true_dist = None

        else:
            self.criterion = nn.CrossEntropyLoss(
                label_smoothing=label_smoothing,
                ignore_index=0,
                reduction="sum"
            )

    def forward(
        self, preds, labels
    ):
        if self.use_kl_divergence:
            batch_size = labels.size(0)
            num_classes = preds.size(1)
            true_dist = preds.data.clone()
            negative_prob = self.smoothing / (num_classes - 2) # Exclude <PAD> and <CLS> tokens
            true_dist.fill_(negative_prob) # Update probability for negative classes
            true_dist.scatter_(1, labels.data.unsqueeze(1), self.confidence) # Update probability for positive class
            true_dist[:, self.ignore_index] = 0
            mask = torch.nonzero(labels.data == self.ignore_index)
            if mask.dim() > 0:
                true_dist.index_fill_(0, mask.squeeze(), 0.0)
            self.true_dist = true_dist
            return self.criterion(
                preds, true_dist.clone().detach()
            )
        else:
            return self.criterion(preds, labels)

# Prepare Training Data

In [None]:
class VieToEngDataset(Dataset):
    def __init__(
        self,
        vie_sentences: List[str],
        en_sentences: Optional[List[str]] = None,
    ):
        self.vie_sentences = vie_sentences
        self.en_sentences = en_sentences

    def __len__(self):
        return len(self.vie_sentences)

    def __getitem__(self, idx):
        return (self.vie_sentences[idx], self.en_sentences[idx]) if self.en_sentences else self.vie_sentences[idx]

class VieToEngDataManager:
    def __init__(
        self,
        vie_sentences: List[str],
        en_sentences: Optional[List[str]] = None,
        vie_tokenizer: Callable = None,
        en_tokenizer: Optional[Callable] = None,
        max_length: int = 4096,
        seed: int = 42
    ):
        self.dataset = VieToEngDataset(
            vie_sentences=vie_sentences, en_sentences=en_sentences
        )
        self.train_dataset = None
        self.test_dataset = None
        self.max_length = max_length
        self.vie_tokenizer = vie_tokenizer or AutoTokenizer.from_pretrained("vinai/bartpho-syllable")
        self.en_tokenizer = en_tokenizer or AutoTokenizer.from_pretrained("bert-large-uncased")
        self.generator = torch.Generator().manual_seed(seed)

    def _collate_fn(
        self, batch_data: List
    ):
        if self.dataset.en_sentences:
            src = [sample[0] for sample in batch_data]
            tgt = [sample[1] for sample in batch_data]
        else:
            src = batch_data
            tgt = None
        tokenized_src = self.vie_tokenizer(
            src, padding=True, truncation=True, return_tensors="pt", max_length = self.max_length
        )
        if tgt:
            tokenized_tgt = self.en_tokenizer(
                tgt, padding=True, truncation=True, return_tensors="pt", max_length = self.max_length
            )
            tgt_input_ids = tokenized_tgt["input_ids"][:, :-1]
            tgt_labels = tokenized_tgt["input_ids"][:, 1:]
            masked_att_mask = get_masked_attention_mask(tgt_input_ids.size(-1))
            tgt_mask = tokenized_tgt["attention_mask"][:, :-1].unsqueeze(-2) & masked_att_mask
            ntokens = tokenized_tgt["attention_mask"][:, 1:].sum()

        return {
            "src_input_ids": tokenized_src["input_ids"],
            "src_mask": tokenized_src["attention_mask"].unsqueeze(-2),
            "tgt_input_ids": tgt_input_ids,
            "tgt_labels": tgt_labels,
            "tgt_mask": tgt_mask,
            "ntokens": ntokens
        }

    def get_data_loader(
        self, dataset: Optional[Dataset] = None, batch_size: Optional[int] = 32, shuffle: bool = True, test_split: Optional[float] = 0.2
    ):
        dataset = dataset or self.dataset
        if test_split:
            self.train_dataset, self.test_dataset = torch.utils.data.random_split(
                self.dataset,
                lengths=[1-test_split, test_split],
                generator=self.generator
            )
            train_dataloader = DataLoader(
                self.train_dataset,
                batch_size,
                shuffle=True,
                drop_last=True,
                collate_fn=lambda x: self._collate_fn(x)
            )
            test_dataloader = DataLoader(
                self.test_dataset,
                batch_size,
                shuffle=False,
                drop_last=False,
                collate_fn=lambda x: self._collate_fn(x)
            )
            return train_dataloader, test_dataloader

        else:
            return DataLoader(
                dataset,
                batch_size,
                shuffle=shuffle,
                drop_last=False,
                collate_fn=lambda x: self._collate_fn(x)
            )

In [None]:
en_tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased")
vie_tokenizer = AutoTokenizer.from_pretrained("vinai/bartpho-syllable")

vie2en_dataset = VieToEngDataset(
    vie_sentences=vie_sentences,
    en_sentences=eng_sentences
)

vie2en_datamanager = VieToEngDataManager(
    vie_sentences = vie_sentences,
    en_sentences = eng_sentences,
    vie_tokenizer = vie_tokenizer,
    en_tokenizer = en_tokenizer
)

In [None]:
BLEU_SCORER = evaluate.load("bleu")
ROUGE_SCORER = evaluate.load('rouge')

def evaluate_translation(
    references: List[str],
    predictions: List[str]
):
    bleu_score = BLEU_SCORER.compute(predictions=predictions, references=references)
    rouge_score = ROUGE_SCORER.compute(predictions=predictions, references=references)

    print("Bleu Score:", bleu_score["bleu"])
    print("Rouge Score:", rouge_score)

    return bleu_score, rouge_score

class TrainState:
    step: int = 0
    n_updates: int = 0
    samples: int = 0
    tokens: int = 0
    best_loss: float = float("inf")
    patience_counter: int = 0

def run_train_epoch(
    data_iter: DataLoader,
    model: nn.Module,
    loss_criterion: nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler,
    train_state,
    tgt_tokenizer,
    accumulation_no: int=1,
    track_gradients: bool = False,
):
    model.train()
    total_tokens = 0
    total_loss = 0.0
    n_updates = 0
    all_losses = []

    device = next(model.parameters()).device

    for i, batch in tqdm(enumerate(data_iter), total=len(data_iter)):
        src = batch["src_input_ids"].to(device)
        tgt = batch["tgt_input_ids"].to(device)
        src_mask = batch["src_mask"].to(device)
        tgt_mask = batch["tgt_mask"].to(device)
        ntokens = batch["ntokens"]
        labels = batch["tgt_labels"].to(device)

        probs = model(
            src=src, tgt=tgt, src_mask=src_mask, tgt_mask=tgt_mask
        )

        batch_loss = loss_criterion(
            probs.contiguous().view(-1, tgt_tokenizer.vocab_size),
            labels.contiguous().view(-1)
        )

        loss = batch_loss / ntokens

        loss.backward()

        if i % accumulation_no == 0:
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            n_updates += 1
            train_state.n_updates += 1

        scheduler.step()

        train_state.step += 1
        train_state.samples += src.size(0)
        train_state.tokens += ntokens

        all_losses.append(loss.item())
        total_loss += batch_loss.item()
        total_tokens += ntokens

        if i % 200 == 1:
            lr = optimizer.param_groups[0]["lr"]
            print(
                "Train Epoch Step: {:6d} | Gradient Update Step: {:3d} | Loss Per Token: {:.5f} | Learning Rate: {:6.1e}".format(
                    i, n_updates, loss, lr)
                )

    if track_gradients:
        calculate_gradient_norm(model)

    return total_loss / total_tokens, train_state, all_losses

def run_eval_epoch(
    data_iter: DataLoader,
    model: nn.Module,
    loss_criterion: nn.Module,
    tgt_tokenizer,
    train_state
):
    model.eval()
    total_tokens = 0
    total_loss = 0.0
    n_updates = 0
    all_losses = []
    all_preds = []
    all_reference = []
    device = next(model.parameters()).device

    for i, batch in tqdm(enumerate(data_iter), total=len(data_iter)):
        with torch.no_grad():
            src = batch["src_input_ids"].to(device)
            tgt = batch["tgt_input_ids"].to(device)
            src_mask = batch["src_mask"].to(device)
            tgt_mask = batch["tgt_mask"].to(device)
            ntokens = batch["ntokens"]
            labels = batch["tgt_labels"].to(device)

            probs = model(
                src=src, tgt=tgt, src_mask=src_mask, tgt_mask=tgt_mask
            )

            batch_loss = loss_criterion(
                probs.contiguous().view(-1, en_tokenizer.vocab_size), labels.contiguous().view(-1)
            )

            loss = batch_loss / ntokens

            all_losses.append(loss.item())
            total_loss += batch_loss.item()
            total_tokens += ntokens

            pred_tokens = model.greedy_decode(
                src=src, src_mask=src_mask,
                cls_token_id=tgt_tokenizer.convert_tokens_to_ids(tgt_tokenizer.special_tokens_map['cls_token']),
                eos_token_id=tgt_tokenizer.convert_tokens_to_ids(tgt_tokenizer.special_tokens_map['sep_token'])
            )

            all_preds.extend(tgt_tokenizer.batch_decode(pred_tokens, skip_special_tokens=True))
            all_reference.extend(tgt_tokenizer.batch_decode(labels, skip_special_tokens=True))

    print("Total per token eval loss:", total_loss / total_tokens)
    if total_loss / total_tokens < train_state.best_loss:
        train_state.best_loss = total_loss / total_tokens
        train_state.patience_counter = 0
    else:
        train_state.patience_counter += 1

    ## Print out evaluation scores
    evaluate_translation(all_reference, all_preds)

    return total_loss / total_tokens, all_losses

def custom_lr_schedule(
    step_no: int,
    d_model: int = 512,
    warm_up: int = 4000,
) -> float:
    if step_no == 0:
        step_no = 1
    return d_model ** (-0.5) * min(step_no**(-0.5), step_no*(warm_up**(-1.5)))

def calculate_gradient_norm(model):
    gradient_dict = {}
    for name, param in model.named_parameters():
        if param.grad is not None:
            layer_type_name = ".".join(name.split(".")[:3])
            if layer_type_name not in gradient_dict:
                gradient_dict[layer_type_name] = []
            gradient_dict[layer_type_name].append(param.grad.abs().mean().item())

    for layer_type_name, module_weights in gradient_dict.items():
        print(f"{layer_type_name} | Mean Gradient: {np.mean(module_weights)}")

# Training

## Training and Model Arguments

In [None]:
model_args = dict(
    n=6, h=8, d_model=512,
    src_vocab_size=vie_tokenizer.vocab_size,
    tgt_vocab_size=en_tokenizer.vocab_size,
    share_embeddings=False,
    max_tokens=96
)

optimizer_args = dict(
    lr=1, betas=(0.9, 0.98), eps=1e-9
)

training_args = dict(
    batch_size=128,
    epoch_nos=30,
    test_split=0.2,
    warm_up=4000,
    lr=1,
    betas=(0.9, 0.98),
    eps=1e-9,
    accumulation_no=4,
    kl_divergence_loss=False
)

train_dataloader, test_dataloader = vie2en_datamanager.get_data_loader(
    batch_size=training_args["batch_size"], test_split=training_args["test_split"]
    )

device = "cuda" if torch.cuda.is_available() else "cpu"

model = TransformersSeqToSeq(**model_args).to(device)
optimizer = Adam(model.parameters(), lr=training_args["lr"], betas=training_args["betas"], eps=training_args["eps"])
criterion = CustomCrossEntropyLoss(
    label_smoothing=0.1,
    ignore_index = en_tokenizer.convert_tokens_to_ids(en_tokenizer.special_tokens_map['pad_token']),
    use_kl_divergence=training_args["kl_divergence_loss"]
    ).to(device)
scheduler = LambdaLR(
    optimizer=optimizer,
    lr_lambda=lambda x: custom_lr_schedule(x, d_model=model_args["d_model"], warm_up=training_args["warm_up"]))

## Train

In [None]:
train_state = TrainState()
train_loss, val_loss = [], []
early_stopping = True
patience = 5

for epoch_no in range(training_args["epoch_nos"]):
    print(f"Epoch: {epoch_no}")
    model.train()
    train_epoch_loss, _, train_step_loss = run_train_epoch(
        data_iter = train_dataloader,
        model = model,
        loss_criterion = criterion,
        optimizer = optimizer,
        scheduler = scheduler,
        train_state = train_state,
        tokenizer = en_tokenizer,
        accumulation_no = training_args["accumulation_no"],
        track_gradients = True
    )
    train_loss.extend(train_step_loss)
    torch.cuda.empty_cache()
    model.eval()
    val_epoch_loss, val_step_loss = run_eval_epoch(
        data_iter = test_dataloader,
        model = model,
        loss_criterion = criterion,
        tgt_tokenizer = en_tokenizer,
        train_state = train_state
    )
    val_loss.extend(val_step_loss)
    torch.cuda.empty_cache()
    if early_stopping and train_state.patience_counter > patience:
        break
    
    ckpt_folder = os.path.join(output_dir, "checkpoints", "epoch_{}".format(epoch_no))
    os.makedirs(ckpt_folder, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(ckpt_folder, "model.pt"))
    
torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pt"))

## Visualize training curves

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_loss, label="train")
plt.plot(val_loss, label="val")
plt.legend()
plt.show()