[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nyacly/rutooro-mt-model/blob/main/notebooks/train_nllb_colab.ipynb)

# NLLB-200 Tooro-English Fine-tuning
This notebook trains `facebook/nllb-200-distilled-600M` for Tooro↔English translation using the HuggingFace ecosystem. Each step prints useful info to help with debugging.

In [None]:

# Install required libraries
!pip install -U transformers datasets evaluate sacrebleu > /dev/null
print('Installed libraries.')


In [None]:

from google.colab import drive

# Mount drive for persistent storage
try:
    drive.mount('/content/drive')
except Exception as e:
    print('Drive mount failed:', e)


In [None]:

from pathlib import Path

data_dir = Path('/content/drive/MyDrive/rutooro-mt-data')
model_dir = Path('/content/drive/MyDrive/rutooro-mt-models')
output_dir = Path('/content/drive/MyDrive/rutooro-mt-outputs')

for p in [data_dir, model_dir, output_dir]:
    p.mkdir(parents=True, exist_ok=True)

print('Data directory:', data_dir)
print('Model directory:', model_dir)
print('Output directory:', output_dir)


In [None]:

from datasets import load_dataset

try:
    raw_ds = load_dataset('michsethowusu/english-tooro_sentence-pairs_mt560')
    print('Loaded dataset from HuggingFace.')
except Exception as e:
    print('Failed to load from Hub:', e)
    local_path = data_dir / 'english_rutooro.json'
    if local_path.exists():
        raw_ds = load_dataset('json', data_files=str(local_path))
        print('Loaded dataset from', local_path)
    else:
        raise RuntimeError('Dataset not found. Please upload the data.')

print(raw_ds)


In [None]:

# Ensure train/val/test splits exist
if 'train' not in raw_ds:
    raw_ds = raw_ds['train'].train_test_split(test_size=0.2)
    raw_ds['validation'] = raw_ds['test'].train_test_split(test_size=0.5)['test']
    raw_ds['test'] = raw_ds['test'].train_test_split(test_size=0.5)['train']
print(raw_ds)


In [None]:

# Map possible column names to 'eng' and 'ttj'
def map_columns(example):
    en = example.get('english') or example.get('source') or example.get('eng')
    tt = example.get('rutooro') or example.get('target') or example.get('ttj') or example.get('tt')
    return {'eng': en, 'ttj': tt}

raw_ds = raw_ds.map(map_columns, remove_columns=raw_ds['train'].column_names)

# Filter out empty rows
raw_ds = raw_ds.filter(lambda x: x['eng'] and x['ttj'])

print('After filtering:', raw_ds)
print('Sample:', raw_ds['train'][0])


In [None]:

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_name = 'facebook/nllb-200-distilled-600M'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

print('Tokenizer language codes:', tokenizer.lang_code_to_id.get('eng_Latn'), tokenizer.lang_code_to_id.get('ttj_Latn'))


In [None]:

max_length = 128

def preprocess(example):
    if isinstance(example['eng'], list):
        eng = example['eng']
        ttj = example['ttj']
    else:
        eng = [example['eng']]
        ttj = [example['ttj']]

    tokenizer.src_lang = 'eng_Latn'
    tokenizer.tgt_lang = 'ttj_Latn'

    model_inputs = tokenizer(eng, text_target=ttj, max_length=max_length, truncation=True)
    return model_inputs

processed_ds = raw_ds.map(preprocess, batched=True)

print(processed_ds)


In [None]:

from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)


In [None]:

from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir=str(output_dir),
    evaluation_strategy='epoch',
    save_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir=str(output_dir / 'logs'),
    predict_with_generate=True,
    remove_unused_columns=False,
)
print(training_args)


In [None]:

import evaluate

bleu = evaluate.load('sacrebleu')
chrf = evaluate.load('chrf')

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = [[(l if l != -100 else tokenizer.pad_token_id) for l in label] for label in labels]
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    bleu_score = bleu.compute(predictions=decoded_preds, references=[[l] for l in decoded_labels])['score']
    chrf_score = chrf.compute(predictions=decoded_preds, references=decoded_labels)['score']
    return {'bleu': bleu_score, 'chrf': chrf_score}


In [None]:

from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=processed_ds['train'],
    eval_dataset=processed_ds['validation'],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)


In [None]:

trainer.train()


In [None]:

metrics = trainer.evaluate(processed_ds['test'])
print('Test set metrics:', metrics)


In [None]:

model_save_path = model_dir / 'nllb-tooro'
trainer.save_model(str(model_save_path))
print('Model saved to', model_save_path)
