# 1. Environment Setup & Data Loading



In [None]:
import json
import glob
import pandas as pd
import ast

import torch
import gc

import numpy as np

import pandas as pd
import os

import re

import matplotlib.pyplot as plt
import seaborn as sns

## 1.1 Mount Google Drive
First, We need to connect the Google Drive to this Colab notebook. This is crucial  to save the results and doesn't lose work due runtime timeouts.


In [None]:
from google.colab import drive

# Mount Google Drive to /content/drive
if not os.path.exists('/content/drive'):
    print("Mounting Google Drive...")
    drive.mount('/content/drive')
else:
    print("Google Drive is already mounted.")

# Verify mount
if os.path.exists('/content/drive/MyDrive'):
    print("SUCCESS: Google Drive mounted successfully.")
else:
    print("WARNING: Drive mount failed. Check permissions.")

## 1.2 Clone AskQE repository & Change Directory
Clone the official **askqe** repository into the local Colab environment (/content/askqe). This provides access to the raw datasets (ContraTICO and BioMQM) and necessary utility scripts without latency.






In [None]:
# Clone only if not already present to avoid errors on re-runs
if not os.path.exists('/content/askqe'):
    print("Cloning AskQE repository...")
    !git clone https://github.com/dayeonki/askqe
else:
    print("Repository already cloned.")

# Change Directory
%cd /content/askqe

## 1.3 Installing dependencies


In [None]:
!pip install -q bitsandbytes accelerate peft sentence-transformers jsonlines datasets einops

## 1.4 Define Functions to parse data
Define helper functions to parse the specific JSONL schemas of the datasets:

load_contratico: Iterates through perturbation-specific files (e.g., spelling.jsonl, omission.jsonl) for a target language pair, extracting the perturbation type from the filename.

load_biomqm: Loads the biomedical domain dataset containing real-world translation errors.

Both functions standardize the output into a unified Pandas DataFrame structure containing: source, mt, backtrans, questions, and perturbation_type.

In [None]:
#  HELPER 1: Load Questions into a Map (ID -> List[Questions])
def get_question_map(repo_root):
    """
    Searches for the atomic question file in QG folder and builds a hash map.
    Returns: dict { 'doc_id': ['Question 1', 'Question 2'] }
    """
    # Try to find the specific atomic file.
    search_pattern = os.path.join(repo_root, "QG", "*", "atomic_*.jsonl")
    candidates = glob.glob(search_pattern)

    if not candidates:
        print("[ERROR] No QG atomic file found! Check path: askqe/QG/")
        return {}

    q_file = candidates[0]
    print(f"Loading Questions from: {q_file}")

    q_map = {}
    with open(q_file, 'r') as f:
        for line in f:
            try:
                d = json.loads(line)
                row_id = d.get('id')
                q_raw = d.get('questions')

                # Parse Stringified List "['a', 'b']" -> List ['a', 'b']
                if isinstance(q_raw, str):
                    try:
                        q_list = ast.literal_eval(q_raw)
                    except:
                        q_list = []
                else:
                    q_list = q_raw if isinstance(q_raw, list) else []

                if row_id:
                    q_map[row_id] = q_list
            except:
                continue
    print(f"-> Mapped {len(q_map)} IDs with questions.")
    return q_map

# HELPER 2: Load ContraTICO (Linked via ID)
def load_contratico_linked(repo_root, lang_pair, q_map):
    # Target: backtranslation/en-es/*.jsonl
    target_dir = os.path.join(repo_root, "backtranslation", lang_pair)
    files = glob.glob(os.path.join(target_dir, "*.jsonl"))

    all_rows = []
    print(f"Processing ContraTICO for {lang_pair}...")

    for p in files:
        ptype = os.path.basename(p).replace(".jsonl", "").replace("bt-", "")
        with open(p, 'r') as f:
            for line in f:
                d = json.loads(line)
                row_id = d.get('id')

                # Retrieve Questions from Map
                questions = q_map.get(row_id, [])

                # We need MT (pert_es/es) and Backtrans (bt_pert_es/es)
                # Field names are dynamic based on language (e.g. pert_es, pert_es)
                # We simply look for keys starting with 'pert_' and 'bt_pert_'
                mt_text = next((v for k,v in d.items() if k.startswith('pert_') and k != 'perturbation'), "")
                bt_text = next((v for k,v in d.items() if k.startswith('bt_pert_')), "")

                if questions and mt_text:
                    all_rows.append({
                        "id": row_id,
                        "source": d.get('en'), # Assuming 'en' is source
                        "mt": mt_text,
                        "backtrans": bt_text,
                        "questions": questions,
                        "perturbation_type": ptype,
                        "dataset": "ContraTICO"
                    })
    return pd.DataFrame(all_rows)

#  HELPER 3: Load BioMQM (Merged via Source)
def load_biomqm_merged(repo_root, lang_pair_filter):
    path_dev = os.path.join(repo_root, "biomqm/dev_with_backtranslation.jsonl")
    path_qg = os.path.join(repo_root, "biomqm/askqe/askqe_qg.jsonl")

    if not os.path.exists(path_dev): return pd.DataFrame()

    # Load Dev
    dev_rows = []
    tgt_lang_code = lang_pair_filter.split("-")[1] # "es"
    with open(path_dev, 'r') as f:
        for line in f:
            d = json.loads(line)
            if d.get("lang_tgt") == tgt_lang_code:
                dev_rows.append(d)
    df_dev = pd.DataFrame(dev_rows)

    # Load QG
    qg_rows = []
    with open(path_qg, 'r') as f:
        for line in f:
            d = json.loads(line)
            if d.get("lang_tgt") == tgt_lang_code:
                qg_rows.append(d)
    df_qg = pd.DataFrame(qg_rows)

    if df_dev.empty or df_qg.empty: return pd.DataFrame()

    # Merge on Source text ('src')
    merged = pd.merge(df_dev, df_qg[['src', 'questions']], on='src', how='inner')

    final_rows = []
    for _, row in merged.iterrows():
        # Parse questions
        q_raw = row['questions']
        q_list = ast.literal_eval(q_raw) if isinstance(q_raw, str) else q_raw

        final_rows.append({
            "source": row['src'],
            "mt": row['tgt'],
            "backtrans": row.get('bt_tgt'),
            "questions": q_list,
            "perturbation_type": "Real_Error",
            "errors": row.get('errors_tgt', []),
            "dataset": "BioMQM"
        })
    return pd.DataFrame(final_rows)

## 1.5 Load Data & Save on drive to persistency
We explicitly set the target language pair to English-Spanish (en-es). We execute the parsing functions defined in the previous step and serialize the resulting DataFrames (.pkl format) to the persistent Google Drive directory.

In [None]:
# CONFIG
TARGET_LANG = "en-es"
SAVE_DIR = "/content/drive/MyDrive/Progetto_AskQE/data"
REPO_ROOT = "/content/askqe"

# 1. Build Question Map (Critical Step)
print("--- Step 1: Mapping Questions ---")
q_map = get_question_map(REPO_ROOT)

if q_map:
    # 2. Load Datasets
    print(f"\n--- Step 2: Loading Datasets ({TARGET_LANG}) ---")
    df_contra = load_contratico_linked(REPO_ROOT, TARGET_LANG, q_map)
    df_bio = load_biomqm_merged(REPO_ROOT, TARGET_LANG)

    # 3. Save
    print(f"\n--- Step 3: Saving to Drive ---")

    df_contra.to_pickle(f"{SAVE_DIR}/df_contra_{TARGET_LANG}.pkl")
    df_bio.to_pickle(f"{SAVE_DIR}/df_bio_{TARGET_LANG}.pkl")

    # 4. Report
    print(f"\n[DONE] Checkpoint Created.")
    print(f"ContraTICO: {len(df_contra)} rows")
    print(f"BioMQM:     {len(df_bio)} rows")

    if not df_contra.empty:
        print(f"Sample Question: {df_contra.iloc[0]['questions']}")
else:
    print("[FAIL] Could not build question map. Aborting save.")


## 1.6  Sample from Datasets
per cotnratico ogni source è composta da 8 pert type e quindi prendiamo 63 source per fare 63*8

In [None]:
def sample_df_by_perturbation_contraTICO(df, n_samples_per_type):
    """
    Samples n_sample_per_type rows for each unique perturbation_type in the DataFrame.
    """
    sampled_df_list = []
    for p_type, group in df.groupby('perturbation_type'):
        sampled_df_list.append(group.head(n_samples_per_type))
    return pd.concat(sampled_df_list)

n_df_contra = sample_df_by_perturbation_contraTICO(df_contra,63)
# Sample first 500 rows of BioMQM
n_df_biomqm = df_bio.head(500)

n_df_contra.to_pickle(f"{SAVE_DIR}/n_df_contra_{TARGET_LANG}.pkl")
n_df_biomqm.to_pickle(f"{SAVE_DIR}/n_df_bio_{TARGET_LANG}.pkl")

print(len(n_df_contra))
print(len(n_df_biomqm))

# 2. Question Answering for each model

## 2.1 Prompt & Import

In [None]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    pipeline
)
from tqdm.auto import tqdm


QA_PROMPT_TEMPLATE = """
You are a question answering model.

Rules:
- Use ONLY the information in the sentence.
- If the question can be answered with Yes or No, answer strictly "Yes" or "No".
- If the question asks for a specific span (number, phrase, symptom, etc.), copy it EXACTLY from the sentence.
- Do NOT invent information.
- Answer with a SHORT phrase, without explanations or additional text.
- Do not output code.
- Respond with the answer only.

Sentence: {{sentence}}
Question: {{question}}
Answer:
"""

print("Prompt Template Loaded.")

## 2.2 Create Class & Function to use models and generate answers


In [None]:
class ModelEngine:
    def __init__(self, model_id):
        self.model_id = model_id
        self.tokenizer = None
        self.model = None
        self.pipeline = None

    def load_model(self):
        """Loads the model in 4-bit quantization to fit in Colab GPU."""
        print(f"Loading {self.model_id}...")


        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        )

        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_id,
            quantization_config=bnb_config,
            device_map="auto",
            torch_dtype=torch.float16,
            trust_remote_code=True
        )

        # Create a text-generation pipeline for easier inference
        self.pipeline = pipeline(
            "text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            max_new_tokens=20,  # Limit answer length
            do_sample=False,    # Greedy decoding for reproducibility
            temperature=0.0,
            pad_token_id=self.tokenizer.eos_token_id
        )
        print(f"Model {self.model_id} loaded successfully.")

    def unload_model(self):
        """Frees up VRAM."""
        print(f"Unloading {self.model_id}...")
        del self.pipeline
        del self.model
        del self.tokenizer
        gc.collect()
        torch.cuda.empty_cache()
        print("GPU Memory cleared.")

    def clean_answer(self, ans):
        ans = ans.strip()
        ans = ans.split("\n")[0]

        if "#" in ans:
          ans = ans.split("#", 1)[0].strip()

        if "Yes" in ans:
          ans = ans.split(".")[0].strip()
          return ans

        if "1." in ans:
            ans = ans.split("1.", 1)[1].strip()

        ans = ans.split(".")[0].strip()
        return ans

    def get_confidence(self, prompt, answer_text):
        if not answer_text or not self.model:
            print("Empty answer or model not loaded")
            return 0.0

        try:
            full_text = prompt + answer_text
            inputs = self.tokenizer(full_text, return_tensors="pt").to(self.model.device)

            prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0]
            prompt_len = prompt_ids.shape[-1]
            labels = inputs["input_ids"].clone()
            if prompt_len > labels.shape[-1]:
                print("prompt_len > seq_len", prompt_len, labels.shape[-1])
                return 1e-5

            labels[:, :prompt_len] = -100

            with torch.no_grad():
                outputs = self.model(**inputs, labels=labels)
                loss = outputs.loss.item()

            conf = float(np.exp(-loss))
            print("loss:", loss, "conf:", conf)
            return conf

        except Exception as e:
            print("get_confidence error:", repr(e))
            return 1e-5


    def predict(self, sentence, questions):
      answers = []
      confidences = []
      for q in questions:
        q_list_str = str([q])
        input_text = QA_PROMPT_TEMPLATE.replace("{{sentence}}", sentence).replace("{{question}}", q_list_str)
        outputs = self.pipeline(input_text)
        raw_output = outputs[0]['generated_text']
        print("raw output --", raw_output)
        print("-- end row output --")
        if "Answer:" in raw_output:
            generated_part = raw_output.split("Answer:")[1].strip()
        else:
            generated_part = raw_output

        ans = generated_part.split("\n")[0].strip()
        ans_f = self.clean_answer(ans)
        conf = self.get_confidence(input_text, ans_f)

        confidences.append(conf)
        answers.append(ans_f)

      print(f"===== out =====")
      print(f"sentence: {sentence}")
      print(f"questions: {questions}")
      print(f"ans: {answers}")
      print(f"conf: {confidences}")
      print("============================\n")
      return answers, confidences


## 2.3 Define loop function for the inference
Main loop: Loads model -> Iterates Data -> Saves Results -> Unloads Model.
Supports resuming from interruption.
    

In [None]:
def run_inference_loop(model_name, dataframe, output_file, sample_limit):

    # 1. Initialize Engine
    engine = ModelEngine(model_name)
    engine.load_model()


    # 2. Check for existing progress
    processed_ids = set()
    if os.path.exists(output_file):
        print(f"Found existing checkpoint: {output_file}")
        with open(output_file, 'r') as f:
            for line in f:
                try:
                    data = json.loads(line)
                    processed_ids.add(data.get('row_index'))
                except:
                    continue
        print(f"Resuming... {len(processed_ids)} rows already processed.")

    # 3. Processing Loop
    print(f"Starting inference on {len(dataframe)} rows...")

    # Open file in append mode
    with open(output_file, 'a', encoding='utf-8') as f_out:
        for i, (idx, row) in enumerate( tqdm(dataframe.iterrows(), total=min(len(dataframe), sample_limit) if sample_limit else len(dataframe)) ):

            # Skip if already done
            if idx in processed_ids:
                continue

            # Optional: Limit for testing
            if sample_limit and i >= sample_limit:
                break

            questions = row['questions']
            if not questions:
                continue # Skip empty questions

            # DOUBLE INFERENCE
            # A_src: Answers based on Source Sentence
            print("scr -->")
            ans_src, conf_scr = engine.predict(row['source'], questions)
            print("<-- end scr")

            # A_bt: Answers based on Backtranslated Sentence
            print("bt -->")
            bt_text = row['backtrans'] if isinstance(row['backtrans'], str) else ""
            ans_bt, conf_bt = engine.predict(bt_text, questions)
            print("<-- end bt")

            # 4. Save Result
            result_entry = {
                "row_index": idx,
                "id": row.get('id', f"row_{idx}"),
                "perturbation_type": row['perturbation_type'],
                "questions": questions,
                "source": row['source'],
                "ans_src": ans_src,
                "conf_scr":conf_scr,
                "backtrans":row['backtrans'],
                "ans_bt": ans_bt,
                "conf_bt": conf_bt,
                "model": model_name
            }

            f_out.write(json.dumps(result_entry) + "\n")

            # Force write to disk periodically
            if idx % 10 == 0:
                f_out.flush()

    # 5. Cleanup
    engine.unload_model()
    print(f"Finished inference for {model_name}.")

## 2.4 HuggingFace Login

In [None]:
from huggingface_hub import login, whoami
from getpass import getpass

token = getpass("Insert Hugging Face token: ")

try:
    login(token)
    info = whoami()
    print(f"Login succeded! User: {info['name']}")
except Exception as e:
    print("Login failed:", e)

## 2.5 Execute LLama Inference
We execute the inference pipeline using the Llama-3.1-8B-Instruct model. The results are saved incrementally to Google Drive in JSONL format. This ensures that even in case of a Colab timeout, processed rows are preserved.

Prerequisites:

Hugging Face Token with access to meta-llama/Meta-Llama-3.1-8B-Instruct.



In [None]:
# 1. Configuration
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
DATA_DIR = "/content/drive/MyDrive/Progetto_AskQE/data"
RESULTS_DIR = "/content/drive/MyDrive/Progetto_AskQE/results"
os.makedirs(RESULTS_DIR, exist_ok=True)

# Debugging: Set to 10 to test quickly. Set to None for full run.
TEST_LIMIT = None

# 2. Load Data from Checkpoints
print(f"Loading datasets from {DATA_DIR}...")
df_contra = pd.read_pickle(f"{DATA_DIR}/n_df_contra_{TARGET_LANG}.pkl")
df_bio = pd.read_pickle(f"{DATA_DIR}/n_df_bio_{TARGET_LANG}.pkl")

# 3. Execution: ContraTICO
output_contra = f"{RESULTS_DIR}/results_llama3_contra.jsonl"
print(f"\n--- Processing ContraTICO with {MODEL_ID} ---")
print(f"Output: {output_contra}")
run_inference_loop(
    model_name=MODEL_ID,
    dataframe=df_contra,
    output_file=output_contra,
    sample_limit=TEST_LIMIT
)

# 4. Execution: BioMQM
output_bio = f"{RESULTS_DIR}/results_llama3_biomqm.jsonl"
print(f"\n--- Processing BioMQM with {MODEL_ID} ---")
print(f"Output: {output_bio}")
run_inference_loop(
    model_name=MODEL_ID,
    dataframe=df_bio,
    output_file=output_bio,
    sample_limit=TEST_LIMIT
)
print("\n[SUCCESS] Llama-3 Inference Complete.")

## 2.6 Execute Qwen Inference

We execute the inference pipeline using the Qwen2.5-7B-Instruct model. The results are saved incrementally to Google Drive in JSONL format. This ensures that even in case of a Colab timeout, processed rows are preserved.

Prerequisites:

Hugging Face Token with access to Qwen/Qwen2.5-7B-Instruct.



In [None]:
# 1. Configuration
MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
DATA_DIR = "/content/drive/MyDrive/Progetto_AskQE/data"
RESULTS_DIR = "/content/drive/MyDrive/Progetto_AskQE/results"
os.makedirs(RESULTS_DIR, exist_ok=True)

# Debugging: Set to 10 to test quickly. Set to None for full run.
TEST_LIMIT = None

# 2. Load Data
print(f"Loading datasets for {MODEL_ID}...")
df_contra = pd.read_pickle(f"{DATA_DIR}/n_df_contra_en-es.pkl")
df_bio = pd.read_pickle(f"{DATA_DIR}/n_df_bio_en-es.pkl")

# 3. Execution: ContraTICO
output_contra = f"{RESULTS_DIR}/results_qwen_contra.jsonl"
print(f"\n--- Processing ContraTICO with qwen ---")
run_inference_loop(
    model_name=MODEL_ID,
    dataframe=df_contra,
    output_file=output_contra,
    sample_limit=TEST_LIMIT
)

# 4. Execution: BioMQM
output_bio = f"{RESULTS_DIR}/results_qwen_biomqm.jsonl"
print(f"\n--- Processing BioMQM with qwen ---")
run_inference_loop(
    model_name=MODEL_ID,
    dataframe=df_bio,
    output_file=output_bio,
    sample_limit=TEST_LIMIT
)

print("\n[SUCCESS] qwen Inference Complete.")

## 2.7 Execute Gemma Inference

We execute the inference pipeline using the gemma-2-9b-it model. The results are saved incrementally to Google Drive in JSONL format. This ensures that even in case of a Colab timeout, processed rows are preserved.

Prerequisites:

Hugging Face Token with access to google/gemma-2-9b-it.

In [None]:
# 1. Configuration
MODEL_ID = "google/gemma-2-9b-it"
DATA_DIR = "/content/drive/MyDrive/Progetto_AskQE/data"
RESULTS_DIR = "/content/drive/MyDrive/Progetto_AskQE/results"
os.makedirs(RESULTS_DIR, exist_ok=True)

# Debugging: Set to 10 to test quickly. Set to None for full run.
TEST_LIMIT = None

# 2. Load Data
print(f"Loading datasets for {MODEL_ID}...")
df_contra = pd.read_pickle(f"{DATA_DIR}/n_df_contra_en-es.pkl")
df_bio = pd.read_pickle(f"{DATA_DIR}/n_df_bio_en-es.pkl")

# 3. Execution: ContraTICO
output_contra = f"{RESULTS_DIR}/results_gemma_contra.jsonl"
print(f"\n--- Processing ContraTICO with Gemma ---")
run_inference_loop(
    model_name=MODEL_ID,
    dataframe=df_contra,
    output_file=output_contra,
    sample_limit=TEST_LIMIT
)

# 4. Execution: BioMQM
output_bio = f"{RESULTS_DIR}/results_gemma_biomqm.jsonl"
print(f"\n--- Processing BioMQM with Gemma ---")
run_inference_loop(
    model_name=MODEL_ID,
    dataframe=df_bio,
    output_file=output_bio,
    sample_limit=TEST_LIMIT
)

print("\n[SUCCESS] Gemma Inference Complete.")

# 3. Extension 1 - LLM Ensamble


## 3.1 ensemble method: centroid selection
The pipeline ensures robustness by identifying the semantic consensus among different models, effectively filtering out individual hallucinations.
- Pairwise Comparison: For both the source and back-translation segments, calculate the SBERT similarity between the three LLM-generated responses
- Centroid Identification: Extract the response with the highest average similarity to the others (the "Semantic Centroid") as $best_{src}$ and $best_{bt}$.


In [None]:
from sentence_transformers import SentenceTransformer, util
from tqdm.auto import tqdm

RESULTS_DIR = "/content/drive/MyDrive/Progetto_AskQE/results"
SELECTION_MODEL = "all-mpnet-base-v2"

# CENTROID LOGIC
class CentroidSelector:
    def __init__(self):
        print(f" Loading Selection Model: {SELECTION_MODEL}...")
        self.model = SentenceTransformer(SELECTION_MODEL)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)

    def select_best_answers(self, answers_list_of_lists):
        """
        Input: [ [Ans_ModelA_Q1, Ans_ModelA_Q2], [Ans_ModelB_Q1, ...], ... ]
        Output: [Best_Ans_Q1, Best_Ans_Q2]
        """


        num_questions = len(answers_list_of_lists[0])
        num_models = len(answers_list_of_lists)

        final_answers = []

        for q_idx in range(num_questions):
            candidates = []
            for m_idx in range(num_models):
                # Safety check for index out of bounds
                if q_idx < len(answers_list_of_lists[m_idx]):
                    candidates.append(answers_list_of_lists[m_idx][q_idx])
                else:
                    candidates.append("") # Padding

            # If all empty, return empty
            if all(not c.strip() for c in candidates):
                final_answers.append("")
                continue

            # Embed candidates
            embeddings = self.model.encode(candidates, convert_to_tensor=True, show_progress_bar=False)

            # Compute pairwise cosine similarity matrix (3x3)
            cos_scores = util.cos_sim(embeddings, embeddings).cpu().numpy()

            # Calculate mean similarity for each candidate with respect to others
            # We subtract self-similarity (diagonal is always 1.0)
            avg_sims = []
            for i in range(num_models):
                row_sum = np.sum(cos_scores[i]) - 1.0 # Remove self
                avg_sim = row_sum / (num_models - 1) if num_models > 1 else 0
                avg_sims.append(avg_sim)

            # Pick index with max average similarity
            best_idx = np.argmax(avg_sims)
            final_answers.append(candidates[best_idx])

        return final_answers

def run_centroid_ensemble(dataset_tag):
    """
    dataset_tag: 'contra' or 'bio'
    Merges all results_{dataset_tag}_*.jsonl files into one Ensemble file.
    """
    print(f"\n Generating Centroid Ensemble for: {dataset_tag.upper()}...")

    # 1. Find all individual model files
    pattern = os.path.join(RESULTS_DIR, f"results_*_{dataset_tag}.jsonl")
    # Exclude files that are already ensembles or scored files
    files = [f for f in glob.glob(pattern) if "ensemble" not in f and "scored" not in f]

    if len(files) < 2:
        print(f" Need at least 2 models to ensemble. Found {len(files)}.")
        return

    print(f"   Combining models: {[os.path.basename(f) for f in files]}")

    # 2. Load data
    # Map structure: { row_index: { 'model_A': record, 'model_B': record } }
    aligned_data = {}

    for f_path in files:
        model_name = os.path.basename(f_path).replace(f"results_{dataset_tag}_", "").replace(".jsonl", "")
        with open(f_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    rec = json.loads(line)
                    if dataset_tag == 'biomqm':
                      idx = rec['row_index'] # Use row_index as primary key
                    else:
                      idx = (rec['id'], rec['perturbation_type'])  # Use id and perturbation_type as primary key
                    if idx not in aligned_data:
                        aligned_data[idx] = {'meta': rec} # Save metadata from first model
                        aligned_data[idx]['inputs'] = {}

                    aligned_data[idx]['inputs'][model_name] = rec
                except: continue

    # 3. Initialize Selector
    selector = CentroidSelector()

    # 4. Process and Save
    outfile = os.path.join(RESULTS_DIR, f"results_ensembleC_{dataset_tag}.jsonl")

    with open(outfile, 'w', encoding='utf-8') as f_out:
        sorted_indices = sorted(aligned_data.keys())

        for idx in tqdm(sorted_indices, desc="Computing Centroids"):
            group = aligned_data[idx]
            models = list(group['inputs'].keys())

            # Extract lists of answers
            src_answers_batch = [group['inputs'][m]['ans_src'] for m in models]
            bt_answers_batch = [group['inputs'][m]['ans_bt'] for m in models]

            #  CENTROID SELECTION
            best_src = selector.select_best_answers(src_answers_batch)
            best_bt = selector.select_best_answers(bt_answers_batch)

            # Create Ensemble Record
            # We take metadata from the first model, but replace answers
            ensemble_rec = group['meta'].copy()
            ensemble_rec['model'] = "Ensemble-Centroid"
            ensemble_rec['ans_src'] = best_src
            ensemble_rec['ans_bt'] = best_bt

            f_out.write(json.dumps(ensemble_rec) + "\n")

    print(f"Ensemble Saved: {os.path.basename(outfile)}")

#  EXECUTE
run_centroid_ensemble('contra')
run_centroid_ensemble('biomqm')

## 3.2 Compoute AskQE Score for single models and ensemble
- Semantic Alignment: Perform a final SBERT comparison between $best_{src}$ and $best_{bt}$ to measure information preservation.
- Final Metric: The resulting similarity value is the average of the recorded as the $AskQE_{score}$.

In [None]:
from sentence_transformers import SentenceTransformer, util
from tqdm.auto import tqdm

# CONFIGURATION
# We use 'all-mpnet-base-v2' to compute SBERT
SBERT_MODEL_ID = "all-mpnet-base-v2"
RESULTS_DIR = "/content/drive/MyDrive/Progetto_AskQE/results"

# SCORING ENGINE
class Scorer:
    def __init__(self, model_name):
        print(f" Loading SBERT model: {model_name}...")
        self.model = SentenceTransformer(model_name)
        # Move to GPU if available
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)
        print(f" Model loaded on {self.device}")

    def calculate_similarity(self, answers_src, answers_bt):
        """
        Computes the cosine similarity between two lists of answers.
        Returns:
            - pair_scores: List of floats (one per question)
            - aggregate_score: Mean of pair_scores
        """
        # Safety check for empty lists
        if not answers_src or not answers_bt:
            return [], 0.0

        # Encode both lists into embeddings
        # Output shape: (batch_size, embedding_dim)
        embeddings1 = self.model.encode(answers_src, convert_to_tensor=True, show_progress_bar=False)
        embeddings2 = self.model.encode(answers_bt, convert_to_tensor=True, show_progress_bar=False)

        # Compute Cosine Similarity
        cosine_scores = util.pairwise_cos_sim(embeddings1, embeddings2)

        # Convert tensor to numpy list
        scores_list = cosine_scores.cpu().numpy().tolist()

        # Aggregate (Mean)
        final_score = np.mean(scores_list) if scores_list else 0.0

        return scores_list, float(final_score)

# EXECUTION LOOP FOR INDIVIDUAL AND ENSAMBLE

def run_individual_scoring():
    scorer = Scorer(SBERT_MODEL_ID)

    # 1. Find path
    all_files = glob.glob(os.path.join(RESULTS_DIR, "results_*.jsonl"))
    input_files = [f for f in all_files if "ensemble" not in f and "scored" not in f] #keep only individual models not scored

    if not input_files:
        print(" No files found.")
        return

    print(f"{len(input_files)} single models to analyze found.")

    for input_path in input_files:
        filename = os.path.basename(input_path)
        out_name = f"results_scored_{filename}"
        print(out_name)

        output_path = os.path.join(RESULTS_DIR, out_name)

        print(f"\nScoring modello: {filename}")
        print(f"-> Output: {out_name}")

        processed_count = 0
        with open(input_path, 'r', encoding='utf-8') as fin, \
             open(output_path, 'w', encoding='utf-8') as fout:

            lines = fin.readlines()
            for line in tqdm(lines, desc=f"Scoring {filename}"):
                try:
                    record = json.loads(line)
                    ans_src = record.get('ans_src', [])
                    ans_bt = record.get('ans_bt', [])

                    min_len = min(len(ans_src), len(ans_bt))
                    ans_src = ans_src[:min_len]
                    ans_bt = ans_bt[:min_len]

                    pair_scores, avg_score = scorer.calculate_similarity(ans_src, ans_bt)

                    record['pair_scores'] = pair_scores
                    record['askqe_score'] = avg_score

                    fout.write(json.dumps(record) + "\n")
                    processed_count += 1
                except Exception as e:
                    continue

        print(f"Completed: {processed_count} rows processed.")

    del scorer
    torch.cuda.empty_cache()

def run_scoring_pipeline():
    print("pipeline --------------------")
    # 1. Initialize Scorer
    scorer = Scorer(SBERT_MODEL_ID)

    # 2. Find all result files from previous step
    # We look for files starting with 'results_' but NOT 'scored_' to avoid double processing
    input_files = glob.glob(os.path.join(RESULTS_DIR, "results_ensembleC_*.jsonl"))
    input_files = [f for f in input_files if "scored_" not in os.path.basename(f)]

    if not input_files:
        print(" No input files found in results folder.")
        return

    print(f"Found {len(input_files)} files to score.")

    for input_path in input_files:
        filename = os.path.basename(input_path)
        output_path = os.path.join(RESULTS_DIR, f"results_scored_{filename}")

        print(f"\nProcessing: {filename}...")
        print(f"-> Saving to: {output_path}")

        processed_count = 0

        with open(input_path, 'r', encoding='utf-8') as fin, \
             open(output_path, 'w', encoding='utf-8') as fout:

            # Read all lines first to use tqdm correctly
            lines = fin.readlines()

            for line in tqdm(lines, desc="Scoring Rows"):
                try:
                    record = json.loads(line)

                    # Extract answers
                    ans_src = record.get('ans_src', [])
                    ans_bt = record.get('ans_bt', [])

                    # Ensure lists are aligned in length (truncate to shortest)
                    min_len = min(len(ans_src), len(ans_bt))
                    ans_src = ans_src[:min_len]
                    ans_bt = ans_bt[:min_len]

                    # Calculate Score
                    pair_scores, avg_score = scorer.calculate_similarity(ans_src, ans_bt)

                    # Possible hallucinations
                    hall = []
                    for score in pair_scores:
                      if score < 0.3 : hall.append("hall")
                      else: hall.append("no")

                    # Add scores to record
                    record['pair_scores'] = pair_scores
                    record['askqe_score'] = avg_score
                    record['hallucination'] = hall

                    # Write to new file
                    fout.write(json.dumps(record) + "\n")
                    processed_count += 1

                except Exception as e:
                    print(f"Error parsing line: {e}")
                    continue

        print(f"Completed {filename}: Scored {processed_count} rows.")

    del scorer
    torch.cuda.empty_cache()
    print("\n ALL SCORING COMPLETED.")

# Run the pipeline
run_individual_scoring() # single model
run_scoring_pipeline() # ensamble

# 4.  Results

## 4.1 Results on ContraTICO

Compute all metrics for contraTICO dataset:
- mean AskeQE score per perturbation type
- Decision accurancy by fitting our AskQE score with a GMM and compare results with the corresponding perturbation type ( Critical = Reject, Minor = Accept )

In [None]:
from sklearn.mixture import GaussianMixture
from sklearn.metrics import accuracy_score

# Configuration
RESULTS_DIR = "/content/drive/MyDrive/Progetto_AskQE/results"
# Define critical errors categories
CRITICAL_ERRORS = ['omission', 'expansion_impact', 'alteration']

def analyze_contratico_all_models():
    # Search all file scored
    pattern = os.path.join(RESULTS_DIR, "results_scored_results_*_contra.jsonl")
    files = glob.glob(pattern)

    if not files:
        print(" Error: No files found.")
        return

    all_summaries = []

    for f_path in files:
        # Identify model name
        model_id = os.path.basename(f_path).replace("results_scored_results_", "").replace("_contra.jsonl", "")

        data = []
        with open(f_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data.append(json.loads(line))
                except: continue

        df = pd.DataFrame(data)

        if df.empty or 'askqe_score' not in df.columns:
            continue

        # A) MEAN SCORE PER PERTURBATION TYPE
        means = df.groupby('perturbation_type')['askqe_score'].mean()

        # B) DECISION ACCURACY (GMM)
        # Accept vs Reject
        df['gold_decision'] = df['perturbation_type'].apply(
            lambda x: 'reject' if x in CRITICAL_ERRORS else 'accept'
        )

        # Fit GMM on ask_qe scores
        scores_v = df['askqe_score'].values.reshape(-1, 1)
        gmm = GaussianMixture(n_components=2, random_state=42).fit(scores_v)
        preds = gmm.predict(scores_v)

        # Mapping ( higher mean --> accept)
        mapping = {np.argmax(gmm.means_): 'accept', np.argmin(gmm.means_): 'reject'}
        df['model_decision'] = [mapping[p] for p in preds]

        acc = accuracy_score(df['gold_decision'], df['model_decision'])

        # Saving Statistics
        summary = means.to_dict()
        summary['Model'] = model_id
        summary['DECISION_ACCURACY'] = acc
        all_summaries.append(summary)

    # Create final table
    comparison_df = pd.DataFrame(all_summaries).set_index('Model')

    # Column order
    cols = [c for c in comparison_df.columns if c != 'DECISION_ACCURACY']
    comparison_df = comparison_df[cols + ['DECISION_ACCURACY']]

    print("\n" + "="*80)
    print("CONTRATICO ANALYSIS (EN-ES)")
    print("="*80)
    print(comparison_df.round(4).to_string())
    print("-" * 80)

    return comparison_df

# EXECUTION
stats_table = analyze_contratico_all_models()

## 4.2 Merge output file with human ratings
We merge the results scoring file of the ensemble with the one with the human ratings to compare our model with human annotations

In [None]:
from tqdm.auto import tqdm

# CONFIGURATION
REPO_ROOT = "/content/drive/MyDrive/askqe"
HUMAN_RATINGS_PATH = "/content/drive/MyDrive/askqe/biomqm/human_simulation/classified/human_ratings.jsonl"
RESULTS_DIR = "/content/drive/MyDrive/Progetto_AskQE/results"

# Normalization
def normalize_text(text):
    if not isinstance(text, str): return ""
    return re.sub(r'\s+', ' ', text).strip().lower()

# 1. Load Human Ratings Map
def load_human_ratings_map(path):
    print(f"Loading Human Ratings from: {path}")
    if not os.path.exists(path):
        print(f" Error: Human ratings file not found at {path}")
        return {}

    ratings_map = {}
    count = 0

    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                d = json.loads(line)

                # Try to find source text
                if d.get('lang_tgt')!= 'es': continue
                src = d.get('src')
                if not src: continue
                norm_src = normalize_text(src)

                # Extract Score
                segment_score = d.get('segment_score')
                human_decision = d.get('decision')
                severity = d.get('severity')
                if segment_score is not None and human_decision is not None:
                    ratings_map[norm_src] = {
                        'score': segment_score,
                        'decision': human_decision,
                        'severity': severity
                    }
                    count += 1
            except: continue

    print(f"Loaded {count} human ratings.")
    return ratings_map

# 2. Merge Pipeline
def merge_ratings_into_results():
    # A. Load Map
    ratings_map = load_human_ratings_map(HUMAN_RATINGS_PATH)
    if not ratings_map:
        print(" Aborting merge: No ratings loaded.")
        return

    # B. Find BioMQM Result Files (Scored)
    # Only look for BioMQM files that have been scored
    search_pattern = os.path.join(RESULTS_DIR, "results_scored_results_ensembleC_biomqm.jsonl")
    result_files = glob.glob(search_pattern)

    if not result_files:
        print(" No scored BioMQM files found. Did you run Cell 9?")
        return

    print(f"Found {len(result_files)} result files to merge.")

    for filepath in result_files:
        filename = os.path.basename(filepath)
        print(f"\nProcessing {filename}...")

        updated_rows = []
        matched_count = 0

        with open(filepath, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    row = json.loads(line)
                    src = row.get('source', '')
                    norm_src = normalize_text(src)

                    # Lookup
                    if norm_src in ratings_map:
                        data = ratings_map[norm_src]
                        row['human_score'] = data['score']
                        row['human_decision'] = data['decision']
                        row['severity'] = data['severity']
                        matched_count += 1
                    else:
                        # Explicitly None if missing
                        row['human_score'] = None
                        row['human_decision'] = None
                        row['severity'] = None

                    updated_rows.append(row)
                except: continue

        # Save back (create a 'final_' version)
        output_path = os.path.join(RESULTS_DIR, f"final_{filename}")
        with open(output_path, 'w', encoding='utf-8') as f_out:
            for row in updated_rows:
                f_out.write(json.dumps(row) + "\n")

        print(f"Merged {matched_count}/{len(updated_rows)} rows. Saved to: final_{filename}")

# Run
merge_ratings_into_results()

## 4.3 Results on BioMQM
We evaluate:
- Kendall’s Tau ($\tau$): Measures the rank-order correlation between AskQE scores and human ratings, indicating alignment with human perception regardless of the score scale.
- Decision Accuracy: Calculates the percentage of matching "Accept/Reject" labels between the model and human ground truth, utilizing a Gaussian Mixture Model (GMM) for score thresholding (same as contraTICO).
- Disagreement Analysis: Counts decision mismatches across different MQM severity levels to identify where the model diverges from human intuition.
- Mean Score per Severity: Computes average AskQE values for each error category to verify the metric's sensitivity to increasingly severe translation errors.

In [None]:
from scipy.stats import kendalltau
from sklearn.mixture import GaussianMixture
from sklearn.metrics import accuracy_score

def evaluate_model_performance(file_path, model_score_key='askqe_score'):
    """
    Compute Kendall's Tau, Decision Accuracy (GMM), Disagreement and Mean Scores
    """
    data_list = []

    # 1. Load data
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            d = json.loads(line)
            if d.get(model_score_key) is not None and d.get('human_score') is not None:
                data_list.append({
                    'score': d[model_score_key],
                    'h_score': d['human_score'],
                    'h_decision': d['human_decision'],
                    'severity': d['severity']
                })

    if not data_list:
        print("Error: No valid data found for the analysis.")
        return None

    df = pd.DataFrame(data_list)

    X_model = df['score'].values.reshape(-1, 1)
    Y_human_scores = df['h_score'].values

    # 2. Compute Kendall's tau
    tau, p_value = kendalltau(X_model, Y_human_scores)

    # 3. Compute decision accurancy via GMM
    gmm = GaussianMixture(n_components=2, random_state=42)
    gmm.fit(X_model)
    clusters = gmm.predict(X_model)

    # Mapping: Higher mean = 'accept'
    if gmm.means_[0] > gmm.means_[1]:
        mapping = {0: 'accept', 1: 'reject'}
    else:
        mapping = {1: 'accept', 0: 'reject'}

    df['m_decision'] = [mapping[c] for c in clusters]
    accuracy = accuracy_score(df['h_decision'], df['m_decision'])

    # 4. Compute Disagreement
    df['disagreement'] = df['h_decision'] != df['m_decision']
    disagreement_counts = df.groupby('severity')['disagreement'].sum()

    # 5. Mean score per severity
    mean_scores = df.groupby('severity')['score'].mean()

    print(f"Processed results for {model_score_key}")
    print("-" * 30)

    return disagreement_counts, mean_scores, tau, accuracy, p_value

# Execution
#To reproduce results also for single models just change model name in the path
file_path = '/content/drive/MyDrive/Progetto_AskQE/results/final_results_scored_results_ensembleC_biomqm.jsonl'
results = evaluate_model_performance(file_path, 'askqe_score')

if results:
    dis_counts, means, tau, accuracy, p_value = results

    # Print Results
    print(f"--- BIOMQM RESULT: {file_path.split('/')[-1]} ---")
    print(f"Global Kendall Tau: {tau:.4f} (p-value: {p_value:.5e})")
    print(f"Global Decision Accuracy (GMM): {accuracy*100:.2f}%")
    print("-" * 50)

    results_table = pd.DataFrame({
        'Mean AskQE Score': means,
        'Disagreement Count': dis_counts
    }).reindex(['no error','neutral', 'minor', 'major', 'critical']).fillna(0)

    print("For severity type:")
    print(results_table)
    print("-" * 50)

# 5. Plots for contraTICO and BioMQM


**box plot of confidence_backtranslation**

In [None]:
models_to_plot = [
    "llama3", "qwen", "gemma"
]

RESULTS_DIR = "/content/drive/MyDrive/Progetto_AskQE/results"

def print_graph(dataset_tag):
    plt.figure(figsize=(10, 7))
    all_conf_bt_data = []

    print(f"\nPlotting Back-Translation Confidence for {dataset_tag.capitalize()} dataset...")

    for model_name in models_to_plot:
        # Construct the file path for each model's result
        file_pattern = os.path.join(RESULTS_DIR, f"results_{model_name}_{dataset_tag}.jsonl")
        files = glob.glob(file_pattern)

        if not files:
            print(f"  Warning: No result file found for {model_name} in {file_pattern}. Skipping.")
            continue

        file_path = files[0]

        try:
            df = pd.read_json(file_path, lines=True)

            df['avg_conf_bt'] = df['conf_bt'].apply(lambda x: sum(x) / len(x) if isinstance(x, list) and len(x) > 0 else 0)

            if not df.empty:
                all_conf_bt_data.append(pd.DataFrame({
                    'Model': model_name,
                    'Confidence_BT': df['avg_conf_bt']
                }))
                print(f"  Loaded {len(df)} rows for {model_name}.")

        except Exception as e:
            print(f"  Error processing {file_path}: {e}. Skipping.")
            continue

    if not all_conf_bt_data:
        print("No data available to plot after filtering. Please ensure result files exist and contain 'conf_bt'.")
        return

    combined_df = pd.concat(all_conf_bt_data)

    sns.boxplot(x='Model', y='Confidence_BT', data=combined_df)
    plt.title(f'Distribution of Average Back-Translation Confidence Scores for {dataset_tag.capitalize()} Dataset')
    plt.xlabel('Model')
    plt.ylabel('Average Confidence (Back-Translation)')
    plt.ylim(0, 1)
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.tight_layout()
    plot_filename = f"/content/drive/MyDrive/Progetto_AskQE/plots/confidence_bt_{dataset_tag}.png"
    plt.savefig(plot_filename, dpi=300)
    plt.show()


print_graph('contra')
print_graph('biomqm')

**plot of askQE score divided per perturbation_type**

In [None]:
RESULTS_DIR = "/content/drive/MyDrive/Progetto_AskQE/results"
PLOTS_DIR = "/content/drive/MyDrive/Progetto_AskQE/plots"
os.makedirs(PLOTS_DIR, exist_ok=True)

MODEL_MAP = {
    "ensembleC": "Ensemble-Centroid"
}

def load_all_scored_results(dataset_tag):
    """
    Loads scored results for all models and the Centroid ensemble into a dictionary of DataFrames.
    """
    data_map = {}
    for file_prefix, display_name in MODEL_MAP.items():
        # Pattern for scored files
        pattern = os.path.join(RESULTS_DIR, f"results_scored_results_{file_prefix}_{dataset_tag}.jsonl")
        files = glob.glob(pattern)

        if files:
            filepath = files[0]
            data = []
            with open(filepath, 'r', encoding='utf-8') as f:
                for line in f:
                    try:
                        data.append(json.loads(line))
                    except:
                        continue

            df = pd.DataFrame(data)
            if 'row_index' in df.columns:
                df.set_index('row_index', inplace=True)
            data_map[display_name] = df
            print(f"Loaded {display_name} ({dataset_tag}): {len(df)} rows from {os.path.basename(filepath)}")
        else:
            print(f"No scored file found for {display_name} with pattern {pattern}")

    return data_map

def analyze_by_perturbation(dataset_tag):
    print(f"\n{dataset_tag} Analysis (AskQE Score by Perturbation Type)")
    data_map = load_all_scored_results(dataset_tag)

    # Combine all into one big dataframe for plotting
    all_data = []
    for model_name, df in data_map.items():
        # Filter out rows where askqe_score is None or NaN
        valid_df = df.dropna(subset=['askqe_score']).copy()
        if not valid_df.empty:
            temp_df = valid_df[['perturbation_type', 'askqe_score']].copy()
            temp_df['Model'] = model_name
            all_data.append(temp_df)

    if not all_data:
        print("No data available to plot after loading and filtering.")
        return

    full_df = pd.concat(all_data)

    order_idx = None
    if "Ensemble-Centroid" in data_map:
        ensemble_df = full_df[full_df['Model'] == 'Ensemble-Centroid']
        if not ensemble_df.empty:
            order_idx = ensemble_df.groupby('perturbation_type')['askqe_score'].median().sort_values().index

    if order_idx is None:
        order_idx = full_df.groupby('perturbation_type')['askqe_score'].median().sort_values().index


    # PLOT
    if dataset_tag == 'contra':
        plt.figure(figsize=(14, 8))
    else:
        plt.figure(figsize=(10, 6))
    sns.boxplot(x='perturbation_type', y='askqe_score', hue='Model', data=full_df, order=order_idx, palette="Set3")

    plt.title(f'{dataset_tag}: AskQE Score Distribution by Perturbation Type and Model')
    plt.xticks(rotation=45, ha='right')
    plt.ylabel('AskQE Score')
    plt.xlabel('Perturbation Type')
    plt.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()

    plot_filename = f"{PLOTS_DIR}/{dataset_tag}_askqe_by_perturbation.png"
    plt.savefig(plot_filename, dpi=300)
    print(f"Saved plot: {plot_filename}")
    plt.show()

# EXECUTE
analyze_by_perturbation("contra")
analyze_by_perturbation("biomqm")


**Possible hallucination**

In [None]:
def summarize_hallucinations(repo_root: str, dataset_tag: str):
    path = os.path.join(
        repo_root, f"results_scored_results_ensembleC_{dataset_tag}.jsonl"
    )

    total = 0
    hall = 0

    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            rec = json.loads(line)
            total += len(rec["questions"])
            hall += sum(1 for h in rec["hallucination"] if h == "hall")

    rate = hall / total if total > 0 else 0.0

    return {
        "dataset": dataset_tag,
        "total_questions": total,
        "hallucinations": hall,
        "hallucination_rate": rate,
    }


repo_root = "/content/drive/MyDrive/Progetto_AskQE/results"

results = [
    summarize_hallucinations(repo_root, "biomqm"),
    summarize_hallucinations(repo_root, "contra"),
]

df = pd.DataFrame(results)
print(df)

# 6. Extension 2 - error categorization

We define an LLM-as-a-judge model (using Qwen2.5-7B-Instruct) which, given the source, the backtranslation, and the QA, assigns a translation error category.

For each example, we build a structured prompt, query the model, and store the predicted category and the explanation. We then compare the category predicted by the model with the gold category (perturbation_type) present in the dataset.

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
from tqdm.auto import tqdm

# Configuration
RESULTS_DIR = "/content/drive/MyDrive/Progetto_AskQE/results"
PLOTS_DIR = "/content/drive/MyDrive/Progetto_AskQE/plots"
JUDGE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"

# Perturbation_type categories
ERROR_CATEGORIES = [
    "omission", "expansion_noimpact", "expansion_impact",
    "intensifier", "spelling", "synonym", "word_order", "alteration", "Unknown"
]

# Prompt
JUDGE_PROMPT = """You are a professional linguist evaluating Machine Translation quality using Backtranslation.

Task: Analyze the discrepancies between the SOURCE text and the BACKTRANSLATION to identify the specific type of translation error.

Definitions:
- Source: The original English text.
- Backtranslation: The translation of the MT output back into English.
- QA Evidence: Questions asked to both texts. Different answers indicate an error.

Error Categories (Choose ONE):
1. omission: Information in Source is missing in Backtranslation.
2. expansion_noimpact: Added detail that does not change the overall meaning; stylistic expansion.
3. expansion_impact: Added content that changes or extends the meaning compared to the source.
4. intensifier: Adds intensity (e.g., “severe”) that was not in the source, i.e., extra semantic content.
5. spelling: Orthographic / surface error.
6. synonym: Lexical substitution.
7. word_order: Directly corresponds to word order / reordering errors.
8. alteration: Stronger change of meaning or contradiction relative to the source.

*** ANALYSIS ***
Source: "{{source}}"
Backtranslation: "{{backtrans}}"

QA Evidence (Discrepancies):
{{qa_evidence}}

*** INSTRUCTIONS ***
Write ONLY two lines:
1) Category: <state the category explicitly in the format "Category: [CategoryName]">.
2) Explanation: <max 15 words>.


Response:
"""

def format_qa_evidence(questions, ans_src, ans_bt):
    lines = []
    for q, s, bt in zip(questions, ans_src, ans_bt):
      lines.append(f"Q:{q}\nSource Ans:{s}\nBacktranslation Ans:{bt}")

    return "\n".join(lines)


def run_judge_pipeline(df, engine):
    """
    Runs the Judge LLM on a DataFrame.
    """
    results = []

    target_df = df

    print(f"Judging {len(target_df)} samples...")

    for idx, row in tqdm(target_df.iterrows(), total=len(target_df)):
        # 1. Prepare Evidence
        qa_text = format_qa_evidence(row['questions'], row['ans_src'], row['ans_bt'])

        # 2. Build Prompt
        prompt = JUDGE_PROMPT.replace("{{source}}", row['source'])\
                            .replace("{{backtrans}}", row['backtrans'])\
                            .replace("{{qa_evidence}}", qa_text)
        # 3. Inference
        try:
            outputs = engine.pipeline(prompt, max_new_tokens=20, max_lenght=None)
            generated_text = outputs[0]['generated_text']

            # Extract category and explanation from the response
            response = generated_text[len(prompt):].strip()

            category_str = None
            explanation_str = None

            lines = response.splitlines()
            for line in lines:
                line = line.strip()

                if line.startswith("Category:") and category_str is None:
                    category_raw = line[len("Category:"):].strip()
                    print(f"category_raw:{category_raw}")
                    category = category_raw.strip("[.]").strip()

                elif line.startswith("Explanation:") and explanation_str is None:
                    raw_expl = line[len("Explanation:"):].strip()
                    explanation = raw_expl.split(".", 1)[0].strip()
                    break

            print(category)
            print(explanation)
            print("--------------------")
            results.append({
                "row_index": idx,
                "index_db": row.get('row_index'),
                "true_category": row.get('perturbation_type', 'Real'), # Ground Truth
                "predicted_category": category,
                "explanation": explanation,
                "askqe_score": row.get('askqe_score', 0.0)
            })

            # save the output
            output_path = os.path.join(RESULTS_DIR, "extension2.jsonl")
            with open(output_path, "w", encoding="utf-8") as f:
                for item in results:
                    f.write(json.dumps(item, ensure_ascii=False) + "\n")

        except Exception as e:
            print(f"Judge Error: {e}")
            continue

    return results

def get_ensemble_file(tag):
    f = os.path.join(RESULTS_DIR, f"results_scored_results_ensembleC_{tag}.jsonl")
    return f


# Load DataFrame
path_contra = get_ensemble_file('contra')

df_judge_contra = pd.DataFrame()

# Load Model Engine (Llama-3-8B)
print("Loading Judge Model (Qwen2.5-7B-Instruct)")
judge_engine = ModelEngine(JUDGE_MODEL_ID)
judge_engine.load_model()


if os.path.exists(path_contra):
    print("\nValidating Judge on ContraTICO")
    data_c = []
    with open(path_contra, 'r') as f:
        for l in f: data_c.append(json.loads(l))
    df_c = pd.DataFrame(data_c)

    # Run Judge
    judge_results_c = run_judge_pipeline(df_c, judge_engine)
    print("------------------")
    #(judge_results_c)

    judge_results_c_df = pd.DataFrame(judge_results_c)

    # count per category
    pred_count = (
        judge_results_c_df["predicted_category"]
        .value_counts()
        .reindex(ERROR_CATEGORIES, fill_value=0)
    )

    true_count = (
        judge_results_c_df["true_category"]
        .value_counts()
        .reindex(ERROR_CATEGORIES, fill_value=0)
    )

    # final table
    table = pd.DataFrame({
        "category": ERROR_CATEGORIES,
        "count_true": true_count.values,
        "count_pred": pred_count.values
    })

    print(table.to_string(index=False))

    # bar chart
    ax = table.set_index("category")[["count_true", "count_pred"]].plot(kind="bar")
    ax.set_ylabel("count")
    ax.set_title("True vs Predicted per category")

    # prendi le posizioni esatte delle barre e ruota le label
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')

    plt.tight_layout()
    plt.show()

# Cleanup
judge_engine.unload_model()
print("\n[DONE] Error Categorization Complete.")