In [1]:
%pip install -U datasets
%pip install transformers
%pip install torch
%pip install sacrebleu
%pip install tqdm
%pip install numpy
%pip install regex
%pip install accelerate
%pip install bert-score
%pip install sentence-transformers

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Using cached nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Using cached nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Using cached nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Using cached nvidia_curand_cu12

In [23]:
from datasets import load_dataset
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from tqdm import tqdm
import logging
from sacrebleu import corpus_bleu
import os
import json
from sacrebleu.metrics import BLEU, CHRF
from bert_score import score as bert_score
from sentence_transformers import SentenceTransformer
import numpy as np
from collections import Counter
import re
import requests
from datasets import Dataset

In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA (NVIDIA GPU) for acceleration")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Metal Performance Shaders) for acceleration")
else:
    device = torch.device("cpu")
    print("No GPU acceleration available, using CPU")

Using CUDA (NVIDIA GPU) for acceleration


In [31]:
# folders to store translations for the flores 101

os.makedirs("flores101", exist_ok=True)

os.makedirs("flores101/northern-sotho", exist_ok=True)

os.makedirs("flores101/hausa", exist_ok=True)

os.makedirs("flores101/zulu", exist_ok=True)

# folders to store translations for the fix for africa

os.makedirs("floresfixforafrica", exist_ok=True)

os.makedirs("floresfixforafrica/northern-sotho", exist_ok=True)

os.makedirs("floresfixforafrica/hausa", exist_ok=True)

os.makedirs("floresfixforafrica/zulu", exist_ok=True)


#creating metrics folders
os.makedirs("metrics", exist_ok=True)
os.makedirs("metrics/flores101", exist_ok=True)
os.makedirs("metrics/floresfixforafrica", exist_ok=True)

In [5]:
# Explicitly set cache_dir=None to prevent loading from potentially problematic cache
eng_dataset_test = load_dataset("gsarti/flores_101", name="eng", split="devtest", cache_dir=None)
eng_dataset_training = load_dataset("gsarti/flores_101", name="eng", split="dev", cache_dir=None)


#nothern sotho
nso_dataset_test = load_dataset("gsarti/flores_101", name="nso", split="devtest", cache_dir=None)
nso_dataset_training = load_dataset("gsarti/flores_101", name="nso", split="dev", cache_dir=None)

#hausa
hau_dataset_test = load_dataset("gsarti/flores_101", name="hau", split="devtest", cache_dir=None)
hau_dataset_training = load_dataset("gsarti/flores_101", name="hau", split="dev", cache_dir=None)

#zulu
zul_dataset_test = load_dataset("gsarti/flores_101", name="zul", split="devtest", cache_dir=None)
zul_dataset_training = load_dataset("gsarti/flores_101", name="zul", split="dev", cache_dir=None)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/7.22k [00:00<?, ?B/s]

flores_101.py:   0%|          | 0.00/7.21k [00:00<?, ?B/s]

0000.parquet:   0%|          | 0.00/109k [00:00<?, ?B/s]

0000.parquet:   0%|          | 0.00/117k [00:00<?, ?B/s]

Generating dev split:   0%|          | 0/997 [00:00<?, ? examples/s]

Generating devtest split:   0%|          | 0/1012 [00:00<?, ? examples/s]

0000.parquet:   0%|          | 0.00/116k [00:00<?, ?B/s]

0000.parquet:   0%|          | 0.00/124k [00:00<?, ?B/s]

Generating dev split:   0%|          | 0/997 [00:00<?, ? examples/s]

Generating devtest split:   0%|          | 0/1012 [00:00<?, ? examples/s]

0000.parquet:   0%|          | 0.00/113k [00:00<?, ?B/s]

0000.parquet:   0%|          | 0.00/120k [00:00<?, ?B/s]

Generating dev split:   0%|          | 0/997 [00:00<?, ? examples/s]

Generating devtest split:   0%|          | 0/1012 [00:00<?, ? examples/s]

0000.parquet:   0%|          | 0.00/118k [00:00<?, ?B/s]

0000.parquet:   0%|          | 0.00/126k [00:00<?, ?B/s]

Generating dev split:   0%|          | 0/997 [00:00<?, ? examples/s]

Generating devtest split:   0%|          | 0/1012 [00:00<?, ? examples/s]

In [6]:
def fetch_data_from_github(url):
    response = requests.get(url)
    if response.status_code == 200:
        return [line.strip() for line in response.text.splitlines()]
    else:
        print(f"Failed to fetch data from {url}")
        return []

In [7]:
# data links
zul_dev_url = "https://raw.githubusercontent.com/dsfsi/flores-fix-4-africa/main/data/corrected/dev/zul_Latn.dev"
zul_devtest_url = "https://raw.githubusercontent.com/dsfsi/flores-fix-4-africa/main/data/corrected/devtest/zul_Latn.devtest"

nso_dev_url = "https://raw.githubusercontent.com/dsfsi/flores-fix-4-africa/main/data/corrected/dev/nso_Latn.dev"
nso_devtest_url = "https://raw.githubusercontent.com/dsfsi/flores-fix-4-africa/main/data/corrected/devtest/nso_Latn.devtest"

hau_dev_url = "https://raw.githubusercontent.com/dsfsi/flores-fix-4-africa/main/data/corrected/dev/hau_Latn.dev"
hau_devtest_url = "https://raw.githubusercontent.com/dsfsi/flores-fix-4-africa/main/data/corrected/devtest/hau_Latn.devtest"

#zulu
zul_dataset_flores_fix_for_africa_test = fetch_data_from_github(zul_devtest_url)
zul_dataset_flores_fix_for_africa_training = fetch_data_from_github(zul_dev_url)

#nothern sotho
nso_dataset_flores_fix_for_africa_test = fetch_data_from_github(nso_devtest_url)
nso_dataset_flores_fix_for_africa_training = fetch_data_from_github(nso_dev_url)

#hausa
hau_dataset_flores_fix_for_africa_test = fetch_data_from_github(hau_devtest_url)
hau_dataset_flores_fix_for_africa_training = fetch_data_from_github(hau_dev_url)

In [8]:
assert len(eng_dataset_test) == len(nso_dataset_test), "Mismatched test dataset sizes"
assert len(eng_dataset_training) == len(nso_dataset_training), "Mismatched training dataset sizes"
assert len(eng_dataset_test) == len(zul_dataset_test), "Mismatched test dataset sizes"
assert len(eng_dataset_training) == len(zul_dataset_training), "Mismatched training dataset sizes"
assert len(eng_dataset_test) == len(hau_dataset_test), "Mismatched test dataset sizes"
assert len(eng_dataset_training) == len(hau_dataset_training), "Mismatched training dataset sizes"

In [9]:
assert len(eng_dataset_test) == len(nso_dataset_flores_fix_for_africa_test), "Mismatched test dataset sizes"
assert len(eng_dataset_training) == len(nso_dataset_flores_fix_for_africa_training), "Mismatched training dataset sizes"
assert len(eng_dataset_test) == len(zul_dataset_flores_fix_for_africa_test), "Mismatched test dataset sizes"
assert len(eng_dataset_training) == len(zul_dataset_flores_fix_for_africa_training), "Mismatched training dataset sizes"
assert len(eng_dataset_test) == len(hau_dataset_flores_fix_for_africa_test), "Mismatched test dataset sizes"
assert len(eng_dataset_training) == len(hau_dataset_flores_fix_for_africa_training), "Mismatched training dataset sizes"

In [10]:
# Northern Sotho
with open("flores101/northern-sotho/flores101.northern-sotho.source.test.txt", "w", encoding="utf-8") as f:
    f.writelines(line["sentence"] + "\n" for line in eng_dataset_test)

with open("flores101/northern-sotho/flores101.northern-sotho.ref.test.txt", "w", encoding="utf-8") as f:
    f.writelines(line["sentence"] + "\n" for line in nso_dataset_test)

with open("flores101/northern-sotho/flores101.northern-sotho.source.training.txt", "w", encoding="utf-8") as f:
    f.writelines(line["sentence"] + "\n" for line in eng_dataset_training)

with open("flores101/northern-sotho/flores101.northern-sotho.ref.training.txt", "w", encoding="utf-8") as f:
    f.writelines(line["sentence"] + "\n" for line in nso_dataset_training)

# Hausa
with open("flores101/hausa/flores101.hausa.source.test.txt", "w", encoding="utf-8") as f:
    f.writelines(line["sentence"] + "\n" for line in eng_dataset_test)

with open("flores101/hausa/flores101.hausa.ref.test.txt", "w", encoding="utf-8") as f:
    f.writelines(line["sentence"] + "\n" for line in hau_dataset_test)

with open("flores101/hausa/flores101.hausa.source.training.txt", "w", encoding="utf-8") as f:
    f.writelines(line["sentence"] + "\n" for line in eng_dataset_training)

with open("flores101/hausa/flores101.hausa.ref.training.txt", "w", encoding="utf-8") as f:
    f.writelines(line["sentence"] + "\n" for line in hau_dataset_training)

# Zulu
with open("flores101/zulu/flores101.zulu.source.test.txt", "w", encoding="utf-8") as f:
    f.writelines(line["sentence"] + "\n" for line in eng_dataset_test)

with open("flores101/zulu/flores101.zulu.ref.test.txt", "w", encoding="utf-8") as f:
    f.writelines(line["sentence"] + "\n" for line in zul_dataset_test)

with open("flores101/zulu/flores101.zulu.source.training.txt", "w", encoding="utf-8") as f:
    f.writelines(line["sentence"] + "\n" for line in eng_dataset_training)

with open("flores101/zulu/flores101.zulu.ref.training.txt", "w", encoding="utf-8") as f:
    f.writelines(line["sentence"] + "\n" for line in zul_dataset_training)


In [11]:
# Northern Sotho
with open("floresfixforafrica/northern-sotho/floresfixforafrica.northern-sotho.source.test.txt", "w", encoding="utf-8") as f:
    f.writelines(line["sentence"]  + "\n" for line in eng_dataset_test)

with open("floresfixforafrica/northern-sotho/floresfixforafrica.northern-sotho.source.training.txt", "w", encoding="utf-8") as f:
    f.writelines(line["sentence"]  + "\n" for line in eng_dataset_training)

with open("floresfixforafrica/northern-sotho/floresfixforafrica.northern-sotho.ref.test.txt", "w", encoding="utf-8") as f:
    f.writelines(line + "\n" for line in nso_dataset_flores_fix_for_africa_test)

with open("floresfixforafrica/northern-sotho/floresfixforafrica.northern-sotho.ref.training.txt", "w", encoding="utf-8") as f:
    f.writelines(line + "\n" for line in nso_dataset_flores_fix_for_africa_training)

# Hausa
with open("floresfixforafrica/hausa/floresfixforafrica.hausa.source.test.txt", "w", encoding="utf-8") as f:
    f.writelines(line["sentence"]  + "\n" for line in eng_dataset_test)

with open("floresfixforafrica/hausa/floresfixforafrica.hausa.source.training.txt", "w", encoding="utf-8") as f:
    f.writelines(line["sentence"]  + "\n" for line in eng_dataset_training)


with open("floresfixforafrica/hausa/floresfixforafrica.hausa.ref.test.txt", "w", encoding="utf-8") as f:
    f.writelines(line + "\n" for line in hau_dataset_flores_fix_for_africa_test)

with open("floresfixforafrica/hausa/floresfixforafrica.hausa.ref.training.txt", "w", encoding="utf-8") as f:
    f.writelines(line + "\n" for line in hau_dataset_flores_fix_for_africa_training)

# Zulu
with open("floresfixforafrica/zulu/floresfixforafrica.zulu.source.test.txt", "w", encoding="utf-8") as f:
    f.writelines(line["sentence"]  + "\n" for line in eng_dataset_test)

with open("floresfixforafrica/zulu/floresfixforafrica.zulu.source.training.txt", "w", encoding="utf-8") as f:
    f.writelines(line["sentence"]  + "\n" for line in eng_dataset_training)

with open("floresfixforafrica/zulu/floresfixforafrica.zulu.ref.test.txt", "w", encoding="utf-8") as f:
    f.writelines(line + "\n" for line in zul_dataset_flores_fix_for_africa_test)

with open("floresfixforafrica/zulu/floresfixforafrica.zulu.ref.training.txt", "w", encoding="utf-8") as f:
    f.writelines(line + "\n" for line in zul_dataset_flores_fix_for_africa_training)

In [12]:
def load_and_prepare_data(target_lang, split="dev"):
    """Load and prepare the FLORES-101 dataset for the specified language."""
    try:
        dataset_name = "facebook/flores"
        source_lang = "eng_Latn"

        # Map language codes to their full format with script
        lang_code_map = {
            "hau": "hau_Latn",
            "nso": "nso_Latn",
            "zul": "zul_Latn"
        }

        target_lang_full = lang_code_map.get(target_lang)
        if not target_lang_full:
            raise ValueError(f"Unsupported target language code: {target_lang}")

        # Load source and target datasets
        source_dataset = load_dataset(dataset_name, name=source_lang, split=split, trust_remote_code=True)
        target_dataset = load_dataset(dataset_name, name=target_lang_full, split=split, trust_remote_code=True)

        # Create training pairs
        training_data = {
            "input_text": [src["sentence"] for src in source_dataset],
            "target_text": [tgt["sentence"] for tgt in target_dataset]
        }

        # Convert to HuggingFace Dataset
        from datasets import Dataset
        return Dataset.from_dict(training_data)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        raise


In [13]:
def initialize_model(model_path="facebook/nllb-200-distilled-600M"):
    """Initialize the NLLB model and tokenizer."""
    try:
        # Check for MPS (Metal Performance Shaders) availability
        if torch.cuda.is_available():
            device = torch.device("cuda")
            print("Using CUDA (NVIDIA GPU) for acceleration")
        elif torch.backends.mps.is_available():
            device = torch.device("mps")
            print("Using MPS (Metal Performance Shaders) for acceleration")
        else:
            device = torch.device("cpu")
            print("No GPU acceleration available, using CPU")

        # Load model and tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device)

        return tokenizer, model, device
    except Exception as e:
        print(f"Error initializing model: {e}")
        raise


In [14]:
def translate_english_to_target_lang(model, tokenizer, device, ref_sentences, output_file, target_lang_code, batch_size=16):
    """Translate English sentences to target language using the trained NLLB model."""
    try:
        # Map language codes to NLLB language codes
        nllb_lang_codes = {
            "hau": "hau_Latn",
            "nso": "nso_Latn",
            "zul": "zul_Latn"
        }

        target_lang = nllb_lang_codes.get(target_lang_code)
        if not target_lang:
            raise ValueError(f"Unsupported target language code: {target_lang_code}")

        forced_bos_token_id = tokenizer.convert_tokens_to_ids(target_lang)

        print(f"Starting translation to {target_lang}...")

        # Prepare all sentences for batch processing
        all_sentences = ref_sentences["input_text"]

        # Ensure output directory exists
        os.makedirs(os.path.dirname(output_file), exist_ok=True)

        with open(output_file, "w", encoding="utf-8") as f:
            # Process in batches
            for i in tqdm(range(0, len(all_sentences), batch_size), desc=f"Translating to {target_lang}"):
                batch_sentences = all_sentences[i:i + batch_size]

                # Tokenize batch without length limits
                inputs = tokenizer(
                    batch_sentences,
                    return_tensors="pt",
                    padding=True,
                    truncation=False
                ).to(device)

                with torch.no_grad():
                    translated = model.generate(
                        **inputs,
                        forced_bos_token_id=forced_bos_token_id,
                        num_beams=4,
                        early_stopping=True
                    )

                # Decode and write batch results
                decoded = tokenizer.batch_decode(translated, skip_special_tokens=True)
                for translation in decoded:
                    f.write(translation.strip() + "\n")

        print(f"Translation completed. Results saved to {output_file}")
    except Exception as e:
        print(f"Error during translation: {e}")
        raise


In [15]:
def calculate_metrics(hypotheses, references, lang, device):
    """Calculate various translation metrics."""
    metrics = {}

    try:
        # BLEU score
        bleu = BLEU()
        metrics['BLEU'] = bleu.corpus_score(hypotheses, [references]).score

        # chrF score
        chrf = CHRF()
        metrics['chrF'] = chrf.corpus_score(hypotheses, [references]).score

        # BERTScore
        P, R, F1 = bert_score(hypotheses, references, lang=lang, device=device)
        metrics['BERTScore'] = F1.mean().item()

        # Semantic Similarity Score (replacing COMET)
        model = SentenceTransformer('all-MiniLM-L6-v2')
        hyp_embeddings = model.encode(hypotheses, convert_to_tensor=True)
        ref_embeddings = model.encode(references, convert_to_tensor=True)
        similarity = torch.nn.functional.cosine_similarity(hyp_embeddings, ref_embeddings)
        metrics['Semantic_Score'] = similarity.mean().item()

        # Error Type Frequency Distribution
        error_types = analyze_errors(hypotheses, references)
        metrics['Error_Distribution'] = error_types

    except Exception as e:
        print(f"Error calculating metrics: {e}")
        metrics['error'] = str(e)

    return metrics


In [16]:
def analyze_errors(hypotheses, references):
    """Analyze translation errors and their distribution."""
    error_types = Counter()

    for hyp, ref in zip(hypotheses, references):
        # Word order errors
        hyp_words = set(hyp.split())
        ref_words = set(ref.split())
        if hyp_words == ref_words and hyp != ref:
            error_types['word_order'] += 1

        # Missing words
        missing = ref_words - hyp_words
        if missing:
            error_types['missing_words'] += len(missing)

        # Extra words
        extra = hyp_words - ref_words
        if extra:
            error_types['extra_words'] += len(extra)

        # Case errors
        if hyp.lower() == ref.lower() and hyp != ref:
            error_types['case_errors'] += 1

        # Punctuation errors
        hyp_no_punct = re.sub(r'[^\w\s]', '', hyp)
        ref_no_punct = re.sub(r'[^\w\s]', '', ref)
        if hyp_no_punct == ref_no_punct and hyp != ref:
            error_types['punctuation_errors'] += 1

    return dict(error_types)


In [29]:
# Function to save metrics to a file
def save_metrics_to_file(metrics, lang, datasetname, output_dir="metrics"):
    os.makedirs(output_dir, exist_ok=True)
    metrics_file = f"{output_dir}/{datasetname}/{lang}_metrics.json"
    print(metrics_file)

    with open(metrics_file, "w", encoding="utf-8") as f:
        json.dump(metrics, f, ensure_ascii=False, indent=4)

In [18]:
# Function to load saved metrics from a file
def load_metrics_from_file(lang, datasetname, output_dir="metrics"):
    metrics_file = f"{output_dir}/{datasetname}/{lang}_metrics.json"
    if os.path.exists(metrics_file):
        with open(metrics_file, "r", encoding="utf-8") as f:
            return json.load(f)
    return None  # Return None if no metrics file exists

In [19]:
def create_dataset_from_sentences(source_dataset, target_dataset):
    """
    Creates a Hugging Face Dataset from source and target sentence pairs.

    Parameters:
    - source_dataset (list of dicts): The source language dataset, e.g., English.
    - target_dataset (list of sentences): The target language dataset, e.g., Hausa.

    Returns:
    - A Hugging Face Dataset object containing 'input_text' and 'target_text'.
    """
    # Ensure the datasets are non-empty and have matching lengths
    if len(source_dataset) != len(target_dataset):
        raise ValueError("Source and target datasets must have the same length.")

    # Create pairs of 'input_text' and 'target_text'
    training_data = {
        "input_text": [src["sentence"] for src in source_dataset],
        "target_text": [tgt for tgt in target_dataset]
    }

    # Convert to Hugging Face Dataset
    return Dataset.from_dict(training_data)


In [33]:
# Configuration
languages = {
    "hausa": "hau",
    "northern-sotho": "nso",
    "zulu": "zul"
}

model_name = "nllb-200-distilled-600M"
model_path = "facebook/nllb-200-distilled-600M"
output_base_dir = "output"

# Initialize model and tokenizer
tokenizer, model, device = initialize_model(model_path)

# Translate for each language
for lang, code in languages.items():
    output_file = f"{output_base_dir}/{model_name}/flores101.{lang}.hyp.txt"

    # Check if translation already exists
    if os.path.exists(output_file) and os.path.getsize(output_file) > 0:
        print(f"Translation for {lang} already exists. Skipping translation...")
    else:
        print(f"Translating to {lang}...")
        test_data = load_and_prepare_data(target_lang=code, split="devtest")
        translate_english_to_target_lang(model, tokenizer, device, test_data, output_file, code)

    # Check if metrics are already saved
    metrics = load_metrics_from_file(lang,"flores101")

    if not metrics:
        # Calculate metrics
        try:
            with open(output_file, "r", encoding="utf-8") as f:
                hyps = [line.strip() for line in f]

            with open(f"flores101/{lang}/flores101.{lang}.ref.test.txt", "r", encoding="utf-8") as f:
                refs = [line.strip() for line in f]

            metrics = calculate_metrics(hyps, refs, lang, device)

            # Save metrics to file
            save_metrics_to_file(metrics,lang,"flores101")

            print(f"\nMetrics for {lang}:")
            print("-" * 30)
            for metric_name, value in metrics.items():
                if metric_name != 'Error_Distribution':
                    print(f"{metric_name:15}: {value:.4f}")

            if 'Error_Distribution' in metrics:
                print("\nError Distribution:")
                for error_type, count in metrics['Error_Distribution'].items():
                    print(f"{error_type:20}: {count}")

        except Exception as e:
            print(f"Error calculating metrics for {lang}: {e}")
            print(f"Could not calculate metrics for {lang}")

# Print summary of all metrics
print("\nSummary of all metrics:")
print("-" * 60)
print(f"{'Language':15} {'BLEU':>8} {'chrF':>8} {'BERTScore':>10} {'Semantic':>8}")
print("-" * 60)

for lang in languages:
    try:
        # Load metrics from file
        metrics = load_metrics_from_file(lang,"flores101")

        if metrics:
            print(f"{lang:15} {metrics['BLEU']:8.2f} {metrics['chrF']:8.2f} {metrics['BERTScore']:10.2f} {metrics['Semantic_Score']:8.2f}")
        else:
            print(f"{lang:15} {'Error':>8} {'Error':>8} {'Error':>10} {'Error':>8}")

    except Exception as e:
        print(f"Error loading metrics for {lang}: {e}")
        print(f"{lang:15} {'Error':>8} {'Error':>8} {'Error':>10} {'Error':>8}")

print("-" * 60)


Using CUDA (NVIDIA GPU) for acceleration
Translation for hausa already exists. Skipping translation...
Translation for northern-sotho already exists. Skipping translation...
Translation for zulu already exists. Skipping translation...

Summary of all metrics:
------------------------------------------------------------
Language            BLEU     chrF  BERTScore Semantic
------------------------------------------------------------
hausa              23.80    51.36       0.82     0.79
northern-sotho     21.75    51.05       0.81     0.77
zulu               16.82    56.27       0.83     0.79
------------------------------------------------------------


In [28]:
# Initialize model and tokenizer
tokenizer, model, device = initialize_model(model_path)

# Translate for each language
for lang, code in languages.items():
    output_file = f"{output_base_dir}/{model_name}/floresfixforafrica.{lang}.hyp.txt"

    # Check if translation already exists
    if os.path.exists(output_file) and os.path.getsize(output_file) > 0:
        print(f"Translation for {lang} already exists. Skipping translation...")
    else:
        print(f"Translating to {lang}...")
        if(lang == "hausa"):
          test_data = create_dataset_from_sentences(eng_dataset_test, hau_dataset_flores_fix_for_africa_test)
          translate_english_to_target_lang(model, tokenizer, device, test_data, output_file, code)
        elif(lang == "zulu"):
          test_data = create_dataset_from_sentences(eng_dataset_test, zul_dataset_flores_fix_for_africa_test)
          translate_english_to_target_lang(model, tokenizer, device, test_data, output_file, code)
        elif(lang == "northern-sotho"):
          test_data = create_dataset_from_sentences(eng_dataset_test, nso_dataset_flores_fix_for_africa_test)
          translate_english_to_target_lang(model, tokenizer, device, test_data, output_file, code)

    # Check if metrics are already saved
    metrics = load_metrics_from_file(lang,"floresfixforafrica")

    if not metrics:
        # Calculate metrics
        try:
            with open(output_file, "r", encoding="utf-8") as f:
                hyps = [line.strip() for line in f]

            with open(f"floresfixforafrica/{lang}/floresfixforafrica.{lang}.ref.test.txt", "r", encoding="utf-8") as f:
                refs = [line.strip() for line in f]

            metrics = calculate_metrics(hyps, refs, lang, device)

            # Save metrics to file
            save_metrics_to_file(metrics, lang, "floresfixforafrica")

            print(f"\nMetrics for {lang}:")
            print("-" * 30)
            for metric_name, value in metrics.items():
                if metric_name != 'Error_Distribution':
                    print(f"{metric_name:15}: {value:.4f}")

            if 'Error_Distribution' in metrics:
                print("\nError Distribution:")
                for error_type, count in metrics['Error_Distribution'].items():
                    print(f"{error_type:20}: {count}")

        except Exception as e:
            print(f"Error calculating metrics for {lang}: {e}")
            print(f"Could not calculate metrics for {lang}")

# Print summary of all metrics
print("\nSummary of all metrics:")
print("-" * 60)
print(f"{'Language':15} {'BLEU':>8} {'chrF':>8} {'BERTScore':>10} {'Semantic':>8}")
print("-" * 60)

for lang in languages:
    try:
        # Load metrics from file
        metrics = load_metrics_from_file(lang,"floresfixforafrica")

        if metrics:
            print(f"{lang:15} {metrics['BLEU']:8.2f} {metrics['chrF']:8.2f} {metrics['BERTScore']:10.2f} {metrics['Semantic_Score']:8.2f}")
        else:
            print(f"{lang:15} {'Error':>8} {'Error':>8} {'Error':>10} {'Error':>8}")

    except Exception as e:
        print(f"Error loading metrics for {lang}: {e}")
        print(f"{lang:15} {'Error':>8} {'Error':>8} {'Error':>10} {'Error':>8}")

print("-" * 60)

Using CUDA (NVIDIA GPU) for acceleration
Translation for hausa already exists. Skipping translation...
Translation for northern-sotho already exists. Skipping translation...
Translation for zulu already exists. Skipping translation...

Summary of all metrics:
------------------------------------------------------------
Language            BLEU     chrF  BERTScore Semantic
------------------------------------------------------------
hausa              23.45    51.08       0.82     0.79
northern-sotho     22.06    51.28       0.81     0.77
zulu               17.83    57.14       0.84     0.80
------------------------------------------------------------
