In [1]:
import pandas as pd
import pickle
from sentence_transformers import CrossEncoder
import torch
from tqdm import tqdm
from typing import List, Tuple
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
#The bi-encoder will retrieve 100 documents. We use a cross-encoder, to re-rank the results list to improve the quality
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device)

In [4]:
queries = pd.read_csv(r"queries.doctrain.tsv", sep='\t', header=None, names=['qid', 'query'], index_col=['qid'])['query']

In [5]:
# load semantic search results
with open('semantic_search_results.pkl', 'rb') as f:
    semantic_search_results: List[Tuple[int, int]] = pickle.load(f)  # (qid, pid)

In [6]:
docs = pd.read_table('msmarco-docs.tsv', header=None, names=['id', 'url', 'title', 'body'], usecols=['body'], dtype='string')['body']

In [7]:
docs.head(1)

0    Science & Mathematics Physics The hot glowing ...
Name: body, dtype: string

In [8]:
docs.dropna(inplace=True)

In [9]:
# docs.index.to_series().to_csv('corpus_id_to_doc_id.txt', header=False, index=False)

In [10]:
docs.map(type).value_counts()

body
<class 'str'>    3201821
Name: count, dtype: int64

In [11]:
#Encode all passages
docs = docs.tolist()

In [12]:
# corpus_embeddings = bi_encoder.encode(docs, convert_to_tensor=True, show_progress_bar=True, device=device)

In [13]:
# save embeddings
# torch.save(corpus_embeddings, 'corpus_embeddings.pt')

In [14]:
cross_inp = []
for result in tqdm(semantic_search_results):
    for qid, pid in result:
        cross_inp.append([queries[qid], docs[pid]])

100%|██████████| 367013/367013 [00:18<00:00, 19401.92it/s]


In [15]:
cross_scores = cross_encoder.predict(cross_inp, batch_size=224, show_progress_bar=True)

Batches:   0%|          | 0/52431 [00:00<?, ?it/s]

In [18]:
# save cross scores
torch.save(cross_scores, 'cross_scores.pt')

In [16]:
top_k = 32
n_results = 10

In [20]:
results = []
cross_score_idx = 0
for semantic_search_result in tqdm(semantic_search_results):
    q_i_results = []
    for j in range(n_results):
        q_i_results.append((semantic_search_result[j][1], cross_scores[cross_score_idx]))
        cross_score_idx += 1
    results.append([x[0] for x in sorted(q_i_results, key=lambda x: x[1], reverse=True)])

367013it [00:02, 153200.42it/s]


In [21]:
with open('corpus_id_to_doc_id.txt') as f:
    doc_ids = list(map(int, f.read().splitlines()))

In [22]:
relevance = pd.read_csv(r"msmarco-doctrain-qrels-idconverted.tsv", sep='\t', header=None, names=['qid', 'pid'], index_col=['qid'], dtype=int)['pid']
relevance.map(type).value_counts()

pid
<class 'int'>    367013
Name: count, dtype: int64

In [23]:
pak_sum = 0
for qid, result in tqdm(zip(queries.index, results)):
    for i in range(n_results):
        if relevance[qid] == doc_ids[result[i]]:
            pak_sum += 1 / (i + 1)
            break
pak_sum / len(results)

367013it [00:02, 132359.19it/s]


0.17274428971159622