# Sử dụng mô hình Transformer để giải quyết bài toán dịch máy

## Dependency

In [56]:
import os
import random
import pandas as pd
import sentencepiece as spm
import sacrebleu
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import zipfile
from tqdm import tqdm
import math
import time
from torch.cuda.amp import autocast, GradScaler

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

PyTorch version: 2.8.0
CUDA available: False


## Configuration and Hyperparameters

In [57]:
# Data path
INPUT_DIR = "../../data/IWSLT15/"
TRAIN_EN_PATH = f"{INPUT_DIR}train.en.txt"
TRAIN_VI_PATH = f"{INPUT_DIR}train.vi.txt"

TEST_EN_12_PATH = f"{INPUT_DIR}tst2012.en.txt"
TEST_VI_12_PATH = f"{INPUT_DIR}tst2012.vi.txt"

TEST_EN_13_PATH = f"{INPUT_DIR}tst2013.en.txt"
TEST_VI_13_PATH = f"{INPUT_DIR}tst2013.vi.txt"

SAVE_DIR = "./checkpoints"
SPM_EN_PREFIX = os.path.join(SAVE_DIR, "spm_en")
SPM_VI_PREFIX = os.path.join(SAVE_DIR, "spm_vi")

# Model hyperparameters
VOCAB_SIZE = 15000
SEED = 42
MAX_LEN=100
BATCH_SIZE=64

# Special tokens
PAD_ID = 0
BOS_ID = 1
EOS_ID = 2
UNK_ID = 3

# Device configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device use: {DEVICE}")

# Set random seeds
os.makedirs(SAVE_DIR, exist_ok=True)
random.seed(SEED)
torch.manual_seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

Device use: cpu


## Data loading and Preprocessing

In [58]:
def load_parallel(src_path, tgt_path):
    """
    Đọc dữ liệu song ngữ và làm sạch:
    - loại dòng rỗng
    - loại dòng chỉ có '.'
    - luôn giữ dữ liệu theo cặp (src[i] tương ứng tgt[i])
    """

    cleaned_src = []
    cleaned_tgt = []

    bad_lines = []

    with open(src_path, "r", encoding="utf-8") as fs, \
         open(tgt_path, "r", encoding="utf-8") as ft:

        for idx, (s, t) in enumerate(zip(fs, ft), start=1):
            s = s.strip()
            t = t.strip()

            # if either side empty or is just "."
            if (not s) or (not t) or s == "." or t == ".":
                bad_lines.append(idx)
                continue

            cleaned_src.append(s)
            cleaned_tgt.append(t)

    print("Số cặp bị loại:", len(bad_lines))
    print("Các dòng bị loại:", bad_lines[:20], "..." if len(bad_lines) > 20 else "")

    return cleaned_src, cleaned_tgt


#Loading training and test data
train_src, train_tgt = load_parallel(TRAIN_EN_PATH, TRAIN_VI_PATH)
test_src, test_tgt = load_parallel(TEST_EN_12_PATH, TEST_VI_12_PATH)

# test_src = readlines(TEST_EN_13_PATH)
# test_tgt = readlines(TEST_VI_13_PATH)

print(f"Training samples: {len(train_src)}")
print(f"Test sample: {len(test_src)}")
print(f"\nExample English sentence: {train_src[1]}")
print(f"Example Vietnamese sentence: {train_tgt[1]}")

Số cặp bị loại: 153
Các dòng bị loại: [470, 8696, 9763, 10708, 21739, 26409, 29495, 38601, 39558, 41018, 48826, 50895, 51587, 54159, 56298, 57141, 57747, 58331, 66261, 68756] ...
Số cặp bị loại: 0
Các dòng bị loại: [] 
Training samples: 133164
Test sample: 1553

Example English sentence: In 4 minutes , atmospheric chemist Rachel Pike provides a glimpse of the massive scientific effort behind the bold headlines on climate change , with her team -- one of thousands who contributed -- taking a risky flight over the rainforest in pursuit of data on a key molecule .
Example Vietnamese sentence: Trong 4 phút , chuyên gia hoá học khí quyển Rachel Pike giới thiệu sơ lược về những nỗ lực khoa học miệt mài đằng sau những tiêu đề táo bạo về biến đổi khí hậu , cùng với đoàn nghiên cứu của mình -- hàng ngàn người đã cống hiến cho dự án này -- một chuyến bay mạo hiểm qua rừng già để tìm kiếm thông tin về một phân tử then chốt .


## SentencePiece Tokenization
Train BPE tokenizers for both English and Vietnamese

In [59]:
def train_spm(input_file, model_prefix, vocab_size=VOCAB_SIZE):
    """Train a SentencePiece BPE model"""
    args = (
        f"--input={input_file} "
        f"--model_prefix={model_prefix} "
        f"--vocab_size={vocab_size} "
        "--model_type=bpe "
        "--character_coverage=1.0 "
        f"--pad_id={PAD_ID} "
        f"--unk_id={UNK_ID} "
        f"--bos_id={BOS_ID} "
        f"--eos_id={EOS_ID}"
    )

    spm.SentencePieceTrainer.Train(args)
    print(f"Trained SentencePiece model: {model_prefix}.model")

def load_sp(model_path):
    """Load a trained SentencePiece model"""
    sp = spm.SentencePieceProcessor()
    sp.Load(model_path)
    return sp

In [60]:
# Train English tokenizer
tmp_en = os.path.join(SAVE_DIR, "tmp_en.txt")
if not os.path.exists(SPM_EN_PREFIX + ".model"):
    with open(tmp_en, 'w', encoding='utf-8') as f:
        for s in train_src:
            f.write(s + "\n")
    train_spm(tmp_en, SPM_EN_PREFIX)

# Train Vietnamese tokenizer
tmp_vi = os.path.join(SAVE_DIR, "tmp_vi.txt")
if not os.path.exists(SPM_VI_PREFIX + ".model"):
    with open(tmp_vi, 'w', encoding='utf-8') as f:
        for s in train_tgt:
            f.write(s + "\n")
    train_spm(tmp_vi, SPM_VI_PREFIX)

# Load tokenizers
sp_en = load_sp(SPM_EN_PREFIX + ".model")
sp_vi = load_sp(SPM_VI_PREFIX + ".model")

print(f"\nEnglish vocab size: {sp_en.GetPieceSize()}")
print(f"Vietnamese vocab size: {sp_vi.GetPieceSize()}")

# Test tokenization
test_sent = train_src[0]
tokens = sp_en.encode(test_sent)
sent = sp_en.decode(tokens)
print(f"\nExample tokenization:")
print(f"Original: {sent}")
print(f"Token IDs: {tokens[:20]}...")


English vocab size: 15000
Vietnamese vocab size: 15000

Example tokenization:
Original: Rachel Pike : The science behind a climate headline
Token IDs: [10717, 299, 1267, 214, 155, 1116, 1724, 6, 2089, 10320]...


## Dataset and Dataloader

### Dataset

In [61]:
class TranslationDataset(Dataset):
    """
    Dataset Parallel 
    """
    def __init__(self, src, tgt, sp_src, sp_tgt, max_len=MAX_LEN):
        self.src = src
        self.tgt = tgt
        self.sp_src = sp_src
        self.sp_tgt = sp_tgt
        self.max_len = max_len
    
    def __len__(self):
        return len(self.src)
    
    def __getitem__(self, index):
        src_ids = [BOS_ID] + self.sp_src.encode(self.src[index])[:self.max_len-2] + [EOS_ID]
        tgt_ids = [BOS_ID] + self.sp_tgt.encode(self.tgt[index])[:self.max_len-2] + [EOS_ID]

        return torch.tensor(src_ids), torch.tensor(tgt_ids)


### Dataloader

In [62]:
# Giữ nguyên TranslationDataset của bạn
# Tối ưu collate_fn một chút để an toàn hơn
def collate_fn(batch):
    src_list, tgt_list = zip(*batch)
    
    # Pad sequence bằng hàm có sẵn của Pytorch (nhanh hơn loop thủ công)
    from torch.nn.utils.rnn import pad_sequence
    
    # Batch first = True cho pad_sequence
    src_pad = pad_sequence(src_list, batch_first=True, padding_value=PAD_ID)
    tgt_pad = pad_sequence(tgt_list, batch_first=True, padding_value=PAD_ID)
    
    # Tách tgt_in (bỏ token cuối) và tgt_out (bỏ token đầu)
    # Lưu ý: slicing tensor vẫn giữ nguyên shape batch
    tgt_in = tgt_pad[:, :-1]
    tgt_out = tgt_pad[:, 1:]
    
    return src_pad, tgt_in, tgt_out

In [63]:
# Create training dataset and dataloader
# Split data: 90% train, 10% validation

total_samples_src = len(train_src)
total_samples_tgt = len(train_tgt)
train_size = int(0.9 * total_samples_src)

train_src_split = train_src[:train_size]
train_tgt_split = train_tgt[:train_size]
val_src = train_src[train_size:]
val_tgt = train_tgt[train_size:]

print(f"Total samples src: {total_samples_src}")
print(f"Total samples src: {total_samples_tgt}")
print(f"Training size: {len(train_tgt_split)}")
print(f"Validation size: {len(val_src)}")

train_dataset = TranslationDataset(src=train_src_split, tgt=train_tgt_split, sp_src=sp_en, sp_tgt=sp_vi)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)

val_dataset = TranslationDataset(src=val_src, tgt=val_tgt, sp_src=sp_en, sp_tgt=sp_vi)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=False)

print(f"\nTraining batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

src_tk = val_dataset[1][0].tolist()
tgt_tk = val_dataset[1][1].tolist()

print(src_tk)
print(tgt_tk)
print(sp_en.decode(src_tk))
print(sp_vi.decode(tgt_tk))

Total samples src: 133164
Total samples src: 133164
Training size: 119847
Validation size: 13317

Training batches: 1873
Validation batches: 209
[1, 134, 87, 101, 10, 1161, 276, 326, 748, 211, 127, 240, 1324, 12, 2]
[1, 383, 26, 38, 1481, 308, 167, 531, 712, 14, 2]
So this was the model which actually came out -- very amazing .
Đây là một sơ đồ rất tuyệt vời ,


In [64]:
for src, tgt_in, tgt_out in train_loader:
    print(f"SRC: {src}")
    print(f"\nTGT_IN: {tgt_in}")
    print(f"\nTGT_OUT: {tgt_out}")
    break

SRC: tensor([[    1, 11950,   619,  ...,     0,     0,     0],
        [    1,   502,    65,  ...,     0,     0,     0],
        [    1,   502,    65,  ...,     0,     0,     0],
        ...,
        [    1,   344,     8,  ...,     0,     0,     0],
        [    1,   502,   109,  ...,     0,     0,     0],
        [    1,    83,   311,  ...,     0,     0,     0]])

TGT_IN: tensor([[   1, 1017,  128,  ...,    0,    0,    0],
        [   1,  482,  167,  ...,    0,    0,    0],
        [   1,  482,  447,  ...,    0,    0,    0],
        ...,
        [   1,  522,   70,  ...,    0,    0,    0],
        [   1,  482,  536,  ...,    0,    0,    0],
        [   1,   93,  424,  ...,    0,    0,    0]])

TGT_OUT: tensor([[1017,  128,  486,  ...,    0,    0,    0],
        [ 482,  167,  172,  ...,    0,    0,    0],
        [ 482,  447, 1058,  ...,    0,    0,    0],
        ...,
        [ 522,   70,  486,  ...,    0,    0,    0],
        [ 482,  536,   38,  ...,    0,    0,    0],
        [  93, 

## TransformerMT

In [65]:
class TransformerMT(nn.Module):
    def __init__(self, sp_src_size, sp_tgt_size, d_model=512, nhead=8,
                 num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, pad_idx=0):
        super().__init__()
        self.model_type = "Transformer"
        self.d_model = d_model
        self.pad_idx = pad_idx
        
        self.src_tok_emb = nn.Embedding(sp_src_size, d_model, padding_idx=pad_idx)
        self.tgt_tok_emb = nn.Embedding(sp_tgt_size, d_model, padding_idx=pad_idx)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        
        # QUAN TRỌNG: batch_first=True
        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead,
                                          num_encoder_layers=num_encoder_layers,
                                          num_decoder_layers=num_decoder_layers,
                                          dim_feedforward=dim_feedforward,
                                          dropout=dropout,
                                          batch_first=True) 
                                          
        self.generator = nn.Linear(d_model, sp_tgt_size)
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.src_tok_emb.weight.data.uniform_(-initrange, initrange)
        self.tgt_tok_emb.weight.data.uniform_(-initrange, initrange)
        self.generator.bias.data.zero_()
        self.generator.weight.data.uniform_(-initrange, initrange)

    def encode(self, src, src_key_padding_mask=None):
        # src: [batch, seq_len] -> embedding: [batch, seq_len, d_model]
        src_emb = self.src_tok_emb(src) * math.sqrt(self.d_model)
        src_emb = self.pos_encoder(src_emb)
        # Không cần transpose nữa
        return self.transformer.encoder(src_emb, src_key_padding_mask=src_key_padding_mask)

    def decode(self, tgt, memory, tgt_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        tgt_emb = self.tgt_tok_emb(tgt) * math.sqrt(self.d_model)
        tgt_emb = self.pos_encoder(tgt_emb)
        return self.transformer.decoder(tgt_emb, memory,
                                          tgt_mask=tgt_mask,
                                          tgt_key_padding_mask=tgt_key_padding_mask,
                                          memory_key_padding_mask=memory_key_padding_mask)

    def forward(self, src, tgt_in, src_key_padding_mask=None, tgt_key_padding_mask=None, tgt_mask=None):
        src_emb = self.src_tok_emb(src) * math.sqrt(self.d_model)
        src_emb = self.pos_encoder(src_emb)
        
        tgt_emb = self.tgt_tok_emb(tgt_in) * math.sqrt(self.d_model)
        tgt_emb = self.pos_encoder(tgt_emb)

        # Transformer với batch_first=True tự động xử lý
        outs = self.transformer(src_emb, tgt_emb, 
                                tgt_mask=tgt_mask, 
                                src_key_padding_mask=src_key_padding_mask,
                                tgt_key_padding_mask=tgt_key_padding_mask,
                                memory_key_padding_mask=src_key_padding_mask)
        
        return self.generator(outs)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        # compute positional encodings once
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 1:
            # odd dims: last column will be zero for cos
            pe[:, 1::2] = torch.cos(position * div_term[:-1])
        else:
            pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        x = x + self.pe[:, :x.size(1), :].to(x.device)
        return self.dropout(x)
    
    
def generate_square_subsequent_mask(sz):
    mask = torch.triu(torch.ones((sz, sz), device='cpu') * float('-inf'), diagonal=1)
    return mask

## Helpers

In [66]:
model = TransformerMT(sp_src_size=VOCAB_SIZE, sp_tgt_size=VOCAB_SIZE)

### Train/Val epoch

In [67]:
def translate_sentence(model, src_tokens, sp_src, sp_tgt, device, max_len=100):
    model.eval()
    with torch.no_grad():
        # Encode src
        src_ids = torch.tensor([sp_src.encode(src_tokens)], dtype=torch.long).to(device)
        src_mask = (src_ids == PAD_ID)
        memory = model.encode(src_ids, src_key_padding_mask=src_mask)
        
        # Bắt đầu với BOS
        ys = torch.tensor([[BOS_ID]], dtype=torch.long).to(device)
        
        for i in range(max_len):
            tgt_mask = generate_square_subsequent_mask(ys.size(1)).to(device)
            
            # Decode (chú ý batch_first logic)
            out = model.decode(ys, memory, tgt_mask=tgt_mask, memory_key_padding_mask=src_mask)
            
            # Generator
            prob = model.generator(out[:, -1]) # Chỉ lấy token cuối cùng
            _, next_word = torch.max(prob, dim=1)
            next_word = next_word.item()

            ys = torch.cat([ys, torch.tensor([[next_word]], device=device)], dim=1)
            if next_word == EOS_ID:
                break
        
        # Decode token IDs sang text
        final_ids = ys.squeeze(0).tolist()
        # Remove BOS/EOS if needed for cleaner output
        if final_ids[0] == BOS_ID: final_ids = final_ids[1:]
        if final_ids[-1] == EOS_ID: final_ids = final_ids[:-1]
            
        return sp_tgt.decode(final_ids)

In [68]:
sample_token = train_dataset[0][0]
sample_sent = sp_en.decode(sample_token.tolist())
print(sample_token)
print(sample_sent)

translate_sentence(model=model, src_tokens=sample_sent, sp_src=sp_en, sp_tgt=sp_vi, device=DEVICE, max_len=10)

tensor([    1, 10717,   299,  1267,   214,   155,  1116,  1724,     6,  2089,
        10320,     2])
Rachel Pike : The science behind a climate headline


'itageitageitageitageitageitageitageitageitageitage'

In [69]:
# ----------------------------
# Training/Validation loops
# ----------------------------
def train_epoch(model, dataloader, criterion, optimizer, scheduler, device, scaler):
    model.train()
    total_loss = 0.0
    
    p_bar = tqdm(dataloader, desc="Train", leave=False)
    for src, tgt_in, tgt_out in p_bar:
        src, tgt_in, tgt_out = src.to(device), tgt_in.to(device), tgt_out.to(device)

        # Masks
        src_key_padding_mask = (src == PAD_ID)
        tgt_key_padding_mask = (tgt_in == PAD_ID)
        tgt_mask = generate_square_subsequent_mask(tgt_in.size(1)).to(device)

        optimizer.zero_grad()
        
        # Mixed Precision Training
        with autocast(): # Tự động chuyển float32 -> float16 khi cần
            output = model(src, tgt_in, 
                           src_key_padding_mask=src_key_padding_mask, 
                           tgt_key_padding_mask=tgt_key_padding_mask, 
                           tgt_mask=tgt_mask)
            
            # Reshape để tính loss
            loss = criterion(output.reshape(-1, output.size(-1)), tgt_out.reshape(-1))

        # Backward với scaler
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        scaler.step(optimizer)
        scaler.update()
        
        # QUAN TRỌNG: Step scheduler tại mỗi batch
        scheduler.step()
        
        total_loss += loss.item()
        
        # Update progress bar
        p_bar.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'])

    return total_loss / len(dataloader)

def val_epoch(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0

    with torch.no_grad():
      for src_batch, tgt_in_batch, tgt_out_batch in dataloader:
          src = src_batch.to(device)
          tgt_in = tgt_in_batch.to(device)
          tgt_out = tgt_out_batch.to(device)

          src_key_padding_mask = (src == 0)
          tgt_key_padding_mask = (tgt_in == 0)

          tgt_mask = generate_square_subsequent_mask(sz=tgt_in.size(1)).to(device)
          output = model(src, tgt_in,
                        src_key_padding_mask=src_key_padding_mask,
                        tgt_key_padding_mask=tgt_key_padding_mask,
                        memory_key_padding_mask=src_key_padding_mask,
                        tgt_mask=tgt_mask)
          # output shape = (batch, tgt_len, vocab)
          loss = criterion(output.view(-1, output.size(-1)), tgt_out.view(-1))
          total_loss += loss.item() * src.size(0)
      return total_loss/len(dataloader.dataset)
    
@torch.no_grad()
def evaluate_bleu(model, dataloader, sp_src, sp_tgt, device):
    """Evaluate model using SacreBLEU metric."""
    model.eval()
    preds, refs = [], []

    pbar = tqdm(dataloader, desc="Evaluating", leave=False)

    for src, tgt_in, tgt_out in pbar:
        src = src.to(device)
        tgt_in = tgt_in.to(device)

        # ===== Decode source =====
        src_ids = src[0].tolist()         # lấy 1 câu trong batch
        src_ids = [x for x in src_ids if x != 0]   # remove padding
        src_text = sp_src.decode(src_ids)

        # print("="*80)
        # print(f"SRC_sentence: {src_text}")
        # ===== Translate =====
        pred_text = translate_sentence(
            model,
            src_text,
            sp_src=sp_src,
            sp_tgt=sp_tgt,
            device=device
        )

        # print(f"PRED_sentence: {pred_text}")

        # ===== Decode reference =====
        ref_ids = tgt_in[0].tolist()
        ref_ids = [x for x in ref_ids if x != 0]
        ref_text = sp_tgt.decode(ref_ids)

        # print(f"REF_sentence: {ref_text}")

        # print("="*80)
        preds.append(pred_text)
        refs.append(ref_text)

    bleu = sacrebleu.corpus_bleu(preds, [refs])
    return bleu.score
# evaluate_bleu(model=model, dataloader=val_loader, sp_src=sp_en, sp_tgt=sp_vi, device=DEVICE)

### Early stopping

In [70]:
# ----------------------------
# Early Stopping
# ----------------------------
class EarlyStopping:
    """
    Early stops the training if validation loss doesn't improve after a given patience.
    """
    def __init__(self, patience=5, delta=0.0):
        self.patience = patience
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.delta = delta

    def step(self, val_loss):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            return False
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                return True
            return False
        else:
            self.best_score = score
            self.counter = 0
            return False

### Save/load model

In [71]:
# ----------------------------
# Save / Load
# ----------------------------
def save_checkpoint(path, model, optimizer, scheduler, epoch, best=False):
    state = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optim_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None
    }
    torch.save(state, path)
    if best:
        best_path = os.path.splitext(path)[0] + ".best.pt"
        torch.save(state, best_path)

def load_checkpoint(path, model, optimizer=None, scheduler=None, map_location='cpu'):
    ckpt = torch.load(path, map_location=map_location)
    model.load_state_dict(ckpt['model_state_dict'])
    if optimizer and ckpt.get('optim_state_dict'):
        optimizer.load_state_dict(ckpt['optim_state_dict'])
    if scheduler and ckpt.get('scheduler_state_dict'):
        scheduler.load_state_dict(ckpt['scheduler_state_dict'])
    return ckpt.get('epoch', 0)


### Train

In [72]:
def rate(step, model_size, factor, warmup):
    """
    Công thức Noam learning rate schedule.
    """
    if step == 0:
        step = 1
    return factor * (model_size ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5)))

# Khi khởi tạo trong main:
# optimizer = torch.optim.Adam(model.parameters(), lr=1, betas=(0.9, 0.98), eps=1e-9)
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: rate(step, 512, 1.0, 4000))

In [None]:
# ============================
# Config
# ============================
config = {
    "lr": 1e-7,
    "epochs": 50,
    "patience": 5,
    "run_name": "transformer_mt_run",
    "save_dir": "./checkpoints",
    "use_wandb": True,
}

# ============================
# Model setup
# ============================
model = TransformerMT(VOCAB_SIZE, VOCAB_SIZE)
model.to(DEVICE)

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=1e-7, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: rate(step, 512, 1.0, 4000))
scaler = GradScaler() # Cho mixed precision
# ============================
# Wandb (optional)
# ============================
use_wandb = config["use_wandb"]

if use_wandb:
    import wandb
    wandb.init(project="mt-project", name=config["run_name"], config=config)
    wandb.watch(model, log="gradients", log_freq=100)

# ============================
# Early stopping
# ============================
early_stopper = EarlyStopping(patience=config["patience"])
best_val_loss = float("inf")

# ============================
# Training Loop
# ============================
bleu = 0.0
for epoch in range(1, config["epochs"] + 1):
    start = time.time()

    # Train
    train_loss = train_epoch(model=model, dataloader=train_loader, criterion=criterion,
                            optimizer=optimizer, device=DEVICE, scaler=scaler, scheduler=scheduler)

    # Eval loss + BLEU
    val_loss = val_epoch(model, val_loader, criterion, DEVICE)

    epoch_time = time.time() - start


    # Print
    print(
        f"Epoch {epoch:02d} | "
        f"Train Loss: {train_loss:.4f} | "
        f"Val Loss: {val_loss:.4f} | "
        f"Time: {epoch_time:.1f}s"
    )
    
    if epoch % 5 == 0:
        bleu = evaluate_bleu(model, val_loader, sp_en, sp_vi, DEVICE)
        print(f"Epoch {epoch} | BLEU: {bleu}")

    # Wandb
    if use_wandb:
        wandb.log({
            "train_loss": train_loss,
            "val_loss": val_loss,
            "bleu_score": bleu,
            "epoch": epoch,
            "time": epoch_time
        })

    # Save last
    latest_path = os.path.join(config["save_dir"], f"{config['run_name']}.epoch{epoch}.pt")
    save_checkpoint(latest_path, model, optimizer, scheduler, epoch, best=False)

    # Save best
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_path = os.path.join(config["save_dir"], f"{config['run_name']}.best.pt")
        save_checkpoint(best_path, model, optimizer, scheduler, epoch, best=True)
        print(f"New best model saved to {best_path}")

    # Early stopping
    if early_stopper.step(val_loss):
        print(f"Early stopping triggered at epoch {epoch}.")
        break



# ============================
# Save final model
# ============================
final_path = os.path.join(config["save_dir"], f"{config['run_name']}.final.pt")
save_checkpoint(final_path, model, optimizer, scheduler, epoch, best=False)
print("Training finished. Final model saved to", final_path)

if use_wandb:
    wandb.finish()

# ============================
# Quick inference demo
# ============================
print("\n=== PREDICTION DEMO ===")
model.eval() # Đảm bảo model ở chế độ eval
for _ in range(5): # Demo 5 câu thôi cho gọn
    if len(val_src) > 0:
        i = random.randrange(len(val_src))
        src_example = val_src[i]
        tgt_example = val_tgt[i]

        print(f"Source: {src_example}")
        print(f"Target: {tgt_example}")

        # Sửa tên tham số cho đúng định nghĩa hàm
        pred = translate_sentence(
            model, 
            src_example, 
            sp_src=sp_en,  # Đúng tên tham số là sp_src
            sp_tgt=sp_vi,  # Đúng tên tham số là sp_tgt
            device=DEVICE
        )
        print(f"Pred  : {pred}")
        print("-" * 50)

  scaler = GradScaler() # Cho mixed precision


  with autocast(): # Tự động chuyển float32 -> float16 khi cần
socket.send() raised exception.                                                   
