In [1]:
!nvidia-smi

Mon Feb 24 15:59:52 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.08             Driver Version: 550.127.08     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A2                      Off |   00000000:CA:00.0 Off |                    0 |
|  0%   48C    P0             25W /   60W |       4MiB /  15356MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
import os
import numpy as np
import evaluate
from evaluate import evaluator
from datasets import load_dataset
import transformers
import torch
from tqdm import tqdm
import pickle
import re
import peft
from datasets import Dataset
import huggingface_hub
from transformers import BitsAndBytesConfig
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)

  from .autonotebook import tqdm as notebook_tqdm
2025-02-24 15:59:57.919767: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-24 15:59:57.934378: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1740412797.952869   56824 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1740412797.958301   56824 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-24 15:59:57.984633: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorF

cuda:0


In [3]:
num_beams = 5
max_new_tokens = 512
top_p = 0.9
temperature = 0.6

In [4]:
language_tested = ["en", "de", "cs", "is", "zh", "ru"] # Only from or to english
metrics_available = ["bleu", "rouge", "bleurt", "sacrebleu", "comet", "meteor", "chrf", "bert_score"]
models_available = [
    # NLLB
    "facebook/nllb-200-distilled-600M",
    # ALMA
    "haoranxu/ALMA-7B",
    # Llama 3 Instruct
    "meta-llama/Llama-3.2-1B-Instruct",
    "meta-llama/Llama-3.2-3B-Instruct",
    "meta-llama/Llama-3.1-8B-Instruct",
    # Falcon 3 Mamba Instruct
    "tiiuae/Falcon3-Mamba-7B-Instruct",
    # Falcon 3 Instruct
    "tiiuae/Falcon3-7B-Instruct",
    "tiiuae/Falcon3-3B-Instruct",
    "tiiuae/Falcon3-1B-Instruct",
    # Qwen 2.5 Mamba Instruct
    "Qwen/Qwen2.5-0.5B-Instruct",
    "Qwen/Qwen2.5-1.5B-Instruct",
    "Qwen/Qwen2.5-3B-Instruct",
    "Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8",
    # Mistral Instruct
    "mistralai/Mistral-7B-Instruct-v0.3",
] # TODO mistral 7B, Bloom 7B, OPT 7B, MPT 7B, Bayling (?)

ds_available = ["haoranxu/WMT23-Test",
                "openlanguagedata/flores_plus"]

## Inference functions

In [5]:
#################################   NLLB

def get_input_targets_NLLB(dataset_wnt_format, source_lang, target_lang):
    inputs = [example[source_lang] for example in dataset_wnt_format[f"{source_lang}-{target_lang}"]]
    targets = [example[target_lang] for example in dataset_wnt_format[f"{source_lang}-{target_lang}"]]
    return inputs, inputs, targets

def translate_list_of_str_NLLB(list_str, tokenizer, model, to_laguage):
    """
    Returns a list containing str corresponding to translation of the inputted
    """
    equivalence_language_to_FLORES = {"en": "eng_Latn", "de": "deu_Latn", "ru": "rus_Cyrl", "is": "isl_Latn", "zh": "zho_Hans", "cs": "ces_Latn"}
    with torch.no_grad():
        inputs = tokenizer(list_str, return_tensors="pt", padding=True)
        language_tgt_FLORES = equivalence_language_to_FLORES[to_laguage]
        translated = model.generate(inputs["input_ids"].to(device),
                                    forced_bos_token_id=tokenizer.convert_tokens_to_ids(language_tgt_FLORES),
                                    num_beams=5, max_length=512, early_stopping=True,
                                    ).cpu()
        translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
    return translated_text

def translate_batched_NLLB(inputs, model, tokenizer, batch_size, target_language):
    """
    For 8GB VRAM, use batch_size = 4
    For 16GB VRAM, use batch_size = 8 (better working with unbatch version to avoid pad noise).
    """
    preds = []
    for i in tqdm(range(len(inputs)//batch_size)):
        tslt = translate_list_of_str_NLLB(inputs[i*batch_size : (i+1)*batch_size], tokenizer, model, target_language)
        preds.extend(tslt)
    return preds

#################################   ALMA

def get_input_targets_ALMA(dataset, source_lang, target_lang):
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    source_lang_name = language_name[source_lang]
    target_lang_name = language_name[target_lang]
    # Use base formulation "Translate this from Chinese to English:\nChinese: 我爱机器翻译。\nEnglish:"
    sources = [example[source_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    inputs = [(
        f"Translate from {source_lang_name} to {target_lang_name}:"
        + f"\n{source_lang_name}: {example.get(source_lang)} \n{target_lang_name}:")
        for example in dataset[f"{source_lang}-{target_lang}"]]
    targets = [example[target_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    return sources, inputs, targets

def translate_list_of_str_ALMA(list_str, tokenizer, model, target_language):
    """
    Returns a list containing str corresponding to translation of the inputted
    """
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    with torch.no_grad():
        inputs = tokenizer(list_str, return_tensors="pt", padding=True)
        translated = model.generate(inputs["input_ids"].to(device),
                                    num_beams=5, max_new_tokens=512, do_sample=True,
                                    temperature=0.6, top_p=0.9
                                    ).cpu()
        translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
        tgt_language_name = language_name[target_language]
        translated_text = [t.split(f"{tgt_language_name}:")[2] for t in translated_text] # Remove prompt
    return translated_text

def translate_batched_ALMA(inputs, model, tokenizer, batch_size, target_language):
    """
    For 8GB VRAM, use batch_size=1
    For 16GB VRAM, use batch_size=3
    """
    preds = []
    for i in tqdm(range(len(inputs)//batch_size)):
        tslt = translate_list_of_str_ALMA(inputs[i*batch_size : (i+1)*batch_size], tokenizer, model, target_language)
        preds.extend(tslt)
    return preds

#################################   Llama 3

def get_input_targets_Llama3(dataset, source_lang, target_lang):
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    source_lang_name = language_name[source_lang]
    target_lang_name = language_name[target_lang]
    # Use base formulation "Translate this from Chinese to English:\nChinese: 我爱机器翻译。\nEnglish:"
    sources = [example[source_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    inputs = [
        [{"role": "system", "content": "You are a translator, you output only the translation in the desired language."},
         {"role": "user",
        "content": f"Translate from {source_lang_name} to {target_lang_name}:"
        + f"\n{source_lang_name}: {example.get(source_lang)} \n{target_lang_name}:"
        }] for example in dataset[f"{source_lang}-{target_lang}"]]
    targets = [example[target_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    return sources, inputs, targets

def extract_translation_Llama3(translated_prompt):
    answer = translated_prompt.split("<|start_header_id|>assistant<|end_header_id|>\n\n")[-1]
    translation_only = answer.split("<|end_of_text|>")[0]
    translation_only = translation_only.split("<|eot_id|><|start_header_id|>assistant\n")[-1]
    translation_only = translation_only.split("<|eot_id|><|start_header_id|>")[-1]
    return translation_only

def translate_list_of_str_Llama3(list_str, tokenizer, model, target_language=None):
    with torch.no_grad():
        instruct_messages = tokenizer.apply_chat_template(list_str, tokenize=False, add_generation_prompt=True)
        tokens = tokenizer(instruct_messages, padding=True, padding_side='left', return_tensors="pt")
        out_tokens = model.generate(**tokens.to(device),
                                    num_beams=num_beams, do_sample=True,
                                    temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens)
        translations = tokenizer.batch_decode(out_tokens)
        translations = [extract_translation_Llama3(trans) for trans in translations]
        return translations
    
def translate_batched_Llama3(inputs, model, tokenizer, batch_size, target_language):
    """
    For 8GB VRAM use
        batch_size=20 with Llama3 1B,
        batch_size=4 with Llama3 3B
    For 16 GB VRAM use 
        batch_size=40 with Llama3 1B,
        batch_size=10 with Llama3 3B,
        batch_size=5 with Llama3 8B,
    """
    preds = []
    for i in tqdm(range(len(inputs)//batch_size)):
        tslt = translate_list_of_str_Llama3(inputs[i*batch_size : (i+1)*batch_size], tokenizer, model, target_language=None)
        preds.extend(tslt)
    return preds

#################################   Falcon 3 (Normal + Mamba)

def get_input_targets_Falcon3(dataset, source_lang, target_lang):
    """
    This function is valid for Falcon 3 and it mamba version
    """
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    source_lang_name = language_name[source_lang]
    target_lang_name = language_name[target_lang]
    # Use base formulation "Translate this from Chinese to English:\nChinese: 我爱机器翻译。\nEnglish:"
    sources = [example[source_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    inputs = [
        [{"role": "system", "content": "You are a translator, you output only the translation in the desired language."},
         {"role": "user",
          "content": f"Translate from {source_lang_name} to {target_lang_name}:"
          + f"{example.get(source_lang)}"
        }] for example in dataset[f"{source_lang}-{target_lang}"]]
    targets = [example[target_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    return sources, inputs, targets

def extract_translation_Falcon3Mamba(translated_prompt):
    answer = translated_prompt.split("<|im_end|>\n<|im_start|>assistant\n")[-1]
    translation_only = answer.split("<|im_end|>")[0]
    return translation_only

def translate_list_of_str_Falcon3Mamba(list_str, tokenizer, model, target_language=None):
    with torch.no_grad():
        instruct_messages = tokenizer.apply_chat_template(list_str, tokenize=False, add_generation_prompt=True)
        tokens = tokenizer(instruct_messages, padding=True, padding_side='left', return_tensors="pt").to(model.device)
        out_tokens = model.generate(**tokens,
                                    num_beams=5, do_sample=True,
                                    temperature=0.6, top_p=0.9, max_new_tokens=300)
        translations = tokenizer.batch_decode(out_tokens)
        translations = [extract_translation_Falcon3Mamba(trans) for trans in translations]
        return translations
    
def translate_batched_Falcon3Mamba(inputs, model, tokenizer, batch_size, target_language=None):
    """
    For 16GB VRAM, use
        batch_size=4 with Falcon Mamba 7B (8 bits quantization),
        batch_size=4 with Falcon Mamba 7B (4 bits quantization),
    """
    preds = []
    for i in tqdm(range(len(inputs)//batch_size)):
        tslt = translate_list_of_str_Falcon3Mamba(inputs[i*batch_size : (i+1)*batch_size], tokenizer, model, target_language)
        preds.extend(tslt)
    return preds

def extract_translation_Falcon3(translated_prompt):
    answerpadded = translated_prompt.split("\n<|assistant|>\n")[-1]
    answer = answerpadded.split("<|pad|>")[-1]
    translation_only = answer.split("<|endoftext|>")[0]
    translation_only = re.sub(r"^[^a-zA-Z0-9]*", "", translation_only)
    return translation_only.replace("assistant|>\n", "")

def translate_list_of_str_Falcon3(list_str, tokenizer, model, target_language=None):
    with torch.no_grad():
        instruct_messages = tokenizer.apply_chat_template(list_str, tokenize=False, add_generation_prompt=True)
        tokens = tokenizer(instruct_messages, padding=True, padding_side='left', return_tensors="pt").to(model.device)
        out_tokens = model.generate(**tokens,
                                    num_beams=5, do_sample=True,
                                    temperature=0.6, top_p=0.9, max_new_tokens=500)
        translations = tokenizer.batch_decode(out_tokens)
        translations = [extract_translation_Falcon3(trans) for trans in translations]
        return translations
    
def translate_batched_Falcon3(inputs, model, tokenizer, batch_size, target_language=None):
    """
    For 16GB VRAM, use
        batch_size=8 with Falcon 7B (8 bits quantization),
        batch_size=4 with Falcon 3B,
        batch_size=12 with Falcon 1B
    """
    preds = []
    for i in tqdm(range(len(inputs)//batch_size)):
        tslt = translate_list_of_str_Falcon3(inputs[i*batch_size : (i+1)*batch_size], tokenizer, model, target_language)
        preds.extend(tslt)
    return preds

#################################   Qwen 2.5

def get_input_targets_Qwen2_5(dataset, source_lang, target_lang):
    """
    This function is valid for Falcon 3 and it mamba version
    """
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    source_lang_name = language_name[source_lang]
    target_lang_name = language_name[target_lang]
    # Use base formulation "Translate this from Chinese to English:\nChinese: 我爱机器翻译。\nEnglish:"
    sources = [example[source_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    inputs = [
        [{"role": "system", "content": "You are a translator, you output only the translation in the desired language."},
         {"role": "user",
          "content": f"Translate from {source_lang_name} to {target_lang_name}:"
          + f"{example.get(source_lang)}"
        }] for example in dataset[f"{source_lang}-{target_lang}"]]
    targets = [example[target_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    return sources, inputs, targets

def extract_translation_Qwen2_5(translated_prompt):
    answerpadded = translated_prompt.split("\n<|im_start|>assistant\n")[-1]
    answer = answerpadded.split("<|im_end|>")[0]
    translation_only = answer.replace("<|endoftext|>", "")
    return translation_only

def translate_list_of_str_Qwen2_5(list_str, tokenizer, model, target_language=None):
    with torch.no_grad():
        instruct_messages = tokenizer.apply_chat_template(list_str, tokenize=False, add_generation_prompt=True)
        tokens = tokenizer(instruct_messages, padding=True, padding_side='left', return_tensors="pt").to(model.device)
        out_tokens = model.generate(**tokens,
                                    num_beams=num_beams, do_sample=True,
                                    temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens)
        translations = tokenizer.batch_decode(out_tokens)
        translations = [extract_translation_Qwen2_5(trans) for trans in translations]
        return translations

def translate_batched_Qwen2_5(inputs, model, tokenizer, batch_size, target_language=None):
    """
    For 16GB VRAM, use batch_size=100 (up to)
    """
    preds = []
    for i in tqdm(range(len(inputs)//batch_size)):
        tslt = translate_list_of_str_Qwen2_5(inputs[i*batch_size : (i+1)*batch_size], tokenizer, model, target_language)
        preds.extend(tslt)
    return preds

#################################   Mistral


def get_input_targets_Mistral(dataset, source_lang, target_lang):
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    source_lang_name = language_name[source_lang]
    target_lang_name = language_name[target_lang]
    # Use base formulation "Translate this from Chinese to English:\nChinese: 我爱机器翻译。\nEnglish:"
    sources = [example[source_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    inputs = [
        [{"role": "system", "content": "You are a translator, you output only the translation in the desired language."},
         {"role": "user",
          "content": f"Translate from {source_lang_name} to {target_lang_name}:"
          + f"{example.get(source_lang)}"
        }] for example in dataset[f"{source_lang}-{target_lang}"]]
    targets = [example[target_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    return sources, inputs, targets

def extract_translation_Mistral(translated_prompt):
    answerpadded = translated_prompt.split("[/INST] ")[-1]
    answer = answerpadded.split("</s>")[0]
    return answer

def translate_list_of_str_Mistral(list_str, tokenizer, model, target_language=None):
    with torch.no_grad():
        instruct_messages = tokenizer.apply_chat_template(list_str, tokenize=False, add_generation_prompt=True)
        tokens = tokenizer(instruct_messages, padding=True, padding_side='left', return_tensors="pt").to(model.device)
        out_tokens = model.generate(**tokens,
                                    num_beams=num_beams, do_sample=True,
                                    temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens)
        translations = tokenizer.batch_decode(out_tokens)
        translations = [extract_translation_Mistral(trans) for trans in translations]
        return translations

def translate_batched_Mistral(inputs, model, tokenizer, batch_size, target_language=None):
    """
    For 16GB VRAM, use batch_size=2 (up to)
    """
    preds = []
    for i in tqdm(range(len(inputs)//batch_size)):
        tslt = translate_list_of_str_Mistral(inputs[i*batch_size : (i+1)*batch_size], tokenizer, model, target_language)
        preds.extend(tslt)
    return preds

## Dataset handling
We use WNT23 from the authors preprocessed split and the FLORES+ dataset, format in the same way that the WNT23 is.

### WNT23

In [15]:
ds_wnt = load_dataset("haoranxu/WMT23-Test", "en-de")["test"]
print(len(ds_wnt), ds_wnt[0:4])

557 {'en-de': [{'en': 'Police arrest 15 after violent protest outside UK refugee hotel', 'de': 'Polizei verhaftet 15 Menschen nach gewalttätigen Protesten vor einer Flüchtlingsunterkunft in Großbritannien'}, {'en': 'The incident comes after increase in numbers of refugees and asylum seekers crossing the Channel to the UK in boats. Police have arrested 15 people after an anti-refugee demonstration outside a hotel used to house asylum seekers turned violent near the English city of Liverpool. The Merseyside Police department said a police officer and two civilians sustained minor injuries during the disturbance on Friday night in Knowsley. The police force said some protesters threw objects and set a police van on fire. The people arrested, who ranged in age from 13 to 54, were detained "following violent disorder." Merseyside police commissioner Emily Spurrell told Radio City, "It was incredibly dangerous and there were a couple of injuries amongst the police officers."', 'de': 'Der Vor

### FLORES

In [27]:
from credentials import hf_token
huggingface_hub.login(token = hf_token)
ds_flores = load_dataset("openlanguagedata/flores_plus")["devtest"]

Resolving data files:   0%|          | 0/219 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/213 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/219 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/213 [00:00<?, ?it/s]

In [28]:
def transform_to_WNT_style(ds_flores, lang, lang_start="en"):
    language_to_iso = {"en": "eng", "de": "deu", "cs": "ces", "is": "isl", "zh": "zho", "ru": "rus"}
    list_sentence_lang, list_sentence_lang_start = [], []
    for elem in ds_flores:
        if elem['iso_639_3'] == language_to_iso[lang]:
            list_sentence_lang.append(elem["text"])
        elif elem['iso_639_3'] == language_to_iso[lang_start]:
            list_sentence_lang_start.append(elem["text"])
    assert len(list_sentence_lang) == len(list_sentence_lang_start)
    print(f"Number of samples: {len(list_sentence_lang)}")
    final_text_list = []
    for i in range(len(list_sentence_lang)):
        final_text_list.append({f"{lang_start}": list_sentence_lang_start[i],
                                f"{lang}": list_sentence_lang[i],})
    return Dataset.from_dict({f"{lang_start}-{lang}": final_text_list})

In [3]:
ds_flores_wnt_style = transform_to_WNT_style(ds_flores, lang="de", lang_start="en")
print(ds_flores_wnt_style[0:4])

NameError: name 'transform_to_WNT_style' is not defined

## Metrics from predictions: evaluation function

In [4]:
def evaluate_translation(sources, targets, translation_infered):
    print("Computing BLEU...")
    bleu = evaluate.load("bleu")
    results_bleu = bleu.compute(predictions=translation_infered, references=targets)

    print("Computing ROUGE...")
    rouge = evaluate.load('rouge')
    results_rouge = rouge.compute(predictions=translation_infered, references=targets)

    print("Computing BLEURT...")
    bleurt = evaluate.load("bleurt", module_type="metric")
    results_bleurt = bleurt.compute(predictions=translation_infered, references=targets)

    print("Computing SACREBLEU...")
    sacrebleu = evaluate.load("sacrebleu")
    results_sacrebleu = sacrebleu.compute(predictions=translation_infered, references=targets)

    print("Computing COMET...")
    comet_metric = evaluate.load('comet')
    results_comet = comet_metric.compute(predictions=translation_infered, references=targets, sources=sources)

    print("Computing METEOR...")
    meteor = evaluate.load('meteor')
    results_meteor = meteor.compute(predictions=translation_infered, references=targets)

    print("Computing Chrf++...")
    chrf = evaluate.load("chrf")
    results_chrf = chrf.compute(predictions=translation_infered, references=[[reference] for reference in targets])

    print("Computing_bert_score...")
    bertscore = evaluate.load("bertscore")
    results_bert = bertscore.compute(predictions=translation_infered, references=targets, lang="de")
    return {
        "bleu": results_bleu,
        "rouge": results_rouge,
        "bleurt": results_bleurt,
        "sacrebleu": results_sacrebleu,
        "comet": results_comet,
        "meteor": results_meteor,
        "chrf": results_chrf,
        "bertscore": results_bert,
    }

## Models & inference loops

### NLLB

In [5]:
tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
model = transformers.AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M", torch_dtype="auto", device_map=device)

In [None]:
sources, inputs, targets = get_input_targets_NLLB(ds_wnt, source_lang="en", target_lang="de")

In [11]:
for i in range(8):
    print(translate_list_of_str_NLLB(inputs[i:i+1], tokenizer, model, "de"))

['Polizei verhaftet 15 Personen nach gewalttätigen Protesten vor einem Flüchtlingshotel in Großbritannien']
['Die Polizei hat 15 Personen verhaftet, nachdem eine Anti-Flüchtlingsdemonstration vor einem Hotel, in dem Asylsuchende untergebracht wurden, in der Nähe der englischen Stadt Liverpool gewalttätig wurde. Die Polizeibehörde von Merseyside sagte, dass ein Polizist und zwei Zivilisten während der Unruhe am Freitagabend in Knowsley leicht verletzt wurden. Die Polizei sagte, dass einige Demonstranten Gegenstände geworfen und einen Polizeiwagen in Brand gesetzt hatten. Die festgenommenen Personen, die im Alter von 13 bis 54 Jahren waren, wurden "nach gewalttätiger Unordnung" festgenommen. "Es war unglaublich gefährlich und es gab ein paar Verletzungen unter den Polizisten", sagte die Polizeibeauftragte von Merseyside Emily Spurrell gegenüber Radio City.']
['Das Innenministerium nutzt das Hotel seit dem vergangenen Jahr, um Asylsuchende vorübergehend zu beherbergen, berichten lokale Me

In [9]:
predicted_trslt = translate_batched_NLLB(inputs[0:2], model, tokenizer, target_lang = "de", batch_size = 1)

100%|██████████| 2/2 [00:04<00:00,  2.38s/it]


In [None]:
with open("./translation_NLLB_en-de_save.pkl", "wb") as f:
    pickle.dump(predicted_trslt, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open("./translation_NLLB_en-de_save.pkl", "rb") as f:
    predicted_trslt = pickle.load(f)

In [12]:
evaluation = evaluate_translation(sources[0:2], targets[0:2], predicted_trslt)

Computing BLEU...
Computing ROUGE...
Computing BLEURT...




INFO:tensorflow:Reading checkpoint /home/onyxia/.cache/huggingface/metrics/bleurt/default/downloads/extracted/ec99b4b83def7843831e88a47af3e7c90f3315b30f2a6acb70a5587bac0264e0/bleurt-base-128.


INFO:tensorflow:Reading checkpoint /home/onyxia/.cache/huggingface/metrics/bleurt/default/downloads/extracted/ec99b4b83def7843831e88a47af3e7c90f3315b30f2a6acb70a5587bac0264e0/bleurt-base-128.


INFO:tensorflow:Config file found, reading.


INFO:tensorflow:Config file found, reading.


INFO:tensorflow:Will load checkpoint bert_custom


INFO:tensorflow:Will load checkpoint bert_custom


INFO:tensorflow:Loads full paths and checks that files exists.


INFO:tensorflow:Loads full paths and checks that files exists.


INFO:tensorflow:... name:bert_custom


INFO:tensorflow:... name:bert_custom


INFO:tensorflow:... vocab_file:vocab.txt


INFO:tensorflow:... vocab_file:vocab.txt


INFO:tensorflow:... bert_config_file:bert_config.json


INFO:tensorflow:... bert_config_file:bert_config.json


INFO:tensorflow:... do_lower_case:True


INFO:tensorflow:... do_lower_case:True


INFO:tensorflow:... max_seq_length:128


INFO:tensorflow:... max_seq_length:128


INFO:tensorflow:Creating BLEURT scorer.


INFO:tensorflow:Creating BLEURT scorer.


INFO:tensorflow:Creating WordPiece tokenizer.


INFO:tensorflow:Creating WordPiece tokenizer.


INFO:tensorflow:WordPiece tokenizer instantiated.


INFO:tensorflow:WordPiece tokenizer instantiated.


INFO:tensorflow:Creating Eager Mode predictor.


INFO:tensorflow:Creating Eager Mode predictor.


INFO:tensorflow:Loading model.


INFO:tensorflow:Loading model.


INFO:tensorflow:BLEURT initialized.


INFO:tensorflow:BLEURT initialized.


Computing SACREBLEU...
Computing COMET...


Fetching 5 files: 100%|██████████| 5/5 [00:00<00:00, 31207.62it/s]
INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.8.3.post1 to v2.5.0.post0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../.cache/huggingface/hub/models--Unbabel--wmt22-comet-da/snapshots/f49d328952c3470eff6bb6f545d62bfdb6e66304/checkpoints/model.ckpt`
/opt/conda/envs/SNLP-gpu/lib/python3.10/site-packages/pytorch_lightning/core/saving.py:195: Found keys that are not in the model state dict but in the checkpoint: ['encoder.model.embeddings.position_ids']
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
huggingface/token

Computing METEOR...


[nltk_data] Downloading package wordnet to /home/onyxia/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/onyxia/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/onyxia/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


Computing Chrf++...
Computing_bert_score...


### ALMA

In [4]:
# Load base model and LoRA weights
tokenizer = transformers.LlamaTokenizer.from_pretrained("haoranxu/ALMA-7B", padding_side='left')
Q_config = BitsAndBytesConfig(load_in_8bit=True) 
model = transformers.AutoModelForCausalLM.from_pretrained("haoranxu/ALMA-7B", torch_dtype="auto", device_map=device, quantization_config=Q_config)

Loading checkpoint shards: 100%|██████████| 3/3 [00:07<00:00,  2.34s/it]


In [None]:
sources, inputs, targets = get_input_targets_ALMA(ds_wnt, source_lang="en", target_lang="de")

In [10]:
for i in range(4):
    print(translate_list_of_str_ALMA(inputs[i:i+1], tokenizer, model, "de"))



['Polizei nimmt 15 Personen nach gewalttätigen Protesten außerhalb eines ukrainischen Flüchtlingshotels fest']
['Der Vorfall ereignete sich nach einem Anstieg der Zahl von Flüchtlingen und Asylbewerbern, die mit Booten über den Kanal nach Großbritannien übersetzten. Die Polizei hat 15 Personen nach einer anti-Flüchtlings-Demonstration außerhalb eines Hotels festgenommen, in dem Asylbewerber untergebracht waren. Das Polizeipräsidium Merseyside teilte mit, dass ein Polizeibeamter und zwei Zivilisten leichte Verletzungen erlitten hatten während der Auseinandersetzung am Freitagabend in Knowsley. Die Polizeibehörde teilte mit, einige Demonstranten hätten Gegenstände geworfen und einen Polizeiwagen in Brand gesteckt. Die festgenommenen Personen, die zwischen 13 und 54 Jahre alt waren, wurden "nach gewalttätigen Ausschreitungen" verhaftet. Die Polizeipräsidentin von Merseyside, Emily Spurrell, sagte im Radio City: "Es war unglaublich gefährlich und es gab einige Verletzungen unter den Polize

### Llama 3 Instruct

In [None]:
from credentials import hf_token
huggingface_hub.login(token = hf_token)
# tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
# tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
# tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
tokenizer.pad_token = tokenizer.eos_token
# model = transformers.AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", torch_dtype="auto", device_map=device)
# model = transformers.AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct", torch_dtype="auto", device_map=device)
# Q_config = BitsAndBytesConfig(load_in_8bit=True) 
# model = transformers.AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", torch_dtype="auto", device_map=device, quantization_config=Q_config)

Downloading shards:   0%|          | 0/4 [00:02<?, ?it/s]


KeyboardInterrupt: 

In [None]:
sources, inputs, targets = get_input_targets_Llama3(ds_wnt, source_lang="en", target_lang="de")

In [None]:
translate_list_of_str_Llama3(inputs[0:5], tokenizer, model)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


['Polizei verhaftet 15 nach gewaltsamen Protesten vor einem Flüchtlingshotel in Großbritannien',
 'Das Ereignis kommt nach einem Anstieg der Zahlen von Flüchtlingen und Asylsuchenden, die in Booten über den Kanal in das Vereinigte Königreich einreisen. Die Polizei hat 15 Menschen festgenommen, nachdem eine Anti-Flüchtlingsdemonstration vor einem Hotel, das als Unterkunft für Asylsuchende diente, in der Nähe der englischen Stadt Liverpool gewalttätig wurde. Die Polizeibehörde von Merseyside sagte, ein Polizeibeamter und zwei Zivilisten seien bei dem Aufruhr am Freitagabend in Knowsley leicht verletzt worden. Die Polizei sagte, einige Demonstranten hätten Gegenstände geworfen und ein Polizeifahrzeug in Brand gesteckt. Die festgenommenen Personen, die im Alter von 13 bis 54 Jahre alt waren, wurden "nach gewalttätigem Verhalten" festgenommen. Die Polizeikommissarin von Merseyside, Emily Spurrell, sagte bei Radio City, "Es war äußerst gefährlich und es gab einige Verletzungen unter den Poli

### Falcon 3 Instruct (mamba and transformer)

In [4]:
# tokenizer = transformers.AutoTokenizer.from_pretrained("tiiuae/Falcon3-Mamba-7B-Instruct")
# Q_config = BitsAndBytesConfig(load_in_8bit=True)
# model = transformers.AutoModelForCausalLM.from_pretrained("tiiuae/Falcon3-Mamba-7B-Instruct", torch_dtype="auto", device_map=device, quantization_config=Q_config)
tokenizer = transformers.AutoTokenizer.from_pretrained("tiiuae/Falcon3-7B-Instruct")
Q_config = BitsAndBytesConfig(load_in_8bit=True)
model = transformers.AutoModelForCausalLM.from_pretrained("tiiuae/Falcon3-7B-Instruct", torch_dtype="auto", device_map=device, quantization_config=Q_config)
# tokenizer = transformers.AutoTokenizer.from_pretrained("tiiuae/Falcon3-3B-Instruct")
# model = transformers.AutoModelForCausalLM.from_pretrained("tiiuae/Falcon3-3B-Instruct", torch_dtype="auto", device_map=device)
# tokenizer = transformers.AutoTokenizer.from_pretrained("tiiuae/Falcon3-1B-Instruct")
# model = transformers.AutoModelForCausalLM.from_pretrained("tiiuae/Falcon3-1B-Instruct", torch_dtype="auto", device_map=device)

Loading checkpoint shards: 100%|██████████| 4/4 [00:06<00:00,  1.66s/it]


In [None]:
sources, inputs, targets = get_input_targets_Falcon3(ds_wnt, source_lang="en", target_lang="de")

In [None]:
translate_list_of_str_Falcon3Mamba(inputs[0:4], tokenizer, model)

In [None]:
translate_list_of_str_Falcon3(inputs[0:8], tokenizer, model)

Setting `pad_token_id` to `eos_token_id`:11 for open-end generation.


### Qwen 2.5

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = transformers.AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", torch_dtype="auto", device_map=device)
# tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
# model = transformers.AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct", torch_dtype="auto", device_map=device)
# tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct")
# model = transformers.AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct", torch_dtype="auto", device_map=device)
# tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8")
# model = transformers.AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8", torch_dtype="auto", device_map=device)

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


In [None]:
sources, inputs, targets = get_input_targets_Qwen2_5(ds_wnt, source_lang="en", target_lang="de")

In [22]:
for i in range(10):
    print(translate_list_of_str_Qwen2_5(inputs[i:i+1], tokenizer, model))

['Polizei arrebten 15 nach einer heftigen Protestzeiten in einem UK-reflektierenden Hotel']
['Das Ereignis folgt einer steigenden Anzahl von Flüchtlingen und Asylbewohnern, die den Atlantik über den England nach Großbritannien fliegen. Die Polizei hat 15 Menschen verhaftet, nachdem ein Anti-Refugee-Demonstration im Hotel, wo Asylbewohner wohnen, in der Nähe der Stadt Liverpool ungemacht wurde. Der Merseyside Polizeidirektor Emily Spurrell sagte am Radio City: "Es war sehr gefährlich und es gab ein paar Verletzungen bei den Polizisten."']
['Der Staatsanwalt hat seit dem letzten Jahr das Hotel für Asylbewerber benutzt, um sie zu暂时erhalten. George Howarth, der im UK-Parlament innehat, sagte, dass die Begegnung am Freitag morgens nicht die Gemeinschaft repräsentiert. "Die Menschen von Knowsley sind nicht bigots und sind freundlich zu Menschen, die aus vielen der gefährlichsten Städten in der Welt nach einem Ort der Sicherheit suchen," sagte er. "Diese Menschen, die gegen Asylbewerber prote

### Mistral 7B

In [None]:
from credentials import hf_token
huggingface_hub.login(token = hf_token)
tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
tokenizer.pad_token = tokenizer.eos_token
Q_config = BitsAndBytesConfig(load_in_8bit=True)
model = transformers.AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", torch_dtype="auto", device_map=device, quantization_config=Q_config)

Downloading shards: 100%|██████████| 3/3 [05:45<00:00, 115.15s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:05<00:00,  1.89s/it]


In [None]:
sources, inputs, targets = get_input_targets_Mistral(ds_wnt, source_lang="en", target_lang="de")

In [18]:
for i in range(4):
    print(translate_list_of_str_Mistral(inputs[i:i+1], tokenizer, model))

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


['Polizei verhaftet 15 nach gewalttätigen Protesten vor Flüchtlingshotel im Vereinigten Königreich.']


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


['Das Ereignis ereignete sich nach dem Anstieg der Zahl von Flüchtlingen und Asylsuchenden, die über den Kanal nach Großbritannien in Booten gelangten. Die Polizei hat 15 Personen festgenommen, nachdem eine antiflüchtlingsdemonstration vor einem Hotel, das Asylsuchende unterkunftlich macht, gewalttätig geworden ist, in der Nähe der englischen Stadt Liverpool. Das Polizeidepartement Merseyside erklärte, dass ein Polizeibeamter und zwei Zivilisten leichte Verletzungen erlitten haben, während die Störung am Freitagabend in Knowsley stattfand. Die Polizeibehörde erklärte, dass einige Demonstranten Gegenstände warfen und ein Polizeifahrzeug anzündeten. Die festgenommenen Personen, deren Alter sich zwischen 13 und 54 Jahren befand, wurden "nach gewalttätigem Verhalten" festgehalten. Die Polizeichefin von Merseyside, Emily Spurrell, erzählte Radio City, "Es war sehr gefährlich und es gab einige Verletzungen unter den Polizeibeamten."']


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


['Das Home Office nutzt seit letztem Jahr ein Hotel, um Asylsuchende vorübergehend unterzubringen, laut lokalen Medien. George Howarth, der Knowsley im britischen Parlament vertritt, sagte, dass die Gewalt am Freitagabend nicht den Charakter der Gemeinde widerspiegelte. "Die Leute von Knowsley sind nicht Rassisten und freundlich gegenüber Menschen, die aus einigen der gefährlichsten Orte der Welt fliehen, um sich einen sicheren Ort zu suchen." Diejenigen, die gegen Flüchtlinge bei dieser Demonstration demonstrierten, seien nicht die Vertreter dieser Gemeinde. Die Demonstration fand unter erhöhten Spannungen statt, als sich zunehmende Zahlen von Flüchtlingen und Migranten über den Kanal in kleinen Booten begeben.']
['Mehr als 45.000 Menschen erreichten das Vereinigte Königreich über diesen Weg im Jahr 2022, und die meisten meldeten sich als Flüchtlinge an. Das System zur Prüfung von Flüchtlingsanträgen ist aufgrund politischer Unruhen und administrativer Verspätungen fast zum Stillstand

### New model

In [None]:
messages = [
    [{"role": "system", "content": "Translate from English to Czech:\nEnglish: Hey! I am an american \nCzech:"}],
    [{"role": "system", "content": "Translate from German to Czech:\nGerman: Hallo, Ich komme aus Deutschland \nCzech:"}]
]
instruct_messages = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
tokens = tokenizer(instruct_messages, padding=True, return_tensors="pt").to(model.device)
print(tokens)
print(instruct_messages)
with torch.no_grad():    
    out_tokens = model.generate(**tokens,
                                num_beams=5, max_new_tokens=100, do_sample=True,
                                temperature=0.6, top_p=0.9)
translations = tokenizer.batch_decode(out_tokens)
translations

## General aggregation functions for benchmark

In [15]:
def load_model(model_name, model_size=None):
    if model_name == "alma":
        tokenizer = transformers.LlamaTokenizer.from_pretrained("haoranxu/ALMA-7B", padding_side='left')
        Q_config = BitsAndBytesConfig(load_in_8bit=True) 
        model = transformers.AutoModelForCausalLM.from_pretrained("haoranxu/ALMA-7B", torch_dtype="auto", device_map=device, quantization_config=Q_config)
        get_input_targets_fn = get_input_targets_ALMA
        tslt_fn = translate_batched_ALMA
        
    elif model_name == "nllb":
        tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
        model = transformers.AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M", torch_dtype="auto", device_map=device)
        get_input_targets_fn = get_input_targets_NLLB
        tslt_fn = translate_batched_NLLB

    elif model_name == "llama3":
        from credentials import hf_token
        huggingface_hub.login(token = hf_token)
        if model_size=="1B" or model_size=="3B":
            tokenizer = transformers.AutoTokenizer.from_pretrained(f"meta-llama/Llama-3.2-{model_size}-Instruct")
            model = transformers.AutoModelForCausalLM.from_pretrained(f"meta-llama/Llama-3.2-{model_size}-Instruct", torch_dtype="auto", device_map=device)
        else:
            tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
            Q_config = BitsAndBytesConfig(load_in_8bit=True)
            model = transformers.AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", torch_dtype="auto", device_map=device, quantization_config=Q_config)
        tokenizer.pad_token = tokenizer.eos_token
        model.generation_config.pad_token_id = tokenizer.pad_token_id
        
        get_input_targets_fn = get_input_targets_Llama3
        tslt_fn = translate_batched_Llama3
    
    elif model_name == "falcon3-mamba":
        tokenizer = transformers.AutoTokenizer.from_pretrained("tiiuae/Falcon3-Mamba-7B-Instruct")
        Q_config = BitsAndBytesConfig(load_in_8bit=True)
        model = transformers.AutoModelForCausalLM.from_pretrained("tiiuae/Falcon3-Mamba-7B-Instruct", torch_dtype="auto", device_map=device, quantization_config=Q_config)
        get_input_targets_fn = get_input_targets_Falcon3
        tslt_fn = translate_batched_Falcon3Mamba
    
    elif model_name == "falcon3":
        if model_size=="1B" or model_size=="3B":
            tokenizer = transformers.AutoTokenizer.from_pretrained(f"tiiuae/Falcon3-{model_size}-Instruct")
            model = transformers.AutoModelForCausalLM.from_pretrained(f"tiiuae/Falcon3-{model_size}-Instruct", torch_dtype="auto", device_map=device)
        else:
            tokenizer = transformers.AutoTokenizer.from_pretrained("tiiuae/Falcon3-7B-Instruct")
            Q_config = BitsAndBytesConfig(load_in_8bit=True)
            model = transformers.AutoModelForCausalLM.from_pretrained("tiiuae/Falcon3-7B-Instruct", torch_dtype="auto", device_map=device, quantization_config=Q_config)  
        model.generation_config.pad_token_id = tokenizer.pad_token_id
        get_input_targets_fn = get_input_targets_Falcon3
        tslt_fn = translate_batched_Falcon3
    
    elif model_name == "qwen2.5":
        if model_size=="0.5B" or model_size=="1.5B" or model_size=="3B":
            tokenizer = transformers.AutoTokenizer.from_pretrained(f"Qwen/Qwen2.5-{model_size}-Instruct")
            model = transformers.AutoModelForCausalLM.from_pretrained(f"Qwen/Qwen2.5-{model_size}-Instruct", torch_dtype="auto", device_map=device)
        else:
            tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8")
            model = transformers.AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8", torch_dtype="auto", device_map=device)
        get_input_targets_fn = get_input_targets_Qwen2_5
        tslt_fn = translate_batched_Qwen2_5
    
    elif model_name == "mistral":
        from credentials import hf_token
        huggingface_hub.login(token = hf_token)
        tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
        tokenizer.pad_token = tokenizer.eos_token
        Q_config = BitsAndBytesConfig(load_in_8bit=True)
        model = transformers.AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", torch_dtype="auto", device_map=device, quantization_config=Q_config)
        model.generation_config.pad_token_id = tokenizer.pad_token_id
        get_input_targets_fn = get_input_targets_Mistral
        tslt_fn = translate_batched_Mistral
        
    return tokenizer, model, get_input_targets_fn, tslt_fn


def reduce_dataset(inputs, sources, targets, final_nb):
    idx = np.arange(len(inputs))
    np.random.seed(42)
    idx = np.random.choice(idx, final_nb)
    return [inputs[i] for i in idx], [sources[i] for i in idx], [targets[i] for i in idx]


def generate_translation_different_directions(directions, dataset_name, model_name, batch_size, reduce_size = None, model_size = None):
    # Loading full flores (if necessary)
    if dataset_name == "flores":
        from credentials import hf_token
        huggingface_hub.login(token = hf_token)
        ds_flores = load_dataset("openlanguagedata/flores_plus")["devtest"]

    # Loading corresponding model
    print("Loading model...")
    tokenizer, model, get_input_targets_fn, tslt_fn = load_model(model_name, model_size)

    for direction in directions:
        print(f"Translating {direction} with model {model_name} for dataset {dataset_name}...")
        input_language, target_language = direction[0:2], direction[3:5]
        
        # Getting the right split corresponding to the translation direction
        if dataset_name == "flores":
            ds = transform_to_WNT_style(ds_flores, lang=target_language, lang_start=input_language)
        elif dataset_name == "wnt23":
            if direction != "cs-en":
                ds = load_dataset("haoranxu/WMT23-Test", direction)["test"]
            else:
                ds = load_dataset("haoranxu/WMT23-Test", "en-cs")["test"]
                ds = Dataset.from_dict({f"cs-en": ds["en-cs"][::-1]}) # Reverse list to avoid having same sentences (if reduce_size not None)
        # Extracting input & targets
        sources, inputs, targets = get_input_targets_fn(ds, input_language, target_language)
        print(f"Total number of samples: {len(sources)}" + ("" if reduce_size is None else f"; reduced to {reduce_size} (numpy seed = 42)"))
        if reduce_size is not None:
            sources, inputs, targets = reduce_dataset(sources, inputs, targets, reduce_size)
        translation_pred = tslt_fn(inputs, model, tokenizer, batch_size, target_language)
        if not os.path.exists(f"./generated_translations/evaluations"):
            os.makedirs(f"./generated_translations/evaluations")
        with open(f"./generated_translations/evaluations/{dataset_name}_{model_name}_{direction}_red-{reduce_size}.pkl", "wb") as f:
            pickle.dump(translation_pred, f, pickle.HIGHEST_PROTOCOL)
    model.cpu()
    del model, tokenizer


In [7]:
directions = ["en-de", "de-en",
              "en-cs", "cs-en",
              "en-is", "is-en",
              "en-zh", "zh-en",
              "en-ru", "ru-en"]

In [None]:
generate_translation_different_directions(directions,
                                          dataset_name="wnt23",
                                          model_name="alma",
                                          model_size=None,
                                          batch_size=1,
                                          reduce_size=2)

Loading model...


Loading checkpoint shards: 100%|██████████| 3/3 [00:05<00:00,  1.96s/it]


Translating en-de with model alma for dataset wnt23...
Total number of samples: 557; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:14<00:00,  7.24s/it]


Translating de-en with model alma for dataset wnt23...
Total number of samples: 549; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:11<00:00,  5.65s/it]


Translating en-cs with model alma for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:12<00:00,  6.11s/it]


Translating cs-en with model alma for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:07<00:00,  3.93s/it]


Translating en-is with model alma for dataset wnt23...
Total number of samples: 1000; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:26<00:00, 13.11s/it]


Translating is-en with model alma for dataset wnt23...
Total number of samples: 1000; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:11<00:00,  5.75s/it]


Translating en-zh with model alma for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:08<00:00,  4.03s/it]


Translating zh-en with model alma for dataset wnt23...
Total number of samples: 1976; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:17<00:00,  8.50s/it]


Translating en-ru with model alma for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:11<00:00,  5.56s/it]


Translating ru-en with model alma for dataset wnt23...
Total number of samples: 1723; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:06<00:00,  3.08s/it]


In [13]:
generate_translation_different_directions(directions,
                                          dataset_name="wnt23",
                                          model_name="llama3",
                                          model_size="1B",
                                          batch_size=1,
                                          reduce_size=2)

Loading model...
Translating en-de with model llama3 for dataset wnt23...
Total number of samples: 557; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:02<00:00,  1.26s/it]


Translating de-en with model llama3 for dataset wnt23...
Total number of samples: 549; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:01<00:00,  1.04it/s]


Translating en-cs with model llama3 for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:01<00:00,  1.50it/s]


Translating cs-en with model llama3 for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:01<00:00,  1.04it/s]


Translating en-is with model llama3 for dataset wnt23...
Total number of samples: 1000; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:03<00:00,  1.58s/it]


Translating is-en with model llama3 for dataset wnt23...
Total number of samples: 1000; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:02<00:00,  1.30s/it]


Translating en-zh with model llama3 for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:01<00:00,  1.61it/s]


Translating zh-en with model llama3 for dataset wnt23...
Total number of samples: 1976; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:02<00:00,  1.04s/it]


Translating en-ru with model llama3 for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:17<00:00,  8.61s/it]


Translating ru-en with model llama3 for dataset wnt23...
Total number of samples: 1723; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:01<00:00,  1.97it/s]


In [9]:
generate_translation_different_directions(directions,
                                          dataset_name="wnt23",
                                          model_name="falcon3-mamba",
                                          model_size=None,
                                          batch_size=1,
                                          reduce_size=2)

Loading model...


Downloading shards: 100%|██████████| 3/3 [07:01<00:00, 140.48s/it]
The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py.
Loading checkpoint shards: 100%|██████████| 3/3 [00:07<00:00,  2.34s/it]


Translating en-de with model falcon3-mamba for dataset wnt23...
Total number of samples: 557; reduced to 2 (numpy seed = 42)


  0%|          | 0/2 [00:00<?, ?it/s]The 'batch_size' argument of MambaCache is deprecated and will be removed in v4.49. Use the more precisely named 'max_batch_size' argument instead.
100%|██████████| 2/2 [00:18<00:00,  9.34s/it]


Translating de-en with model falcon3-mamba for dataset wnt23...
Total number of samples: 549; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:15<00:00,  7.94s/it]


Translating en-cs with model falcon3-mamba for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:17<00:00,  8.78s/it]


Translating cs-en with model falcon3-mamba for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:40<00:00, 20.23s/it]


Translating en-is with model falcon3-mamba for dataset wnt23...
Total number of samples: 1000; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [01:18<00:00, 39.22s/it]


Translating is-en with model falcon3-mamba for dataset wnt23...
Total number of samples: 1000; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [01:16<00:00, 38.42s/it]


Translating en-zh with model falcon3-mamba for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:16<00:00,  8.33s/it]


Translating zh-en with model falcon3-mamba for dataset wnt23...
Total number of samples: 1976; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:23<00:00, 11.58s/it]


Translating en-ru with model falcon3-mamba for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:42<00:00, 21.39s/it]


Translating ru-en with model falcon3-mamba for dataset wnt23...
Total number of samples: 1723; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:12<00:00,  6.37s/it]


In [16]:
generate_translation_different_directions(directions,
                                          dataset_name="wnt23",
                                          model_name="falcon3",
                                          model_size="1B",
                                          batch_size=1,
                                          reduce_size=2)

Loading model...
Translating en-de with model falcon3 for dataset wnt23...
Total number of samples: 557; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:03<00:00,  1.56s/it]


Translating de-en with model falcon3 for dataset wnt23...
Total number of samples: 549; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:02<00:00,  1.24s/it]


Translating en-cs with model falcon3 for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:21<00:00, 10.76s/it]


Translating cs-en with model falcon3 for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:20<00:00, 10.39s/it]


Translating en-is with model falcon3 for dataset wnt23...
Total number of samples: 1000; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:06<00:00,  3.06s/it]


Translating is-en with model falcon3 for dataset wnt23...
Total number of samples: 1000; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:04<00:00,  2.14s/it]


Translating en-zh with model falcon3 for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:02<00:00,  1.49s/it]


Translating zh-en with model falcon3 for dataset wnt23...
Total number of samples: 1976; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:03<00:00,  1.61s/it]


Translating en-ru with model falcon3 for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:20<00:00, 10.14s/it]


Translating ru-en with model falcon3 for dataset wnt23...
Total number of samples: 1723; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:01<00:00,  1.52it/s]


In [11]:
generate_translation_different_directions(directions,
                                          dataset_name="wnt23",
                                          model_name="qwen2.5",
                                          model_size="0.5B",
                                          batch_size=1,
                                          reduce_size=2)

Loading model...


Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Translating en-de with model qwen2.5 for dataset wnt23...
Total number of samples: 557; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:04<00:00,  2.16s/it]


Translating de-en with model qwen2.5 for dataset wnt23...
Total number of samples: 549; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:02<00:00,  1.12s/it]


Translating en-cs with model qwen2.5 for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:02<00:00,  1.27s/it]


Translating cs-en with model qwen2.5 for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:02<00:00,  1.36s/it]


Translating en-is with model qwen2.5 for dataset wnt23...
Total number of samples: 1000; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:03<00:00,  1.87s/it]


Translating is-en with model qwen2.5 for dataset wnt23...
Total number of samples: 1000; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:01<00:00,  1.08it/s]


Translating en-zh with model qwen2.5 for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:01<00:00,  1.91it/s]


Translating zh-en with model qwen2.5 for dataset wnt23...
Total number of samples: 1976; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:02<00:00,  1.36s/it]


Translating en-ru with model qwen2.5 for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:15<00:00,  7.99s/it]


Translating ru-en with model qwen2.5 for dataset wnt23...
Total number of samples: 1723; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:01<00:00,  1.29it/s]


In [17]:
generate_translation_different_directions(directions,
                                          dataset_name="wnt23",
                                          model_name="mistral",
                                          model_size=None,
                                          batch_size=1,
                                          reduce_size=2)

Loading model...


Loading checkpoint shards: 100%|██████████| 3/3 [01:05<00:00, 21.75s/it]


Translating en-de with model mistral for dataset wnt23...
Total number of samples: 557; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:22<00:00, 11.06s/it]


Translating de-en with model mistral for dataset wnt23...
Total number of samples: 549; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:13<00:00,  6.77s/it]


Translating en-cs with model mistral for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:39<00:00, 19.61s/it]


Translating cs-en with model mistral for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:14<00:00,  7.16s/it]


Translating en-is with model mistral for dataset wnt23...
Total number of samples: 1000; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:54<00:00, 27.37s/it]


Translating is-en with model mistral for dataset wnt23...
Total number of samples: 1000; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:16<00:00,  8.43s/it]


Translating en-zh with model mistral for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:23<00:00, 11.53s/it]


Translating zh-en with model mistral for dataset wnt23...
Total number of samples: 1976; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:20<00:00, 10.38s/it]


Translating en-ru with model mistral for dataset wnt23...
Total number of samples: 2074; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [01:25<00:00, 42.90s/it]


Translating ru-en with model mistral for dataset wnt23...
Total number of samples: 1723; reduced to 2 (numpy seed = 42)


100%|██████████| 2/2 [00:12<00:00,  6.33s/it]


In [None]:
def eval_translation_different_directions(directions, dataset_name, model_name, batch_size):
    # Loading full flores (if necessary)
    if dataset_name == "flores":
        from credentials import hf_token
        huggingface_hub.login(token = hf_token)
        ds_flores = load_dataset("openlanguagedata/flores_plus")["devtest"]

    # Loading corresponding model
    tokenizer, model, get_input_targets_fn, tslt_fn = load_model(model_name)

    evaluation_all_direction = {}
    for direction in directions:
        input_language, target_language = direction[0:2], direction[3:5]
        
        # Getting the right split corresponding to the translation direction
        if dataset_name == "flores":
            ds = transform_to_WNT_style(ds_flores, lang=target_language, lang_start=input_language)
        elif dataset_name == "wnt23":
            ds = load_dataset("haoranxu/WMT23-Test", direction)["test"]
        # Extracting input & targets
        sources, inputs, targets = get_input_targets_fn(ds, input_language, target_language)
        translation_pred = tslt_fn(inputs, model, tokenizer, batch_size, target_language)
        with open(f"./generated_translations/evaluations/{dataset_name}_{model_name}_{direction}.pkl", "wb") as f:
            pickle.dump(translation_pred, f, pickle.HIGHEST_PROTOCOL)
        evaluation_all_direction[f"{direction}"] = evaluate_translation(sources, targets, translation_pred)
    return evaluation_all_direction

def eval_translation_general(directions, dataset_names, model_names, batch_size):
    evaluation_all_datasets = {}
    for dataset_name in dataset_names:
        evaluation_all_models = {}
        for model_name in model_names:
            evaluation_all_models[f"{model_name}"] = eval_translation_different_directions(directions, dataset_name, model_name, batch_size)
        evaluation_all_datasets[f"{dataset_name}"] = evaluation_all_models