In [1]:
from rank_bm25 import BM25Okapi
from itertools import chain
import pandas as pd
import numpy as np
import ujson
import json
import ast
import os

In [2]:
file_path = '../data/keypoints.csv'
data = pd.read_csv(file_path, encoding='ISO-8859-1')

def fit_bm25(document):
    flattened_docs = []
    doc_indices = []
    sentence_indices = []

    for doc_idx, doc in enumerate(document):
        for sent_idx, sentence in enumerate(doc):
            flattened_docs.append(sentence)
            doc_indices.append(doc_idx)
            sentence_indices.append(sent_idx)
    
    tokenized_docs = [list(map(str.lower, doc.split(" "))) for doc in flattened_docs]
    bm25 = BM25Okapi(tokenized_docs)
    return bm25, flattened_docs, doc_indices, sentence_indices

def get_top_k_sentences_for_query(query, bm25, flattened_docs, doc_indices, sentence_indices, top_k):
    tokenized_query = query.split(" ")
    scores = bm25.get_scores(tokenized_query)
    top_k_idx = np.argsort(scores)[::-1][:top_k]
    top_k_sentences = [flattened_docs[i] for i in top_k_idx]
    top_k_doc_indices = [doc_indices[i] for i in top_k_idx]
    top_k_sentence_indices = [sentence_indices[i] for i in top_k_idx]
    return top_k_sentences, top_k_doc_indices, top_k_sentence_indices

folder_path = '../knowledge_store/'

counter = 0
claims, info = [], []
for i, r in data.iterrows():
    file_path = folder_path + f"{counter}.json"
    with open(file_path, 'r', encoding='utf-8') as file:
        knowledge_store = [ujson.loads(line) for line in file]
    top_k = []
    claim = r['Claims']
    claims.append(claim)
    main = ast.literal_eval(r['Key Points'])['main']
    combined = ast.literal_eval(r['Key Points'])['combined']
    single = ast.literal_eval(r['Key Points'])['single']
    queries = [item for item in ([main] + combined + single) if item != ""]

    top_k.append(70)
    top_k += [12] * len(combined)
    top_k += [12] * len(single)
    
    urls, document = [], []
    for item in knowledge_store:
        document.append(item['url2text'])
        urls.append(item['url'])

    bm25, flattened_docs, doc_indices, sentence_indices = fit_bm25(document)
    results = []
    for query in queries:
        t_k = top_k[queries.index(query)]
        top_k_sentences, top_k_doc_indices, top_k_sentence_indices = get_top_k_sentences_for_query(query, bm25, flattened_docs, doc_indices, sentence_indices, t_k)
        top_k_sentences = [f"{top_k_sentences[i]} <{top_k_doc_indices[i]}_{top_k_sentence_indices[i]}>" for i in range(len(top_k_doc_indices))]
        results.append({
            'query': query,
            'top_k_sentences': top_k_sentences
        })
    info.append(results)
    counter += 1
    if counter % 100 == 0:
        print(counter)

strs = []

for entry in info:
    result = []
    for dic in entry:
        sentences = dic['top_k_sentences']
        txt = chr(10).join(sentences)
        result.append(txt)
    strs.append("\n------------------------------\n".join(result))

data = []

for i in range(len(claims)):
    data.append({'claim': claims[i], 'retrievals': strs[i]})

with open('../data/retrieval.json', 'w') as f:
    json.dump(data, f, indent=4)

100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
