In [None]:
# CELL 1: Environment Setup & GPU Check
print("="*70)
print("KAGGLE TRANSFORMER NMT PIPELINE (SENTENCEPIECE)")
print("="*70)

import os
import sys
import torch

# Kaggle paths
CODE_PATH = '/kaggle/input/nlp-py-v2'  # Your code dataset
DATA_PATH = '/kaggle/input/nlp-py-v2'  # Data files
OUTPUT_PATH = '/kaggle/working'

# Add code to Python path
sys.path.insert(0, CODE_PATH)

# Check GPU
print("\n" + "="*70)
print("  GPU INFORMATION")
print("="*70)
if torch.cuda.is_available():
    print(f" GPU: {torch.cuda.get_device_name(0)}")
    print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f" CUDA version: {torch.version.cuda}")
else:
    print("  WARNING: No GPU detected! Training will be slow.")
print("="*70)

# Create output directories
os.makedirs(f'{OUTPUT_PATH}/checkpoints', exist_ok=True)
os.makedirs(f'{OUTPUT_PATH}/results', exist_ok=True)
print("\n Output directories created in /kaggle/working/")

In [None]:
import os
print(os.listdir("/kaggle/input"))


In [None]:
# CELL 2: Install SentencePiece
print("\n" + "="*70)
print("INSTALLING SENTENCEPIECE")
print("="*70)

# Install sentencepiece
import subprocess
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "sentencepiece"])

print(" SentencePiece installed successfully!")


In [None]:
# CELL 3: Verify Files
print("\n" + "="*70)
print("VERIFYING FILES")
print("="*70)

# Check for SentencePiece model files
required_sp_files = [
    'vi_sp.model',
    'en_sp.model',
    'vi_sp.vocab',
    'en_sp.vocab'
]

# Check Python files
required_files = [
    'transformer_components.py',
    'transformer_encoder_decoder.py',
    'complete_transformer.py',
    'training_module.py',
    'inference_evaluation_v2.py',  # Updated inference file
    'tokenizer_sentencepiece.py',  # New tokenizer file
]

print("\n Python files:")
all_good = True
for f in required_files:
    file_path = f'{CODE_PATH}/{f}'
    exists = os.path.exists(file_path)
    status = "True" if exists else "False"
    print(f"  {status} {f}")
    if not exists:
        all_good = False
        if f in ['tokenizer_sentencepiece.py', 'inference_evaluation_v2.py']:
            print(f"        Missing new file - you need to add this to your Kaggle dataset")

# Check SentencePiece model files
print(f"\nSentencePiece models:")
sp_files_exist = True
for sp_file in required_sp_files:
    file_path = f'{DATA_PATH}/{sp_file}'
    exists = os.path.exists(file_path)
    status = "True" if exists else "False"
    if exists:
        file_size = os.path.getsize(file_path) / (1024*1024)
        print(f"  {status} {sp_file} ({file_size:.2f} MB)")
    else:
        print(f"  {status} {sp_file}")
        sp_files_exist = False

if not sp_files_exist:
    print("\n  SentencePiece models not found!")
    print("   You need to:")
    print("   1. Run tokenizer_sentencepiece.py locally to create .model files")
    print("   2. Upload vi_sp.model and en_sp.model to your Kaggle dataset")

# Check data files
print(f"\n Data files:")
data_files = ['processed_data.pkl']
for data_file in data_files:
    file_path = f'{DATA_PATH}/{data_file}'
    exists = os.path.exists(file_path)
    status = "True" if exists else "False"
    if exists:
        file_size = os.path.getsize(file_path) / (1024*1024)
        print(f"  {status} {data_file} ({file_size:.2f} MB)")
    else:
        print(f"  {status} {data_file}")

if sp_files_exist:
    print("\n ALL SENTENCEPIECE FILES READY!")
else:
    print("\n MISSING SENTENCEPIECE FILES!")
    print("\n TO CREATE SENTENCEPIECE MODELS:")
    print("   Run this command locally with your CSV:")
    print("   python tokenizer_sentencepiece.py --csv_path data.csv --vi_col src --en_col tgt")


In [None]:
# CELL 4: Configuration
print("\n" + "="*70)
print("‚öôÔ∏è  CONFIGURATION")
print("="*70)

CONFIG = {
    # Model settings
    'model_size': 'tiny',              # tiny/small/base 
    # Training settings
    'num_epochs': 5,
    'batch_size': 32,                  # Increased due to smaller vocab
    'warmup_steps': 2000,              # Reduced for smaller vocab
    'label_smoothing': 0.1,
    'save_every_batches': 89000,        # Save more frequently
    'use_amp': True,                   # Mixed precision training
    
    # Evaluation settings
    'beam_size': 5,
}

print("\n Training Configuration:")
for key, value in CONFIG.items():
    print(f"  {key:20s}: {value}")

print("\n Key Improvements with SentencePiece:")
print("   Vocab size: ~64K (was 426K) - 85% reduction!")
print("   No UNK tokens - handles all new words")
print("   Faster training - smaller vocab")
print("   Better for Vietnamese - handles accents well")
print("="*70)

In [None]:
# CELL 5: Load Tokenizers
print("\n" + "="*70)
print(" LOADING SENTENCEPIECE TOKENIZERS")
print("="*70)

from tokenizer_sentencepiece import SentencePieceTokenizer

# Load tokenizers
print("\n Loading vocabularies...")
vi_tokenizer = SentencePieceTokenizer(f'{DATA_PATH}/vi_sp.model')
en_tokenizer = SentencePieceTokenizer(f'{DATA_PATH}/en_sp.model')

print(f"\n Tokenizers loaded successfully!")
print(f"   Vietnamese vocab: {len(vi_tokenizer):,} tokens")
print(f"   English vocab: {len(en_tokenizer):,} tokens")
print(f"   Total vocab: {len(vi_tokenizer) + len(en_tokenizer):,} tokens")

# Test tokenizers
print(f"\n Testing tokenizers:")
test_vi = "xin ch√†o, t√¥i l√† sinh vi√™n"
test_en = "hello, i am a student"

vi_ids = vi_tokenizer.encode(test_vi)
en_ids = en_tokenizer.encode(test_en)

print(f"\n  VI: '{test_vi}'")
print(f"      Tokens: {vi_tokenizer.encode_as_pieces(test_vi)}")
print(f"      Decoded: '{vi_tokenizer.decode(vi_ids)}'")

print(f"\n  EN: '{test_en}'")
print(f"      Tokens: {en_tokenizer.encode_as_pieces(test_en)}")
print(f"      Decoded: '{en_tokenizer.decode(en_ids)}'")

# Test with COMPLETELY NEW WORD
new_word = "ChatGPT DeepMind"
print(f"\n   NEW WORD TEST: '{new_word}'")
pieces = en_tokenizer.encode_as_pieces(new_word)
print(f"      Pieces: {pieces}")
print(f"       No <unk>! Broken into subwords")


In [None]:
# CELL 6: Create Models
print("\n" + "="*70)
print("  STAGE 2: CREATE MODELS")
print("="*70)

from complete_transformer import create_model, print_model_info

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n  Device: {device}")

# Create VI->EN model
print(f"\nüî® Creating VI->EN {CONFIG['model_size']} model...")
model_vi_en, model_config = create_model(
    src_vocab_size=len(vi_tokenizer),
    tgt_vocab_size=len(en_tokenizer),
    model_size=CONFIG['model_size'],
    pad_idx=0
)

print_model_info(model_vi_en, CONFIG['model_size'])
model_vi_en = model_vi_en.to(device)
print(f"\n VI->EN Model moved to {device}")

# Create EN->VI model
print(f"\n Creating EN->VI {CONFIG['model_size']} model...")
model_en_vi, model_config_en_vi = create_model(
    src_vocab_size=len(en_tokenizer),
    tgt_vocab_size=len(vi_tokenizer),
    model_size=CONFIG['model_size'],
    pad_idx=0
)

print_model_info(model_en_vi, CONFIG['model_size'])
model_en_vi = model_en_vi.to(device)
print(f"\n EN->VI Model moved to {device}")

print(f"\n Comparison with old vocab:")
old_total = 174608 + 251853
new_total = len(vi_tokenizer) + len(en_tokenizer)
reduction = (1 - new_total / old_total) * 100
print(f"   Old total vocab: {old_total:,}")
print(f"   New total vocab: {new_total:,}")
print(f"   Reduction: {reduction:.1f}%")

In [None]:
# RESUME TRAINING FROM EPOCH 5 CHECKPOINT

print("\n" + "="*70)
print(" RESUMING TRAINING FROM CHECKPOINT")
print("="*70)

import torch
import os
import glob

# STEP 1: LOCATE YOUR CHECKPOINT FILES


print("\n Looking for checkpoints...")

# Path to your checkpoints (adjust if different)
CHECKPOINT_DIR_VI_EN = '/kaggle/input/checkpoint-nlp'
CHECKPOINT_DIR_EN_VI = '//kaggle/input/checkpoint-nlp'

# Find all epoch checkpoints
vi_en_checkpoints = glob.glob(f'{CHECKPOINT_DIR_VI_EN}/best_model_vi_en_14.pt')
en_vi_checkpoints = glob.glob(f'{CHECKPOINT_DIR_EN_VI}/best_model_en_vi_14.pt')

if vi_en_checkpoints:
    print(f"\n Found {len(vi_en_checkpoints)} VI->EN checkpoints:")
    for cp in sorted(vi_en_checkpoints):
        print(f"   - {os.path.basename(cp)}")
else:
    print("\n No VI->EN checkpoints found!")

if en_vi_checkpoints:
    print(f"\n Found {len(en_vi_checkpoints)} EN->VI checkpoints:")
    for cp in sorted(en_vi_checkpoints):
        print(f"   - {os.path.basename(cp)}")
else:
    print("\n No EN->VI checkpoints found!")


# STEP 2: SELECT CHECKPOINT TO RESUME FROM


# Use specific epoch (RECOMMENDED for you - Epoch 5)
RESUME_FROM_EPOCH = 5
resume_checkpoint_vi_en = f'{CHECKPOINT_DIR_VI_EN}/best_model_vi_en_14.pt'
resume_checkpoint_en_vi = f'{CHECKPOINT_DIR_EN_VI}/best_model_en_vi_14.pt'

print(f"\n Selected checkpoints:")
print(f"   VI->EN: {os.path.basename(resume_checkpoint_vi_en)}")
print(f"   EN->VI: {os.path.basename(resume_checkpoint_en_vi)}")

# STEP 3: LOAD CHECKPOINTS

from training_module import load_checkpoint
from torch.cuda.amp import GradScaler

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("\n" + "="*70)
print("üì• LOADING CHECKPOINTS")
print("="*70)

# Create optimizers (will be replaced by checkpoint state)
optimizer_vi_en = torch.optim.Adam(
    model_vi_en.parameters(), lr=1.0, betas=(0.9, 0.98), eps=1e-9
)
optimizer_en_vi = torch.optim.Adam(
    model_en_vi.parameters(), lr=1.0, betas=(0.9, 0.98), eps=1e-9
)

# Create schedulers (will be replaced by checkpoint state)
from training_module import TransformerLRScheduler

scheduler_vi_en = TransformerLRScheduler(
    optimizer_vi_en, d_model=model_config['d_model'], warmup_steps=CONFIG['warmup_steps']
)
scheduler_en_vi = TransformerLRScheduler(
    optimizer_en_vi, d_model=model_config_en_vi['d_model'], warmup_steps=CONFIG['warmup_steps']
)

# Create scalers
scaler_vi_en = GradScaler() if CONFIG['use_amp'] else None
scaler_en_vi = GradScaler() if CONFIG['use_amp'] else None

# Load VI->EN checkpoint
print("\n Loading VI->EN checkpoint...")
if os.path.exists(resume_checkpoint_vi_en):
    model_vi_en, optimizer_vi_en, scheduler_vi_en, start_epoch_vi_en, start_batch_vi_en, history_vi_en, scaler_vi_en = load_checkpoint(
        model_vi_en,
        resume_checkpoint_vi_en,
        device,
        optimizer_vi_en,
        scheduler_vi_en,
        scaler_vi_en
    )
    print(f" Loaded VI->EN from Epoch {start_epoch_vi_en}")
else:
    print(f" Checkpoint not found: {resume_checkpoint_vi_en}")
    print("   Will train from scratch!")
    start_epoch_vi_en = 1
    start_batch_vi_en = 0
    history_vi_en = {
        'train_loss': [], 'train_ppl': [],
        'val_loss': [], 'val_ppl': [],
        'lr': [], 'epoch_times': [],
        'batch_checkpoints': []
    }

# Load EN->VI checkpoint
print("\n Loading EN->VI checkpoint...")
if os.path.exists(resume_checkpoint_en_vi):
    model_en_vi, optimizer_en_vi, scheduler_en_vi, start_epoch_en_vi, start_batch_en_vi, history_en_vi, scaler_en_vi = load_checkpoint(
        model_en_vi,
        resume_checkpoint_en_vi,
        device,
        optimizer_en_vi,
        scheduler_en_vi,
        scaler_en_vi
    )
    print(f" Loaded EN->VI from Epoch {start_epoch_en_vi}")
else:
    print(f" Checkpoint not found: {resume_checkpoint_en_vi}")
    print("   Will train from scratch!")
    start_epoch_en_vi = 1
    start_batch_en_vi = 0
    history_en_vi = {
        'train_loss': [], 'train_ppl': [],
        'val_loss': [], 'val_ppl': [],
        'lr': [], 'epoch_times': [],
        'batch_checkpoints': []
    }

# STEP 4: VERIFY CHECKPOINT INFO

print("\n" + "="*70)
print("CHECKPOINT INFO")
print("="*70)

print("\nüáªüá≥‚Üíüá¨üáß VI->EN Model:")
print(f"   Resuming from: Epoch {start_epoch_vi_en}")
print(f"   Previous epochs: {len(history_vi_en.get('train_loss', []))}")
if history_vi_en.get('train_loss'):
    print(f"   Last train loss: {history_vi_en['train_loss'][-1]:.4f}")
    print(f"   Last train PPL: {history_vi_en['train_ppl'][-1]:.2f}")
if history_vi_en.get('val_loss'):
    print(f"   Last val loss: {history_vi_en['val_loss'][-1]:.4f}")
    print(f"   Last val PPL: {history_vi_en['val_ppl'][-1]:.2f}")
if history_vi_en.get('lr'):
    print(f"   Current LR: {history_vi_en['lr'][-1]:.8f}")

print("\nüá¨üáß‚Üíüáªüá≥ EN->VI Model:")
print(f"   Resuming from: Epoch {start_epoch_en_vi}")
print(f"   Previous epochs: {len(history_en_vi.get('train_loss', []))}")
if history_en_vi.get('train_loss'):
    print(f"   Last train loss: {history_en_vi['train_loss'][-1]:.4f}")
    print(f"   Last train PPL: {history_en_vi['train_ppl'][-1]:.2f}")
if history_en_vi.get('val_loss'):
    print(f"   Last val loss: {history_en_vi['val_loss'][-1]:.4f}")
    print(f"   Last val PPL: {history_en_vi['val_ppl'][-1]:.2f}")
if history_en_vi.get('lr'):
    print(f"   Current LR: {history_en_vi['lr'][-1]:.8f}")

# STEP 5: UPDATE CONFIG FOR CONTINUING TRAINING

print("\n" + "="*70)
print("  UPDATED TRAINING CONFIG")
print("="*70)

# Calculate actual starting epoch (take max in case they differ)
start_epoch = max(start_epoch_vi_en, start_epoch_en_vi) + 1  # +1 to start next epoch

# Update config with NEW hyperparameters
CONFIG_RESUME = {
    'model_size': CONFIG['model_size'],
    'num_epochs': 16,  #  Train to epoch 15 (from epoch 5 ‚Üí 10 more epochs)
    'batch_size': 64,  #  INCREASED from 32
    'warmup_steps': 4000,  #  INCREASED from 2000
    'label_smoothing': 0.05,  #  DECREASED from 0.1
    'save_every_batches': 89000,  #  FIXED from 89000
    'use_amp': True,
    'beam_size': 5,
    'start_epoch': start_epoch,  #  Resume from here
}

print(f"\n Resume Configuration:")
for key, value in CONFIG_RESUME.items():
    print(f"   {key:20s}: {value}")

print(f"\n Changes from original:")
print(f"    batch_size: 32 ‚Üí 64 (2x larger, more stable gradients)")
print(f"    warmup_steps: 2000 ‚Üí 4000 (2x longer warmup)")
print(f"    label_smoothing: 0.1 ‚Üí 0.05 (easier to learn)")
print(f"    save_every_batches: 89000 ‚Üí 500 (safer)")
print(f"    num_epochs: 5 ‚Üí 15 (train 10 more epochs)")

print(f"\n Expected improvement:")
print(f"   Current PPL: ~22 (plateau)")
print(f"   After 10 more epochs: ~16-18 (much better!)")

# STEP 6: IMPORTANT - ADJUST LEARNING RATE (OPTIONAL)

print("\n" + "="*70)
print(" LEARNING RATE ADJUSTMENT (OPTIONAL)")
print("="*70)

# Check current LR
current_lr_vi_en = scheduler_vi_en.get_last_lr()[0]
current_lr_en_vi = scheduler_en_vi.get_last_lr()[0]

print(f"\nCurrent Learning Rates:")
print(f"   VI->EN: {current_lr_vi_en:.8f}")
print(f"   EN->VI: {current_lr_en_vi:.8f}")

# If LR is too low (< 1e-5), consider resetting it
if current_lr_vi_en < 1e-5 or current_lr_en_vi < 1e-5:
    print("\n  WARNING: Learning rate is very low!")
    print("   This might be why you're plateauing.")
    print("\n Options:")
    print("   1. Continue with current LR (conservative)")
    print("   2. Reset LR to higher value (aggressive)")
    print("   3. Use Cosine Annealing to restart LR (recommended)")
    
    # Option: Manually increase LR
    # RESET_LR = 5e-4  # Set higher LR
    # for param_group in optimizer_vi_en.param_groups:
    #     param_group['lr'] = RESET_LR
    # for param_group in optimizer_en_vi.param_groups:
    #     param_group['lr'] = RESET_LR
    # print(f"\n Reset LR to {RESET_LR:.6f}")
else:
    print("\nLearning rate looks OK, continuing with current schedule")


# STEP 7: CREATE HISTORY DICT FOR TRAINING LOOP


# Merge histories into the format expected by training loop
history = {
    'vi_en': history_vi_en if history_vi_en else {
        'train_loss': [], 'train_ppl': [],
        'val_loss': [], 'val_ppl': [],
        'lr': [], 'epoch_times': [],
        'batch_checkpoints': []
    },
    'en_vi': history_en_vi if history_en_vi else {
        'train_loss': [], 'train_ppl': [],
        'val_loss': [], 'val_ppl': [],
        'lr': [], 'epoch_times': [],
        'batch_checkpoints': []
    }
}

print("\n" + "="*70)
print("READY TO RESUME TRAINING!")
print("="*70)
print(f"\n Will start from Epoch {CONFIG_RESUME['start_epoch']}")
print(f" Will train until Epoch {CONFIG_RESUME['num_epochs']}")
print(f"  Estimated time: ~{(CONFIG_RESUME['num_epochs'] - CONFIG_RESUME['start_epoch'] + 1) * 15} minutes")
print("\n Next: Run your training cell (Cell 8) with these loaded checkpoints!")
print("="*70)

# Export variables for training loop
CONFIG = CONFIG_RESUME  # Use updated config

In [None]:
# REDUCE DROPOUT + BOOST LR
print("\n" + "="*70)
print(" ANTI-PLATEAU OPTIMIZATION STRATEGY")
print("="*70)

import torch

# STEP 1: REDUCE DROPOUT (0.1 ‚Üí 0.05)
print("\n STEP 1: REDUCING DROPOUT")
print("-" * 70)

def adjust_dropout(model, new_dropout, model_name="Model"):
    """ƒêi·ªÅu ch·ªânh dropout c·ªßa model ƒëang training"""
    count = 0
    old_dropout = None
    
    for module in model.modules():
        if isinstance(module, torch.nn.Dropout):
            if old_dropout is None:
                old_dropout = module.p
            module.p = new_dropout
            count += 1
    
    print(f"   {model_name}:")
    print(f"      ‚Ä¢ Updated {count} dropout layers")
    print(f"      ‚Ä¢ Dropout: {old_dropout:.2f} ‚Üí {new_dropout:.2f}")
    
    return count

# Current dropout: 0.1
# New dropout: 0.05 (gi·∫£m 50%)
NEW_DROPOUT = 0.05

print(f"\n Target Dropout: {NEW_DROPOUT}")
print(f"   (Reduced from 0.1 to allow better learning)")

# Apply to both models
count_vi_en = adjust_dropout(model_vi_en, NEW_DROPOUT, "VI‚ÜíEN")
count_en_vi = adjust_dropout(model_en_vi, NEW_DROPOUT, "EN‚ÜíVI")

print(f"\n Dropout reduction complete!")
print(f"   This will reduce regularization and allow")
print(f"   model to learn training data better.")

# STEP 2: BOOST LEARNING RATE (2.5x)
print("\n" + "-" * 70)
print(" STEP 2: BOOSTING LEARNING RATE")
print("-" * 70)

# Get current LR
current_lr_vi_en = optimizer_vi_en.param_groups[0]['lr']
current_lr_en_vi = optimizer_en_vi.param_groups[0]['lr']

print(f"\n Current Learning Rates:")
print(f"   VI‚ÜíEN: {current_lr_vi_en:.8f}")
print(f"   EN‚ÜíVI: {current_lr_en_vi:.8f}")
print(f"   (Too low! Model stuck in plateau)")

# Boost factor
BOOST_FACTOR = 2.5  # TƒÉng 2.5x

new_lr_vi_en = current_lr_vi_en * BOOST_FACTOR
new_lr_en_vi = current_lr_en_vi * BOOST_FACTOR

# Apply new LR to optimizer
for param_group in optimizer_vi_en.param_groups:
    param_group['lr'] = new_lr_vi_en

for param_group in optimizer_en_vi.param_groups:
    param_group['lr'] = new_lr_en_vi

print(f"\n New Learning Rates (boosted {BOOST_FACTOR}x):")
print(f"   VI‚ÜíEN: {new_lr_vi_en:.8f} ({current_lr_vi_en:.8f} √ó {BOOST_FACTOR})")
print(f"   EN‚ÜíVI: {new_lr_en_vi:.8f} ({current_lr_en_vi:.8f} √ó {BOOST_FACTOR})")

print(f"\n LR boost complete!")
print(f"   Higher LR will help escape plateau")

# STEP 3: OPTIONAL - SETUP COSINE ANNEALING WITH WARM RESTARTS

print("\n" + "-" * 70)
print(" STEP 3: COSINE ANNEALING SCHEDULER (OPTIONAL)")
print("-" * 70)

USE_COSINE_SCHEDULER = False  #  SET False if you want manual LR only

if USE_COSINE_SCHEDULER:
    from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
    
    # Create new schedulers
    scheduler_vi_en = CosineAnnealingWarmRestarts(
        optimizer_vi_en,
        T_0=3,        # Restart LR every 3 epochs
        T_mult=1,     # Keep restart period constant
        eta_min=1e-6  # Minimum LR
    )
    
    scheduler_en_vi = CosineAnnealingWarmRestarts(
        optimizer_en_vi,
        T_0=3,
        T_mult=1,
        eta_min=1e-6
    )
    
    print(f"\n Cosine Annealing enabled!")
    print(f"   ‚Ä¢ Restart period: 3 epochs")
    print(f"   ‚Ä¢ Min LR: 1e-6")
    print(f"   ‚Ä¢ LR will cycle: high ‚Üí low ‚Üí restart")
    print(f"\n   This helps explore different learning rates")
    print(f"   and escape local minima")
else:
    print(f"\n  Keeping original Transformer LR scheduler")
    print(f"   Using manual LR boost only")

# SUMMARY & EXPECTATIONS

print("\n" + "="*70)
print(" OPTIMIZATION SUMMARY")
print("="*70)

print(f"\n Changes Applied:")
print(f"    Dropout: 0.1 ‚Üí {NEW_DROPOUT} (-50%)")
print(f"    LR VI‚ÜíEN: {current_lr_vi_en:.8f} ‚Üí {new_lr_vi_en:.8f} (+{(BOOST_FACTOR-1)*100:.0f}%)")
print(f"    LR EN‚ÜíVI: {current_lr_en_vi:.8f} ‚Üí {new_lr_en_vi:.8f} (+{(BOOST_FACTOR-1)*100:.0f}%)")
if USE_COSINE_SCHEDULER:
    print(f"    Scheduler: Transformer ‚Üí Cosine Annealing")

# VERIFICATION
print("\n VERIFICATION:")
print(f"   Dropout in VI‚ÜíEN: {next(m for m in model_vi_en.modules() if isinstance(m, torch.nn.Dropout)).p}")
print(f"   Dropout in EN‚ÜíVI: {next(m for m in model_en_vi.modules() if isinstance(m, torch.nn.Dropout)).p}")
print(f"   LR VI‚ÜíEN: {optimizer_vi_en.param_groups[0]['lr']:.8f}")
print(f"   LR EN‚ÜíVI: {optimizer_en_vi.param_groups[0]['lr']:.8f}")
print(f"   Scheduler: {'CosineAnnealingWarmRestarts' if USE_COSINE_SCHEDULER else 'TransformerLRScheduler'}")

In [None]:
# CELL 7: Load Data & Create DataLoaders
print("\n" + "="*70)
print(" STAGE 3: CREATE DATALOADERS")
print("="*70)

import pickle
from dataloader_module import create_dataloaders_with_bucketing

# Load processed data
print("\n Loading processed data...")
with open(f'{DATA_PATH}/processed_data.pkl', 'rb') as f:
    processed_data = pickle.load(f)

# Create VI->EN dataloaders
print("\n Creating VI->EN dataloaders...")
train_loader_vi_en, val_loader_vi_en, test_loader_vi_en = create_dataloaders_with_bucketing(
    processed_data,
    batch_size=CONFIG['batch_size'],
    num_workers=2
)

print(f"\n VI->EN DataLoaders ready!")
print(f"   Train batches: {len(train_loader_vi_en)}")
print(f"   Val batches: {len(val_loader_vi_en)}")
print(f"   Test batches: {len(test_loader_vi_en)}")

# Create EN->VI dataloaders (reversed data)
print("\n Creating EN->VI dataloaders...")
reversed_data = {
    'train': {
        'src': processed_data['train']['tgt'],
        'tgt': processed_data['train']['src']
    },
    'validation': {
        'src': processed_data['validation']['tgt'],
        'tgt': processed_data['validation']['src']
    },
    'test': {
        'src': processed_data['test']['tgt'],
        'tgt': processed_data['test']['src']
    }
}

train_loader_en_vi, val_loader_en_vi, test_loader_en_vi = create_dataloaders_with_bucketing(
    reversed_data,
    batch_size=CONFIG['batch_size'],
    num_workers=2
)

print(f"\n EN->VI DataLoaders ready!")
print(f"   Train batches: {len(train_loader_en_vi)}")
print(f"   Val batches: {len(val_loader_en_vi)}")
print(f"   Test batches: {len(test_loader_en_vi)}")

In [None]:
# CELL 8: RESUME TRAINING WITH IMPROVED HYPERPARAMETERS
print("\n" + "="*70)
print("üöÄ RESUMING BIDIRECTIONAL TRAINING (IMPROVED)")
print("="*70)

from training_module import (
    TransformerLRScheduler, 
    LabelSmoothingLoss,
    calculate_perplexity,
    save_checkpoint,
    validate
)
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import time


print("\n  Setting up loss functions with NEW smoothing...")
criterion_vi_en = LabelSmoothingLoss(
    vocab_size=len(en_tokenizer), 
    pad_idx=0, 
    smoothing=CONFIG['label_smoothing']  # Now 0.05 instead of 0.1
)
criterion_en_vi = LabelSmoothingLoss(
    vocab_size=len(vi_tokenizer), 
    pad_idx=0, 
    smoothing=CONFIG['label_smoothing']
)

print(f"Label smoothing: {CONFIG['label_smoothing']} (was 0.1)")

# RECREATE DATALOADERS WITH NEW BATCH SIZE

if CONFIG['batch_size'] != 32:  # If we changed batch size
    print(f"\n Recreating dataloaders with batch_size={CONFIG['batch_size']}...")
    
    from dataloader_module import create_dataloaders_with_bucketing
    
    train_loader_vi_en, val_loader_vi_en, test_loader_vi_en = create_dataloaders_with_bucketing(
        processed_data,
        batch_size=CONFIG['batch_size'],
        num_workers=2
    )
    
    train_loader_en_vi, val_loader_en_vi, test_loader_en_vi = create_dataloaders_with_bucketing(
        reversed_data,
        batch_size=CONFIG['batch_size'],
        num_workers=2
    )
    
    print(f" New batch size: {CONFIG['batch_size']} (was 32)")
    print(f"   Batches per epoch: {len(train_loader_vi_en)}")

# TRAINING INFO

print("\n" + "="*70)
print(" TRAINING CONFIGURATION")
print("="*70)
print(f"\n Resuming from: Epoch {CONFIG['start_epoch']}")
print(f" Training until: Epoch {CONFIG['num_epochs']}")
print(f" Epochs to train: {CONFIG['num_epochs'] - CONFIG['start_epoch'] + 1}")
print(f"\n Previous Performance:")
if history['vi_en']['val_ppl']:
    print(f"  VI‚ÜíEN PPL: {history['vi_en']['val_ppl'][-1]:.2f}")
if history['en_vi']['val_ppl']:
    print(f"  EN‚ÜíVI PPL: {history['en_vi']['val_ppl'][-1]:.2f}")
print(f"\n Expected after 10 more epochs: PPL ~16-18")

# Create checkpoint directories
os.makedirs(f'{OUTPUT_PATH}/checkpoints/vi_en', exist_ok=True)
os.makedirs(f'{OUTPUT_PATH}/checkpoints/en_vi', exist_ok=True)

print(f"\n Training start: {time.strftime('%H:%M:%S')}")
print("="*70)

best_val_loss_vi_en = min(history['vi_en']['val_loss']) if history['vi_en']['val_loss'] else float('inf')
best_val_loss_en_vi = min(history['en_vi']['val_loss']) if history['en_vi']['val_loss'] else float('inf')

# TRAINING LOOP - RESUME FROM start_epoch

for epoch in range(CONFIG['start_epoch'], CONFIG['num_epochs'] + 1):
    epoch_start_time = time.time()
    
    print(f"\n{'='*70}")
    print(f"EPOCH {epoch}/{CONFIG['num_epochs']}")
    print(f"{'='*70}")
    
    # Train mode
    model_vi_en.train()
    model_en_vi.train()
    
    # Statistics
    total_loss_vi_en = 0
    total_tokens_vi_en = 0
    total_loss_en_vi = 0
    total_tokens_en_vi = 0
    batch_loss_vi_en = 0
    batch_tokens_vi_en = 0
    batch_loss_en_vi = 0
    batch_tokens_en_vi = 0
    
    # Iterators
    iter_vi_en = iter(train_loader_vi_en)
    iter_en_vi = iter(train_loader_en_vi)
    max_batches = max(len(train_loader_vi_en), len(train_loader_en_vi))
    
    progress_bar = tqdm(range(max_batches), desc=f'Epoch {epoch}')
    
    for batch_idx in progress_bar:
        #  TRAIN VI->EN 
        try:
            src, tgt, _, _ = next(iter_vi_en)
            src, tgt = src.to(device), tgt.to(device)
            
            optimizer_vi_en.zero_grad()
            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]
            
            if CONFIG['use_amp']:
                with autocast():
                    output = model_vi_en(src, tgt_input)
                    loss_vi_en = criterion_vi_en(output, tgt_output)
                scaler_vi_en.scale(loss_vi_en).backward()
                scaler_vi_en.unscale_(optimizer_vi_en)
                torch.nn.utils.clip_grad_norm_(model_vi_en.parameters(), 1.0)
                scaler_vi_en.step(optimizer_vi_en)
                scaler_vi_en.update()
            else:
                output = model_vi_en(src, tgt_input)
                loss_vi_en = criterion_vi_en(output, tgt_output)
                loss_vi_en.backward()
                torch.nn.utils.clip_grad_norm_(model_vi_en.parameters(), 1.0)
                optimizer_vi_en.step()
            
            scheduler_vi_en.step()
            
            num_tokens = (tgt_output != 0).sum().item()
            total_loss_vi_en += loss_vi_en.item() * num_tokens
            total_tokens_vi_en += num_tokens
            batch_loss_vi_en += loss_vi_en.item() * num_tokens
            batch_tokens_vi_en += num_tokens
            
            del src, tgt, output, loss_vi_en
            
        except StopIteration:
            iter_vi_en = iter(train_loader_vi_en)
        
        #  TRAIN EN->VI 
        try:
            src, tgt, _, _ = next(iter_en_vi)
            src, tgt = src.to(device), tgt.to(device)
            
            optimizer_en_vi.zero_grad()
            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]
            
            if CONFIG['use_amp']:
                with autocast():
                    output = model_en_vi(src, tgt_input)
                    loss_en_vi = criterion_en_vi(output, tgt_output)
                scaler_en_vi.scale(loss_en_vi).backward()
                scaler_en_vi.unscale_(optimizer_en_vi)
                torch.nn.utils.clip_grad_norm_(model_en_vi.parameters(), 1.0)
                scaler_en_vi.step(optimizer_en_vi)
                scaler_en_vi.update()
            else:
                output = model_en_vi(src, tgt_input)
                loss_en_vi = criterion_en_vi(output, tgt_output)
                loss_en_vi.backward()
                torch.nn.utils.clip_grad_norm_(model_en_vi.parameters(), 1.0)
                optimizer_en_vi.step()
            
            scheduler_en_vi.step()
            
            num_tokens = (tgt_output != 0).sum().item()
            total_loss_en_vi += loss_en_vi.item() * num_tokens
            total_tokens_en_vi += num_tokens
            batch_loss_en_vi += loss_en_vi.item() * num_tokens
            batch_tokens_en_vi += num_tokens
            
            del src, tgt, output, loss_en_vi
            
        except StopIteration:
            iter_en_vi = iter(train_loader_en_vi)
        
        # Update progress
        current_loss_vi_en = total_loss_vi_en / max(total_tokens_vi_en, 1)
        current_loss_en_vi = total_loss_en_vi / max(total_tokens_en_vi, 1)
        current_ppl_vi_en = calculate_perplexity(current_loss_vi_en)
        current_ppl_en_vi = calculate_perplexity(current_loss_en_vi)
        current_lr = scheduler_vi_en.get_last_lr()[0]
        
        progress_bar.set_postfix({
            'VI‚ÜíEN_loss': f'{current_loss_vi_en:.4f}',
            'VI‚ÜíEN_ppl': f'{current_ppl_vi_en:.1f}',
            'EN‚ÜíVI_loss': f'{current_loss_en_vi:.4f}',
            'EN‚ÜíVI_ppl': f'{current_ppl_en_vi:.1f}',
            'LR': f'{current_lr:.6f}'
        })
        
        # Save checkpoint
        if (batch_idx + 1) % CONFIG['save_every_batches'] == 0:
            avg_loss_vi_en = batch_loss_vi_en / max(batch_tokens_vi_en, 1)
            avg_loss_en_vi = batch_loss_en_vi / max(batch_tokens_en_vi, 1)
            
            print(f"\n Checkpoint at batch {batch_idx + 1}/{max_batches}")
            
            save_checkpoint(
                model_vi_en, optimizer_vi_en, scheduler_vi_en,
                epoch, batch_idx + 1, avg_loss_vi_en, None,
                f'{OUTPUT_PATH}/checkpoints/vi_en',
                history=history['vi_en'], scaler=scaler_vi_en
            )
            
            save_checkpoint(
                model_en_vi, optimizer_en_vi, scheduler_en_vi,
                epoch, batch_idx + 1, avg_loss_en_vi, None,
                f'{OUTPUT_PATH}/checkpoints/en_vi',
                history=history['en_vi'], scaler=scaler_en_vi
            )
            
            batch_loss_vi_en = 0
            batch_tokens_vi_en = 0
            batch_loss_en_vi = 0
            batch_tokens_en_vi = 0
    
    # VALIDATION 
    print("\n Validating...")
    
    val_loss_vi_en, val_ppl_vi_en = validate(
        model_vi_en, val_loader_vi_en, criterion_vi_en, device
    )
    val_loss_en_vi, val_ppl_en_vi = validate(
        model_en_vi, val_loader_en_vi, criterion_en_vi, device
    )
    
    epoch_time = time.time() - epoch_start_time
    
    # Calculate metrics
    train_loss_vi_en = total_loss_vi_en / total_tokens_vi_en
    train_loss_en_vi = total_loss_en_vi / total_tokens_en_vi
    train_ppl_vi_en = calculate_perplexity(train_loss_vi_en)
    train_ppl_en_vi = calculate_perplexity(train_loss_en_vi)
    
    # Update history
    history['vi_en']['train_loss'].append(train_loss_vi_en)
    history['vi_en']['train_ppl'].append(train_ppl_vi_en)
    history['vi_en']['val_loss'].append(val_loss_vi_en)
    history['vi_en']['val_ppl'].append(val_ppl_vi_en)
    history['vi_en']['lr'].append(scheduler_vi_en.get_last_lr()[0])
    history['vi_en']['epoch_times'].append(epoch_time)
    
    history['en_vi']['train_loss'].append(train_loss_en_vi)
    history['en_vi']['train_ppl'].append(train_ppl_en_vi)
    history['en_vi']['val_loss'].append(val_loss_en_vi)
    history['en_vi']['val_ppl'].append(val_ppl_en_vi)
    history['en_vi']['lr'].append(scheduler_en_vi.get_last_lr()[0])
    history['en_vi']['epoch_times'].append(epoch_time)
    
    # Print results
    print(f"\n{'='*70}")
    print(f"EPOCH {epoch}/{CONFIG['num_epochs']} COMPLETE - {epoch_time/60:.2f} min")
    print(f"{'='*70}")
    
    print(f"\nüáªüá≥‚Üíüá¨üáß VI‚ÜíEN:")
    print(f"  Train: Loss={train_loss_vi_en:.4f}, PPL={train_ppl_vi_en:.2f}")
    print(f"  Val:   Loss={val_loss_vi_en:.4f}, PPL={val_ppl_vi_en:.2f}")
    print(f"  LR: {scheduler_vi_en.get_last_lr()[0]:.6f}")
    
    print(f"\nüá¨üáß‚Üíüáªüá≥ EN‚ÜíVI:")
    print(f"  Train: Loss={train_loss_en_vi:.4f}, PPL={train_ppl_en_vi:.2f}")
    print(f"  Val:   Loss={val_loss_en_vi:.4f}, PPL={val_ppl_en_vi:.2f}")
    print(f"  LR: {scheduler_en_vi.get_last_lr()[0]:.6f}")
    
    # Check best
    is_best_vi_en = val_loss_vi_en < best_val_loss_vi_en
    is_best_en_vi = val_loss_en_vi < best_val_loss_en_vi
    
    if is_best_vi_en:
        best_val_loss_vi_en = val_loss_vi_en
        print(f"\n   NEW BEST VI‚ÜíEN! Val PPL: {val_ppl_vi_en:.2f}")
    
    if is_best_en_vi:
        best_val_loss_en_vi = val_loss_en_vi
        print(f"   NEW BEST EN‚ÜíVI! Val PPL: {val_ppl_en_vi:.2f}")
    
    # Save epoch checkpoint
    print(f"\n Saving epoch checkpoint...")
    
    save_checkpoint(
        model_vi_en, optimizer_vi_en, scheduler_vi_en,
        epoch, 0, train_loss_vi_en, val_loss_vi_en,
        f'{OUTPUT_PATH}/checkpoints/vi_en',
        history=history['vi_en'], scaler=scaler_vi_en, is_best=is_best_vi_en
    )
    
    save_checkpoint(
        model_en_vi, optimizer_en_vi, scheduler_en_vi,
        epoch, 0, train_loss_en_vi, val_loss_en_vi,
        f'{OUTPUT_PATH}/checkpoints/en_vi',
        history=history['en_vi'], scaler=scaler_en_vi, is_best=is_best_en_vi
    )

print(f"\n Training end: {time.strftime('%H:%M:%S')}")
print("\n" + "="*70)
print(" TRAINING COMPLETE!")
print("="*70)

# Final summary
print(f"\n FINAL RESULTS:")
print(f"\nüáªüá≥‚Üíüá¨üáß VI‚ÜíEN:")
print(f"  Best Val PPL: {calculate_perplexity(best_val_loss_vi_en):.2f}")
print(f"  Final Val PPL: {history['vi_en']['val_ppl'][-1]:.2f}")

print(f"\nüá¨üáß‚Üíüáªüá≥ EN‚ÜíVI:")
print(f"  Best Val PPL: {calculate_perplexity(best_val_loss_en_vi):.2f}")
print(f"  Final Val PPL: {history['en_vi']['val_ppl'][-1]:.2f}")

print(f"\n Total training time: {sum(history['vi_en']['epoch_times'])/60:.1f} min")

In [None]:
# CELL 9: COMPREHENSIVE EVALUATION - BLEU & METRICS 
print("\n" + "="*70)
print("COMPREHENSIVE EVALUATION - BLEU SCORES")
print("="*70)

import torch
from tqdm import tqdm
import json
import os
import numpy as np

# Import evaluation functions
from inference_evaluation_v2 import (
    compute_sentence_bleu,
    calculate_corpus_bleu,
    calculate_gemini_score,
    evaluate_translations,
    print_evaluation_results,
    translate_sentence
)

# EVALUATION CONFIGURATION

EVAL_CONFIG = {
    'beam_size': 5,
    'max_length': 200,
    'num_samples': 1000,  # None = use all samples from file
    'show_examples': 10,  # Number of examples to show
    'use_gemini': False,
    'gemini_api_key': "AIzaSyDqHNpCj6xpao90muWoPOSwIBTcNzVk6Is"
}

# File paths
VI_FILE_PATH = '/kaggle/input/eva-model/kaggle_vi_sents_first1000.txt'
EN_FILE_PATH = '/kaggle/input/eva-model/kaggle_en_sents_first1000.txt'

print("\n Evaluation Configuration:")
for key, value in EVAL_CONFIG.items():
    if key != 'gemini_api_key':
        print(f"  {key:20s}: {value}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# FUNCTION: LOAD SENTENCES FROM FILE

def load_sentences_from_file(file_path, max_samples=None):
    """
    Load sentences from text file
    """
    print(f"\n Loading sentences from: {file_path}")
    
    sentences = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:  # Skip empty lines
                sentences.append(line)
                if max_samples and len(sentences) >= max_samples:
                    break
    
    print(f" Loaded {len(sentences)} sentences")
    return sentences

# FUNCTION: GENERATE TRANSLATIONS FROM FILE


def generate_translations_from_file(model, source_sentences, src_tokenizer, 
                                   tgt_tokenizer, device, beam_size=5, max_len=100):
    model.eval()
    predictions = []
    
    print(f"\n Generating translations for {len(source_sentences)} samples...")
    
    with torch.no_grad():
        for src_text in tqdm(source_sentences, desc="Translating"):
            # Generate translation
            pred_text = translate_sentence(
                model=model,
                sentence=src_text,
                src_tokenizer=src_tokenizer,
                tgt_tokenizer=tgt_tokenizer,
                device=device,
                use_beam_search=True,
                beam_size=beam_size,
                max_len=max_len
            )
            predictions.append(pred_text)
    
    print(f" Generated {len(predictions)} translations")
    return predictions


# LOAD DATA FROM FILES


print("\n" + "="*70)
print(" LOADING DATA FROM FILES")
print("="*70)

# Load Vietnamese sentences
vi_sentences = load_sentences_from_file(VI_FILE_PATH, EVAL_CONFIG['num_samples'])

# Load English sentences (references)
en_sentences = load_sentences_from_file(EN_FILE_PATH, EVAL_CONFIG['num_samples'])

# Make sure both files have same number of sentences
min_length = min(len(vi_sentences), len(en_sentences))
vi_sentences = vi_sentences[:min_length]
en_sentences = en_sentences[:min_length]

print(f"\n Using {min_length} sentence pairs for evaluation")

# EVALUATE VI->EN MODEL

print("\n" + "="*70)
print(" EVALUATING VI->EN MODEL")
print("="*70)

print("\n Loading best VI->EN checkpoint...")
checkpoint_paths_vi_en = [
    '/kaggle/input/checkpoint-nlp/best_model_vi_en_14.pt',
    f'{OUTPUT_PATH}/checkpoints/vi_en/best_model.pt',
    f'{OUTPUT_PATH}/checkpoints/vi_en/latest_checkpoint.pt'
]

checkpoint_path = None
for path in checkpoint_paths_vi_en:
    if os.path.exists(path):
        checkpoint_path = path
        print(f" Found checkpoint: {path}")
        break

if checkpoint_path is None:
    raise FileNotFoundError("‚ùå No VI->EN checkpoint found!")

checkpoint = torch.load(checkpoint_path, map_location=device)
model_vi_en.load_state_dict(checkpoint['model_state_dict'])
model_vi_en.eval()
print(" Model loaded!")

# Generate translations
sources_vi = vi_sentences
refs_en = en_sentences
preds_en = generate_translations_from_file(
    model_vi_en, vi_sentences, vi_tokenizer, en_tokenizer,
    device, beam_size=EVAL_CONFIG['beam_size'],
    max_len=EVAL_CONFIG['max_length']
)

# Calculate CORPUS BLEU
print("\n Calculating Corpus BLEU scores...")
bleu_results_vi_en = calculate_corpus_bleu(refs_en, preds_en)

print("\n" + "="*70)
print(" VI->EN RESULTS")
print("="*70)
print(f"\n Corpus BLEU Score: {bleu_results_vi_en['corpus_bleu']:.2f}")
print(f"\n Detailed Metrics:")
print(f"  BLEU-1: {bleu_results_vi_en['corpus_bleu_1']:.2f}")
print(f"  BLEU-2: {bleu_results_vi_en['corpus_bleu_2']:.2f}")
print(f"  BLEU-3: {bleu_results_vi_en['corpus_bleu_3']:.2f}")
print(f"  BLEU-4: {bleu_results_vi_en['corpus_bleu_4']:.2f}")
print(f"  Brevity Penalty: {bleu_results_vi_en['brevity_penalty']:.3f}")
print(f"  Length Ratio: {bleu_results_vi_en['length_ratio']:.3f}")

# Show sample translations
if EVAL_CONFIG['show_examples'] > 0:
    print(f"\n Sample Translations (first {EVAL_CONFIG['show_examples']}):")
    print("="*70)
    for i in range(min(EVAL_CONFIG['show_examples'], len(sources_vi))):
        individual_bleu = compute_sentence_bleu(refs_en[i], preds_en[i])
        
        print(f"\n[Example {i+1}] - Sentence BLEU: {individual_bleu['bleu']:.2f}")
        print(f" Source (VI): {sources_vi[i]}")
        print(f" Reference:   {refs_en[i]}")
        print(f" Prediction:  {preds_en[i]}")
        print(f"    BLEU-1/2/3/4: {individual_bleu['bleu_1']:.1f}/{individual_bleu['bleu_2']:.1f}/{individual_bleu['bleu_3']:.1f}/{individual_bleu['bleu_4']:.1f}")
        print("-"*70)

# Optional: Gemini evaluation
if EVAL_CONFIG['use_gemini'] and EVAL_CONFIG['gemini_api_key']:
    print("\n  Running Gemini evaluation (sample of 5 translations)...")
    gemini_scores_vi_en = []
    for i in range(min(5, len(sources_vi))):
        score = calculate_gemini_score(
            sources_vi[i], refs_en[i], preds_en[i],
            api_key=EVAL_CONFIG['gemini_api_key']
        )
        gemini_scores_vi_en.append(score)
        if score['gemini_score']:
            print(f"  Example {i+1}: Score={score['gemini_score']:.1f}, "
                  f"Fluency={score['fluency']:.1f}, Adequacy={score['adequacy']:.1f}")

# EVALUATE EN->VI MODEL

print("\n" + "="*70)
print(" EVALUATING EN->VI MODEL")
print("="*70)

print("\n Loading best EN->VI checkpoint...")
checkpoint_paths_en_vi = [
    '/kaggle/input/checkpoint-nlp/best_model_en_vi_14.pt',
    f'{OUTPUT_PATH}/checkpoints/en_vi/best_model.pt',
    f'{OUTPUT_PATH}/checkpoints/en_vi/latest_checkpoint.pt'
]

checkpoint_path = None
for path in checkpoint_paths_en_vi:
    if os.path.exists(path):
        checkpoint_path = path
        print(f" Found checkpoint: {path}")
        break

if checkpoint_path is None:
    raise FileNotFoundError("‚ùå No EN->VI checkpoint found!")

checkpoint = torch.load(checkpoint_path, map_location=device)
model_en_vi.load_state_dict(checkpoint['model_state_dict'])
model_en_vi.eval()
print(" Model loaded!")

# Generate translations
sources_en = en_sentences
refs_vi = vi_sentences
preds_vi = generate_translations_from_file(
    model_en_vi, en_sentences, en_tokenizer, vi_tokenizer,
    device, beam_size=EVAL_CONFIG['beam_size'],
    max_len=EVAL_CONFIG['max_length']
)

# Calculate CORPUS BLEU
print("\n Calculating Corpus BLEU scores...")
bleu_results_en_vi = calculate_corpus_bleu(refs_vi, preds_vi)

print("\n" + "="*70)
print(" EN->VI RESULTS")
print("="*70)
print(f"\nCorpus BLEU Score: {bleu_results_en_vi['corpus_bleu']:.2f}")
print(f"\n Detailed Metrics:")
print(f"  BLEU-1: {bleu_results_en_vi['corpus_bleu_1']:.2f}")
print(f"  BLEU-2: {bleu_results_en_vi['corpus_bleu_2']:.2f}")
print(f"  BLEU-3: {bleu_results_en_vi['corpus_bleu_3']:.2f}")
print(f"  BLEU-4: {bleu_results_en_vi['corpus_bleu_4']:.2f}")
print(f"  Brevity Penalty: {bleu_results_en_vi['brevity_penalty']:.3f}")
print(f"  Length Ratio: {bleu_results_en_vi['length_ratio']:.3f}")

# Show examples
if EVAL_CONFIG['show_examples'] > 0:
    print(f"\n Sample Translations (first {EVAL_CONFIG['show_examples']}):")
    print("="*70)
    for i in range(min(EVAL_CONFIG['show_examples'], len(sources_en))):
        individual_bleu = compute_sentence_bleu(refs_vi[i], preds_vi[i])
        
        print(f"\n[Example {i+1}] - Sentence BLEU: {individual_bleu['bleu']:.2f}")
        print(f" Source (EN): {sources_en[i]}")
        print(f" Reference:   {refs_vi[i]}")
        print(f" Prediction:  {preds_vi[i]}")
        print(f"    BLEU-1/2/3/4: {individual_bleu['bleu_1']:.1f}/{individual_bleu['bleu_2']:.1f}/{individual_bleu['bleu_3']:.1f}/{individual_bleu['bleu_4']:.1f}")
        print("-"*70)

# Optional: Gemini evaluation
if EVAL_CONFIG['use_gemini'] and EVAL_CONFIG['gemini_api_key']:
    print("\n Running Gemini evaluation (sample of 5 translations)...")
    gemini_scores_en_vi = []
    for i in range(min(5, len(sources_en))):
        score = calculate_gemini_score(
            sources_en[i], refs_vi[i], preds_vi[i],
            api_key=EVAL_CONFIG['gemini_api_key']
        )
        gemini_scores_en_vi.append(score)
        if score['gemini_score']:
            print(f"  Example {i+1}: Score={score['gemini_score']:.1f}, "
                  f"Fluency={score['fluency']:.1f}, Adequacy={score['adequacy']:.1f}")

# COMPREHENSIVE EVALUATION SUMMARY

print("\n" + "="*70)
print(" COMPREHENSIVE EVALUATION SUMMARY")
print("="*70)

# VI->EN comprehensive evaluation
print("\n VI->EN Comprehensive Metrics:")
eval_results_vi_en = evaluate_translations(
    sources=sources_vi,
    references=refs_en,
    hypotheses=preds_en,
    use_gemini=EVAL_CONFIG['use_gemini'],
    gemini_api_key=EVAL_CONFIG['gemini_api_key'],
    use_sacrebleu=True
)
print_evaluation_results(eval_results_vi_en)

# EN->VI comprehensive evaluation
print("\n EN->VI Comprehensive Metrics:")
eval_results_en_vi = evaluate_translations(
    sources=sources_en,
    references=refs_vi,
    hypotheses=preds_vi,
    use_gemini=EVAL_CONFIG['use_gemini'],
    gemini_api_key=EVAL_CONFIG['gemini_api_key'],
    use_sacrebleu=True
)
print_evaluation_results(eval_results_en_vi)

# FINAL SUMMARY

print("\n" + "="*70)
print(" FINAL EVALUATION SUMMARY")
print("="*70)

print(f"\n Overall Performance:")
print(f"\n  VI->EN:")
print(f"    Corpus BLEU:    {bleu_results_vi_en['corpus_bleu']:.2f}")
print(f"    BLEU-4:         {bleu_results_vi_en['corpus_bleu_4']:.2f}")
print(f"    Brevity Penalty: {bleu_results_vi_en['brevity_penalty']:.3f}")
print(f"    Test Samples:   {len(preds_en)}")

print(f"\n   EN->VI:")
print(f"    Corpus BLEU:    {bleu_results_en_vi['corpus_bleu']:.2f}")
print(f"    BLEU-4:         {bleu_results_en_vi['corpus_bleu_4']:.2f}")
print(f"    Brevity Penalty: {bleu_results_en_vi['brevity_penalty']:.3f}")
print(f"    Test Samples:   {len(preds_vi)}")

avg_bleu = (bleu_results_vi_en['corpus_bleu'] + bleu_results_en_vi['corpus_bleu']) / 2
print(f"\n   Average BLEU: {avg_bleu:.2f}")

# Interpretation guide
print("\nBLEU Score Interpretation:")
print("  < 10:  Almost unusable")
print("  10-19: Difficult to understand")
print("  20-29: Understandable with effort")
print("  30-39: Understandable")
print("  40-49: High quality")
print("  50-59: Very high quality")
print("  > 60:  Native-like quality")

# Quality assessment
if avg_bleu < 20:
    quality = " Poor - Needs significant improvement"
elif avg_bleu < 30:
    quality = "  Fair - Understandable but needs work"
elif avg_bleu < 40:
    quality = " Good - Acceptable quality"
elif avg_bleu < 50:
    quality = " Very Good - High quality"
else:
    quality = " Excellent - Professional quality"

print(f"\n Overall Quality Assessment: {quality}")

# SAVE RESULTS

print("\n Saving evaluation results...")

os.makedirs(f'{OUTPUT_PATH}/results', exist_ok=True)

# Save comprehensive results
eval_results_full = {
    'vi_en': {
        'corpus_bleu': bleu_results_vi_en['corpus_bleu'],
        'bleu_1': bleu_results_vi_en['corpus_bleu_1'],
        'bleu_2': bleu_results_vi_en['corpus_bleu_2'],
        'bleu_3': bleu_results_vi_en['corpus_bleu_3'],
        'bleu_4': bleu_results_vi_en['corpus_bleu_4'],
        'brevity_penalty': bleu_results_vi_en['brevity_penalty'],
        'length_ratio': bleu_results_vi_en['length_ratio'],
        'num_samples': len(preds_en)
    },
    'en_vi': {
        'corpus_bleu': bleu_results_en_vi['corpus_bleu'],
        'bleu_1': bleu_results_en_vi['corpus_bleu_1'],
        'bleu_2': bleu_results_en_vi['corpus_bleu_2'],
        'bleu_3': bleu_results_en_vi['corpus_bleu_3'],
        'bleu_4': bleu_results_en_vi['corpus_bleu_4'],
        'brevity_penalty': bleu_results_en_vi['brevity_penalty'],
        'length_ratio': bleu_results_en_vi['length_ratio'],
        'num_samples': len(preds_vi)
    },
    'overall': {
        'avg_bleu': avg_bleu,
        'quality_assessment': quality
    },
    'config': EVAL_CONFIG,
    'data_source': {
        'vi_file': VI_FILE_PATH,
        'en_file': EN_FILE_PATH
    }
}

with open(f'{OUTPUT_PATH}/results/evaluation_results.json', 'w', encoding='utf-8') as f:
    json.dump(eval_results_full, f, indent=2, ensure_ascii=False)

# Save sample translations
with open(f'{OUTPUT_PATH}/results/sample_translations_vi_en.txt', 'w', encoding='utf-8') as f:
    f.write("VI->EN SAMPLE TRANSLATIONS\n")
    f.write("="*70 + "\n\n")
    for i in range(min(50, len(sources_vi))):
        individual_bleu = compute_sentence_bleu(refs_en[i], preds_en[i])
        f.write(f"[Example {i+1}] BLEU: {individual_bleu['bleu']:.2f}\n")
        f.write(f"Source: {sources_vi[i]}\n")
        f.write(f"Reference: {refs_en[i]}\n")
        f.write(f"Prediction: {preds_en[i]}\n")
        f.write(f"BLEU-1/2/3/4: {individual_bleu['bleu_1']:.1f}/{individual_bleu['bleu_2']:.1f}/{individual_bleu['bleu_3']:.1f}/{individual_bleu['bleu_4']:.1f}\n")
        f.write("-"*70 + "\n\n")

with open(f'{OUTPUT_PATH}/results/sample_translations_en_vi.txt', 'w', encoding='utf-8') as f:
    f.write("EN->VI SAMPLE TRANSLATIONS\n")
    f.write("="*70 + "\n\n")
    for i in range(min(50, len(sources_en))):
        individual_bleu = compute_sentence_bleu(refs_vi[i], preds_vi[i])
        f.write(f"[Example {i+1}] BLEU: {individual_bleu['bleu']:.2f}\n")
        f.write(f"Source: {sources_en[i]}\n")
        f.write(f"Reference: {refs_vi[i]}\n")
        f.write(f"Prediction: {preds_vi[i]}\n")
        f.write(f"BLEU-1/2/3/4: {individual_bleu['bleu_1']:.1f}/{individual_bleu['bleu_2']:.1f}/{individual_bleu['bleu_3']:.1f}/{individual_bleu['bleu_4']:.1f}\n")
        f.write("-"*70 + "\n\n")

print(" Results saved!")
print(f"   - evaluation_results.json")
print(f"   - sample_translations_vi_en.txt")
print(f"   - sample_translations_en_vi.txt")

print("\n" + "="*70)
print(" EVALUATION COMPLETE!")
print("="*70)

print("\n Output Files:")
print(f"   JSON results:   {OUTPUT_PATH}/results/evaluation_results.json")
print(f"   VI->EN samples: {OUTPUT_PATH}/results/sample_translations_vi_en.txt")
print(f"   EN->VI samples: {OUTPUT_PATH}/results/sample_translations_en_vi.txt")

print("\n" + "="*70)
print("KEY METRICS FOR REPORT:")
print("="*70)
print(f"  VI->EN Corpus BLEU: {bleu_results_vi_en['corpus_bleu']:.2f}")
print(f"  EN->VI Corpus BLEU: {bleu_results_en_vi['corpus_bleu']:.2f}")
print(f"  Average BLEU:       {avg_bleu:.2f}")
print("="*70)