In [56]:

import sacrebleu.utils
from sacrebleu.metrics.bleu import BLEU
from tqdm import tqdm
import transformers
import torch

In [57]:
template = """{source_lang}: {source_text}
{target_lang}: {target_text}"""

def apply_prompt(training=False, eos_token=None, **kwargs):
    # note: we strip because of potential trailing whitespace
    # we also provide a default value for target_text so that it can be omitted
    return template.format(**{"target_text": "", **kwargs}).strip() + ("" if not training or eos_token is None else eos_token)

def apply_prompt_n_shot(examples, n: int, eos_token: str, **kwargs):
    return (eos_token + "\n\n").join(
        [apply_prompt(**{"target_text": "", **example}) for example in examples[:n]] + [apply_prompt(**kwargs)]
    )

EXAMPLE_SENTENCES = [
    {
        "source_lang": "English",
        "target_lang": "Czech",
        "source_text": "I am sorry to hear that.",
        "target_text": "To je mi líto.",
    },
    {
        "source_lang": "English",
        "target_lang": "Czech",
        "source_text": "How much does it cost?",
        "target_text": "Kolik to stojí?",
    },
    {
        "source_lang": "English",
        "target_lang": "Czech",
        "source_text": "Prague is the capital of the Czech Republic.",
        "target_text": "Praha je hlavní město České republiky.",
    },
    {
        "source_lang": "English",
        "target_lang": "Czech",
        "source_text": "Pay attention to the road.",
        "target_text": "Dávej pozor na silnici.",
    },
    {
        "source_lang": "English",
        "target_lang": "Czech",
        "source_text": "I have a headache.",
        "target_text": "Bolí mě hlava.",
    }
]

In [58]:
print(sacrebleu.utils.get_source_file("wmt22", "en-cs"))
with open(sacrebleu.utils.get_source_file("wmt22", "en-cs"), "r", encoding="utf-8") as fd:
    sources = list(map(str.strip, fd.readlines()))
with open(sacrebleu.utils.get_reference_files("wmt22", "en-cs")[0], "r", encoding="utf-8") as fd:
    references = list(map(str.strip, fd.readlines()))
source_lang = "English"
target_lang = "Czech"

/storage/praha1/home/hrabalm/.sacrebleu/wmt22/wmt22.en-cs.src


In [59]:
from transformers import StoppingCriteria
class EosListStoppingCriteria(StoppingCriteria):
    # Adopted from: https://github.com/huggingface/transformers/issues/26959
    def __init__(self, eos_sequence = [13]):  # Stop on newline
        self.eos_sequence = eos_sequence

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        last_ids = input_ids[:,-len(self.eos_sequence):].tolist()
        return self.eos_sequence in last_ids

def translate(model, tokenizer, source_lang, target_lang, source_texts: list[str], n_shot: int = 0):
    prompts = [apply_prompt_n_shot(EXAMPLE_SENTENCES, n_shot, eos_token=tokenizer.eos_token, source_lang=source_lang, target_lang=target_lang, source_text=source_text) for source_text in source_texts]

    translations = []
    for prompt in prompts:
        inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
        outputs = model.generate(**inputs, max_new_tokens=256, use_cache=True, stopping_criteria=[EosListStoppingCriteria()])
        decoded = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
        translations.append(decoded)

    return translations

def evaluate(model, tokenizer, n_shot: int = 0):
    translations = translate(model, tokenizer, source_lang, target_lang, sources, n_shot=n_shot)
    return BLEU().corpus_score(translations, [references])

In [60]:
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "outputs/mistral-ft-qlora",
    max_seq_length = 4096,
    dtype = None,
    load_in_4bit = True,
)
FastLanguageModel.for_inference(model)


==((====))==  Unsloth: Fast Mistral patching release 2024.3
   \\   /|    GPU: Tesla T4. Max memory: 14.581 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.0. CUDA = 7.5. CUDA Toolkit = 12.1.
\        /    Bfloat16 = FALSE. Xformers = 0.0.24. FA = False.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth




In [None]:
sources[:5]

: 

In [61]:
translate(model, tokenizer, source_lang, target_lang, sources[:5])

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


['Pokud nebudou nalézt, budou jistě volat.',
 'Ale je to lépe, když se blíží k vaší dodací adresě, můžete se s nimi spojit.',
 'Samco Sport je kvalitní kovová nádoba pro chlazení motoru.',
 'Speciálně navržený pro všechny ventilní trubice motoru, ventilní trubice karburátoru, trubice ventilace nádrže na palivo, trubice odvodu přebytku chlazení a může být použit pro trubky nabíjecího nádrže a izolace vodičů.',
 'Vhodné použití v nízkotlakových instalacích.']

In [62]:
evaluate(model, tokenizer, n_shot=0)  # note: took ~79min
# BLEU = 18.03 47.0/22.6/12.9/7.7 (BP = 1.000 ratio = 1.013 hyp_len = 35230 ref_len = 34787)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for o

BLEU = 18.03 47.0/22.6/12.9/7.7 (BP = 1.000 ratio = 1.013 hyp_len = 35230 ref_len = 34787)

In [63]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/mistral-7b-bnb-4bit",
    max_seq_length=4096,
    load_in_4bit=True,
)
FastLanguageModel.for_inference(model)
evaluate(model, tokenizer, n_shot=5)

==((====))==  Unsloth: Fast Mistral patching release 2024.3
   \\   /|    GPU: Tesla T4. Max memory: 14.581 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.0. CUDA = 7.5. CUDA Toolkit = 12.1.
\        /    Bfloat16 = FALSE. Xformers = 0.0.24. FA = False.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for o

KeyboardInterrupt: 