<a href="https://colab.research.google.com/github/mosesmakola/ola-speaks/blob/emmanuel-dev/TrainandTest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch transformers datasets huggingface_hub snowflake-connector-python pandas numpy scikit-learn tqdm tensorboard

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting snowflake-connector-python
  Downloading snowflake_connector_python-3.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (67 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.8/67.8 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cub

In [None]:
! pip install bitsandbytes tokenizers snowflake-connector-python[pandas]

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Downloading bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl (76.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.1/76.1 MB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.45.5


In [None]:
%%capture
!pip install sacrebleu nltk

In [None]:
# All imports
import os
import logging
import numpy as np
import pandas as pd
import torch
from multiprocessing import Pool
from tqdm.auto import tqdm
from huggingface_hub import login
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer
)
from datasets import Dataset, load_from_disk
from sklearn.model_selection import train_test_split
import snowflake.connector
from snowflake.connector.pandas_tools import write_pandas
import sacrebleu
import nltk

nltk.download('punkt', quiet=True)
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

In [None]:
# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("nllb_training.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

In [None]:
from google.colab import userdata
HF_TOKEN = userdata.get('HF_TOKEN')

# Login to Hugging Face
logger.info("Logging into Hugging Face Hub")
login(token=HF_TOKEN)

# Check for CUDA availability
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")

# Paths for cached datasets
CACHE_DIR = "./cached_datasets"
os.makedirs(CACHE_DIR, exist_ok=True)
TRAIN_CACHE = os.path.join(CACHE_DIR, "train_dataset")
VAL_CACHE = os.path.join(CACHE_DIR, "val_dataset")
TEST_CACHE = os.path.join(CACHE_DIR, "test_dataset")

In [None]:
# Load model and tokenizer
logger.info("Loading model and tokenizer")
model_name = "facebook/nllb-200-distilled-600M"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load model with optimizations for training
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name,
    device_map="auto" if device == "cuda" else None,
    torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,)

model.config.use_cache = False

In [None]:
SF_USER = userdata.get("SF_USER")
SF_PASSWORD = userdata.get("SF_PASSWORD")
SF_ACCOUNT = userdata.get("SF_ACCOUNT")

In [None]:
def fetch_data_in_batches(batch_size=10000, max_rows=None):
    """Fetch data from Snowflake in batches"""
    logger.info("Connecting to Snowflake")
    conn = snowflake.connector.connect(
        user=SF_USER,
        password=SF_PASSWORD,
        account=SF_ACCOUNT,
        warehouse='OLASPEAKS',
        database='TEXT_LANGUAGE_DATA',
        schema='PUBLIC',
    )
    cur = conn.cursor()
    offset = 0
    all_rows = []

    logger.info("Fetching data in batches")
    while True:
        cur.execute(f"SELECT ENG, YOR, LIN FROM TEXT_LANGUAGE_DATA.BIBLE.RAW_BIBLE LIMIT {batch_size} OFFSET {offset}")
        batch = cur.fetchall()
        if not batch:
            break
        all_rows.extend(batch)
        offset += batch_size
        logger.info(f"Fetched {len(all_rows)} rows so far...")

        # Check if we've reached max_rows
        if max_rows and len(all_rows) >= max_rows:
            all_rows = all_rows[:max_rows]
            break

    conn.close()
    logger.info(f"Total rows fetched: {len(all_rows)}")

    df = pd.DataFrame(all_rows, columns=["ENG", "YOR", "LIN"])

    yor_nan_count = df['YOR'].isna().sum()
    lin_nan_count = df['LIN'].isna().sum()
    logger.info(f"Rows missing Yoruba: {yor_nan_count}, Rows missing Lingala: {lin_nan_count}")

    df = df.dropna(subset=['YOR', 'LIN'], how='all')
    logger.info(f"Rows after filtering: {len(df)}")

    return df

In [None]:
# very new
def preprocess_function(examples):
    """Process a batch of examples for tokenization"""
    inputs = examples["ENG"]
    targets = examples["target_text"]
    target_langs = examples["target_lang"]

    # Set source language
    tokenizer.src_lang = "eng_Latn"

    # Tokenize inputs
    model_inputs = tokenizer(
        inputs,
        max_length=128,
        truncation=True,
        padding="max_length"
    )

    # Process each example individually to set the correct target language
    tokenized_targets = []
    for i, target_lang in enumerate(target_langs):
        # Set target language for this specific example
        tokenizer.tgt_lang = target_lang

        # Tokenize this target
        with tokenizer.as_target_tokenizer():
            tokenized_target = tokenizer(
                targets[i],
                max_length=128,
                truncation=True,
                padding="max_length"
            )
        tokenized_targets.append(tokenized_target["input_ids"])

    # Assign tokenized targets to model inputs
    model_inputs["labels"] = tokenized_targets

    # Replace padding token id with -100
    for i in range(len(model_inputs["labels"])):
        model_inputs["labels"][i] = [
            -100 if label == tokenizer.pad_token_id else label
            for label in model_inputs["labels"][i]
        ]

    return model_inputs

In [None]:
def prepare_datasets(df, test_size=0.1, val_size=0.1):
    """Split data and prepare datasets with stratified sampling"""
    # Create a stratification column based on available translations
    df['strat'] = df.apply(
        lambda row: (
            'both' if pd.notna(row['YOR']) and pd.notna(row['LIN']) else
            'yor_only' if pd.notna(row['YOR']) else
            'lin_only' if pd.notna(row['LIN']) else
            'none'
        ),
        axis=1
    )

    # First split into train and temp (test + validation)
    train_df, temp_df = train_test_split(
        df,
        test_size=(test_size + val_size),
        random_state=42,
        stratify=df['strat']
    )

    # Split temp into validation and test
    val_size_adjusted = val_size / (test_size + val_size)
    val_df, test_df = train_test_split(
        temp_df,
        test_size=val_size_adjusted,
        random_state=42,
        stratify=temp_df['strat']
    )

    # Remove the stratification column
    train_df = train_df.drop(columns=['strat'])
    val_df = val_df.drop(columns=['strat'])
    test_df = test_df.drop(columns=['strat'])

    logger.info(f"Train: {len(train_df)}, Validation: {len(val_df)}, Test: {len(test_df)}")

    # Log distribution of languages in each split
    for split_name, split_df in [("Train", train_df), ("Val", val_df), ("Test", test_df)]:
        yor_count = split_df['YOR'].notna().sum()
        lin_count = split_df['LIN'].notna().sum()
        both_count = (split_df['YOR'].notna() & split_df['LIN'].notna()).sum()
        logger.info(f"{split_name} - Yoruba: {yor_count}, Lingala: {lin_count}, Both: {both_count}")

    return train_df, val_df, test_df

In [None]:
def prepare_language_data(train_df, val_df, test_df):
    """Prepare language-specific datasets with validation"""
    language_datasets = {}

    # Process Yoruba data
    for split_name, split_df in [("train", train_df), ("val", val_df), ("test", test_df)]:
        # Create Yoruba datasets
        df_yoruba = split_df[split_df["YOR"].notna()].copy()
        df_yoruba["target_lang"] = "yor_Latn"
        df_yoruba["target_text"] = df_yoruba["YOR"]
        language_datasets[f"yoruba_{split_name}"] = df_yoruba

        # Create Lingala datasets
        df_lingala = split_df[split_df["LIN"].notna()].copy()
        df_lingala["target_lang"] = "lin_Latn"
        df_lingala["target_text"] = df_lingala["LIN"]
        language_datasets[f"lingala_{split_name}"] = df_lingala

    # Log dataset sizes
    for dataset_name, dataset in language_datasets.items():
        logger.info(f"{dataset_name}: {len(dataset)} examples")

        # Validate a few examples to ensure alignment
        if len(dataset) > 0:
            sample = dataset.sample(min(3, len(dataset)))
            for _, row in sample.iterrows():
                source = row["ENG"]
                target = row["target_text"]
                logger.info(f"Sample from {dataset_name}:")
                logger.info(f"  Source: {source[:50]}...")
                logger.info(f"  Target: {target[:50]}...")

    return language_datasets

In [None]:
def translate(text, source_lang, target_lang):
    """Translate text from source_lang to target_lang"""
    tokenizer.src_lang = source_lang
    inputs = tokenizer(text, return_tensors="pt", padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Set forced BOS token to target language
    forced_bos_token_id = tokenizer.convert_tokens_to_ids(target_lang)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            forced_bos_token_id=forced_bos_token_id,
            max_length=128,
        )

    return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

In [None]:
def calculate_bleu_score(references, hypothesis):
    """Calculate BLEU score for a single translation"""
    # Tokenize the sentences
    tokenized_ref = [references.split()]
    tokenized_hyp = hypothesis.split()

    # Calculate BLEU score with smoothing
    smooth = SmoothingFunction().method1
    bleu_score = sentence_bleu(tokenized_ref, tokenized_hyp, smoothing_function=smooth)

    return bleu_score

In [None]:
def calculate_sacrebleu_score(references, hypothesis):
    """Calculate sacreBLEU score for a single translation"""
    bleu = sacrebleu.corpus_bleu([hypothesis], [[references]])
    return bleu.score

In [None]:
def evaluate_translations(model, tokenizer, test_df, source_lang, target_lang, lang_code):
    """Evaluate translations for a specific language pair"""
    filtered_df = test_df[test_df[lang_code].notna()].copy()
    if len(filtered_df) == 0:
        logger.warning(f"No test examples available for {target_lang}")
        return None

    # Limit to a reasonable number for evaluation
    if len(filtered_df) > 100:
        filtered_df = filtered_df.sample(100, random_state=42)

    bleu_scores = []
    sacrebleu_scores = []
    examples = []

    logger.info(f"Evaluating {len(filtered_df)} examples for {source_lang} → {target_lang}")

    for _, row in tqdm(filtered_df.iterrows(), total=len(filtered_df), desc=f"Evaluating {target_lang}"):
        english = row["ENG"]
        expected = row[lang_code]

        # Translate
        translated = translate(english, source_lang, target_lang)

        # Calculate scores
        bleu = calculate_bleu_score(expected, translated)
        sacrebleu = calculate_sacrebleu_score(expected, translated)

        bleu_scores.append(bleu)
        sacrebleu_scores.append(sacrebleu)

        # Save example for display
        examples.append({
            "source": english,
            "expected": expected,
            "translated": translated,
            "bleu": bleu,
            "sacrebleu": sacrebleu
        })

    # Calculate average scores
    avg_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0
    avg_sacrebleu = sum(sacrebleu_scores) / len(sacrebleu_scores) if sacrebleu_scores else 0

    logger.info(f"Average BLEU score for {target_lang}: {avg_bleu:.4f}")
    logger.info(f"Average sacreBLEU score for {target_lang}: {avg_sacrebleu:.4f}")

    return {
        "language": target_lang,
        "avg_bleu": avg_bleu,
        "avg_sacrebleu": avg_sacrebleu,
        "examples": examples
    }

In [None]:
from datasets import concatenate_datasets

# ready
def main():
    # Check if cached datasets exist
    if (os.path.exists(TRAIN_CACHE) and
        os.path.exists(VAL_CACHE) and
        os.path.exists(TEST_CACHE) and
        os.path.exists("./cached_datasets/test_df.pkl") and
        os.environ.get('REPROCESS_DATA') != '1'):

        logger.info("Loading cached datasets")
        train_dataset = load_from_disk(TRAIN_CACHE)
        val_dataset = load_from_disk(VAL_CACHE)
        test_dataset = load_from_disk(TEST_CACHE)

        # Load test dataframe for evaluation
        test_df = pd.read_pickle("./cached_datasets/test_df.pkl")
    else:
        # Set to None to process all rows, or a number to limit during development
        max_rows = None  # e.g., 10000 for faster development

        # Fetch data with ordered IDs
        df = fetch_data_in_batches(max_rows=max_rows)

        # Prepare datasets with stratified sampling
        train_df, val_df, test_df = prepare_datasets(df)

        # Save test dataframe for later evaluation
        os.makedirs("./cached_datasets", exist_ok=True)
        test_df.to_pickle("./cached_datasets/test_df.pkl")

        # Prepare language-specific data
        lang_data = prepare_language_data(train_df, val_df, test_df)

        # Convert to Hugging Face datasets
        logger.info("Converting to Hugging Face datasets")
        yoruba_train_dataset = Dataset.from_pandas(lang_data["yoruba_train"])
        lingala_train_dataset = Dataset.from_pandas(lang_data["lingala_train"])
        yoruba_val_dataset = Dataset.from_pandas(lang_data["yoruba_val"])
        lingala_val_dataset = Dataset.from_pandas(lang_data["lingala_val"])
        yoruba_test_dataset = Dataset.from_pandas(lang_data["yoruba_test"])
        lingala_test_dataset = Dataset.from_pandas(lang_data["lingala_test"])

        # Process the datasets with multiple workers
        logger.info("Processing training datasets")
        yoruba_train_dataset = yoruba_train_dataset.map(
            preprocess_function,
            batched=True,
            batch_size=64,
            num_proc=4,
            remove_columns=["ENG", "YOR", "LIN", "target_lang", "target_text"]
        )

        lingala_train_dataset = lingala_train_dataset.map(
            preprocess_function,
            batched=True,
            batch_size=64,
            num_proc=4,
            remove_columns=["ENG", "YOR", "LIN", "target_lang", "target_text"]
        )

        logger.info("Processing validation datasets")
        yoruba_val_dataset = yoruba_val_dataset.map(
            preprocess_function,
            batched=True,
            batch_size=64,
            num_proc=4,
            remove_columns=["ENG", "YOR", "LIN", "target_lang", "target_text"]
        )

        lingala_val_dataset = lingala_val_dataset.map(
            preprocess_function,
            batched=True,
            batch_size=64,
            num_proc=4,
            remove_columns=["ENG", "YOR", "LIN", "target_lang", "target_text"]
        )

        logger.info("Processing test datasets")
        yoruba_test_dataset = yoruba_test_dataset.map(
            preprocess_function,
            batched=True,
            batch_size=64,
            num_proc=4,
            remove_columns=["ENG", "YOR", "LIN", "target_lang", "target_text"]
        )

        lingala_test_dataset = lingala_test_dataset.map(
            preprocess_function,
            batched=True,
            batch_size=64,
            num_proc=4,
            remove_columns=["ENG", "YOR", "LIN", "target_lang", "target_text"]
        )

        # Combine datasets
        train_dataset = concatenate_datasets([yoruba_train_dataset, lingala_train_dataset])
        val_dataset = concatenate_datasets([yoruba_val_dataset, lingala_val_dataset])
        test_dataset = concatenate_datasets([yoruba_test_dataset, lingala_test_dataset])

        # Cache the datasets
        logger.info("Caching processed datasets")
        train_dataset.save_to_disk(TRAIN_CACHE)
        val_dataset.save_to_disk(VAL_CACHE)
        test_dataset.save_to_disk(TEST_CACHE)

    # Set up training arguments
    logger.info("Setting up training arguments")
    training_args = Seq2SeqTrainingArguments(
        output_dir="./nllb-finetuned",
        evaluation_strategy="steps",
        learning_rate=2e-5,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        gradient_accumulation_steps=4,
        gradient_checkpointing=True,
        num_train_epochs=3,
        weight_decay=0.01,
        save_total_limit=2,
        save_steps=500,
        eval_steps=500,
        logging_steps=100,
        predict_with_generate=True,
        bf16=True,  # ✅ Enable fp16 only if CUDA is available
        logging_dir="./logs",
        report_to="tensorboard",
        push_to_hub=False,  # Set to True if you want to push to HF Hub
    )

    # Set up data collator
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        padding="max_length",
        max_length=128
    )

    # Set up trainer
    logger.info("Setting up trainer")
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    # Train the model
    logger.info("Starting training")
    trainer.train()

    # Save the model
    logger.info("Saving model")
    trainer.save_model("./nllb-finetuned-final")

    # Return the trainer and test_df for further evaluation
    return trainer, test_df

In [None]:
trainer, test_df = main()

  trainer = Seq2SeqTrainer(


Step,Training Loss,Validation Loss


In [None]:
from datasets import concatenate_datasets

help(concatenate_datasets)

Help on function concatenate_datasets in module datasets.combine:

concatenate_datasets(dsets: list[~DatasetType], info: Optional[datasets.info.DatasetInfo] = None, split: Optional[datasets.splits.NamedSplit] = None, axis: int = 0) -> ~DatasetType
    Converts a list of [`Dataset`] with the same schema into a single [`Dataset`].
    
    Args:
        dsets (`List[datasets.Dataset]`):
            List of Datasets to concatenate.
        info (`DatasetInfo`, *optional*):
            Dataset information, like description, citation, etc.
        split (`NamedSplit`, *optional*):
            Name of the dataset split.
        axis (`{0, 1}`, defaults to `0`):
            Axis to concatenate over, where `0` means over rows (vertically) and `1` means over columns
            (horizontally).
    
            <Added version="1.6.0"/>
    
    Example:
    
    ```py
    >>> ds3 = concatenate_datasets([ds1, ds2])
    ```



In [None]:
logger.info("Evaluating model on test set")
evaluation_results = {}

# Evaluate Yoruba translations
yoruba_results = evaluate_translations(
    model, tokenizer, test_df,
    source_lang="eng_Latn", target_lang="yor_Latn", lang_code="YOR"
)
if yoruba_results:
    evaluation_results["yoruba"] = yoruba_results

# Evaluate Lingala translations
lingala_results = evaluate_translations(
    model, tokenizer, test_df,
    source_lang="eng_Latn", target_lang="lin_Latn", lang_code="LIN"
)
if lingala_results:
    evaluation_results["lingala"] = lingala_results

# Print detailed evaluation results
print("\n==== EVALUATION RESULTS ====")
for lang, results in evaluation_results.items():
    print(f"\n{lang.upper()} TRANSLATION METRICS:")
    print(f"Average BLEU score: {results['avg_bleu']:.4f}")
    print(f"Average sacreBLEU score: {results['avg_sacrebleu']:.4f}")

    print(f"\n{lang.upper()} TRANSLATION EXAMPLES:")
    # Print 5 examples with their metrics
    for i, example in enumerate(results['examples'][:5]):
        print(f"Example {i+1}:")
        print(f"Source: {example['source']}")
        print(f"Expected: {example['expected']}")
        print(f"Translated: {example['translated']}")
        print(f"BLEU: {example['bleu']:.4f}, sacreBLEU: {example['sacrebleu']:.4f}")
        print("-" * 50)

logger.info("Training and evaluation complete")

In [None]:
def interactive_translate():
    """Interactive function to test translations"""
    print("Interactive Translation Testing")
    print("Type 'exit' to quit")

    while True:
        text = input("\nEnter English text to translate: ")
        if text.lower() == 'exit':
            break

        target = input("Translate to (yor/lin): ").lower()
        if target == 'yor':
            target_lang = "yor_Latn"
        elif target == 'lin':
            target_lang = "lin_Latn"
        else:
            print("Invalid language. Use 'yor' or 'lin'.")
            continue

        translation = translate(text, "eng_Latn", target_lang)
        print(f"\nTranslation: {translation}")

In [None]:
interactive_translate()

---

In [None]:
def main():
    # Check if cached datasets exist
    if os.path.exists(TRAIN_CACHE) and os.path.exists(VAL_CACHE):
        logger.info("Loading cached datasets")
        train_dataset = load_from_disk(TRAIN_CACHE)
        val_dataset = load_from_disk(VAL_CACHE)
    else:
        # Set to None to process all rows, or a number to limit during development
        max_rows = None  # e.g., 10000 for faster development

        # Fetch data
        df = fetch_data_in_batches(batch_size=10000, max_rows=max_rows)

        # Prepare datasets
        train_df, val_df, test_df = prepare_datasets(df)

        # Prepare language-specific data
        lang_data = prepare_language_data(train_df, val_df)

        # Convert to Hugging Face datasets
        logger.info("Converting to Hugging Face datasets")
        yoruba_train_dataset = Dataset.from_pandas(lang_data["yoruba_train"])
        lingala_train_dataset = Dataset.from_pandas(lang_data["lingala_train"])
        yoruba_val_dataset = Dataset.from_pandas(lang_data["yoruba_val"])
        lingala_val_dataset = Dataset.from_pandas(lang_data["lingala_val"])

        # Process the datasets with multiple workers
        logger.info("Processing training datasets")
        yoruba_train_dataset = yoruba_train_dataset.map(
            preprocess_function,
            batched=True,
            batch_size=64,
            num_proc=4,
            remove_columns=["ENG", "YOR", "LIN", "target_lang", "target_text"]
        )

        lingala_train_dataset = lingala_train_dataset.map(
            preprocess_function,
            batched=True,
            batch_size=64,
            num_proc=4,
            remove_columns=["ENG", "YOR", "LIN", "target_lang", "target_text"]
        )

        logger.info("Processing validation datasets")
        yoruba_val_dataset = yoruba_val_dataset.map(
            preprocess_function,
            batched=True,
            batch_size=64,
            num_proc=4,
            remove_columns=["ENG", "YOR", "LIN", "target_lang", "target_text"]
        )

        lingala_val_dataset = lingala_val_dataset.map(
            preprocess_function,
            batched=True,
            batch_size=64,
            num_proc=4,
            remove_columns=["ENG", "YOR", "LIN", "target_lang", "target_text"]
        )

        # Combine datasets
        train_dataset = concatenate_datasets([yoruba_train_dataset, lingala_train_dataset])
        val_dataset = concatenate_datasets([yoruba_val_dataset, lingala_val_dataset])

        # Cache the datasets
        logger.info("Caching processed datasets")
        train_dataset.save_to_disk(TRAIN_CACHE)
        val_dataset.save_to_disk(VAL_CACHE)

    # Set up training arguments
    logger.info("Setting up training arguments")
    training_args = Seq2SeqTrainingArguments(
        output_dir="./nllb-finetuned",
        evaluation_strategy="steps",
        learning_rate=2e-5,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        gradient_accumulation_steps=4,
        gradient_checkpointing=True,
        num_train_epochs=3,
        weight_decay=0.01,
        save_total_limit=2,
        save_steps=500,
        eval_steps=500,
        logging_steps=100,
        predict_with_generate=True,
        fp16=device == "cuda",  # Enable mixed precision only on CUDA
        logging_dir="./logs",
        report_to="tensorboard",
        push_to_hub=False,  # Set to True if you want to push to HF Hub
    )

    # Set up data collator
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        padding="max_length",
        max_length=128
    )

    # Set up trainer
    logger.info("Setting up trainer")
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    # Train the model
    logger.info("Starting training")
    trainer.train()

    # Save the model
    logger.info("Saving model")
    trainer.save_model("./nllb-finetuned-final")

In [None]:
# Evaluate on a sample from test set
def translate(text, source_lang, target_lang):
        tokenizer.src_lang = source_lang
        inputs = tokenizer(text, return_tensors="pt", padding=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Set forced BOS token to target language
        forced_bos_token_id = tokenizer.convert_tokens_to_ids(target_lang)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                forced_bos_token_id=forced_bos_token_id,
                max_length=128,
            )

        return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

In [None]:
logger.info("Evaluation on test examples:")

    # Recreate test set if we don't have it
if not 'test_df' in locals():
        # Fetch a small sample for testing
        test_df = fetch_data_in_batches(batch_size=100, max_rows=100)

    # Sample a few examples
test_samples = test_df.sample(min(5, len(test_df)))

print("\nTranslation Examples (English → Yoruba):")
for _, row in test_samples.iterrows():
      if pd.notna(row["YOR"]):
            english = row["ENG"]
            expected_yoruba = row["YOR"]
            translated = translate(english, "eng_Latn", "yor_Latn")

            print(f"Source: {english}")
            print(f"Expected: {expected_yoruba}")
            print(f"Translated: {translated}")
            print("-" * 50)

print("\nTranslation Examples (English → Lingala):")
for _, row in test_samples.iterrows():
      if pd.notna(row["LIN"]):
            english = row["ENG"]
            expected_lingala = row["LIN"]
            translated = translate(english, "eng_Latn", "lin_Latn")

            print(f"Source: {english}")
            print(f"Expected: {expected_lingala}")
            print(f"Translated: {translated}")
            print("-" * 50)

logger.info("Training and evaluation complete")


Translation Examples (English → Yoruba):
Source: (The gold of that land is good; aromatic resin
Expected: “Mo ṣẹ̀ṣẹ̀ rí ẹni tí ó dàbí mi,
Translated: (Gólọ́ọ̀lù ilẹ̀ yẹn dára; òwú òórùn
--------------------------------------------------
Source: God formed a man
Expected: Orúkọ odò kẹta ni Tigirisi, òun ni ó ṣàn lọ sí apá ìlà oòrùn Asiria. Ẹkẹrin ni odò Yufurate.
Translated: Ọlọ́run dá èèyàn
--------------------------------------------------
Source: God had not sent rain on the earth and there was no one to work the ground,
Expected: Láti inú ilẹ̀ ni OLUWA Ọlọrun ti mú kí oríṣìíríṣìí igi hù jáde ninu ọgbà náà, tí wọ́n dùn ún wò, tí wọ́n sì dára fún jíjẹ. Igi ìyè wà láàrin ọgbà náà, ati igi ìmọ̀ ibi ati ire.
Translated: Ọlọ́run kò tíì rọ̀jò lórí ilẹ̀ ayé, kò sì sí ẹni tó lè ṣiṣẹ́ lórí ilẹ̀ náà.
--------------------------------------------------
Source: God had planted a garden in the east, in Eden; and there he put the man he had formed.
Expected: Lẹ́yìn náà OLUWA Ọlọrun sọ pé, “Kò dára