### **Import Package**

In [32]:
import re 
import math 
import time
import random 
import unicodedata
from pyvi import ViTokenizer
from collections import Counter
from sklearn.model_selection import train_test_split

import torch 
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### **Load Data**

In [5]:
eng_data_path = "datasets/en_sents.txt"
vnm_data_path = "datasets/vi_sents.txt"

In [6]:
# read data 
with open(eng_data_path, 'r', encoding='utf-8') as file: 
    eng_sentences = [line.strip() for line in file if line.strip()]
file.close()

with open(vnm_data_path, 'r', encoding='utf-8') as vi_file: 
    vi_sentences = [line.strip() for line in vi_file if line.strip()]
vi_file.close()

In [7]:
print(eng_sentences[2:4])
print(vi_sentences[2:4])
print(f"Total records: {len(vi_sentences)}")

['Read this', 'Tom persuaded the store manager to give him back his money.']
['đọc này', 'tom thuyết phục người quản lý cửa hàng trả lại tiền cho anh ta.']
Total records: 254090


### **Pre-Processing Data**

In this step, we implement to normalize data before training. This process includes steps below:

B1: Normalization Sentence (process separatly for english sentence and vietnamese sentence)
    
        - For Vietnamese: 
            + remove special quotes 
            + normalize whitespace 
            + apply Vietnamese word segmentation 
            
        - For English: 
            + convert accent characters to unaccent characters 
            + add space before punctuation
            + remove non-letter characters 
            + normalize multiple spaces 

B2: Add a pair <sos> and <eos> to mark the starting and the ending of a sentence. 

B3: Tokenize each sentence. 

B4: Train - Val - Test Split 

B5: Build vocabularies

B6: Convert sentences to index sequences (numericalization)

In [8]:
# English Normalization
def unicode_to_ascii(s):
    return "".join(c for c in unicodedata.normalize('NFD', s)
                   if unicodedata.category(c) != 'Mn')

def normalize_en(s): 
    s = unicode_to_ascii(s.lower().strip())   
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) 
    s = re.sub(r"\s+", " ",s)
    return s 

In [9]:
# Vietnamese Normalization
def normalize_vi(s): 
    s = s.lower().strip()
    s = re.sub(r"[“”\"‘’]", "", s)
    s = re.sub(r"\s+", " ", s)
    s = ViTokenizer.tokenize(s)
    return s

In [10]:
# apply to normalize paired data 
def preprocess_translation_pairs(vi_sentences, eng_sentences):
    processed = []
    for vi, eng in zip(vi_sentences, eng_sentences): 
        vi_norm = normalize_vi(vi)
        eng_norm = normalize_en(eng)
        eng_final = f"<sos> {eng_norm} <eos>"
        processed.append([vi_norm, eng_final])
    return processed

In [11]:
processed_data = preprocess_translation_pairs(vi_sentences, eng_sentences)

In [12]:
for vi, en in processed_data[:10]:
    print(f"VI: {vi}")
    print(f"EN: {en}\n")

VI: xin vui_lòng đặt người quét rác trong tủ chổi
EN: <sos> please put the dustpan in the broom closet <eos>

VI: im_lặng một lát
EN: <sos> be quiet for a moment . <eos>

VI: đọc này
EN: <sos> read this <eos>

VI: tom thuyết_phục người quản_lý cửa_hàng trả lại tiền cho anh ta .
EN: <sos> tom persuaded the store manager to give him back his money . <eos>

VI: tình bạn bao_gồm sự hiểu_biết lẫn nhau
EN: <sos> friendship consists of mutual understanding <eos>

VI: ngày_mai bạn có đến không
EN: <sos> are you going to come tomorrow ? <eos>

VI: nhìn thấy vấn_đề này ngay lập_tức , bạn sẽ ?
EN: <sos> see to this matter right away will you ? <eos>

VI: tôi đã cho bạn_bè của tôi xem những tấm bưu_thiếp hình_ảnh .
EN: <sos> i showed my friends these picture postcards . <eos>

VI: mary là em_út trong ba chị_em
EN: <sos> mary is the youngest of the three sisters <eos>

VI: anh ấy có hai người dì ở bên mẹ .
EN: <sos> he has two aunts on his mother s side . <eos>



#### **Tokenize**

In [13]:
# we consider each word as a token 
def tokenize_pairs(processed_data):
    tokenized_data = []
    for vi, en in processed_data:
        vi_tokens = vi.split()
        en_tokens = en.split()
        tokenized_data.append((vi_tokens, en_tokens))
    return tokenized_data

tokenized_data = tokenize_pairs(processed_data)

#### **Train - Val - Test Split** 

In [14]:
train_data, temp_data = train_test_split(tokenized_data, test_size= 0.2, random_state= 42)
val_data, test_data = train_test_split(temp_data, test_size= 0.5, random_state= 42)

#### **Build Vocabulary** 

In [15]:
def build_vocab(tokenized_data, idx, min_freq, specials):
    counter = Counter()
    for vi_tokens, eng_tokens in tokenized_data: 
        if idx == 0: 
            counter.update(vi_tokens)
        else: 
            counter.update(eng_tokens)
    
    vocab = {}
    vocab_idx = 0 
    for token in specials:
        vocab[token] = vocab_idx
        vocab_idx += 1 
    
    for token, freq in counter.items():
        if freq >= min_freq and token not in vocab: 
            vocab[token] = vocab_idx
            vocab_idx += 1 
    return vocab 

In [16]:
vi_vocab = build_vocab(train_data, idx=0, min_freq=1, specials=["<pad>", "<unk>"])
eng_vocab = build_vocab(train_data, idx=1, min_freq=1, specials=["<pad>", "<unk>", "<sos>", "<eos>"])

#### **Numericalization**

In [17]:
def numericalize(tokens, vocab): 
    return [vocab.get(token, vocab["<unk>"]) for token in tokens]

train_numericalized = []
for vi_tokens, eng_tokens in train_data:
    vi_ids = numericalize(vi_tokens, vi_vocab)
    eng_ids = numericalize(eng_tokens, eng_vocab)
    train_numericalized.append((vi_ids, eng_ids))  # convert (eng_ids, vi_ids) when translate english to vietnamese

In [18]:
train_numericalized = []
for vi_tokens, eng_tokens in train_data:
    vi_ids = numericalize(vi_tokens, vi_vocab)
    eng_ids = numericalize(eng_tokens, eng_vocab)
    train_numericalized.append((vi_ids, eng_ids))  # convert (eng_ids, vi_ids) when translate english to vietnamese

val_numericalized = []
for vi_tokens, eng_tokens in val_data:
    vi_ids = numericalize(vi_tokens, vi_vocab)
    eng_ids = numericalize(eng_tokens, eng_vocab)
    val_numericalized.append((vi_ids, eng_ids))  # convert (eng_ids, vi_ids) when translate english to vietnamese

test_numericalized = []
for vi_tokens, eng_tokens in test_data:
    vi_ids = numericalize(vi_tokens, vi_vocab)
    eng_ids = numericalize(eng_tokens, eng_vocab)
    test_numericalized.append((vi_ids, eng_ids))  # convert (eng_ids, vi_ids) when translate english to vietnamese

In [None]:
all_src_tokens = set()
all_tgt_tokens = set()

for src_seq, tgt_seq in train_numericalized + val_numericalized + test_numericalized:
    all_src_tokens.update(src_seq)
    all_tgt_tokens.update(tgt_seq)

vocab_size_src = max(all_src_tokens) + 1  
vocab_size_tgt = max(all_tgt_tokens) + 1

print("Vocab src size:", vocab_size_src)   # number different tokens in source language dictionary (VN)
print("Vocab tgt size:", vocab_size_tgt)   # number different tokens in target language dictionary (ENG)

Vocab src size: 12220
Vocab tgt size: 18246


In [31]:
train_numericalized

[([2, 3, 4, 5, 6, 7, 8], [2, 4, 5, 6, 7, 8, 9, 10, 11, 3]),
 ([9, 10, 11, 12, 13, 14, 15], [2, 12, 13, 14, 15, 16, 17, 18, 3]),
 ([11, 12, 16, 17, 18, 19, 20, 21, 22, 23, 19, 24, 25, 26],
  [2, 19, 5, 14, 20, 21, 22, 23, 24, 25, 26, 13, 27, 28, 3]),
 ([27, 28, 11, 29, 3, 30], [2, 29, 30, 31, 32, 33, 34, 35, 3]),
 ([31, 32, 33, 34, 35, 36, 37, 38], [2, 36, 37, 38, 39, 30, 40, 28, 3]),
 ([39, 40, 41, 42, 38, 30], [2, 41, 42, 43, 44, 30, 35, 3]),
 ([31, 43, 44, 45, 40, 32, 46, 47, 48, 49, 50, 51, 52],
  [2, 36, 45, 46, 43, 47, 48, 39, 49, 50, 51, 3]),
 ([40, 53, 54, 55], [2, 52, 53, 54, 43, 55, 56, 35, 3]),
 ([31, 56, 11, 55, 57, 3, 58], [2, 36, 57, 58, 59, 60, 28, 3]),
 ([9, 59, 60, 61, 62, 2], [2, 12, 61, 9, 62, 3]),
 ([5, 63], [2, 63, 64, 28, 3]),
 ([40, 32, 64, 65, 66, 55], [2, 52, 65, 43, 66, 7, 8, 35, 3]),
 ([3, 11, 67, 68, 69, 70, 71, 72], [2, 19, 67, 33, 68, 69, 70, 71, 72, 73, 3]),
 ([73, 74, 75, 76], [2, 74, 75, 76, 22, 23, 77, 3]),
 ([73, 77, 78, 26], [2, 74, 75, 78, 28, 3]),
 

### **DataLoader**

In [20]:
class TranslationDataset(Dataset): 
    def __init__(self, data): 
        """
        data: list of tuples
        """
        self.data = data 
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        src, tgt = self.data[index]
        return torch.tensor(src, dtype= torch.long), torch.tensor(tgt, dtype=torch.long)


def collate_fn(batch):
    """
    batch: list of (src_tensor, tgt_tensor)
    Returns:
        src_padded: (batch_size, src_len_max)
        tgt_padded: (batch_size, tgt_len_max)
        src_mask: (batch_size, src_len_max) -> mask cho attention
    Note: use look-ahead mask in case process with Transformer 
    """
    src_batch, tgt_batch = zip(*batch) 

    # pad sequences to max length in batch 
    src_padded = pad_sequence(src_batch, batch_first=True, padding_value=0)
    tgt_padded = pad_sequence(tgt_batch, batch_first=True, padding_value=0)

    # Mask: True in pad position
    src_mask = (src_padded == 0)
    tgt_mask = (tgt_padded == 0)
    return src_padded, tgt_padded, src_mask, tgt_mask

In [21]:
train_dataset = TranslationDataset(train_numericalized)
val_dataset = TranslationDataset(val_numericalized)
test_dataset = TranslationDataset(test_numericalized)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle= True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle= True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle= True, collate_fn=collate_fn)

### **Seq2Seq Architecture**

In [None]:
class EncoderLSTM(nn.Module):
    def __init__(self, vocab_size_src, hidden_size, num_layers, dropout):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size_src, hidden_size, padding_idx= 0)
        self.lstm = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size, 
            num_layers=num_layers, 
            dropout=dropout if num_layers > 1 else 0.0,
            bidirectional= False
        )
    
    def init_hidden(self, batch_size):
        # initialize hidden state (h0) and cell state (c0) filled with zeros 
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
        return (h0, c0)
    
    def forward(self, src_tokens, hidden):
        emb = self.embedding(src_tokens)  # convert token IDs to embeddings 
        emb = emb.transpose(0, 1)         # switch to (seq_len, batch, hidden) for LSTM
        outputs, hidden = self.lstm(emb, hidden)
        outputs = outputs.transpose(0, 1)
        return outputs, hidden


class AttentionDecoderLSTM(nn.Module):
    def __init__(self, hidden_size, vocab_size_target, num_layers, dropout):
        super(AttentionDecoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.vocab_size_target = vocab_size_target

        self.embedding = nn.Embedding(vocab_size_target, hidden_size, padding_idx=0)
        self.dropout = nn.Dropout(dropout)

        # combine [context; embedding] into hidden size 
        self.attn_combine = nn.Linear(hidden_size * 2, hidden_size)
        self.lstm = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size, 
            num_layers=num_layers, 
            dropout=dropout if num_layers > 1 else 0.0,
            bidirectional= False
        )

        # final output projection 
        self.out = nn.Linear(hidden_size, vocab_size_target)
    
    def forward(self, input_tokens, hidden, encoder_outputs, src_mask=None):
        embedded = self.embedding(input_tokens)   # convert token IDs to embeddings 
        embedded = self.dropout(embedded)

        # dot-product attention: score = encoder_output * decoder_hidden_last 
        h_last = hidden[0][-1]
        scores = torch.bmm(encoder_outputs, h_last.unsqueeze(2)).squeeze(2)

        if src_mask is not None: 
            scores = scores.masked_fill(src_mask, float('-inf'))
        attn_weights = F.softmax(scores, dim= 1)   # softmax to get attention weights

        # compute context vector as weighted sum of encode outputs 
        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)
        context = context.squeeze(1) 

        # concatenate context with embedding and transform 
        combo = torch.cat([embedded, context], dim=1) 
        combo = torch.tanh(self.attn_combine(combo))

        # LSTM expects (seq_len=1, batch, hidden)
        lstm_in = combo.unsqueeze(0) 
        output, hidden_next = self.lstm(lstm_in, hidden)
        output = output.squeeze(0)
        log_probs = F.log_softmax(self.out(output), dim=1)

        return log_probs, hidden_next, attn_weights

#### **Training**

In [None]:
class Seq2SeqTrainConfig:
    def __init__(self):
        self.n_epochs = 2
        self.learning_rate = 1e-3
        self.teacher_forcing_ratio = 0.5
        self.max_length = 20
        self.SOS_token = 2
        self.EOS_token = 3
        self.optimizer = torch.optim.Adam
        self.ignore_index = 0   # PAD token ids 
        self.clip_grad_norm = 5.0
        self.hidden_size = 256 
        self.num_layers = 3
        self.dropout = 0.1
        self.model_save_path = "seq2seq_attn.pth"
        self.loss_log_path = "loss_history.txt"

In [None]:
def train_model(encoder, decoder, train_loader, config: Seq2SeqTrainConfig): 
    loss_history = []
    encoder.to(device)
    decoder.to(device)
    encoder.train()
    decoder.train()

    encoder_optimizer = config.optimizer(encoder.parameters(), lr=config.learning_rate)
    decoder_optimizer = config.optimizer(decoder.parameters(), lr=config.learning_rate)
    criterion = nn.NLLLoss(ignore_index= config.ignore_index, reduction="sum")

    total_start = time.time()
    for epoch in range(config.n_epochs + 1): 
        epoch_start = time.time() 
        total_loss = 0.0
        total_tokens = 0
        batch_count = 0

        for src_padded, tgt_padded, src_mask, _ in train_loader: 
            src_padded = src_padded.to(device)
            tgt_padded = tgt_padded.to(device)
            src_mask = src_mask.to(device)

            batch_size = src_padded.size(0)
            batch_count += 1

            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()

            # encode 
            encode_hidden = encoder.init_hidden(batch_size)
            encode_outputs, encode_hidden = encoder(src_padded, encode_hidden)

            # decode 
            decode_input = torch.full((batch_size,), config.SOS_token, dtype=torch.long, device=device)
            decode_hidden = encode_hidden
            
            tgt_len = tgt_padded.size(1) 
            use_tf = random.random() < config.teacher_forcing_ratio 
            non_pad = (tgt_padded != config.ignore_index).sum().item()
            total_tokens += non_pad
            loss = 0.0

            for t in range(tgt_len):
                log_probs, decode_hidden, _ = decoder(decode_input, decode_hidden, encode_outputs, src_mask)
                gold = tgt_padded[:, t]
                loss += criterion(log_probs, gold)
                decode_input = gold if use_tf else log_probs.argmax(dim=1)

            # backprop 
            loss.backward()
            nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=config.clip_grad_norm)
            nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=config.clip_grad_norm)
            encoder_optimizer.step()
            decoder_optimizer.step()

            total_loss += loss.item()

        # end of epoch 
        avg_loss = total_loss / max(total_tokens, 1)
        loss_history.append(avg_loss) 
        epoch_time = time.time() - epoch_start
        tokens_per_sec = total_tokens / epoch_time

        print(f"[Epoch {epoch}/{config.n_epochs}] "
              f"Loss/token: {avg_loss:.4f} | "
              f"Tokens: {total_tokens} | "
              f"Batches: {batch_count} | "
              f"Time: {epoch_time:.2f}s | "
              f"{tokens_per_sec:.1f} tok/s")

    # measure training time 
    total_time = time.time() - total_start
    print(f"Training complete in {total_time:.2f}s") 

    # save models 
    torch.save({"encoder": encoder.state_dict(),
                "decoder": decoder.state_dict()}, config.model_save_path)
    print(f"Saved model to {config.model_save_path}")

    # save logging 
    with open(config.loss_log_path, "w") as f:
        for loss in loss_history:
            f.write(f"{loss}\n")

In [None]:
params = Seq2SeqTrainConfig()

encoder = EncoderLSTM(
    vocab_size_src= vocab_size_src, 
    hidden_size= params.hidden_size,
    num_layers= params.num_layers, 
    dropout= params.dropout
)

decoder = AttentionDecoderLSTM(
    hidden_size= params.hidden_size, 
    vocab_size_target= vocab_size_tgt, 
    num_layers= params.num_layers, 
    dropout= params.dropout
)

train_model(encoder, decoder, train_loader, params)

### **Transformer**

In [None]:
class TransformerTrainConfig:
    def __init__(self):
        # training hyper-parameters
        self.n_epochs = 2
        self.learning_rate = 1e-3
        self.teacher_forcing_ratio = 0.5
        self.clip_grad_norm = 5.0
        self.betas = (0.9, 0.98)
        self.eps = 1e-9
        
        # model architecture
        self.d_model = 512
        self.nhead = 8
        self.num_encoder_layers = 6
        self.num_decoder_layers = 6
        self.dim_feedforward = 2048
        self.dropout = 0.1
        
        # sequence parameters 
        self.max_length = 20
        self.SOS_token = 2
        self.EOS_token = 3
        self.ignore_index = 0  # PAD token id
        
        # optimizer
        self.optimizer = torch.optim.Adam
        
        # save paths
        self.model_save_path = "transformer_model.pth"
        self.loss_log_path = "transformer_loss_history.txt"

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len= 5000): 
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x): 
        return x + self.pe[:x.size(0), :]


class TransformerModel(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, config: TransformerTrainConfig):
        super().__init__()
        self.d_model = config.d_model
        self.pad_idx = config.ignore_index
        self.config = config
        
        # embeddings
        self.src_embedding = nn.Embedding(src_vocab_size, config.d_model, padding_idx=config.ignore_index)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, config.d_model, padding_idx=config.ignore_index)

        # positional encodings 
        self.pos_encoding = PositionalEncoding(config.d_model, config.max_length)

        # built Transformer by Pytorch 
        self.transformer = nn.Transformer(
            d_model = config.d_model,
            nhead = config.nhead, 
            num_encoder_layers=config.num_encoder_layers,
            num_decoder_layers=config.num_decoder_layers,
            dim_feedforward=config.dim_feedforward,
            dropout=config.dropout,
            batch_first=False
        )

        self.output_projection = nn.Linear(config.d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(config.dropout)
    
    def create_mask(self, src, tgt):
        tgt_seq_len = tgt.shape[1]

        src_padding_mask = (src == self.pad_idx)   # source padding mask 
        tgt_padding_mask = (tgt == self.pad_idx)   # target padding mask 

        # target look-ahead mask 
        tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_seq_len).to(tgt.device) 
        return src_padding_mask, tgt_padding_mask, tgt_mask
    
    def forward(self, src, tgt): 
        # create masks 
        src_padding_mask, tgt_padding_mask, tgt_mask = self.create_mask(src, tgt[:, :-1])
        
        # Embedding + scaling + positional encoding
        src_emb = self.src_embedding(src) * math.sqrt(self.config.d_model)  
        tgt_emb = self.tgt_embedding(tgt[:, :-1]) * math.sqrt(self.config.d_model)  
        
        # transpose for format
        src_emb = src_emb.transpose(0, 1)  
        tgt_emb = tgt_emb.transpose(0, 1)  
        
        #add positional encoding
        src_emb = self.pos_encoding(src_emb)
        tgt_emb = self.pos_encoding(tgt_emb)
        
        src_emb = self.dropout(src_emb)
        tgt_emb = self.dropout(tgt_emb)
        
        # transformer forward 
        output = self.transformer(
            src_emb, 
            tgt_emb,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_padding_mask,
            tgt_key_padding_mask=tgt_padding_mask
        )
        
        output = output.transpose(0, 1) 
        output = self.output_projection(output) 
        
        return output

In [None]:
def train_model(model, train_loader, val_loader, config: TransformerTrainConfig):
    model.to(device)
    criterion = nn.CrossEntropyLoss(ignore_index= config.ignore_index)  # ignore padding 
    optimizer = config.optimizer(
        model.parameters(), 
        lr=config.learning_rate,
        betas=config.betas,
        eps=config.eps
    )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.8)
    best_val_loss = float('inf') 
    loss_history = []

    for epoch in range(config.n_epochs):
        model.train() 
        total_train_loss = 0 
        for batch_idx, (src, tgt, _, _) in enumerate(train_loader):
            src, tgt = src.to(device), tgt.to(device)
            optimizer.zero_grad()

            output = model(src, tgt)   # forward pass 
            output = output.reshape(-1, output.size(-1))  # reshape for loss
            target = tgt[:, 1:].reshape(-1)
            
            loss = criterion(output, target)
            loss.backward()

            # gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
            optimizer.step()
            total_train_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}')

        # validation 
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for src, tgt, _, _ in val_loader:
                src, tgt = src.to(device), tgt.to(device)
                output = model(src, tgt)
                output = output.reshape(-1, output.size(-1))
                target = tgt[:, 1:].reshape(-1)

                val_loss = criterion(output, target)
                total_val_loss += val_loss.item()

        avg_train_loss = total_train_loss / len(train_loader)
        avg_val_loss = total_val_loss / len(val_loader)
        print(f'Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}')

        loss_history.append((avg_train_loss, avg_val_loss))
        
        # save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), config.model_save_path)
            print(f"Saved best model to {config.model_save_path}")
            
        scheduler.step()

    # save history 
    with open(config.loss_log_path, 'w') as f:
        f.write("Epoch,Train_Loss,Val_Loss\n")
        for i, (train_loss, val_loss) in enumerate(loss_history):
            f.write(f"{i+1},{train_loss:.6f},{val_loss:.6f}\n")

In [None]:
config = TransformerTrainConfig()

model = TransformerModel(
    src_vocab_size=vocab_size_src, 
    tgt_vocab_size=vocab_size_tgt, 
    config=config
)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Config - d_model: {config.d_model}, nhead: {config.nhead}")
print(f"Config - epochs: {config.n_epochs}, lr: {config.learning_rate}")

train_model(
    model= model,
    train_loader= train_loader,
    val_loader= val_loader, 
    config= config
)
