In [None]:
# -*- coding: utf-8 -*-
# Cần dòng trên nếu chạy trên môi trường không mặc định UTF-8

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# Hugging Face Tokenizers
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.trainers import WordPieceTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.processors import TemplateProcessing # For SOS/EOS tokens

import numpy as np
import random
import math
import time
import os
import re
import requests # For downloading GloVe
from tqdm import tqdm # Progress bar for download
import zipfile # For unzipping GloVe
import traceback # For printing detailed errors

# --- Constants ---
DATA_DIR = 'data'
EN_FILE = os.path.join(DATA_DIR, 'en_sents')
VI_FILE = os.path.join(DATA_DIR, 'vi_sents')
TOKENIZER_DIR = 'tokenizers' # Thư mục lưu tokenizer đã huấn luyện
GLOVE_DIR = 'glove_data'   # Thư mục chứa file GloVe
GLOVE_ZIP_URL = 'http://nlp.stanford.edu/data/glove.6B.zip' # Link tải GloVe (ví dụ)
GLOVE_ZIP_FILENAME = 'glove.6B.zip' # Tên file zip
GLOVE_FILENAME = 'glove.6B.300d.txt' # Chọn file GloVe (ví dụ 300 chiều)
GLOVE_PATH = os.path.join(GLOVE_DIR, GLOVE_FILENAME)
MODEL_SAVE_PATH = 'seq2seq-gru-glove-hf.pt' # Tên file lưu model

# Special tokens (phù hợp với Hugging Face Tokenizers)
UNK_TOKEN = "[UNK]"
PAD_TOKEN = "[PAD]"
SOS_TOKEN = "[SOS]" # Hoặc [CLS] tùy bạn chọn
EOS_TOKEN = "[EOS]" # Hoặc [SEP] tùy bạn chọn
SPECIAL_TOKENS = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN]

# --- Hyperparameters ---
MAX_VOCAB_SIZE = 20000     # Giới hạn kích thước từ điển (ví dụ)
EMBEDDING_DIM_GLOVE = 300 # Phải khớp với file GloVe bạn dùng
EMBEDDING_DIM_VI = 256   # Kích thước embedding cho tiếng Việt (có thể khác GloVe)
HIDDEN_DIM = 512
NUM_LAYERS = 2
DROPOUT_RATE = 0.3
LEARNING_RATE = 0.001
BATCH_SIZE = 128       # Giảm nếu gặp lỗi Out-of-Memory (OOM)
NUM_EPOCHS = 10      # Số epoch huấn luyện (ví dụ)
CLIP = 1.0                # Giá trị clipping gradient
TEACHER_FORCING_RATIO = 0.5 # Tỷ lệ sử dụng teacher forcing khi huấn luyện
FREEZE_GLOVE = True       # Có đóng băng embedding GloVe hay không
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# --- Utility Functions ---
def normalize_string(s):
    s = s.lower().strip()
    s = re.sub(r"\s+", " ", s)
    # Có thể thêm các bước làm sạch khác nếu cần
    return s

def read_raw_data(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            lines = [normalize_string(line) for line in f if normalize_string(line)]
        print(f"Successfully read {len(lines)} lines from {file_path}")
        return lines
    except FileNotFoundError:
        print(f"Error: File not found at {file_path}")
        return None
    except Exception as e:
        print(f"An error occurred reading {file_path}: {e}")
        return None


def download_file(url, dest_folder, filename):
    os.makedirs(dest_folder, exist_ok=True)
    zip_dest_path = os.path.join(dest_folder, filename) # Path to the zip file

    if not os.path.exists(GLOVE_PATH): # Chỉ tải nếu file txt chưa tồn tại
        if not os.path.exists(zip_dest_path):
            print(f"Downloading {filename} from {url}...")
            try:
                response = requests.get(url, stream=True)
                response.raise_for_status() # Check for download errors
                total_size = int(response.headers.get('content-length', 0))
                block_size = 1024 # 1 Kibibyte
                t = tqdm(total=total_size, unit='iB', unit_scale=True, desc=f"Downloading {filename}")
                with open(zip_dest_path + '.tmp', 'wb') as f: # Download to temp file
                    for data in response.iter_content(block_size):
                        t.update(len(data))
                        f.write(data)
                t.close()
                if total_size != 0 and t.n != total_size:
                     print("ERROR, something went wrong during download")
                     if os.path.exists(zip_dest_path + '.tmp'): os.remove(zip_dest_path + '.tmp')
                     return False
                os.rename(zip_dest_path + '.tmp', zip_dest_path) # Rename after successful download
                print(f"Downloaded {filename} successfully.")
            except requests.exceptions.RequestException as e:
                print(f"Error downloading {filename}: {e}")
                if os.path.exists(zip_dest_path + '.tmp'): os.remove(zip_dest_path + '.tmp') # Clean up temp file
                return False
            except Exception as e:
                print(f"An unexpected error occurred during download: {e}")
                if os.path.exists(zip_dest_path + '.tmp'): os.remove(zip_dest_path + '.tmp')
                return False
        else:
            print(f"{filename} (zip file) already exists.")

        # --- Unzip ---
        if os.path.exists(zip_dest_path) and zip_filename.endswith('.zip'):
            print(f"Attempting to unzip {filename}...")
            try:
                with zipfile.ZipFile(zip_dest_path, 'r') as zip_ref:
                    zip_ref.extractall(dest_folder)
                print(f"Unzipped {filename} to {dest_folder}")
                # Check if the target file exists after unzipping
                if not os.path.exists(GLOVE_PATH):
                     print(f"Error: Target GloVe file {GLOVE_FILENAME} not found after unzipping.")
                     return False
                return True # Indicate success or file already exists
            except zipfile.BadZipFile:
                print(f"Error: Downloaded file {filename} is not a valid zip file or is corrupted.")
                # Optionally remove corrupted zip file
                # os.remove(zip_dest_path)
                return False
            except Exception as e:
                print(f"An error occurred during unzip: {e}")
                return False
        elif not zip_filename.endswith('.zip'):
             print(f"Expected a zip file, but got {filename}")
             return False
    else:
         print(f"{GLOVE_FILENAME} already exists. Skipping download/unzip.")
         return True # File already exists

    return False # Should not reach here unless there's an issue


# --- Tokenizer Training/Loading ---
def train_or_load_tokenizer(lang, sentences, vocab_size):
    tokenizer_path = os.path.join(TOKENIZER_DIR, f'{lang}_tokenizer.json')
    os.makedirs(TOKENIZER_DIR, exist_ok=True)

    if os.path.exists(tokenizer_path):
        print(f"Loading existing tokenizer for {lang} from {tokenizer_path}")
        try:
            tokenizer = Tokenizer.from_file(tokenizer_path)
        except Exception as e:
            print(f"Error loading tokenizer from {tokenizer_path}: {e}")
            print("Attempting to retrain...")
            os.remove(tokenizer_path) # Remove corrupted file
            return train_or_load_tokenizer(lang, sentences, vocab_size) # Retry training
    else:
        print(f"Training tokenizer for {lang}...")
        tokenizer = Tokenizer(WordPiece(unk_token=UNK_TOKEN))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordPieceTrainer(vocab_size=vocab_size, special_tokens=SPECIAL_TOKENS)

        # Filter out empty sentences before training tokenizer
        non_empty_sentences = [s for s in sentences if s]
        if not non_empty_sentences:
             print(f"Error: No sentences provided for training {lang} tokenizer.")
             return None

        tokenizer.train_from_iterator(non_empty_sentences, trainer=trainer)

        sos_token_id = tokenizer.token_to_id(SOS_TOKEN)
        eos_token_id = tokenizer.token_to_id(EOS_TOKEN)

        if sos_token_id is not None and eos_token_id is not None:
            tokenizer.post_processor = TemplateProcessing(
                single=f"{SOS_TOKEN} $A {EOS_TOKEN}",
                special_tokens=[(SOS_TOKEN, sos_token_id), (EOS_TOKEN, eos_token_id)],
            )
            print(f"Set post-processor for {lang} tokenizer.")
        else:
            print(f"Warning: SOS or EOS token not found in {lang} tokenizer vocab after training.")

        try:
            tokenizer.save(tokenizer_path)
            print(f"Saved tokenizer for {lang} to {tokenizer_path}")
        except Exception as e:
             print(f"Error saving tokenizer for {lang} to {tokenizer_path}: {e}")
             return None # Indicate failure


    pad_token_id = tokenizer.token_to_id(PAD_TOKEN)
    if pad_token_id is not None:
         tokenizer.enable_padding(pad_id=pad_token_id, pad_token=PAD_TOKEN, direction='right') # Ensure direction
         print(f"Enabled padding for {lang} tokenizer (PAD ID: {pad_token_id}).")
    else:
        print(f"Warning: {PAD_TOKEN} not found in {lang} tokenizer. Padding will not work.")

    return tokenizer

# --- GloVe Loading ---
def load_glove_embeddings(glove_path, embedding_dim, tokenizer):
    print(f"Loading GloVe embeddings from {glove_path}...")
    if not os.path.exists(glove_path):
        print(f"Error: GloVe file not found at {glove_path}")
        return None

    word_to_idx = tokenizer.get_vocab() # Get word -> index mapping
    vocab_size = tokenizer.get_vocab_size()
    print(f"Tokenizer vocab size (for embedding matrix): {vocab_size}")
    embeddings = np.zeros((vocab_size, embedding_dim))
    found_words = 0

    try:
        with open(glove_path, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.split()
                word = parts[0]
                if word in word_to_idx: # Check if word from GloVe is in our tokenizer vocab
                    try:
                        vector = np.array(parts[1:], dtype=np.float32)
                        if vector.shape[0] == embedding_dim: # Check dimension match
                            embeddings[word_to_idx[word]] = vector
                            found_words += 1
                        # else: print(f"Warning: Dim mismatch for '{word}'") # Verbose
                    except ValueError:
                        # print(f"Warning: Cannot parse vector for '{word}'") # Verbose
                        pass # Skip malformed lines
    except Exception as e:
        print(f"An error occurred while reading GloVe file: {e}")
        return None

    print(f"Loaded {found_words}/{len(word_to_idx)} words from GloVe file into embedding matrix.")

    # Handle special tokens (important!)
    pad_idx = tokenizer.token_to_id(PAD_TOKEN)
    unk_idx = tokenizer.token_to_id(UNK_TOKEN)
    sos_idx = tokenizer.token_to_id(SOS_TOKEN)
    eos_idx = tokenizer.token_to_id(EOS_TOKEN)

    # It's crucial PAD is zeros for padding_idx to work correctly in nn.Embedding
    if pad_idx is not None:
        embeddings[pad_idx] = np.zeros(embedding_dim)
        print(f"Set PAD token embedding (Index: {pad_idx}) to zeros.")

    # Initialize UNK with small random values if not found in GloVe
    if unk_idx is not None and np.all(embeddings[unk_idx] == 0):
        print(f"Initializing UNK token embedding (Index: {unk_idx}) randomly.")
        embeddings[unk_idx] = np.random.randn(embedding_dim) * 0.01

    # Optionally initialize SOS/EOS if they weren't found
    if sos_idx is not None and np.all(embeddings[sos_idx] == 0):
        print(f"Initializing SOS token embedding (Index: {sos_idx}) randomly.")
        embeddings[sos_idx] = np.random.randn(embedding_dim) * 0.01
    if eos_idx is not None and np.all(embeddings[eos_idx] == 0):
        print(f"Initializing EOS token embedding (Index: {eos_idx}) randomly.")
        embeddings[eos_idx] = np.random.randn(embedding_dim) * 0.01

    return torch.tensor(embeddings, dtype=torch.float)


# --- Dataset using Hugging Face Tokenizer ---
class TranslationDatasetHF(Dataset):
    def __init__(self, src_sentences, trg_sentences, src_tokenizer, trg_tokenizer):
        if not src_sentences or not trg_sentences:
            raise ValueError("Source or target sentences list is empty.")
        if len(src_sentences) != len(trg_sentences):
             raise ValueError("Source and target sentences must have the same length.")

        self.src_sentences = src_sentences
        self.trg_sentences = trg_sentences
        self.src_tokenizer = src_tokenizer
        self.trg_tokenizer = trg_tokenizer

    def __len__(self):
        return len(self.src_sentences)

    def __getitem__(self, idx):
        src_text = self.src_sentences[idx]
        trg_text = self.trg_sentences[idx]

        # Encode text using the tokenizer (should add SOS/EOS if post_processor is set)
        src_encoded = self.src_tokenizer.encode(src_text)
        trg_encoded = self.trg_tokenizer.encode(trg_text)

        src_ids = src_encoded.ids
        trg_ids = trg_encoded.ids

        # Basic check if encoding produced empty lists (might happen with empty input strings)
        if not src_ids: src_ids = [self.src_tokenizer.token_to_id(PAD_TOKEN)] # Handle potential empty encoding
        if not trg_ids: trg_ids = [self.trg_tokenizer.token_to_id(PAD_TOKEN)]

        return torch.tensor(src_ids, dtype=torch.long), torch.tensor(trg_ids, dtype=torch.long)

# --- Collate Function for HF Tokenizer Output ---
def collate_fn_hf(batch, pad_idx_src, pad_idx_trg):
    src_batch, trg_batch = [], []
    for src_item, trg_item in batch:
        src_batch.append(src_item)
        trg_batch.append(trg_item)

    # Pad sequences to the max length in the batch
    src_batch_padded = pad_sequence(src_batch, batch_first=True, padding_value=pad_idx_src)
    trg_batch_padded = pad_sequence(trg_batch, batch_first=True, padding_value=pad_idx_trg)

    return src_batch_padded, trg_batch_padded

# --- Model Definition (Encoder adapted for GloVe) ---
class EncoderGRU(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim, num_layers, dropout, embedding_weights=None, freeze_emb=True, pad_idx=0):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        if embedding_weights is not None:
            print("Encoder: Initializing embedding layer with pre-trained weights.")
            # Ensure emb_dim matches pre-trained dimension
            if emb_dim != embedding_weights.shape[1]:
                 print(f"Warning: emb_dim ({emb_dim}) does not match pre-trained embedding dimension ({embedding_weights.shape[1]}). Adjusting emb_dim.")
                 emb_dim = embedding_weights.shape[1] # Override emb_dim

            self.embedding = nn.Embedding.from_pretrained(embedding_weights, freeze=freeze_emb, padding_idx=pad_idx)
            print(f"Encoder embedding layer created with shape: {self.embedding.weight.shape}, Frozen: {freeze_emb}")
        else:
            print("Encoder: Initializing embedding layer randomly.")
            self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=pad_idx)
            print(f"Encoder embedding layer created with shape: {self.embedding.weight.shape}")


        self.gru = nn.GRU(emb_dim, hidden_dim, num_layers,
                          dropout=dropout if num_layers > 1 else 0,
                          batch_first=True, bidirectional=False) # Set bidirectional=False explicitly
        self.dropout = nn.Dropout(dropout)

    def forward(self, src_seq):
        # src_seq shape: [batch_size, src_len]
        embedded = self.dropout(self.embedding(src_seq))
        # embedded shape: [batch_size, src_len, emb_dim]
        outputs, hidden = self.gru(embedded)
        # If GRU is bidirectional, hidden shape is [num_layers*2, batch_size, hidden_dim].
        # Need to handle this if using bidirectional. For now, assuming unidirectional.
        # hidden shape: [num_layers, batch_size, hidden_dim]
        return hidden

class DecoderGRU(nn.Module):
    def __init__(self, output_dim, emb_dim, hidden_dim, num_layers, dropout, pad_idx=0):
        super().__init__()
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        print("Decoder: Initializing embedding layer randomly.")
        self.embedding = nn.Embedding(output_dim, emb_dim, padding_idx=pad_idx)
        print(f"Decoder embedding layer created with shape: {self.embedding.weight.shape}")
        self.gru = nn.GRU(emb_dim, hidden_dim, num_layers,
                          dropout=dropout if num_layers > 1 else 0,
                          batch_first=True) # batch_first=True
        self.fc_out = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_step, hidden_state):
        # input_step shape: [batch_size] -> [batch_size, 1]
        input_step = input_step.unsqueeze(1)
        # embedded shape: [batch_size, 1, emb_dim]
        embedded = self.dropout(self.embedding(input_step))
        # output shape: [batch_size, 1, hidden_dim]
        # new_hidden_state shape: [num_layers, batch_size, hidden_dim]
        output, new_hidden_state = self.gru(embedded, hidden_state)
        # output shape: [batch_size, hidden_dim]
        output = output.squeeze(1)
        # prediction shape: [batch_size, output_dim]
        prediction = self.fc_out(output)
        return prediction, new_hidden_state

# Seq2Seq Model (No changes needed)
class Seq2SeqGRU(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        # Assert dimensions match
        assert encoder.hidden_dim == decoder.hidden_dim, \
            "Hidden dimensions of encoder and decoder must be equal!"
        assert encoder.num_layers == decoder.num_layers, \
            "Number of layers in encoder and decoder must be equal!"

    def forward(self, src_seq, trg_seq, teacher_forcing_ratio=0.5):
        batch_size = trg_seq.shape[0]
        trg_len = trg_seq.shape[1]
        trg_vocab_size = self.decoder.output_dim

        # Tensor to store decoder outputs
        decoder_outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)

        # Encoder forward pass
        encoder_hidden = self.encoder(src_seq) # [num_layers, batch_size, hidden_dim]

        # Initialize decoder hidden state with encoder's final hidden state
        decoder_hidden = encoder_hidden

        # First input to the decoder is the <sos> tokens from the target sequence
        decoder_input = trg_seq[:, 0] # Shape: [batch_size]

        # Loop through the target sequence length
        for t in range(1, trg_len): # Start from 1 as 0 is <sos>
            # Decoder forward pass for one time step
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)

            # Store prediction
            decoder_outputs[:, t, :] = decoder_output

            # Decide whether to use teacher forcing
            use_teacher_force = random.random() < teacher_forcing_ratio

            if use_teacher_force:
                # Use actual next token from target sequence as next input
                decoder_input = trg_seq[:, t]
            else:
                # Use decoder's own prediction as next input
                top1 = decoder_output.argmax(1) # Get index of highest probability token
                decoder_input = top1
                # Check if all sequences predicted EOS (optimization: break early)
                # if (decoder_input == EOS_IDX_VI).all():
                #    break

        return decoder_outputs


# --- Weight Initialization ---
def init_weights(m):
    for name, param in m.named_parameters():
        # Skip pre-trained embedding layer in encoder
        if 'encoder.embedding.weight' in name and glove_embeddings is not None:
             print(f"Skipping initialization for {name} (pre-trained)")
             continue
        # Initialize other weights
        if param.dim() > 1:
            nn.init.xavier_uniform_(param)
        # Initialize biases to zero
        elif 'bias' in name:
            nn.init.constant_(param, 0)

# --- Training Loop Function ---
def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    processed_batches = 0
    print(f"Starting training epoch...")
    batch_progress = tqdm(iterator, desc="Training", leave=False)
    for i, batch in enumerate(batch_progress):
        try:
            src, trg = batch
            src = src.to(DEVICE)
            trg = trg.to(DEVICE)

            optimizer.zero_grad()
            # output shape: [batch_size, trg_len, output_vocab_size]
            output = model(src, trg, teacher_forcing_ratio=TEACHER_FORCING_RATIO)

            # Reshape for CrossEntropyLoss:
            # Output: remove <sos> token -> [batch_size, trg_len-1, output_vocab_size] -> [(batch_size*(trg_len-1)), output_vocab_size]
            # Target: remove <sos> token -> [batch_size, trg_len-1] -> [(batch_size*(trg_len-1))]
            output_dim = output.shape[-1]
            output = output[:, 1:].reshape(-1, output_dim)
            trg = trg[:, 1:].reshape(-1)

            loss = criterion(output, trg)
            if torch.isnan(loss):
                print(f"Warning: NaN loss detected at batch {i}. Skipping update.")
                continue # Skip this batch

            loss.backward()
            # Clip gradients to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step()

            epoch_loss += loss.item()
            processed_batches += 1
            batch_progress.set_postfix(loss=loss.item())

        except Exception as e:
            print(f"\nError during training batch {i}: {e}")
            print("Skipping this batch.")
            traceback.print_exc() # Print detailed error
            optimizer.zero_grad() # Ensure grads are cleared
            continue # Move to next batch

    if processed_batches == 0:
        print("Warning: No batches were processed successfully in this training epoch.")
        return float('inf') # Return infinity if no batches processed

    return epoch_loss / processed_batches

# --- Evaluation Loop Function ---
def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    processed_batches = 0
    print(f"Starting evaluation...")
    batch_progress = tqdm(iterator, desc="Evaluating", leave=False)
    with torch.no_grad():
        for i, batch in enumerate(batch_progress):
            try:
                src, trg = batch
                src = src.to(DEVICE)
                trg = trg.to(DEVICE)

                # Turn off teacher forcing for evaluation
                output = model(src, trg, teacher_forcing_ratio=0)

                output_dim = output.shape[-1]
                output = output[:, 1:].reshape(-1, output_dim)
                trg = trg[:, 1:].reshape(-1)

                loss = criterion(output, trg)
                if torch.isnan(loss):
                    print(f"Warning: NaN loss detected during evaluation batch {i}. Skipping.")
                    continue

                epoch_loss += loss.item()
                processed_batches += 1
                batch_progress.set_postfix(loss=loss.item())

            except Exception as e:
                print(f"\nError during evaluation batch {i}: {e}")
                print("Skipping this batch.")
                traceback.print_exc()
                continue

    if processed_batches == 0:
         print("Warning: No batches were processed successfully in evaluation.")
         return float('inf')

    return epoch_loss / processed_batches

# --- Helper function for timing ---
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

# --- Inference Function ---
def translate_sentence_hf(sentence: str, src_tokenizer: Tokenizer, trg_tokenizer: Tokenizer, model: Seq2SeqGRU, device, max_len=50):
    model.eval()

    if not isinstance(sentence, str) or not sentence.strip():
        return "Input sentence is empty or invalid."

    # Tiền xử lý câu nguồn
    cleaned_sentence = normalize_string(sentence)
    src_encoded = src_tokenizer.encode(cleaned_sentence)
    if not src_encoded or not src_encoded.ids:
        return "Input sentence is empty after tokenization."

    src_tensor = torch.LongTensor(src_encoded.ids).unsqueeze(0).to(device)

    trg_sos_id = trg_tokenizer.token_to_id(SOS_TOKEN)
    trg_eos_id = trg_tokenizer.token_to_id(EOS_TOKEN)
    if trg_sos_id is None or trg_eos_id is None:
        print("Error: Target SOS or EOS token ID not found in tokenizer.")
        return "Error during translation setup."

    try:
        with torch.no_grad():
            encoder_hidden = model.encoder(src_tensor)
            decoder_hidden = encoder_hidden
            decoder_input = torch.LongTensor([trg_sos_id]).to(device)
            trg_ids_result = []

            for _ in range(max_len):
                output, decoder_hidden = model.decoder(decoder_input, decoder_hidden)
                pred_token_id = output.argmax(1).item()

                if pred_token_id == trg_eos_id:
                    break # Stop if EOS is predicted

                trg_ids_result.append(pred_token_id)
                decoder_input = torch.LongTensor([pred_token_id]).to(device)

        # Decode the result (Hugging Face decode often handles special tokens)
        translated_text = trg_tokenizer.decode(trg_ids_result, skip_special_tokens=True)
        return translated_text

    except Exception as e:
        print(f"Error during translation inference: {e}")
        traceback.print_exc()
        return "Error during translation."


# ================================================================
#                       MAIN EXECUTION BLOCK
# ================================================================

if __name__ == '__main__': # Ensure this runs only when script is executed directly

    # 1. Download GloVe data if necessary
    if not download_file(GLOVE_ZIP_URL, GLOVE_DIR, GLOVE_ZIP_FILENAME):
        print("GloVe preparation failed. Please ensure the GloVe file exists or download is possible.")
        exit()

    # 2. Load Raw Data
    print("\n--- Loading Raw Data ---")
    en_sents_raw = read_raw_data(EN_FILE)
    vi_sents_raw = read_raw_data(VI_FILE)

    if en_sents_raw is None or vi_sents_raw is None or len(en_sents_raw) != len(vi_sents_raw):
        print("Exiting due to data loading error or length mismatch.")
        exit()
    if not en_sents_raw:
         print("Error: No data loaded. Check data files.")
         exit()

    # 3. Split Data
    print("\n--- Splitting Data ---")
    combined = list(zip(en_sents_raw, vi_sents_raw))
    random.shuffle(combined) # Shuffle before splitting
    train_size = int(0.8 * len(combined))
    valid_size = int(0.1 * len(combined))
    # Ensure sizes are valid
    if train_size == 0 or valid_size == 0 or len(combined) - train_size - valid_size == 0:
        print("Error: Dataset too small to split into train/validation/test sets.")
        # Handle small datasets differently, e.g., use only train/val or cross-validation
        # For this example, we'll exit if the split is invalid
        exit()

    train_data = combined[:train_size]
    valid_data = combined[train_size:train_size+valid_size]
    test_data = combined[train_size+valid_size:]

    en_train_sents, vi_train_sents = zip(*train_data)
    en_valid_sents, vi_valid_sents = zip(*valid_data)
    en_test_sents, vi_test_sents = zip(*test_data)

    print(f"Train size: {len(en_train_sents)}")
    print(f"Validation size: {len(en_valid_sents)}")
    print(f"Test size: {len(en_test_sents)}")


    # 4. Train/Load Tokenizers
    print("\n--- Preparing Tokenizers ---")
    tokenizer_en = train_or_load_tokenizer("en", en_train_sents, vocab_size=MAX_VOCAB_SIZE)
    tokenizer_vi = train_or_load_tokenizer("vi", vi_train_sents, vocab_size=MAX_VOCAB_SIZE)

    if tokenizer_en is None or tokenizer_vi is None:
         print("Exiting due to tokenizer preparation error.")
         exit()

    INPUT_VOCAB_SIZE = tokenizer_en.get_vocab_size()
    OUTPUT_VOCAB_SIZE = tokenizer_vi.get_vocab_size()
    PAD_IDX_EN = tokenizer_en.token_to_id(PAD_TOKEN)
    PAD_IDX_VI = tokenizer_vi.token_to_id(PAD_TOKEN)

    # Check if PAD tokens exist, otherwise padding will fail
    if PAD_IDX_EN is None or PAD_IDX_VI is None:
        print(f"Error: PAD token ('{PAD_TOKEN}') not found in one or both tokenizers. Cannot proceed.")
        exit()

    print(f"English Vocab Size: {INPUT_VOCAB_SIZE}")
    print(f"Vietnamese Vocab Size: {OUTPUT_VOCAB_SIZE}")


    # 5. Load GloVe Embeddings
    print("\n--- Loading GloVe Embeddings ---")
    glove_embeddings = load_glove_embeddings(GLOVE_PATH, EMBEDDING_DIM_GLOVE, tokenizer_en)
    if glove_embeddings is None:
         print("Warning: Failed to load GloVe embeddings. Encoder embedding will be random.")
         # Decide if you want to exit or continue with random embeddings
         # exit()


    # 6. Create Datasets and DataLoaders
    print("\n--- Creating Datasets and DataLoaders ---")
    try:
        train_dataset = TranslationDatasetHF(en_train_sents, vi_train_sents, tokenizer_en, tokenizer_vi)
        valid_dataset = TranslationDatasetHF(en_valid_sents, vi_valid_sents, tokenizer_en, tokenizer_vi)
        test_dataset  = TranslationDatasetHF(en_test_sents, vi_test_sents, tokenizer_en, tokenizer_vi)
    except ValueError as e:
         print(f"Error creating dataset: {e}")
         exit()


    collate_with_padding_hf = lambda batch: collate_fn_hf(batch, PAD_IDX_EN, PAD_IDX_VI)

    train_iterator = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_with_padding_hf)
    valid_iterator = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_with_padding_hf)
    test_iterator  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_with_padding_hf)
    print("DataLoaders created.")


    # 7. Initialize Model, Optimizer, Criterion
    print("\n--- Initializing Model ---")
    encoder = EncoderGRU(INPUT_VOCAB_SIZE, EMBEDDING_DIM_GLOVE, HIDDEN_DIM, NUM_LAYERS, DROPOUT_RATE,
                         embedding_weights=glove_embeddings, freeze_emb=FREEZE_GLOVE, pad_idx=PAD_IDX_EN).to(DEVICE)
    # Decoder embedding dim can be different if needed
    decoder = DecoderGRU(OUTPUT_VOCAB_SIZE, EMBEDDING_DIM_VI, HIDDEN_DIM, NUM_LAYERS, DROPOUT_RATE,
                         pad_idx=PAD_IDX_VI).to(DEVICE)
    model = Seq2SeqGRU(encoder, decoder, DEVICE).to(DEVICE)

    # Apply custom weight initialization
    model.apply(init_weights)
    print("Model initialized.")

    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX_VI)
    print("Optimizer and Criterion set.")


    # --- Training Execution ---
    best_valid_loss = float('inf')

    print("\n--- Starting Training Loop ---")
    for epoch in range(NUM_EPOCHS):
        start_time = time.time()

        print(f"\nEpoch: {epoch+1:02}/{NUM_EPOCHS}")
        train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
        valid_loss = evaluate(model, valid_iterator, criterion)

        end_time = time.time()
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"\t-> Saved Best Model (Val Loss: {valid_loss:.3f})")
        else:
            print(f"\t   Validation loss did not improve from {best_valid_loss:.3f}")

        print(f'\tTime: {epoch_mins}m {epoch_secs}s')
        # Handle potential inf loss from train/eval loops if no batches were processed
        if train_loss != float('inf'):
            print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(min(train_loss, 700)):7.3f}') # Cap PPL for display
        else:
             print('\tTrain Loss: Inf (No batches processed)')
        if valid_loss != float('inf'):
            print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(min(valid_loss, 700)):7.3f}') # Cap PPL for display
        else:
             print('\t Val. Loss: Inf (No batches processed)')


    print("\n--- Training Finished ---")


    # --- Load Best Model and Translate Examples ---
    print(f"\n--- Loading Best Model for Inference ({MODEL_SAVE_PATH}) ---")
    try:
        # Load the state dict onto the correct device
        model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
        print("Model loaded successfully.")

        # Example 1: Custom Sentence
        example_sentence_en = "a man in a blue shirt is playing a guitar ." # Example sentence
        print(f"\nSource EN: {example_sentence_en}")
        translation = translate_sentence_hf(example_sentence_en, tokenizer_en, tokenizer_vi, model, DEVICE)
        print(f'Predicted VI: {translation}')

        # Example 2: Sentence from Test Set
        if test_data: # Check if test set is not empty
            print("\nTranslating a sentence from the test set:")
            test_idx = random.randint(0, len(en_test_sents) - 1)
            src_test_sent = en_test_sents[test_idx]
            trg_test_sent = vi_test_sents[test_idx] # Ground truth for comparison
            print(f"Source EN: {src_test_sent}")
            print(f"Actual VI: {trg_test_sent}")
            translation_test = translate_sentence_hf(src_test_sent, tokenizer_en, tokenizer_vi, model, DEVICE)
            print(f'Predicted VI: {translation_test}')
        else:
             print("\nTest set is empty, cannot translate test example.")

    except FileNotFoundError:
        print(f"\nModel file '{MODEL_SAVE_PATH}' not found. Cannot perform inference.")
    except Exception as e:
        print(f"\nAn error occurred during final inference: {e}")
        traceback.print_exc()


Using device: cpu
glove.6B.300d.txt already exists. Skipping download/unzip.

--- Loading Raw Data ---
Successfully read 254090 lines from data/en_sents
Successfully read 254090 lines from data/vi_sents

--- Splitting Data ---
Train size: 203272
Validation size: 25409
Test size: 25409

--- Preparing Tokenizers ---
Loading existing tokenizer for en from tokenizers/en_tokenizer.json
Enabled padding for en tokenizer (PAD ID: 0).
Loading existing tokenizer for vi from tokenizers/vi_tokenizer.json
Enabled padding for vi tokenizer (PAD ID: 0).
English Vocab Size: 20000
Vietnamese Vocab Size: 9660

--- Loading GloVe Embeddings ---
Loading GloVe embeddings from glove_data/glove.6B.300d.txt...
Tokenizer vocab size (for embedding matrix): 20000
Loaded 14781/20000 words from GloVe file into embedding matrix.
Set PAD token embedding (Index: 0) to zeros.
Initializing UNK token embedding (Index: 3) randomly.
Initializing SOS token embedding (Index: 1) randomly.
Initializing EOS token embedding (Inde

                                                                              

Starting evaluation...


                                                                        

	-> Saved Best Model (Val Loss: 4.918)
	Time: 91m 13s
	Train Loss: 4.985 | Train PPL: 146.186
	 Val. Loss: 4.918 |  Val. PPL: 136.756

Epoch: 02/10
Starting training epoch...


                                                                              

Starting evaluation...


                                                                        

	-> Saved Best Model (Val Loss: 4.344)
	Time: 136m 17s
	Train Loss: 3.977 | Train PPL:  53.383
	 Val. Loss: 4.344 |  Val. PPL:  76.977

Epoch: 03/10
Starting training epoch...


                                                                          

Starting evaluation...


                                                                        

	-> Saved Best Model (Val Loss: 3.972)
	Time: 54m 56s
	Train Loss: 3.387 | Train PPL:  29.586
	 Val. Loss: 3.972 |  Val. PPL:  53.114

Epoch: 04/10
Starting training epoch...


                                                                         

Starting evaluation...


                                                                        

	-> Saved Best Model (Val Loss: 3.718)
	Time: 32m 58s
	Train Loss: 3.014 | Train PPL:  20.363
	 Val. Loss: 3.718 |  Val. PPL:  41.186

Epoch: 05/10
Starting training epoch...


                                                                        

Starting evaluation...


                                                                        

	-> Saved Best Model (Val Loss: 3.476)
	Time: 30m 47s
	Train Loss: 2.750 | Train PPL:  15.641
	 Val. Loss: 3.476 |  Val. PPL:  32.321

Epoch: 06/10
Starting training epoch...


                                                                        

Starting evaluation...


                                                                        

	-> Saved Best Model (Val Loss: 3.340)
	Time: 30m 3s
	Train Loss: 2.539 | Train PPL:  12.672
	 Val. Loss: 3.340 |  Val. PPL:  28.214

Epoch: 07/10
Starting training epoch...


Training:  81%|████████  | 1284/1589 [24:12<12:58,  2.55s/it, loss=2.01]