# EN-AR Model Training

## Objective
- Train a bidirectional EN <-> AR encoder-decoder model from random initialization.
- Use the cleaned combined dataset exported by Notebook 01.

## Scope
- Notebook-first, micro-step implementation (1-2 short cells per step).


In [1]:
import os
os.environ["PATH"] = r"C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Tools\MSVC\14.44.35207\bin\Hostx64\x64" + os.pathsep + os.environ.get("PATH", "")

In [2]:
# Setup: imports and training constants
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pandas as pd
import torch

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

candidate_roots = [Path.cwd(), Path.cwd().parent]
PROJECT_ROOT = next((r for r in candidate_roots if (r / "artifacts").exists()), Path.cwd())
DATA_PATH = PROJECT_ROOT / "artifacts" / "eda" / "final_cleaned_combined_dataset.parquet"

@dataclass
class TrainConfig:
    max_seq_len: int = 128
    vocab_size: int = 32_000
    train_ratio: float = 0.90
    val_ratio: float = 0.05
    test_ratio: float = 0.05

config = TrainConfig()
print(f"Project root: {PROJECT_ROOT}")
print(f"Dataset path: {DATA_PATH}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(config)


Project root: c:\My Projects\en-ar-translation
Dataset path: c:\My Projects\en-ar-translation\artifacts\eda\final_cleaned_combined_dataset.parquet
CUDA available: True
TrainConfig(max_seq_len=128, vocab_size=32000, train_ratio=0.9, val_ratio=0.05, test_ratio=0.05)


In [3]:
# Micro-step 2: load cleaned dataset and validate schema
assert DATA_PATH.exists(), f"Cleaned dataset not found: {DATA_PATH}"
df = pd.read_parquet(DATA_PATH)

required_columns = ["en", "ar"]
missing = [c for c in required_columns if c not in df.columns]
assert not missing, f"Missing required columns: {missing}"

print(f"Loaded dataset: {DATA_PATH}")
print(f"Shape: {df.shape}")
print(f"Columns: {list(df.columns)}")


Loaded dataset: c:\My Projects\en-ar-translation\artifacts\eda\final_cleaned_combined_dataset.parquet
Shape: (827576, 2)
Columns: ['en', 'ar']


In [4]:
# Micro-step 3: remove invalid/empty rows and report before/after
rows_before = len(df)

df = df.dropna(subset=["en", "ar"]).copy()
df["en"] = df["en"].astype(str).str.strip()
df["ar"] = df["ar"].astype(str).str.strip()
df = df[(df["en"] != "") & (df["ar"] != "")].reset_index(drop=True)

rows_after = len(df)
rows_removed = rows_before - rows_after
print(f"Rows before cleaning: {rows_before:,}")
print(f"Rows after cleaning: {rows_after:,}")
print(f"Rows removed: {rows_removed:,}")


Rows before cleaning: 827,576
Rows after cleaning: 827,546
Rows removed: 30


In [5]:
# Quick check: 10 random rows from current dataset
df.sample(n=10)[["en", "ar"]].reset_index(drop=True)

Unnamed: 0,en,ar
0,man with blue shirt standing in a gymnasium,رجل ذو قميص أزرق يقف في صالة رياضية
1,Hart did not run for public office again.,لم تبحث شركة Hart عن مقر عام لها مجددا.
2,uh present day linda ronstadt,في الوقت الحاضر (ليندا رونستد)
3,The current mayor is Leonard Reed.,ليونارد ريد هو عمدة البلدية الحالي.
4,"""It's for Annabelle.""",إنها لأنابيل.
5,thats not your problem,هذه ليست مشكلتك
6,leave me alone,كلا ! دعوني وشأني
7,a large silver filigree pendant on a white bac...,قلادة فضية كبيرة على خلفية بيضاء
8,a large brick building with a blue sign that r...,مبنى كبير من الطوب مع علامة زرقاء التي تقرأ'mo...
9,egyptian president muhammad morsi has started ...,بدا الرييس المصري محمد مرسي رسميا بالتغريد عبر...


In [6]:
# Micro-step 4: deterministic hash split (90/5/5) with leakage guard
assert abs((config.train_ratio + config.val_ratio + config.test_ratio) - 1.0) < 1e-9, "Split ratios must sum to 1.0"

pair_hash = pd.util.hash_pandas_object(df[["en", "ar"]], index=False).astype("uint64")
u = pair_hash / np.float64(2**64)

train_cut = config.train_ratio
val_cut = config.train_ratio + config.val_ratio
df["split"] = np.where(u < train_cut, "train", np.where(u < val_cut, "val", "test"))

leak_count = int((df.groupby(["en", "ar"])["split"].nunique() > 1).sum())
assert leak_count == 0, f"Leakage detected across splits for {leak_count} pairs"

split_counts = df["split"].value_counts().rename_axis("split").reset_index(name="rows")
split_counts["ratio"] = (split_counts["rows"] / len(df)).round(4)
split_counts


Unnamed: 0,split,rows,ratio
0,train,744926,0.9002
1,val,41445,0.0501
2,test,41175,0.0498


In [7]:
# Micro-step 5: build bidirectional rows with direction tokens
required_split_cols = ["en", "ar", "split"]
missing_split_cols = [c for c in required_split_cols if c not in df.columns]
assert not missing_split_cols, f"Missing columns before bidirectional build: {missing_split_cols}"

df_en_to_ar = pd.DataFrame({
    "source_text": "<2ar> " + df["en"],
    "target_text": df["ar"],
    "direction": "en_to_ar",
    "split": df["split"],
})

df_ar_to_en = pd.DataFrame({
    "source_text": "<2en> " + df["ar"],
    "target_text": df["en"],
    "direction": "ar_to_en",
    "split": df["split"],
})

df_bi = pd.concat([df_en_to_ar, df_ar_to_en], ignore_index=True)

print(f"Base rows: {len(df):,}")
print(f"Bidirectional rows: {len(df_bi):,}")

direction_split_counts = (
    df_bi.groupby(["split", "direction"]).size().reset_index(name="rows")
)
direction_split_counts


Base rows: 827,546
Bidirectional rows: 1,655,092


Unnamed: 0,split,direction,rows
0,test,ar_to_en,41175
1,test,en_to_ar,41175
2,train,ar_to_en,744926
3,train,en_to_ar,744926
4,val,ar_to_en,41445
5,val,en_to_ar,41445


In [8]:
# Micro-step 6: create train/val/test views from bidirectional dataset
required_bi_cols = ["source_text", "target_text", "direction", "split"]
missing_bi_cols = [c for c in required_bi_cols if c not in df_bi.columns]
assert not missing_bi_cols, f"Missing columns in df_bi: {missing_bi_cols}"

train_df = df_bi[df_bi["split"] == "train"].reset_index(drop=True)
val_df = df_bi[df_bi["split"] == "val"].reset_index(drop=True)
test_df = df_bi[df_bi["split"] == "test"].reset_index(drop=True)

assert len(train_df) + len(val_df) + len(test_df) == len(df_bi), "Split size mismatch"
assert len(train_df) > 0 and len(val_df) > 0 and len(test_df) > 0, "One split is empty"

print(f"train rows: {len(train_df):,}")
print(f"val rows: {len(val_df):,}")
print(f"test rows: {len(test_df):,}")


train rows: 1,489,852
val rows: 82,890
test rows: 82,350


In [9]:
# Micro-step 7: tokenizer setup (train split only)
try:
    from tokenizers import Tokenizer, decoders, models, pre_tokenizers, trainers
except ImportError as e:
    raise ImportError("`tokenizers` is required. Install with: pip install tokenizers") from e

TOKENIZER_DIR = PROJECT_ROOT / "artifacts" / "tokenizer"
TOKENIZER_DIR.mkdir(parents=True, exist_ok=True)
TOKENIZER_PATH = TOKENIZER_DIR / "en_ar_bpe_tokenizer.json"

SPECIAL_TOKENS = ["<pad>", "<s>", "</s>", "<unk>", "<2ar>", "<2en>"]

def train_corpus_iterator(df_in):
    for row in df_in.itertuples(index=False):
        yield row.source_text
        yield row.target_text

print(f"Tokenizer output path: {TOKENIZER_PATH}")
print(f"Special tokens: {SPECIAL_TOKENS}")
print("Tokenizer mode: ByteLevel BPE")
print(f"Train rows for tokenizer: {len(train_df):,}")
print(f"Approx lines seen by tokenizer iterator: {len(train_df) * 2:,}")


Tokenizer output path: c:\My Projects\en-ar-translation\artifacts\tokenizer\en_ar_bpe_tokenizer.json
Special tokens: ['<pad>', '<s>', '</s>', '<unk>', '<2ar>', '<2en>']
Tokenizer mode: ByteLevel BPE
Train rows for tokenizer: 1,489,852
Approx lines seen by tokenizer iterator: 2,979,704


In [10]:
# Micro-step 8: train and save shared ByteLevel BPE tokenizer
tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
tokenizer.decoder = decoders.ByteLevel()

trainer = trainers.BpeTrainer(
    vocab_size=config.vocab_size,
    special_tokens=SPECIAL_TOKENS,
    min_frequency=2,
    initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
)

tokenizer.train_from_iterator(train_corpus_iterator(train_df), trainer=trainer)
tokenizer.save(str(TOKENIZER_PATH))

vocab_size_trained = tokenizer.get_vocab_size()
print(f"Trained tokenizer vocab size: {vocab_size_trained:,}")
print(f"Saved tokenizer to: {TOKENIZER_PATH}")


Trained tokenizer vocab size: 32,000
Saved tokenizer to: c:\My Projects\en-ar-translation\artifacts\tokenizer\en_ar_bpe_tokenizer.json


In [11]:
# Micro-step 9: tokenizer audit (critical checks)
from collections import Counter

# Ensure tokenizer object is available (reload from disk if needed)
if "tokenizer" not in globals():
    tokenizer = Tokenizer.from_file(str(TOKENIZER_PATH))

# 1) Special-token integrity
special_token_ids = {tok: tokenizer.token_to_id(tok) for tok in SPECIAL_TOKENS}
missing_special = [tok for tok, tid in special_token_ids.items() if tid is None]
assert not missing_special, f"Missing special tokens in tokenizer vocab: {missing_special}"

id_2ar = special_token_ids["<2ar>"]
id_2en = special_token_ids["<2en>"]
probe_2ar = tokenizer.encode("<2ar> this is a test")
probe_2en = tokenizer.encode("<2en> english direction probe")
assert len(probe_2ar.ids) > 0 and probe_2ar.ids[0] == id_2ar, "<2ar> is not preserved as first token"
assert len(probe_2en.ids) > 0 and probe_2en.ids[0] == id_2en, "<2en> is not preserved as first token"

# 2) Sample-based token-length and truncation audit (real tokenizer lengths)
audit_sample_size = min(20000, len(df_bi))
audit_df = df_bi.sample(n=audit_sample_size, random_state=RANDOM_SEED) if len(df_bi) > audit_sample_size else df_bi

src_lengths = []
tgt_lengths = []
unk_counter = Counter()
unk_id = special_token_ids["<unk>"]

for row in audit_df.itertuples(index=False):
    src_ids = tokenizer.encode(row.source_text).ids
    tgt_ids = tokenizer.encode(row.target_text).ids
    src_lengths.append(len(src_ids))
    tgt_lengths.append(len(tgt_ids))
    unk_counter["src_unk_tokens"] += sum(1 for i in src_ids if i == unk_id)
    unk_counter["tgt_unk_tokens"] += sum(1 for i in tgt_ids if i == unk_id)

src_lengths = np.array(src_lengths, dtype=np.int32)
tgt_lengths = np.array(tgt_lengths, dtype=np.int32)

tokenizer_audit = pd.DataFrame({
    "metric": [
        "sample_rows",
        "trained_vocab_size",
        "src_len_p50", "src_len_p90", "src_len_p95", "src_len_p99",
        "tgt_len_p50", "tgt_len_p90", "tgt_len_p95", "tgt_len_p99",
        "src_trunc_ratio_gt_max_len",
        "tgt_trunc_ratio_gt_max_len",
        "either_trunc_ratio_gt_max_len",
        "src_unk_tokens",
        "tgt_unk_tokens",
    ],
    "value": [
        int(len(audit_df)),
        int(tokenizer.get_vocab_size()),
        int(np.percentile(src_lengths, 50)), int(np.percentile(src_lengths, 90)), int(np.percentile(src_lengths, 95)), int(np.percentile(src_lengths, 99)),
        int(np.percentile(tgt_lengths, 50)), int(np.percentile(tgt_lengths, 90)), int(np.percentile(tgt_lengths, 95)), int(np.percentile(tgt_lengths, 99)),
        float((src_lengths > config.max_seq_len).mean()),
        float((tgt_lengths > config.max_seq_len).mean()),
        float(((src_lengths > config.max_seq_len) | (tgt_lengths > config.max_seq_len)).mean()),
        int(unk_counter["src_unk_tokens"]),
        int(unk_counter["tgt_unk_tokens"]),
    ]
})

print("Special token IDs:", special_token_ids)
print("Direction token probes OK (<2ar>/<2en> preserved).")
tokenizer_audit


Special token IDs: {'<pad>': 0, '<s>': 1, '</s>': 2, '<unk>': 3, '<2ar>': 4, '<2en>': 5}
Direction token probes OK (<2ar>/<2en> preserved).


Unnamed: 0,metric,value
0,sample_rows,20000.0
1,trained_vocab_size,32000.0
2,src_len_p50,9.0
3,src_len_p90,18.0
4,src_len_p95,21.0
5,src_len_p99,28.0
6,tgt_len_p50,8.0
7,tgt_len_p90,17.0
8,tgt_len_p95,20.0
9,tgt_len_p99,28.0


In [12]:
# Micro-step 10: build and save Hugging Face compatible fast tokenizer
from transformers import PreTrainedTokenizerFast

HF_TOKENIZER_DIR = TOKENIZER_DIR / "hf_tokenizer"
HF_TOKENIZER_DIR.mkdir(parents=True, exist_ok=True)

hf_tokenizer = PreTrainedTokenizerFast(
    tokenizer_file=str(TOKENIZER_PATH),
    bos_token="<s>",
    eos_token="</s>",
    unk_token="<unk>",
    pad_token="<pad>",
    additional_special_tokens=["<2ar>", "<2en>"],
)

# Verify essential token ids exist
required_tokens = ["<pad>", "<s>", "</s>", "<unk>", "<2ar>", "<2en>"]
token_id_map = {tok: hf_tokenizer.convert_tokens_to_ids(tok) for tok in required_tokens}
missing_ids = [tok for tok, tid in token_id_map.items() if tid is None or tid < 0]
assert not missing_ids, f"Missing token ids in hf_tokenizer: {missing_ids}"

hf_tokenizer.save_pretrained(str(HF_TOKENIZER_DIR))

print(f"Saved HF tokenizer to: {HF_TOKENIZER_DIR}")
print("Token IDs:", token_id_map)
print(f"HF vocab size: {hf_tokenizer.vocab_size:,}")


Saved HF tokenizer to: c:\My Projects\en-ar-translation\artifacts\tokenizer\hf_tokenizer
Token IDs: {'<pad>': 0, '<s>': 1, '</s>': 2, '<unk>': 3, '<2ar>': 4, '<2en>': 5}
HF vocab size: 32,000


In [13]:
# Micro-step 11: tokenize train/val/test (truncation only, no static padding)
from datasets import Dataset

def tokenize_batch(batch):
    src = hf_tokenizer(
        batch["source_text"],
        truncation=True,
        max_length=config.max_seq_len,
        padding=False,
    )
    tgt = hf_tokenizer(
        batch["target_text"],
        truncation=True,
        max_length=config.max_seq_len,
        padding=False,
    )
    src["labels"] = tgt["input_ids"]
    return src

train_ds = Dataset.from_pandas(train_df[["source_text", "target_text", "direction", "split"]], preserve_index=False)
val_ds = Dataset.from_pandas(val_df[["source_text", "target_text", "direction", "split"]], preserve_index=False)
test_ds = Dataset.from_pandas(test_df[["source_text", "target_text", "direction", "split"]], preserve_index=False)

train_tok = train_ds.map(tokenize_batch, batched=True, desc="Tokenizing train")
val_tok = val_ds.map(tokenize_batch, batched=True, desc="Tokenizing val")
test_tok = test_ds.map(tokenize_batch, batched=True, desc="Tokenizing test")

print(f"Tokenized train rows: {len(train_tok):,}")
print(f"Tokenized val rows: {len(val_tok):,}")
print(f"Tokenized test rows: {len(test_tok):,}")
print(train_tok[0].keys())


Tokenizing train:   0%|          | 0/1489852 [00:00<?, ? examples/s]

Tokenizing val:   0%|          | 0/82890 [00:00<?, ? examples/s]

Tokenizing test:   0%|          | 0/82350 [00:00<?, ? examples/s]

Tokenized train rows: 1,489,852
Tokenized val rows: 82,890
Tokenized test rows: 82,350
dict_keys(['source_text', 'target_text', 'direction', 'split', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'])


In [14]:
# Micro-step 12: dynamic padding collator + one-batch sanity check
from transformers import DataCollatorForSeq2Seq

model_input_cols = ["input_ids", "attention_mask", "labels"]
train_tok_model = train_tok.remove_columns([c for c in train_tok.column_names if c not in model_input_cols])

data_collator = DataCollatorForSeq2Seq(
    tokenizer=hf_tokenizer,
    model=None,
    padding="longest",
    label_pad_token_id=-100,
)

batch_size_probe = min(32, len(train_tok_model))
probe_ds = train_tok_model.shuffle(seed=RANDOM_SEED).select(range(batch_size_probe))
probe_features = [probe_ds[i] for i in range(len(probe_ds))]
probe_batch = data_collator(probe_features)

print(f"Probe batch size: {batch_size_probe}")
print(f"input_ids shape: {tuple(probe_batch['input_ids'].shape)}")
print(f"attention_mask shape: {tuple(probe_batch['attention_mask'].shape)}")
print(f"labels shape: {tuple(probe_batch['labels'].shape)}")
print(f"Dynamic padded source length (batch max): {probe_batch['input_ids'].shape[1]}")
print(f"Dynamic padded target length (batch max): {probe_batch['labels'].shape[1]}")
print(f"Configured truncation cap: {config.max_seq_len}")


Probe batch size: 32
input_ids shape: (32, 74)
attention_mask shape: (32, 74)
labels shape: (32, 37)
Dynamic padded source length (batch max): 74
Dynamic padded target length (batch max): 37
Configured truncation cap: 128


In [15]:
# Micro-step 13: manual audit of random final tokenized examples
sample_n = min(5, len(train_tok))
audit_ds = train_tok.shuffle().select(range(sample_n))

audit_rows = []
for item in audit_ds:
    src_ids = item["input_ids"]
    lbl_ids = item["labels"]
    audit_rows.append({
        "direction": item.get("direction", ""),
        "source_text": item["source_text"],
        "target_text": item["target_text"],
        "source_len": len(src_ids),
        "target_len": len(lbl_ids),
        "source_ids": src_ids,
        "target_ids": lbl_ids,
        "decoded_source_from_ids": hf_tokenizer.decode(src_ids, skip_special_tokens=False),
        "decoded_target_from_ids": hf_tokenizer.decode(lbl_ids, skip_special_tokens=False),
    })

tokenized_audit_df = pd.DataFrame(audit_rows)
tokenized_audit_df


Unnamed: 0,direction,source_text,target_text,source_len,target_len,source_ids,target_ids,decoded_source_from_ids,decoded_target_from_ids
0,ar_to_en,<2en> فقد نفذ مني البروبان تقريبا,im nearly out of propane,7,6,"[5, 3071, 18650, 4714, 1520, 16742, 4348]","[431, 9763, 684, 333, 4298, 2385]",<2en> فقد نفذ مني البروبان تقريبا,im nearly out of propane
1,en_to_ar,<2ar> another glass for the lady,أتريدين كأسا آخر؟,6,4,"[4, 2200, 1849, 434, 295, 5523]","[27912, 17415, 1714, 411]",<2ar> another glass for the lady,أتريدين كأسا آخر؟
2,ar_to_en,<2en> أنت مضحك,mm youre funny,3,3,"[5, 895, 10872]","[3683, 1151, 5910]",<2en> أنت مضحك,mm youre funny
3,ar_to_en,<2en> رغم ذلك، فإن هذه المعتقدات لها ما يبررها...,"However, these beliefs are clearly justified.",17,10,"[5, 10193, 569, 360, 3970, 687, 332, 1144, 335...","[3799, 17, 1614, 18768, 88, 450, 13120, 798, 4...",<2en> رغم ذلك، فإن هذه المعتقدات لها ما يبررها...,"However, these beliefs are clearly justified."
4,ar_to_en,<2en> قراءة والمشاركة في مدونة الاصوات الصاعدة...,read and participate on the rising voices blog...,39,35,"[5, 10144, 1723, 12101, 325, 7438, 2178, 29307...","[4776, 355, 18161, 372, 295, 13727, 2129, 2799...",<2en> قراءة والمشاركة في مدونة الاصوات الصاعدة...,read and participate on the rising voices blog...


In [16]:
for item in audit_ds:
    print("=" * 60)
    print(f"Direction : {item.get('direction', '')}")
    print(f"Source    : {item['source_text']}")
    print(f"Target    : {item['target_text']}")
    print(f"Input IDs : {item['input_ids']}")
    print(f"Labels    : {item['labels']}")
    print(f"Attn Mask : {item['attention_mask']}")
    print(f"Decoded S : {hf_tokenizer.decode(item['input_ids'], skip_special_tokens=False)}")
    print(f"Decoded T : {hf_tokenizer.decode(item['labels'], skip_special_tokens=False)}")
print("=" * 60)

Direction : ar_to_en
Source    : <2en> فقد نفذ مني البروبان تقريبا
Target    : im nearly out of propane
Input IDs : [5, 3071, 18650, 4714, 1520, 16742, 4348]
Labels    : [431, 9763, 684, 333, 4298, 2385]
Attn Mask : [1, 1, 1, 1, 1, 1, 1]
Decoded S : <2en> فقد نفذ مني البروبان تقريبا
Decoded T : im nearly out of propane
Direction : en_to_ar
Source    : <2ar> another glass for the lady
Target    : أتريدين كأسا آخر؟
Input IDs : [4, 2200, 1849, 434, 295, 5523]
Labels    : [27912, 17415, 1714, 411]
Attn Mask : [1, 1, 1, 1, 1, 1]
Decoded S : <2ar> another glass for the lady
Decoded T : أتريدين كأسا آخر؟
Direction : ar_to_en
Source    : <2en> أنت مضحك
Target    : mm youre funny
Input IDs : [5, 895, 10872]
Labels    : [3683, 1151, 5910]
Attn Mask : [1, 1, 1]
Decoded S : <2en> أنت مضحك
Decoded T : mm youre funny
Direction : ar_to_en
Source    : <2en> رغم ذلك، فإن هذه المعتقدات لها ما يبررها بوضوح.
Target    : However, these beliefs are clearly justified.
Input IDs : [5, 10193, 569, 360, 3970, 6

In [17]:
# Micro-step 14: define and instantiate random-init BART model
from transformers import BartConfig, BartForConditionalGeneration

bart_config = BartConfig(
    vocab_size=hf_tokenizer.vocab_size,
    max_position_embeddings=config.max_seq_len + 2,
    d_model=512,
    encoder_layers=6,
    decoder_layers=6,
    encoder_attention_heads=8,
    decoder_attention_heads=8,
    encoder_ffn_dim=2048,
    decoder_ffn_dim=2048,
    dropout=0.1, # Main residual dropout (applied after attention & FFN blocks)
    attention_dropout=0.1, # Dropout applied to attention probabilities
    activation_dropout=0.0, # Dropout after FFN activation (kept 0 for stability)
    pad_token_id=hf_tokenizer.pad_token_id,
    bos_token_id=hf_tokenizer.bos_token_id,
    eos_token_id=hf_tokenizer.eos_token_id,
    decoder_start_token_id=hf_tokenizer.bos_token_id, # the decoder will start with <bos> to predict the first token
)

model = BartForConditionalGeneration(bart_config)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Initialized random BART model (no pretrained weights).")
print(f"Vocab size: {bart_config.vocab_size:,}")
print(f"d_model: {bart_config.d_model}, enc_layers: {bart_config.encoder_layers}, dec_layers: {bart_config.decoder_layers}")
print(f"Total params: {total_params:,}")
print(f"Trainable params: {trainable_params:,}")


Initialized random BART model (no pretrained weights).
Vocab size: 32,000
d_model: 512, enc_layers: 6, dec_layers: 6
Total params: 60,659,712
Trainable params: 60,659,712


In [18]:
# Micro-step 15: one-batch forward-pass sanity check
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.train()

probe_size = min(8, len(train_tok_model))
probe_ds = train_tok_model.shuffle(seed=RANDOM_SEED).select(range(probe_size))
probe_features = [probe_ds[i] for i in range(len(probe_ds))]
probe_batch = data_collator(probe_features)

batch_on_device = {k: v.to(device) for k, v in probe_batch.items()}
outputs = model(**batch_on_device)

loss_value = float(outputs.loss.detach().cpu().item())
assert np.isfinite(loss_value), f"Non-finite loss detected: {loss_value}"

print(f"Device: {device}")
print(f"Probe batch size: {probe_size}")
print(f"Loss: {loss_value:.6f}")
print(f"Logits shape: {tuple(outputs.logits.shape)}")
print(outputs.logits)


Device: cuda
Probe batch size: 8
Loss: 10.462287
Logits shape: (8, 17, 32000)
tensor([[[ 0.0000e+00,  4.7105e+00,  7.3093e-01,  ..., -1.7524e-01,
          -4.8019e-01, -2.6668e-01],
         [ 0.0000e+00,  8.7805e-04, -7.4898e-01,  ...,  1.4209e-01,
           7.2610e-02, -9.1130e-02],
         [ 0.0000e+00, -1.0881e-01,  3.3885e-01,  ..., -6.3099e-01,
          -1.9324e-01, -2.8855e-01],
         ...,
         [ 0.0000e+00, -1.6474e-01,  3.3509e-01,  ..., -8.8127e-02,
          -6.1816e-01, -4.6103e-01],
         [ 0.0000e+00,  1.9915e-01,  2.0464e-01,  ...,  2.3678e-01,
           2.9234e-01, -7.0112e-01],
         [ 0.0000e+00, -7.7394e-01, -3.0749e-01,  ..., -1.1064e-01,
          -6.7551e-01, -4.5202e-01]],

        [[ 0.0000e+00,  4.4847e+00,  1.6182e-01,  ...,  7.3620e-01,
          -7.9925e-01,  6.1794e-02],
         [ 0.0000e+00,  1.7318e-01, -5.4782e-01,  ...,  8.3050e-01,
           3.1044e-01, -4.5127e-02],
         [ 0.0000e+00,  3.1265e-01,  2.2879e-01,  ...,  3.6026e-01

In [19]:
# Micro-step 16: training hyperparameters + optimizer/scheduler setup
from transformers import get_cosine_with_min_lr_schedule_with_warmup

assert "model" in globals(), "Model must be initialized before optimizer setup"

learning_rate = 3e-4
weight_decay = 0.01
adam_betas = (0.9, 0.98)
adam_eps = 1e-8

max_steps = 32_000
warmup_ratio = 0.015
num_warmup_steps = int(max_steps * warmup_ratio)

# Cosine decay floor: keep LR at >=10% of initial LR
min_lr_rate = 0.10
min_lr = learning_rate * min_lr_rate

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=learning_rate,
    betas=adam_betas,
    eps=adam_eps,
    weight_decay=weight_decay,
)

lr_scheduler = get_cosine_with_min_lr_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=max_steps,
    min_lr_rate=min_lr_rate,
)

print(f"max learning_rate: {learning_rate}")
print(f"min_lr_rate ratio: {min_lr_rate:.2f}")
print(f"min_lr value: {min_lr}")
print(f"weight_decay: {weight_decay}")
print(f"adam_betas: {adam_betas}")
print(f"max_steps: {max_steps:,}")
print(f"num_warmup_steps: {num_warmup_steps:,}")


max learning_rate: 0.0003
min_lr_rate ratio: 0.10
min_lr value: 2.9999999999999997e-05
weight_decay: 0.01
adam_betas: (0.9, 0.98)
max_steps: 32,000
num_warmup_steps: 480


In [20]:
# Micro-step 17: runtime memory config (checkpointing OFF by default)
assert "model" in globals(), "Model must exist before runtime memory config"

# Start simple: no gradient checkpointing unless we face OOM
use_gradient_checkpointing = False
if use_gradient_checkpointing:
    model.gradient_checkpointing_enable()
else:
    model.gradient_checkpointing_disable()

# Mixed precision for 8GB VRAM training efficiency
use_fp16 = torch.cuda.is_available()
try:
    grad_scaler = torch.amp.GradScaler("cuda", enabled=use_fp16)
except TypeError:
    # Fallback for older torch versions
    grad_scaler = torch.cuda.amp.GradScaler(enabled=use_fp16)

# Initial batch settings (can be tuned after smoke run)
per_device_train_batch_size = 14
gradient_accumulation_steps = 8
effective_batch_size = per_device_train_batch_size * gradient_accumulation_steps

# Optional acceleration on Ampere+ GPUs
if torch.cuda.is_available() and hasattr(torch.backends.cuda.matmul, "allow_tf32"):
    print("Enabling TF32 for matmul and cuDNN (Ampere+ GPUs)")
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

print(f"use_gradient_checkpointing: {use_gradient_checkpointing}")
print(f"use_fp16: {use_fp16}")
print(f"per_device_train_batch_size: {per_device_train_batch_size}")
print(f"gradient_accumulation_steps: {gradient_accumulation_steps}")
print(f"effective_batch_size (examples/update): {effective_batch_size}")


Enabling TF32 for matmul and cuDNN (Ampere+ GPUs)
use_gradient_checkpointing: False
use_fp16: True
per_device_train_batch_size: 14
gradient_accumulation_steps: 8
effective_batch_size (examples/update): 112


In [21]:
# Micro-step 18: dynamic training-volume analytics
assert "per_device_train_batch_size" in globals(), "Run runtime config cell first"
assert "gradient_accumulation_steps" in globals(), "Run runtime config cell first"
assert "max_steps" in globals(), "Run optimizer/scheduler cell first"
assert "train_df" in globals(), "Run split/bidirectional cells first"

examples_per_micro_batch = int(per_device_train_batch_size)
examples_per_optimizer_step = int(per_device_train_batch_size * gradient_accumulation_steps)
examples_seen_total = int(max_steps * examples_per_optimizer_step)
train_examples = int(len(train_df))
approx_epochs = (examples_seen_total / train_examples) if train_examples else 0.0

analytics_df = pd.DataFrame([
    {"metric": "examples_per_micro_batch", "value": examples_per_micro_batch},
    {"metric": "gradient_accumulation_steps", "value": int(gradient_accumulation_steps)},
    {"metric": "examples_per_optimizer_step", "value": examples_per_optimizer_step},
    {"metric": "max_steps", "value": int(max_steps)},
    {"metric": "examples_seen_total", "value": examples_seen_total},
    {"metric": "train_examples", "value": train_examples},
    {"metric": "approx_epochs_over_train_split", "value": round(approx_epochs, 4)},
    {"metric": "coverage_percent_of_train_split %", "value": round(approx_epochs * 100, 2)},
])

analytics_df = analytics_df.astype({"value": float})


print("Training-volume summary:")
analytics_df.style.format({"value": "{:,.0f}"})


Training-volume summary:


Unnamed: 0,metric,value
0,examples_per_micro_batch,14
1,gradient_accumulation_steps,8
2,examples_per_optimizer_step,112
3,max_steps,32000
4,examples_seen_total,3584000
5,train_examples,1489852
6,approx_epochs_over_train_split,2
7,coverage_percent_of_train_split %,241


In [22]:
# Micro-step 19: quick training-time estimate for current max_steps
assert "max_steps" in globals(), "Run optimizer/scheduler setup cell first"

# Set this after a short timed run (seconds per optimizer step)
sec_per_step_estimate = 1.5

total_seconds = max_steps * sec_per_step_estimate
total_hours = total_seconds / 3600

print(f"max_steps: {max_steps:,}")
print(f"sec_per_step_estimate: {sec_per_step_estimate:.3f}")
print(f"estimated_total_seconds: {total_seconds:,.0f}")
print(f"estimated_total_hours: {total_hours:.2f}")

for s in [1.0, 1.5, 2.0, 2.5, 3.0]:
    h = (max_steps * s) / 3600
    print(f"if {s:.1f}s/step -> {h:.2f} hours")


max_steps: 32,000
sec_per_step_estimate: 1.500
estimated_total_seconds: 48,000
estimated_total_hours: 13.33
if 1.0s/step -> 8.89 hours
if 1.5s/step -> 13.33 hours
if 2.0s/step -> 17.78 hours
if 2.5s/step -> 22.22 hours
if 3.0s/step -> 26.67 hours


In [23]:
# model = BartForConditionalGeneration(bart_config)
# model = model.to(device)

In [24]:
# Micro-step 19b: optional torch.compile acceleration
use_torch_compile = False #torch.cuda.is_available() and hasattr(torch, "compile")

if use_torch_compile:
    try:
        # `reduce-overhead` is a practical starting mode for training loops
        model = torch.compile(model, backend="eager", dynamic=True)
        print("torch.compile: enabled")
    except Exception as e:
        print(f"torch.compile: failed, continuing without compile. reason={e}")
else:
    print("torch.compile: skipped (CUDA or torch.compile unavailable)")


torch.compile: skipped (CUDA or torch.compile unavailable)


In [25]:
# Micro-step 20: 50-step smoke training run
import time
from torch.utils.data import DataLoader

smoke_steps = 50
assert smoke_steps > 0, "smoke_steps must be positive"

train_tok_model = train_tok.remove_columns([c for c in train_tok.column_names if c not in ["input_ids", "attention_mask", "labels"]])
val_tok_model = val_tok.remove_columns([c for c in val_tok.column_names if c not in ["input_ids", "attention_mask", "labels"]])

train_loader = DataLoader(
    train_tok_model,
    batch_size=per_device_train_batch_size,
    shuffle=True,
    collate_fn=data_collator,
    drop_last=False,
    num_workers=2,
    pin_memory=True,
)
val_loader = DataLoader(
    val_tok_model,
    batch_size=per_device_train_batch_size,
    shuffle=False,
    collate_fn=data_collator,
    drop_last=False,
    num_workers=2,
    pin_memory=True,
)

model.train()
optimizer.zero_grad(set_to_none=True)
start = time.time()
running_loss = 0.0
optimizer_steps_done = 0

for micro_step, batch in enumerate(train_loader, start=1):
    if optimizer_steps_done >= smoke_steps:
        break

    batch = {k: v.to(device) for k, v in batch.items()}

    if use_fp16:
        with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
            out = model(**batch)
            loss = out.loss / gradient_accumulation_steps
        grad_scaler.scale(loss).backward()
    else:
        out = model(**batch)
        loss = out.loss / gradient_accumulation_steps
        loss.backward()

    running_loss += float(out.loss.detach().cpu().item())

    if micro_step % gradient_accumulation_steps == 0:
        if use_fp16:
            grad_scaler.step(optimizer)
            grad_scaler.update()
        else:
            optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        lr_scheduler.step()
        optimizer_steps_done += 1

        if optimizer_steps_done % 10 == 0 or optimizer_steps_done == 1:
            avg_loss = running_loss / micro_step
            current_lr = optimizer.param_groups[0]["lr"]
            print(f"step={optimizer_steps_done:>3} avg_train_loss={avg_loss:.4f} lr={current_lr:.6g}")

# Quick single-batch validation check
model.eval()
with torch.no_grad():
    val_batch = next(iter(val_loader))
    val_batch = {k: v.to(device) for k, v in val_batch.items()}
    if use_fp16:
        with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
            val_out = model(**val_batch)
    else:
        val_out = model(**val_batch)

elapsed = time.time() - start
steps_per_sec = optimizer_steps_done / elapsed if elapsed > 0 else 0.0
sec_per_step_estimate = 1.0 / steps_per_sec if steps_per_sec > 0 else float("inf")

print("-" * 60)
print(f"smoke optimizer steps completed: {optimizer_steps_done}")
print(f"elapsed_sec: {elapsed:.2f}")
print(f"steps_per_sec: {steps_per_sec:.4f}")
print(f"sec_per_step_estimate: {sec_per_step_estimate:.4f}")
print(f"val_loss_single_batch: {float(val_out.loss.detach().cpu().item()):.4f}")
if torch.cuda.is_available():
    peak_mem_gb = torch.cuda.max_memory_allocated(device) / (1024**3)
    print(f"peak_cuda_memory_gb: {peak_mem_gb:.3f}")


step=  1 avg_train_loss=10.4594 lr=6.25e-07
step= 10 avg_train_loss=10.4527 lr=6.25e-06
step= 20 avg_train_loss=10.4176 lr=1.25e-05
step= 30 avg_train_loss=10.3541 lr=1.875e-05
step= 40 avg_train_loss=10.2686 lr=2.5e-05
step= 50 avg_train_loss=10.1809 lr=3.125e-05
------------------------------------------------------------
smoke optimizer steps completed: 50
elapsed_sec: 104.25
steps_per_sec: 0.4796
sec_per_step_estimate: 2.0850
val_loss_single_batch: 9.5379
peak_cuda_memory_gb: 2.543


In [None]:
# Micro-step 21a: helper functions for eval, metrics, and checkpointing
import math

def _load_text_metrics(enable_comet=False):
    try:
        import evaluate
    except ImportError as e:
        raise ImportError("`evaluate` is required for BLEU/chrF metrics. Install: pip install evaluate sacrebleu") from e

    _bleu_metric = evaluate.load("sacrebleu")
    _chrf_metric = evaluate.load("chrf")
    _comet_metric = None
    if enable_comet:
        try:
            _comet_metric = evaluate.load("comet")
        except Exception as e:
            print(f"[warn] COMET unavailable; continuing without COMET. Reason: {e}")
            _comet_metric = None
    return _bleu_metric, _chrf_metric, _comet_metric

def _eval_val_loss(_model, _val_loader, _device, _use_fp16, _max_batches=None):
    _model.eval()
    losses = []
    with torch.no_grad():
        for b_idx, batch in enumerate(_val_loader):
            if _max_batches is not None and b_idx >= _max_batches:
                break
            batch = {k: v.to(_device) for k, v in batch.items()}
            if _use_fp16:
                with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
                    out = _model(**batch)
            else:
                out = _model(**batch)
            losses.append(float(out.loss.detach().cpu().item()))
    _model.train()
    return float(np.mean(losses)) if losses else float("nan")

def _eval_text_metrics(
    _model,
    _tokenizer,
    _eval_df,
    _device,
    _batch_size,
    _max_batches,
    _num_beams,
    _seed,
    _bleu_metric,
    _chrf_metric,
    _comet_metric=None,
):

    _model.eval()
    if _eval_df is None or len(_eval_df) == 0:
        _model.train()
        return {
            "num_samples": 0,
            "num_batches": 0,
            "bleu": float("nan"),
            "chrf": float("nan"),
            "comet": float("nan"),
            "bleu_en_to_ar": float("nan"),
            "bleu_ar_to_en": float("nan"),
            "chrf_en_to_ar": float("nan"),
            "chrf_ar_to_en": float("nan"),
        }

    if _max_batches is None:
        sample_df = _eval_df.reset_index(drop=True)
    else:
        max_samples = int(_max_batches * _batch_size)
        max_samples = min(max_samples, len(_eval_df))
        if max_samples <= 0:
            _model.train()
            return {
                "num_samples": 0,
                "num_batches": 0,
                "bleu": float("nan"),
                "chrf": float("nan"),
                "comet": float("nan"),
                "bleu_en_to_ar": float("nan"),
                "bleu_ar_to_en": float("nan"),
                "chrf_en_to_ar": float("nan"),
                "chrf_ar_to_en": float("nan"),
            }
        sample_df = _eval_df.sample(n=max_samples, random_state=int(_seed)).reset_index(drop=True)

    preds = []
    refs = []
    srcs = []
    dirs = []
    with torch.no_grad():
        for start in range(0, len(sample_df), _batch_size):
            chunk = sample_df.iloc[start : start + _batch_size]
            src_batch = chunk["source_text"].astype(str).tolist()
            ref_batch = chunk["target_text"].astype(str).tolist()
            dir_batch = chunk["direction"].astype(str).tolist()

            enc = _tokenizer(
                src_batch,
                truncation=True,
                max_length=config.max_seq_len,
                padding=True,
                return_tensors="pt",
            )
            enc = {k: v.to(_device) for k, v in enc.items()}
            gen_ids = _model.generate(
                **enc,
                max_new_tokens=config.max_seq_len,
                num_beams=_num_beams,
                do_sample=False,
            )
            pred_batch = _tokenizer.batch_decode(gen_ids, skip_special_tokens=True)

            preds.extend([p.strip() for p in pred_batch])
            refs.extend([r.strip() for r in ref_batch])
            srcs.extend(src_batch)
            dirs.extend(dir_batch)

    _model.train()

    if len(preds) == 0:
        return {
            "num_samples": 0,
            "num_batches": 0,
            "bleu": float("nan"),
            "chrf": float("nan"),
            "comet": float("nan"),
            "bleu_en_to_ar": float("nan"),
            "bleu_ar_to_en": float("nan"),
            "chrf_en_to_ar": float("nan"),
            "chrf_ar_to_en": float("nan"),
        }

    def _score_subset(_preds, _refs):
        if len(_preds) == 0:
            return float("nan"), float("nan")
        _bleu = float(_bleu_metric.compute(predictions=_preds, references=[[r] for r in _refs])["score"])
        _chrf = float(_chrf_metric.compute(predictions=_preds, references=_refs)["score"])
        return _bleu, _chrf

    bleu_all, chrf_all = _score_subset(preds, refs)
    idx_en_to_ar = [i for i, d in enumerate(dirs) if d == "en_to_ar"]
    idx_ar_to_en = [i for i, d in enumerate(dirs) if d == "ar_to_en"]
    bleu_en_to_ar, chrf_en_to_ar = _score_subset([preds[i] for i in idx_en_to_ar], [refs[i] for i in idx_en_to_ar])
    bleu_ar_to_en, chrf_ar_to_en = _score_subset([preds[i] for i in idx_ar_to_en], [refs[i] for i in idx_ar_to_en])

    comet_score = float("nan")
    if _comet_metric is not None:
        try:
            comet_out = _comet_metric.compute(predictions=preds, references=refs, sources=srcs)
            comet_score = float(comet_out.get("mean_score", comet_out.get("score", float("nan"))))
        except Exception as e:
            print(f"[warn] COMET scoring failed at runtime; continuing without COMET value. Reason: {e}")

    return {
        "num_samples": int(len(preds)),
        "num_batches": int(math.ceil(len(preds) / _batch_size)),
        "bleu": bleu_all,
        "chrf": chrf_all,
        "comet": comet_score,
        "bleu_en_to_ar": bleu_en_to_ar,
        "bleu_ar_to_en": bleu_ar_to_en,
        "chrf_en_to_ar": chrf_en_to_ar,
        "chrf_ar_to_en": chrf_ar_to_en,
    }

def _save_checkpoint(_model, _tokenizer, _optimizer, _scheduler, _scaler, _step, _val_loss, _target_dir):
    _target_dir.mkdir(parents=True, exist_ok=True)
    _model.save_pretrained(str(_target_dir))
    _tokenizer.save_pretrained(str(_target_dir / "tokenizer"))
    torch.save({
        "step": int(_step),
        "val_loss": float(_val_loss) if _val_loss is not None else None,
        "optimizer": _optimizer.state_dict(),
        "scheduler": _scheduler.state_dict(),
        "scaler": _scaler.state_dict() if _scaler is not None else None,
    }, _target_dir / "trainer_state.pt")


In [None]:
# Micro-step 21: full training loop with interval logging, checkpoints, and best-model tracking
import json as pyjson
import shutil
import time
from datetime import datetime
from torch.utils.data import DataLoader

# -----------------------------
# Tunable run hyperparameters
# -----------------------------
log_every_steps = 50
eval_every_steps = 200
save_every_steps = 500
keep_last_n_checkpoints = 3
max_val_batches = 300  # set None to use full val loader
text_metric_num_beams = 1
max_text_metric_batches_log = 1   # lightweight text-metric pass at log intervals
max_text_metric_batches_eval = 8  # stronger text-metric pass at eval intervals
enable_comet = False
train_num_workers = 2
val_num_workers = 2
pin_memory = True

assert max_steps > 0, "max_steps must be > 0"
assert gradient_accumulation_steps > 0, "gradient_accumulation_steps must be > 0"

# Build train/val model-input datasets (dynamic padding via collator)
model_input_cols = ["input_ids", "attention_mask", "labels"]
train_tok_model = train_tok.remove_columns([c for c in train_tok.column_names if c not in model_input_cols])
val_tok_model = val_tok.remove_columns([c for c in val_tok.column_names if c not in model_input_cols])

train_loader = DataLoader(
    train_tok_model,
    batch_size=per_device_train_batch_size,
    shuffle=True,
    collate_fn=data_collator,
    num_workers=train_num_workers,
    pin_memory=pin_memory,
    persistent_workers=(train_num_workers > 0),
)
val_loader = DataLoader(
    val_tok_model,
    batch_size=per_device_train_batch_size,
    shuffle=True,
    collate_fn=data_collator,
    num_workers=val_num_workers,
    pin_memory=pin_memory,
    persistent_workers=(val_num_workers > 0),
)

In [None]:
# Run dirs and paths
run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = PROJECT_ROOT / "artifacts" / "runs" / run_id
ckpt_root = PROJECT_ROOT / "checkpoints" / run_id
best_dir = ckpt_root / "best_model"
run_dir.mkdir(parents=True, exist_ok=True)
ckpt_root.mkdir(parents=True, exist_ok=True)
best_dir.mkdir(parents=True, exist_ok=True)
train_metrics_path = run_dir / "train_metrics.csv"
eval_metrics_path = run_dir / "eval_metrics.csv"
config_path = run_dir / "run_config.json"

run_config = {
    "run_id": run_id,
    "learning_rate": learning_rate,
    "min_lr_rate": min_lr_rate,
    "weight_decay": weight_decay,
    "adam_betas": list(adam_betas),
    "adam_eps": adam_eps,
    "max_steps": max_steps,
    "warmup_ratio": warmup_ratio,
    "num_warmup_steps": num_warmup_steps,
    "per_device_train_batch_size": per_device_train_batch_size,
    "gradient_accumulation_steps": gradient_accumulation_steps,
    "effective_batch_size": effective_batch_size,
    "use_fp16": bool(use_fp16),
    "use_gradient_checkpointing": bool(use_gradient_checkpointing),
    "log_every_steps": log_every_steps,
    "eval_every_steps": eval_every_steps,
    "save_every_steps": save_every_steps,
    "keep_last_n_checkpoints": keep_last_n_checkpoints,
    "max_val_batches": max_val_batches,
    "text_metric_num_beams": text_metric_num_beams,
    "max_text_metric_batches_log": max_text_metric_batches_log,
    "max_text_metric_batches_eval": max_text_metric_batches_eval,
    "enable_comet": bool(enable_comet),
}
config_path.write_text(pyjson.dumps(run_config, indent=2), encoding="utf-8")

print(f"run_id: {run_id}")
print(f"run_dir: {run_dir}")
print(f"ckpt_root: {ckpt_root}")

assert "_load_text_metrics" in globals(), "Run helper-functions cell (Micro-step 21a) first"
assert "_eval_val_loss" in globals(), "Run helper-functions cell (Micro-step 21a) first"
assert "_eval_text_metrics" in globals(), "Run helper-functions cell (Micro-step 21a) first"
assert "_save_checkpoint" in globals(), "Run helper-functions cell (Micro-step 21a) first"
bleu_metric, chrf_metric, comet_metric = _load_text_metrics(enable_comet=enable_comet)

In [None]:
# Training state
model.train()
optimizer.zero_grad(set_to_none=True)
train_metrics_log = []
eval_metrics_log = []
best_val_loss = float("inf")
optimizer_steps_done = 0
micro_step = 0
start_time = time.time()
interval_start = start_time
interval_raw_loss_sum = 0.0
interval_micro_steps = 0
interval_opt_steps = 0
saved_ckpts = []

while optimizer_steps_done < max_steps:
    for batch in train_loader:
        if optimizer_steps_done >= max_steps:
            break

        micro_step += 1
        batch = {k: v.to(device) for k, v in batch.items()}

        if use_fp16:
            with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
                out = model(**batch)
                scaled_loss = out.loss / gradient_accumulation_steps
            grad_scaler.scale(scaled_loss).backward()
        else:
            out = model(**batch)
            scaled_loss = out.loss / gradient_accumulation_steps
            scaled_loss.backward()

        raw_loss = float(out.loss.detach().cpu().item())
        interval_raw_loss_sum += raw_loss
        interval_micro_steps += 1

        # One optimizer update happens after N micro-batches (gradient accumulation).
        if micro_step % gradient_accumulation_steps == 0:
            if use_fp16:
                grad_scaler.unscale_(optimizer)
            grad_norm = float(torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).detach().cpu().item())

            if use_fp16:
                grad_scaler.step(optimizer)
                grad_scaler.update()
            else:
                optimizer.step()

            optimizer.zero_grad(set_to_none=True)
            lr_scheduler.step()
            optimizer_steps_done += 1
            interval_opt_steps += 1

            current_lr = float(optimizer.param_groups[0]["lr"])
            examples_seen = int(optimizer_steps_done * per_device_train_batch_size * gradient_accumulation_steps)
            approx_epochs = examples_seen / len(train_df) if len(train_df) else float("nan")

            # Interval logging is based on optimizer steps (not micro-steps).
            if optimizer_steps_done % log_every_steps == 0 or optimizer_steps_done == 1:
                interval_elapsed = time.time() - interval_start
                avg_interval_loss = interval_raw_loss_sum / max(1, interval_micro_steps)
                steps_per_sec = interval_opt_steps / interval_elapsed if interval_elapsed > 0 else 0.0
                sec_per_step = 1.0 / steps_per_sec if steps_per_sec > 0 else float("inf")

                mem_alloc_gb = None
                mem_peak_gb = None
                if torch.cuda.is_available():
                    mem_alloc_gb = torch.cuda.memory_allocated(device) / (1024**3)
                    mem_peak_gb = torch.cuda.max_memory_allocated(device) / (1024**3)

                print("=" * 88)
                print(f"step {optimizer_steps_done:>6}/{max_steps} | avg_loss={avg_interval_loss:.4f} | lr={current_lr:.6g} | grad_norm={grad_norm:.4f}")
                print(f"interval_sec={interval_elapsed:.2f} | steps/sec={steps_per_sec:.4f} | sec/step={sec_per_step:.4f}")
                print(f"examples_seen={examples_seen:,} | approx_epochs={approx_epochs:.4f}")
                print(f"batch_shapes input={tuple(batch['input_ids'].shape)} labels={tuple(batch['labels'].shape)}")
                if mem_alloc_gb is not None:
                    print(f"cuda_mem_alloc_gb={mem_alloc_gb:.3f} | cuda_mem_peak_gb={mem_peak_gb:.3f}")

                text_log_t0 = time.time()
                text_metrics_log = _eval_text_metrics(
                    model,
                    hf_tokenizer,
                    val_df,
                    device,
                    per_device_train_batch_size,
                    max_text_metric_batches_log,
                    text_metric_num_beams,
                    _seed=RANDOM_SEED + optimizer_steps_done,
                    _bleu_metric=bleu_metric,
                    _chrf_metric=chrf_metric,
                    _comet_metric=comet_metric,
                )
                text_log_sec = time.time() - text_log_t0
                print(
                    f"[text@log] samples={text_metrics_log['num_samples']} | bleu={text_metrics_log['bleu']:.3f} "
                    f"| chrf={text_metrics_log['chrf']:.3f} | comet={text_metrics_log['comet']:.3f} "
                    f"| text_eval_sec={text_log_sec:.2f}"
                )

                train_metrics_log.append({
                    "step": int(optimizer_steps_done),
                    "avg_interval_loss": float(avg_interval_loss),
                    "lr": current_lr,
                    "grad_norm": grad_norm,
                    "interval_sec": float(interval_elapsed),
                    "steps_per_sec": float(steps_per_sec),
                    "sec_per_step": float(sec_per_step),
                    "examples_seen": int(examples_seen),
                    "approx_epochs": float(approx_epochs),
                    "batch_input_len": int(batch["input_ids"].shape[1]),
                    "batch_label_len": int(batch["labels"].shape[1]),
                    "cuda_mem_alloc_gb": float(mem_alloc_gb) if mem_alloc_gb is not None else None,
                    "cuda_mem_peak_gb": float(mem_peak_gb) if mem_peak_gb is not None else None,
                    "text_eval_sec": float(text_log_sec),
                    "text_samples": int(text_metrics_log["num_samples"]),
                    "text_batches": int(text_metrics_log["num_batches"]),
                    "bleu": float(text_metrics_log["bleu"]),
                    "chrf": float(text_metrics_log["chrf"]),
                    "comet": float(text_metrics_log["comet"]),
                    "bleu_en_to_ar": float(text_metrics_log["bleu_en_to_ar"]),
                    "bleu_ar_to_en": float(text_metrics_log["bleu_ar_to_en"]),
                    "chrf_en_to_ar": float(text_metrics_log["chrf_en_to_ar"]),
                    "chrf_ar_to_en": float(text_metrics_log["chrf_ar_to_en"]),
                })
                pd.DataFrame(train_metrics_log).to_csv(train_metrics_path, index=False)

                interval_start = time.time()
                interval_raw_loss_sum = 0.0
                interval_micro_steps = 0
                interval_opt_steps = 0

            # Run validation periodically and always at the final step.
            should_eval = (optimizer_steps_done % eval_every_steps == 0) or (optimizer_steps_done == max_steps)
            if should_eval:
                val_loss = _eval_val_loss(model, val_loader, device, use_fp16, max_batches=max_val_batches)
                text_eval_t0 = time.time()
                text_metrics_eval = _eval_text_metrics(
                    model,
                    hf_tokenizer,
                    val_df,
                    device,
                    per_device_train_batch_size,
                    max_text_metric_batches_eval,
                    text_metric_num_beams,
                    _seed=RANDOM_SEED + 10_000 + optimizer_steps_done,
                    _bleu_metric=bleu_metric,
                    _chrf_metric=chrf_metric,
                    _comet_metric=comet_metric,
                )
                text_eval_sec = time.time() - text_eval_t0
                print(
                    f"[eval] step={optimizer_steps_done} val_loss={val_loss:.4f} | "
                    f"bleu={text_metrics_eval['bleu']:.3f} | chrf={text_metrics_eval['chrf']:.3f} | "
                    f"comet={text_metrics_eval['comet']:.3f} | text_eval_sec={text_eval_sec:.2f}"
                )
                eval_metrics_log.append({
                    "step": int(optimizer_steps_done),
                    "val_loss": float(val_loss),
                    "lr": current_lr,
                    "text_eval_sec": float(text_eval_sec),
                    "text_samples": int(text_metrics_eval["num_samples"]),
                    "text_batches": int(text_metrics_eval["num_batches"]),
                    "bleu": float(text_metrics_eval["bleu"]),
                    "chrf": float(text_metrics_eval["chrf"]),
                    "comet": float(text_metrics_eval["comet"]),
                    "bleu_en_to_ar": float(text_metrics_eval["bleu_en_to_ar"]),
                    "bleu_ar_to_en": float(text_metrics_eval["bleu_ar_to_en"]),
                    "chrf_en_to_ar": float(text_metrics_eval["chrf_en_to_ar"]),
                    "chrf_ar_to_en": float(text_metrics_eval["chrf_ar_to_en"]),
                })
                pd.DataFrame(eval_metrics_log).to_csv(eval_metrics_path, index=False)

                if np.isfinite(val_loss) and val_loss < best_val_loss:
                    best_val_loss = float(val_loss)
                    _save_checkpoint(model, hf_tokenizer, optimizer, lr_scheduler, grad_scaler, optimizer_steps_done, val_loss, best_dir)
                    print(f"[best] new best val_loss={best_val_loss:.4f} saved to {best_dir}")

            # Save periodic checkpoints independently from validation schedule.
            should_save = (optimizer_steps_done % save_every_steps == 0) or (optimizer_steps_done == max_steps)
            if should_save:
                ckpt_dir = ckpt_root / f"step_{optimizer_steps_done:06d}"
                _save_checkpoint(model, hf_tokenizer, optimizer, lr_scheduler, grad_scaler, optimizer_steps_done, None, ckpt_dir)
                saved_ckpts.append(ckpt_dir)
                print(f"[ckpt] saved: {ckpt_dir}")

                # Optional retention policy to limit disk usage.
                if keep_last_n_checkpoints is not None and keep_last_n_checkpoints > 0:
                    while len(saved_ckpts) > keep_last_n_checkpoints:
                        old = saved_ckpts.pop(0)
                        if old.exists():
                            shutil.rmtree(old)
                            print(f"[ckpt] removed old checkpoint: {old}")

total_elapsed = time.time() - start_time
print("=" * 88)
print(f"Training finished at step {optimizer_steps_done}/{max_steps}")
print(f"Total elapsed sec: {total_elapsed:.2f}")
print(f"Best val loss: {best_val_loss if np.isfinite(best_val_loss) else None}")
print(f"Train metrics CSV: {train_metrics_path}")
print(f"Eval metrics CSV: {eval_metrics_path}")
print(f"Best model dir: {best_dir}")
