In [1]:
!pip install sacrebleu

Collecting sacrebleu
  Downloading sacrebleu-2.5.1-py3-none-any.whl.metadata (51 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting portalocker (from sacrebleu)
  Downloading portalocker-3.2.0-py3-none-any.whl.metadata (8.7 kB)
Downloading sacrebleu-2.5.1-py3-none-any.whl (104 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading portalocker-3.2.0-py3-none-any.whl (22 kB)
Installing collected packages: portalocker, sacrebleu
Successfully installed portalocker-3.2.0 sacrebleu-2.5.1


In [2]:
import os
import math
import time
import csv
from enum import Enum

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F

import sentencepiece as spm
import sacrebleu 

In [None]:
class Languages(Enum):
    ENG = 0
    KOR = 1
    
SRC_PAD_IDX    = None
TRG_PAD_IDX    = None
SRC_VOCAB_SIZE = None
TRG_VOCAB_SIZE = None
BATCH_SIZE     = 32

# Early Stopping Criteria
PATIENCE       = 10
DELTA          = 0.01

SAVE_DIR  = r"/kaggle/working/" # IMPORTANT: Change your output folder destination
FROM_LANG = Languages.ENG       # IMPORTANT: Change your original language 

SENTENCE_PIECE_DATA = '/kaggle/input/korean-sentencepiece/sentencepiece' # IMPORTANT: Tokenizers & data directory

# IMPORTANT: Change tokenizer's filename 
ENG_MODEL_PATH = os.path.join(SENTENCE_PIECE_DATA,'english_unigram.model') 
KOR_MODEL_PATH = os.path.join(SENTENCE_PIECE_DATA,'korean_unigram.model')

# IMPORTANT: Change training/testing data filename
train_korean  = os.path.join(SENTENCE_PIECE_DATA, 'korean.txt') 
train_english = os.path.join(SENTENCE_PIECE_DATA, 'english.txt')
test_korean   = os.path.join(SENTENCE_PIECE_DATA, 'test_korean.txt')
test_english  = os.path.join(SENTENCE_PIECE_DATA, 'test_english.txt')

# The scheduler will make learning rate be a much smaller value
# lr * 1 / sqrt(d_model) -> Choose 1.0 (dont scale up or down, we'll use the scheduler)
LEARNING_RATE    = 1.0 # Learning rate (to be multiplied with Noam scheduler)
SCHEDULER_FACTOR = 1.0 # Scaler to Noam scheduler, higher factor -> higher learning rate 
NUM_EPOCHS       = 100
SMOOTHING_FACTOR = 0.1  
TRAIN_MODEL_PATH = None
L2_LAMBDA        = 1e-5 # L2 regularization weigth

# MODEL ARCHITECTURE
HIDDEN_DIM       = 512  # Embedded word dimension
NUM_LAYERS       = 6
NUM_HEADS        = 8
MAX_SEQ_LEN      = 64
DROPOUT_RATE     = 0.3

## Data Prep

In [None]:
class TranslationDataset(Dataset):
    def __init__(self,
                 src_lines: list[str],
                 tgt_lines: list[str],
                 src_sp: spm.SentencePieceProcessor,
                 trg_sp: spm.SentencePieceProcessor,
                 max_len=64):

       self.src_sp  = src_sp
       self.trg_sp  = trg_sp
       self.max_len = max_len
       self.src = [src_sp.encode(s.strip(), out_type=int) for s in src_lines] # source tokens do not require <bos> & <eos>
       self.trg = [self.add_controls(trg_sp.encode(t.strip(), out_type=int)) for t in tgt_lines] # Add <bos> & <eos> with the add_controls() function

    def add_controls(self, ids):
        # ids[:self.max_len - 2] -> reserve space for bos & eos, only matters when len(ids) > self.max_len
        return [self.trg_sp.bos_id()] + ids[:self.max_len - 2] + [self.trg_sp.eos_id()]

    def __getitem__(self, idx):
        return {
            'input_ids': torch.tensor(self.src[idx]),
            'labels': torch.tensor(self.trg[idx])
        }

    def __len__(self):
        return len(self.src)

class MyCollate:
  '''
    Initially, each sentence has a different number of tokens, since there length
    aren't the same
    -> Use collate_fn to pad them to the same size

    Sequences within the same batch must have identical length. This does not hold
    for sequences among different batches.
  '''
  def __init__(self, pad_token_id):
    self.pad_token_id = pad_token_id

  def __call__(self, batch):
    # Work on the copies of the tensors
    input_ids = [x['input_ids'].clone().detach() for x in batch]
    labels    = [x['labels'].clone().detach() for x in batch]

    # Pad sequences in each batch with the numeric token of <pad>
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id)
    labels    = pad_sequence(labels, batch_first=True, padding_value=self.pad_token_id)

    return {
        'input_ids': input_ids,
        'labels': labels
    }  

def build_tokenizer_model(language: str):
    CHAR_COVERAGE = 0.9995 if language == 'korean' else 1.000
    VOCAB_SIZE    = 22000 if language == 'korean' else 13000 # 22000 words for korean 13000 otw

    # Choose training file(s)
    input_file = os.path.join(SENTENCE_PIECE_DATA, "{}.txt".format(language))  # or "korean.txt, english.txt" for multiple files

    # Train the tokenizer
    '''
    Advantages of sentencepiece:
    + Treats whitespace as part of the subword (denoted by '_')
    + Retains morphological meaning (adv, adj, etc.) E.g. obviously -> [obvious, ly]
      `model_type`:
          + bpe: merges the same frequent pairs.
          + unigram: chooses subwords with the highest likelihood (probability)
      `character_coverage`: amount of characters covered by the model
          + 1.0 cover all characters (Latin-based language) or 0.9995 (e.g. kanji -> ignore rare chars)
      `pad_id`: Assign the numeric id to <pad> tokens (By default, unk_id, bos_id & eos_id are 0, 1 & 2 respectively)
    '''
    spm.SentencePieceTrainer.train(
        input=input_file,
        model_prefix="{}_unigram".format(language),
        vocab_size=VOCAB_SIZE,
        model_type="unigram",
        character_coverage=CHAR_COVERAGE,
        pad_id=3
    )

def load_tokenizer_model(path_to_model: str) -> spm.SentencePieceProcessor:
    sp = spm.SentencePieceProcessor(model_file=path_to_model)
    return sp

def get_lines(file_path: str):
    lines = []
    with open(file_path, 'r', encoding='utf-8') as f:
        lines = [line.strip() for line in f if line.strip()]
    return lines

# If no tokenizer has been compiled, uncomment the following 2 lines
# build_tokenizer_model('korean')
# build_tokenizer_model('english')

# Otw, load the tokenizers
# Load tokenizers
eng_sp = load_tokenizer_model(ENG_MODEL_PATH)
kor_sp = load_tokenizer_model(KOR_MODEL_PATH)
tokenizers = {
    Languages.ENG: (eng_sp, kor_sp),
    Languages.KOR: (kor_sp, eng_sp)
}

train_sentences = {
    Languages.ENG: (train_english, train_korean),
    Languages.KOR: (train_korean, train_english)
}

test_sentences = {
    Languages.ENG: (test_english, test_korean),
    Languages.KOR: (test_korean, test_english)
}

src_tkn, trg_tkn = tokenizers[FROM_LANG]
train_src_data, train_trg_data = train_sentences[FROM_LANG]
test_src_data, test_trg_data = test_sentences[FROM_LANG]

# Maximum number of tokens in each sentence for this specific dataset
# max_len_src = max(len(eng_sp.encode(s.strip(), out_type=int)) for s in get_lines(english_file)) # 48
# max_len_tgt = max(len(kor_sp.encode(t.strip(), out_type=int)) for t in get_lines(korean_file))  # 32

# Test the tokenizers
# print(eng_sp.encode_as_pieces('this is obviously the right answer.'))
# print(kor_sp.encode_as_pieces('이건 분명히 정답이다'))

SRC_VOCAB_SIZE = src_tkn.get_piece_size()
TRG_VOCAB_SIZE = trg_tkn.get_piece_size()

# Get the source padding index
SRC_PAD_IDX = src_tkn.pad_id()
TRG_PAD_IDX = trg_tkn.pad_id()

# Modulate data -> DataLoader
train_dataset = TranslationDataset(
    get_lines(train_src_data),
    get_lines(train_trg_data),
    src_tkn,
    trg_tkn
)
val_dataset = TranslationDataset(
    get_lines(test_src_data),
    get_lines(test_trg_data),
    src_tkn,
    trg_tkn
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE, # Number of sequences per batch
    shuffle=True,
    collate_fn=MyCollate(pad_token_id=src_tkn.pad_id())
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE, # Number of sequences per batch
    shuffle=True,
    collate_fn=MyCollate(pad_token_id=src_tkn.pad_id())
)

for batch in train_loader:
    print(batch['input_ids'].shape) # e.g. torch.Size([32, max_seq_len])
    print(batch['input_ids'][0])    # first sample's input_ids (padded)
    print(batch['labels'].shape)    # e.g. torch.Size([32, max_seq_len])
    print(batch['labels'][0])       # first sample's labels (padded)
    break                           # stop after one batch

torch.Size([32, 25])
tensor([  53,  120,   49,   10, 2153, 5548,   31,   10, 2100,   18,    9,   21,
           7,   38,   10, 5548,   47, 3357,   19,   10, 1003, 5548,    4,    3,
           3])
torch.Size([32, 24])
tensor([    1,  2915,    26,     8,   348,  8020, 10644,    46,  5503,  3509,
        10644,    11,     7,    24,   614,   790,  3212,   159,     4,     2,
            3,     3,     3,     3])


## Helper

In [5]:
def create_pad_mask(input: torch.Tensor) -> torch.Tensor:
  """
  input: [batch_size, seq_length]
  """
  # True values = <pad> tokens -> masked
  pad_mask = (input == SRC_PAD_IDX).unsqueeze(1).unsqueeze(1) # [B, 1, 1, L]
  return pad_mask

def create_causal_mask(seq_len, device) -> torch.Tensor:
  """
  input: [batch_size, seq_length]

  E.g. If seq_length == 3

  [[[[False, True, True],
     [False,  False, True],
     [False,  False,  False]]]]
  """
  causal_mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.bool, device=device), diagonal=1)  # shape = [L, L]
  causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)  # shape = [1, 1, L, L]

  return causal_mask

def create_decoder_mask(src: torch.Tensor, trg) -> torch.Tensor:
  """
  src: [batch_size, seq_length]
  """
  device_type = src.device
  src_pad_mask = create_pad_mask(src) # [B, 1, 1, L]

  trg_seq_len = trg.size(1)
  trg_pad_mask = create_pad_mask(trg) # [B, 1, 1, L]
  trg_causal_mask  = create_causal_mask(trg_seq_len, device_type) # [1, 1, L, L]

  return src_pad_mask, trg_pad_mask, trg_causal_mask

class Scheduler(torch.optim.lr_scheduler.LambdaLR):
  def __init__(self, optimizer, d_model, warmup_steps=4000, scale_factor=1.0):
      self.d_model = d_model
      self.warmup_steps = warmup_steps
      self.scale_factor = scale_factor # Speed up learning rate
      super().__init__(optimizer, self.lr_lambda)

  def lr_lambda(self, step):
      if step == 0:
          step = 1
      return self.scale_factor * (self.d_model ** -0.5) * min(step ** -0.5, step * self.warmup_steps ** -1.5) # Attention is all you need

## Embedder

In [6]:
class Embedder(nn.Module):
    def __init__(self, vocab_size, d_model):
      """
        vocab_size: the size of the vocabulary
        d_model: the dimension of the embedding vector for EACH token
      """
      super().__init__()
      self.embed = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
      return self.embed(x)

## Positional Encoding

In [7]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len=64):
      """
        d_model: the dimension of the embedding vector for EACH token
      """
      super().__init__()

      # Odd: PE(pos, i) = sin(pos * (1 / 10000 ^ (i / d_model))) = sin(pos * div_term)
      position = torch.arange(max_seq_len, dtype=torch.float).unsqueeze(1)
      div_term = torch.exp(
          torch.arange(0, d_model, 2).float() * (-math.log(10000) / d_model) # torch.arange(0, d_model, 2) generates even i's (2i)
          ) # torch.exp() -> e^x; math.log() -> ln

      pe          = torch.zeros(max_seq_len, d_model) # [max_seq_len, d_model]
      pe[:, 0::2] = torch.sin(position * div_term) # Apply sine function to even positions (0, 2, 4, ...)
      pe[:, 1::2] = torch.cos(position * div_term) # Apply cosine function to odd positions (1, 3, 5, ...)
      pe          = pe.unsqueeze(0) # [1, max_seq_len, d_model] -> broadcast to batch_size later on

      assert pe.shape == (1, max_seq_len, d_model), f"[PositionalEncoding.__init__()] Expected shape (1, {max_seq_len}, {d_model}), got {self.pe.shape}"
      # register_buffer => Tensor which is not a parameter, but should be part of the modules state.
      # Used for tensors that need to be on the same device as the module.
      # persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model)
      self.register_buffer('pe', pe, persistent=False)

    def forward(self, x):
      # Add positional encodings to token embeddings
      # x: [batch_size, seq_len, d_model]
      seq_len = x.size(1)
      x = x + self.pe[:, :seq_len, :].to(x.device) # Take the first seq_len positions from the positional encoder
      return x

## Feed-Forward Network

In [8]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=2048, dropout=0.1):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        x = self.ffn(x)
        return x

## Attention

In [9]:
# nn.MultiheadAttention can be used instead but for learning purposes, I built one from scratch
# https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html

class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads, drop_out=0.1):
    """
      d_model: the dimension of the embedding vector for EACH token
      num_heads: the number of attention heads
      drop_out: the dropout rate
    """
    super().__init__()
    assert d_model % num_heads == 0, "MultiHeadAttention.__init__(): d_model must be divisible by num_heads"
    self.d_model   = d_model
    self.num_heads = num_heads
    self.head_dim  = d_model // num_heads

    self.q_linear = nn.Linear(d_model, d_model)
    self.v_linear = nn.Linear(d_model, d_model)
    self.k_linear = nn.Linear(d_model, d_model)

    # Final linear projection after concatenating heads (Learns how to merge all heads back into a single representation)
    self.out_proj = nn.Linear(d_model, d_model)

    self.dropout = nn.Dropout(drop_out)

  def forward(self, q, k, v, mask=None, cache=None):
    '''
     `cache` should be a dict of 
     {
      "K": Tensor [B, num_heads, L_cached, d_head],
      "V": Tensor [B, num_heads, L_cached, d_head],
     }
    '''
    # x: [batch_size, seq_len, d_model] or [B, L, D]
    B, L, D = q.shape

    # Linear projections (trainable parameters) [B, L, D]
    Q = self.q_linear(q)
    K = self.k_linear(k)
    V = self.v_linear(v)

    # Split into heads: [B, L, num_heads, head_dim] → transpose → [B, num_heads, L, head_dim]
    Q = Q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
    K = K.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
    V = V.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)

    if cache is not None:
        K = torch.cat([cache['K'], K], dim=2)
        V = torch.cat([cache['V'], V], dim=2)
    new_cache = {'K': K, 'V': V}
    
    # Scaled dot-product attention
    # Q      = [B, num_heads, L,        head_dim]
    # K_T    = [B, num_heads, head_dim, L]
    # scores = [B, num_heads, L,        L]
    scale  = 1 / math.sqrt(self.head_dim) # prevents the softmax function from saturation -> extremely small gradients
    scores = torch.matmul(Q, K.transpose(-2, -1)) * scale

    # Preventing the decoder from attending to future tokens (causal mask)
    # Ignoring <pad> tokens
    if mask is not None:
      scores = scores.masked_fill(mask, float('-inf')) # masked_fill(mask, value) applies 'value' where 'mask == True'

    attn = torch.softmax(scores, dim=-1) # [B, num_heads, L, L]
    attn = self.dropout(attn)

    output = torch.matmul(attn, V) # [B, num_heads, L, head_dim]

    # Concatenate heads
    output = output.transpose(1, 2).contiguous().view(B, L, D)  # [B, L, D]
    output = self.out_proj(output)
    assert output.shape == (B, L, D)

    return output, new_cache

## Encoder

In [10]:
class EncoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, ffn_dim, drop_out):
    super().__init__()
    # self.attn = MultiHeadAttention(d_model, num_heads, is_enc=True) # torch version
    self.attn = MultiHeadAttention(d_model, num_heads) # Minh's
    self.ffn  = FeedForward(d_model, ffn_dim, drop_out)

    # Normalization: https://medium.com/@florian_algo/batchnorm-and-layernorm-2637f46a998b
    # Use normalization to prevent gradient vanishing 
    # -> ensuring that the feature values fall within the range where the activation function 
    #    is more sensitive to inputs, thereby avoiding gradient vanishing and speeding up convergence.
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.dropout1 = nn.Dropout(drop_out)
    self.dropout2 = nn.Dropout(drop_out)

  def forward(self, x, mask=None):
    # 1. Attention
    x2 = self.norm1(x)
    attn1_output, _ = self.attn(x2, x2, x2, mask) # minh's multi-head attn
    x = x + self.dropout1(attn1_output) 
    # x = x + self.dropout1(self.attn(x2, x2, x2, padding_mask=mask)) # torch

    # 2. Feed-forward
    x2 = self.norm2(x)
    x = x + self.dropout2(self.ffn(x2))
    return x

In [11]:
class TransformerEncoder(nn.Module):
  def __init__(self,
               vocab_size,
               d_model=512,
               num_layers=6,
               num_attn_heads=8,
               ffn_dim=2048,
               max_seq_len=64,
               dropout=0.1):
    """
      vocab_size: the size of the vocabulary
      d_model: the dimension of the embedding vector for EACH token
      num_layers: the number of encoder layers
      num_attn_heads: the number of attention heads
      ffn_dim: the dimension of the feed-forward network
      max_seq_len: the maximum sequence length
      dropout: the dropout rate
    """
    super().__init__()
    self.d_model = d_model

    self.embedder = Embedder(vocab_size, d_model)
    self.positional_encoding = PositionalEncoding(d_model, max_seq_len)
    self.encoder_layers = nn.ModuleList([
        EncoderLayer(d_model, num_attn_heads, ffn_dim, dropout)
        for _ in range(num_layers)
    ])
    self.dropout = nn.Dropout(dropout)
    self.norm = nn.LayerNorm(d_model)

  def forward(self, input_ids):
    """
      input_ids: [batch_size, seq_len]
    """
    batch_size, seq_len = input_ids.shape
    src_mask = create_pad_mask(input_ids)

    # 1. Token Embedding
    x = self.embedder(input_ids) * math.sqrt(self.d_model)
    assert x.shape == (batch_size, seq_len, self.d_model), f"[TransformerEncoder.forward()] - Token Embedding: Expect ({batch_size}, {seq_len}, {self.d_model}), got {x.shape}"

    # 2. Positional Encoding
    x = self.positional_encoding(x)
    x = self.dropout(x) # Use dropout to prevent overfitting
    assert x.shape == (batch_size, seq_len, self.d_model), f"[TransformerEncoder.forward()] - Pos Encoding: Expect ({batch_size}, {seq_len}, {self.d_model}), got {x.shape}"

    # 3. Encoder Layers
    for encoder_layer in self.encoder_layers:
      x = encoder_layer(x, src_mask)
    assert x.shape == (batch_size, seq_len, self.d_model), f"[TransformerEncoder.forward()] - Encoder Layers: Expect ({batch_size}, {seq_len}, {self.d_model}), got {x.shape}"

    return self.norm(x) # [batch size, source length, hidden dim]

## Decoder

In [12]:
class DecoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, ffn_dim, drop_out):
    super().__init__()
    self.attn1 = MultiHeadAttention(d_model, num_heads) # Minh's
    self.attn2 = MultiHeadAttention(d_model, num_heads) # Minh's

    self.ffn  = FeedForward(d_model, ffn_dim, drop_out)

    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.norm3 = nn.LayerNorm(d_model)

    self.dropout1 = nn.Dropout(drop_out)
    self.dropout2 = nn.Dropout(drop_out)
    self.dropout3 = nn.Dropout(drop_out)

  def forward(self, x, enc_output, dec_masks, cache):
    """
    x: [batch size, target length, hidden dim]
    enc_output: [batch size, source length, hidden dim]
    dec_mask: (enc_dec_mask [B, 1, 1, L], combined_mask [1, 1, L, L])
    Note:
    - In "encoder-decoder attention" layers, the queries come from the previous decoder layer,
      and the memory keys and values come from the output of the encoder (https://arxiv.org/abs/1706.03762)
    """
    enc_dec_mask    = dec_masks[0]
    dec_pad_mask    = dec_masks[1]
    dec_causal_mask = dec_masks[2]

    # 1. Causally-masked Attention
    self_attn_cache = cache if cache != None else None
    x2 = self.norm1(x)
    attn1_output, attn1_cache = self.attn1(x2, x2, x2, dec_pad_mask | dec_causal_mask, self_attn_cache) # Only cache causal attention
    x = x + self.dropout1(attn1_output) # Minh's

    # 2. Encoder-Decoder Attention
    x2 = self.norm2(x)
    attn2_output, _ = self.attn2(x2, enc_output, enc_output, enc_dec_mask) # enc_output = k = v -> same for every decoding step -> no caching
    x = x + self.dropout2(attn2_output) # Minh's

    # 3. Feed-forward
    x2 = self.norm3(x)
    x = x + self.dropout3(self.ffn(x2))

    return x, attn1_cache

In [13]:
class TransformerDecoder(nn.Module):
  def __init__(self,
               vocab_size,
               d_model=512,
               num_layers=6,
               num_attn_heads=8,
               ffn_dim=2048,
               max_seq_len=64,
               dropout=0.1):
    super().__init__()

    self.d_model = d_model
    self.vocab_size = vocab_size

    self.embedder = Embedder(vocab_size, d_model)
    self.positional_encoding = PositionalEncoding(d_model, max_seq_len)
    self.decoder_layers = nn.ModuleList([
        DecoderLayer(d_model, num_attn_heads, ffn_dim, dropout)
        for _ in range(num_layers)
    ])

    self.dropout = nn.Dropout(dropout)
    self.norm = nn.LayerNorm(d_model)

  def forward(self, src_ids, target_ids, encoder_outputs, cache=None):
    batch_size, seq_len = target_ids.shape
    dec_masks      = create_decoder_mask(src_ids, target_ids)

    # 1. Token Embedding
    x = self.embedder(target_ids) * math.sqrt(self.d_model)
    assert x.shape == (batch_size, seq_len, self.d_model), f"[TransformerDecoder.forward()] - Token Embedding: Expect ({batch_size}, {seq_len}, {self.d_model}), got {x.shape}"

    # 2. Positional Encoding
    x = self.positional_encoding(x)
    x = self.dropout(x) # Use dropout to prevent overfitting
    assert x.shape == (batch_size, seq_len, self.d_model), f"[TransformerDecoder.forward()] - Pos Encoding: Expect ({batch_size}, {seq_len}, {self.d_model}), got {x.shape}"

    # 3. Decoder Layers
    new_caches = []
    for layer_idx, decoder_layer in enumerate(self.decoder_layers):
      layer_cache = None if cache == None else cache[layer_idx]
      x, new_layer_cache = decoder_layer(x, encoder_outputs, dec_masks, layer_cache)
      new_caches.append(new_layer_cache)
    assert x.shape == (batch_size, seq_len, self.d_model), f"[TransformerDecoder.forward()] - Encoder Layers: Expect ({batch_size}, {seq_len}, {self.d_model}), got {x.shape}"

    output = self.norm(x)
    output = torch.matmul(output, self.embedder.embed.weight.transpose(0, 1)) # output: [batch_size, seq_len, vocab_size]
    assert output.shape == (batch_size, seq_len, self.vocab_size)

    return output, new_caches

## Transformer

In [14]:
class Transformer(nn.Module):
  def __init__(self,
               src_vocab_size,
               trg_vocab_size,
               max_seq_len=64,
               num_layers=6,
               num_attn_heads=8,
               dropout=0.1):
    super().__init__()
    self.src_vocab_size = src_vocab_size
    self.trg_vocab_size = trg_vocab_size
    self.encoder = TransformerEncoder(src_vocab_size, num_layers=num_layers, num_attn_heads=num_attn_heads, dropout=dropout)
    self.decoder = TransformerDecoder(trg_vocab_size, num_layers=num_layers, num_attn_heads=num_attn_heads, dropout=dropout)
    
  def encode(self, src):
      return self.encoder(src) # [batch size, source length, hidden dim]

  def decode(self, src, trg, enc_output, cache=None):
      return self.decoder(src, trg, enc_output, cache) # [batch_size, target_seq_len, trg_vocab_size]
      
  def forward(self, src, trg):
    assert src.device == trg.device
    enc_output = self.encode(src) 
    dec_output, _ = self.decode(src, trg, enc_output) 

    return dec_output

## Metrics

In [15]:
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_seq_len = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_seq_len))
    return res

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def l2_regularization(model, l2_lambda=1.0):
    l2_norm = 0.0
    for name, param in model.named_parameters():
        if 'bias' not in name and 'norm' not in name:  # skip biases and LayerNorms
            l2_norm += torch.sum(param.pow(2))
    return l2_lambda * l2_norm

In [16]:
def log_to_csv(file_name, epoch, train_loss, val_loss, bleu_score):
    with open(file_name, 'a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([epoch, train_loss, val_loss, bleu_score])

## Translate

In [17]:
def greedy_batch(
    model,
    src,
    max_seq_len,
    trg_tokenizer
):
    trg_output = []
    batch_len, _ = src.shape
    for sentence_idx in range(batch_len):
        src_sentence = src[sentence_idx, :].unsqueeze(0)
        out_sentence = greedy_search(model, src_sentence, max_seq_len, trg_tokenizer)
        trg_output.append(out_sentence)
        
    return trg_output
    
def greedy_search(
    model,
    src,
    max_seq_len,
    trg_tokenizer
) -> str:
    # 1. Create a translation tensor, init the first token to <bos>/<sos> (beginning/start of seq)
    y_trg = torch.tensor([[trg_tokenizer.bos_id()]], device=src.device) # Current size [1, 1], will be [1, max_seq_len] eventually
    enc_output = model.encode(src) # [batch size, source length, hidden dim]

    last_cache = None
    for _ in range(max_seq_len):
        dec_output, _ = model.decode(src, y_trg, enc_output) # [1, trg_seq_len, trg_vocab_size]
        # dec_output, last_cache = model.decode(src, y_trg[:, -1:], enc_output, last_cache) # [1, trg_seq_len, trg_vocab_size]
        
        # Get the last predicted token (at position trg_seq_len - 1)
        # Return the token that has the highest vocab score
        next_token = dec_output[:, -1, :].argmax(-1) # Return the index of the token with highest value
        y_trg = torch.cat((y_trg, next_token.unsqueeze(0)), dim=-1)

        if next_token.item() == trg_tokenizer.eos_id(): # Early break if <eos> has been generated
            break
    return trg_tokenizer.decode(y_trg.squeeze().tolist())
    
def beam_batch(
    model,
    src,
    trg_tokenizer,
    beam_size=3,
    max_seq_len=MAX_SEQ_LEN,
    length_penalty=0.75
):
    trg_output = []
    batch_len, _ = src.shape
    for sentence_idx in range(batch_len):
        src_sentence = src[sentence_idx, :].unsqueeze(0)
        out_sentence = beam_search(model, src_sentence, trg_tokenizer, beam_size, max_seq_len, length_penalty)
        trg_output.append(out_sentence)
        
    return trg_output
    
def beam_search(
        model, 
        src, 
        trg_tokenizer, 
        beam_size=3, 
        max_len=MAX_SEQ_LEN, 
        length_penalty=0.75
    ) -> list:
    """
    Beam search decoding for Transformer, adapted from https://arxiv.org/pdf/1609.08144 (7. Decoder - p.12)
    
    Args:
        model: Transformer model with `encode` and `decode` methods.
        src: Source tensor [1, src_len]
        src_mask: Source padding mask [1, 1, src_len]
        tokenizer: Object with tokenizer.eos_id, tokenizer.bos_id
        beam_size: Number of beams to keep
        max_len: Maximum generation length
        length_penalty: Penalty to favor longer sequences [0.6, 0.7]
    """
    bos = trg_tokenizer.bos_id
    eos = trg_tokenizer.eos_id
    
    # 1. Encode source sentence
    enc_output = model.encode(src)
    
    # 2. Initialize beam
    beams = [{
        "tokens": torch.tensor([[bos]], device=device), # [1, 1] -> will be [1, seq_len]
        "log_prob": 0.0,
        "cache": None
    }]
    
    completed = [] # List storing complete combinations (expected to be `beam_size` to end the decoding loop)

    # 3. Decoding loop
    new_caches = None
    for step in range(max_len):
        all_candidates = []

        # Find the next (beam_size) tokens at (t) to the previous 5 beams
        # From the current 5 
        # AB, -> ABC
        #     -> ABD
        #     -> ABQ
        #     -> ABY
        #     -> ABT
        # AC, -> AC ... * 5
        # AD, -> AD ... * 5
        # DE, -> ditto
        # DQ, -> ditto
        # -> A total of 25 new combinations by the end of the for beam loop
        for beam in beams:
            y_trg = beam["tokens"]

            # Stop expanding if the last generated item is EOS
            if y_trg[0, -1].item() == eos:
                completed.append(beam)
                continue

            # Decode next token probabilities
            out, new_caches = model.decode(src, y_trg[:, -1:], enc_output, new_caches) # [1, trg_seq_len, trg_vocab_size]
            logits = out[:, -1, :]   # last timestamp logits [1, trg_vocab_size]
            log_probs = F.log_softmax(logits, dim=-1) # [1, trg_vocab_size]

            # Get top-k next tokens
            topk_log_probs, topk_ids = log_probs.topk(beam_size, dim=-1)

            # Expand beam
            for k in range(beam_size):
                # For a set of tokens in (t-1) -> create 5 more combinations of that set with new token in (t)
                candidate = {
                    "tokens": torch.cat([beam["tokens"], topk_ids[:, k].unsqueeze(0)], dim=1),
                    "log_prob": beam["log_prob"] + topk_log_probs[0, k].item(),
                }
                all_candidates.append(candidate)

        # 4. Keep top-k beams
        # Sort beams from high -> low score (score defined by lambda function)
        # Short sentences -> smaller denominator -> higher score magnitude (more negative) -> slightly penalized (avoid too short).
        all_candidates = sorted(
            all_candidates, 
            key=lambda x: x["log_prob"] / ((5 + len(x["tokens"])) ** length_penalty / (5 + 1) ** length_penalty), 
            reverse=True
        )
        beams = all_candidates[:beam_size] # Select the top-k combinations

        # Stop if all beams completed -> break the decoding loop early
        if len(completed) >= beam_size:
            break
    
    # If none of the beam has been properly decoded (has <eos> token at the end)
    # Assign the completed list to whatever has been decoded
    if not completed: 
        completed = beams

    # 5. Choose best result
    completed = sorted(
        completed,
        key=lambda x: x["log_prob"] / ((5 + len(x["tokens"])) ** length_penalty / (5 + 1) ** length_penalty),
        reverse=True
    )

    # Choose the beam with highest score (["tokens"][0] ignore the dimension 0 (which is 1) [1, seq_len])
    best = completed[0]["tokens"][0].tolist() 
    
    return trg_tokenizer.decode(best)

## Train

In [18]:
TRAIN_PRINT_FREQ = 200

def train(model, optimizer, scheduler, train_loader, criterion, device, epoch_id):
  """
    Run one epoch
  """
  model.train()
  total_loss = 0
  # accuracy_meter   = AverageMeter()
  batch_time_meter = AverageMeter()
  loss_meter       = AverageMeter()

  for i, batch in enumerate(train_loader):
    start_time = time.time()

    src = batch['input_ids'].to(device)
    trg = batch['labels'].to(device)

    trg_input = trg[:, :-1] # For each sequence, include everything excluding
                            # the last token <eos>
    trg_target = trg[:, 1:] # For each sequence, exclude the <sos> token

    # Predict
    output = model(src, trg_input)
    preds = output.reshape(-1, output.size(-1)) # [B * L, V]
    trg_target = trg_target.reshape(-1)        # [B * L]

    # Loss computation
    optimizer.zero_grad()
    loss = criterion(preds, trg_target)
    # loss += l2_regularization(model, L2_LAMBDA) # Add L2 regularization 
    loss.backward()
    optimizer.step()
    scheduler.step()

    loss_meter.update(loss.item(), src.size(0))
    batch_time_meter.update(time.time() - start_time)
    total_loss += loss.item()

    if i % TRAIN_PRINT_FREQ == 0:
      print('Epoch: [{0}][{1}/{2}]\t'
            'Time(s) {time.val:.3f} ({time.avg:.3f})\t'
            'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                epoch_id, i, len(train_loader), time=batch_time_meter,
                loss=loss_meter))
      batch_time_meter.reset()
    
  return loss_meter.avg

In [19]:
VAL_PRINT_FREQ = 100
NUM_BLEU_EVAL = 10 # Calculate BLEU on the last 10 batches

def validate(model, val_loader, criterion, device, trg_tokenizer, translate=False):
  """
    Run one epoch
  """
  model.eval()
  total_loss = 0
  accuracy_meter   = AverageMeter()
  batch_time_meter = AverageMeter()
  loss_meter       = AverageMeter()
  num_batch = len(val_loader)
    
  with torch.no_grad():
    for i, batch in enumerate(val_loader):
      start_time = time.time()

      src = batch['input_ids'].to(device)
      trg = batch['labels'].to(device)

      trg_input  = trg[:, :-1] # For each sequence, include everything excluding
                              # the last token <eos>
      trg_target = trg[:, 1:] # For each sequence, exclude the <sos> token

      # Predict
      output     = model(src, trg_input)
      preds      = output.reshape(-1, output.size(-1)) # [B * L, V]
      trg_target = trg_target.reshape(-1)              # [B * L]
      loss       = criterion(preds, trg_target)
      
      loss_meter.update(loss.item(), src.size(0))
      batch_time_meter.update(time.time() - start_time)

      total_loss += loss.item()
      if translate:
          # Metrics
          # Use greedy search and beam search
          # https://medium.com/nlplanet/two-minutes-nlp-most-used-decoding-methods-for-language-models-9d44b2375612

          # 1. Translate sentences
          torch.cuda.synchronize()
          start_search_t = time.time()
          
          trg_sentences = greedy_batch(model, src, MAX_SEQ_LEN, trg_tokenizer)

          torch.cuda.synchronize()
          end_search_t = time.time()
          print(f"Beam search took {end_search_t - start_search_t:.4f} s")
          
          # 2. Compute BLEU score
          # Decode tokens into sentences
          ref_trg_sentences = []
          trg_sentence_tokens = trg[:, 1:-1]
          for sentence_idx in range(trg_sentence_tokens.shape[0]):
              trg_sentence = trg_tokenizer.decode(trg_sentence_tokens[sentence_idx].squeeze().tolist())
              ref_trg_sentences.append(trg_sentence)

          bleu = sacrebleu.corpus_bleu(trg_sentences, [ref_trg_sentences])
          bleu_score = float(bleu.score)
          
          accuracy_meter.update(bleu_score)
          print(f'Test: [{i}/{num_batch}] BLEU: {bleu_score:.4f} SRC: {ref_trg_sentences[0]} -> TRG: {trg_sentences[0]}')
          
      if i % VAL_PRINT_FREQ == 0:
        print('Test: [{0}/{1}]\t'
              'Time(s) {time.val:.3f} ({time.avg:.3f})\t'
              'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                  i, num_batch, time=batch_time_meter,loss=loss_meter))
        batch_time_meter.reset()
  print('* Val Avg BLEU {top1.avg:.3f}'.format(top1=accuracy_meter))

  return loss_meter.avg, accuracy_meter.avg

## Main (Training Loop)

In [20]:
if not os.path.exists(SAVE_DIR):
  os.makedirs(SAVE_DIR)
LOG_FILE_PATH = os.path.join(SAVE_DIR, 'log.csv')

# Create the log file with header if it doesn't exist
if not os.path.exists(LOG_FILE_PATH):
    with open(LOG_FILE_PATH, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['epoch', 'train_loss', 'val_loss', 'bleu_score'])

# Define model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(SRC_VOCAB_SIZE, TRG_VOCAB_SIZE, MAX_SEQ_LEN, NUM_LAYERS, NUM_HEADS, DROPOUT_RATE)
model = model.to(device)

# Define optimizer, learning rate scheduler & loss function
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.98), eps=1e-9)
scheduler = Scheduler(optimizer, d_model=HIDDEN_DIM, scale_factor=SCHEDULER_FACTOR)
criterion = nn.CrossEntropyLoss(
    ignore_index=TRG_PAD_IDX,
    label_smoothing=SMOOTHING_FACTOR # smoothing factor: target word gets 1 - factor, all other classes get factor / (vocab_length - 1)
                                     # -> tells the model that 'the labelling is not 100% accurate, be skeptical'
    )
criterion = criterion.to(device)

# Load the trained model
if TRAIN_MODEL_PATH is not None:
  checkpoint = torch.load(TRAIN_MODEL_PATH)
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

patience_cnt = 0
best_loss = float('inf')
# Training loop
for epoch_id in range(NUM_EPOCHS):
  # 1. Train the model first
  train_loss = train(model, optimizer, scheduler, train_loader, criterion, device, epoch_id)

  # 2. Test the current model on validation set -> performance evaluation
  val_loss, bleu_score = validate(model, val_loader, criterion, device, trg_tkn)

  log_to_csv(LOG_FILE_PATH, epoch_id, train_loss, val_loss, bleu_score)  
  print(f'*** EPOCH {epoch_id}: Train Loss: {train_loss}; Val Loss {val_loss}')
  if val_loss < best_loss - DELTA:
    best_loss = val_loss
    patience_cnt = 0
  else:
    patience_cnt += 1
    if patience_cnt == PATIENCE:
        print(f'Early stopping at {epoch_id + 1}') # EARLY STOPPING
        torch.save({
            'epoch': epoch_id + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, os.path.join(SAVE_DIR, file_name))
        break
                
  # New strategy save a model every 10 epochs (due to kaggle limited output storage)
  if (epoch_id + 1) % 10 == 0:
    file_name = f"epoch_{epoch_id}.pth"
    print(f'Saving... {file_name}')
    torch.save({
        'epoch': epoch_id + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, os.path.join(SAVE_DIR, file_name))

## Testing

In [None]:
TEST_MODEL_PATH  = '/kaggle/input/model1/pytorch/default/1/epoch_109.pth' # IMPORTANT: Change testing model's filename

In [22]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(SRC_VOCAB_SIZE, TRG_VOCAB_SIZE, MAX_SEQ_LEN, dropout=DROPOUT_RATE)
model = model.to(device)

# Define optimizer, learning rate scheduler & loss function
criterion = nn.CrossEntropyLoss(
    ignore_index=TRG_PAD_IDX,
    label_smoothing=SMOOTHING_FACTOR # smoothing factor: target word gets 1 - factor, all other classes get factor / (vocab_length - 1)
                                     # -> tells the model that 'the labelling is not 100% accurate, be skeptical'
    )
criterion = criterion.to(device)

if TEST_MODEL_PATH is not None:
  checkpoint = torch.load(TEST_MODEL_PATH)
  model.load_state_dict(checkpoint['model_state_dict'])

val_loss, bleu_score = validate(model, val_loader, criterion, device, trg_tkn, True)
print(f'LOSS: {val_loss}; BLEU: {bleu_score}')

## Translate

In [25]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(SRC_VOCAB_SIZE, TRG_VOCAB_SIZE, MAX_SEQ_LEN, dropout=DROPOUT_RATE)
model = model.to(device)
if TEST_MODEL_PATH is not None:
  checkpoint = torch.load(TEST_MODEL_PATH)
  model.load_state_dict(checkpoint['model_state_dict'])

eng_sentence = "I am going home tomorrow."
kor_sentence = '저는 내일 집에 갑니다.'
src = torch.tensor(src_tkn.encode(eng_sentence)).unsqueeze(0)
src = src.to(device)
trg_tokens = greedy_batch(model, src, MAX_SEQ_LEN, trg_tkn)
# trg_sentences = beam_batch(model, src, src_tkn)

# 2. Compute BLEU score
bleu = sacrebleu.corpus_bleu(trg_tokens, [[kor_sentence]])
bleu_score = float(bleu.score)

# Decode tokens into sentences
print(f'(BLEU: {bleu_score}) {trg_tokens[0]}')

(BLEU: 12.703318703865365) 내일은 집에 갈 예정이예요.
