In [None]:
!pip install rank-bm25
!pip install sentence_transformers
!pip install nltk
import json
from rank_bm25 import BM25Okapi
from tqdm import tqdm
from sentence_transformers import CrossEncoder
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize

In [None]:
with open('/kaggle/input/quantemp/corpus_evidence_unified.json', 'r') as f:
    evidence_corpus = json.load(f)

# evidence_corpus[0]

In [None]:
# Load the claims
with open(r'/kaggle/input/quantemp/test_claims_quantemp.json', 'r') as f:
    claims = json.load(f)


In [None]:
corpus = [evidence for evidence in tqdm(evidence_corpus.values())]
tokenized_corpus = [doc.split() for doc in corpus]

In [None]:
bm25 = BM25Okapi(tokenized_corpus)

In [None]:
def retrieve_evidence(claim, k=100):
    tokenized_claim = claim.split()
    scores = bm25.get_scores(tokenized_claim)
    top_k_indices = scores.argsort()[-k:][::-1]
    top_k_documents = [corpus[idx] for idx in top_k_indices]
    return top_k_documents

In [None]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
re_ranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
re_ranker.model.to(device)

def rerank_snippets(claim, evidence_list, top_k=5):
    pairs = [[claim, snippet] for snippet in evidence_list if snippet.strip()]  # Remove empty or whitespace-only snippets
    if not pairs:
        return ""
    scores = re_ranker.predict(pairs)
    ranked_snippets = [snippet for _, snippet in sorted(zip(scores, evidence_list), reverse=True)]
    return ' '.join(ranked_snippets[:top_k]) 


In [None]:
total_evidence = []
nli_data = []
# claims_sub = claims[:5]

for claim in tqdm(claims):
    claim_text = claim['claim']
    evidence_list = retrieve_evidence(claim_text)
    
    evidence_input = {
        'claim': claim_text,
        'evidence_snippet_list': evidence_list
    }
    
    total_evidence.append(evidence_input)
    reranked_evidence = rerank_snippets(claim_text, evidence_list)
    nli_input = {
        'claim': claim_text,
        'evidence': reranked_evidence
    }
    nli_data.append(nli_input)

In [None]:
with open('/kaggle/working/nli_input_test_reranktop5.json', 'w') as f:
    json.dump(nli_data, f)

In [None]:
with open('/kaggle/working/nli_input_test_evidence_withoutrerank.json', 'w') as f:
    json.dump(total_evidence, f)