In [None]:
"""akkadian_v5b_train.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/13Tup1kPmEkrrMMI7DXSzZ0w361ZP_3Cm

# Akkadian V5b Training (Glossary-Prompt)

Stage A: Publications English doc-level (optional)
Stage B: Sentence-level main training with glossary prompts

## 0. Setup (Colab)

Mount Google Drive if running on Colab.
"""

In [None]:
try:
    from google.colab import drive  # type: ignore

    drive.mount("/content/drive")
except Exception:
    pass

In [None]:
"""## 1. Imports & Configuration"""

In [None]:
!pip install -q sacrebleu

In [None]:
from __future__ import annotations

In [None]:
import json
import os
import random
import re
from collections import Counter, defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from sacrebleu.metrics import BLEU, CHRF
from transformers import (
    AutoModelForSeq2SeqLM,
    ByT5Tokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TrainerCallback,
    set_seed,
)

In [None]:
SRC_SPLIT_RE = re.compile(r"[\s\-]+")
TGT_TOKEN_RE = re.compile(r"[A-Za-z][A-Za-z'\-]*|\d+")

In [None]:
def tokenize_src(text: str) -> list[str]:
    if not text:
        return []
    return [t for t in SRC_SPLIT_RE.split(str(text)) if t]

In [None]:
def tokenize_tgt(text: str) -> list[str]:
    if not text:
        return []
    return TGT_TOKEN_RE.findall(str(text))

In [None]:
@dataclass
class Config:
    model_size: str = "small"  # "base" or "large" or small
    data_dir: Optional[Path] = None
    output_dir: Optional[Path] = None

    # Stage A (publications)
    use_publications_stage: bool = True
    stage_a_epochs: int = 2
    stage_a_lr: float = 5e-5

    # Stage B (sentence-level)
    stage_b_epochs: int = 8
    stage_b_lr: float = 1e-4

    # Sequence lengths
    max_source_length: int = 256
    max_target_length: int = 256

    # Training
    seed: int = 42
    # A100 40GB: push larger per-device batch and use modest accumulation
    batch_size: int = 16
    gradient_accumulation_steps: int = 2
    warmup_ratio: float = 0.1
    weight_decay: float = 0.01
    max_grad_norm: float = 1.0

    # Hardware
    fp16: bool = False
    bf16: bool = True
    gradient_checkpointing: bool = False
    dataloader_num_workers: int = 4

    # Glossary prompt
    use_glossary_prompt: bool = True
    glossary_path: Optional[Path] = None
    glossary_max_items: int = 8
    glossary_drop_prob_train: float = 0.5
    glossary_drop_prob_eval: float = 0.0

    # Glossary build params (if file missing)
    glossary_min_src_count: int = 5
    glossary_min_pair_count: int = 2
    glossary_min_score: float = 0.15
    glossary_max_targets: int = 2
    glossary_min_src_len: int = 2
    glossary_min_tgt_len: int = 2

    # Model-specific
    model_name: str = field(init=False)

    def __post_init__(self):
        if self.model_size == "small":
            self.model_name = "google/byt5-small"
            if self.output_dir is None:
                self.output_dir = Path("/content/drive/MyDrive/akkadian/v5b-small")
        elif self.model_size == "base":
            self.model_name = "google/byt5-base"
            if self.output_dir is None:
                self.output_dir = Path("/content/drive/MyDrive/akkadian/v5b-base")
        else:
            self.model_name = "google/byt5-large"
            if self.output_dir is None:
                self.output_dir = Path("/content/drive/MyDrive/akkadian/v5b-large")

In [None]:
def resolve_data_dir() -> Path:
    env = os.environ.get("V5B_DATA_DIR")
    if env:
        p = Path(env)
        if p.exists():
            return p

    # Colab common locations (Google Drive)
    colab_candidates = [
        Path("/content/drive/MyDrive/akkadian/data/v5b"),
        Path("/content/drive/MyDrive/akkadian/v5b"),
        Path("/content/drive/MyDrive/data/v5b"),
        Path("/content/drive/MyDrive/v5b"),
    ]
    for p in colab_candidates:
        if (p / "v5_sentence_train.csv").exists():
            return p

    local = Path("data/v5b")
    if local.exists():
        return local

    fallback = Path("data/v5")
    if fallback.exists():
        return fallback

    kaggle_input = Path("/kaggle/input")
    if kaggle_input.exists():
        for d in kaggle_input.iterdir():
            if (d / "v5_sentence_train.csv").exists():
                return d

    raise FileNotFoundError(
        "V5b data directory not found. Set V5B_DATA_DIR or place data/v5b. "
        "For Colab, put data in /content/drive/MyDrive/akkadian/data/v5b (or similar)."
    )

In [None]:
def resolve_glossary_path(data_dir: Path) -> Optional[Path]:
    candidates = [
        data_dir / "v5b_glossary.json",
        Path("data/v5b/v5b_glossary.json"),
    ]
    for p in candidates:
        if p.exists():
            return p
    return None

In [None]:
CFG = Config(model_size="small")
CFG.data_dir = resolve_data_dir()
CFG.glossary_path = resolve_glossary_path(CFG.data_dir)
CFG.output_dir.mkdir(parents=True, exist_ok=True)

In [None]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [None]:
print("=" * 60)
print(f"üöÄ Akkadian V5b Training: {CFG.model_size.upper()}")
print("=" * 60)
print(f"üìÅ Data: {CFG.data_dir}")
print(f"üìÅ Output: {CFG.output_dir}")
print(f"ü§ñ Model: {CFG.model_name}")
print(f"üéÆ CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
print("=" * 60)

In [None]:
set_seed(CFG.seed)

In [None]:
"""## 2. Helpers"""

In [None]:
def load_pairs(path: Path) -> pd.DataFrame:
    df = pd.read_csv(path)
    if not {"src", "tgt"}.issubset(df.columns):
        raise ValueError(f"Missing src/tgt columns: {path}")
    df = df.dropna(subset=["src", "tgt"]).reset_index(drop=True)
    return df

In [None]:
def build_glossary_from_df(df: pd.DataFrame) -> dict[str, list[str]]:
    src_count: Counter[str] = Counter()
    cooc: dict[str, Counter[str]] = defaultdict(Counter)

    for src, tgt in zip(df["src"], df["tgt"]):
        src_tokens = set(t for t in tokenize_src(src) if len(t) >= CFG.glossary_min_src_len)
        tgt_tokens = set(t for t in tokenize_tgt(tgt) if len(t) >= CFG.glossary_min_tgt_len)
        if not src_tokens or not tgt_tokens:
            continue
        for s in src_tokens:
            src_count[s] += 1
        for s in src_tokens:
            for t in tgt_tokens:
                cooc[s][t] += 1

    glossary: dict[str, list[str]] = {}
    for s, total in src_count.items():
        if total < CFG.glossary_min_src_count:
            continue
        candidates = []
        for t, c in cooc[s].items():
            if c < CFG.glossary_min_pair_count:
                continue
            score = c / total
            if score < CFG.glossary_min_score:
                continue
            candidates.append((score, c, t))
        candidates.sort(key=lambda x: (-x[0], -x[1], x[2]))
        if candidates:
            glossary[s] = [t for _, _, t in candidates[: CFG.glossary_max_targets]]

    return glossary

In [None]:
def load_glossary(path: Path) -> dict[str, list[str]]:
    with path.open("r", encoding="utf-8") as f:
        data = json.load(f)
    return {k: list(v) for k, v in data.items()}

In [None]:
def build_glossary_prompt(
    src: str,
    glossary: dict[str, list[str]] | None,
    max_items: int,
    drop_prob: float,
    rng: random.Random,
) -> str:
    if not glossary:
        return src
    if drop_prob > 0 and rng.random() < drop_prob:
        return src

    items: list[str] = []
    used = set()
    for tok in tokenize_src(src):
        if tok in used:
            continue
        tgts = glossary.get(tok)
        if not tgts:
            continue
        tgt = tgts[0]
        items.append(f"{tok}={tgt}")
        used.add(tok)
        if len(items) >= max_items:
            break

    if not items:
        return src

    return "GLOSSARY: " + "; ".join(items) + " ||| " + src

In [None]:
def apply_glossary(df: pd.DataFrame, glossary: dict[str, list[str]], drop_prob: float, seed: int) -> pd.DataFrame:
    rng = random.Random(seed)
    df = df.copy()
    df["src_aug"] = [
        build_glossary_prompt(src, glossary, CFG.glossary_max_items, drop_prob, rng)
        for src in df["src"].tolist()
    ]
    return df

In [None]:
def build_compute_metrics(tokenizer):
    bleu = BLEU()
    chrf = CHRF(word_order=2)

    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        if isinstance(predictions, tuple):
            predictions = predictions[0]

        predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        decoded_preds = [p.strip() for p in decoded_preds]
        decoded_labels = [[l.strip()] for l in decoded_labels]

        bleu_score = bleu.corpus_score(decoded_preds, decoded_labels).score
        chrf_score = chrf.corpus_score(decoded_preds, decoded_labels).score
        geo_mean = np.sqrt(bleu_score * chrf_score) if bleu_score > 0 and chrf_score > 0 else 0.0

        return {"bleu": bleu_score, "chrf": chrf_score, "geo_mean": geo_mean}

    return compute_metrics

In [None]:
class LogCallback(TrainerCallback):
    def __init__(self, label: str):
        self.label = label
        self.epoch = 0
        self.losses = []

    def on_epoch_begin(self, args, state, control, **kwargs):
        self.epoch = int(state.epoch) if state.epoch else 0
        self.losses = []
        print(f"\n{'='*60}\nüìä {self.label} Epoch {self.epoch + 1}/{args.num_train_epochs}\n{'='*60}")

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs and "loss" in logs:
            self.losses.append(logs["loss"])

    def on_epoch_end(self, args, state, control, **kwargs):
        if self.losses:
            print(f"\nüìâ {self.label} Train Loss: {sum(self.losses)/len(self.losses):.4f}")

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if metrics:
            print(f"\n{'‚îÄ'*40}\nüìà Validation ({self.label})\n{'‚îÄ'*40}")
            print(f"   BLEU: {metrics.get('eval_bleu', 0):.2f}")
            print(f"   chrF: {metrics.get('eval_chrf', 0):.2f}")
            print(f"   Geo:  {metrics.get('eval_geo_mean', 0):.2f}\n{'‚îÄ'*40}")

In [None]:
class HistoryCallback(TrainerCallback):
    """Collect training/eval metrics for plotting."""

    def __init__(self, label: str):
        self.label = label
        self.train_steps: list[int] = []
        self.train_losses: list[float] = []
        self.eval_epochs: list[float] = []
        self.eval_bleu: list[float] = []
        self.eval_chrf: list[float] = []
        self.eval_geo: list[float] = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs and "loss" in logs:
            self.train_steps.append(int(state.global_step))
            self.train_losses.append(float(logs["loss"]))

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if not metrics:
            return
        epoch = float(state.epoch) if state.epoch is not None else 0.0
        self.eval_epochs.append(epoch)
        self.eval_bleu.append(float(metrics.get("eval_bleu", 0.0)))
        self.eval_chrf.append(float(metrics.get("eval_chrf", 0.0)))
        self.eval_geo.append(float(metrics.get("eval_geo_mean", 0.0)))

    def plot(self):
        if not self.train_steps and not self.eval_epochs:
            return

        fig, axes = plt.subplots(1, 2, figsize=(12, 4))
        if self.train_steps:
            axes[0].plot(self.train_steps, self.train_losses, label="train_loss")
            axes[0].set_title(f"{self.label} Train Loss")
            axes[0].set_xlabel("step")
            axes[0].set_ylabel("loss")
            axes[0].grid(True, alpha=0.3)
        else:
            axes[0].set_visible(False)

        if self.eval_epochs:
            axes[1].plot(self.eval_epochs, self.eval_bleu, label="BLEU")
            axes[1].plot(self.eval_epochs, self.eval_chrf, label="chrF")
            axes[1].plot(self.eval_epochs, self.eval_geo, label="GeoMean")
            axes[1].set_title(f"{self.label} Eval Metrics")
            axes[1].set_xlabel("epoch")
            axes[1].set_ylabel("score")
            axes[1].legend()
            axes[1].grid(True, alpha=0.3)
        else:
            axes[1].set_visible(False)

        fig.tight_layout()
        plt.show()

In [None]:
"""## 3. Load Data"""

In [None]:
print("üìñ Loading V5b datasets...")

In [None]:
sentence_train_path = CFG.data_dir / "v5_sentence_train.csv"
sentence_val_path = CFG.data_dir / "v5_sentence_val.csv"

In [None]:
if not sentence_train_path.exists() or not sentence_val_path.exists():
    raise FileNotFoundError("v5_sentence_train/val.csv not found in data dir")

In [None]:
sent_train_df = load_pairs(sentence_train_path)
sent_val_df = load_pairs(sentence_val_path)

In [None]:
pub_pairs_path = CFG.data_dir / "v5_publications_doc_pairs.csv"
pub_df = load_pairs(pub_pairs_path) if pub_pairs_path.exists() else None

In [None]:
print(f"   Sentence train: {len(sent_train_df):,}")
print(f"   Sentence val: {len(sent_val_df):,}")
if pub_df is not None:
    print(f"   Publications doc pairs: {len(pub_df):,}")
else:
    print("   Publications doc pairs: not found")

In [None]:
"""## 4. Glossary Prompt (Stage B)"""

In [None]:
if CFG.use_glossary_prompt:
    glossary = None
    if CFG.glossary_path and CFG.glossary_path.exists():
        print(f"üß† Loading glossary: {CFG.glossary_path}")
        glossary = load_glossary(CFG.glossary_path)
    else:
        print("üß† Building glossary from train set (file not found)")
        glossary = build_glossary_from_df(sent_train_df)

    print(f"   Glossary size: {len(glossary):,}")

    sent_train_df = apply_glossary(
        sent_train_df,
        glossary,
        drop_prob=CFG.glossary_drop_prob_train,
        seed=CFG.seed,
    )
    sent_val_df = apply_glossary(
        sent_val_df,
        glossary,
        drop_prob=CFG.glossary_drop_prob_eval,
        seed=CFG.seed,
    )
else:
    print("üß† Glossary prompt disabled")

In [None]:
"""## 5. Model Setup"""

In [None]:
print(f"\nü§ñ Loading model: {CFG.model_name}")

In [None]:
tokenizer = ByT5Tokenizer.from_pretrained(CFG.model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(CFG.model_name)

In [None]:
print(f"   Tokenizer: {len(tokenizer)}, Model vocab: {model.config.vocab_size}")
print(f"   Params: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
if CFG.gradient_checkpointing:
    model.gradient_checkpointing_enable()
    print("   ‚úÖ Gradient checkpointing enabled")

In [None]:
"""## 6. Tokenization"""

In [None]:
def tokenize_fn(examples):
    src_key = "src_aug" if "src_aug" in examples else "src"
    model_inputs = tokenizer(
        examples[src_key],
        max_length=CFG.max_source_length,
        truncation=True,
        padding=False,
    )
    labels = tokenizer(
        examples["tgt"],
        max_length=CFG.max_target_length,
        truncation=True,
        padding=False,
    )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
def to_dataset(df: pd.DataFrame) -> Dataset:
    cols = ["src", "tgt"]
    if "src_aug" in df.columns:
        cols.append("src_aug")
    ds = Dataset.from_pandas(df[cols])
    return ds.map(tokenize_fn, batched=True, remove_columns=cols)

In [None]:
"""## 7. Stage A: Publications Doc-Level (optional)"""

In [None]:
if CFG.use_publications_stage and pub_df is not None and len(pub_df) > 0:
    print("\nüèÅ Stage A: Publications doc-level")
    pub_train_ds = to_dataset(pub_df)

    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)
    history_a = HistoryCallback("Stage A")

    stage_a_args = dict(
        output_dir=str(CFG.output_dir / "stage_a_checkpoints"),
        num_train_epochs=CFG.stage_a_epochs,
        per_device_train_batch_size=CFG.batch_size,
        gradient_accumulation_steps=CFG.gradient_accumulation_steps,
        learning_rate=CFG.stage_a_lr,
        weight_decay=CFG.weight_decay,
        warmup_ratio=CFG.warmup_ratio,
        max_grad_norm=CFG.max_grad_norm,
        fp16=CFG.fp16,
        bf16=CFG.bf16,
        evaluation_strategy="no",
        save_strategy="epoch",
        save_total_limit=1,
        predict_with_generate=False,
        dataloader_num_workers=CFG.dataloader_num_workers,
        logging_steps=50,
        report_to="none",
        seed=CFG.seed,
    )

    try:
        training_args = Seq2SeqTrainingArguments(**stage_a_args)
    except TypeError:
        stage_a_args["eval_strategy"] = stage_a_args.pop("evaluation_strategy")
        training_args = Seq2SeqTrainingArguments(**stage_a_args)

    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=pub_train_ds,
        processing_class=tokenizer,
        data_collator=data_collator,
        callbacks=[LogCallback("Stage A"), history_a],
    )

    trainer.train()
    history_a.plot()
else:
    print("\n‚è≠Ô∏è  Stage A skipped (no publications data or disabled)")

In [None]:
"""## 8. Stage B: Sentence-Level Main Training"""

In [None]:
print("\nüèÅ Stage B: Sentence-level training")

In [None]:
sent_train_ds = to_dataset(sent_train_df)
sent_val_ds = to_dataset(sent_val_df)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)

In [None]:
stage_b_args = dict(
    output_dir=str(CFG.output_dir / "stage_b_checkpoints"),
    num_train_epochs=CFG.stage_b_epochs,
    per_device_train_batch_size=CFG.batch_size,
    per_device_eval_batch_size=CFG.batch_size * 2,
    gradient_accumulation_steps=CFG.gradient_accumulation_steps,
    learning_rate=CFG.stage_b_lr,
    weight_decay=CFG.weight_decay,
    warmup_ratio=CFG.warmup_ratio,
    max_grad_norm=CFG.max_grad_norm,
    fp16=CFG.fp16,
    bf16=CFG.bf16,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_geo_mean",
    greater_is_better=True,
    predict_with_generate=True,
    generation_max_length=CFG.max_target_length,
    dataloader_num_workers=CFG.dataloader_num_workers,
    logging_steps=50,
    report_to="none",
    seed=CFG.seed,
)

In [None]:
try:
    training_args = Seq2SeqTrainingArguments(**stage_b_args)
except TypeError:
    stage_b_args["eval_strategy"] = stage_b_args.pop("evaluation_strategy")
    training_args = Seq2SeqTrainingArguments(**stage_b_args)

In [None]:
history_b = HistoryCallback("Stage B")
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=sent_train_ds,
    eval_dataset=sent_val_ds,
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=build_compute_metrics(tokenizer),
    callbacks=[LogCallback("Stage B"), history_b],
)

In [None]:
trainer.train()
history_b.plot()

In [None]:
"""## 9. Save Model"""

In [None]:
model_dir = CFG.output_dir / "model"
trainer.save_model(str(model_dir))
tokenizer.save_pretrained(str(model_dir))
print(f"\nüíæ Saved: {model_dir}")

In [None]:
results = trainer.evaluate()
print(f"\nüìà Final: BLEU={results.get('eval_bleu',0):.2f}, chrF={results.get('eval_chrf',0):.2f}, Geo={results.get('eval_geo_mean',0):.2f}")

In [None]:
print(f"\n{'='*60}\n‚úÖ V5b Training Complete!\n{'='*60}")