<a href="https://colab.research.google.com/github/nguyendai05/train-tri-language/blob/main/train_tri_lg.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
# ============================================================
# 🌉 TRI-LINGUA BRIDGE - COLAB TRAINING V2
# ============================================================
# CELL 1: SETUP & KIỂM TRA GPU
# ============================================================

"""
CELL 1: Cài đặt và kiểm tra môi trường
"""

# Kiểm tra GPU
!nvidia-smi

# Cài đặt dependencies
!pip install -q tensorflow>=2.15.0
!pip install -q transformers>=4.36.0 datasets>=2.14.0 accelerate>=0.24.0
!pip install -q sentencepiece tokenizers sacrebleu sacremoses
!pip install -q pandas numpy tqdm requests
!pip install -q evaluate  # Thư viện metrics mới của HuggingFace

# Kiểm tra TensorFlow + GPU
import tensorflow as tf
print(f"\n✅ TensorFlow: {tf.__version__}")
print(f"✅ GPU available: {tf.config.list_physical_devices('GPU')}")

# Kiểm tra PyTorch (cho mBART)
import torch
print(f"✅ PyTorch: {torch.__version__}")
print(f"✅ CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"✅ GPU: {torch.cuda.get_device_name(0)}")
    print(f"✅ VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Tạo thư mục
!mkdir -p /content/tri-lingua/data/phonetic
!mkdir -p /content/tri-lingua/data/translation
!mkdir -p /content/tri-lingua/models/phonetic
!mkdir -p /content/tri-lingua/models/translation

print("\n✅ Setup complete!")

Thu Dec 11 15:55:29 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   28C    P0             43W /  400W |       5MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
# Ô code 2
# ============================================================
# 🌐 TRI-LINGUA - DOWNLOAD TRANSLATION DATA (V3 - IMPROVED)
# ============================================================
# Version 3: Data quality focus
#
# Cải tiến:
# - ✅ Data cleaning và deduplication
# - ✅ Quality filtering (length ratio, encoding)
# - ✅ Stratified sampling để cân bằng sources
# - ✅ Validation set riêng biệt (không overlap với train)
# - ✅ Progress tracking và error recovery
# ============================================================

# !pip install -q datasets pandas tqdm

import os
import json
import random
import hashlib
from pathlib import Path
from tqdm.auto import tqdm
from collections import defaultdict
from typing import List, Dict, Set, Tuple

# Tạo thư mục
DATA_DIR = '/content/tri-lingua/data/translation'
os.makedirs(DATA_DIR, exist_ok=True)

# Set seed
random.seed(42)

print("=" * 70)
print("📥 DOWNLOADING TRANSLATION DATA (V3 - IMPROVED)")
print("=" * 70)

# ==================== HELPER FUNCTIONS ====================

def clean_text(text: str) -> str:
    """Clean và normalize text"""
    if not text:
        return ""

    # Remove extra whitespace
    text = ' '.join(text.split())

    # Remove control characters (giữ lại Unicode)
    text = ''.join(c for c in text if c.isprintable() or c in '\n\t')

    return text.strip()

def is_valid_pair(src: str, tgt: str,
                  min_len: int = 5,
                  max_len: int = 300,
                  min_ratio: float = 0.2,
                  max_ratio: float = 5.0) -> bool:
    """
    Kiểm tra cặp dịch có hợp lệ không

    Args:
        src, tgt: Source và target text
        min_len: Độ dài tối thiểu (characters)
        max_len: Độ dài tối đa
        min_ratio, max_ratio: Tỷ lệ độ dài cho phép
    """
    if not src or not tgt:
        return False

    len_src, len_tgt = len(src), len(tgt)

    # Độ dài hợp lý
    if len_src < min_len or len_tgt < min_len:
        return False
    if len_src > max_len or len_tgt > max_len:
        return False

    # Tỷ lệ độ dài hợp lý
    ratio = len_src / len_tgt if len_tgt > 0 else 0
    if ratio < min_ratio or ratio > max_ratio:
        return False

    # Không phải toàn số/ký tự đặc biệt
    if src.isdigit() or tgt.isdigit():
        return False

    # Có ít nhất 1 chữ cái
    if not any(c.isalpha() for c in src):
        return False

    return True

def get_hash(text: str) -> str:
    """Tạo hash để detect duplicate"""
    return hashlib.md5(text.lower().encode('utf-8')).hexdigest()[:16]

def deduplicate_pairs(pairs: List[Dict], key_fields: Tuple[str, str]) -> List[Dict]:
    """Loại bỏ duplicate dựa trên hash của src+tgt"""
    seen = set()
    unique = []

    for p in pairs:
        # Tạo key từ cả src và tgt
        key = get_hash(p[key_fields[0]] + "|||" + p[key_fields[1]])
        if key not in seen:
            seen.add(key)
            unique.append(p)

    return unique

# ==================== DATA STORAGE ====================
all_pairs = {
    'en_vi': [],
    'en_zh': [],
    'zh_vi': []
}

stats = {
    'raw': defaultdict(int),
    'filtered': defaultdict(int),
    'deduplicated': defaultdict(int)
}

# ============================================================
# 1. EN-VI: IWSLT (HuggingFace)
# ============================================================
print("\n" + "=" * 70)
print("1️⃣ EN-VI: IWSLT (TED Talks)")
print("=" * 70)

try:
    from datasets import load_dataset

    dataset_configs = [
        ("iwslt2017", "iwslt2017-en-vi"),
        ("IWSLT/mt_eng_vietnamese", "iwslt2015-vi-en"),
    ]

    for ds_name, config in dataset_configs:
        try:
            print(f"   Trying {ds_name}/{config}...")
            ds = load_dataset(ds_name, config, trust_remote_code=True)

            raw_count = 0
            valid_count = 0

            for split in ds.keys():
                for item in tqdm(ds[split], desc=f"   {split}", leave=False):
                    raw_count += 1

                    translation = item.get('translation', item)
                    en = clean_text(translation.get('en', ''))
                    vi = clean_text(translation.get('vi', ''))

                    if is_valid_pair(en, vi):
                        all_pairs['en_vi'].append({
                            'en': en,
                            'vi': vi,
                            'source': ds_name.split('/')[-1]
                        })
                        valid_count += 1

            stats['raw']['en_vi'] += raw_count
            stats['filtered']['en_vi'] += valid_count

            if valid_count > 0:
                print(f"   ✅ Loaded {valid_count:,} valid pairs (from {raw_count:,} raw)")
                break

        except Exception as e:
            print(f"   ⚠️ {ds_name}: {str(e)[:60]}")
            continue

except ImportError:
    print("   ❌ datasets library not installed")


# ============================================================
# 2. EN-ZH: OPUS100 (HuggingFace)
# ============================================================
print("\n" + "=" * 70)
print("2️⃣ EN-ZH: OPUS100")
print("=" * 70)

try:
    from datasets import load_dataset

    ds = load_dataset("opus100", "en-zh", trust_remote_code=True)

    max_samples = 500000  # Giới hạn
    raw_count = 0
    valid_count = 0

    for split in ds.keys():
        if valid_count >= max_samples:
            break

        for item in tqdm(ds[split], desc=f"   {split}", leave=False):
            if valid_count >= max_samples:
                break

            raw_count += 1

            translation = item.get('translation', item)
            en = clean_text(translation.get('en', ''))
            zh = clean_text(translation.get('zh', ''))

            if is_valid_pair(en, zh, min_len=3):  # ZH có thể ngắn hơn
                all_pairs['en_zh'].append({
                    'en': en,
                    'zh': zh,
                    'source': 'opus100'
                })
                valid_count += 1

    stats['raw']['en_zh'] = raw_count
    stats['filtered']['en_zh'] = valid_count

    print(f"   ✅ Loaded {valid_count:,} valid pairs (from {raw_count:,} raw)")

except Exception as e:
    print(f"   ❌ Error loading EN-ZH: {e}")

# ============================================================
# 3. ZH-VI: OpenSubtitles (OPUS)
# ============================================================
print("\n" + "=" * 70)
print("3️⃣ ZH-VI: OpenSubtitles")
print("=" * 70)

import subprocess
import glob

OPENSUB_URL = "https://object.pouta.csc.fi/OPUS-OpenSubtitles/v2024/moses/vi-zh_CN.txt.zip"
OPENSUB_DIR = "/content/opensub_vi_zh"

try:
    # Download
    print("   Downloading OpenSubtitles ZH-VI...")
    subprocess.run(
        ["wget", "-q", "--show-progress", "-nc", OPENSUB_URL, "-O", "/content/opensub.zip"],
        check=True
    )

    # Unzip
    print("   Extracting...")
    os.makedirs(OPENSUB_DIR, exist_ok=True)
    subprocess.run(
        ["unzip", "-q", "-o", "/content/opensub.zip", "-d", OPENSUB_DIR],
        check=True
    )

    # Find files
    vi_files = glob.glob(f'{OPENSUB_DIR}/*.vi')
    zh_files = glob.glob(f'{OPENSUB_DIR}/*.zh*')

    if vi_files and zh_files:
        vi_file = vi_files[0]
        zh_file = zh_files[0]

        print(f"   Found: {os.path.basename(vi_file)}, {os.path.basename(zh_file)}")

        max_samples = 2000000  # Giới hạn 2M (đủ để train tốt)
        raw_count = 0
        valid_count = 0

        with open(vi_file, 'r', encoding='utf-8', errors='ignore') as f_vi, \
             open(zh_file, 'r', encoding='utf-8', errors='ignore') as f_zh:

            for vi_line, zh_line in tqdm(zip(f_vi, f_zh), desc="   Processing", total=max_samples):
                if valid_count >= max_samples:
                    break

                raw_count += 1

                vi = clean_text(vi_line)
                zh = clean_text(zh_line)

                # Filter cho subtitle (thường ngắn hơn)
                if is_valid_pair(zh, vi, min_len=2, max_len=200, min_ratio=0.1, max_ratio=10):
                    all_pairs['zh_vi'].append({
                        'zh': zh,
                        'vi': vi,
                        'source': 'opensubtitles'
                    })
                    valid_count += 1

        stats['raw']['zh_vi'] = raw_count
        stats['filtered']['zh_vi'] = valid_count

        print(f"   ✅ Loaded {valid_count:,} valid pairs (from {raw_count:,} raw)")
    else:
        print(f"   ❌ Files not found in {OPENSUB_DIR}")

except Exception as e:
    print(f"   ❌ Error: {e}")

# Thử MultiCCAligned nếu OpenSubtitles không đủ
if len(all_pairs['zh_vi']) < 100000:
    print("\n   Trying MultiCCAligned as backup...")

    MULTICC_URL = "https://object.pouta.csc.fi/OPUS-MultiCCAligned/v1.1/moses/vi-zh_CN.txt.zip"

    try:
        subprocess.run(
            ["wget", "-q", "-nc", MULTICC_URL, "-O", "/content/multicc.zip"],
            check=True
        )

        multicc_dir = "/content/multicc_vi_zh"
        os.makedirs(multicc_dir, exist_ok=True)
        subprocess.run(
            ["unzip", "-q", "-o", "/content/multicc.zip", "-d", multicc_dir],
            check=True
        )

        vi_files = glob.glob(f'{multicc_dir}/*.vi')
        zh_files = glob.glob(f'{multicc_dir}/*.zh*')

        if vi_files and zh_files:
            count_before = len(all_pairs['zh_vi'])

            with open(vi_files[0], 'r', encoding='utf-8', errors='ignore') as f_vi, \
                 open(zh_files[0], 'r', encoding='utf-8', errors='ignore') as f_zh:

                for vi_line, zh_line in tqdm(zip(f_vi, f_zh), desc="   MultiCC"):
                    vi = clean_text(vi_line)
                    zh = clean_text(zh_line)

                    if is_valid_pair(zh, vi, min_len=2, max_len=200):
                        all_pairs['zh_vi'].append({
                            'zh': zh,
                            'vi': vi,
                            'source': 'multiccaligned'
                        })

            added = len(all_pairs['zh_vi']) - count_before
            print(f"   ✅ Added {added:,} pairs from MultiCCAligned")

    except Exception as e:
        print(f"   ⚠️ MultiCCAligned error: {e}")

# ============================================================
# 4. DEDUPLICATION
# ============================================================
print("\n" + "=" * 70)
print("4️⃣ DEDUPLICATION")
print("=" * 70)

for pair_type in all_pairs:
    before = len(all_pairs[pair_type])

    if pair_type == 'en_vi':
        all_pairs[pair_type] = deduplicate_pairs(all_pairs[pair_type], ('en', 'vi'))
    elif pair_type == 'en_zh':
        all_pairs[pair_type] = deduplicate_pairs(all_pairs[pair_type], ('en', 'zh'))
    elif pair_type == 'zh_vi':
        all_pairs[pair_type] = deduplicate_pairs(all_pairs[pair_type], ('zh', 'vi'))

    after = len(all_pairs[pair_type])
    removed = before - after
    stats['deduplicated'][pair_type] = after

    print(f"   {pair_type}: {before:,} → {after:,} (removed {removed:,} duplicates)")


# ============================================================
# 5. SUMMARY & STATISTICS
# ============================================================
print("\n" + "=" * 70)
print("5️⃣ DATA SUMMARY")
print("=" * 70)

total = sum(len(v) for v in all_pairs.values())

print(f"""
📊 TRANSLATION PAIRS (after cleaning & dedup):
   EN-VI: {len(all_pairs['en_vi']):,} pairs
   EN-ZH: {len(all_pairs['en_zh']):,} pairs
   ZH-VI: {len(all_pairs['zh_vi']):,} pairs
   ─────────────────────────
   TOTAL: {total:,} pairs
""")

# Thống kê theo source
print("📊 By source:")
for pair_type, pairs in all_pairs.items():
    if not pairs:
        continue
    print(f"\n   {pair_type.upper()}:")
    source_counts = defaultdict(int)
    for p in pairs:
        source_counts[p.get('source', 'unknown')] += 1
    for src, count in sorted(source_counts.items(), key=lambda x: -x[1]):
        print(f"      {src}: {count:,}")

# ============================================================
# 6. CREATE TRAIN/VAL/TEST SPLITS (Stratified)
# ============================================================
print("\n" + "=" * 70)
print("6️⃣ CREATING STRATIFIED SPLITS")
print("=" * 70)

def create_stratified_splits(pairs: List[Dict],
                             train_ratio: float = 0.90,
                             val_ratio: float = 0.05) -> Dict[str, List[Dict]]:
    """
    Tạo splits với stratification theo source
    Đảm bảo val/test không overlap với train
    """
    # Group by source
    by_source = defaultdict(list)
    for p in pairs:
        by_source[p.get('source', 'unknown')].append(p)

    train, val, test = [], [], []

    for source, source_pairs in by_source.items():
        random.shuffle(source_pairs)
        n = len(source_pairs)

        train_end = int(n * train_ratio)
        val_end = int(n * (train_ratio + val_ratio))

        train.extend(source_pairs[:train_end])
        val.extend(source_pairs[train_end:val_end])
        test.extend(source_pairs[val_end:])

    # Shuffle each split
    random.shuffle(train)
    random.shuffle(val)
    random.shuffle(test)

    return {'train': train, 'val': val, 'test': test}

# Tạo unified format
print("   Creating unified format...")

all_data = []

# EN-VI
for p in all_pairs['en_vi']:
    all_data.append({
        'en': p['en'],
        'vi': p['vi'],
        'zh': '',
        'pair_type': 'en-vi',
        'source': p.get('source', '')
    })

# EN-ZH
for p in all_pairs['en_zh']:
    all_data.append({
        'en': p['en'],
        'vi': '',
        'zh': p['zh'],
        'pair_type': 'en-zh',
        'source': p.get('source', '')
    })

# ZH-VI
for p in all_pairs['zh_vi']:
    all_data.append({
        'en': '',
        'vi': p['vi'],
        'zh': p['zh'],
        'pair_type': 'zh-vi',
        'source': p.get('source', '')
    })

print(f"   Total unified samples: {len(all_data):,}")

# Stratified split
splits = create_stratified_splits(all_data)

print(f"""
   Splits:
      Train: {len(splits['train']):,} ({len(splits['train'])/len(all_data)*100:.1f}%)
      Val:   {len(splits['val']):,} ({len(splits['val'])/len(all_data)*100:.1f}%)
      Test:  {len(splits['test']):,} ({len(splits['test'])/len(all_data)*100:.1f}%)
""")

# Verify no overlap
train_hashes = set(get_hash(json.dumps(d, sort_keys=True)) for d in splits['train'])
val_hashes = set(get_hash(json.dumps(d, sort_keys=True)) for d in splits['val'])
test_hashes = set(get_hash(json.dumps(d, sort_keys=True)) for d in splits['test'])

train_val_overlap = len(train_hashes & val_hashes)
train_test_overlap = len(train_hashes & test_hashes)
val_test_overlap = len(val_hashes & test_hashes)

print(f"   Overlap check:")
print(f"      Train-Val: {train_val_overlap}")
print(f"      Train-Test: {train_test_overlap}")
print(f"      Val-Test: {val_test_overlap}")

if train_val_overlap + train_test_overlap + val_test_overlap == 0:
    print("   ✅ No overlap between splits!")
else:
    print("   ⚠️ Some overlap detected (may be due to duplicate sources)")

# ============================================================
# 7. SAVE DATA
# ============================================================
print("\n" + "=" * 70)
print("7️⃣ SAVING DATA")
print("=" * 70)

# Save as JSONL
for name, data in splits.items():
    path = f'{DATA_DIR}/{name}.jsonl'
    with open(path, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    print(f"   ✅ {name}: {len(data):,} samples → {path}")

# Save metadata
metadata = {
    'total': len(all_data),
    'en_vi': len(all_pairs['en_vi']),
    'en_zh': len(all_pairs['en_zh']),
    'zh_vi': len(all_pairs['zh_vi']),
    'train': len(splits['train']),
    'val': len(splits['val']),
    'test': len(splits['test']),
    'stats': {
        'raw': dict(stats['raw']),
        'filtered': dict(stats['filtered']),
        'deduplicated': dict(stats['deduplicated'])
    }
}

with open(f'{DATA_DIR}/metadata.json', 'w') as f:
    json.dump(metadata, f, indent=2)

print(f"   ✅ Metadata saved")

# Save individual pair files (for reference)
for pair_type, pairs in all_pairs.items():
    if pairs:
        path = f'{DATA_DIR}/{pair_type}_pairs.json'
        with open(path, 'w', encoding='utf-8') as f:
            json.dump(pairs[:10000], f, ensure_ascii=False, indent=2)  # Save sample
        print(f"   ✅ {pair_type} sample (10k) → {path}")

# ============================================================
# 8. SAMPLE DATA
# ============================================================
print("\n" + "=" * 70)
print("📋 SAMPLE DATA")
print("=" * 70)

for pair_type, pairs in all_pairs.items():
    if not pairs:
        continue

    print(f"\n{pair_type.upper()} samples:")
    print("-" * 60)
    samples = random.sample(pairs, min(3, len(pairs)))
    for s in samples:
        if pair_type == 'en_vi':
            print(f"   EN: {s['en'][:70]}{'...' if len(s['en']) > 70 else ''}")
            print(f"   VI: {s['vi'][:70]}{'...' if len(s['vi']) > 70 else ''}")
        elif pair_type == 'en_zh':
            print(f"   EN: {s['en'][:70]}{'...' if len(s['en']) > 70 else ''}")
            print(f"   ZH: {s['zh'][:70]}{'...' if len(s['zh']) > 70 else ''}")
        elif pair_type == 'zh_vi':
            print(f"   ZH: {s['zh'][:70]}{'...' if len(s['zh']) > 70 else ''}")
            print(f"   VI: {s['vi'][:70]}{'...' if len(s['vi']) > 70 else ''}")
        print()

# ============================================================
# SUMMARY
# ============================================================
print("\n" + "=" * 70)
print("✅ TRANSLATION DATA DOWNLOAD COMPLETE!")
print("=" * 70)
print(f"""
📊 FINAL SUMMARY:

   EN-VI: {len(all_pairs['en_vi']):,} pairs (IWSLT)
   EN-ZH: {len(all_pairs['en_zh']):,} pairs (OPUS100)
   ZH-VI: {len(all_pairs['zh_vi']):,} pairs (OpenSubtitles)

   TOTAL: {len(all_data):,} samples

   Splits:
      Train: {len(splits['train']):,}
      Val:   {len(splits['val']):,}
      Test:  {len(splits['test']):,}

📁 Files saved to: {DATA_DIR}/
   - train.jsonl
   - val.jsonl
   - test.jsonl
   - metadata.json
   - *_pairs.json (samples)

🔧 Improvements in V3:
   - ✅ Text cleaning & normalization
   - ✅ Quality filtering (length, ratio)
   - ✅ Deduplication
   - ✅ Stratified splits (no overlap)
   - ✅ Error handling for encoding issues

🚀 Next: Run 08_train_mbart_v2.py
""")

📥 DOWNLOADING TRANSLATION DATA (V3 - IMPROVED)

1️⃣ EN-VI: IWSLT (TED Talks)


`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'iwslt2017' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'iwslt2017' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


   Trying iwslt2017/iwslt2017-en-vi...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

iwslt2017.py: 0.00B [00:00, ?B/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'IWSLT/mt_eng_vietnamese' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'IWSLT/mt_eng_vietnamese' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


   ⚠️ iwslt2017: Dataset scripts are no longer supported, but found iwslt2017
   Trying IWSLT/mt_eng_vietnamese/iwslt2015-vi-en...


README.md: 0.00B [00:00, ?B/s]

mt_eng_vietnamese.py: 0.00B [00:00, ?B/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'opus100' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'opus100' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


   ⚠️ IWSLT/mt_eng_vietnamese: Dataset scripts are no longer supported, but found mt_eng_vi

2️⃣ EN-ZH: OPUS100


README.md: 0.00B [00:00, ?B/s]

en-zh/test-00000-of-00001.parquet:   0%|          | 0.00/355k [00:00<?, ?B/s]

en-zh/train-00000-of-00001.parquet:   0%|          | 0.00/143M [00:00<?, ?B/s]

en-zh/validation-00000-of-00001.parquet:   0%|          | 0.00/359k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/2000 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/1000000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/2000 [00:00<?, ? examples/s]

   test:   0%|          | 0/2000 [00:00<?, ?it/s]

   train:   0%|          | 0/1000000 [00:00<?, ?it/s]

   ✅ Loaded 500,000 valid pairs (from 584,527 raw)

3️⃣ ZH-VI: OpenSubtitles
   Downloading OpenSubtitles ZH-VI...
   Extracting...
   Found: OpenSubtitles.vi-zh_CN.vi, OpenSubtitles.vi-zh_CN.zh_CN


   Processing:   0%|          | 0/2000000 [00:00<?, ?it/s]

   ✅ Loaded 2,000,000 valid pairs (from 2,038,262 raw)

4️⃣ DEDUPLICATION
   en_vi: 0 → 0 (removed 0 duplicates)
   en_zh: 500,000 → 483,886 (removed 16,114 duplicates)
   zh_vi: 2,000,000 → 1,911,660 (removed 88,340 duplicates)

5️⃣ DATA SUMMARY

📊 TRANSLATION PAIRS (after cleaning & dedup):
   EN-VI: 0 pairs
   EN-ZH: 483,886 pairs
   ZH-VI: 1,911,660 pairs
   ─────────────────────────
   TOTAL: 2,395,546 pairs

📊 By source:

   EN_ZH:
      opus100: 483,886

   ZH_VI:
      opensubtitles: 1,911,660

6️⃣ CREATING STRATIFIED SPLITS
   Creating unified format...
   Total unified samples: 2,395,546

   Splits:
      Train: 2,155,991 (90.0%)
      Val:   119,777 (5.0%)
      Test:  119,778 (5.0%)

   Overlap check:
      Train-Val: 0
      Train-Test: 0
      Val-Test: 0
   ✅ No overlap between splits!

7️⃣ SAVING DATA
   ✅ train: 2,155,991 samples → /content/tri-lingua/data/translation/train.jsonl
   ✅ val: 119,777 samples → /content/tri-lingua/data/translation/val.jsonl
   ✅ test: 119,

In [None]:
# Ô code 3
# ============================================================
# 🌐 TRI-LINGUA - TRAIN mBART-50 TRANSLATION MODEL (V2 - FIXED)
# ============================================================
# Fine-tune mBART-50 cho dịch tam ngữ EN ↔ VI ↔ ZH
# Optimized cho Google Colab A100
#
# FIXES V2:
# - ✅ Tokenization đúng cách cho mBART (labels với forced_bos)
# - ✅ Labels padding với -100 (ignore trong loss)
# - ✅ Thêm BLEU metric để đánh giá
# - ✅ Data balancing (upsampling cặp ít data)
# - ✅ Gradient checkpointing để tiết kiệm VRAM
# - ✅ Mixed precision training ổn định hơn
# ============================================================

# !pip install -q transformers>=4.36.0 datasets accelerate sentencepiece sacremoses sacrebleu

import os
import json
import torch
import random
import numpy as np
from pathlib import Path
from tqdm.auto import tqdm
from dataclasses import dataclass
from typing import Dict, List, Optional, Any
from collections import Counter

from transformers import (
    MBartForConditionalGeneration,
    MBart50TokenizerFast,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback
)
from datasets import Dataset
import evaluate  # Thư viện mới cho metrics

# Tắt wandb
os.environ["WANDB_DISABLED"] = "true"

# Set seed cho reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

# ==================== PATHS ====================
DATA_DIR = '/content/tri-lingua/data/translation'
MODEL_DIR = '/content/tri-lingua/models/translation'
FINAL_MODEL_DIR = f'{MODEL_DIR}/final'

Path(MODEL_DIR).mkdir(parents=True, exist_ok=True)
Path(FINAL_MODEL_DIR).mkdir(parents=True, exist_ok=True)

# ==================== CONFIG ====================
@dataclass
class TrainingConfig:
    base_model: str = 'facebook/mbart-large-50-many-to-many-mmt'
    max_length: int = 128

    # Training - Optimized cho A100 80GB với dataset 6M+
    batch_size: int = 16          # Giảm để ổn định hơn với gradient checkpointing
    gradient_accumulation_steps: int = 4  # Effective batch = 64
    learning_rate: float = 2e-5   # Giảm nhẹ để ổn định
    num_epochs: int = 2           # 2 epochs cho convergence tốt hơn
    warmup_ratio: float = 0.05    # Tăng warmup
    weight_decay: float = 0.01
    max_grad_norm: float = 1.0    # Gradient clipping

    # Checkpoint & Evaluation
    eval_steps: int = 3000        # Eval thường xuyên hơn
    save_steps: int = 3000
    save_total_limit: int = 3
    logging_steps: int = 200

    # Performance
    fp16: bool = True
    gradient_checkpointing: bool = True  # Tiết kiệm VRAM
    dataloader_num_workers: int = 4
    dataloader_pin_memory: bool = True

    # Data balancing
    balance_data: bool = True     # Cân bằng các cặp ngôn ngữ
    max_samples_per_direction: int = 500000  # Giới hạn mỗi direction

    # Sampling
    max_train_samples: int = -1

CONFIG = TrainingConfig()

# Language codes cho mBART-50
LANG_CODES = {
    'en': 'en_XX',
    'vi': 'vi_VN',
    'zh': 'zh_CN'
}

print("=" * 70)
print("🌐 TRAINING mBART-50 TRANSLATION MODEL (V2 - FIXED)")
print("=" * 70)

# ==================== CHECK GPU ====================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🖥️ Device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"   VRAM: {vram_gb:.1f} GB")

    # Auto-adjust batch size based on VRAM
    if vram_gb >= 70:  # A100 80GB
        CONFIG.batch_size = 16
        CONFIG.gradient_checkpointing = True
    elif vram_gb >= 35:  # A100 40GB
        CONFIG.batch_size = 8
        CONFIG.gradient_checkpointing = True
    else:  # T4/V100
        CONFIG.batch_size = 4
        CONFIG.gradient_accumulation_steps = 16
        CONFIG.gradient_checkpointing = True

    print(f"   → batch_size={CONFIG.batch_size}, grad_accum={CONFIG.gradient_accumulation_steps}")

# ==================== LOAD DATA ====================
print("\n📊 Loading data...")

def load_jsonl(path: str) -> List[Dict]:
    """Load JSONL file với error handling"""
    data = []
    with open(path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            try:
                line = line.strip()
                if line:
                    item = json.loads(line)
                    data.append(item)
            except json.JSONDecodeError as e:
                print(f"   ⚠️ Skip line {i}: {e}")
    return data

train_data = load_jsonl(f'{DATA_DIR}/train.jsonl')
val_data = load_jsonl(f'{DATA_DIR}/val.jsonl')

print(f"   Train: {len(train_data):,} samples")
print(f"   Val: {len(val_data):,} samples")


# ==================== DATA CLEANING ====================
print("\n🧹 Cleaning data...")

def clean_text(text: str) -> str:
    """Clean và normalize text"""
    if not text:
        return ""

    # Remove extra whitespace
    text = ' '.join(text.split())

    # Remove control characters
    text = ''.join(c for c in text if c.isprintable() or c in '\n\t')

    return text.strip()

def is_valid_pair(src: str, tgt: str, min_len: int = 3, max_len: int = 200) -> bool:
    """Kiểm tra cặp dịch có hợp lệ không"""
    if not src or not tgt:
        return False

    # Độ dài hợp lý
    if len(src) < min_len or len(tgt) < min_len:
        return False
    if len(src) > max_len or len(tgt) > max_len:
        return False

    # Tỷ lệ độ dài hợp lý (tránh câu bị cắt/lỗi)
    ratio = len(src) / len(tgt) if len(tgt) > 0 else 0
    if ratio < 0.1 or ratio > 10:
        return False

    return True

# ==================== CREATE TRANSLATION PAIRS ====================
print("\n🔄 Creating translation pairs...")

def create_translation_pairs(data: List[Dict], balance: bool = False) -> List[Dict]:
    """Tạo cặp dịch cho tất cả directions với data balancing"""
    pairs_by_direction = {
        'en_vi': [], 'vi_en': [],
        'en_zh': [], 'zh_en': [],
        'vi_zh': [], 'zh_vi': []
    }

    seen = set()  # Để loại bỏ duplicate

    for item in tqdm(data, desc="   Processing"):
        en = clean_text(item.get('en', ''))
        vi = clean_text(item.get('vi', ''))
        zh = clean_text(item.get('zh', ''))

        # EN ↔ VI
        if en and vi and is_valid_pair(en, vi):
            key = f"en_vi:{en}:{vi}"
            if key not in seen:
                seen.add(key)
                pairs_by_direction['en_vi'].append({
                    'src': en, 'tgt': vi,
                    'src_lang': 'en_XX', 'tgt_lang': 'vi_VN'
                })
                pairs_by_direction['vi_en'].append({
                    'src': vi, 'tgt': en,
                    'src_lang': 'vi_VN', 'tgt_lang': 'en_XX'
                })

        # EN ↔ ZH
        if en and zh and is_valid_pair(en, zh):
            key = f"en_zh:{en}:{zh}"
            if key not in seen:
                seen.add(key)
                pairs_by_direction['en_zh'].append({
                    'src': en, 'tgt': zh,
                    'src_lang': 'en_XX', 'tgt_lang': 'zh_CN'
                })
                pairs_by_direction['zh_en'].append({
                    'src': zh, 'tgt': en,
                    'src_lang': 'zh_CN', 'tgt_lang': 'en_XX'
                })

        # VI ↔ ZH
        if vi and zh and is_valid_pair(vi, zh):
            key = f"vi_zh:{vi}:{zh}"
            if key not in seen:
                seen.add(key)
                pairs_by_direction['vi_zh'].append({
                    'src': vi, 'tgt': zh,
                    'src_lang': 'vi_VN', 'tgt_lang': 'zh_CN'
                })
                pairs_by_direction['zh_vi'].append({
                    'src': zh, 'tgt': vi,
                    'src_lang': 'zh_CN', 'tgt_lang': 'vi_VN'
                })

    # Thống kê
    print("\n   Distribution before balancing:")
    for direction, pairs in pairs_by_direction.items():
        print(f"      {direction}: {len(pairs):,}")

    # Data balancing: upsampling cặp ít, downsampling cặp nhiều
    if balance:
        print("\n   Balancing data...")

        # Tìm số lượng target (median)
        counts = [len(p) for p in pairs_by_direction.values() if len(p) > 0]
        if counts:
            target_count = min(
                int(np.median(counts)),
                CONFIG.max_samples_per_direction
            )

            for direction in pairs_by_direction:
                current = len(pairs_by_direction[direction])
                if current == 0:
                    continue

                if current < target_count:
                    # Upsampling
                    multiplier = target_count // current
                    remainder = target_count % current
                    pairs_by_direction[direction] = (
                        pairs_by_direction[direction] * multiplier +
                        random.sample(pairs_by_direction[direction], remainder)
                    )
                elif current > target_count:
                    # Downsampling
                    pairs_by_direction[direction] = random.sample(
                        pairs_by_direction[direction], target_count
                    )

            print("   Distribution after balancing:")
            for direction, pairs in pairs_by_direction.items():
                print(f"      {direction}: {len(pairs):,}")

    # Gộp tất cả
    all_pairs = []
    for pairs in pairs_by_direction.values():
        all_pairs.extend(pairs)

    return all_pairs

train_pairs = create_translation_pairs(train_data, balance=CONFIG.balance_data)
val_pairs = create_translation_pairs(val_data, balance=False)  # Không balance val

# Shuffle train pairs
random.shuffle(train_pairs)

# Giới hạn nếu cần
if CONFIG.max_train_samples > 0 and len(train_pairs) > CONFIG.max_train_samples:
    print(f"\n   ⚠️ Limiting train: {len(train_pairs):,} → {CONFIG.max_train_samples:,}")
    train_pairs = train_pairs[:CONFIG.max_train_samples]

print(f"\n   Final train pairs: {len(train_pairs):,}")
print(f"   Final val pairs: {len(val_pairs):,}")

# ==================== LOAD MODEL & TOKENIZER ====================
print(f"\n📦 Loading {CONFIG.base_model}...")

tokenizer = MBart50TokenizerFast.from_pretrained(CONFIG.base_model)
model = MBartForConditionalGeneration.from_pretrained(
    CONFIG.base_model,
    torch_dtype=torch.float16 if CONFIG.fp16 and torch.cuda.is_available() else torch.float32
)

# Enable gradient checkpointing để tiết kiệm VRAM
if CONFIG.gradient_checkpointing:
    model.gradient_checkpointing_enable()
    print("   ✅ Gradient checkpointing enabled")

print(f"✅ Model loaded!")
print(f"   Parameters: {model.num_parameters():,}")


# ==================== TOKENIZE DATA (FIXED) ====================
print("\n🔤 Tokenizing data (FIXED method)...")

def tokenize_pairs_fixed(pairs: List[Dict], tokenizer, max_length: int) -> Dataset:
    """
    Tokenize đúng cách cho mBART:
    - Source: tokenize với src_lang
    - Target: tokenize với tgt_lang, thêm forced_bos_token
    - Labels: padding với -100 để ignore trong loss
    """
    input_ids_list = []
    attention_mask_list = []
    labels_list = []

    pad_token_id = tokenizer.pad_token_id

    for pair in tqdm(pairs, desc="   Tokenizing"):
        # === SOURCE ===
        tokenizer.src_lang = pair['src_lang']
        source = tokenizer(
            pair['src'],
            max_length=max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )

        # === TARGET ===
        # Quan trọng: Set tgt_lang để tokenizer thêm đúng language token
        tokenizer.src_lang = pair['tgt_lang']

        # Tokenize target
        target = tokenizer(
            pair['tgt'],
            max_length=max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )

        # Lấy labels và thay padding bằng -100
        labels = target['input_ids'].squeeze().tolist()
        labels = [
            label if label != pad_token_id else -100
            for label in labels
        ]

        input_ids_list.append(source['input_ids'].squeeze().tolist())
        attention_mask_list.append(source['attention_mask'].squeeze().tolist())
        labels_list.append(labels)

    # Tạo HuggingFace Dataset
    dataset = Dataset.from_dict({
        'input_ids': input_ids_list,
        'attention_mask': attention_mask_list,
        'labels': labels_list
    })

    return dataset

train_dataset = tokenize_pairs_fixed(train_pairs, tokenizer, CONFIG.max_length)
val_dataset = tokenize_pairs_fixed(val_pairs, tokenizer, CONFIG.max_length)

print(f"   ✅ Train dataset: {len(train_dataset):,} samples")
print(f"   ✅ Val dataset: {len(val_dataset):,} samples")

# ==================== METRICS (BLEU) ====================
print("\n📊 Setting up metrics...")

# Load BLEU metric
try:
    bleu_metric = evaluate.load("sacrebleu")
    print("   ✅ BLEU metric loaded")
except Exception as e:
    print(f"   ⚠️ Could not load BLEU: {e}")
    bleu_metric = None

def compute_metrics(eval_preds):
    """Compute BLEU score cho evaluation"""
    if bleu_metric is None:
        return {}

    preds, labels = eval_preds

    # Decode predictions
    if isinstance(preds, tuple):
        preds = preds[0]

    # Replace -100 với pad_token_id
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    # Decode
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # BLEU expects list of references for each prediction
    decoded_labels = [[label] for label in decoded_labels]

    # Compute BLEU
    result = bleu_metric.compute(predictions=decoded_preds, references=decoded_labels)

    return {"bleu": result["score"]}

# ==================== TRAINING ARGUMENTS ====================
print("\n🚀 Setting up training...")

# Tính số steps
total_train_samples = len(train_dataset)
steps_per_epoch = total_train_samples // (CONFIG.batch_size * CONFIG.gradient_accumulation_steps)
total_steps = steps_per_epoch * CONFIG.num_epochs

print(f"""
   Batch size: {CONFIG.batch_size}
   Gradient accumulation: {CONFIG.gradient_accumulation_steps}
   Effective batch size: {CONFIG.batch_size * CONFIG.gradient_accumulation_steps}
   Learning rate: {CONFIG.learning_rate}
   Epochs: {CONFIG.num_epochs}
   Steps per epoch: {steps_per_epoch:,}
   Total steps: {total_steps:,}
   Eval/Save every: {CONFIG.eval_steps} steps
   Gradient checkpointing: {CONFIG.gradient_checkpointing}
""")

training_args = Seq2SeqTrainingArguments(
    output_dir=MODEL_DIR,

    # Training
    num_train_epochs=CONFIG.num_epochs,
    per_device_train_batch_size=CONFIG.batch_size,
    per_device_eval_batch_size=CONFIG.batch_size,
    gradient_accumulation_steps=CONFIG.gradient_accumulation_steps,

    # Optimizer
    learning_rate=CONFIG.learning_rate,
    weight_decay=CONFIG.weight_decay,
    warmup_ratio=CONFIG.warmup_ratio,
    lr_scheduler_type='cosine',  # Cosine schedule thường tốt hơn linear
    max_grad_norm=CONFIG.max_grad_norm,

    # Evaluation & Checkpointing
    eval_strategy='steps',
    eval_steps=CONFIG.eval_steps,
    save_strategy='steps',
    save_steps=CONFIG.save_steps,
    save_total_limit=CONFIG.save_total_limit,

    # Best model
    load_best_model_at_end=True,
    metric_for_best_model='eval_loss',
    greater_is_better=False,

    # Generation (cho BLEU)
    predict_with_generate=True if bleu_metric else False,
    generation_max_length=CONFIG.max_length,
    generation_num_beams=4,

    # Logging
    logging_dir=f'{MODEL_DIR}/logs',
    logging_steps=CONFIG.logging_steps,
    report_to='none',

    # Performance
    fp16=CONFIG.fp16 and torch.cuda.is_available(),
    dataloader_num_workers=CONFIG.dataloader_num_workers,
    dataloader_pin_memory=CONFIG.dataloader_pin_memory,

    # Gradient checkpointing
    gradient_checkpointing=CONFIG.gradient_checkpointing,

    # Resume
    ignore_data_skip=False,
)

# ==================== DATA COLLATOR ====================
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True,
    pad_to_multiple_of=8,
    label_pad_token_id=-100
)

# ==================== TRAINER ====================
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics if bleu_metric else None,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
)

print("✅ Trainer initialized")

# ==================== FIND CHECKPOINT TO RESUME ====================
def find_latest_checkpoint(output_dir: str) -> Optional[str]:
    """Tìm checkpoint mới nhất để resume"""
    output_path = Path(output_dir)
    checkpoints = list(output_path.glob("checkpoint-*"))

    if not checkpoints:
        return None

    # Sort by step number
    checkpoints = sorted(checkpoints, key=lambda x: int(x.name.split("-")[1]))
    latest = checkpoints[-1]

    return str(latest)

latest_checkpoint = find_latest_checkpoint(MODEL_DIR)

# ==================== TRAIN ====================
print("\n" + "=" * 70)
print("🚀 STARTING TRAINING")
print("=" * 70)

if latest_checkpoint:
    print(f"✨ Resuming from checkpoint: {latest_checkpoint}")
    trainer.train(resume_from_checkpoint=latest_checkpoint)
else:
    print("✨ Starting fresh training...")
    trainer.train()


# ==================== SAVE FINAL MODEL ====================
print("\n💾 Saving final model...")

trainer.save_model(FINAL_MODEL_DIR)
tokenizer.save_pretrained(FINAL_MODEL_DIR)

# Save training config
config_dict = {
    'base_model': CONFIG.base_model,
    'max_length': CONFIG.max_length,
    'batch_size': CONFIG.batch_size,
    'gradient_accumulation_steps': CONFIG.gradient_accumulation_steps,
    'learning_rate': CONFIG.learning_rate,
    'num_epochs': CONFIG.num_epochs,
    'warmup_ratio': CONFIG.warmup_ratio,
    'gradient_checkpointing': CONFIG.gradient_checkpointing,
    'balance_data': CONFIG.balance_data,
    'languages': ['en', 'vi', 'zh'],
    'directions': ['en→vi', 'vi→en', 'en→zh', 'zh→en', 'vi→zh', 'zh→vi'],
    'train_samples': len(train_dataset),
    'val_samples': len(val_dataset),
}

with open(f'{FINAL_MODEL_DIR}/training_config.json', 'w') as f:
    json.dump(config_dict, f, indent=2, ensure_ascii=False)

print(f"✅ Model saved to: {FINAL_MODEL_DIR}")

# ==================== COMPREHENSIVE TEST ====================
print("\n" + "=" * 70)
print("🧪 COMPREHENSIVE TEST")
print("=" * 70)

model.eval()
model.to(device)

def translate(text: str, src_lang: str, tgt_lang: str, num_beams: int = 5) -> str:
    """Translation với proper language handling"""
    tokenizer.src_lang = LANG_CODES[src_lang]

    inputs = tokenizer(text, return_tensors='pt', max_length=128, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        generated = model.generate(
            **inputs,
            forced_bos_token_id=tokenizer.lang_code_to_id[LANG_CODES[tgt_lang]],
            max_length=128,
            num_beams=num_beams,
            early_stopping=True,
            no_repeat_ngram_size=3,
        )

    return tokenizer.decode(generated[0], skip_special_tokens=True)

# Test cases cho tất cả directions
test_cases = [
    # EN → VI
    ("Hello, how are you?", "en", "vi", "Xin chào, bạn khỏe không?"),
    ("The weather is beautiful today.", "en", "vi", "Thời tiết hôm nay đẹp."),
    ("I love learning languages.", "en", "vi", "Tôi thích học ngôn ngữ."),

    # VI → EN
    ("Xin chào, bạn khỏe không?", "vi", "en", "Hello, how are you?"),
    ("Việt Nam là một đất nước xinh đẹp.", "vi", "en", "Vietnam is a beautiful country."),

    # EN → ZH
    ("Hello world", "en", "zh", "你好世界"),
    ("Thank you very much.", "en", "zh", "非常感谢。"),

    # ZH → EN
    ("你好世界", "zh", "en", "Hello world"),
    ("谢谢你的帮助。", "zh", "en", "Thank you for your help."),

    # VI → ZH
    ("Xin chào", "vi", "zh", "你好"),
    ("Cảm ơn bạn.", "vi", "zh", "谢谢你。"),

    # ZH → VI
    ("你好", "zh", "vi", "Xin chào"),
    ("谢谢", "zh", "vi", "Cảm ơn"),
]

print("\nTest translations:")
print("-" * 70)

results = []
for text, src, tgt, expected in test_cases:
    try:
        result = translate(text, src, tgt)
        status = "✅" if result.strip() else "⚠️"
        print(f"[{src}→{tgt}] {text}")
        print(f"   Expected: {expected}")
        print(f"   Got:      {result} {status}")
        print()
        results.append({
            'src': text,
            'tgt': result,
            'expected': expected,
            'direction': f"{src}→{tgt}"
        })
    except Exception as e:
        print(f"[{src}→{tgt}] Error: {e}\n")

# Save test results
with open(f'{FINAL_MODEL_DIR}/test_results.json', 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=2)

# ==================== SUMMARY ====================
print("\n" + "=" * 70)
print("✅ TRAINING COMPLETE!")
print("=" * 70)
print(f"""
📊 Summary:
   Train samples: {len(train_dataset):,}
   Val samples: {len(val_dataset):,}
   Data balanced: {CONFIG.balance_data}
   Gradient checkpointing: {CONFIG.gradient_checkpointing}

📁 Files saved to: {FINAL_MODEL_DIR}/
   - config.json
   - model.safetensors (or pytorch_model.bin)
   - tokenizer files
   - training_config.json
   - test_results.json

📥 To download:
   Run cell 10_download_models.py

   Or manually:
   !zip -r translation_model.zip {FINAL_MODEL_DIR}
   from google.colab import files
   files.download('translation_model.zip')

🔧 Key improvements in V2:
   - ✅ Proper tokenization for mBART
   - ✅ Labels padding with -100
   - ✅ Data balancing across language pairs
   - ✅ Gradient checkpointing for memory efficiency
   - ✅ Cosine learning rate schedule
   - ✅ Comprehensive test suite
""")

🌐 TRAINING mBART-50 TRANSLATION MODEL (V2 - FIXED)
🖥️ Device: cuda
   GPU: NVIDIA A100-SXM4-40GB
   VRAM: 42.5 GB
   → batch_size=8, grad_accum=4

📊 Loading data...
   Train: 2,155,991 samples
   Val: 119,777 samples

🧹 Cleaning data...

🔄 Creating translation pairs...


   Processing:   0%|          | 0/2155991 [00:00<?, ?it/s]


   Distribution before balancing:
      en_vi: 0
      vi_en: 0
      en_zh: 383,870
      zh_en: 383,870
      vi_zh: 1,686,214
      zh_vi: 1,686,214

   Balancing data...
   Distribution after balancing:
      en_vi: 0
      vi_en: 0
      en_zh: 500,000
      zh_en: 500,000
      vi_zh: 500,000
      zh_vi: 500,000


   Processing:   0%|          | 0/119777 [00:00<?, ?it/s]


   Distribution before balancing:
      en_vi: 0
      vi_en: 0
      en_zh: 21,350
      zh_en: 21,350
      vi_zh: 93,739
      zh_vi: 93,739

   Final train pairs: 2,000,000
   Final val pairs: 230,178

📦 Loading facebook/mbart-large-50-many-to-many-mmt...


tokenizer_config.json:   0%|          | 0.00/529 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/649 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/2.44G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/261 [00:00<?, ?B/s]

   ✅ Gradient checkpointing enabled
✅ Model loaded!
   Parameters: 610,879,488

🔤 Tokenizing data (FIXED method)...


   Tokenizing:   0%|          | 0/2000000 [00:00<?, ?it/s]

In [4]:
# Ô code 4 download
# ============================================================
# CELL 9: TEST mBART TRANSLATION MODEL
# ============================================================

"""
CELL 9: Test và đánh giá mBART Translation Model
"""

import torch
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import json

MODEL_DIR = '/content/tri-lingua/models/translation/final'

# Language codes
LANG_CODES = {
    'en': 'en_XX',
    'vi': 'vi_VN',
    'zh': 'zh_CN'
}

print("=" * 60)
print("🧪 TESTING mBART TRANSLATION MODEL")
print("=" * 60)

# ==================== LOAD MODEL ====================
print("\n📦 Loading model...")

tokenizer = MBart50TokenizerFast.from_pretrained(MODEL_DIR)
model = MBartForConditionalGeneration.from_pretrained(MODEL_DIR)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

print(f"✅ Model loaded on {device}")

# ==================== TRANSLATION FUNCTION ====================
def translate(text: str, src_lang: str, tgt_lang: str) -> str:
    """
    Dịch text từ src_lang sang tgt_lang

    Args:
        text: Text cần dịch
        src_lang: 'en', 'vi', 'zh'
        tgt_lang: 'en', 'vi', 'zh'

    Returns:
        Translated text
    """
    # Set source language
    tokenizer.src_lang = LANG_CODES[src_lang]

    # Tokenize
    inputs = tokenizer(text, return_tensors='pt', max_length=128, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Generate with forced_bos_token_id for target language
    forced_bos_token_id = tokenizer.lang_code_to_id[LANG_CODES[tgt_lang]]

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            forced_bos_token_id=forced_bos_token_id,
            max_length=128,
            num_beams=5,
            early_stopping=True
        )

    # Decode
    translated = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return translated

# ==================== TEST CASES ====================
print("\n📊 Test Results:")
print("-" * 70)

test_cases = [
    # EN → VI
    ("Hello, how are you?", "en", "vi"),
    ("The economy is developing rapidly.", "en", "vi"),
    ("I am studying at university.", "en", "vi"),
    ("Education is very important.", "en", "vi"),

    # VI → EN
    ("Xin chào, bạn khỏe không?", "vi", "en"),
    ("Kinh tế đang phát triển nhanh.", "vi", "en"),
    ("Tôi đang học đại học.", "vi", "en"),
    ("Giáo dục rất quan trọng.", "vi", "en"),

    # EN → ZH
    ("Hello, how are you?", "en", "zh"),
    ("The economy is important.", "en", "zh"),
    ("I study science.", "en", "zh"),

    # ZH → EN
    ("你好，你好吗？", "zh", "en"),
    ("经济很重要。", "zh", "en"),
    ("我学习科学。", "zh", "en"),

    # VI → ZH
    ("Xin chào", "vi", "zh"),
    ("Kinh tế rất quan trọng.", "vi", "zh"),
    ("Tôi học khoa học.", "vi", "zh"),

    # ZH → VI
    ("你好", "zh", "vi"),
    ("经济很重要。", "zh", "vi"),
    ("我学习科学。", "zh", "vi"),
]

for text, src, tgt in test_cases:
    result = translate(text, src, tgt)
    print(f"[{src}→{tgt}] {text}")
    print(f"         → {result}")
    print()

# ==================== BATCH TRANSLATION ====================
def batch_translate(texts: list, src_lang: str, tgt_lang: str) -> list:
    """Batch translation"""
    tokenizer.src_lang = LANG_CODES[src_lang]

    inputs = tokenizer(texts, return_tensors='pt', padding=True,
                       max_length=128, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    forced_bos_token_id = tokenizer.lang_code_to_id[LANG_CODES[tgt_lang]]

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            forced_bos_token_id=forced_bos_token_id,
            max_length=128,
            num_beams=5
        )

    return [tokenizer.decode(o, skip_special_tokens=True) for o in outputs]

print("\n📊 Batch Translation Test:")
print("-" * 70)

batch_texts = [
    "Hello world",
    "How are you?",
    "Nice to meet you"
]

results = batch_translate(batch_texts, 'en', 'vi')
for text, result in zip(batch_texts, results):
    print(f"  {text} → {result}")

# ==================== INTERACTIVE TEST ====================
print("\n" + "=" * 60)
print("🎮 Interactive Translation")
print("   Format: <src> <tgt> <text>")
print("   Example: en vi Hello world")
print("   Languages: en, vi, zh")
print("   Type 'quit' to exit")
print("=" * 60)

while True:
    try:
        user_input = input("\n> ").strip()

        if user_input.lower() == 'quit':
            break

        parts = user_input.split(' ', 2)
        if len(parts) < 3:
            print("  ❌ Format: <src> <tgt> <text>")
            continue

        src, tgt, text = parts

        if src not in LANG_CODES or tgt not in LANG_CODES:
            print("  ❌ Languages: en, vi, zh")
            continue

        result = translate(text, src, tgt)
        print(f"  → {result}")

    except KeyboardInterrupt:
        break
    except Exception as e:
        print(f"  ❌ Error: {e}")

print("\n👋 Done!")

In [5]:
# Ô code 5
# ============================================================
# CELL 10: DOWNLOAD TRAINED MODELS
# ============================================================

"""
CELL 10: Đóng gói và download models về máy local
"""

import shutil
import os
from pathlib import Path
from google.colab import files

MODEL_DIR = '/content/tri-lingua/models'

print("=" * 60)
print("📥 PACKAGING & DOWNLOADING MODELS")
print("=" * 60)

# ==================== 1. PHONETIC MODEL ====================
print("\n1️⃣ Packaging Phonetic Model (BiLSTM)...")

phonetic_dir = f'{MODEL_DIR}/phonetic'
if os.path.exists(phonetic_dir):
    # List files
    print("   Files:")
    for f in os.listdir(phonetic_dir):
        size = os.path.getsize(f'{phonetic_dir}/{f}') / 1024
        print(f"      {f}: {size:.1f} KB")

    # Create zip
    shutil.make_archive('/content/phonetic_model_bilstm', 'zip', phonetic_dir)
    print("   ✅ Created phonetic_model_bilstm.zip")
else:
    print("   ⚠️ Phonetic model not found")

# ==================== 2. TRANSLATION MODEL ====================
print("\n2️⃣ Packaging Translation Model (mBART)...")

translation_dir = f'{MODEL_DIR}/translation/final'
if os.path.exists(translation_dir):
    # List files
    print("   Files:")
    total_size = 0
    for f in os.listdir(translation_dir):
        size = os.path.getsize(f'{translation_dir}/{f}') / (1024 * 1024)
        total_size += size
        print(f"      {f}: {size:.1f} MB")
    print(f"   Total: {total_size:.1f} MB")

    # Create zip
    print("   Creating zip (this may take a while for large models)...")
    shutil.make_archive('/content/translation_model_mbart', 'zip', translation_dir)
    print("   ✅ Created translation_model_mbart.zip")
else:
    print("   ⚠️ Translation model not found")

# ==================== 3. DOWNLOAD ====================
print("\n" + "=" * 60)
print("📥 DOWNLOADING...")
print("=" * 60)

print("\n1️⃣ Downloading Phonetic Model...")
if os.path.exists('/content/phonetic_model_bilstm.zip'):
    files.download('/content/phonetic_model_bilstm.zip')
    print("   ✅ Downloaded!")

print("\n2️⃣ Downloading Translation Model...")
print("   ⚠️ mBART model is large (~2.4GB). Download may take a while.")
if os.path.exists('/content/translation_model_mbart.zip'):
    files.download('/content/translation_model_mbart.zip')
    print("   ✅ Downloaded!")

# ==================== 4. USAGE INSTRUCTIONS ====================
print("\n" + "=" * 60)
print("📋 HƯỚNG DẪN SỬ DỤNG")
print("=" * 60)

print("""
1️⃣ Giải nén models:

   cd tri-lingua-bridge/ml/models
   unzip phonetic_model_bilstm.zip -d phonetic/
   unzip translation_model_mbart.zip -d translation/

2️⃣ Cấu trúc thư mục sau khi giải nén:

   ml/models/
   ├── phonetic/
   │   ├── encoder.weights.h5
   │   ├── decoder.weights.h5
   │   ├── input_vocab.json
   │   ├── output_vocab.json
   │   └── model_config.json
   └── translation/
       ├── config.json
       ├── pytorch_model.bin (hoặc model.safetensors)
       ├── tokenizer_config.json
       ├── sentencepiece.bpe.model
       └── special_tokens_map.json

3️⃣ Test Phonetic Model (TensorFlow):

   python ml/inference_phonetic.py

4️⃣ Test Translation Model (PyTorch):

   python ml/inference_translation.py

5️⃣ Tích hợp vào app:

   from ml.inference import TriLinguaML

   ml = TriLinguaML()

   # Phonetic
   pinyin = ml.hanviet_to_pinyin("KINH TẾ")

   # Translation
   result = ml.translate("Hello", "en", "vi")
""")

print("=" * 60)
print("✅ All done!")
print("=" * 60)
