# Burmese Grammar Error Correction (mBART-50)

This notebook fine-tunes **facebook/mbart-large-50-many-to-many-mmt** for **Burmese** grammar error correction.

Key features:
- Robust CSV loading with multiple encodings (`utf-8-sig`, `utf-8`, `utf-16`, `cp932`).
- Auto-detection of source/target columns (incorrect → correct).
- Clean preprocessing & safe decoding helpers.
- Proper **mBART-50 language tags** for Burmese (`my_MM`).
- Conservative yet effective training args with warmup, scheduler, weight decay, early best-model loading.
- Multiple metrics: BLEU, chrF, Exact Match, WER, CER.


In [1]:
# %%capture
!pip -q install -U transformers datasets evaluate sacrebleu jiwer accelerate sentencepiece

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.1/40.1 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.6/11.6 MB[0m [31m99.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m85.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m90.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import os, math, random
from pathlib import Path
import numpy as np
import pandas as pd
import torch

import evaluate
import sacrebleu
from jiwer import cer, wer

from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)

os.environ['TOKENIZERS_PARALLELISM'] = 'false'
print('Torch:', torch.__version__)
import transformers, datasets
print('Transformers:', transformers.__version__)
print('Datasets:', datasets.__version__)

Torch: 2.8.0+cu126
Transformers: 4.56.0
Datasets: 4.0.0


In [3]:
# ===================== Config =====================
DATA_CSV   = Path("testdata.csv")  # <-- change if needed
OUTPUT_DIR = Path("./out-mbart50-gec")
MODEL_NAME = "facebook/mbart-large-50-many-to-many-mmt"  # mBART-50

# Language code for Burmese (mBART-50 uses language tags)
LANG_CODE  = "my_MM"

# Sequence lengths
MAX_SRC_LEN = 256
MAX_TGT_LEN = 256

# Training schedule
SEED       = 42
EPOCHS     = 20
LR         = 1e-4        # stable for fine-tuning; consider 5e-5 for smaller data
BATCH_SIZE = 4           # per device
GRAD_ACCUM = 4           # effective batch size = 16
LOG_STEPS  = 50
SAVE_STEPS = 200

# Generation
NUM_BEAMS  = 5
GEN_MAXLEN = MAX_TGT_LEN
TASK_PREFIX = "fix: "
# ==================================================

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
assert DATA_CSV.exists(), f"Dataset not found at {DATA_CSV}. Upload it or fix DATA_CSV."
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
print('Config OK - using', MODEL_NAME)

Config OK - using facebook/mbart-large-50-many-to-many-mmt


In [4]:
# Robust CSV loader with encoding fallbacks
def read_csv_smart(path: Path):
    errs = []
    for enc in ["utf-8-sig", "utf-8", "utf-16", "cp932"]:
        try:
            return pd.read_csv(path, encoding=enc)
        except Exception as e:
            errs.append((enc, str(e)))
    raise RuntimeError(f"Failed to read {path} with tried encodings: {errs}")

df_raw = read_csv_smart(DATA_CSV)
print('Columns detected:', list(df_raw.columns))
print('Total rows:', len(df_raw))
df_raw.head(5)

Columns detected: ['source', 'target']
Total rows: 999


Unnamed: 0,source,target
0,သူအိမ်ကို ထွက်သွားသည်။,သူအိမ်သို့ ထွက်သွားသည်။
1,သူစာသင်ကျောင်းကို သွားသည်။,သူစာသင်ကျောင်းသို့ သွားသည်။
2,ကျောင်းမှ စာမေးပွဲဖြေသည်။,ကျောင်း၌ စာမေးပွဲဖြေသည်။
3,ဆေးရုံကို စောင့်ဆိုင်းသည်။,ဆေးရုံ၌ စောင့်ဆိုင်းသည်။
4,စခန်းမှ ထွက်လာသည်။,စခန်းကို ထွက်လာသည်။


In [5]:
# Auto-detect source/target columns
SRC_CANDS = ['incorrect','error','noisy','input','source','wrong','bad','sentence_in']
TGT_CANDS = ['correct','target','fixed','output','sentence_out','sentence']

def detect_cols(columns):
    lower = {c.lower(): c for c in columns}
    src = None; tgt = None
    for k in SRC_CANDS:
        if k in lower:
            src = lower[k]; break
    for k in TGT_CANDS:
        if k in lower:
            tgt = lower[k]; break
    if src is None and 'Incorrect' in columns:
        src = 'Incorrect'
    if tgt is None and 'Correct' in columns:
        tgt = 'Correct'
    return src, tgt

src_col, tgt_col = detect_cols(df_raw.columns)
if not src_col or not tgt_col:
    raise ValueError(
        f"Could not detect columns. Found: {list(df_raw.columns)}\n"
        f"Please rename columns to e.g. 'Incorrect' and 'Correct'."
    )
print(f"Using columns -> source: '{src_col}', target: '{tgt_col}'")

# Basic cleaning and dedup
df = df_raw[[src_col, tgt_col]].dropna()
df = df.astype({src_col: str, tgt_col: str})
df[src_col] = df[src_col].str.strip()
df[tgt_col] = df[tgt_col].str.strip()
df = df[(df[src_col] != '') & (df[tgt_col] != '')]
df = df[df[src_col] != df[tgt_col]]
df = df.drop_duplicates(subset=[src_col, tgt_col])
print('After cleaning:', len(df))
df.head(5)

Using columns -> source: 'source', target: 'target'
After cleaning: 993


Unnamed: 0,source,target
0,သူအိမ်ကို ထွက်သွားသည်။,သူအိမ်သို့ ထွက်သွားသည်။
1,သူစာသင်ကျောင်းကို သွားသည်။,သူစာသင်ကျောင်းသို့ သွားသည်။
2,ကျောင်းမှ စာမေးပွဲဖြေသည်။,ကျောင်း၌ စာမေးပွဲဖြေသည်။
3,ဆေးရုံကို စောင့်ဆိုင်းသည်။,ဆေးရုံ၌ စောင့်ဆိုင်းသည်။
4,စခန်းမှ ထွက်လာသည်။,စခန်းကို ထွက်လာသည်။


In [6]:
# Train/val/test split and (optional) light bootstrapping for train
from math import floor

n_total = len(df)
n_train = int(n_total * 0.8)
n_val   = int(n_total * 0.1)
n_test  = n_total - n_train - n_val

df_shuf = df.sample(frac=1.0, random_state=SEED).reset_index(drop=True)
df_train = df_shuf.iloc[:n_train]
df_val   = df_shuf.iloc[n_train:n_train+n_val]
df_test  = df_shuf.iloc[n_train+n_val:]

print(f"Split sizes -> train: {len(df_train)}, val: {len(df_val)}, test: {len(df_test)}")

# OPTIONAL: modest bootstrapping to stabilize updates on small data
BOOTSTRAP_FACTOR = 2  # keep small to avoid overfitting
if BOOTSTRAP_FACTOR > 1 and len(df_train) > 0:
    df_train = df_train.sample(n=len(df_train) * BOOTSTRAP_FACTOR,
                               replace=True, random_state=SEED).reset_index(drop=True)
    print('Bootstrapped train size:', len(df_train))

ds = DatasetDict({
    'train': Dataset.from_pandas(df_train.reset_index(drop=True)),
    'validation': Dataset.from_pandas(df_val.reset_index(drop=True)),
    'test': Dataset.from_pandas(df_test.reset_index(drop=True)),
})
ds

Split sizes -> train: 794, val: 99, test: 100
Bootstrapped train size: 1588


DatasetDict({
    train: Dataset({
        features: ['source', 'target'],
        num_rows: 1588
    })
    validation: Dataset({
        features: ['source', 'target'],
        num_rows: 99
    })
    test: Dataset({
        features: ['source', 'target'],
        num_rows: 100
    })
})

In [7]:
# Tokenizer + model (mBART-50) with Burmese language tags
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
model     = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

tokenizer.src_lang = LANG_CODE
tokenizer.tgt_lang = LANG_CODE
model.config.forced_bos_token_id = tokenizer.lang_code_to_id[LANG_CODE]

# Ensure pad token id is set consistently
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = getattr(tokenizer, 'eos_token_id', None) or 0
model.config.pad_token_id = tokenizer.pad_token_id

VOCAB_SIZE = getattr(tokenizer, 'vocab_size', None) or getattr(model.config, 'vocab_size', None)
assert VOCAB_SIZE is not None, 'Could not determine vocab size.'
print('Tokenizer/model ready. Vocab size:', VOCAB_SIZE)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/529 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/649 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.44G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/261 [00:00<?, ?B/s]

Tokenizer/model ready. Vocab size: 250054


In [8]:
# Safe decoding helpers
def _clamp_ids(arr):
    arr = np.asarray(arr)
    arr = np.where(arr < 0, tokenizer.pad_token_id, arr)
    arr = np.where(arr >= VOCAB_SIZE, tokenizer.pad_token_id, arr)
    return arr.astype('int64')

def safe_batch_decode(batch_ids):
    ids = _clamp_ids(batch_ids)
    return tokenizer.batch_decode(ids, skip_special_tokens=True)

In [9]:
# Preprocessing & tokenization
def preprocess_fn(batch):
    inputs  = [TASK_PREFIX + x for x in batch[src_col]]
    targets = [x for x in batch[tgt_col]]
    model_inputs = tokenizer(inputs, max_length=MAX_SRC_LEN, truncation=True, padding=False)
    labels = tokenizer(targets, max_length=MAX_TGT_LEN, truncation=True, padding=False)
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

remove_cols = list(ds['train'].features.keys())
ds_tok = ds.map(preprocess_fn, batched=True, remove_columns=remove_cols)
print(ds_tok)

Map:   0%|          | 0/1588 [00:00<?, ? examples/s]

Map:   0%|          | 0/99 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 1588
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 99
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 100
    })
})


In [10]:
# Data collator (dynamic padding)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding='longest')

# Metrics
bleu_metric = evaluate.load('sacrebleu')
def compute_metrics(eval_preds):
    pred_ids, label_ids = eval_preds
    label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
    preds = safe_batch_decode(pred_ids)
    refs  = safe_batch_decode(label_ids)

    bleu_res = bleu_metric.compute(predictions=preds, references=[[r] for r in refs])
    chrf = sacrebleu.corpus_chrf(preds, [refs]).score
    exact = float(np.mean([p.strip() == r.strip() for p, r in zip(preds, refs)]) * 100.0)
    _wer = wer(refs, preds) * 100.0
    _cer = cer(refs, preds) * 100.0

    return {
        'bleu': round(bleu_res['score'], 4),
        'chrf': round(chrf, 4),
        'exact_match': round(exact, 2),
        'wer': round(_wer, 2),
        'cer': round(_cer, 2),
    }

Downloading builder script: 0.00B [00:00, ?B/s]

In [11]:
!pip install -U transformers




In [12]:
# Training arguments (conservative + stable)
SAFE_ARGS = dict(
    output_dir=str(OUTPUT_DIR),
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LR,
    num_train_epochs=EPOCHS,
    save_steps=min(SAVE_STEPS, max(50, len(ds_tok['train'])//2)) if len(ds_tok['train']) > 0 else SAVE_STEPS,
    logging_steps=LOG_STEPS,
    save_total_limit=2,
    predict_with_generate=True,
    generation_max_length=GEN_MAXLEN,
    generation_num_beams=NUM_BEAMS,
    fp16=False,
    bf16=False,
    group_by_length=True,
    report_to='none',
)

# Stability & quality improvements
SAFE_ARGS.update(dict(
    weight_decay=0.01,
    lr_scheduler_type='linear',
    warmup_ratio=0.06,
    eval_strategy='steps',   # ✅ old param name
    eval_steps=LOG_STEPS,
    save_strategy='steps',
    load_best_model_at_end=True,
    metric_for_best_model='chrf',
    greater_is_better=True,
))


print('Using Seq2SeqTrainingArguments with keys:', sorted(SAFE_ARGS.keys()))
training_args = Seq2SeqTrainingArguments(**SAFE_ARGS)

Using Seq2SeqTrainingArguments with keys: ['bf16', 'eval_steps', 'eval_strategy', 'fp16', 'generation_max_length', 'generation_num_beams', 'gradient_accumulation_steps', 'greater_is_better', 'group_by_length', 'learning_rate', 'load_best_model_at_end', 'logging_steps', 'lr_scheduler_type', 'metric_for_best_model', 'num_train_epochs', 'output_dir', 'per_device_eval_batch_size', 'per_device_train_batch_size', 'predict_with_generate', 'report_to', 'save_steps', 'save_strategy', 'save_total_limit', 'warmup_ratio', 'weight_decay']


In [13]:
# Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=ds_tok['train'],
    eval_dataset=ds_tok['validation'] if len(ds_tok['validation']) > 0 else None,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics if len(ds_tok['validation']) > 0 else None,
)
print('Trainer ready')

  trainer = Seq2SeqTrainer(


Trainer ready


In [14]:
# Train & save
train_result = trainer.train()
trainer.save_model()
tokenizer.save_pretrained(OUTPUT_DIR)
print('Training complete. Model saved to', OUTPUT_DIR)

Step,Training Loss,Validation Loss,Bleu,Chrf,Exact Match,Wer,Cer
50,2.2176,0.838215,11.5013,74.574,12.12,83.91,35.81
100,0.4667,0.395513,45.8446,85.9488,39.39,29.34,10.91
150,0.2355,0.310947,63.3348,85.6441,53.54,21.45,30.79
200,0.1711,0.28221,66.2639,87.0002,60.61,18.93,27.57
250,0.1446,0.303058,65.4403,92.6772,62.63,17.67,6.03
300,0.1016,0.320178,62.2139,91.6218,58.59,22.4,8.64
350,0.0929,0.279958,67.7105,91.8322,60.61,19.87,9.06
400,0.0684,0.379857,55.4614,88.6394,55.56,18.93,11.27
450,0.1372,0.267897,61.8952,87.494,57.58,21.77,12.22
500,0.0618,0.273001,58.7938,87.5174,58.59,22.71,11.95


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight', 'lm_head.weight'].


Training complete. Model saved to out-mbart50-gec


In [15]:
# Evaluate on validation set + save predictions
if len(ds_tok['validation']) > 0:
    val_metrics = trainer.evaluate(max_length=GEN_MAXLEN, num_beams=NUM_BEAMS)
    print('Validation metrics:', val_metrics)
    preds = trainer.predict(ds_tok['validation'], max_length=GEN_MAXLEN, num_beams=NUM_BEAMS)
    pred_texts = safe_batch_decode(preds.predictions)
    df_out = df_val.copy()
    df_out['prediction'] = pred_texts
    p = OUTPUT_DIR / 'val_predictions.csv'
    df_out.to_csv(p, index=False, encoding='utf-8-sig')
    print('Saved validation predictions to:', p)
else:
    print('No validation split available.')

Validation metrics: {'eval_loss': 0.2747446298599243, 'eval_bleu': 66.2639, 'eval_chrf': 87.0002, 'eval_exact_match': 60.61, 'eval_wer': 18.93, 'eval_cer': 27.57, 'eval_runtime': 24.8132, 'eval_samples_per_second': 3.99, 'eval_steps_per_second': 1.008, 'epoch': 20.0}
Saved validation predictions to: out-mbart50-gec/val_predictions.csv


In [16]:
# Final test evaluation & predictions
if len(ds_tok['test']) > 0:
    test_metrics = trainer.evaluate(eval_dataset=ds_tok['test'], max_length=GEN_MAXLEN, num_beams=NUM_BEAMS, metric_key_prefix='test')
    print('Test metrics:', test_metrics)
    test_preds = trainer.predict(ds_tok['test'], max_length=GEN_MAXLEN, num_beams=NUM_BEAMS)
    df_test_out = df_test.copy()
    df_test_out['prediction'] = safe_batch_decode(test_preds.predictions)
    p = OUTPUT_DIR / 'test_predictions.csv'
    df_test_out.to_csv(p, index=False, encoding='utf-8-sig')
    print('Saved test predictions to:', p)
else:
    print('No test split available.')

Test metrics: {'test_loss': 0.3413468301296234, 'test_bleu': 72.1487, 'test_chrf': 85.06, 'test_exact_match': 59.0, 'test_wer': 18.91, 'test_cer': 40.54, 'test_runtime': 29.3699, 'test_samples_per_second': 3.405, 'test_steps_per_second': 0.851, 'epoch': 20.0}
Saved test predictions to: out-mbart50-gec/test_predictions.csv


In [21]:
# Inference helper (device-aware) for ad-hoc sentences
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device); model.eval()

pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
eos_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else pad_id

def correct_sentence(text: str, max_new_tokens: int = GEN_MAXLEN) -> str:
    inp = TASK_PREFIX + text.strip()
    enc = tokenizer([inp], return_tensors='pt', truncation=True, max_length=MAX_SRC_LEN)
    enc = {k: v.to(device) for k, v in enc.items()}
    with torch.no_grad():
        out = model.generate(
            **enc,
            num_beams=NUM_BEAMS,
            max_length=max_new_tokens,
            pad_token_id=pad_id,
            eos_token_id=eos_id,
            forced_bos_token_id=tokenizer.lang_code_to_id[LANG_CODE],
            no_repeat_ngram_size=3,
        )
    return tokenizer.decode(out[0], skip_special_tokens=True)

# Quick smoke test (replace with your sentence)
sample_text = "ဤပန်းကန်ထဲ ဟင်းက မွှေးသောကြောင့် စားချင်စဖွယ်ရှိသည်။"
print('Input :', sample_text)
print('Output:', correct_sentence(sample_text))

Input : ဤပန်းကန်ထဲ ဟင်းက မွှေးသောကြောင့် စားချင်စဖွယ်ရှိသည်။
Output: ဤပန်းကန်ထဲမှဟင်းက မွှေးသောကြောင့် စားချင်စဖွယ်ရှိသည်။
