# Notebook 2: ByT5 Training with NLLB-Augmented Data

Fine-tunes `notninja/byt5-base-akkadian` on gold translations + NLLB teacher translations.

**Requires:** Output from Notebook 1 (`gold_with_nllb.parquet`) uploaded as Kaggle dataset `nicbarthelemy1/akkadian-nllb-translations`.

**Training plan:**
- Phase 1: All augmented gold data (~252K rows: 126K gold + 126K NLLB)
- Phase 2: Old Assyrian only (competition domain specialization)
- Checkpoint selection on competition val (95 OA samples) by geo_mean

In [None]:
!pip install -q sacrebleu datasets accelerate

In [None]:
import os, gc, math, time, warnings
from pathlib import Path
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import torch
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM,
    Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq,
)
from datasets import Dataset
from sacrebleu.metrics import BLEU, CHRF

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

## Load Data

In [None]:
def find_file(filename, base='/kaggle/input'):
    for root, dirs, files in os.walk(base):
        if filename in files:
            return Path(root)
    return None

# Show input layout
print('Contents of /kaggle/input/')
for d in sorted(os.listdir('/kaggle/input')):
    full = os.path.join('/kaggle/input', d)
    if os.path.isdir(full):
        print(f'  {d}/')
        for f in sorted(os.listdir(full))[:15]:
            print(f'    {f}')

# Load NLLB-augmented gold data (from Notebook 1 output)
NLLB_DIR = find_file('gold_with_nllb.parquet')
if NLLB_DIR is None:
    raise FileNotFoundError(
        'Cannot find gold_with_nllb.parquet. '
        'Run Notebook 1 (nllb_distill) first, then upload its output as a Kaggle dataset.'
    )
print(f'\nNLLB data at: {NLLB_DIR}')
gold_nllb_df = pd.read_parquet(NLLB_DIR / 'gold_with_nllb.parquet')

# Load original assembled data (for val splits)
ASSEMBLED_DIR = find_file('val_competition.parquet')
val_comp = pd.read_parquet(ASSEMBLED_DIR / 'val_competition.parquet')

# Competition test data
COMP_DIR = find_file('test.csv')
test_df = pd.read_csv(COMP_DIR / 'test.csv') if COMP_DIR else None

# Also load NLLB test predictions if available (for ensembling)
nllb_test_path = find_file('test_nllb_predictions.csv')
nllb_test_df = pd.read_csv(nllb_test_path / 'test_nllb_predictions.csv') if nllb_test_path else None

print(f'Gold+NLLB training: {len(gold_nllb_df)}')
print(f'Competition val: {len(val_comp)}')
if test_df is not None:
    print(f'Test samples: {len(test_df)}')

## Build Augmented Training Set

In [None]:
# Original gold pairs
orig = gold_nllb_df[['transliteration', 'translation', 'dialect']].copy()
orig['source_type'] = 'gold'

# NLLB-generated pairs (filter out empty translations)
nllb = gold_nllb_df[['transliteration', 'nllb_translation', 'dialect']].copy()
nllb = nllb.rename(columns={'nllb_translation': 'translation'})
nllb = nllb[nllb['translation'].str.len() > 0].reset_index(drop=True)
nllb['source_type'] = 'nllb'

# Combine with equal weight
augmented = pd.concat([orig, nllb], ignore_index=True)
augmented = augmented.sample(frac=1, random_state=42).reset_index(drop=True)
print(f'Augmented training: {len(augmented)} rows ({len(orig)} gold + {len(nllb)} nllb)')

# Old Assyrian subsets for Phase 2
oa_augmented = augmented[augmented['dialect'] == 'old_assyrian'].reset_index(drop=True)
print(f'OA augmented: {len(oa_augmented)} rows')

## Load ByT5 Model

In [None]:
BYT5_MODEL = 'notninja/byt5-base-akkadian'
PREFIX = 'translate Akkadian to English: '
MAX_SOURCE = 768
MAX_TARGET = 512

print(f'Loading {BYT5_MODEL}...')
tokenizer = AutoTokenizer.from_pretrained(BYT5_MODEL)
model = AutoModelForSeq2SeqLM.from_pretrained(BYT5_MODEL).to(device)

n_params = sum(p.numel() for p in model.parameters()) / 1e6
print(f'Loaded ({n_params:.0f}M params)')
print(f'GPU memory: {torch.cuda.memory_allocated()/1e9:.1f} GB')

In [None]:
def preprocess(examples):
    inputs = [PREFIX + str(t) for t in examples['transliteration']]
    targets = [str(t) for t in examples['translation']]
    model_inputs = tokenizer(inputs, max_length=MAX_SOURCE, truncation=True)
    labels = tokenizer(targets, max_length=MAX_TARGET, truncation=True)
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

def score_preds(preds, refs):
    b = BLEU().corpus_score(preds, [refs]).score
    c = CHRF(word_order=2).corpus_score(preds, [refs]).score
    g = math.sqrt(max(b, 0) * max(c, 0))
    return {'bleu': b, 'chrf': c, 'geo_mean': g}

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    return score_preds(decoded_preds, decoded_labels)

# Prepare eval dataset (competition val â€” used for all phases)
eval_ds = Dataset.from_pandas(val_comp[['transliteration', 'translation']])
eval_ds = eval_ds.map(preprocess, batched=True, remove_columns=eval_ds.column_names)
collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
print(f'Eval dataset: {len(eval_ds)} samples')

## Phase 1: All Augmented Gold Data

In [None]:
print(f'=== Phase 1: All augmented gold ({len(augmented)} samples) ===')

train_ds = Dataset.from_pandas(augmented[['transliteration', 'translation']])
train_ds = train_ds.map(preprocess, batched=True, remove_columns=train_ds.column_names)

phase1_args = Seq2SeqTrainingArguments(
    output_dir='/kaggle/working/byt5-nllb-phase1',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=4,
    learning_rate=5e-5,
    lr_scheduler_type='cosine',
    warmup_ratio=0.1,
    weight_decay=0.01,
    fp16=True,
    logging_steps=100,
    eval_strategy='steps',
    eval_steps=2000,
    save_strategy='steps',
    save_steps=2000,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model='geo_mean',
    greater_is_better=True,
    predict_with_generate=True,
    generation_max_length=MAX_TARGET,
    generation_num_beams=4,
    report_to='none',
    dataloader_num_workers=2,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=phase1_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=collator,
    processing_class=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()
trainer.save_model('/kaggle/working/byt5-nllb-phase1-best')
tokenizer.save_pretrained('/kaggle/working/byt5-nllb-phase1-best')
print('Phase 1 complete')

## Phase 2: Old Assyrian Specialization

In [None]:
print(f'\n=== Phase 2: Old Assyrian ({len(oa_augmented)} samples) ===')

oa_train_ds = Dataset.from_pandas(oa_augmented[['transliteration', 'translation']])
oa_train_ds = oa_train_ds.map(preprocess, batched=True, remove_columns=oa_train_ds.column_names)

phase2_args = Seq2SeqTrainingArguments(
    output_dir='/kaggle/working/byt5-nllb-phase2',
    num_train_epochs=5,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    per_device_eval_batch_size=4,
    learning_rate=1e-5,
    lr_scheduler_type='cosine',
    warmup_ratio=0.05,
    weight_decay=0.01,
    fp16=True,
    logging_steps=50,
    eval_strategy='steps',
    eval_steps=500,
    save_strategy='steps',
    save_steps=500,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model='geo_mean',
    greater_is_better=True,
    predict_with_generate=True,
    generation_max_length=MAX_TARGET,
    generation_num_beams=4,
    report_to='none',
    dataloader_num_workers=2,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=phase2_args,
    train_dataset=oa_train_ds,
    eval_dataset=eval_ds,
    data_collator=collator,
    processing_class=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()
trainer.save_model('/kaggle/working/byt5-nllb-phase2-best')
tokenizer.save_pretrained('/kaggle/working/byt5-nllb-phase2-best')
print('Phase 2 complete')

## Final Evaluation

In [None]:
comp_trans = val_comp['transliteration'].tolist()
comp_refs = val_comp['translation'].tolist()

model.eval()
preds = []
for i in range(0, len(comp_trans), 4):
    batch = [PREFIX + t for t in comp_trans[i:i+4]]
    enc = tokenizer(batch, max_length=MAX_SOURCE, truncation=True,
                    padding=True, return_tensors='pt').to(device)
    with torch.no_grad():
        out = model.generate(
            **enc, max_length=MAX_TARGET, num_beams=5,
            length_penalty=1.0, no_repeat_ngram_size=3, early_stopping=True,
        )
    preds.extend(tokenizer.batch_decode(out, skip_special_tokens=True))

scores = score_preds(preds, comp_refs)
print(f'ByT5 final: BLEU={scores["bleu"]:.2f}  chrF++={scores["chrf"]:.2f}  geo_mean={scores["geo_mean"]:.4f}')

for j in range(min(5, len(preds))):
    print(f'\n[{j}] Src: {comp_trans[j][:120]}...')
    print(f'    Out: {preds[j][:250]}')
    print(f'    Ref: {comp_refs[j][:250]}')

## Generate Submission

In [None]:
if test_df is not None:
    test_trans = test_df['transliteration'].tolist()
    test_preds = []
    for i in range(0, len(test_trans), 2):
        batch = [PREFIX + t for t in test_trans[i:i+2]]
        enc = tokenizer(batch, max_length=MAX_SOURCE, truncation=True,
                        padding=True, return_tensors='pt').to(device)
        with torch.no_grad():
            out = model.generate(
                **enc, max_length=MAX_TARGET, num_beams=5,
                length_penalty=1.0, no_repeat_ngram_size=3, early_stopping=True,
            )
        test_preds.extend(tokenizer.batch_decode(out, skip_special_tokens=True))
    
    submission = pd.DataFrame({'id': test_df['id'], 'translation': test_preds})
    submission.to_csv('/kaggle/working/submission.csv', index=False)
    print(f'Submission saved ({len(submission)} rows)')
    
    for j in range(len(test_preds)):
        print(f'\n[{j}] {test_trans[j][:100]}...')
        print(f'    ByT5: {test_preds[j][:300]}')
        if nllb_test_df is not None:
            print(f'    NLLB: {nllb_test_df.iloc[j]["nllb_translation"][:300]}')
else:
    print('No test.csv found')