In [None]:
import time
import numpy as np
from tqdm import tqdm

import json
import pandas as pd

from transformers import AutoTokenizer, AutoModel

import warnings
import os
import ast

warnings.filterwarnings('ignore')

import re
from unidecode import unidecode

from pymongo.mongo_client import MongoClient
import json

In [None]:
def compute_f1(prediction, truth):
    pred_tokens = prediction.split()
    truth_tokens = truth.split()
    
    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return int(pred_tokens == truth_tokens)
    
    common_tokens = set(pred_tokens) & set(truth_tokens)
    
    # if there are no common tokens then f1 = 0
    if len(common_tokens) == 0:
        return 0
    
    prec = len(common_tokens) / len(pred_tokens)
    rec = len(common_tokens) / len(truth_tokens)
    
    return 2 * (prec * rec) / (prec + rec)

In [None]:
import string
def normalize(input_string):
    input_string = unidecode(input_string)
    input_string = input_string.lower()
    
    # Remove commas and periods between digits (e.g., 7,531 or 7.531 -> 7531)
    input_string = re.sub(r'(?<=\d)[,\.](?=\d)', '', input_string)
    
    # Replace all other punctuation with a space
    input_string = re.sub(f"[{re.escape(string.punctuation)}]", " ", input_string)
    
    # Replace multiple spaces with a single space
    input_string = re.sub(r"\s+", " ", input_string)
    
    # Trim leading/trailing whitespace
    return input_string.strip()

## MuSiQue

In [None]:
def get_mongo_client(mongo_uri):
    client = MongoClient(mongo_uri)
    return client

mongo_client = get_mongo_client("mongodb://localhost:63819/?directConnection=true")
mongo_client.list_database_names()
db = mongo_client.get_database("onto_triplets_db_llama")
entity_aliases_collection = db.get_collection("entity_aliases")

In [None]:
with open("datasets/musique.json", "r") as f:
    ds = json.load(f)
    
# ds = ds['data'][:]

id2sample = {}
for elem in ds:
    id2sample[elem['id']] = elem

In [None]:
qa_res_file = 'qa_logs/onto_triplets_db_llama_Meta-llama_Llama-3.3-70B-Instruct_structured_True_multi_step_True_use_qualifiers_True_use_filtered_triplets_True_musique_test_run_3.jsonl'
sample_id2ans = {}
with open(qa_res_file, "r") as f:
    for line in f:
        sample2ans = json.loads(line)
        sample_id2ans[sample2ans['sample_id']] = sample2ans['answer']

In [None]:
f1s = []
ems = []
for sample_id in id2sample.keys():
# for sample_id in ids:
    if sample_id in sample_id2ans:
        print(sample_id)
        question = id2sample[sample_id]['question']
        ans = sample_id2ans[sample_id]
        if isinstance(ans, int):
            ans = str(ans)

        aliases = set([ans])
        retrieved_aliases = list(entity_aliases_collection.find({"$or": [{"label": ans}]}, {"_id": 0, "label": 1, "alias": 1}))
        for alias in retrieved_aliases:
            aliases.add(alias['alias'])
            aliases.add(alias['label'])
        
        aliases = list(aliases)

        max_f1 = 0
        max_em = 0
        max_f1_entity = ''
        max_em_entity = ''
        gold_answers_variations = [id2sample[sample_id]['answer']]
        # gold_answers_variations.extend(id2sample[sample_id]['answer_aliases'])
        for golden_answer in gold_answers_variations:
            for alias in aliases:
                golden_answer = normalize(golden_answer)
                ans = normalize(alias)
                f1 = compute_f1(prediction=ans, truth=golden_answer)
                em = golden_answer == ans

                if f1 > max_f1:
                    max_f1_entity = golden_answer
                if em > max_em:
                    max_em_entity = golden_answer
                max_f1 = max(max_f1, f1)
                max_em = max(max_em, em)
        
        print(sample_id, " | ", ans, " | ", max_em_entity, " | ", max_f1_entity)
        f1s.append(max_f1)
        ems.append(max_em)

sum(ems) / len(ems), sum(f1s)/len(f1s)

# HotPot


In [None]:
with open("datasets/hotpotqa200.json", "r") as f:
    ds = json.load(f)

ds = ds[:]
id2sample = {}
for elem in ds:
    id2sample[elem['_id']] = elem

In [None]:
mongo_client = get_mongo_client("mongodb://localhost:63819/?directConnection=true")
mongo_client.list_database_names()
db = mongo_client.get_database("triplets_db_hotpot_llama")
entity_aliases_collection = db.get_collection("entity_aliases")

In [None]:
qa_res_file = 'qa_logs/triplets_db_hotpot_llama_Meta-llama_Llama-3.3-70B-Instruct_structured_True_multi_step_True_use_qualifiers_True_use_filtered_triplets_True_hotpot_test_run_3.jsonl'
sample_id2ans = {}
with open(qa_res_file, "r") as f:
    for line in f:
        sample2ans = json.loads(line)
        sample_id2ans[sample2ans['sample_id']] = sample2ans['answer']

In [None]:
f1s = []
ems = []
for sample_id in id2sample.keys():
    try:
        question = id2sample[sample_id]['question']
        ans = sample_id2ans[sample_id]
        aliases = set([ans])
        
        retrieved_aliases = list(entity_aliases_collection.find({"$or": [{"label": ans}, {'alias': ans}]}, {"_id": 0, "label": 1, "alias": 1}))
        for alias in retrieved_aliases:
            aliases.add(alias['alias'])
            aliases.add(alias['label'])

            
        max_f1 = 0
        max_em = 0
        max_f1_entity = ''
        max_em_entity = ''
        gold_answers_variations = [id2sample[sample_id]['answer']]
        # gold_answers_variations.extend(id2sample[sample_id]['answer_aliases'])
        for golden_answer in gold_answers_variations:
            for alias in aliases:
                golden_answer = normalize(golden_answer)
                if 'no,' in alias.lower():
                     alias = 'no'
                if 'yes,' in alias.lower():
                     alias = 'yes'          
                ans = normalize(alias)
                f1 = compute_f1(prediction=ans, truth=golden_answer)
                em = golden_answer == ans

                if f1 > max_f1:
                    max_f1_entity = golden_answer
                if em > max_em:
                    max_em_entity = golden_answer
                max_f1 = max(max_f1, f1)
                max_em = max(max_em, em)
        
        print(sample_id, " | ", ans, " | ", max_em_entity, " | ", max_f1_entity)
        f1s.append(max_f1)
        ems.append(max_em)
    except Exception as e:
        continue
sum(ems) / len(ems), sum(f1s)/len(f1s)

In [None]:
# eval_setup = accs_llama_hopot_with_aliases_with_decomposition_with_filtered_triplets
# em = [e[0] for e in eval_setup]
# f1 = [e[1] for e in eval_setup]
# print("EM: ", np.mean(em), np.std(em))
# print("F1: ", np.mean(f1), np.std(f1))
# print(round(np.mean(em), 3), "±", round(np.std(em), 3))
# print(round(np.mean(f1), 3), "±", round(np.std(f1), 3))