In [None]:
!nvidia-smi

In [None]:
import os
import re
import pickle
import time
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.path import Path
import matplotlib.patches as patches

import torch

import huggingface_hub
from datasets import load_dataset, Dataset
import transformers
from transformers import BitsAndBytesConfig
import evaluate
from evaluate import evaluator
from sacrebleu.tokenizers.tokenizer_13a import Tokenizer13a
from sacrebleu.tokenizers.tokenizer_zh import TokenizerZh

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

In [None]:
num_beams = 5
max_new_tokens = 512
top_p = 0.9
temperature = 0.6

In [None]:
language_tested = ["en", "de", "cs", "is", "zh", "ru"] # Only from or to english
metrics_available = ["bleu", "rouge", "bleurt", "sacrebleu", "comet", "meteor", "chrf", "bert_score"]
models_available = [
    # NLLB
    "facebook/nllb-200-distilled-600M",
    # ALMA
    "haoranxu/ALMA-7B",
    # Llama 3 Instruct
    "meta-llama/Llama-3.2-1B-Instruct",
    "meta-llama/Llama-3.2-3B-Instruct",
    "meta-llama/Llama-3.1-8B-Instruct",
    # Falcon 3 Mamba Instruct
    "tiiuae/Falcon3-Mamba-7B-Instruct",
    # Falcon 3 Instruct
    "tiiuae/Falcon3-1B-Instruct",
    "tiiuae/Falcon3-3B-Instruct",
    "tiiuae/Falcon3-7B-Instruct",
    # Qwen 2.5 Mamba Instruct
    "Qwen/Qwen2.5-0.5B-Instruct",
    "Qwen/Qwen2.5-1.5B-Instruct",
    "Qwen/Qwen2.5-3B-Instruct",
    "Qwen/Qwen2.5-7B-Instruct",
    # Mistral Instruct
    "mistralai/Mistral-7B-Instruct-v0.3",
    # BayLing
    "ICTNLP/bayling-2-7b",
    # Bloom & Bloomz
    "bigscience/bloom-560m",
    "bigscience/bloom-1b7",
    "bigscience/bloom-3b",
    "bigscience/bloom-7b1",
    "bigscience/bloomz-1b7",
    "bigscience/bloomz-3b",
    "bigscience/bloomz-7b1",
    # OPT
    "facebook/opt-125m",
    "facebook/opt-350m",
    "facebook/opt-6.7b",
    "facebook/opt-iml-1.3b",
    # MPT
    "mosaicml/mpt-7b-instruct",
]


ds_available = ["haoranxu/WMT23-Test",
                "openlanguagedata/flores_plus"]

## Inference functions

In [None]:
#################################   NLLB

def get_input_targets_NLLB(dataset_wnt_format, source_lang, target_lang):
    inputs = [example[source_lang] for example in dataset_wnt_format[f"{source_lang}-{target_lang}"]]
    targets = [example[target_lang] for example in dataset_wnt_format[f"{source_lang}-{target_lang}"]]
    return inputs, inputs, targets

def translate_list_of_str_NLLB(list_str, tokenizer, model, to_laguage):
    """
    Returns a list containing str corresponding to translation of the inputted
    """
    equivalence_language_to_FLORES = {"en": "eng_Latn", "de": "deu_Latn", "ru": "rus_Cyrl", "is": "isl_Latn", "zh": "zho_Hans", "cs": "ces_Latn"}
    with torch.no_grad():
        inputs = tokenizer(list_str, return_tensors="pt", padding=True)
        language_tgt_FLORES = equivalence_language_to_FLORES[to_laguage]
        translated = model.generate(inputs["input_ids"].to(device),
                                    forced_bos_token_id=tokenizer.convert_tokens_to_ids(language_tgt_FLORES),
                                    num_beams=num_beams, max_length=max_new_tokens, early_stopping=True,
                                    ).cpu()
        translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
    return translated_text

def translate_batched_NLLB(inputs, model, tokenizer, batch_size, target_language):
    """
    For 8GB VRAM, use batch_size = 4
    For 16GB VRAM, use batch_size = 8 (better working with unbatch version to avoid pad noise).
    """
    preds = []
    for i in tqdm(range(len(inputs)//batch_size)):
        tslt = translate_list_of_str_NLLB(inputs[i*batch_size : (i+1)*batch_size], tokenizer, model, target_language)
        preds.extend(tslt)
    return preds

#################################   ALMA

def get_input_targets_ALMA(dataset, source_lang, target_lang):
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    source_lang_name = language_name[source_lang]
    target_lang_name = language_name[target_lang]
    # Use base formulation "Translate this from Chinese to English:\nChinese: 我爱机器翻译。\nEnglish:"
    sources = [example[source_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    inputs = [(
        f"Translate from {source_lang_name} to {target_lang_name}:"
        + f"\n{source_lang_name}: {example.get(source_lang)} \n{target_lang_name}:")
        for example in dataset[f"{source_lang}-{target_lang}"]]
    targets = [example[target_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    return sources, inputs, targets

def translate_list_of_str_ALMA(list_str, tokenizer, model, target_language):
    """
    Returns a list containing str corresponding to translation of the inputted
    """
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    with torch.no_grad():
        inputs = tokenizer(list_str, return_tensors="pt", padding=True)
        translated = model.generate(inputs["input_ids"].to(device),
                                    num_beams=num_beams, max_new_tokens=max_new_tokens, do_sample=True,
                                    temperature=temperature, top_p=top_p
                                    ).cpu()
        translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
        tgt_language_name = language_name[target_language]
        translated_text = [t.split(f"{tgt_language_name}:")[2] for t in translated_text] # Remove prompt
    return translated_text

def translate_batched_ALMA(inputs, model, tokenizer, batch_size, target_language):
    """
    For 8GB VRAM, use batch_size=1
    For 16GB VRAM, use batch_size=3
    """
    preds = []
    for i in tqdm(range(len(inputs)//batch_size)):
        tslt = translate_list_of_str_ALMA(inputs[i*batch_size : (i+1)*batch_size], tokenizer, model, target_language)
        preds.extend(tslt)
    return preds

#################################   Llama 3

def get_input_targets_Llama3(dataset, source_lang, target_lang):
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    source_lang_name = language_name[source_lang]
    target_lang_name = language_name[target_lang]
    # Use base formulation "Translate this from Chinese to English:\nChinese: 我爱机器翻译。\nEnglish:"
    sources = [example[source_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    inputs = [
        [{"role": "system", "content": "You are a translator, you output only the translation in the desired language."},
         {"role": "user",
        "content": f"Translate from {source_lang_name} to {target_lang_name}:"
        + f"\n{source_lang_name}: {example.get(source_lang)} \n{target_lang_name}:"
        }] for example in dataset[f"{source_lang}-{target_lang}"]]
    targets = [example[target_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    return sources, inputs, targets

def extract_translation_Llama3(translated_prompt):
    answer = translated_prompt.split("<|start_header_id|>assistant<|end_header_id|>\n\n")[-1]
    translation_only = answer.split("<|end_of_text|>")[0]
    translation_only = translation_only.split("<|eot_id|><|start_header_id|>assistant\n")[-1]
    translation_only = translation_only.split("<|eot_id|><|start_header_id|>")[-1]
    return translation_only

def translate_list_of_str_Llama3(list_str, tokenizer, model, target_language=None):
    with torch.no_grad():
        instruct_messages = tokenizer.apply_chat_template(list_str, tokenize=False, add_generation_prompt=True)
        tokens = tokenizer(instruct_messages, padding=True, padding_side='left', return_tensors="pt")
        out_tokens = model.generate(**tokens.to(device),
                                    num_beams=num_beams, do_sample=True,
                                    temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens)
        translations = tokenizer.batch_decode(out_tokens)
        translations = [extract_translation_Llama3(trans) for trans in translations]
        return translations
    
def translate_batched_Llama3(inputs, model, tokenizer, batch_size, target_language):
    """
    For 8GB VRAM use
        batch_size=20 with Llama3 1B,
        batch_size=4 with Llama3 3B
    For 16 GB VRAM use 
        batch_size=40 with Llama3 1B,
        batch_size=10 with Llama3 3B,
        batch_size=5 with Llama3 8B,
    """
    preds = []
    for i in tqdm(range(len(inputs)//batch_size)):
        tslt = translate_list_of_str_Llama3(inputs[i*batch_size : (i+1)*batch_size], tokenizer, model, target_language=None)
        preds.extend(tslt)
    return preds

#################################   Falcon 3 (Normal + Mamba)

def get_input_targets_Falcon3(dataset, source_lang, target_lang):
    """
    This function is valid for Falcon 3 and it mamba version
    """
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    source_lang_name = language_name[source_lang]
    target_lang_name = language_name[target_lang]
    # Use base formulation "Translate this from Chinese to English:\nChinese: 我爱机器翻译。\nEnglish:"
    sources = [example[source_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    inputs = [
        [{"role": "system", "content": "You are a translator, you output only the translation in the desired language."},
         {"role": "user",
          "content": f"Translate from {source_lang_name} to {target_lang_name}:"
          + f"{example.get(source_lang)}"
        }] for example in dataset[f"{source_lang}-{target_lang}"]]
    targets = [example[target_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    return sources, inputs, targets

def extract_translation_Falcon3Mamba(translated_prompt):
    answer = translated_prompt.split("<|im_end|>\n<|im_start|>assistant\n")[-1]
    translation_only = answer.split("<|im_end|>")[0]
    return translation_only

def translate_list_of_str_Falcon3Mamba(list_str, tokenizer, model, target_language=None):
    with torch.no_grad():
        instruct_messages = tokenizer.apply_chat_template(list_str, tokenize=False, add_generation_prompt=True)
        tokens = tokenizer(instruct_messages, padding=True, padding_side='left', return_tensors="pt").to(model.device)
        out_tokens = model.generate(**tokens,
                                    num_beams=num_beams, do_sample=True,
                                    temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens)
        translations = tokenizer.batch_decode(out_tokens)
        translations = [extract_translation_Falcon3Mamba(trans) for trans in translations]
        return translations
    
def translate_batched_Falcon3Mamba(inputs, model, tokenizer, batch_size, target_language=None):
    """
    For 16GB VRAM, use
        batch_size=4 with Falcon Mamba 7B (8 bits quantization),
        batch_size=4 with Falcon Mamba 7B (4 bits quantization),
    """
    preds = []
    for i in tqdm(range(len(inputs)//batch_size)):
        tslt = translate_list_of_str_Falcon3Mamba(inputs[i*batch_size : (i+1)*batch_size], tokenizer, model, target_language)
        preds.extend(tslt)
    return preds

def extract_translation_Falcon3(translated_prompt):
    answerpadded = translated_prompt.split("\n<|assistant|>\n")[-1]
    answer = answerpadded.split("<|pad|>")[-1]
    translation_only = answer.split("<|endoftext|>")[0]
    translation_only = re.sub(r"^[^a-zA-Z0-9]*", "", translation_only)
    return translation_only.replace("assistant|>\n", "")

def translate_list_of_str_Falcon3(list_str, tokenizer, model, target_language=None):
    with torch.no_grad():
        instruct_messages = tokenizer.apply_chat_template(list_str, tokenize=False, add_generation_prompt=True)
        tokens = tokenizer(instruct_messages, padding=True, padding_side='left', return_tensors="pt").to(model.device)
        out_tokens = model.generate(**tokens,
                                    num_beams=num_beams, do_sample=True,
                                    temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens)
        translations = tokenizer.batch_decode(out_tokens)
        translations = [extract_translation_Falcon3(trans) for trans in translations]
        return translations
    
def translate_batched_Falcon3(inputs, model, tokenizer, batch_size, target_language=None):
    """
    For 16GB VRAM, use
        batch_size=8 with Falcon 7B (8 bits quantization),
        batch_size=4 with Falcon 3B,
        batch_size=12 with Falcon 1B
    """
    preds = []
    for i in tqdm(range(len(inputs)//batch_size)):
        tslt = translate_list_of_str_Falcon3(inputs[i*batch_size : (i+1)*batch_size], tokenizer, model, target_language)
        preds.extend(tslt)
    return preds

#################################   Qwen 2.5

def get_input_targets_Qwen2_5(dataset, source_lang, target_lang):
    """
    This function is valid for Falcon 3 and it mamba version
    """
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    source_lang_name = language_name[source_lang]
    target_lang_name = language_name[target_lang]
    # Use base formulation "Translate this from Chinese to English:\nChinese: 我爱机器翻译。\nEnglish:"
    sources = [example[source_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    inputs = [
        [{"role": "system", "content": "You are a translator, you output only the translation in the desired language."},
         {"role": "user",
          "content": f"Translate from {source_lang_name} to {target_lang_name}:"
          + f"{example.get(source_lang)}"
        }] for example in dataset[f"{source_lang}-{target_lang}"]]
    targets = [example[target_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    return sources, inputs, targets

def extract_translation_Qwen2_5(translated_prompt):
    answerpadded = translated_prompt.split("\n<|im_start|>assistant\n")[-1]
    answer = answerpadded.split("<|im_end|>")[0]
    translation_only = answer.replace("<|endoftext|>", "")
    return translation_only

def translate_list_of_str_Qwen2_5(list_str, tokenizer, model, target_language=None):
    with torch.no_grad():
        instruct_messages = tokenizer.apply_chat_template(list_str, tokenize=False, add_generation_prompt=True)
        tokens = tokenizer(instruct_messages, padding=True, padding_side='left', return_tensors="pt").to(model.device)
        out_tokens = model.generate(**tokens,
                                    num_beams=num_beams, do_sample=True,
                                    temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens)
        translations = tokenizer.batch_decode(out_tokens)
        translations = [extract_translation_Qwen2_5(trans) for trans in translations]
        return translations

def translate_batched_Qwen2_5(inputs, model, tokenizer, batch_size, target_language=None):
    """
    For 16GB VRAM, use batch_size=100 (up to)
    """
    preds = []
    for i in tqdm(range(len(inputs)//batch_size)):
        tslt = translate_list_of_str_Qwen2_5(inputs[i*batch_size : (i+1)*batch_size], tokenizer, model, target_language)
        preds.extend(tslt)
    return preds

#################################   Mistral


def get_input_targets_Mistral(dataset, source_lang, target_lang):
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    source_lang_name = language_name[source_lang]
    target_lang_name = language_name[target_lang]
    # Use base formulation "Translate this from Chinese to English:\nChinese: 我爱机器翻译。\nEnglish:"
    sources = [example[source_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    inputs = [
        [{"role": "system", "content": "You are a translator, you output only the translation in the desired language."},
         {"role": "user",
          "content": f"Translate from {source_lang_name} to {target_lang_name}:"
          + f"{example.get(source_lang)}"
        }] for example in dataset[f"{source_lang}-{target_lang}"]]
    targets = [example[target_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    return sources, inputs, targets

def extract_translation_Mistral(translated_prompt):
    answerpadded = translated_prompt.split("[/INST] ")[-1]
    answer = answerpadded.split("</s>")[0]
    return answer

def translate_list_of_str_Mistral(list_str, tokenizer, model, target_language=None):
    with torch.no_grad():
        instruct_messages = tokenizer.apply_chat_template(list_str, tokenize=False, add_generation_prompt=True)
        tokens = tokenizer(instruct_messages, padding=True, padding_side='left', return_tensors="pt").to(model.device)
        out_tokens = model.generate(**tokens,
                                    num_beams=num_beams, do_sample=True,
                                    temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens)
        translations = tokenizer.batch_decode(out_tokens)
        translations = [extract_translation_Mistral(trans) for trans in translations]
        return translations

def translate_batched_Mistral(inputs, model, tokenizer, batch_size, target_language=None):
    """
    For 16GB VRAM, use batch_size=2 (up to)
    """
    preds = []
    for i in tqdm(range(len(inputs)//batch_size)):
        tslt = translate_list_of_str_Mistral(inputs[i*batch_size : (i+1)*batch_size], tokenizer, model, target_language)
        preds.extend(tslt)
    return preds

#################################   BayLing

def get_input_targets_BayLing(dataset, source_lang, target_lang):
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    source_lang_name = language_name[source_lang]
    target_lang_name = language_name[target_lang]
    # Use base formulation "Translate this from Chinese to English:\nChinese: 我爱机器翻译。\nEnglish:"
    sources = [example[source_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    inputs = [(
        f"Translate from {source_lang_name} to {target_lang_name}:"
        + f"\n{source_lang_name}: {example.get(source_lang)} \n{target_lang_name}:")
        for example in dataset[f"{source_lang}-{target_lang}"]]
    targets = [example[target_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    return sources, inputs, targets

def translate_list_of_str_BayLing(list_str, tokenizer, model, target_language):
    """
    Returns a list containing str corresponding to translation of the inputted
    """
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    with torch.no_grad():
        inputs = tokenizer(list_str, return_tensors="pt", padding=True)
        translated = model.generate(inputs["input_ids"].to(device),
                                    num_beams=num_beams, max_new_tokens=max_new_tokens, do_sample=True,
                                    temperature=temperature, top_p=top_p
                                    ).cpu()
        translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
        tgt_language_name = language_name[target_language]
        translated_text = [t.split(f"{tgt_language_name}:")[2] for t in translated_text] # Remove prompt
    return translated_text

def translate_batched_BayLing(inputs, model, tokenizer, batch_size, target_language):
    preds = []
    for i in tqdm(range(len(inputs)//batch_size)):
        tslt = translate_list_of_str_BayLing(inputs[i*batch_size : (i+1)*batch_size], tokenizer, model, target_language)
        preds.extend(tslt)
    return preds

#################################   BLOOM & BLOOMZ

def get_input_targets_BLOOM(dataset, source_lang, target_lang):
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    source_lang_name = language_name[source_lang]
    target_lang_name = language_name[target_lang]
    # Use base formulation "Translate this from Chinese to English:\nChinese: 我爱机器翻译。\nEnglish:"
    sources = [example[source_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    inputs = [(
        f"Translate from {source_lang_name} to {target_lang_name}:"
        + f"\n{source_lang_name}: {example.get(source_lang)} \n{target_lang_name}:")
        for example in dataset[f"{source_lang}-{target_lang}"]]
    targets = [example[target_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    return sources, inputs, targets

def translate_list_of_str_BLOOM(list_str, tokenizer, model, target_language):
    """
    Returns a list containing str corresponding to translation of the inputted
    """
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    with torch.no_grad():
        inputs = tokenizer(list_str, return_tensors="pt", padding=True)
        translated = model.generate(inputs["input_ids"].to(device),
                                    num_beams=num_beams, max_new_tokens=max_new_tokens, do_sample=True,
                                    temperature=temperature, top_p=top_p
                                    ).cpu()
        translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
        tgt_language_name = language_name[target_language]
        translated_text = [t.split(f"{tgt_language_name}:")[2] for t in translated_text] # Remove prompt
    return translated_text

def translate_batched_BLOOM(inputs, model, tokenizer, batch_size, target_language):
    preds = []
    for i in tqdm(range(len(inputs)//batch_size)):
        tslt = translate_list_of_str_BLOOM(inputs[i*batch_size : (i+1)*batch_size], tokenizer, model, target_language)
        preds.extend(tslt)
    return preds

#################################   OPT & OPT Instruct

def get_input_targets_OPT(dataset, source_lang, target_lang):
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    source_lang_name = language_name[source_lang]
    target_lang_name = language_name[target_lang]
    # Use base formulation "Translate this from Chinese to English:\nChinese: 我爱机器翻译。\nEnglish:"
    sources = [example[source_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    inputs = [(
        f"Translate from {source_lang_name} to {target_lang_name}:"
        + f"\n{source_lang_name}: {example.get(source_lang)} \n{target_lang_name}:")
        for example in dataset[f"{source_lang}-{target_lang}"]]
    targets = [example[target_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    return sources, inputs, targets

def translate_list_of_str_OPT(list_str, tokenizer, model, target_language):
    """
    Returns a list containing str corresponding to translation of the inputted
    """
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    with torch.no_grad():
        inputs = tokenizer(list_str, return_tensors="pt", padding=True)
        translated = model.generate(inputs["input_ids"].to(device),
                                    num_beams=num_beams, max_new_tokens=max_new_tokens, do_sample=True,
                                    temperature=temperature, top_p=top_p
                                    ).cpu()
        translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
        tgt_language_name = language_name[target_language]
        translated_text = [t.split(f"{tgt_language_name}:")[2] for t in translated_text] # Remove prompt
    return translated_text

def translate_batched_OPT(inputs, model, tokenizer, batch_size, target_language):
    preds = []
    for i in tqdm(range(len(inputs)//batch_size)):
        tslt = translate_list_of_str_OPT(inputs[i*batch_size : (i+1)*batch_size], tokenizer, model, target_language)
        preds.extend(tslt)
    return preds

#################################   MPT

def get_input_targets_MPT(dataset, source_lang, target_lang):
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    source_lang_name = language_name[source_lang]
    target_lang_name = language_name[target_lang]
    # Use the instruct template
    sources = [example[source_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    inputs = [(
        "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n### Instruction:"
        + f"Translate from {source_lang_name} to {target_lang_name}: {example.get(source_lang)}"
        + "\n### Response:")
        for example in dataset[f"{source_lang}-{target_lang}"]]
    targets = [example[target_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    return sources, inputs, targets

def translate_list_of_str_MPT(list_str, tokenizer, model, target_language=None):
    """
    Returns a list containing str corresponding to translation of the inputted
    """
    with torch.no_grad():
        inputs = tokenizer(list_str, return_tensors="pt", padding=True)
        translated = model.generate(inputs["input_ids"].to(device),
                                    num_beams=num_beams, max_new_tokens=max_new_tokens, do_sample=True,
                                    temperature=temperature, top_p=top_p
                                    ).cpu()
        translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
        translated_text = [t.split("\n### Response:")[-1] for t in translated_text] # Remove prompt
    return translated_text

def translate_batched_MPT(inputs, model, tokenizer, batch_size, target_language):
    preds = []
    for i in tqdm(range(len(inputs)//batch_size)):
        tslt = translate_list_of_str_MPT(inputs[i*batch_size : (i+1)*batch_size], tokenizer, model, target_language)
        preds.extend(tslt)
    return preds

In [None]:
def reduce_flores_to_some_languages(ds_flores, directions):
    print("Extracting all languages in directions from FLORES...")
    list_languages = []
    for direction in directions:
        lang1, lang2 = direction[0:2], direction[3:5]
        if lang1 not in list_languages:
            list_languages.append(lang1)
        if lang2 not in list_languages:
            list_languages.append(lang2)

    language_to_iso = {"en": "eng", "de": "deu", "cs": "ces", "is": "isl", "zh": "cmn", "ru": "rus"}
    ds_list = []
    for elem in ds_flores:
        for lang in list_languages:
            if elem["iso_639_3"] == language_to_iso[lang]:
                if lang == "zh":
                    if elem["glottocode"] == "beij1234":
                        ds_list.append(elem)
                else:
                    ds_list.append(elem)
    return Dataset.from_list(ds_list)

def transform_to_WNT_style(ds_flores, lang, lang_start="en"):
    language_to_iso = {"en": "eng", "de": "deu", "cs": "ces", "is": "isl", "zh": "cmn", "ru": "rus"}
    list_sentence_lang, list_sentence_lang_start = [], []
    for elem in ds_flores:
        if elem["iso_639_3"] == language_to_iso[lang]:
            if lang == "zh":
                if elem["glottocode"] == "beij1234":
                    list_sentence_lang.append(elem["text"])
            else:
                list_sentence_lang.append(elem["text"])

        elif elem["iso_639_3"] == language_to_iso[lang_start]:
            if lang_start == "zh":
                if elem["glottocode"] == "beij1234":
                    list_sentence_lang_start.append(elem["text"])
            else:
                list_sentence_lang_start.append(elem["text"])
    assert len(list_sentence_lang) == len(list_sentence_lang_start)
    #print(f"Number of samples: {len(list_sentence_lang)}")
    final_text_list = []
    for i in range(len(list_sentence_lang)):
        final_text_list.append({f"{lang_start}": list_sentence_lang_start[i],
                                f"{lang}": list_sentence_lang[i],})
    return Dataset.from_dict({f"{lang_start}-{lang}": final_text_list})

## Dataset handling
We use WNT23 from the authors preprocessed split and the FLORES+ dataset, format in the same way that the WNT23 is.

### WNT23

In [None]:
ds_wnt = load_dataset("haoranxu/WMT23-Test", "en-cs")["test"]
print(len(ds_wnt), ds_wnt[0:4])

### FLORES

In [None]:
from credentials import hf_token
huggingface_hub.login(token = hf_token)
ds_flores = load_dataset("openlanguagedata/flores_plus")["devtest"]

In [None]:
def reduce_flores_to_some_languages(ds_flores, directions):
    list_languages = []
    for direction in directions:
        lang1, lang2 = direction[0:2], direction[3:5]
        if lang1 not in list_languages:
            list_languages.append(lang1)
        if lang2 not in list_languages:
            list_languages.append(lang2)

    language_to_iso = {"en": "eng", "de": "deu", "cs": "ces", "is": "isl", "zh": "cmn", "ru": "rus"}
    ds_list = []
    for elem in ds_flores:
        for lang in list_languages:
            if elem["iso_639_3"] == language_to_iso[lang]:
                if lang == "zh":
                    if elem["glottocode"] == "beij1234":
                        ds_list.append(elem)
                else:
                    ds_list.append(elem)
    return Dataset.from_list(ds_list)

def transform_to_WNT_style(ds_flores, lang, lang_start="en"):
    language_to_iso = {"en": "eng", "de": "deu", "cs": "ces", "is": "isl", "zh": "cmn", "ru": "rus"}
    list_sentence_lang, list_sentence_lang_start = [], []
    for elem in ds_flores:
        if elem["iso_639_3"] == language_to_iso[lang]:
            if lang == "zh":
                if elem["glottocode"] == "beij1234":
                    list_sentence_lang.append(elem["text"])
            else:
                list_sentence_lang.append(elem["text"])

        elif elem["iso_639_3"] == language_to_iso[lang_start]:
            if lang_start == "zh":
                if elem["glottocode"] == "beij1234":
                    list_sentence_lang_start.append(elem["text"])
            else:
                list_sentence_lang_start.append(elem["text"])
    assert len(list_sentence_lang) == len(list_sentence_lang_start)
    #print(f"Number of samples: {len(list_sentence_lang)}")
    final_text_list = []
    for i in range(len(list_sentence_lang)):
        final_text_list.append({f"{lang_start}": list_sentence_lang_start[i],
                                f"{lang}": list_sentence_lang[i],})
    return Dataset.from_dict({f"{lang_start}-{lang}": final_text_list})

In [None]:
directions = ["en-de", "de-en",
              "en-cs", "cs-en",
              "en-is", "is-en",
              "en-zh", "zh-en",
              "en-ru", "ru-en"]
ds_flores_reduced = reduce_flores_to_some_languages(ds_flores, directions)

In [None]:
t1 = time.time()
ds_flores_wnt_style = transform_to_WNT_style(ds_flores, lang="zh", lang_start="en")
print(f"Time to compute: {time.time()-t1:.2f}s", ds_flores_wnt_style[0:4])

t1 = time.time()
ds_flores_wnt_style_reduced = transform_to_WNT_style(ds_flores_reduced, lang="zh", lang_start="en")
print(f"Time to compute: {time.time()-t1:.2f}s", ds_flores_wnt_style_reduced[0:4])

## Models & inference loops

### NLLB

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
model = transformers.AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M", torch_dtype="auto", device_map=device)

In [None]:
sources, inputs, targets = get_input_targets_NLLB(ds_wnt, source_lang="en", target_lang="de")

In [None]:
for i in range(8):
    print(translate_list_of_str_NLLB(inputs[i:i+1], tokenizer, model, "de"))

### ALMA

In [None]:
# Load base model and LoRA weights
tokenizer = transformers.LlamaTokenizer.from_pretrained("haoranxu/ALMA-7B", padding_side='left')
Q_config = BitsAndBytesConfig(load_in_8bit=True) 
model = transformers.AutoModelForCausalLM.from_pretrained("haoranxu/ALMA-7B", torch_dtype="auto", device_map=device, quantization_config=Q_config)

In [None]:
sources, inputs, targets = get_input_targets_ALMA(ds_wnt, source_lang="en", target_lang="de")

In [None]:
for i in range(4):
    print(translate_list_of_str_ALMA(inputs[i:i+1], tokenizer, model, "de"))

### Llama 3 Instruct

In [None]:
from credentials import hf_token
huggingface_hub.login(token = hf_token)
# tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
# tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
# tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
tokenizer.pad_token = tokenizer.eos_token
# model = transformers.AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", torch_dtype="auto", device_map=device)
# model = transformers.AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct", torch_dtype="auto", device_map=device)
# Q_config = BitsAndBytesConfig(load_in_8bit=True) 
# model = transformers.AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", torch_dtype="auto", device_map=device, quantization_config=Q_config)

In [None]:
sources, inputs, targets = get_input_targets_Llama3(ds_wnt, source_lang="en", target_lang="de")

In [None]:
translate_list_of_str_Llama3(inputs[0:5], tokenizer, model)

### Llama3 not instructed 4B (for comparaison to finetuned version)

In [None]:
from credentials import hf_token
huggingface_hub.login(token = hf_token)
tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B")
Q_config = BitsAndBytesConfig(load_in_4bit=True,
                                bnb_4bit_quant_type="nf4",
                                bnb_4bit_compute_dtype=getattr(torch, "float16"),
                                bnb_4bit_use_double_quant=False)
model = transformers.AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype="auto", device_map=device, quantization_config=Q_config)
tokenizer.pad_token = tokenizer.eos_token
model.generation_config.pad_token_id = tokenizer.pad_token_id

In [None]:
def get_input_targets_Llama3NI4bit(dataset, source_lang, target_lang):
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    source_lang_name = language_name[source_lang]
    target_lang_name = language_name[target_lang]
    # Use base formulation "Translate this from Chinese to English:\nChinese: 我爱机器翻译。\nEnglish:"
    sources = [example[source_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    inputs = [(
        f"Translate from {source_lang_name} to {target_lang_name} and end your answer as soon as the task is finished:"
        + f"\n{source_lang_name}: {example.get(source_lang)} \n{target_lang_name}:")
        for example in dataset[f"{source_lang}-{target_lang}"]]
    targets = [example[target_lang] for example in dataset[f"{source_lang}-{target_lang}"]]
    return sources, inputs, targets

def translate_list_of_str_Llama3NI4bit(list_str, tokenizer, model, target_language):
    """
    Returns a list containing str corresponding to translation of the inputted
    """
    language_name = {"en": "English", "de": "German", "ru": "Russian", "is": "Islandic", "zh": "Chinese", "cs": "Czech"}
    with torch.no_grad():
        inputs = tokenizer(list_str, return_tensors="pt", padding=True)
        translated = model.generate(**inputs.to(device),
                                    num_beams=num_beams, max_new_tokens=max_new_tokens, do_sample=True,
                                    temperature=temperature, top_p=top_p
                                    ).cpu()
        translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
        tgt_language_name = language_name[target_language]
        translated_text = [t.split(f"{tgt_language_name}:")[-1] for t in translated_text] # Remove prompt
    return translated_text

def translate_batched_Llama3NI4bit(inputs, model, tokenizer, batch_size, target_language):
    preds = []
    for i in tqdm(range(len(inputs)//batch_size)):
        tslt = translate_list_of_str_Llama3NI4bit(inputs[i*batch_size : (i+1)*batch_size], tokenizer, model, target_language)
        preds.extend(tslt)
    return preds

In [None]:
sources, inputs, targets = get_input_targets_Llama3NI4bit(ds_wnt, source_lang="en", target_lang="de")

In [None]:
for i in range(4):
    print(translate_list_of_str_Llama3NI4bit(inputs[i:i+1], tokenizer, model, target_language="de"))

### Falcon 3 Instruct (mamba and transformer)

In [None]:
# tokenizer = transformers.AutoTokenizer.from_pretrained("tiiuae/Falcon3-Mamba-7B-Instruct")
# Q_config = BitsAndBytesConfig(load_in_8bit=True)
# model = transformers.AutoModelForCausalLM.from_pretrained("tiiuae/Falcon3-Mamba-7B-Instruct", torch_dtype="auto", device_map=device, quantization_config=Q_config)
tokenizer = transformers.AutoTokenizer.from_pretrained("tiiuae/Falcon3-7B-Instruct")
Q_config = BitsAndBytesConfig(load_in_8bit=True)
model = transformers.AutoModelForCausalLM.from_pretrained("tiiuae/Falcon3-7B-Instruct", torch_dtype="auto", device_map=device, quantization_config=Q_config)
# tokenizer = transformers.AutoTokenizer.from_pretrained("tiiuae/Falcon3-3B-Instruct")
# model = transformers.AutoModelForCausalLM.from_pretrained("tiiuae/Falcon3-3B-Instruct", torch_dtype="auto", device_map=device)
# tokenizer = transformers.AutoTokenizer.from_pretrained("tiiuae/Falcon3-1B-Instruct")
# model = transformers.AutoModelForCausalLM.from_pretrained("tiiuae/Falcon3-1B-Instruct", torch_dtype="auto", device_map=device)

In [None]:
sources, inputs, targets = get_input_targets_Falcon3(ds_wnt, source_lang="en", target_lang="de")

In [None]:
translate_list_of_str_Falcon3Mamba(inputs[0:4], tokenizer, model)

In [None]:
translate_list_of_str_Falcon3(inputs[0:8], tokenizer, model)

### Qwen 2.5

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = transformers.AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", torch_dtype="auto", device_map=device)
# tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
# model = transformers.AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct", torch_dtype="auto", device_map=device)
# tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct")
# model = transformers.AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct", torch_dtype="auto", device_map=device)
# tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8")
# model = transformers.AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8", torch_dtype="auto", device_map=device)

In [None]:
sources, inputs, targets = get_input_targets_Qwen2_5(ds_wnt, source_lang="en", target_lang="de")

In [None]:
for i in range(10):
    print(translate_list_of_str_Qwen2_5(inputs[i:i+1], tokenizer, model))

### Mistral 7B

In [None]:
from credentials import hf_token
huggingface_hub.login(token = hf_token)
tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
tokenizer.pad_token = tokenizer.eos_token
Q_config = BitsAndBytesConfig(load_in_8bit=True)
model = transformers.AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", torch_dtype="auto", device_map=device, quantization_config=Q_config)

In [None]:
sources, inputs, targets = get_input_targets_Mistral(ds_wnt, source_lang="en", target_lang="de")

In [None]:
for i in range(4):
    print(translate_list_of_str_Mistral(inputs[i:i+1], tokenizer, model))

### Bayling

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained("ICTNLP/bayling-2-7b")
tokenizer.pad_token = tokenizer.eos_token
Q_config = BitsAndBytesConfig(load_in_8bit=True)
model = transformers.AutoModelForCausalLM.from_pretrained("ICTNLP/bayling-2-7b", torch_dtype="auto", device_map=device, quantization_config=Q_config)

In [None]:
sources, inputs, targets = get_input_targets_BayLing(ds_wnt, source_lang="en", target_lang="de")

In [None]:
for i in range(4):
    print(translate_list_of_str_BayLing(inputs[i:i+1], tokenizer, model, target_language="de"))

### Bloom & Bloom Z

In [None]:
# tokenizer = transformers.AutoTokenizer.from_pretrained("bigscience/bloom-560m")
# model = transformers.AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.bfloat16, device_map=device)
# tokenizer = transformers.AutoTokenizer.from_pretrained("bigscience/bloom-1b7")
# model = transformers.AutoModelForCausalLM.from_pretrained("bigscience/bloom-1b7", torch_dtype="auto", device_map=device)
# tokenizer = transformers.AutoTokenizer.from_pretrained("bigscience/bloom-3b")
# model = transformers.AutoModelForCausalLM.from_pretrained("bigscience/bloom-3b", torch_dtype="auto", device_map=device)
# tokenizer = transformers.AutoTokenizer.from_pretrained("bigscience/bloom-7b1")
# Q_config = BitsAndBytesConfig(load_in_8bit=True)
# model = transformers.AutoModelForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype="auto", device_map=device, quantization_config=Q_config)
# tokenizer = transformers.AutoTokenizer.from_pretrained("bigscience/bloomz-1b7")
# model = transformers.AutoModelForCausalLM.from_pretrained("bigscience/bloomz-1b7", torch_dtype="auto", device_map=device)
# tokenizer = transformers.AutoTokenizer.from_pretrained("bigscience/bloomz-3b")
# model = transformers.AutoModelForCausalLM.from_pretrained("bigscience/bloomz-3b", torch_dtype="auto", device_map=device)
tokenizer = transformers.AutoTokenizer.from_pretrained("bigscience/bloomz-7b1")
Q_config = BitsAndBytesConfig(load_in_8bit=True)
model = transformers.AutoModelForCausalLM.from_pretrained("bigscience/bloomz-7b1", torch_dtype="auto", device_map=device, quantization_config=Q_config)

In [None]:
sources, inputs, targets = get_input_targets_BLOOM(ds_wnt, "en", "de")

In [None]:
translate_batched_BLOOM(inputs[0:2], model, tokenizer, batch_size=1, target_language="de")

### OPT & OPT Instruct

In [None]:
# tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/opt-125m")
# model = transformers.AutoModelForCausalLM.from_pretrained("facebook/opt-125m", torch_dtype="auto", device_map=device)
# tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/opt-350m")
# model = transformers.AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype="auto", device_map=device)
# tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/opt-6.7b")
# Q_config = BitsAndBytesConfig(load_in_8bit=True)
# model = transformers.AutoModelForCausalLM.from_pretrained("facebook/opt-6.7b", torch_dtype="auto", device_map=device, quantization_config=Q_config)
tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/opt-iml-1.3b")
model = transformers.AutoModelForCausalLM.from_pretrained("facebook/opt-iml-1.3b", torch_dtype="auto", device_map=device)

In [None]:
sources, inputs, targets = get_input_targets_OPT(ds_wnt, "en", "de")

In [None]:
translate_batched_OPT(inputs[4:6], model, tokenizer, batch_size=1, target_language="de")

### MPT

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained("mosaicml/mpt-7b-instruct")
tokenizer.pad_token = tokenizer.eos_token
Q_config = BitsAndBytesConfig(load_in_8bit=True)
model = transformers.AutoModelForCausalLM.from_pretrained("mosaicml/mpt-7b-instruct", torch_dtype="auto", device_map=device, quantization_config=Q_config)
model.generation_config.pad_token_id = tokenizer.pad_token_id

In [None]:
sources, inputs, targets = get_input_targets_MPT(ds_wnt, source_lang="en", target_lang="de")

In [None]:
for i in range(4,6):
    print(translate_list_of_str_MPT(inputs[i:i+1], tokenizer, model, target_language="de"))

## General aggregation functions for benchmark

In [None]:
def load_model(model_name, model_size=None):
    if model_name == "alma":
        tokenizer = transformers.LlamaTokenizer.from_pretrained("haoranxu/ALMA-7B", padding_side='left')
        Q_config = BitsAndBytesConfig(load_in_8bit=True) 
        model = transformers.AutoModelForCausalLM.from_pretrained("haoranxu/ALMA-7B", torch_dtype="auto", device_map=device, quantization_config=Q_config)
        
    elif model_name == "nllb":
        tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
        model = transformers.AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M", torch_dtype="auto", device_map=device)

    elif model_name == "llama3":
        from credentials import hf_token
        huggingface_hub.login(token = hf_token)
        if model_size=="1B" or model_size=="3B":
            tokenizer = transformers.AutoTokenizer.from_pretrained(f"meta-llama/Llama-3.2-{model_size}-Instruct")
            model = transformers.AutoModelForCausalLM.from_pretrained(f"meta-llama/Llama-3.2-{model_size}-Instruct", torch_dtype="auto", device_map=device)
        else:
            tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
            nQ_cofig = BitsAndBytesConfig(load_in_8bit=True)
            model = transformers.AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", torch_dtype="auto", device_map=device, quantization_config=Q_config)
        tokenizer.pad_token = tokenizer.eos_token
        model.generation_config.pad_token_id = tokenizer.pad_token_id
    
    elif model_name == "llama3-NI-4bit":
        from credentials import hf_token
        huggingface_hub.login(token = hf_token)
        tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B")
        nQ_cofig = BitsAndBytesConfig(load_in_4bit=True,
                                      bnb_4bit_quant_type="nf4",
                                      bnb_4bit_compute_dtype=getattr(torch, "float16"),
                                      bnb_4bit_use_double_quant=False)
        model = transformers.AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype="auto", device_map=device, quantization_config=Q_config)
        tokenizer.pad_token = tokenizer.eos_token
        model.generation_config.pad_token_id = tokenizer.pad_token_id
        
    elif model_name == "falcon3-mamba":
        tokenizer = transformers.AutoTokenizer.from_pretrained("tiiuae/Falcon3-Mamba-7B-Instruct")
        Q_config = BitsAndBytesConfig(load_in_8bit=True)
        model = transformers.AutoModelForCausalLM.from_pretrained("tiiuae/Falcon3-Mamba-7B-Instruct", torch_dtype="auto", device_map=device, quantization_config=Q_config)
    
    elif model_name == "falcon3":
        if model_size=="1B" or model_size=="3B":
            tokenizer = transformers.AutoTokenizer.from_pretrained(f"tiiuae/Falcon3-{model_size}-Instruct")
            model = transformers.AutoModelForCausalLM.from_pretrained(f"tiiuae/Falcon3-{model_size}-Instruct", torch_dtype="auto", device_map=device)
        else:
            tokenizer = transformers.AutoTokenizer.from_pretrained("tiiuae/Falcon3-7B-Instruct")
            Q_config = BitsAndBytesConfig(load_in_8bit=True)
            model = transformers.AutoModelForCausalLM.from_pretrained("tiiuae/Falcon3-7B-Instruct", torch_dtype="auto", device_map=device, quantization_config=Q_config)
        model.generation_config.pad_token_id = tokenizer.pad_token_id
        
    elif model_name == "qwen2.5":
        if model_size=="0.5B" or model_size=="1.5B" or model_size=="3B":
            tokenizer = transformers.AutoTokenizer.from_pretrained(f"Qwen/Qwen2.5-{model_size}-Instruct")
            model = transformers.AutoModelForCausalLM.from_pretrained(f"Qwen/Qwen2.5-{model_size}-Instruct", torch_dtype="auto", device_map=device)
        else:
            tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
            Q_config = BitsAndBytesConfig(load_in_8bit=True)
            model = transformers.AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B-Instruct", torch_dtype="auto", device_map=device, quantization_config=Q_config)
    
    elif model_name == "mistral":
        from credentials import hf_token
        huggingface_hub.login(token = hf_token)
        tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
        tokenizer.pad_token = tokenizer.eos_token
        Q_config = BitsAndBytesConfig(load_in_8bit=True)
        model = transformers.AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", torch_dtype="auto", device_map=device, quantization_config=Q_config)
        model.generation_config.pad_token_id = tokenizer.pad_token_id
    
    elif model_name == "bayling":
        tokenizer = transformers.AutoTokenizer.from_pretrained("ICTNLP/bayling-2-7b")
        tokenizer.pad_token = tokenizer.eos_token
        Q_config = BitsAndBytesConfig(load_in_8bit=True)
        model = transformers.AutoModelForCausalLM.from_pretrained("ICTNLP/bayling-2-7b", torch_dtype="auto", device_map=device, quantization_config=Q_config)
    
    elif model_name == "bloom":
        if model_size=="0.5B":
            tokenizer = transformers.AutoTokenizer.from_pretrained("bigscience/bloom-560m")
            model = transformers.AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.bfloat16, device_map=device)
        elif model_size=="1B":
            tokenizer = transformers.AutoTokenizer.from_pretrained("bigscience/bloom-1b7")
            model = transformers.AutoModelForCausalLM.from_pretrained("bigscience/bloom-1b7", torch_dtype="auto", device_map=device)
        elif model_size=="3B":
            tokenizer = transformers.AutoTokenizer.from_pretrained("bigscience/bloom-3b")
            model = transformers.AutoModelForCausalLM.from_pretrained("bigscience/bloom-3b", torch_dtype="auto", device_map=device)
        else:
            tokenizer = transformers.AutoTokenizer.from_pretrained("bigscience/bloom-7b1")
            Q_config = BitsAndBytesConfig(load_in_8bit=True)
            model = transformers.AutoModelForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype="auto", device_map=device, quantization_config=Q_config)

    elif model_name == "bloomz":
        if model_size=="1B":
            tokenizer = transformers.AutoTokenizer.from_pretrained("bigscience/bloomz-1b7")
            model = transformers.AutoModelForCausalLM.from_pretrained("bigscience/bloomz-1b7", torch_dtype="auto", device_map=device)
        elif model_size=="3B":
            tokenizer = transformers.AutoTokenizer.from_pretrained("bigscience/bloomz-3b")
            model = transformers.AutoModelForCausalLM.from_pretrained("bigscience/bloomz-3b", torch_dtype="auto", device_map=device)
        else:
            tokenizer = transformers.AutoTokenizer.from_pretrained("bigscience/bloomz-7b1")
            Q_config = BitsAndBytesConfig(load_in_8bit=True)
            model = transformers.AutoModelForCausalLM.from_pretrained("bigscience/bloomz-7b1", torch_dtype="auto", device_map=device, quantization_config=Q_config)
    
    elif model_name == "opt":
        if model_size=="0.1B":
            tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/opt-125m")
            model = transformers.AutoModelForCausalLM.from_pretrained("facebook/opt-125m", torch_dtype="auto", device_map=device)
        elif model_size=="0.3B":
            tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/opt-350m")
            model = transformers.AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype="auto", device_map=device)
        else:
            tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/opt-6.7b")
            Q_config = BitsAndBytesConfig(load_in_8bit=True)
            model = transformers.AutoModelForCausalLM.from_pretrained("facebook/opt-6.7b", torch_dtype="auto", device_map=device, quantization_config=Q_config)
    
    elif model_name == "opt-instruct":
        tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/opt-iml-1.3b")
        model = transformers.AutoModelForCausalLM.from_pretrained("facebook/opt-iml-1.3b", torch_dtype="auto", device_map=device)
    
    elif model_name == "mpt":
        tokenizer = transformers.AutoTokenizer.from_pretrained("mosaicml/mpt-7b-instruct")
        tokenizer.pad_token = tokenizer.eos_token
        Q_config = BitsAndBytesConfig(load_in_8bit=True)
        model = transformers.AutoModelForCausalLM.from_pretrained("mosaicml/mpt-7b-instruct", torch_dtype="auto", device_map=device, quantization_config=Q_config)
        model.generation_config.pad_token_id = tokenizer.pad_token_id
        
    return tokenizer, model

def get_support_fn(model_name):
    if model_name == "alma":
        get_input_targets_fn = get_input_targets_ALMA
        tslt_fn = translate_batched_ALMA
        
    elif model_name == "nllb":
        get_input_targets_fn = get_input_targets_NLLB
        tslt_fn = translate_batched_NLLB

    elif model_name == "llama3":
        get_input_targets_fn = get_input_targets_Llama3
        tslt_fn = translate_batched_Llama3
    
    elif model_name == "llama3-NI-4bit":
        get_input_targets_fn = get_input_targets_Llama3NI4bit
        tslt_fn = translate_batched_Llama3NI4bit
    
    elif model_name == "falcon3-mamba":
        get_input_targets_fn = get_input_targets_Falcon3
        tslt_fn = translate_batched_Falcon3Mamba
    
    elif model_name == "falcon3":
        get_input_targets_fn = get_input_targets_Falcon3
        tslt_fn = translate_batched_Falcon3
    
    elif model_name == "qwen2.5":
        get_input_targets_fn = get_input_targets_Qwen2_5
        tslt_fn = translate_batched_Qwen2_5
    
    elif model_name == "mistral":
        get_input_targets_fn = get_input_targets_Mistral
        tslt_fn = translate_batched_Mistral
    
    elif model_name == "bayling":
        get_input_targets_fn = get_input_targets_BayLing
        tslt_fn = translate_batched_BayLing

    elif model_name == "bloom" or model_name == "bloomz":
        get_input_targets_fn = get_input_targets_BLOOM
        tslt_fn = translate_batched_BLOOM
    
    elif model_name == "opt" or model_name == "opt-instruct":
        get_input_targets_fn = get_input_targets_OPT
        tslt_fn = translate_batched_OPT
    
    elif model_name == "mpt":
        get_input_targets_fn = get_input_targets_MPT
        tslt_fn = translate_batched_MPT
        
    return get_input_targets_fn, tslt_fn

def get_inp_tgt_lang(direction):
    return direction[0:2], direction[3:5]

def reduce_dataset(inputs, sources, targets, final_nb):
    idx = np.arange(len(inputs))
    np.random.seed(42)
    idx = np.random.choice(idx, final_nb)
    return [inputs[i] for i in idx], [sources[i] for i in idx], [targets[i] for i in idx]

def get_translations_filename(direction, dataset_name, model_name, model_size, reduce_size):
    mod_size = "-"+model_size if model_size is not None else ""
    return f"./generated_translations/evaluations/{dataset_name}_{model_name}{mod_size}_{direction}_red-{reduce_size}.pkl"

def get_eval_filename(direction, dataset_name, model_name, model_size, reduce_size):
    mod_size = "-"+model_size if model_size is not None else ""
    return f"./evaluations/raw_{dataset_name}_{model_name}{mod_size}_{direction}_red-{reduce_size}.pkl"


def generate_translation_different_directions(directions, dataset_name, model_name, batch_size, reduce_size = None, model_size = None):
    # Loading full flores (if necessary)
    if dataset_name == "flores":
        from credentials import hf_token
        huggingface_hub.login(token = hf_token)
        ds_flores = load_dataset("openlanguagedata/flores_plus")["devtest"]

    # Loading corresponding model
    print("Loading model...")
    tokenizer, model = load_model(model_name, model_size)
    get_input_targets_fn, tslt_fn = get_support_fn(model_name)

    for direction in directions:
        print(f"Translating {direction} with model {model_name}"
              +(f"-{model_size}" if model_size is not None else "")
              +f" for dataset {dataset_name}...")
        input_language, target_language = get_inp_tgt_lang(direction)
        
        # Getting the right split corresponding to the translation direction
        if dataset_name == "flores":
            ds = transform_to_WNT_style(ds_flores, lang=target_language, lang_start=input_language)
        elif dataset_name == "wnt23":
            if direction != "cs-en":
                ds = load_dataset("haoranxu/WMT23-Test", direction)["test"]
            else:
                ds = load_dataset("haoranxu/WMT23-Test", "en-cs")["test"]
                ds = Dataset.from_dict({f"cs-en": ds["en-cs"][::-1]}) # Reverse list to avoid having same sentences (if reduce_size not None)
        # Extracting input & targets
        sources, inputs, targets = get_input_targets_fn(ds, input_language, target_language)
        print(f"Total number of samples: {len(sources)}" + ("" if reduce_size is None else f"; reduced to {reduce_size} (numpy seed = 42)"))
        if reduce_size is not None:
            sources, inputs, targets = reduce_dataset(sources, inputs, targets, reduce_size)
        translation_pred = tslt_fn(inputs, model, tokenizer, batch_size, target_language)
        if not os.path.exists(f"./generated_translations/evaluations"):
            os.makedirs(f"./generated_translations/evaluations")
        translations_filename = get_translations_filename(direction, dataset_name, model_name, model_size, reduce_size)
        with open(translations_filename, "wb") as f:
            pickle.dump(translation_pred, f, pickle.HIGHEST_PROTOCOL)

    # De-load model from GPU to enable calling this function with another model without restarting kernel
    model.cpu()
    del model, tokenizer

def generate_translation_several_models(directions, dataset_name, model_names, model_sizes, batch_size, reduce_size):
    for model_name, model_size in zip(model_names, model_sizes):
        generate_translation_different_directions(directions=directions,
                                                dataset_name=dataset_name,
                                                model_name=model_name,
                                                model_size=model_size,
                                                batch_size=batch_size,
                                                reduce_size=reduce_size)
        
def generate_translation_several_datasets(directions, dataset_names, model_names, model_sizes, batch_size, reduce_size):
    for dataset_name in dataset_names:
        generate_translation_several_models(directions, dataset_name, model_names, model_sizes, batch_size, reduce_size)

## Translation part

In [None]:
directions = ["en-de", "de-en",
              "en-cs", "cs-en",
              "en-is", "is-en",
              "en-zh", "zh-en",
              "en-ru", "ru-en"]
model_names = ["alma",
               "nllb",
               "llama3", "llama3", "llama3",
               "falcon3-mamba",
               "falcon3", "falcon3", "falcon3",
               "qwen2.5", "qwen2.5", "qwen2.5"]
model_sizes = [None,
               None,
               "1B", "3B", "8B",
               None,
               "1B", "3B", "7B",
               "0.5B", "1.5B", "3B"]

generate_translation_several_models(directions,
                                    dataset_name="wnt23",
                                    model_names=model_names,
                                    model_sizes=model_sizes,
                                    batch_size=1,
                                    reduce_size=100)

In [None]:
directions = ["en-de", "de-en",
              "en-cs", "cs-en",
              "en-is", "is-en",
              "en-zh", "zh-en",
              "en-ru", "ru-en"]
model_names = ["qwen2.5",
               "mistral"]
model_sizes = ["7B",
               None]

generate_translation_several_models(directions,
                                    dataset_name="wnt23",
                                    model_names=model_names,
                                    model_sizes=model_sizes,
                                    batch_size=1,
                                    reduce_size=100)

In [None]:
directions = ["en-de", "de-en",
              "en-cs", "cs-en",
              "en-is", "is-en",
              "en-zh", "zh-en",
              "en-ru", "ru-en"]
model_names = ["bloom", "bloom"]
model_sizes = ["0.5B", "1B"]

generate_translation_several_models(directions,
                                    dataset_name="wnt23",
                                    model_names=model_names,
                                    model_sizes=model_sizes,
                                    batch_size=1,
                                    reduce_size=100)

In [None]:
directions = ["en-de", "de-en",
              "en-cs", "cs-en"]
model_names = ["bloom"]
model_sizes = ["7B"]

generate_translation_several_models(directions,
                                    dataset_name="wnt23",
                                    model_names=model_names,
                                    model_sizes=model_sizes,
                                    batch_size=1,
                                    reduce_size=100)

In [None]:
directions = ["en-de", "de-en",
              "en-cs", "cs-en",
              "en-is", "is-en",
              "en-zh", "zh-en",
              "en-ru", "ru-en"]
model_names = ["bloomz"]
model_sizes = ["7B"]

generate_translation_several_models(directions,
                                    dataset_name="wnt23",
                                    model_names=model_names,
                                    model_sizes=model_sizes,
                                    batch_size=1,
                                    reduce_size=100)

In [None]:
directions = ["en-de", "de-en",
              "en-cs", "cs-en",
              "en-is", "is-en",
              "en-zh", "zh-en",
              "en-ru", "ru-en"]
model_names = ["opt-instruct"]
model_sizes = [None]

generate_translation_several_models(directions,
                                    dataset_name="wnt23",
                                    model_names=model_names,
                                    model_sizes=model_sizes,
                                    batch_size=1,
                                    reduce_size=100)

In [None]:
directions = ["en-de", "de-en",
              "en-cs", "cs-en",
              "en-is", "is-en",
              "en-zh", "zh-en",
              "en-ru", "ru-en"]
model_names = ["bloomz",
               "opt-instruct"]
model_sizes = ["7B",
               None]

generate_translation_several_models(directions,
                                    dataset_name="flores",
                                    model_names=model_names,
                                    model_sizes=model_sizes,
                                    batch_size=1,
                                    reduce_size=200)

In [None]:
### Left to do if time ###

directions = ["en-de", "de-en",
              "en-cs", "cs-en",
              "en-is", "is-en",
              "en-zh", "zh-en",
              "en-ru", "ru-en"]
model_names = ["opt",
               "mpt",
               "bayling",
               "bloomz", "bloomz",
               "bloom", "bloom", "bloom", "bloom",
               "opt", "opt"]
model_sizes = ["7B",
               None,
               None,
               "1B", "3B",
               "0.5B", "1B", "3B", "7B",
               "0.1B", "0.3B"]

generate_translation_several_models(directions,
                                    dataset_name="flores",
                                    model_names=model_names,
                                    model_sizes=model_sizes,
                                    batch_size=1,
                                    reduce_size=200)

directions = ["en-de", "de-en",
              "en-cs", "cs-en",
              "en-is", "is-en",
              "en-zh", "zh-en",
              "en-ru", "ru-en"]
model_names = ["opt",
               "mpt",
               "bayling",
               "bloomz", "bloomz",
               "bloom",
               "opt", "opt"]
model_sizes = ["7B",
               None,
               None,
               "1B", "3B",
               "3B",
               "0.1B", "0.3B"]

generate_translation_several_models(directions,
                                    dataset_name="wnt23",
                                    model_names=model_names,
                                    model_sizes=model_sizes,
                                    batch_size=1,
                                    reduce_size=100)

directions = ["en-is", "is-en",
              "en-zh", "zh-en",
              "en-ru", "ru-en"]
model_names = ["bloom"]
model_sizes = ["7B"]

generate_translation_several_models(directions,
                                    dataset_name="wnt23",
                                    model_names=model_names,
                                    model_sizes=model_sizes,
                                    batch_size=1,
                                    reduce_size=100)

In [None]:
directions = ["en-de", "de-en",
              "en-cs", "cs-en",
              "en-is", "is-en",
              "en-zh", "zh-en",
              "en-ru", "ru-en"]
model_names = ["alma",
               "nllb",
               "llama3", "llama3", "llama3",
               "falcon3-mamba",
               "falcon3", "falcon3", "falcon3",
               "qwen2.5", "qwen2.5", "qwen2.5", "qwen2.5",
               "mistral"]
model_sizes = [None,
               None,
               "1B", "3B", "8B",
               None,
               "1B", "3B", "7B",
               "0.5B", "1.5B", "3B", "7B",
               None]
               
generate_translation_several_models(directions,
                                    dataset_name="flores",
                                    model_names=model_names,
                                    model_sizes=model_sizes,
                                    batch_size=1,
                                    reduce_size=200)

In [None]:
generate_translation_different_directions(directions=["en-de", "de-en",
                                                        "en-cs", "cs-en"],
                                          dataset_name="wnt23",
                                          model_name="falcon3-mamba",
                                          model_size=None,
                                          batch_size=1,
                                          reduce_size=100)

In [None]:
generate_translation_different_directions(directions=["en-is", "is-en",
                                                        "en-zh", "zh-en",
                                                        "en-ru", "ru-en"],
                                          dataset_name="wnt23",
                                          model_name="falcon3-mamba",
                                          model_size=None,
                                          batch_size=1,
                                          reduce_size=100)

## Metrics from predictions: evaluation function

### Evaluations

In [None]:
def eval_rouge(metric, sources, targets, translation_infered, target_language):
    out_rouge = metric.compute(predictions=translation_infered,
                                  references=targets,
                                  use_aggregator=False)
    # For further statistical treatment
    results_rouge = {"rouge1": {},
                     "rouge2": {},
                     "rougeL": {},
                     "rougeLsum": {},}
    for key in ["rouge1", "rouge2", "rougeL", "rougeLsum"]:
        results_rouge[key]["mean_score"] = np.mean(out_rouge[key]).item()
        results_rouge[key]["std_score"] = np.std(out_rouge[key]).item()
        results_rouge[key]["std_unbias_score"] = np.std(out_rouge[key], ddof=1).item()
    return results_rouge

def eval_bleu(metric, sources, targets, translation_infered, target_language):
    results_bleu = {"scores": [], "brevity_penalty": []}
    for trans, tgt in zip(translation_infered, targets):
        try:
            bleu_out = metric.compute(predictions=[trans],
                                    references=[[tgt]],
                                    tokenizer = TokenizerZh() if target_language=="zh" else Tokenizer13a())
        except ZeroDivisionError:
            bleu_out={"bleu": 0., "brevity_penalty": 0.}

        results_bleu["scores"].append(bleu_out["bleu"])
        results_bleu["brevity_penalty"].append(bleu_out["brevity_penalty"])
    # For further statistical treatment
    results_bleu["mean_score"] = np.mean(results_bleu["scores"]).item()
    results_bleu["std_score"] = np.std(results_bleu["scores"]).item()
    results_bleu["std_unbias_score"] = np.std(results_bleu["scores"], ddof=1).item()
    return {"bleu": results_bleu}

def eval_sacrebleu(metric, sources, targets, translation_infered, target_language):
    results_sacrebleu = {"scores": [], "brevity_penalty": []}
    for trans, tgt in zip(translation_infered, targets):
        try:
            sacrebleu_out = metric.compute(predictions=[trans],
                                            references=[[tgt]],
                                            tokenize = "zh" if target_language=="zh" else "13a")
        except ZeroDivisionError:
            sacrebleu_out = {"score": 0., "bp": 0.}
        results_sacrebleu["scores"].append(sacrebleu_out["score"])
        results_sacrebleu["brevity_penalty"].append(sacrebleu_out["bp"])
    # For further statistical treatment
    results_sacrebleu["mean_score"] = np.mean(results_sacrebleu["scores"]).item()
    results_sacrebleu["std_score"] = np.std(results_sacrebleu["scores"]).item()
    results_sacrebleu["std_unbias_score"] = np.std(results_sacrebleu["scores"], ddof=1).item()
    return {"sacrebleu": results_sacrebleu}

def eval_chrf_and_chrfplusplus(metric, sources, targets, translation_infered, target_language):
    results_chrf = {"scores": []}
    results_chrfplusplus = {"scores": []}
    for trans, tgt in zip(translation_infered, targets):
        try:
            chrf_out = metric.compute(predictions=[trans],
                                    references=[[tgt]],
                                    word_order=0,
                                    eps_smoothing=False)
        except ZeroDivisionError:
            chrf_out = {"score": 0.}
        try:
            chrfplusplus_out = metric.compute(predictions=[trans],
                                            references=[[tgt]],
                                            word_order=2,
                                            eps_smoothing=True)
        except ZeroDivisionError:
            chrfplusplus_out = {"score": 0.}
        results_chrf["scores"].append(chrf_out['score'])
        results_chrfplusplus["scores"].append(chrfplusplus_out['score'])
    # For further statistical treatment
    results_chrf["mean_score"] = np.mean(results_chrf["scores"]).item()
    results_chrf["std_score"] = np.std(results_chrf["scores"]).item()
    results_chrf["std_unbias_score"] = np.std(results_chrf["scores"], ddof=1).item()
    results_chrfplusplus["mean_score"] = np.mean(results_chrfplusplus["scores"]).item()
    results_chrfplusplus["std_score"] = np.std(results_chrfplusplus["scores"]).item()
    results_chrfplusplus["std_unbias_score"] = np.std(results_chrfplusplus["scores"], ddof=1).item()
    return {"chrf": results_chrf,
            "chrfplusplus": results_chrfplusplus}

def eval_comet(metric, sources, targets, translation_infered, target_language):
    results_comet = metric.compute(predictions=translation_infered,
                                         references=targets,
                                         sources=sources)
    # For further statistical treatment
    results_comet.update({"std_score": np.std(results_comet["scores"]).item(),
                          "std_unbias_score": np.std(results_comet["scores"], ddof=1).item()})
    return {"comet": results_comet}

def eval_bleurt(metric, sources, targets, translation_infered, target_language):
    results_bleurt = metric.compute(predictions=translation_infered,
                                    references=targets)
    # For further statistical treatment
    results_bleurt.update({"mean_score": np.mean(results_bleurt["scores"]).item(),
                           "std_score": np.std(results_bleurt["scores"]).item(),
                           "std_unbias_score": np.std(results_bleurt["scores"], ddof=1).item()})
    return {"bleurt": results_bleurt}

def eval_bertscore(metric, sources, targets, translation_infered, target_language):
    results_bert = metric.compute(predictions=translation_infered, references=targets, lang=target_language)
    # For further statistical treatment
    results_bert.update({"mean_score": np.mean(results_bert["f1"]).item(),
                         "std_score": np.std(results_bert["f1"]).item(),
                         "std_unbias_score": np.std(results_bert["f1"], ddof=1).item()})
    return {"bertscore": results_bert}

def eval_meteor(metric, sources, targets, translation_infered, target_language):
    results_meteor = {"scores": []}
    for trans, tgt in zip(translation_infered, targets):
        meteor_out = metric.compute(predictions=[trans],
                                    references=[tgt])
        results_meteor["scores"].append(meteor_out["meteor"])
    # For further statistical treatment
    results_meteor["mean_score"] = np.mean(results_meteor["scores"]).item()
    results_meteor["std_score"] = np.std(results_meteor["scores"]).item()
    results_meteor["std_unbias_score"] = np.std(results_meteor["scores"], ddof=1).item()
    return {"meteor": results_meteor}

def get_eval_fn(metric_name):
    if metric_name == "rouge":
        return eval_rouge
    elif metric_name == "bleu":
        return eval_bleu
    elif metric_name == "sacrebleu":
        return eval_sacrebleu
    elif metric_name == "chrf":
        return eval_chrf_and_chrfplusplus
    elif metric_name == "comet":
        return eval_comet
    elif metric_name == "bleurt":
        return eval_bleurt
    elif metric_name == "bertscore":
        return eval_bertscore
    elif metric_name == "meteor":
        return eval_meteor

def load_metric(metric_name):
    if metric_name == "rouge":
        return evaluate.load('rouge')
    elif metric_name == "bleu":
        return evaluate.load("bleu")
    elif metric_name == "sacrebleu":
        return evaluate.load("sacrebleu")
    elif metric_name == "chrf":
        return evaluate.load("chrf")
    elif metric_name == "comet":
        return evaluate.load('comet')
    elif metric_name == "bleurt":
        return evaluate.load('bleurt', 'bleurt-large-512')
    elif metric_name == "bertscore":
        return evaluate.load("bertscore")
    elif metric_name == "meteor":
        return evaluate.load('meteor')

In [None]:
def eval_one_metric_one_model(metric_name, metric, directions, dataset_name, model_name, model_size, reduce_size):
    # Getting right evaluation function
    metric_eval_fn = get_eval_fn(metric_name)

    # Loading full flores (if necessary)
    if dataset_name == "flores":
        from credentials import hf_token
        huggingface_hub.login(token = hf_token)
        ds_flores = load_dataset("openlanguagedata/flores_plus")["devtest"]
        ds_flores = reduce_flores_to_some_languages(ds_flores, directions)

    for direction in directions:
        print(f"Evaluating translations {direction} with model {model_name}"
              +(f"-{model_size}" if model_size is not None else "")
              +f" for dataset {dataset_name}...")
        input_language, target_language = get_inp_tgt_lang(direction)

        # Loading previous eval if existing
        eval_filename = get_eval_filename(direction, dataset_name, model_name, model_size, reduce_size)
        if not os.path.exists(f"./evaluations"):
            os.makedirs(f"./evaluations")
        if os.path.exists(eval_filename):
            with open(eval_filename, "rb") as f:
                complete_eval = pickle.load(f)
        else:
            complete_eval = {}
        
        # Getting the right split corresponding to the translation direction
        if dataset_name == "flores":
            ds = transform_to_WNT_style(ds_flores, lang=target_language, lang_start=input_language)
        elif dataset_name == "wnt23":
            if direction != "cs-en":
                ds = load_dataset("haoranxu/WMT23-Test", direction)["test"]
            else:
                ds = load_dataset("haoranxu/WMT23-Test", "en-cs")["test"]
                ds = Dataset.from_dict({f"cs-en": ds["en-cs"][::-1]}) # Reverse list to avoid having same sentences (if reduce_size not None)
        
        # Extracting input & targets
        get_input_targets_fn, _ = get_support_fn(model_name)
        sources, inputs, targets = get_input_targets_fn(ds, input_language, target_language)
        print(f"Total number of samples: {len(sources)}" + ("" if reduce_size is None else f"; reduced to {reduce_size} (numpy seed = 42)"))
        if reduce_size is not None:
            # /!\ Use same reduce size and same seed to ensure sources and previous inputs are the same /!\
            sources, inputs, targets = reduce_dataset(sources, inputs, targets, reduce_size)

        # Loading precomputed translations
        translations_filename = get_translations_filename(direction, dataset_name, model_name, model_size, reduce_size)
        with open(translations_filename, "rb") as f:
            translation_pred = pickle.load(f)
        
        # Evaluation translation for this direction
        eval_dict = metric_eval_fn(metric, sources, targets, translation_pred, target_language)
        complete_eval.update(eval_dict)

        with open(eval_filename, "wb") as f:
            pickle.dump(complete_eval, f, pickle.HIGHEST_PROTOCOL)

def eval_one_metric(metric_name, directions, dataset_names, model_names, model_sizes, reduce_sizes):
    print(f"Computing evaluations with {metric_name}...")
    metric = load_metric(metric_name)
    for dataset_name, reduce_size in zip(dataset_names, reduce_sizes):
        for model_name, model_size in zip(model_names, model_sizes):
            eval_one_metric_one_model(metric_name, metric, directions, dataset_name, model_name, model_size, reduce_size)

def eval_metrics(metric_names, directions, dataset_names, model_names, model_sizes, reduce_sizes):
    for metric_name in metric_names:
        eval_one_metric(metric_name, directions, dataset_names, model_names, model_sizes, reduce_sizes)

In [None]:
metric_names = ["rouge", "bleu", "sacrebleu", "chrf", "comet", "meteor", "bertscore"]

dataset_names = ["wnt23", "flores"]
reduce_sizes = [100, 200]

directions = ["en-de", "de-en",
              "en-cs", "cs-en",
              "en-is", "is-en",
              "en-zh", "zh-en",
              "en-ru", "ru-en"]

model_names = ["alma",
               "nllb",
               "llama3", "llama3", "llama3",
               "falcon3-mamba",
               "falcon3", "falcon3", "falcon3",
               "qwen2.5", "qwen2.5", "qwen2.5", "qwen2.5",
               "mistral",
               "bloomz",
               "opt-instruct",]
model_sizes = [None,
               None,
               "1B", "3B", "8B",
               None,
               "1B", "3B", "7B",
               "0.5B", "1.5B", "3B", "7B",
               None,
               "7B",
               None,]

In [None]:
eval_metrics(metric_names, directions, dataset_names, model_names, model_sizes, reduce_sizes)

In [None]:
metric_names = ["bleurt"]

dataset_names = ["flores", "wnt23"]
reduce_sizes = [200, 100]

directions = ["en-de", "de-en",
              "en-cs", "cs-en",
              "en-is", "is-en",
              "en-zh", "zh-en",
              "en-ru", "ru-en"]

model_names = ["bloomz",
               "opt-instruct"]
model_sizes = ["7B",
               None]

In [None]:
eval_metrics(metric_names, directions, dataset_names, model_names, model_sizes, reduce_sizes)

### Plot

In [None]:
def parallelCoordinatesPlot(title, N, data, category, ynames, colors=None, category_names=None, savepath=None):
    """
    A legend is added, if category_names is not None.

    :param title: The title of the plot.
    :param N: Number of data sets (i.e., lines).
    :param data: A list containing one array per parallel axis, each containing N data points.
    :param category: An array containing the category of each data set.
    :param category_names: Labels of the categories. Must have the same length as set(category).
    :param ynames: The labels of the parallel axes.
    :param colors: A colormap to use.
    :return:
    """

    fig, host = plt.subplots(figsize=(24, 8))

    # organize the data
    ys = np.dstack(data)[0]
    ymins = ys.min(axis=0)
    ymaxs = ys.max(axis=0)
    dys = ymaxs - ymins
    ymins -= dys * 0.05  # add 5% padding below and above
    ymaxs += dys * 0.05
    dys = ymaxs - ymins

    # transform all data to be compatible with the main axis
    zs = np.zeros_like(ys)
    zs[:, 0] = ys[:, 0]
    zs[:, 1:] = (ys[:, 1:] - ymins[1:]) / dys[1:] * dys[0] + ymins[0]

    axes = [host] + [host.twinx() for i in range(ys.shape[1] - 1)]
    for i, ax in enumerate(axes):
        ax.set_ylim(ymins[i], ymaxs[i])
        ax.spines['top'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        if ax != host:
            ax.spines['left'].set_visible(False)
            ax.yaxis.set_ticks_position('right')
            ax.spines["right"].set_position(("axes", i / (ys.shape[1] - 1)))

    host.set_xlim(0, ys.shape[1] - 1)
    host.set_xticks(range(ys.shape[1]))
    host.set_xticklabels(ynames, fontsize=7)
    host.tick_params(axis='x', which='major', pad=7)
    host.spines['right'].set_visible(False)
    host.xaxis.tick_top()
    host.set_title(title, fontsize=15)

    if colors is None:
        colors = plt.cm.tab10.colors
    if category_names is not None:
        legend_handles = [None for _ in category_names]
    else:
        legend_handles = [None for _ in set(category)]
    for j in range(N):
        # to just draw straight lines between the axes:
        # host.plot(range(ys.shape[1]), zs[j,:], c=colors[(category[j] - 1) % len(colors) ])

        # create bezier curves
        # for each axis, there will a control vertex at the point itself, one at 1/3rd towards the previous and one
        #   at one third towards the next axis; the first and last axis have one less control vertex
        # x-coordinate of the control vertices: at each integer (for the axes) and two inbetween
        # y-coordinate: repeat every point three times, except the first and last only twice
        verts = list(zip([x for x in np.linspace(0, len(ys) - 1, len(ys) * 3 - 2, endpoint=True)],
                         np.repeat(zs[j, :], 3)[1:-1]))
        # for x,y in verts: host.plot(x, y, 'go') # to show the control points of the beziers
        codes = [Path.MOVETO] + [Path.CURVE4 for _ in range(len(verts) - 1)]
        path = Path(verts, codes)
        patch = patches.PathPatch(path, facecolor='none', lw=1, edgecolor=colors[category[j]])
        legend_handles[category[j]] = patch
        host.add_patch(patch)

        if category_names is not None:
            host.legend(legend_handles, category_names,
                        loc='lower center', bbox_to_anchor=(0.5, -0.18),
                        ncol=len(category_names)//2, fancybox=True, shadow=True)

    plt.tight_layout()
    if savepath is not None:
        plt.savefig(savepath)
    plt.show()
    plt.close()

In [None]:
def get_full_model_name(model_name, model_size):
    return f"{model_name}"+(f"-{model_size}" if model_size is not None else "")

def concatenate_results(directions, models, model_sizes, datasets, reduce_sizes, metrics_names, agg_keys, verbose=False):
    """
    agg_keys should be a list containing keys present in output dictonnary for every metrics desired
    for all metrics, can be only ["mean_score", "std_score", "std_unbias_score"] (or less)
    """
    metrics_names2metrics = {"ROUGE-1": "rouge1",
                             "ROUGE-2": "rouge2",
                             "ROUGE-L": "rougeL",
                             "ROUGE-Lsum": "rougeLsum",
                             "BLEU": "bleu",
                             "SacreBLEU": "sacrebleu",
                             "chrF": "chrf",
                             "chrF++": "chrfplusplus",
                             "COMET": "comet",
                             "BLEURT": "bleurt",
                             "BERTscore": "bertscore",
                             "METEOR": "meteor"}
    metrics = [metrics_names2metrics[name] for name in metrics_names] # Want something ordered, don't only take dico.values()
    
    data = {key: [[] for _ in range(len(metrics))] for key in agg_keys}
    
    print("Extracting and concatenating metrics...")
    for dataset_name, reduce_size in zip(datasets, reduce_sizes):
        for model_name, model_size in zip(models, model_sizes):
            for direction in directions:
                eval_filename = get_eval_filename(direction, dataset_name, model_name, model_size, reduce_size)
                if verbose:
                    print(eval_filename)
                with open(eval_filename, "rb") as f:
                    evaluations = pickle.load(f)
                for i, m in enumerate(metrics):
                    for key in agg_keys:
                        data[key][i].append(evaluations[m][key])
    return data

def make_parallel_plot(directions,
                       models, model_sizes,
                       datasets, reduce_sizes,
                       metrics_names,
                       list_colors_per, colors=None, verbose=False, savepath=None):
    # Aggregate eval data
    data = concatenate_results(directions, models, model_sizes, datasets, reduce_sizes, metrics_names, agg_keys=["mean_score"], verbose=verbose)
    data = data["mean_score"]

    # Generate plot categories
    for colors_per in list_colors_per:
        print(f"Generating categories based {colors_per} type ('colors_per' param)...")
        elem2cat = ({dataset_name: i for i, dataset_name in enumerate(datasets)} if colors_per == "dataset"
                    else {direction: i for i, direction in enumerate(directions)} if colors_per == "direction"
                    else {get_full_model_name(model_name, model_size): i for i, (model_name, model_size) in enumerate(zip(models, model_sizes))} if colors_per == "model"
                    else {})
        dataset_name2real_name = {"wnt23": "WNT23", "flores": "FLORES+"}
        dataset_name2real_name_and_reduction = {}
        for dataset_name, reduce_size in zip(datasets, reduce_sizes):
            dataset_name2real_name_and_reduction[dataset_name] = dataset_name2real_name[dataset_name] + f" - reduct to {reduce_size} samples"
        category_names = ([dataset_name2real_name_and_reduction[dataset_name] for dataset_name in datasets] if colors_per == "dataset"
                        else directions if colors_per == "direction"
                        else [get_full_model_name(model_name, model_size) for model_name, model_size in zip(models, model_sizes)] if colors_per == "model"
                        else ["No category"])
        category = []
        for dataset_name in datasets:
            for model_name, model_size in zip(models, model_sizes):
                for direction in directions:
                    if colors_per == "dataset":
                        category.append(elem2cat[dataset_name])
                    elif colors_per == "direction":
                        category.append(elem2cat[direction])
                    elif colors_per == "model":
                        category.append(elem2cat[get_full_model_name(model_name, model_size)])
                    else:
                        category.append(0)

        if colors is None:
            if colors_per == "dataset":
                colors = plt.cm.Accent.colors
            elif colors_per == "direction":
                colors = plt.cm.tab20.colors
            else:
                colors = plt.cm.Dark2.colors + plt.cm.tab10.colors[0:7] + plt.cm.tab10.colors[8:]

        # Plot
        print("Plotting in parallel coordinates plot...")
        n_datasets, n_directions, n_models = len(directions), len(models), len(datasets)
        parallelCoordinatesPlot(title = f"Influence of {colors_per} on translation performances",
                                N = n_datasets*n_directions*n_models,
                                data = data,
                                category = category,
                                category_names = category_names,
                                ynames = metrics_names,
                                colors=colors,
                                savepath=savepath)

In [None]:
metric_names = ["ROUGE-1", "ROUGE-2", "ROUGE-L", "ROUGE-Lsum",
                "BLEU", "SacreBLEU", "chrF", "chrF++",
                "COMET", "BLEURT", "BERTscore", "METEOR"]

dataset_names = ["wnt23", "flores"]
reduce_sizes = [100, 200]

directions = ["en-de", "de-en",
              "en-cs", "cs-en",
              "en-is", "is-en",
              "en-zh", "zh-en",
              "en-ru", "ru-en"]

model_names = ["alma",
               "nllb",
               "llama3", "llama3", "llama3",
               "falcon3-mamba",
               "falcon3", "falcon3", "falcon3",
               "qwen2.5", "qwen2.5", "qwen2.5", "qwen2.5",
               "mistral",
               "bloomz",
               "opt-instruct",]
model_sizes = [None,
               None,
               "1B", "3B", "8B",
               None,
               "1B", "3B", "7B",
               "0.5B", "1.5B", "3B", "7B",
               None,
               "7B",
               None,]

make_parallel_plot(directions,
                    model_names, model_sizes,
                    dataset_names, reduce_sizes,
                    metric_names,
                    list_colors_per = ["dataset"],
                    colors=None,
                    savepath = "./results/evaluations_figures/all_dataset")

In [None]:
metric_names = ["ROUGE-1", "ROUGE-2", "ROUGE-L", "ROUGE-Lsum",
                "BLEU", "SacreBLEU", "chrF", "chrF++",
                "COMET", "BLEURT", "BERTscore", "METEOR"]

dataset_names = ["flores"]
reduce_sizes = [200]

directions = ["en-de", "de-en",
              "en-cs", "cs-en",
              "en-is", "is-en",
              "en-zh", "zh-en",
              "en-ru", "ru-en"]

model_names = ["alma", "nllb", "llama3", "falcon3-mamba", "falcon3", "qwen2.5",
               "mistral", "bloomz"]
model_sizes = [None, None, "8B", None, "7B", "7B",
               None, "7B", None,]

make_parallel_plot(directions,
                    model_names, model_sizes,
                    dataset_names, reduce_sizes,
                    metric_names,
                    list_colors_per = ["model"],
                    colors=None,
                    savepath="./results/evaluations_figures/main_model_flores")

In [None]:
metric_names = ["ROUGE-1", "ROUGE-2", "ROUGE-L", "ROUGE-Lsum",
                "BLEU", "SacreBLEU", "chrF", "chrF++",
                "COMET", "BLEURT", "BERTscore", "METEOR"]

dataset_names = ["flores"]
reduce_sizes = [200]

directions = ["en-de", "de-en",
              "en-cs", "cs-en",
              "en-is", "is-en",
              "en-zh", "zh-en",
              "en-ru", "ru-en"]

model_names = ["alma",
               "llama3",
               "falcon3",
               "qwen2.5",
               "opt-instruct",]
model_sizes = [None,
               "3B",
               "3B",
               "3B",
               None,]

make_parallel_plot(directions,
                    model_names, model_sizes,
                    dataset_names, reduce_sizes,
                    metric_names,
                    list_colors_per = ["model"],
                    colors=None,
                    savepath = "./results/evaluations_figures/medium_and_alma_model_flores")

In [None]:
metric_names = ["ROUGE-1", "ROUGE-2", "ROUGE-L", "ROUGE-Lsum",
                "BLEU", "SacreBLEU", "chrF", "chrF++",
                "COMET", "BLEURT", "BERTscore", "METEOR"]

dataset_names = ["flores"]
reduce_sizes = [200]

directions = ["en-de", "de-en",
              "en-cs", "cs-en",
              "en-is", "is-en",
              "en-zh", "zh-en",
              "en-ru", "ru-en"]

model_names = ["alma",
               "llama3",
               "falcon3",
               "qwen2.5", "qwen2.5"]

model_sizes = [None,
               "1B",
               "1B", "3B", "7B",
               "0.5B", "1.5B"]

make_parallel_plot(directions,
                    model_names, model_sizes,
                    dataset_names, reduce_sizes,
                    metric_names,
                    list_colors_per = ["model"],
                    colors=None,
                    savepath = "./results/evaluations_figures/small_and_alma_model_flores")

In [None]:
metric_names = ["ROUGE-1", "ROUGE-2", "ROUGE-L", "ROUGE-Lsum",
                "BLEU", "SacreBLEU", "chrF", "chrF++",
                "COMET", "BLEURT", "BERTscore", "METEOR"]

dataset_names = ["flores"]
reduce_sizes = [200]

directions = ["en-zh", "zh-en"]

model_names = ["alma", "nllb",
               "llama3",
               "falcon3-mamba",
               "falcon3",
               "qwen2.5", "qwen2.5",
               "mistral",
               "bloomz"]
model_sizes = [None, None,
               "8B",
            None,
            "7B",
               "3B", "7B",
               None,
               "7B"]

make_parallel_plot(directions,
                    model_names, model_sizes,
                    dataset_names, reduce_sizes,
                    metric_names,
                    list_colors_per = ["model"],
                    colors = None,
                    savepath = "./results/evaluations_figures/chinese_dir_model_flores")

In [None]:
metric_names = ["ROUGE-1", "ROUGE-2", "ROUGE-L", "ROUGE-Lsum",
                "BLEU", "SacreBLEU", "chrF", "chrF++",
                "COMET", "BLEURT", "BERTscore", "METEOR"]

dataset_names = ["wnt23", "flores"]
reduce_sizes = [100, 200]

directions = ["en-de",
              "en-cs",
              "en-is",
              "en-zh",
              "en-ru",]

model_names = ["alma",
               "nllb"]

model_sizes = [None,
               None]

make_parallel_plot(directions,
                    model_names, model_sizes,
                    dataset_names, reduce_sizes,
                    metric_names,
                    list_colors_per = ["direction"],
                    colors = plt.cm.Dark2.colors,
                    savepath = "./results/evaluations_figures/alma-nllb_from_en_direction")

In [None]:
metric_names = ["ROUGE-1", "ROUGE-2", "ROUGE-L", "ROUGE-Lsum",
                "BLEU", "SacreBLEU", "chrF", "chrF++",
                "COMET", "BLEURT", "BERTscore", "METEOR"]

dataset_names = ["wnt23", "flores"]
reduce_sizes = [100, 200]

directions = ["de-en",
              "cs-en",
              "is-en",
              "zh-en",
              "ru-en"]

model_names = ["alma",
               "nllb"]

model_sizes = [None,
               None]

make_parallel_plot(directions,
                    model_names, model_sizes,
                    dataset_names, reduce_sizes,
                    metric_names,
                    list_colors_per = ["direction"],
                    colors = plt.cm.Dark2.colors,
                    savepath = "./results/evaluations_figures/alma-nllb_to_en_direction")