In [8]:
%env XDG_CACHE=/workspace/.cache
%env HF_HOME=/workspace/.cache/huggingface

env: XDG_CACHE=/workspace/.cache
env: HF_HOME=/workspace/.cache/huggingface


In [9]:
from datasets import load_dataset
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import pandas as pd
import random
from transformers import AutoTokenizer, AutoModelForCausalLM
import pyonmttok
import ctranslate2
from metrics import *

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [10]:
model_id = "projecte-aina/aguila-7b"
#model_id = "tiiuae/falcon-7b"
model_name = model_id.split('/')[1]
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id,
                                             torch_dtype=torch.bfloat16,
                                             trust_remote_code=True,
                                             device_map="auto")


Loading checkpoint shards: 100%|██████████| 15/15 [00:11<00:00,  1.32it/s]


In [11]:
## Lets Do the translation layer
from huggingface_hub import snapshot_download
print("Loading translator Models...")

ca_en_model_folder = snapshot_download(repo_id="projecte-aina/mt-aina-ca-en", revision="main")
tokenizer_ca_en = pyonmttok.Tokenizer(
    mode="none", sp_model_path=ca_en_model_folder + "/spm.model"
)
ca_en_model = ctranslate2.Translator(ca_en_model_folder, device="cuda")

Loading translator Models...


Fetching 5 files: 100%|██████████| 5/5 [00:00<00:00, 27666.91it/s]


In [12]:
def min_max_scaling(tensor):
    min_val = torch.min(tensor)
    max_val = torch.max(tensor)
    scaled_tensor = (tensor - min_val) / (max_val - min_val)
    return scaled_tensor


def compute_probability(input_text, answer):
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
    answer_tokens = tokenizer(answer)['input_ids']
    answer_probability = 1
    with torch.no_grad():
        for token in answer_tokens:
            outputs = model(**inputs, output_hidden_states=True, return_dict=True)

            # Logits are in the outputs, you can access the last token's logits like this:
            logits = outputs.logits[:, -1, :]
            #log_probs = torch.log_softmax(logits, dim=-1).cpu()
            log_probs = min_max_scaling(logits)
            answer_probability *= log_probs[0][token]

            # Prepare input_ids for the next token prediction
            new_token = torch.tensor([[token]]).to(model.device)
            inputs = {'input_ids': torch.cat([inputs['input_ids'], new_token], dim=1),
                    'attention_mask': torch.cat([inputs['attention_mask'], torch.tensor([[1]]).to(model.device)], dim=1)}
    #return torch.exp(answer_probability).item()
    return answer_probability.item()


def run_inference(txt, num_tokens=20, stop_text='\n'):
    # Tokenize the input text
    inputs = tokenizer(txt, return_tensors="pt").to(model.device)
    # Calculate the total length of the output (input length + number of tokens to generate)
    max_length = len(inputs['input_ids'][0]) + num_tokens

    with torch.no_grad():
        # Generate tokens
        tokens = model.generate(**inputs, do_sample=True, top_k=1, eos_token_id=tokenizer.eos_token_id, max_length=max_length)

        # Decode the generated tokens into text
        generated_text = tokenizer.decode(tokens[0], skip_special_tokens=True)

        # Slice the generated text to exclude the input prompt
        generated_only = generated_text[len(tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=True)):]

        # If a stop text is found, truncate the output at its first occurrence
        if stop_text in generated_only:
            generated_only = generated_only.split(stop_text)[0]

        return generated_only.strip()


def translate(sample):
    def translate_to_english(txt):
        lines = txt.split("\n")
        toks, _ = tokenizer_ca_en.tokenize_batch(lines)
        translated = ca_en_model.translate_batch(toks)
        ts = []
        for t in translated:
            ts.append(tokenizer_ca_en.detokenize(t.hypotheses[0]))

        return "\n".join(ts)
    en_prompt = translate_to_english(sample['prompt'])
    en_answer = translate_to_english(sample['answer'])
    return {"prompt": en_prompt, "answer": en_answer}


def compute_metrics(sample):
    prob = compute_probability(sample['prompt'], sample['answer'])
    prediction = run_inference(sample['prompt'])
    f1 = f1_score(prediction, sample['answer'])
    bleu = calculate_bleu_score(prediction, sample['answer'])
    return {"prediction": prediction, "prob": prob, "f1": f1, "bleu": bleu}


import re
from collections import Counter
import nltk
from nltk.translate.bleu_score import sentence_bleu
from nltk.tokenize import word_tokenize
nltk.download('punkt')
import math
def calculate_bleu_score(prediction, ground_truth):
    """
    Calculate BLEU score for a prediction against a ground truth.

    Args:
    prediction (str): The predicted text.
    ground_truth (str): The reference text (ground truth).

    Returns:
    float: The BLEU score.
    """
    # Tokenizing the texts into words
    prediction_tokens = word_tokenize(prediction)
    ground_truth_tokens = [word_tokenize(ground_truth)]  # List of lists for multiple references support

    # Calculating BLEU score
    bleu_score = sentence_bleu(ground_truth_tokens, prediction_tokens)
    return bleu_score


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [13]:
viquiquad = load_dataset("data", data_files="viquiquad.csv", split="train[:10]")
viquiquad_en = viquiquad.map(translate)

Map: 100%|██████████| 10/10 [00:03<00:00,  2.70 examples/s]


In [14]:
results_ca = viquiquad.map(compute_metrics)
results_ca.to_pandas()

The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
Map: 100%|██████████| 10/10 [00:22<00:00,  2.27s/ examples]


Unnamed: 0,prompt,answer,prediction,prob,f1,bleu
0,En aquesta època es va consolidar el concepte ...,"Life, Paris-Match, Stern o Época","Paris-Match, Stern i Época",0.648438,0.666667,4.947126e-78
1,"Després de la seva mort, s'han celebrat divers...",en un certamen de fotografia en blanc i negre ...,----,0.546875,0.0,0.0
2,El 1952 es va fer soci de l'Agrupació Fotogràf...,un treball sobre l'emplaçament on es construir...,----,0.5625,0.0,0.0
3,Durant la dècada de 1960 també va exercir de r...,viatjar per un gran nombre de països,----,0.527344,0.0,0.0
4,El 1957 hi va haver la primera de les dues exp...,el pas de les classes populars cap al nou ento...,----,0.546875,0.0,0.0
5,"Finalment, cal destacar que cap de les dues pr...",incòmoda,----,0.5625,0.0,0.0
6,"Els dos relats però, no només presenten simili...",parlant amb ella,----,0.539062,0.0,0.0
7,Helen Fielding va crear la vida de Bridget Jon...,a través dels conflictes sentimentals en les p...,----,0.507812,0.0,0.0
8,"No és fins un mes més tard, ja en ple Nadal, q...",Natasha,----,0.78125,0.0,0.0
9,L'elecció de Renée Zellweger com protagonista ...,Time,----,0.789062,0.0,0.0


In [15]:
results_en = viquiquad_en.map(compute_metrics)
results_en.to_pandas()

The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
Map: 100%|██████████| 10/10 [00:22<00:00,  2.22s/ examples]


Unnamed: 0,prompt,answer,prediction,prob,f1,bleu
0,"During this period, the modern concept of phot...","Life, Paris-Match, Stern or Época","Paris Match, Paris Match, Paris Match, Paris M...",0.326172,0.0,1.1008880000000001e-231
1,"After his death, several events and exhibition...",in a black and white photography contest about...,Photographic contest about Barcelona,0.589844,0.5,2.474304e-78
2,In 1952 he became a member of the Photographic...,a work on the site where the College of Archit...,to draw the site where the Colegio de Arquitec...,0.75,0.416667,0.2393949
3,During the 1960s he also served as a reporter ...,travel through a large number of countries,photographic reporting,0.447266,0.0,0.0
4,In 1957 there was the first of the two exhibit...,the shift of popular classes towards the new u...,The photographic poetics of the time,0.279297,0.166667,7.107197e-232
5,"Finally, it should be noted that neither of th...",inconvenient,Bridget Jones is aware of her own imperfection...,0.503906,0.0,0.0
6,"The two stories, however, not only present sim...",talking to her,Mark Darcy,0.6875,0.0,0.0
7,Helen Fielding created Bridget Jones's life in...,through the sentimental conflicts in couples a...,through sentimental conflicts in couples and t...,0.683594,1.0,0.8423627
8,"It is not until a month later, already in the ...",Natasha Natasha,Mark Darcy,0.431641,0.0,0.0
9,The choice of Renée Zellweger as the main char...,Time to Time,Time magazine,0.492188,0.4,9.29188e-232


In [16]:
results_ca.to_csv(f"results/{model_name}-viquiquad-ca.csv", index=False)
results_en.to_csv(f"results/{model_name}-viquiquad-en.csv", index=False)

Creating CSV from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 261.31ba/s]
Creating CSV from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 362.45ba/s]


12753

In [18]:
results_ca_mean = results_ca.to_pandas()[['prob', 'f1', 'bleu']].mean()
results_en_mean = results_en.to_pandas()[['prob', 'f1', 'bleu']].mean()
print(results_ca_mean)
print(results_en_mean)

prob    6.011719e-01
f1      6.666667e-02
bleu    4.947126e-79
dtype: float64
prob    0.519141
f1      0.248333
bleu    0.108176
dtype: float64
