# Set up

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Install necessary libraries
!pip install -q datasets evaluate underthesea jiwer
!pip install transformers==4.48.1

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.9/20.9 MB[0m [31m76.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m657.8/657.8 kB[0m [31m44.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m61.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m82.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers==4.48.1
  Downloading transformers-4.48.1-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
Downloading transformers-4.48.1-py3-none-any.whl (9.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.7/9.7 MB[0m [31m93.1 MB/s[0m eta [36m0:00:00[0m
[?25hI

# Import Libraries

In [None]:
import os
import sys
import random
import re
import pandas as pd
import numpy as np
import torch
from pathlib import Path
from typing import List, Dict, Callable, Tuple, Optional
from dataclasses import dataclass, field
from tqdm.notebook import tqdm
import difflib
from datasets import Dataset, DatasetDict, load_dataset
from transformers import (
    T5Tokenizer,
    T5ForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq
)
import re
from jiwer import wer, cer
from underthesea import word_tokenize as underthesea_tokenize, pos_tag as underthesea_pos_tag
from difflib import SequenceMatcher

# Configuration

In [None]:
# CONFIGURATION
DRIVE_BASE_PATH = "/content/drive/MyDrive/DACNTT"
Path(DRIVE_BASE_PATH).mkdir(parents=True, exist_ok=True)
Path(os.path.join(DRIVE_BASE_PATH, "Dataset")).mkdir(parents=True, exist_ok=True)
Path(os.path.join(DRIVE_BASE_PATH, "data_processed")).mkdir(parents=True, exist_ok=True)
Path(os.path.join(DRIVE_BASE_PATH, "models")).mkdir(parents=True, exist_ok=True)

In [None]:
# --- Data Configuration ---
@dataclass
class DataConfig:
    raw_txt_files= [
        os.path.join(DRIVE_BASE_PATH, "Dataset/data.txt")
    ]
    processed_base_path = os.path.join(DRIVE_BASE_PATH, "data_processed")
    processed_train_path = os.path.join(processed_base_path, "train.csv")
    processed_test_path = os.path.join(processed_base_path, "test.csv")
    test_size = 0.10

# --- Model Configuration ---
@dataclass
class ModelConfig:
    model_name_or_path = "VietAI/vit5-base"
    max_length = 256

# --- Augmentation Configuration ---
@dataclass
class AugmentationConfig:
    prob_apply_noise: float = 0.85
    max_augmentations_per_sample: int = 2
    num_augmented_samples: int = 3

    # Word/Token level noise probabilities
    prob_typo: float = 0.50
    prob_region_tone: float = 0.35
    prob_spelling_confusion: float = 0.10
    prob_semantic_confusion: float = 0.15

    # Sentence/Grammar level noise probabilities
    prob_delete_word: float = 0.05
    prob_duplicate_word: float = 0.05
    prob_swap_adjacent: float = 0.05
    prob_abbreviation: float = 0.30

    # Probabilities for corrupt_tone internal logic
    prob_remove_tone_specific: float = 0.40
    prob_change_tone_specific: float = 0.50
    prob_add_tone_specific: float = 0.10

    vietnamese_tones: Dict[str, List[str]] = field(default_factory=dict)
    base_vowel_map: Dict[str, str] = field(default_factory=dict)
    all_vowels_with_tone: List[str] = field(default_factory=list)
    telex_typing_map: Dict[str, str] = field(default_factory=dict)
    reverse_telex_map: Dict[str, str] = field(default_factory=dict)

    initial_consonant_confusion_map: Dict[str, List[str]] = field(default_factory=dict)
    final_consonant_map: Dict[str, List[str]] = field(default_factory=dict)
    vowel_confusion_map: Dict[str, List[str]] = field(default_factory=dict)

    semantic_map: Dict[str, List[str]] = field(default_factory=dict)
    semantic_confusion_pairs: List[Tuple[str, str]] = field(default_factory=list)
    abbreviation_map: Dict[str, List[str]] = field(default_factory=dict)

    def __post_init__(self):
        self.vietnamese_tones = {
            'a': ['a', 'à', 'á', 'ả', 'ã', 'ạ'], 'ă': ['ă', 'ằ', 'ắ', 'ẳ', 'ẵ', 'ặ'], 'â': ['â', 'ầ', 'ấ', 'ẩ', 'ẫ', 'ậ'],
            'e': ['e', 'è', 'é', 'ẻ', 'ẽ', 'ẹ'], 'ê': ['ê', 'ề', 'ế', 'ể', 'ễ', 'ệ'], 'i': ['i', 'ì', 'í', 'ỉ', 'ĩ', 'ị'],
            'o': ['o', 'ò', 'ó', 'ỏ', 'õ', 'ọ'], 'ô': ['ô', 'ồ', 'ố', 'ổ', 'ỗ', 'ộ'], 'ơ': ['ơ', 'ờ', 'ớ', 'ở', 'ỡ', 'ợ'],
            'u': ['u', 'ù', 'ú', 'ủ', 'ũ', 'ụ'], 'ư': ['ư', 'ừ', 'ứ', 'ử', 'ữ', 'ự'], 'y': ['y', 'ỳ', 'ý', 'ỷ', 'ỹ', 'ỵ']
        }
        self.base_vowel_map = {toned_vowel: base for base, tones_list in self.vietnamese_tones.items() for toned_vowel in tones_list}
        self.all_vowels_with_tone = [v for tones_list in self.vietnamese_tones.values() for v_idx, v in enumerate(tones_list) if v_idx > 0]

        self.telex_typing_map = {
            'aw': 'ă', 'aa': 'â', 'dd': 'đ', 'ee': 'ê', 'oo': 'ô', 'ow': 'ơ', 'uw': 'ư',
            'as': 'á', 'af': 'à', 'ar': 'ả', 'ax': 'ã', 'aj': 'ạ', 'es': 'é', 'ef': 'è',
            'er': 'ẻ', 'ex': 'ẽ', 'ej': 'ẹ', 'os': 'ó', 'of': 'ò', 'or': 'ỏ', 'ox': 'õ',
            'oj': 'ọ', 'is': 'í', 'if': 'ì', 'ir': 'ỉ', 'ix': 'ĩ', 'ij': 'ị', 'us': 'ú',
            'uf': 'ù', 'ur': 'ủ', 'ux': 'ũ', 'uj': 'ụ', 'ys': 'ý', 'yf': 'ỳ', 'yr': 'ỷ',
            'yx': 'ỹ', 'yj': 'ỵ', 'aws': 'ắ', 'awf': 'ằ', 'awr': 'ẳ', 'awx': 'ẵ',
            'awj': 'ặ', 'aas': 'ấ', 'aaf': 'ầ', 'aar': 'ẩ', 'aax': 'ẫ', 'aaj': 'ậ',
            'ees': 'ế', 'eef': 'ề', 'eer': 'ể', 'eex': 'ễ', 'eej': 'ệ', 'oos': 'ố',
            'oof': 'ồ', 'oor': 'ổ', 'oox': 'ỗ', 'ooj': 'ộ', 'ows': 'ớ', 'owf': 'ờ',
            'owr': 'ở', 'owx': 'ỡ', 'owj': 'ợ', 'uws': 'ứ', 'uwf': 'ừ', 'uwr': 'ử',
            'uwx': 'ữ', 'uwj': 'ự',
        }
        self.reverse_telex_map = {v: k for k, v in self.telex_typing_map.items()}

        # Consolidated initial consonant confusion map
        base_initial_errors = {
            's': ['x'], 'x': ['s'],
            'n': ['l'], 'l': ['n'],
            'ch': ['tr', 'c', 't'], 'tr': ['ch'],
            'd': ['gi', 'r', 'v'], 'gi': ['d', 'r', 'v'],
            'r': ['d', 'g', 'gi', 'v'], 'v': ['d','gi','r','z'],
            'c': ['k', 'q', 't', 'ch'], 'k': ['c', 'q'], 'q': ['c', 'k'],
            'p': ['b'], 'b': ['p'],
            't': ['th', 'c', 'ch'], 'th': ['t'],
            'i': ['y'], 'y': ['i'],
            'nh': ['n', 'ng'], 'ng': ['n', 'nh', 'g'],
            'g': ['ng', 'gh'], 'gh': ['g']
        }
        phonetic_like_updates = {
            'ph': ['f'], 'f': ['ph']
        }
        self.initial_consonant_confusion_map = base_initial_errors.copy()
        for key, value_list in phonetic_like_updates.items():
            if key in self.initial_consonant_confusion_map:
                current_values = self.initial_consonant_confusion_map[key]
                for v_update in value_list:
                    if v_update not in current_values:
                        current_values.append(v_update)
            else:
                self.initial_consonant_confusion_map[key] = value_list

        self.final_consonant_map = {
            'n': ['ng', 'm'], 'ng': ['n', 'm'], 't': ['c', 'k'], 'c': ['t', 'k'], 'm': ['n', 'ng']
        }
        self.vowel_confusion_map = {
            'ê': ['e'], 'e': ['ê'], 'ô': ['o'], 'o': ['ô'], 'ơ': ['o'],
            'u': ['ư'], 'ư': ['u'], 'i':['y'], 'y':['i']
        }

        self.semantic_confusion_pairs = [
            ('năm', 'lăm'), ('sinh', 'sanh'), ('sử dụng', 'xử dụng'), ('của', 'cũ'),
            ('rồi', 'dồi'), ('nghĩ', 'nghỉ'), ('kỹ', 'kỷ'), ('chuyện', 'truyện'),
            ('lên', 'nên'), ('để', 'đễ'), ('vô', 'vào'), ('ra', 'gia'), ('sao', 'sang'),
            ('rất', 'nhất'), ('trong', 'trên'), ('với', 'về'), ('là', 'ra'),
            ('dành', 'giành'), ('kết cuộc', 'kết cục'), ('tham quan', 'tham quang'),
            ('đọc giả', 'độc giả'), ('súc tích', 'xúc tích'),('bàng quang', 'bàng quan'),
            ('chẩn đoán', 'chuẩn đoán'), ('giả thuyết', 'giả thiết'), ('sáng lạng', 'xán lạn'),
        ]

        self.semantic_map = {}
        for w1, w2 in self.semantic_confusion_pairs:
            self.semantic_map.setdefault(w1, []).append(w2)
            self.semantic_map.setdefault(w2, []).append(w1)

        self.abbreviation_map = {
            'không': ['ko', 'k', 'kg', 'kô', 'khum', 'hong'], 'được': ['đc', 'dc', 'dk', 'đk', 'ok', 'oke'],
            'tôi': ['tui', 't', 'toy', 'tao', 'tau'], 'bạn': ['bn', 'b', 'mày', 'mi'], 'chúng ta': ['cta', 'ct'],
            'học': ['hok', 'hc'], 'biết': ['bít', 'bit', 'bik', 'pk'], 'này': ['nè', 'ni'],
            'rồi': ['rùi', 'rui', 'rr', 'ròi'], 'vậy': ['v', 'zậy', 'za', 'z', 'ntn', 'tn'],
            'đi': ['di', 'ik', 'dj'], 'làm': ['lam', 'lm'],'có': ['co', 'cóa', 'cóá'],
            'đang': ['đg', 'dang', 'đag'], 'anh': ['a'], 'em' : ['e'], 'thì': ['thi', 'thy', 'thỳ'],
            'yêu': ['iu', 'ew'],'người yêu': ['ny', 'ngiu'], 'chồng': ['ck'], 'vợ': ['vk'],
            'gì': ['j', 'zj', 'hok bik'], 'quá': ['wa', 'qá'],'tin nhắn': ['tn', 'mess'],
            'điện thoại': ['đt', 'dt'],'bây giờ': ['bh', 'h', 'bi h'],'bao giờ': ['bg', 'bj h'],
            'với': ['vs', 'w'],'luôn': ['lun'],'chứ': ['chớ'],
            'cũng': ['cũg', 'cug'],'nhưng': ['nhug', 'nhưg', 'nhma'],'cho': ['choa'],
            'hôm nay': ['hnay'], 'sinh nhật': ['sn', 'snhat'],'online': ['onl', 'ol'],
            'offline': ['off', 'of'],'mọi người': ['mn'], 'về': ['dìa','zìa'],
            'tiếng Anh': ['ta'], 'tiếng Việt': ['tv'],'Việt Nam': ['vn'],
            'à': ['ah'], 'ừ': ['uh', 'ukm', 'ừm'],
            'haha': ['hihi', 'hehe', 'hoho', 'kaka'],
            'vui vẻ': ['vv'],'hạnh phúc': ['hp'],
            'chia tay': ['ctay', 'ct'],'nói chuyện': ['nc'],
            'thích': ['thik', 'thix'],'muốn': ['mún'],
            'xinh': ['xink', 'xynh'],'Trời ơi': ['tr oi', 'OMG'],
            'vân vân': ['vv', 'etc'],'xin lỗi': ['sr', 'xl'],
            'cảm ơn': ['tks', 'thanks', 'mơn'],
            'đồng ý': ['đy', 'oki'],'chúc mừng': ['cm'],
            'nhắn tin': ['nt']
        }

# --- Training Configuration ---
@dataclass
class TrainingConfig:
    output_dir = os.path.join(DRIVE_BASE_PATH, "models")
    num_train_epochs = 3
    per_device_train_batch_size = 8
    gradient_accumulation_steps = 1
    learning_rate = 4e-5
    weight_decay = 0.01
    warmup_ratio = 0.1
    logging_strategy = "steps"
    logging_steps = 5000
    eval_strategy = "steps"
    eval_steps = 5000
    save_strategy = "steps"
    save_steps = 5000
    save_total_limit = 1
    fp16: bool = field(default_factory=torch.cuda.is_available)
    generation_max_length = ModelConfig.max_length + 20
    generation_num_beams = 2
    report_to = "wandb"

data_cfg = DataConfig()
model_cfg = ModelConfig()
aug_cfg = AugmentationConfig()
train_cfg = TrainingConfig()

Path(data_cfg.processed_base_path).mkdir(parents=True, exist_ok=True)
Path(train_cfg.output_dir).mkdir(parents=True, exist_ok=True)

# Data Augmentation for Noisy Data

In [None]:
# --- Utility Functions ---
def clean_text(text):
    if not isinstance(text, str):
        return ""
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def preprocess_vietnamese_text(text):
    text = clean_text(text)
    if not text: return ""
    try:
        tokenized_text = underthesea_tokenize(text, format="text")
    except Exception:
        tokenized_text = text
    return tokenized_text

def remove_underscore_in_names(text):
    if not isinstance(text, str):
        return ""
    return re.sub(r'(\w)_(\w)', r'\1 \2', text)

def normalize_punctuation_spacing(text):
    if not isinstance(text, str):
        return ""
    text = re.sub(r'\s+([.,!?;:])', r'\1', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text


# --- Data Loading ---
def load_raw_txt_data(file_paths, max_samples_per_file = None):
    lines = []
    for file_path in file_paths:
        p = Path(file_path)
        if not p.exists():
            print(f"Warning: File not found {p}, skipping.")
            continue
        try:
            with open(p, 'r', encoding='utf-8') as f:
                file_content_lines = f.readlines()
                if max_samples_per_file and len(file_content_lines) > max_samples_per_file:
                    print(f"Sampling {max_samples_per_file} lines from {p.name}")
                    lines_to_process = random.sample(file_content_lines, max_samples_per_file)
                else:
                    lines_to_process = file_content_lines

                for line in tqdm(lines_to_process, desc=f"Processing {p.name}"):
                    cleaned_line = line.strip()
                    if cleaned_line:
                        lines.append(cleaned_line)
        except Exception as e:
            print(f"Error reading or processing file {p.name}: {e}")
    return lines

##Tone Error

Example:

Before: Hôm nay thời tiết rất đẹp, tôi muốn đi chơi.

After: Hốm nay thoi tiet rat đép, tôi muôn đi chơi.

In [None]:
# --- Augmentation Functions ---
def corrupt_tone(word, config, max_retries = 2):
    if not isinstance(word, str) or not word.strip():
        return word

    chars = list(word)
    vowel_indices = [
        i for i, char in enumerate(chars)
        if char.lower() in config.vietnamese_tones or char.lower() in config.base_vowel_map
    ]
    if not vowel_indices:
        return word

    num_changes = random.randint(1, min(2, len(vowel_indices)))
    indices_to_change = random.sample(vowel_indices, num_changes)

    changed_this_call = False
    for idx_to_change in indices_to_change:
        original_char = chars[idx_to_change]
        original_lower = original_char.lower()

        is_toned = original_lower in config.base_vowel_map and config.base_vowel_map[original_lower] != original_lower
        base_vowel = config.base_vowel_map.get(original_lower, original_lower)

        if not base_vowel or base_vowel not in config.vietnamese_tones:
            continue

        possible_tones = config.vietnamese_tones[base_vowel]
        action = random.choices(
            ['remove_tone', 'change_tone', 'add_tone'],
            weights=[config.prob_remove_tone_specific, config.prob_change_tone_specific, config.prob_add_tone_specific],
            k=1
        )[0]

        new_char_val = None
        if action == 'remove_tone' and is_toned:
            new_char_val = base_vowel
        elif action == 'change_tone' and is_toned:
            available_tones = [t for t in possible_tones if t != original_lower]
            if available_tones: new_char_val = random.choice(available_tones)
        elif action == 'add_tone' and not is_toned:
            available_tones = [t for t in possible_tones if t != base_vowel]
            if available_tones: new_char_val = random.choice(available_tones)

        if new_char_val:
            new_char_cased = new_char_val.upper() if original_char.isupper() else new_char_val
            if chars[idx_to_change] != new_char_cased:
                chars[idx_to_change] = new_char_cased
                changed_this_call = True

    result = "".join(chars)
    if not changed_this_call and result == word and vowel_indices and max_retries > 0:
        return corrupt_tone(word, config, max_retries - 1)
    return result

##Typing Telex Error

Example:

Before: Hôm nay thời tiết rất đẹp, tôi muốn đi chơi.

After: Hôm nay thowif tieest raats ddejp, tôi muoons ddi chowif.

In [None]:
def corrupt_telex_typing(word, config, max_retries = 2):
    if not isinstance(word, str) or not word.strip():
        return word

    word_lower = word.lower()
    vowel_indices = [
        i for i, char_lower in enumerate(word_lower)
        if char_lower in config.all_vowels_with_tone or char_lower in config.vietnamese_tones
    ]

    if not vowel_indices:
        return word

    num_changes = random.randint(1, min(2, len(vowel_indices)))
    indices_to_change = random.sample(vowel_indices, num_changes)

    parts = list(word)
    changed_this_call = False

    for idx_to_change in sorted(indices_to_change, reverse=True):
        original_char_at_idx = parts[idx_to_change]
        original_char_lower = original_char_at_idx.lower()

        telex_sequence = None
        if original_char_lower in config.reverse_telex_map:
            telex_sequence = config.reverse_telex_map[original_char_lower]
        elif original_char_lower in config.vietnamese_tones:
            base_vowel = original_char_lower
            possible_telex_mods = [
                k for k, v_toned in config.telex_typing_map.items()
                if config.base_vowel_map.get(v_toned) == base_vowel and v_toned != base_vowel
            ]
            if possible_telex_mods:
                telex_sequence = random.choice(possible_telex_mods)

        if telex_sequence:
            replacement_chars = list(telex_sequence)
            if original_char_at_idx.isupper() and replacement_chars:
                replacement_chars[0] = replacement_chars[0].upper()

            parts = parts[:idx_to_change] + replacement_chars + parts[idx_to_change+1:]
            changed_this_call = True

    result = "".join(parts)
    if not changed_this_call and result == word and vowel_indices and max_retries > 0:
        return corrupt_telex_typing(word, config, max_retries - 1)
    return result

##Spelling Error

Example:

Before: Hôm nay thời tiết rất đẹp, tôi muốn đi chơi.

After: Hôm nay thời tiết **gất** đẹp, tôi muốn đi **trơi**.

In [None]:
def corrupt_spelling(word, config = aug_cfg):
    if not isinstance(word, str) or not word.strip():
        return word

    original_word = word
    word_lower = word.lower()
    possible_changes = []
    for initial, replacements in config.initial_consonant_confusion_map.items():
        if word_lower.startswith(initial) and replacements:
            for rep in replacements:
                if rep != initial:
                    new_word_stem = rep + word_lower[len(initial):]
                    possible_changes.append(('initial', new_word_stem))

    for final, replacements in config.final_consonant_map.items():
        if word_lower.endswith(final) and replacements:
            for rep in replacements:
                if rep != final:
                    new_word_stem = word_lower[:-len(final)] + rep
                    possible_changes.append(('final', new_word_stem))

    for i, char_lower in enumerate(word_lower):
        if char_lower in config.vowel_confusion_map and config.vowel_confusion_map[char_lower]:
            for replacement_vowel in config.vowel_confusion_map[char_lower]:
                if replacement_vowel != char_lower:
                    word_list_temp = list(word_lower)
                    word_list_temp[i] = replacement_vowel
                    new_word_stem = "".join(word_list_temp)
                    possible_changes.append(('vowel', new_word_stem, i, original_word[i]))

    if not possible_changes:
        return original_word

    change_type, chosen_new_stem, *extra_info = random.choice(possible_changes)

    if change_type == 'initial':
        if original_word.isupper():
            return chosen_new_stem.upper()

        if chosen_new_stem:
            if original_word and original_word[0].isupper():
                 return chosen_new_stem[0].upper() + chosen_new_stem[1:]
        return chosen_new_stem

    elif change_type == 'final':
        if original_word.isupper():
            return chosen_new_stem.upper()

        if chosen_new_stem:
            if original_word and original_word[0].isupper():
                return chosen_new_stem[0].upper() + chosen_new_stem[1:]
        return chosen_new_stem

    elif change_type == 'vowel':
        idx_to_change, original_char_at_idx = extra_info
        new_vowel_char = chosen_new_stem[idx_to_change]
        new_vowel_cased = new_vowel_char.upper() if original_char_at_idx.isupper() else new_vowel_char

        new_word_list = list(original_word)
        new_word_list[idx_to_change] = new_vowel_cased
        return "".join(new_word_list)

    return original_word

##Similar meaning error

Example:

Before: Anh ấy **dành** cả buổi sáng để học bài.

After: Anh ấy **giành** cả buổi sáng để học bài.

In [None]:
def corrupt_semantic(tokens, config = aug_cfg):
    if not tokens:
        return tokens

    new_tokens = list(tokens)
    tokens_lower = [t.lower() for t in new_tokens]
    replaceable_indices = [
        i for i, token_l in enumerate(tokens_lower)
        if token_l in config.semantic_map and config.semantic_map[token_l]
    ]
    if not replaceable_indices:
        if random.random() < 0.05 and len(new_tokens) > 2:
            idx = random.randint(0, len(new_tokens) - 1)
            original_token_case = new_tokens[idx]
            default_replacements = ['biết', 'hiểu', 'nói', 'làm', 'đi', 'có', 'thấy', 'đến']
            possible_reps = [r for r in default_replacements if r != original_token_case.lower()]
            if possible_reps:
                replacement_lower = random.choice(possible_reps)
                if original_token_case.isupper(): new_tokens[idx] = replacement_lower.upper()
                elif original_token_case.istitle(): new_tokens[idx] = replacement_lower.title()
                else:
                    new_tokens[idx] = replacement_lower
        return new_tokens

    max_changes = max(1, len(tokens) // 10)
    num_to_change = min(max_changes, len(replaceable_indices))
    indices_to_replace = random.sample(replaceable_indices, num_to_change)

    for idx in indices_to_replace:
        original_token_case = new_tokens[idx]
        token_to_replace_lower = original_token_case.lower()
        replacement_options = config.semantic_map.get(token_to_replace_lower, [])
        if replacement_options:
            new_token_lower = random.choice(replacement_options)
            if original_token_case.isupper(): new_tokens[idx] = new_token_lower.upper()
            elif original_token_case.istitle(): new_tokens[idx] = new_token_lower.title()
            else:
                new_tokens[idx] = new_token_lower

    return new_tokens

##Structure Error

Example:

Before: Mặt trời mọc ở đằng đông và lặn ở đằng tây.

After:
- Mặt trời mọc ở đằng đông và ở đằng tây.

- Mặt trời mọc ở đằng **đông đông** và ở đằng tây.

- Mặt trời mọc ở **đông đằng** và lặn ở đằng tây.



In [None]:
def corrupt_structure(tokens, config= aug_cfg, pos_tags = None):
    new_tokens = list(tokens)
    if not new_tokens or len(new_tokens) < 1:
        return tokens

    use_pos_for_corruption = pos_tags is not None and len(pos_tags) == len(new_tokens)
    grammar_actions = ['drop', 'duplicate', 'swap_adjacent']
    action_weights = [
        config.prob_delete_word, config.prob_duplicate_word, config.prob_swap_adjacent
    ]
    valid_actions_weights = [(action, weight) for action, weight in zip(grammar_actions, action_weights) if weight > 0]
    if not valid_actions_weights:
        return new_tokens
    actions_chosen, weights_chosen = zip(*valid_actions_weights)
    chosen_action = random.choices(actions_chosen, weights=weights_chosen, k=1)[0]

    if chosen_action == 'drop' and len(new_tokens) >= 2:
        idx_to_drop = -1
        if use_pos_for_corruption and pos_tags:
            drop_candidates = [i for i, tag in enumerate(pos_tags) if tag in ['P', 'E', 'L', 'R', 'A']]
            if drop_candidates: idx_to_drop = random.choice(drop_candidates)
        if idx_to_drop == -1: idx_to_drop = random.randint(0, len(new_tokens) - 1)
        del new_tokens[idx_to_drop]

    elif chosen_action == 'duplicate' and len(new_tokens) >= 1:
        idx_to_duplicate = -1
        if use_pos_for_corruption and pos_tags:
            duplicate_candidates = [i for i, tag in enumerate(pos_tags) if tag in ['A', 'E', 'N']]
            if duplicate_candidates: idx_to_duplicate = random.choice(duplicate_candidates)
        if idx_to_duplicate == -1: idx_to_duplicate = random.randint(0, len(new_tokens) - 1)
        new_tokens.insert(idx_to_duplicate, new_tokens[idx_to_duplicate])

    elif chosen_action == 'swap_adjacent' and len(new_tokens) >= 2:
        idx_to_swap = -1
        if use_pos_for_corruption and pos_tags:
            swap_pairs = []
            for i in range(len(new_tokens) - 1):
                tag1, tag2 = pos_tags[i], pos_tags[i+1]
                if (tag1 == 'N' and tag2 == 'A') or (tag1 == 'A' and tag2 == 'N'): swap_pairs.append(i)
                elif (tag1 == 'V' and tag2 == 'E') or (tag1 == 'E' and tag2 == 'V'): swap_pairs.append(i)
                elif (tag1 == 'V' and tag2 in ['R', 'C']) or (tag1 in ['R', 'C'] and tag2 == 'V'): swap_pairs.append(i)
                elif (tag1 == 'L' and tag2 == 'N') or (tag1 == 'N' and tag2 == 'L'): swap_pairs.append(i)
            if swap_pairs: idx_to_swap = random.choice(swap_pairs)
        if idx_to_swap == -1: idx_to_swap = random.randint(0, len(new_tokens) - 2)
        new_tokens[idx_to_swap], new_tokens[idx_to_swap+1] = new_tokens[idx_to_swap+1], new_tokens[idx_to_swap]
    return new_tokens

##Abbreviation Error

Example:

Before: Tôi không biết bây giờ phải làm gì, nhưng tôi cũng không muốn bỏ cuộc.

After: **Tui** **k** **bít** **bh** phải làm **j**, **nhma** **t** cũng **hong** mún bỏ cuộc.

In [None]:
def corrupt_abbreviation(tokens, config = aug_cfg):
    new_tokens = list(tokens)
    if not new_tokens:
        return tokens
    candidate_indices = [
        i for i, token in enumerate(new_tokens)
        if token.lower() in config.abbreviation_map and config.abbreviation_map[token.lower()]
    ]
    if not candidate_indices:
        return new_tokens
    max_changes_hard_limit = 2
    limit_by_tokens_count = (len(tokens) // 5) + 1
    num_possible_changes = min(max_changes_hard_limit, limit_by_tokens_count, len(candidate_indices))
    if num_possible_changes == 0:
        return new_tokens

    num_to_actually_change = random.randint(1, num_possible_changes)
    indices_to_process = random.sample(candidate_indices, num_to_actually_change)

    for idx_to_change in indices_to_process:
        original_token = new_tokens[idx_to_change]
        token_lower = original_token.lower()
        replacement_options = config.abbreviation_map[token_lower]
        replacement_abbr = random.choice(replacement_options)
        if original_token.isupper(): new_tokens[idx_to_change] = replacement_abbr.upper()

        elif original_token.istitle() and len(replacement_abbr) > 0:
            new_tokens[idx_to_change] = replacement_abbr.title() if len(replacement_abbr) > 1 else replacement_abbr.upper()
        else:
            new_tokens[idx_to_change] = replacement_abbr

    return new_tokens

##Create Noise Data

In [None]:
def create_noisy_text(original_text, config = aug_cfg, pos_tag_for_initial_tokenize = False, verbose = False):
    if not isinstance(original_text, str) or not original_text.strip():
        return original_text
    text_cleaned = clean_text(original_text)

    if not text_cleaned:
        return original_text
    if random.random() > config.prob_apply_noise:
        return text_cleaned

    initial_tokens = []
    current_pos_tags = None

    if pos_tag_for_initial_tokenize:
        try:
            tagged_output = underthesea_pos_tag(text_cleaned)
            initial_tokens = [pair[0] for pair in tagged_output]
            current_pos_tags = [pair[1] for pair in tagged_output]
            if verbose and (not initial_tokens or (tagged_output and not current_pos_tags)):
                print(f"Underthesea_pos_tag returned empty tokens/tags for: '{text_cleaned}'")
        except Exception as e:
            if verbose:
                print(f"Underthesea_pos_tag failed for initial tokenization: {e}. Falling back to split().")
            initial_tokens = text_cleaned.split()
    else:
        try:
            initial_tokens = underthesea_tokenize(text_cleaned)
        except Exception as e:
            if verbose: print(f"Underthesea_tokenize failed: {e}. Falling back to split().")
            initial_tokens = text_cleaned.split()

    if not initial_tokens:
        return text_cleaned
    current_tokens = list(initial_tokens)

    structure_corruption_possible = any(p > 0 for p in [
        config.prob_delete_word, config.prob_duplicate_word, config.prob_swap_adjacent
    ])

    if current_pos_tags is None and structure_corruption_possible:
        try:
            temp_sentence_for_pos = " ".join(current_tokens)
            if temp_sentence_for_pos:
                tagged_result = underthesea_pos_tag(temp_sentence_for_pos)
                if len(tagged_result) == len(current_tokens):
                    current_pos_tags = [tag for _, tag in tagged_result]
                elif verbose:
                    print(f"Token count mismatch after POS tagging. "
                          f"Initial tokens: {len(current_tokens)}, POS_tagged tokens: {len(tagged_result)}. "
                          f"Sentence: '{temp_sentence_for_pos}'. Grammar corruption will be random.")
        except Exception as e:
            if verbose: print(f"Error during on-demand POS tagging: {e}. Grammar corruption will be random.")

    total_structure_prob = (config.prob_delete_word +
                          config.prob_duplicate_word +
                          config.prob_swap_adjacent)
    augmentation_candidates = [
        (corrupt_tone, config.prob_region_tone, 'token'),
        (corrupt_telex_typing, config.prob_typo, 'token'),
        (corrupt_spelling, config.prob_spelling_confusion, 'token'),
        (corrupt_semantic, config.prob_semantic_confusion, 'sentence'),
        (corrupt_structure, total_structure_prob, 'sentence'),
        (corrupt_abbreviation, config.prob_abbreviation, 'sentence')
    ]
    valid_augmentation_candidates = [(f, p, l) for f, p, l in augmentation_candidates if p > 0]
    if not valid_augmentation_candidates:
        return " ".join(current_tokens)

    aug_funcs, aug_probs, aug_levels = zip(*valid_augmentation_candidates)
    num_types_to_attempt = random.randint(1, min(config.max_augmentations_per_sample, len(aug_funcs)))
    selected_aug_indices = random.choices(range(len(aug_funcs)), weights=aug_probs, k=num_types_to_attempt)

    for i in selected_aug_indices:
        func_to_apply = aug_funcs[i]
        level = aug_levels[i]
        original_token_count_before_this_aug = len(current_tokens)
        try:
            if not current_tokens: break
            if level == 'token':
                idx_to_corrupt = random.randint(0, len(current_tokens) - 1)
                current_tokens[idx_to_corrupt] = func_to_apply(current_tokens[idx_to_corrupt], config=config)
            elif level == 'sentence':
                if func_to_apply == corrupt_structure:
                    pos_tags_for_this_call = None
                    if current_pos_tags and len(current_tokens) == len(current_pos_tags):
                        pos_tags_for_this_call = current_pos_tags
                    elif verbose and current_pos_tags and len(current_tokens) != len(current_pos_tags):
                         print(f"Token count changed ({original_token_count_before_this_aug} -> {len(current_tokens)}), "
                               f"or POS tags ({len(current_pos_tags) if current_pos_tags else 'None'}) are invalid for grammar corruption.")
                    current_tokens = func_to_apply(current_tokens, config=config, pos_tags=pos_tags_for_this_call)
                else:
                    current_tokens = func_to_apply(current_tokens, config=config)
            if len(current_tokens) != original_token_count_before_this_aug:
                current_pos_tags = None
        except Exception as e:
            if verbose: print(f"Augmentation error {func_to_apply.__name__}: {e}. Tokens: {current_tokens[:5]}")
            pass

    result = " ".join(filter(None, current_tokens))
    result = re.sub(r'\s+([.,!?;:])', r'\1', result)
    result = re.sub(r'([.,!?;:])\s+([.,!?;:])', r'\1\2', result)
    result = re.sub(r'\s+', ' ', result).strip()

    return result if result else text_cleaned

# Dataset Preparation for Training

In [None]:
print("Loading raw data for augmentation base...")
processed_correct_sentences = load_raw_txt_data(data_cfg.raw_txt_files)

if not processed_correct_sentences:
    raise ValueError("No correct sentences loaded. Check data.")
print(f"Loaded {len(processed_correct_sentences)} correct sentences for augmentation.")

print("Starting data augmentation process...")
augmented_pairs = []
for correct_sent in tqdm(processed_correct_sentences, desc="Augmenting data"):
    generated_this_round_distinct = 0
    seen_in_this_round = {correct_sent}
    for _ in range(aug_cfg.num_augmented_samples):
        generated_input_sent = create_noisy_text(correct_sent, aug_cfg, verbose=False)

        if generated_input_sent and generated_input_sent != correct_sent and generated_input_sent not in seen_in_this_round:
            augmented_pairs.append((generated_input_sent, correct_sent))
            seen_in_this_round.add(generated_input_sent)
            generated_this_round_distinct += 1

    if generated_this_round_distinct == 0 and aug_cfg.num_augmented_samples > 0:
        original_prob = aug_cfg.prob_apply_noise
        aug_cfg.prob_apply_noise = 1.0
        forced_incorrect_sent = create_noisy_text(correct_sent, aug_cfg, verbose=False)
        aug_cfg.prob_apply_noise = original_prob
        if forced_incorrect_sent and forced_incorrect_sent != correct_sent and forced_incorrect_sent not in seen_in_this_round:
            augmented_pairs.append((forced_incorrect_sent, correct_sent))

if not augmented_pairs:
    print("No data generated after augmentation. Adding original sentences as pairs if any.")
    for cs in processed_correct_sentences:
        if cs.strip():
             augmented_pairs.append((cs,cs))
    if not augmented_pairs:
        raise ValueError("No data generated, and no raw sentences. Check config and raw data.")

print(f"Generated a total of {len(augmented_pairs)} (incorrect, correct) pairs.")
del processed_correct_sentences

random.shuffle(augmented_pairs)
df_all = pd.DataFrame(augmented_pairs, columns=['incorrect', 'correct'])

df_all['incorrect'] = df_all['incorrect'].apply(normalize_punctuation_spacing)
df_all['correct'] = df_all['correct'].apply(normalize_punctuation_spacing)
df_all['correct'] = df_all['correct'].apply(remove_underscore_in_names)

df_all.dropna(subset=['incorrect', 'correct'], inplace=True)
df_all = df_all[df_all['incorrect'].str.strip().astype(bool) & df_all['correct'].str.strip().astype(bool)]
df_all = df_all.drop_duplicates(subset=['incorrect', 'correct'], keep='first').reset_index(drop=True)

print(f"Total pairs after dropping duplicates and empty strings: {len(df_all)}")
del augmented_pairs

if len(df_all) == 0:
    raise ValueError("DataFrame is empty after augmentation and deduplication. Check data generation.")

test_set_actual_size = data_cfg.test_size
num_test_samples = int(len(df_all) * test_set_actual_size) if isinstance(test_set_actual_size, float) else int(test_set_actual_size)
num_test_samples = max(0, min(num_test_samples, len(df_all) -1))

if len(df_all) <= 1 :
    print(f"Very few samples ({len(df_all)}). Allocating all to training if > 0, else test will be empty.")
    df_train = df_all.copy() if len(df_all) > 0 else pd.DataFrame(columns=['incorrect', 'correct'])
    df_test = pd.DataFrame(columns=['incorrect', 'correct'])

elif num_test_samples == 0 and len(df_all) > 0:
    print(f"Test set size is 0. All {len(df_all)} samples allocated to training set.")
    df_train = df_all.copy()
    df_test = pd.DataFrame(columns=['incorrect', 'correct'])
else:
    df_test = df_all.iloc[:num_test_samples]
    df_train = df_all.iloc[num_test_samples:]

del df_all

print(f"Train samples: {len(df_train)}, Test samples: {len(df_test)}")

if not df_train.empty:
    df_train.to_csv(data_cfg.processed_train_path, index=False, encoding='utf-8')
    print(f"Train data saved to {data_cfg.processed_train_path}")
else:
    print("Train dataset is empty. Training cannot proceed.")

if not df_test.empty:
    df_test.to_csv(data_cfg.processed_test_path, index=False, encoding='utf-8')
    print(f"Test data saved to {data_cfg.processed_test_path}")
else:
    print("Test dataset is empty. Evaluation might not be possible or meaningful.")

Loading raw data for augmentation base...


Processing data.txt:   0%|          | 0/20182 [00:00<?, ?it/s]

Loaded 20182 correct sentences for augmentation.
Starting data augmentation process...


Augmenting data:   0%|          | 0/20182 [00:00<?, ?it/s]

Generated a total of 46760 (incorrect, correct) pairs.
Total pairs after dropping duplicates and empty strings: 46741
Train samples: 42067, Test samples: 4674
Train data saved to /content/drive/MyDrive/DACNTT/data_processed/train.csv
Test data saved to /content/drive/MyDrive/DACNTT/data_processed/test.csv


#Define Model

In [None]:
tokenizer = T5Tokenizer.from_pretrained(model_cfg.model_name_or_path)
model = T5ForConditionalGeneration.from_pretrained(model_cfg.model_name_or_path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Using device: {device}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/820k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.12k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.40M [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


config.json:   0%|          | 0.00/702 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/904M [00:00<?, ?B/s]

Using device: cuda


#Split Train/Test set

In [None]:
print("Creating datasets from DataFrames...")
raw_datasets = DatasetDict()

print(f"Loading data from {data_cfg.processed_train_path} and {data_cfg.processed_test_path}...")

df_train = pd.DataFrame()
df_test = pd.DataFrame()

try:
    if os.path.exists(data_cfg.processed_train_path):
        df_train = pd.read_csv(data_cfg.processed_train_path, encoding='utf-8')
        print(f"Train data loaded: {len(df_train)} samples.")
    else:
        print(f"Warning: Train data file not found at {data_cfg.processed_train_path}")

    if os.path.exists(data_cfg.processed_test_path):
        df_test = pd.read_csv(data_cfg.processed_test_path, encoding='utf-8')
        print(f"Test data loaded: {len(df_test)} samples.")
    else:
          print(f"Warning: Test data file not found at {data_cfg.processed_test_path}. Evaluation might not be possible.")

except Exception as e:
    print(f"Error loading data from CSV files: {e}")

if not df_train.empty:
    raw_datasets['train'] = Dataset.from_pandas(df_train)
    print(f"Train dataset created from DataFrame with {len(raw_datasets['train'])} samples.")
else:
    print("df_train is empty. Cannot create train dataset.")

if not df_test.empty:
    raw_datasets['test'] = Dataset.from_pandas(df_test)
    print(f"Test dataset created from DataFrame with {len(raw_datasets['test'])} samples.")
else:
    print("df_test is empty. Cannot create test dataset.")

if not raw_datasets:
    raise RuntimeError("No datasets were created. Check if df_train or df_test were populated correctly.")

if 'train' in raw_datasets:
    if not ('incorrect' in raw_datasets['train'].column_names and 'correct' in raw_datasets['train'].column_names):
        raise ValueError("Train data must contain 'incorrect' and 'correct' columns.")
if 'test' in raw_datasets:
     if not ('incorrect' in raw_datasets['test'].column_names and 'correct' in raw_datasets['test'].column_names):
        raise ValueError("Test data must contain 'incorrect' and 'correct' columns.")

print(raw_datasets)

Using device: cuda
Creating datasets from DataFrames...
Loading data from /content/drive/MyDrive/DACNTT/data_processed/train.csv and /content/drive/MyDrive/DACNTT/data_processed/test.csv...
Train data loaded: 42067 samples.
Test data loaded: 4674 samples.
Train dataset created from DataFrame with 42067 samples.
Test dataset created from DataFrame with 4674 samples.
DatasetDict({
    train: Dataset({
        features: ['incorrect', 'correct'],
        num_rows: 42067
    })
    test: Dataset({
        features: ['incorrect', 'correct'],
        num_rows: 4674
    })
})


#Tokenizer & Calculate Metrics Function

In [None]:
def tokenize_function(examples):
    inputs = ["Fix: " + str(text) if text is not None else "" for text in examples['incorrect']]
    targets = [str(text) if text is not None else "" for text in examples['correct']]
    model_inputs = tokenizer(inputs, max_length=model_cfg.max_length, padding="max_length", truncation=True)

    labels = tokenizer(text_target=targets, max_length=model_cfg.max_length, padding="max_length", truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

def calculate_correction_metrics(incorrect_sents: List[str],
                                 predicted_sents: List[str],
                                 reference_sents: List[str]) -> Dict[str, float]:

    total_tp, total_fp, total_fn = 0, 0, 0

    if not (len(incorrect_sents) == len(predicted_sents) == len(reference_sents)):
        raise ValueError("Input lists must have the same length.")

    for inc_sent, pred_sent, ref_sent in zip(incorrect_sents, predicted_sents, reference_sents):
        try:
            inc_tokens = underthesea_tokenize(str(inc_sent) if inc_sent else "")
            pred_tokens = underthesea_tokenize(str(pred_sent) if pred_sent else "")
            ref_tokens = underthesea_tokenize(str(ref_sent) if ref_sent else "")
        except Exception:
            inc_tokens = (str(inc_sent) if inc_sent else "").split()
            pred_tokens = (str(pred_sent) if pred_sent else "").split()
            ref_tokens = (str(ref_sent) if ref_sent else "").split()

        original_errors = set()
        matcher_inc_ref = difflib.SequenceMatcher(None, inc_tokens, ref_tokens)
        for tag, i1, i2, j1, j2 in matcher_inc_ref.get_opcodes():
            if tag != 'equal':
                original_errors.add((tuple(inc_tokens[i1:i2]), tuple(ref_tokens[j1:j2])))

        model_changes = set()
        matcher_inc_pred = difflib.SequenceMatcher(None, inc_tokens, pred_tokens)
        for tag, i1, i2, j1, j2 in matcher_inc_pred.get_opcodes():
            if tag != 'equal':
                model_changes.add((tuple(inc_tokens[i1:i2]), tuple(pred_tokens[j1:j2])))

        tp_for_sentence = len(original_errors.intersection(model_changes))
        fp_for_sentence = len(model_changes - original_errors)
        fn_for_sentence = len(original_errors - model_changes)

        total_tp += tp_for_sentence
        total_fp += fp_for_sentence
        total_fn += fn_for_sentence

    total_errors_in_incorrect = total_tp + total_fn
    total_model_edits = total_tp + total_fp

    precision = total_tp / total_model_edits if total_model_edits > 0 else 0.0
    recall = total_tp / total_errors_in_incorrect if total_errors_in_incorrect > 0 else 0.0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0

    return {
        "precision": precision,
        "recall": recall,
        "f1_score": f1,
        "true positives": total_tp,
        "false positives": total_fp,
        "false negatives": total_fn,
        "Total errors": total_errors_in_incorrect,
        "Total model edits": total_model_edits
    }

In [None]:
print("Tokenizing datasets...")
remove_cols = None
if 'train' in raw_datasets and raw_datasets['train'].num_rows > 0:
    remove_cols = raw_datasets['train'].column_names
elif 'test' in raw_datasets and raw_datasets['test'].num_rows > 0:
    remove_cols = raw_datasets['test'].column_names
else:
    print("No columns to remove as datasets might be empty or not loaded.")

num_cpus = os.cpu_count()
num_proc_tokenizer = max(1, num_cpus // 2 if num_cpus is not None else 1)
if sys.platform == "win32":
    print("Running on Windows, setting num_proc for tokenization to 1.")
    num_proc_tokenizer = 1

tokenized_datasets = raw_datasets.map(
    tokenize_function,
    batched=True,
    remove_columns=remove_cols if remove_cols else [],
    num_proc=num_proc_tokenizer,
    load_from_cache_file=True
)

data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=tokenizer.pad_token_id,
)

Path(train_cfg.output_dir).mkdir(parents=True, exist_ok=True)


training_args = Seq2SeqTrainingArguments(
    output_dir=train_cfg.output_dir,
    num_train_epochs=train_cfg.num_train_epochs,
    per_device_train_batch_size=train_cfg.per_device_train_batch_size,
    gradient_accumulation_steps=train_cfg.gradient_accumulation_steps,
    learning_rate=train_cfg.learning_rate,
    weight_decay=train_cfg.weight_decay,
    warmup_ratio=train_cfg.warmup_ratio,
    logging_strategy=train_cfg.logging_strategy,
    logging_steps=train_cfg.logging_steps,
    evaluation_strategy = train_cfg.eval_strategy if 'test' in tokenized_datasets else "no",
    eval_steps=train_cfg.eval_steps if 'test' in tokenized_datasets else None,
    save_strategy=train_cfg.save_strategy,
    save_steps=train_cfg.save_steps,
    save_total_limit=train_cfg.save_total_limit,
    fp16=train_cfg.fp16,
    generation_max_length=train_cfg.generation_max_length,
    generation_num_beams=train_cfg.generation_num_beams,
    report_to=train_cfg.report_to
)

train_dataset_for_trainer = tokenized_datasets.get("train")
eval_dataset_for_trainer = tokenized_datasets.get("test")

if not train_dataset_for_trainer:
    print("Train dataset is not available after tokenization. Exiting.")
    sys.exit(1)

if not eval_dataset_for_trainer:
    print("Evaluation dataset is not available. Trainer will run without intermediate evaluation.")
    training_args.evaluation_strategy = "no"
    training_args.metric_for_best_model = None
    training_args.load_best_model_at_end = False

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset_for_trainer,
    eval_dataset=eval_dataset_for_trainer,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

Tokenizing datasets...


Map:   0%|          | 0/42067 [00:00<?, ? examples/s]

Map:   0%|          | 0/4674 [00:00<?, ? examples/s]

  trainer = Seq2SeqTrainer(


#Starting Training

In [None]:
model_to_load_path = None
metrics = {}
try:
    print("Starting training...")
    train_result = trainer.train()
    trainer.save_state()

    final_model_explicit_save_path = os.path.join(train_cfg.output_dir, "ViT5_model")
    trainer.save_model(final_model_explicit_save_path)
    tokenizer.save_pretrained(final_model_explicit_save_path)
    print(f"Training completed. Final model state saved. Model for inference is at {final_model_explicit_save_path}")

    if training_args.load_best_model_at_end and trainer.state.best_model_checkpoint and Path(trainer.state.best_model_checkpoint).exists():
        model_to_load_path = trainer.state.best_model_checkpoint
        print(f"Best model checkpoint is at: {model_to_load_path}")
    else:
        model_to_load_path = final_model_explicit_save_path
        print(f"Using explicitly saved final model from: {model_to_load_path}")

except Exception as e:
    print(f"An error occurred during training or evaluation: {e}")
    if hasattr(trainer, 'state'):
        trainer.save_state()

    if hasattr(trainer, 'state') and trainer.state.best_model_checkpoint and Path(trainer.state.best_model_checkpoint).exists():
        model_to_load_path = trainer.state.best_model_checkpoint
    else:
        output_dir_path = Path(train_cfg.output_dir)
        final_model_check = output_dir_path / "ViT5_model"

        if final_model_check.exists():
            model_to_load_path = str(final_model_check)
        else:
            checkpoints = sorted(
                [d for d in output_dir_path.iterdir() if d.is_dir() and d.name.startswith("checkpoint-")],
                key=lambda p: p.stat().st_mtime, reverse=True
            )
            if checkpoints:
                model_to_load_path = str(checkpoints[0])

    if model_to_load_path:
        print(f"Training interrupted/finished. Attempting to use model: {model_to_load_path}")
    else:
        print("Training interrupted/finished. No usable model checkpoint found for inference.")

Starting training...


Step,Training Loss,Validation Loss
5000,0.0165,0.010465
10000,0.008,0.008512
15000,0.0039,0.008767


Training completed. Final model state saved. Model for inference is at /content/drive/MyDrive/DACNTT/models/model/final_model_checkpoint
Using explicitly saved final model from: /content/drive/MyDrive/DACNTT/models/model/final_model_checkpoint


# Evaluation and Inference

In [None]:
# --- Inference and Extended Evaluation ---
def correct_text_inference(text, model_to_use, tokenizer_to_use, device_to_use,
                           max_len_inf, num_beams_inf = 3,
                           repetition_penalty = 1.05, no_repeat_ngram_size = 2):

    preprocessed_text = preprocess_vietnamese_text(text)
    if not preprocessed_text:
        return ""

    input_text_inf = + preprocessed_text
    inputs = tokenizer_to_use(input_text_inf, return_tensors="pt",
                              max_length=max_len_inf, truncation=True, padding=True).to(device_to_use)

    model_to_use.to(device_to_use)
    model_to_use.eval()

    with torch.no_grad():
        outputs = model_to_use.generate(
            **inputs,
            max_length=max_len_inf + 20,
            num_beams=num_beams_inf,
            early_stopping=True,
            repetition_penalty=repetition_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size
        )
    corrected_text = tokenizer_to_use.decode(outputs[0], skip_special_tokens=True)
    return clean_text(corrected_text)

def evaluate_model(model_to_eval, dataset_raw_eval, tokenizer_eval,
                                  device_eval, max_len_eval, num_beams_eval,
                                  output_dir_path_eval):
    predictions_eval = []
    references_jiwer_eval = []
    inputs_original_eval = []
    evaluation_details = []

    model_to_eval.eval()
    model_to_eval.to(device_eval)

    print("\nStarting extended evaluation on raw test data...")
    for example in tqdm(dataset_raw_eval, desc="Extended Evaluation on Raw Test Set"):
        input_text_original = str(example['incorrect']) if example['incorrect'] is not None else ""
        reference_text_original = str(example['correct']).strip() if example['correct'] is not None else ""

        predicted_text_eval = correct_text_inference(
            input_text_original, model_to_eval, tokenizer_eval, device_eval,
            max_len_eval, num_beams_eval
        ).strip()

        predictions_eval.append(predicted_text_eval)
        references_jiwer_eval.append(reference_text_original)
        inputs_original_eval.append(input_text_original)

        evaluation_details.append({
            "input": input_text_original,
            "predicted_output": predicted_text_eval,
            "reference_output": reference_text_original
        })


    results_eval = {}
    # Calculate CER, WER
    try:
        if predictions_eval or references_jiwer_eval:
            results_eval["CER"] = cer(references_jiwer_eval, predictions_eval)
            results_eval["WER"] = wer(references_jiwer_eval, predictions_eval)
        else:
            results_eval["CER"], results_eval["WER"] = 0.0, 0.0

    except Exception as e_jiwer_eval:
        print(f"Error in CER/WER calculation: {e_jiwer_eval}")
        results_eval["CER"], results_eval["WER"] = -1.0, -1.0

    # Calculate Precision, Recall, F1
    try:
        if inputs_original_eval:
            correction_metrics = calculate_correction_metrics(
                inputs_original_eval,
                predictions_eval,
                references_jiwer_eval
            )
            results_eval.update(correction_metrics)
        else:
            print("Warning: Input list is empty, skipping correction P/R/F1 calculation.")

    except Exception as e_prf:
        print(f"Error calculating Correction P/R/F1: {e_prf}")

    print("\n--- Detailed Final Evaluation Results ---")
    for metric_name, score_val in results_eval.items():
        if isinstance(score_val, float): print(f"{metric_name}: {score_val:.4f}")
        else: print(f"{metric_name}: {score_val}")

    Path(output_dir_path_eval).mkdir(parents=True, exist_ok=True)
    summary_path_extended = os.path.join(output_dir_path_eval, "evaluation_summary.txt")
    with open(summary_path_extended, "w", encoding='utf-8') as f:
        f.write("Extended Evaluation Metrics:\n")
        for key, value in results_eval.items(): f.write(f"{key}: {value}\n")
    print(f"Extended evaluation summary saved to {summary_path_extended}")

    if evaluation_details:
        error_cases_df = pd.DataFrame(evaluation_details)
        error_file_path = os.path.join(output_dir_path_eval, "evaluation_details.csv")
        error_cases_df.to_csv(error_file_path, index=False, encoding='utf-8')
        print(f"Full evaluation details saved to: {error_file_path}")
    else:
        print("No evaluation cases found or evaluation set was empty.")
    return results_eval

In [None]:
model_to_load_path = '/content/drive/MyDrive/DACNTT/models/ViT5_model'
if model_to_load_path and Path(model_to_load_path).exists():
    tokenizer_inf = T5Tokenizer.from_pretrained(model_to_load_path)
    model_inf = T5ForConditionalGeneration.from_pretrained(model_to_load_path)
    model_inf.to(device)
    model_inf.eval()

    raw_test_dataset_for_eval = raw_datasets.get("test")
    if raw_test_dataset_for_eval and raw_test_dataset_for_eval.num_rows > 0:
        evaluate_model(
            model_inf,
            raw_test_dataset_for_eval,
            tokenizer_inf,
            device,
            model_cfg.max_length,
            train_cfg.generation_num_beams,
            train_cfg.output_dir
        )
    else:
        print("Skipping evaluation as data is not available or empty.")
else:
    print(f"Error: model_to_load_path is not a valid path: {model_to_load_path}")
    print("Please ensure the model training completed successfully and saved a model checkpoint.")


Starting extended evaluation on raw test data...


Extended Evaluation on Raw Test Set:   0%|          | 0/4674 [00:00<?, ?it/s]


--- Detailed Final Evaluation Results ---
CER: 0.0599
WER: 0.2057
precision: 0.3079
recall: 0.5838
f1_score: 0.4031
true positives: 4004
false positives: 9002
false negatives: 2855
Total errors: 6859
Total model edits: 13006
Extended evaluation summary saved to /content/drive/MyDrive/DACNTT/models/model/evaluation_summary.txt
Full evaluation details saved to: /content/drive/MyDrive/DACNTT/models/model/evaluation_details.csv


In [None]:
if model_to_load_path and Path(model_to_load_path).exists():
    print(f"\nLoading model from: {model_to_load_path} for inference.")
    try:
        tokenizer_inf = T5Tokenizer.from_pretrained(model_to_load_path)
        model_inf = T5ForConditionalGeneration.from_pretrained(model_to_load_path)
        model_inf.to(device)

        test_sentences = [
            "chương trỉnhnh được páht sóng vào lúc 19h", "chúc mừng bạnn đã trúng giải nhất",
            "công nghề thônngg tin đáng phát chiển rất nhanh", "tôi mún đi chơi với bạn bè cuỗi tuần này",
            "anh âý là một ngừơi tốt bụng và thân thiện", "xin chào tắc cả mọi người",
            "chúng ta càn phãi cố gắng hơn nũa.", "Thòi tiết hôm nay rất đepj.",
            "Bây h phải lm s đay.", "Học hok tốt thì kho mak đc điểm cao.",
            "cuốn truyện này rất hay và ý nghĩa", "tôi nghỉ bạn nên đi du lịch đễ giải tỏa căng thẳng",
            "thay đổi khí hậu ảnh hưởng lớn đến trái đất", "Năm nay kinh tế có nhìu triển vọng hơn năm ngoái.",
            "dù rất mệt nhưng anh ấy vẫn cố hoàn thành công việc",
            "hôm nay trời rất đẹp đẹp", "mưa trời hôm nay nhiều",
            "", # Test empty string
            "   " # Test string with only spaces
        ]
        print("\n--- Inference Examples ---")
        for sentence in test_sentences:
            corrected = correct_text_inference(sentence, model_inf, tokenizer_inf, device,
                                               model_cfg.max_length, train_cfg.generation_num_beams)
            print(f"Input    : '{sentence}'\nCorrected: '{corrected}'\n" + "-" * 30)

    except Exception as e:
        print(f"Error loading model or during inference: {e}")
else:
    print("\nNo valid model path found to load for inference.")


Loading model from: /content/drive/MyDrive/DACNTT/models/model/final_model_checkpoint for inference.

--- Inference Examples ---
Input    : 'chương trỉnhnh được páht sóng vào lúc 19h'
Corrected: 'chương trình được phát sóng vào lúc 19h.'
------------------------------
Input    : 'chúc mừng bạnn đã trúng giải nhất'
Corrected: 'chúc mừng bạn đã trúng giải nhất.'
------------------------------
Input    : 'công nghề thônngg tin đáng phát chiển rất nhanh'
Corrected: 'công nghệ thông tin đáng phát triển rất nhanh.'
------------------------------
Input    : 'tôi mún đi chơi với bạn bè cuỗi tuần này'
Corrected: 'tôi muốn đi chơi với bạn bè cuối tuần này.'
------------------------------
Input    : 'anh âý là một ngừơi tốt bụng và thân thiện'
Corrected: 'anh ấy là một người tốt bụng và thân thiện.'
------------------------------
Input    : 'xin chào tắc cả mọi người'
Corrected: 'Xin chào tất cả mọi người.'
------------------------------
Input    : 'chúng ta càn phãi cố gắng hơn nũa.'
Corrected: