In [None]:
!pip install faiss-cpu datasets scikit-learn
from google.colab import files
import pickle
import faiss
import numpy as np
from datasets import load_dataset
from sklearn.metrics import average_precision_score

In [None]:
uploaded = files.upload()
input_file = "scifact_evidence_embeddings.pkl"
with open(input_file, "rb") as f:
    embeddings = pickle.load(f)

embedding_list = []
doc_id_list = []
abstract_list = []

for doc, embedding in embeddings.items():
    doc_id, abstract = doc
    doc_id_list.append(doc_id)
    abstract_list.append(abstract)
    embedding_list.append(embedding)


embedding_matrix = np.array(embedding_list).astype('float32')
embedding_dim = embedding_matrix.shape[1]
index = faiss.IndexFlatL2(embedding_dim)

index.add(embedding_matrix)

In [None]:
claim_file = "scifact_claim_embeddings.pkl"
with open(claim_file, "rb") as f:
    claims = pickle.load(f)


claim_embedding_list = []
claim_id_list = []

for claim, claim_embedding in claims.items():
    claim_id_list.append(claim)
    claim_embedding_list.append(claim_embedding)


scifact_corpus = load_dataset("scifact", "corpus")
scifact_claims = load_dataset("scifact", "claims")


claim_id_to_gold = {}

for entry in scifact_claims['train']:
    claim_id = entry['id']
    relevant_ids = entry['cited_doc_ids']
    claim_id_to_gold[claim_id] = relevant_ids


aligned_gold_relevant_ids = []
aligned_claim_embeddings = []

for idx, (claim_id, claim_text) in enumerate(claim_id_list):
    if claim_id in claim_id_to_gold:
        aligned_gold_relevant_ids.append(claim_id_to_gold[claim_id])
        aligned_claim_embeddings.append(claim_embedding_list[idx])

In [None]:
if len(aligned_claim_embeddings) == 0:
    print("No aligned claim embeddings found. Exiting.")
else:
    claim_matrix = np.array(aligned_claim_embeddings).astype('float32')


    assert claim_matrix.ndim == 2, f"claim_matrix has incorrect shape: {claim_matrix.shape}"

    def map_faiss_indices_to_doc_ids(retrieved_indices, doc_id_list):
        """
        Map the FAISS retrieved indices to actual document IDs.
        """
        mapped_ids = []
        for indices in retrieved_indices:
            mapped_ids.append([doc_id_list[idx] for idx in indices])
        return mapped_ids

    def calculate_map(retrieved_indices, relevant_indices):
        """
        Calculate the Mean Average Precision (MAP).
        """
        ap_sum = 0
        for idx in range(len(retrieved_indices)):
            relevant = relevant_indices[idx]
            if len(relevant) == 0:
                continue

            y_true = [1 if i in relevant else 0 for i in retrieved_indices[idx]]
            if sum(y_true) == 0:
                continue

            y_score = [1 / (rank + 1) for rank in range(len(retrieved_indices[idx]))]
            ap_sum += average_precision_score(y_true, y_score)
        return ap_sum / len(relevant_indices)


    k = 20
    D, I = index.search(claim_matrix, k)
    mapped_retrieved_ids = map_faiss_indices_to_doc_ids(I, doc_id_list)


    if len(mapped_retrieved_ids) != len(aligned_gold_relevant_ids):
        print(f"Mismatch between retrieved and relevant indices for k={k}")
    else:
        map_score = calculate_map(mapped_retrieved_ids, aligned_gold_relevant_ids)
        print(f"MAP@20: {map_score}")
