In [1]:
from sklearn.metrics.pairwise import cosine_similarity
import pickle
import json
import numpy as np
from tqdm import tqdm
import jsonlines

In [2]:
with open("corpus_paragraph_biosentvec.pkl","rb") as f:
    corpus_embeddings = pickle.load(f)

In [3]:
with open("test_claim_biosentvec.pkl","rb") as f:
    claim_embeddings = pickle.load(f)

In [4]:
claim_file = "SciFact/claims_test.jsonl"

In [5]:
claims = []
with open(claim_file) as f:
    for line in f:
        claim = json.loads(line)
        claims.append(claim)
claims_by_id = {claim['id']:claim for claim in claims}

In [6]:
all_similarities = {}
for claim_id, claim_embedding in tqdm(claim_embeddings.items()):
    this_claim = {}
    for abstract_id, abstract_embedding in corpus_embeddings.items():
        claim_similarity = cosine_similarity(claim_embedding,abstract_embedding)
        this_claim[abstract_id] = claim_similarity
    all_similarities[claim_id] = this_claim

100%|██████████| 300/300 [09:49<00:00,  1.97s/it]


In [None]:
with open("dev_paragraph_similarities.pkl","rb") as f:
    all_similarities = pickle.load(f)

In [7]:
ordered_corpus = {}
for claim_id, claim_similarities in tqdm(all_similarities.items()):
    corpus_ids = []
    max_similarity = []
    for abstract_id, similarity in claim_similarities.items():
        corpus_ids.append(abstract_id)
        max_similarity.append(np.max(similarity))
    corpus_ids = np.array(corpus_ids)
    sorted_order = np.argsort(max_similarity)[::-1]
    ordered_corpus[claim_id] = corpus_ids[sorted_order]

100%|██████████| 300/300 [00:13<00:00, 22.47it/s]


In [8]:
k = 30
retrieved_corpus = {ID:v[:k] for ID,v in ordered_corpus.items()}

In [None]:
TP = 0
FP = 0

for claim_id, abstract_ids in retrieved_corpus.items():
    gold_ids = claims_by_id[claim_id]["cited_doc_ids"]
    for abstract_id in abstract_ids:
        if int(abstract_id) in gold_ids:
            TP += 1
        else: 
            FP += 1
print("Precision:", TP/(TP+FP))

In [None]:
TP = 0
FN = 0

for claim_id, claim in claims_by_id.items():
    predicted_ids = retrieved_corpus[claim_id]
    for abstract_id in claim["cited_doc_ids"]:
        if str(abstract_id) in predicted_ids:
            TP += 1
        else: 
            FN += 1
print("Recall:", TP/(TP+FN))

In [None]:
TP = 0
FP = 0

for claim_id, abstract_ids in retrieved_corpus.items():
    gold_ids = claims_by_id[claim_id]["evidence"].keys()
    for abstract_id in abstract_ids:
        if abstract_id in gold_ids:
            TP += 1
        else: 
            FP += 1
print("Precision:", TP/(TP+FP))

In [None]:
TP = 0
FN = 0

for claim_id, claim in claims_by_id.items():
    predicted_ids = retrieved_corpus[claim_id]
    for abstract_id in claim["evidence"].keys():
        if abstract_id in predicted_ids:
            TP += 1
        else: 
            FN += 1
print("Recall:", TP/(TP+FN))

In [10]:
with jsonlines.open("SciFact/claims_test_retrieved.jsonl", 'w') as output:
    claim_ids = sorted(list(claims_by_id.keys()))
    for id in claim_ids:
        claims_by_id[id]["retrieved_doc_ids"] = retrieved_corpus[id].tolist()
        output.write(claims_by_id[id])

In [9]:
with jsonlines.open("SciFact/abstract_retrieval_test_"+str(k)+".jsonl", 'w') as output:
    claim_ids = sorted(list(claims_by_id.keys()))
    for id in claim_ids:
        doc_ids = retrieved_corpus[id].tolist()
        doc_ids = [int(id) for id in doc_ids]
        output.write({"claim_id": id, "doc_ids": doc_ids})