In [1]:
import os

os.environ['JAVA_HOME'] = '/usr/lib/jvm/java-11-openjdk-amd64'

print(os.environ.get('JAVA_HOME'))


from pyserini.search.lucene import LuceneSearcher
from datasets import get_dataset
import ir_datasets
import pandas as pd
from collections import defaultdict
import pickle
import time
import numpy as np
from tqdm import trange


/usr/lib/jvm/java-11-openjdk-amd64


In [2]:
searcher = LuceneSearcher.from_prebuilt_index('msmarco-v1-passage')

In [3]:
searcher.set_bm25(0.82, 0.68)

In [4]:
QUERIES_PATH = '/home/catalinlup/MyWorkspace/MasterThesis/datasets/queries/msmarco_psg_queries.train.tsv'
queries = pd.read_csv(QUERIES_PATH, sep='\t', header=None).to_numpy()

In [5]:
print(queries)
query_text = queries[:, 1]
qids = list(map(lambda x: str(x), queries[:, 0]))
print(len(queries))

[[121352 'define extreme']
 [634306 'what does chattel mean on credit history']
 [920825 'what was the great leap forward brainly']
 ...
 [210839 'how can i watch the day after']
 [908165 'what to use instead of pgp in windows']
 [50393 'benefits of boiling lemons and drinking juice.']]
808731


In [6]:
QRELS_PATH = '/home/catalinlup/MyWorkspace/MasterThesis/datasets/qrels/msmarco_psg.qrels.train.tsv'
qrels = pd.read_csv(QRELS_PATH, sep='\t', header=None).to_numpy()
qrels

array([[1185869,       0,       0,       1],
       [1185868,       0,      16,       1],
       [ 597651,       0,      49,       1],
       ...,
       [ 559149,       0, 8841547,       1],
       [ 706678,       0, 8841643,       1],
       [ 405466,       0, 8841735,       1]])

In [7]:
relevant_docs_by_qid = defaultdict(list)
for row in qrels:
    qid = row[0]
    docid = row[2]
    relevant_docs_by_qid[qid].append(docid)

In [8]:
import math

TOP_K = 100
NUM_NEG_SAMPLES = 32
BATCH_SIZE = 1000
NUM_BATCHES = math.ceil(len(qids) / BATCH_SIZE)
CHECKPOINT_PATH = './negative_sample_batches'

In [9]:
def process_sample(qids_sample, hits, num_neg_samples):
    samples_by_qid = defaultdict(lambda: defaultdict(list))

    for qid in qids_sample:
    
        if len(relevant_docs_by_qid[int(qid)]) == 0:
            continue
    
        # print(qid, hits[qid][0].score)
        # print(qid, hits[qid][-1].score)
        for row in hits[qid]:
            
            if int(row.docid) in relevant_docs_by_qid[int(qid)]:
                samples_by_qid[int(qid)]['pos_docs'].append((int(row.docid), row.score))
    
            if len(samples_by_qid[int(qid)]['pos_docs']) >= len(relevant_docs_by_qid[int(qid)]):
                break
    
        for row in reversed(hits[qid]):
            if int(row.docid) in relevant_docs_by_qid[int(qid)]:
                continue
                
            samples_by_qid[int(qid)]['neg_docs'].append((int(row.docid), row.score))
    
            if len(samples_by_qid[int(qid)]['neg_docs']) >= num_neg_samples:
                break
                
    return samples_by_qid

In [10]:

for bi in trange(505, NUM_BATCHES):
    query_text_sample = query_text[bi * BATCH_SIZE : min((bi + 1) * BATCH_SIZE, len(query_text))]
    qids_sample = qids[bi * BATCH_SIZE : min((bi + 1) * BATCH_SIZE, len(qids))]
    
    hits = searcher.batch_search(queries=query_text_sample, qids=qids_sample, k=TOP_K, threads=6)

    samples_by_qid = dict(process_sample(qids_sample, hits, NUM_NEG_SAMPLES))
    pickle.dump(samples_by_qid, open(os.path.join(CHECKPOINT_PATH, f'sample_batch_{bi}.pickle'), 'wb'))
    


100%|█████████████████████████████████████████| 304/304 [12:57<00:00,  2.56s/it]
