In [None]:
import sys
sys.path.append('../src')

import os
os.environ["LANGUAGE"] = 'ar'

from sqlalchemy.sql.expression import func
from wikidataDB import WikidataEntity, WikidataID, Session
from wikidataRetriever import WikidataKeywordSearch, AstraDBConnect
from SPARQLWrapper import SPARQLWrapper, JSON

import json
import os
import pandas as pd
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import pickle
from datasets import load_dataset
import re
from requests.exceptions import HTTPError
import time
from tqdm import tqdm

def is_in_wikipedia(qid):
    item = WikidataID.get_id(qid)
    if item is None:
        return False
    return item.in_wikipedia

In [None]:
import numpy as np
import pickle

def calculate_mrr_score(df, pred_col, true_cols):
    # Remove duplicate QIDs while keeping the order
    prep[pred_col] = prep[pred_col].apply(lambda x: list(dict.fromkeys(x)))
    # Get the rank of each retrieved QID
    ranks = df.apply(lambda x: [i+1 for i in range(len(x[pred_col])) if (x[pred_col][i] in x[true_cols])], axis=1)
    # Return the MRR
    return ranks.apply(lambda x: 1/x[0] if len(x)>0 else 0).mean()

def calculate_ndcg_score(df, pred_col, true_cols):
    # Remove duplicate QIDs while keeping the order
    prep[pred_col] = prep[pred_col].apply(lambda x: list(dict.fromkeys(x)))
    # Get the rank of each retrieved QID
    ranks = df.apply(lambda x: [i+1 for i in range(len(x[pred_col])) if (x[pred_col][i] in x[true_cols])], axis=1)
    # Calculate the DCG, the Ideal DCG and finally return the NDCG
    dcg = ranks.apply(lambda x: sum([1/np.log2(y+1) for y in x]) if len(x)>0 else 0)
    idcg = df.apply(lambda x: sum([1/np.log2(y+1) for y in range(1, min(len(x[true_cols]), len(x[pred_col])) + 1)]), axis=1)
    return (dcg/idcg).mean()

collection = "wikidata_test_v1"
evaluation_dataset = "REDFM"
filename = f"retrieval_results_{evaluation_dataset}-{collection}-arwithentity_DB-AR_EN_DE_Query-AR"
filename = f"../data/Evaluation Data/retrieval_results_REDFM-wikidata_test_v1-DB(en,ar)-Query(EN)_DB-AR-EN_Query-EN.pkl"

# directory = '../data/Evaluation Data/Language Results/REDFM-noentity'
# for file in os.listdir(directory):
prep = pickle.load(open(filename, "rb"))
assert pd.isna(prep['Retrieval QIDs']).sum() == 0, "Evaluation not complete"

# For Mintaka, LC_QuAD, and RuBQ
# prep = prep[prep.apply(lambda x: all(x['Question in Wikipedia'] + x['Answer in Wikipedia']), axis=1)]
# prep['Correct QIDs'] = prep.apply(lambda x: x['Question QIDs'] + x['Answer QIDs'], axis=1)

# For REDFM
prep = prep[prep['Correct in Wikipedia']]
prep['Correct QIDs'] = prep['Correct QID'].apply(lambda x: [x])

print(file)
print("MRR:")
print(calculate_mrr_score(prep, 'Retrieval QIDs', 'Correct QIDs'))
print("NDCG:")
print(calculate_ndcg_score(prep, 'Retrieval QIDs', 'Correct QIDs'))
print()

In [None]:
from wikidataDB import WikidataEntity
from wikidataEmbed import WikidataTextifier, JinaAIReranker

collection = "wikidata_test_v1"
evaluation_dataset = "Mintaka"
filename = f"retrieval_results_{evaluation_dataset}-{collection}-de_DB-EN_Query-DE"
prep = pickle.load(open(f"../data/Evaluation Data/{filename}.pkl", "rb"))

textifier = WikidataTextifier(with_claim_aliases=False, with_property_aliases=False, language='en')
reranker = JinaAIReranker()

def rerank_qids(query, qids, reranker, textifier):
    entities = [WikidataEntity.get_entity(qid) for qid in qids]
    texts = [textifier.entity_to_text(entity) for entity in entities]
    scores = reranker.rank(query, texts)

    score_zip = zip(scores, prep.iloc[0]['Retrieval QIDs'])
    score_zip = sorted(score_zip, key=lambda x: -x[0])
    return [x[1] for x in score_zip]

scores = rerank_qids(prep.iloc[0]['Question'], prep.iloc[0]['Retrieval QIDs'], reranker, textifier)

In [None]:
filename = f"retrieval_results_{evaluation_dataset}-{collection}-en_DB-EN_Query-EN"
prep = pickle.load(open(f"../data/Evaluation Data/{filename}.pkl", "rb"))
prep

In [None]:
import pickle
import numpy as np

def calculate_accuracy_score(df):
    highest_score_idx = df['Retrieval Score'].apply(np.argmax)
    top_qid = df.apply(lambda x: x['Retrieval QIDs'][highest_score_idx[x.name]], axis=1)
    return (top_qid == df['Correct QID']).mean()

def calculate_log_odds_ratio_score(df):
    def log_odds_ratio(row):
        correct_qid = row['Correct QID']
        wrong_qid = row['Wrong QID']

        # Find the maximum scores for the correct and wrong QIDs
        correct_scores = [score for qid, score in zip(row['Retrieval QIDs'], row['Retrieval Score']) if qid == correct_qid]
        wrong_scores = [score for qid, score in zip(row['Retrieval QIDs'], row['Retrieval Score']) if qid == wrong_qid]

        max_correct_score = max(correct_scores, default=0)
        max_wrong_score = max(wrong_scores, default=0)

        correct_log_odds = np.log(max_correct_score / (1 - max_correct_score))
        wrong_log_odds = np.log(max_wrong_score / (1 - max_wrong_score))
        return correct_log_odds - wrong_log_odds

    # Apply the log odds ratio calculation to each row
    return df.apply(log_odds_ratio, axis=1).mean()

collection = "wikidata_test_v2"
evaluation_dataset = "Wikidata-Disamb"
prep = pickle.load(open(f"../data/Evaluation Data/retrieval_results_{evaluation_dataset}-{collection}-en.pkl", "rb"))
assert pd.isna(prep['Retrieval QIDs']).sum() == 0, "Evaluation not complete"

calculate_accuracy_score(prep)

In [None]:
import matplotlib.pyplot as plt

def calculate_accuracy_over_K(df, pred_col, true_cols):
    # Remove duplicate QIDs while keeping the order
    prep[pred_col] = prep[pred_col].apply(lambda x: list(dict.fromkeys(x)))
    # Get the rank of each retrieved QID
    ranks = df.apply(lambda x: [i+1 for i in range(len(x[pred_col])) if (x[pred_col][i] in x[true_cols])], axis=1)
    ranks = ranks.apply(lambda x: min(x) if len(x) > 0 else None)

    accuracy = [(ranks <= i).mean() for i in range(int(ranks.max()))]
    return accuracy

collection = "wikidata_test_v1"
evaluation_dataset = "REDFM"
prep = pickle.load(open(f"../data/Evaluation Data/retrieval_results_{evaluation_dataset}-{collection}-en.pkl", "rb"))
assert pd.isna(prep['Retrieval QIDs']).sum() == 0, "Evaluation not complete"
prep = prep[prep['Correct in Wikipedia']]
prep['Correct QIDs'] = prep['Correct QID'].apply(lambda x: [x])

accuracy_v1 = calculate_accuracy_over_K(prep, 'Retrieval QIDs', 'Correct QIDs')

collection = "wikidata_test_v2"
evaluation_dataset = "REDFM"
prep = pickle.load(open(f"../data/Evaluation Data/retrieval_results_{evaluation_dataset}-{collection}-en.pkl", "rb"))
assert pd.isna(prep['Retrieval QIDs']).sum() == 0, "Evaluation not complete"
prep = prep[prep['Correct in Wikipedia']]
prep['Correct QIDs'] = prep['Correct QID'].apply(lambda x: [x])

accuracy_v2 = calculate_accuracy_over_K(prep, 'Retrieval QIDs', 'Correct QIDs')

# Create a simple bar chart
plt.plot(list(range(len(accuracy_v1))), np.array(accuracy_v1)*100, label='Jina')
plt.plot(list(range(len(accuracy_v2))), np.array(accuracy_v2)*100, label='Nvidia')
plt.title('Accuracy of 1 correct item in REDFM')
plt.xlabel('# Entities Retrieved')
plt.ylabel('Accuracy %')
plt.legend()

# Show the chart
plt.show()


In [None]:
from sqlalchemy.sql import func
from tqdm import tqdm

# Modified query with random ordering
sample_count = sample_ids['from Evaluation'].sum()*2 - (~sample_ids['from Evaluation']).sum()
with tqdm(total=sample_count) as progressbar:
    with Session() as session:
        entities = (
            session.query(WikidataID)
            .filter(WikidataID.in_wikipedia == True)
            .order_by(func.random())  # Adds random ordering
            .yield_per(1000)
        )

        # Example of iterating through the entities
        for entity in tqdm(entities):
            if entity.id not in sample_ids['QID'].values:
                sample_ids = pd.concat([sample_ids, pd.DataFrame([{
                        'QID': entity.id,
                        'from Evaluation': False,
                        'In Wikipedia': True,
                        'Sample 2': True
                    }])], ignore_index=True)
                progressbar.update(1)
            if progressbar.n >= sample_count:
                break

In [None]:
import pickle

# prep = pickle.load(open("/home/philippe.saade/GitHub/WikidataTextEmbedding/data/Evaluation Data/KGConv/processed_dataframe.pkl", "rb"))

sample_ids = pickle.load(open("../data/Evaluation Data/Sample IDs (EN).pkl", "rb"))
sample_ids = sample_ids[sample_ids['In Wikipedia']]

sample_qids_set = set(sample_ids['QID'].values)

# Use vectorized operations for 'not_in_sample'
# prep['Question in Wikipedia'] = prep['Question QID'].isin(sample_qids_set)
# prep['Answer in Wikipedia'] = prep['Answer QID'].isin(sample_qids_set)
# prep

In [None]:
sample_ids[sample_ids['Sample 2']]

In [None]:
for _, row in tqdm(prep.iterrows()):
    for i in range(len(row['Answer QIDs'])):
        if row['Answer in Wikipedia'][i] and row['Answer QIDs'][i] not in sample_qids_set:
            sample_ids = pd.concat([sample_ids, pd.DataFrame([{
                'QID': row['Answer QIDs'][i],
                'from Evaluation': True,
                'In Wikipedia': True,
                'from Evaluation 2': True
            }])], ignore_index=True)

In [None]:
def remove_spans(sentence, spans, replace_with='Entity'):
    # Sort spans in ascending order to remove from left to right
    spans = sorted(spans, key=lambda x: x[0])
    offset = 0  # To track the shift in index after replacing each span

    for start, end in spans:
        sentence = sentence[:start - offset] + replace_with + sentence[end - offset:]
        offset += (end - start) - len(replace_with)  # Update offset to account for the replaced span length

    return sentence

data['Sentence no entity'] = data.apply(lambda x: remove_spans(x['Sentence'], x['Entity Span']), axis=1)