
# Burmese Grammar Correction on CPU (ByT5-small)

This notebook fine-tunes a small multilingual model (**google/byt5-small**) for Burmese grammar correction on **CPU only**.  
It expects a CSV dataset with two columns (auto-detected): e.g. **Incorrect** and **Correct**.

**What you'll get:**
- Robust CSV loader that auto-detects source/target columns (e.g. `Incorrect` → `Correct`).
- CPU-only training (no GPU required) with sensible defaults.
- Evaluation metrics (BLEU, chrF, exact match, WER, CER).
- Saved validation predictions to a CSV.
- `correct_sentence(text)` function for inference.


In [None]:

# If running locally for the first time, uncomment to install dependencies:
# %pip install -U transformers datasets evaluate sacrebleu jiwer accelerate sentencepiece
!pip install evaluate
!pip install sacrebleu
!pip install jiwer





In [None]:

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import math
import random
from pathlib import Path

import pandas as pd
from datasets import Dataset, DatasetDict
import numpy as np

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)

import evaluate
import sacrebleu
from jiwer import cer, wer

print("Torch:", torch.__version__)
import transformers, datasets
print("Transformers:", transformers.__version__)
print("Datasets:", datasets.__version__)



Torch: 2.8.0+cu126
Transformers: 4.55.2
Datasets: 4.0.0


In [None]:

# ===================== Config =====================
DATA_CSV   = Path("testdata.csv")   # Your uploaded file
OUTPUT_DIR = Path("./out-mbart50-gec")        # Where to save model & outputs
MODEL_NAME = "facebook/mbart-large-50-many-to-many-mmt"           # small & CPU-friendly

# Sequence lengths (shorter = faster; increase if needed later)
MAX_SRC_LEN = 256
MAX_TGT_LEN = 256

SEED       = 42

# Training schedule (keep small for CPU; increase after a quick sanity run)
EPOCHS     = 5
LR         = 5e-4
BATCH_SIZE = 4          # per-device batch size; if OOM on CPU, lower to 2 or 1
GRAD_ACCUM = 2          # effective batch size ~= BATCH_SIZE * GRAD_ACCUM
LOG_STEPS  = 50
SAVE_STEPS = 200

# Generation settings (used for eval & inference)
NUM_BEAMS  = 4
GEN_MAXLEN = MAX_TGT_LEN
# ==================================================

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
assert DATA_CSV.exists(), f"Dataset not found at {DATA_CSV}. Upload it or fix DATA_CSV."
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)


<torch._C.Generator at 0x7d4a97b0c930>

In [None]:
# --- Robust CSV load (handles common encodings) ---
def _read_csv_smart(path: Path):
    errs = []
    for enc in ["utf-8-sig", "utf-8", "utf-16", "cp932"]:
        try:
            return pd.read_csv(path, encoding=enc)
        except Exception as e:
            errs.append((enc, str(e)))
    raise RuntimeError(f"Failed to read CSV with common encodings. Errors: {errs}")

df_raw = _read_csv_smart(DATA_CSV)
print("Columns detected:", list(df_raw.columns))
print("Sample rows:", len(df_raw))
display(df_raw.head(10))

Columns detected: ['source', 'target']
Sample rows: 999


Unnamed: 0,source,target
0,သူအိမ်ကို ထွက်သွားသည်။,သူအိမ်သို့ ထွက်သွားသည်။
1,သူစာသင်ကျောင်းကို သွားသည်။,သူစာသင်ကျောင်းသို့ သွားသည်။
2,ကျောင်းမှ စာမေးပွဲဖြေသည်။,ကျောင်း၌ စာမေးပွဲဖြေသည်။
3,ဆေးရုံကို စောင့်ဆိုင်းသည်။,ဆေးရုံ၌ စောင့်ဆိုင်းသည်။
4,စခန်းမှ ထွက်လာသည်။,စခန်းကို ထွက်လာသည်။
5,သူအလုပ်သွားကို ပြင်ဆင်သည်။,သူအလုပ်သွားဖို့ ပြင်ဆင်သည်။
6,သူအိမ်ကို ပြန်လာသည်။,သူအိမ်သို့ ပြန်လာသည်။
7,သူအတန်းမှ ထွက်သွားသည်။,သူအတန်းကို ထွက်သွားသည်။
8,သူဈေးကို သွားသည်။,သူဈေးသို့ သွားသည်။
9,မိဘအိမ်မှ ပြန်လာသည်။,မိဘအိမ်သို့ ပြန်လာသည်။


In [None]:

SRC_CANDS = ['incorrect','error','noisy','input','source','wrong','bad','sentence_in']
TGT_CANDS = ['correct','target','fixed','output','sentence_out','sentence']

def _detect_cols(columns):
    lower = {c.lower(): c for c in columns}
    src = None
    tgt = None
    for k in SRC_CANDS:
        if k in lower:
            src = lower[k]; break
    for k in TGT_CANDS:
        if k in lower:
            tgt = lower[k]; break
    # Fallback common case: 'Incorrect' and 'Correct'
    if src is None and 'Incorrect' in columns:
        src = 'Incorrect'
    if tgt is None and 'Correct' in columns:
        tgt = 'Correct'
    return src, tgt

src_col, tgt_col = _detect_cols(df_raw.columns)
if not src_col or not tgt_col:
    raise ValueError(
        f"Could not detect columns. Found: {list(df_raw.columns)}\n"
        f"Please rename columns to e.g. 'Incorrect' and 'Correct'."
    )
print(f"Using columns -> source: '{src_col}', target: '{tgt_col}'")

df = df_raw[[src_col, tgt_col]].dropna()
df = df.astype({src_col: str, tgt_col: str})
df[src_col] = df[src_col].str.strip()
df[tgt_col] = df[tgt_col].str.strip()
df = df[df[src_col] != ""]

print("After cleaning, rows:", len(df))
display(df.head(3))

Using columns -> source: 'source', target: 'target'
After cleaning, rows: 999


Unnamed: 0,source,target
0,သူအိမ်ကို ထွက်သွားသည်။,သူအိမ်သို့ ထွက်သွားသည်။
1,သူစာသင်ကျောင်းကို သွားသည်။,သူစာသင်ကျောင်းသို့ သွားသည်။
2,ကျောင်းမှ စာမေးပွဲဖြေသည်။,ကျောင်း၌ စာမေးပွဲဖြေသည်။


In [None]:
# Train/Val split

n_total = len(df)
n_train = int(n_total * 0.7)
n_val = int(n_total * 0.15)
n_test = n_total - n_train - n_val
df_shuf = df.sample(frac=1.0, random_state=SEED).reset_index(drop=True)

df_train = df_shuf.iloc[:n_train]
df_val   = df_shuf.iloc[n_train:n_train+n_val]
df_test  = df_shuf.iloc[n_train+n_val:]
print(f"Split -> train: {len(df_train)}, val: {len(df_val)}, test: {len(df_test)}")

# Bootstrapping
BOOTSTRAP_FACTOR = 10  # 10x larger training set
df_train_bootstrapped = df_train.sample(
    n=len(df_train) * BOOTSTRAP_FACTOR,
    replace=True,
    random_state=SEED
).reset_index(drop=True)

print(f"Unique samples -> train: {len(df_train)}, val: {len(df_val)}, test: {len(df_test)}")
print(f"Bootstrapped -> train: {len(df_train_bootstrapped)} ({BOOTSTRAP_FACTOR}x), val: {len(df_val)}, test: {len(df_test)}")

ds_train = Dataset.from_pandas(df_train_bootstrapped, preserve_index=False)
ds_val   = Dataset.from_pandas(df_val,   preserve_index=False)
ds_test  = Dataset.from_pandas(df_test,  preserve_index=False)


# Create DatasetDict with all three splits
ds = DatasetDict({
    "train": ds_train,
    "validation": ds_val,
    "test": ds_test})


Split -> train: 699, val: 149, test: 151
Unique samples -> train: 699, val: 149, test: 151
Bootstrapped -> train: 6990 (10x), val: 149, test: 151


In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
model     = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

# Helpful prefix to guide the model. You can tweak this (e.g., "fix: " or "gec: ").
TASK_PREFIX = "fix: "

In [None]:
import numpy as np

# Make sure pad_token_id is valid for decoding
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = getattr(tokenizer, "eos_token_id", None) or 0
model.config.pad_token_id = tokenizer.pad_token_id

# Get a reliable vocab size (tokenizer first, else model.config)
VOCAB_SIZE = getattr(tokenizer, "vocab_size", None) or getattr(model.config, "vocab_size", None)
assert VOCAB_SIZE is not None, "Could not determine vocab size."

def _clamp_ids(arr):
    """Clamp any out-of-range/negative ids to pad_token_id and return int64."""
    arr = np.asarray(arr)
    arr = np.where(arr < 0, tokenizer.pad_token_id, arr)
    arr = np.where(arr >= VOCAB_SIZE, tokenizer.pad_token_id, arr)
    return arr.astype("int64")

def safe_batch_decode(batch_ids):
    """
    Try normal batch_decode on clamped ids.
    Since mBART50 uses standard tokenization, no need for ByT5 specific decoding.
    """
    ids = _clamp_ids(batch_ids)

    # Use batch_decode directly as mBART50 handles this in standard way
    return tokenizer.batch_decode(ids, skip_special_tokens=True)


In [None]:
def preprocess_fn(batch):
    # Prepare inputs with task prefix, ensuring no None or empty values
    inputs = [TASK_PREFIX + x for x in batch["%s" % src_col] if x and x != ""]
    targets = [x for x in batch["%s" % tgt_col] if x and x != ""]

    # Skip batch if inputs or targets are empty after filtering
    if len(inputs) == 0 or len(targets) == 0:
        return {}

    # Tokenize inputs (standard tokenization for mBART50)
    model_inputs = tokenizer(
        inputs,
        max_length=MAX_SRC_LEN,
        truncation=True,
        padding=False,
    )

    # Tokenize targets (standard tokenization for mBART50)
    labels = tokenizer(
        targets,
        max_length=MAX_TGT_LEN,
        truncation=True,
        padding=False
    )

    # Add labels to model inputs for training
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Ensure no None or empty values are in the dataset before mapping
ds = ds.filter(lambda x: x is not None and x != "" and all(v is not None for v in x.values()))

# Tokenize the dataset using the updated preprocess function
ds_tok = ds.map(preprocess_fn, batched=True, remove_columns=ds["train"].column_names)

# Print the tokenized dataset for verification
print(ds_tok)


Filter:   0%|          | 0/6990 [00:00<?, ? examples/s]

Filter:   0%|          | 0/149 [00:00<?, ? examples/s]

Filter:   0%|          | 0/151 [00:00<?, ? examples/s]

Map:   0%|          | 0/6990 [00:00<?, ? examples/s]

Map:   0%|          | 0/149 [00:00<?, ? examples/s]

Map:   0%|          | 0/151 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 6990
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 149
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 151
    })
})


In [None]:

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding="longest")


In [None]:
import evaluate, sacrebleu
from jiwer import cer, wer

bleu = evaluate.load("sacrebleu")

def compute_metrics(eval_preds):
    pred_ids, label_ids = eval_preds

    # Replace -100 in labels BEFORE decode
    label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)

    preds = safe_batch_decode(pred_ids)
    refs  = safe_batch_decode(label_ids)

    bleu_res = bleu.compute(predictions=preds, references=[[r] for r in refs])
    chrf = sacrebleu.corpus_chrf(preds, [refs]).score
    exact = float(np.mean([p.strip() == r.strip() for p, r in zip(preds, refs)]) * 100.0)
    _wer = wer(refs, preds) * 100.0
    _cer = cer(refs, preds) * 100.0

    return {
        "bleu": round(bleu_res["score"], 4),
        "chrf": round(chrf, 4),
        "exact_match": round(exact, 2),
        "wer": round(_wer, 2),
        "cer": round(_cer, 2),
    }



In [None]:

# Keep arguments conservative to avoid version mismatches across Transformers.
SAFE_ARGS = dict(
    output_dir=str(OUTPUT_DIR),
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LR,
    num_train_epochs=EPOCHS,
    save_steps=min(SAVE_STEPS, max(50, len(ds_tok["train"])//2)) if len(ds_tok["train"]) > 0 else SAVE_STEPS,
    logging_steps=LOG_STEPS,
    save_total_limit=2,
    predict_with_generate=True,
    generation_max_length=GEN_MAXLEN,
    generation_num_beams=NUM_BEAMS,
    fp16=False,
    bf16=False,
    group_by_length=True,
    report_to="none",
)

print("Using Seq2SeqTrainingArguments with keys:", sorted(SAFE_ARGS.keys()))
training_args = Seq2SeqTrainingArguments(**SAFE_ARGS)


Using Seq2SeqTrainingArguments with keys: ['bf16', 'fp16', 'generation_max_length', 'generation_num_beams', 'gradient_accumulation_steps', 'group_by_length', 'learning_rate', 'logging_steps', 'num_train_epochs', 'output_dir', 'per_device_eval_batch_size', 'per_device_train_batch_size', 'predict_with_generate', 'report_to', 'save_steps', 'save_total_limit']


In [None]:

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=ds_tok["train"],
    eval_dataset=ds_tok["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics if len(ds_tok["validation"]) > 0 else None,
)


  trainer = Seq2SeqTrainer(


In [64]:

train_result = trainer.train()
trainer.save_model()
tokenizer.save_pretrained(OUTPUT_DIR)


Step,Training Loss
50,4.6518
100,3.3867
150,2.769
200,2.6266
250,2.1437
300,2.0675
350,1.7159
400,1.7451
450,1.4932
500,1.4595




('out-mt5-gec/tokenizer_config.json',
 'out-mt5-gec/special_tokens_map.json',
 'out-mt5-gec/sentencepiece.bpe.model',
 'out-mt5-gec/added_tokens.json',
 'out-mt5-gec/tokenizer.json')

In [65]:
print("=== VALIDATION SET SAMPLES ===")
for i in range(min(3, len(df_val))):
    print(f"Val sample {i}:")
    print(f"  Source: {df_val.iloc[i][src_col]}")
    print(f"  Target: {df_val.iloc[i][tgt_col]}")
    print()

print("=== TEST SET SAMPLES ===")
for i in range(min(3, len(df_test))):
    print(f"Test sample {i}:")
    print(f"  Source: {df_test.iloc[i][src_col]}")
    print(f"  Target: {df_test.iloc[i][tgt_col]}")
    print()

# Or see the first few rows:
print("Validation set head:")
print(df_val.head(3))
print("\nTest set head:")
print(df_test.head(3))

=== VALIDATION SET SAMPLES ===
Val sample 0:
  Source: အချုပ်အခြာအာဏာပိုင်က လက်ထောက်ကို စာရင်းစစ်ရန် လွှတ်လိုက်သည်။
  Target: အချုပ်အခြာအာဏာပိုင်က လက်ထောက်ကို စာရင်းစစ်ရန် စေလွှတ်လိုက်သည်။

Val sample 1:
  Source: ဤအစီအစဉ်ကိုလက်ခံမည်ဆို အစည်းအဝေးပိတ်သိမ်းမည်။
  Target: ဤအစီအစဉ်ကိုလက်ခံမည်ဆိုလျှင် အစည်းအဝေးပိတ်သိမ်းမည်။

Val sample 2:
  Source: ဓာတ်ပုံအဖြစ် ကြည့်မည်။
  Target: ဓာတ်ပုံကို ကြည့်မည်။

=== TEST SET SAMPLES ===
Test sample 0:
  Source: အကြောင်းအရာဖြင့် မဖြေရှင်းနိုင်ဘူး။
  Target: အကြောင်းအရာကြောင့် မဖြေရှင်းနိုင်ဘူး။

Test sample 1:
  Source: ကျွန်တော် မပြောမည်။
  Target: ကျွန်တော် မပြောဘူး။

Test sample 2:
  Source: အချုပ်အခြာအာဏာပိုင်က သံတမန်ကို နိုင်ငံခြားသို့ လွှတ်လိုက်သည်။
  Target: အချုပ်အခြာအာဏာပိုင်က သံတမန်ကို နိုင်ငံခြားသို့ စေလွှတ်လိုက်သည်။

Validation set head:
                                                source  \
699  အချုပ်အခြာအာဏာပိုင်က လက်ထောက်ကို စာရင်းစစ်ရန် ...   
700      ဤအစီအစဉ်ကိုလက်ခံမည်ဆို အစည်းအဝေးပိတ်သိမ်းမည်။   
701                            

In [66]:
# Validation
if len(ds_tok["validation"]) > 0:
    metrics = trainer.evaluate(max_length=GEN_MAXLEN, num_beams=NUM_BEAMS)
    print("Validation metrics:", metrics)

    preds = trainer.predict(ds_tok["validation"], max_length=GEN_MAXLEN, num_beams=NUM_BEAMS)
    pred_texts = safe_batch_decode(preds.predictions)

    df_out = df_val.copy()
    df_out["prediction"] = pred_texts
    out_csv = OUTPUT_DIR / "val_predictions.csv"
    df_out.to_csv(out_csv, index=False, encoding="utf-8-sig")
    print(f"Saved validation predictions to: {out_csv}")
else:
    print("No validation split (dataset too small). Skipping eval.")


Validation metrics: {'eval_loss': 5.769466400146484, 'eval_bleu': 3.1505, 'eval_chrf': 34.1514, 'eval_exact_match': 0.67, 'eval_wer': 87.42, 'eval_cer': 61.62, 'eval_runtime': 20.0677, 'eval_samples_per_second': 7.425, 'eval_steps_per_second': 1.894, 'epoch': 5.0}
Saved validation predictions to: out-mt5-gec/val_predictions.csv


In [None]:
print("\n" + "="*60)
print("FINAL TEST EVALUATION (on completely unseen data):")
print("="*60)

# Evaluate on TEST set (never seen during training)
test_metrics = trainer.evaluate(
    ds_tok["test"],
    max_length=GEN_MAXLEN,
    num_beams=NUM_BEAMS
)
print("TEST metrics (unbiased final results):", test_metrics)

# Get test predictions
test_preds = trainer.predict(ds_tok["test"], max_length=GEN_MAXLEN, num_beams=NUM_BEAMS)
test_pred_texts = safe_batch_decode(test_preds.predictions)

# Save test predictions
df_test_out = df_test.copy()
df_test_out["prediction"] = test_pred_texts
test_out_csv = OUTPUT_DIR / "test_predictions.csv"
df_test_out.to_csv(test_out_csv, index=False, encoding="utf-8-sig")
print(f"Saved TEST predictions to: {test_out_csv}")

In [None]:

def correct_sentence(text: str, max_new_tokens: int = GEN_MAXLEN) -> str:
    """Return a corrected version of the input sentence using the fine-tuned model."""
    model.eval()
    inp = TASK_PREFIX + text.strip()
    enc = tokenizer([inp], return_tensors="pt", truncation=True, max_length=MAX_SRC_LEN)
    with torch.no_grad():
        out = model.generate(
            **enc,
            num_beams=NUM_BEAMS,
            max_length=max_new_tokens,
        )
    return tokenizer.decode(out[0], skip_special_tokens=True)

# Quick smoke test on first few (after training)
try:
    for i in range(min(3, len(df_val))):
        src = df_val.iloc[i][src_col]
        tgt = df_val.iloc[i][tgt_col]
        pred = correct_sentence(src)
        print(f"Incorrect: {src}\nTarget   : {tgt}\nPred     : {pred}\n{'-'*60}")
except Exception as e:
    print("Inference smoke test skipped:", e)


In [69]:
# ---- Device setup (run once, after loading model/tokenizer) ----
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)                     # move weights to device
model.eval()                         # eval mode for inference

# Some models (e.g., mT5/ByT5/MBART) can be picky about pad/eos
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
eos_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else pad_id

# ---- Inference helper ----
def correct_sentence(text: str, max_new_tokens: int = GEN_MAXLEN) -> str:
    inp = "fix: " + text.strip()
    # BatchEncoding supports .to(device) to move all tensors together
    enc = tokenizer(
        [inp],
        return_tensors="pt",
        truncation=True,
        max_length=MAX_SRC_LEN,
        padding=True
    ).to(device)

    with torch.no_grad():
        out = model.generate(
            **enc,
            num_beams=NUM_BEAMS,
            # prefer max_new_tokens for “newly generated” length
            max_new_tokens=max_new_tokens,
            pad_token_id=pad_id,
            eos_token_id=eos_id
        )
    return tokenizer.decode(out[0], skip_special_tokens=True)

# ---- Quick smoke test (keeps everything on the same device) ----
try:
    n = min(3, len(df_val))
    for i in range(n):
        src = str(df_val.iloc[i][src_col])
        tgt = str(df_val.iloc[i][tgt_col])
        pred = correct_sentence(src)
        print(f"Incorrect: {src}\nTarget   : {tgt}\nPred     : {pred}\n{'-'*60}")
except Exception as e:
    print("Inference smoke test skipped:", repr(e))


Incorrect: အချုပ်အခြာအာဏာပိုင်က လက်ထောက်ကို စာရင်းစစ်ရန် လွှတ်လိုက်သည်။
Target   : အချုပ်အခြာအာဏာပိုင်က လက်ထောက်ကို စာရင်းစစ်ရန် စေလွှတ်လိုက်သည်။
Pred     : အကြီးအကဲက အဖွဲ့သားများကို အစည်းအဝေးသို့ စေလွှတ်လိုက်သည်။
------------------------------------------------------------
Incorrect: ဤအစီအစဉ်ကိုလက်ခံမည်ဆို အစည်းအဝေးပိတ်သိမ်းမည်။
Target   : ဤအစီအစဉ်ကိုလက်ခံမည်ဆိုလျှင် အစည်းအဝေးပိတ်သိမ်းမည်။
Pred     : ဤအစီအစဉ်ကို ကျောင်းအဖြစ် ပြောင်းရွှေ့သွားမည်။
------------------------------------------------------------
Incorrect: ဓာတ်ပုံအဖြစ် ကြည့်မည်။
Target   : ဓာတ်ပုံကို ကြည့်မည်။
Pred     : အရသာခံဖို့ စီစဉ်သည်။
------------------------------------------------------------


In [70]:

# Enter your own sentence to correct (after training completes)
sample_text = "မလာလျှင် "
print(correct_sentence(sample_text))


မနက်မှာ အစားအစာဝယ်သည်။
