In [1]:
%pip install sacrebleu
%pip install datasets
%pip install transformers torch
%pip install --upgrade pip setuptools wheel

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
from datasets import load_dataset

eng_dataset = load_dataset("gsarti/flores_101", name="eng", split="devtest",trust_remote_code=True)
zul_dataset = load_dataset("gsarti/flores_101", name="zul", split="devtest",trust_remote_code=True)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
assert len(eng_dataset) == len(zul_dataset), "Mismatched dataset sizes"

with open("flores101.source.txt", "w", encoding="utf-8") as f:
    f.writelines(line["sentence"] + "\n" for line in eng_dataset)

with open("flores101.ref.txt", "w", encoding="utf-8") as f:
    f.writelines(line["sentence"] + "\n" for line in zul_dataset)

In [4]:
from datasets import load_dataset

zul_dataset = load_dataset("gsarti/flores_101", name="zul", split="devtest",trust_remote_code=True)
ref_sentences = [entry["sentence"] for entry in zul_dataset]

with open("flores101.ref.txt", "w", encoding="utf-8") as f:
    for sent in ref_sentences:
        f.write(sent.strip() + "\n")

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from datasets import load_dataset
from tqdm import tqdm
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def load_and_prepare_data(dataset_name="gsarti/flores_101", language="eng", split="devtest"):
    """Load and prepare the Flores dataset."""
    try:
        dataset = load_dataset(dataset_name, name=language, split=split, trust_remote_code=True)
        return [{"eng": entry["sentence"]} for entry in dataset]
    except Exception as e:
        logger.error(f"Error loading dataset: {e}")
        raise

def initialize_model(model_name="facebook/nllb-200-distilled-600M"):
    """Initialize and return the translation model and tokenizer."""
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = model.to(device)
        return tokenizer, model, device
    except Exception as e:
        logger.error(f"Error initializing model: {e}")
        raise

def translate_english_to_zulu(ref_sentences, output_file="flores101.hyp.txt"):
    """Translate English sentences to Zulu using NLLB-200 model."""
    try:
        tokenizer, model, device = initialize_model()
        target_lang = "zul_Latn"  
        
        forced_bos_token_id = tokenizer.convert_tokens_to_ids(target_lang)
        
        logger.info(f"Starting translation of {len(ref_sentences)} sentences...")
        
        with open(output_file, "w", encoding="utf-8") as f:
            for entry in tqdm(ref_sentences, desc="Translating"):
                
                inputs = tokenizer(
                    entry["eng"],
                    return_tensors="pt",
                    padding=True,
                    truncation=True
                ).to(device)

                with torch.no_grad():
                    translated = model.generate(
                        **inputs,
                        forced_bos_token_id=forced_bos_token_id
                    )
                
                decoded = tokenizer.decode(translated[0], skip_special_tokens=True)
                f.write(decoded.strip() + "\n")
        
        logger.info(f"Translation completed. Results saved to {output_file}")
    except Exception as e:
        logger.error(f"Error during translation: {e}")
        raise

ref_sentences = load_and_prepare_data()
    
translate_english_to_zulu(ref_sentences)

INFO:__main__:Starting translation of 1012 sentences...
Translating: 100%|██████████| 1012/1012 [18:39<00:00,  1.11s/it]
INFO:__main__:Translation completed. Results saved to flores101.hyp.txt


In [6]:
from sacrebleu import corpus_bleu

with open("flores101.hyp.txt", "r", encoding="utf-8") as f:
    hyps = [line.strip() for line in f]

with open("flores101.ref.txt", "r", encoding="utf-8") as f:
    refs = [[line.strip() for line in f]]

bleu_score = corpus_bleu(hyps, refs)
print(f"BLEU score: {bleu_score.score:.2f}")

BLEU score: 15.61
