In [1]:
from ranx import Qrels, Run, evaluate
import json

import pyterrier  as pt
import pandas as pd

import os
from collections import defaultdict
from collection import SparseCollection, SparseCollectionCSR
from retriever import SparseRetriever
from weighting_model import BM25WeightingModel
from backend import TYPE
import json

from tqdm import tqdm

from text2vec import BagOfWords

os.environ["CUDA_VISIBLE_DEVICES"]="1"

dataset_folder = "beir_datasets/msmarco/"

qrels = defaultdict(dict)
run = {}

with open(os.path.join(dataset_folder, "relevant_pairs.jsonl")) as f:
    for q_data in map(json.loads, f):
        qrels[q_data["id"]][q_data["doc_id"]] = 1

        run[q_data["id"]] = q_data["question"]
        
question_ids, question_text = list(zip(*run.items()))

In [2]:
if not pt.started():
    pt.init()
        
indexref = pt.IndexRef.of(f"./{dataset_folder}/terrier_index/")
index = pt.IndexFactory.of(indexref)

def tp_func():
    stops = pt.autoclass("org.terrier.terms.Stopwords")(None)
    stemmer = pt.autoclass("org.terrier.terms.PorterStemmer")(None)
    def _apply_func(row):
        words = row["query"].split(" ") # this is safe following pt.rewrite.tokenise()
        words = [stemmer.stem(w) for w in words if not stops.isStopword(w) ]
        return words
    return _apply_func 

pipe = pt.rewrite.tokenise() >> pt.apply.query(tp_func())
token2id = {word.getKey():i for i,word in enumerate(index.getLexicon()) }

vocab_size = len(index.getLexicon())

def tokenizer(text):
    tokens_ids = []
    for token in pipe(pd.DataFrame([{"qid":0, "query":text.lower()}]))["query"][0]:
        if token in token2id:
            token_id=token2id[token]
            if token_id is not None:
                tokens_ids.append(token_id)
    return tokens_ids

#mp.set_start_method("spawn")

# load collection data and metadata from a previously created folder
print("load collection")
sparse_collection = SparseCollectionCSR.load_from_file(f"{dataset_folder}/csr_terrier_bm25_12_075") # cpu
print("dataset size", sparse_collection.shape)

bow = BagOfWords(tokenizer, vocab_size)

PyTerrier 0.9.2 has loaded Terrier 5.7 (built by craigm on 2022-11-10 18:30) and terrier-helper 0.0.7

No etc/terrier.properties, using terrier.default.properties for bootstrap configuration.


09:17:39.146 [main] WARN org.terrier.structures.BaseCompressingMetaIndex - Structure meta reading data file directly from disk (SLOW) - try index.meta.data-source=fileinmem in the index properties file. 1.9 GiB of memory would be required.
load collection
dataset size (8841823, 1179529)


In [3]:
sparse_retriver = SparseRetriever(sparse_collection, bow, BM25WeightingModel())

print("Num questions", len(question_text))

#questions = questions[:10000]
    
#s = time.time()
# Retrieve by default utilizes the maximum amount of resources available
out = sparse_retriver.retrieve(question_text, top_k=1000, profiling=False, return_scores=True) # TODO load directly to the device
#e = time.time()

Collection is already in BM25 weighting schema, using its parameters
Num questions 6980


  collection_matrix = torch.sparse_csr_tensor(self.crow, self.indice, self.values, self.shape, dtype=self.values.dtype)
100%|██████████| 6980/6980 [01:00<00:00, 114.55it/s]


Retrieval time: 60.93796443939209 QPS 114.54271674831205
Mem transference time: 0.039444923400878906


In [4]:
ranx_run = defaultdict(dict)

for i in range(len(out.ids)):
    scores = out.scores[i].tolist()
    for j in range(len(out.ids[i])):
        ranx_run[question_ids[i]][out.ids[i][j]] = scores[j]

In [5]:
qrels = Qrels(qrels)
ranx_sp_run = Run(ranx_run)

In [6]:
evaluate(qrels, ranx_sp_run, ["map@1000", "mrr", "ndcg@1000"])

{'map@1000': 0.19365641766430777,
 'mrr': 0.19688139749257347,
 'ndcg@1000': 0.3164874107311945}

In [7]:
### 

In [8]:
bm25_pipe = pt.rewrite.tokenise() >> pt.BatchRetrieve(index, wmodel="BM25", num_results=1000)#).parallel(3)
        
run_pyterrier = defaultdict(dict)

for i, question in enumerate(tqdm(question_text)):
    
    questions_dataframe = pd.DataFrame([{"qid":0, "query":question.lower()}])

    df_results = bm25_pipe.transform(questions_dataframe)
    
    for _, row in df_results.iterrows():    
        run_pyterrier[question_ids[i]][row["docno"]] = row["score"]

100%|██████████| 6980/6980 [18:15<00:00,  6.37it/s]


In [9]:
ranx_run_pyterrier = Run(run_pyterrier)

In [10]:
evaluate(qrels, ranx_run_pyterrier, ["map@1000", "mrr", "ndcg@1000"], make_comparable=True)

{'map@1000': 0.19364555151521373,
 'mrr': 0.19687720303739742,
 'ndcg@1000': 0.31645799399293445}

In [27]:
df_results

Unnamed: 0,qid,docid,docno,rank,score,query_0,query
0,0,8009380,8009380,0,29.169715,glioma meaning,glioma meaning
1,0,6358134,6358134,1,26.253322,glioma meaning,glioma meaning
2,0,8009378,8009378,2,26.250039,glioma meaning,glioma meaning
3,0,8009383,8009383,3,26.250039,glioma meaning,glioma meaning
4,0,6001369,6001369,4,26.109108,glioma meaning,glioma meaning
5,0,4381975,4381975,5,25.877556,glioma meaning,glioma meaning
6,0,3908516,3908516,6,25.547754,glioma meaning,glioma meaning
7,0,6344704,6344704,7,25.411167,glioma meaning,glioma meaning
8,0,1906021,1906021,8,25.406041,glioma meaning,glioma meaning
9,0,8106537,8106537,9,25.326009,glioma meaning,glioma meaning


In [50]:
qrels["264827"]

{'7071066': 1}

In [51]:
run_pyterrier["264827"]

{'7071072': 38.779505472536314,
 '5730193': 37.883670468410756,
 '6246494': 37.8742953586615,
 '7071066': 36.64356242034553,
 '2577289': 35.56788058562289,
 '8289021': 35.40961137723089,
 '5730192': 35.21928198733615,
 '645426': 34.271307365645356,
 '6786307': 34.04565945158207,
 '8289022': 33.91535999544205}

In [52]:
ranx_run["264827"]

{'7071072': 38.786888122558594,
 '5730193': 37.912994384765625,
 '6246494': 37.90830612182617,
 '7071066': 36.649681091308594,
 '2577289': 35.59677505493164,
 '8289021': 35.415245056152344,
 '5730192': 35.22487258911133,
 '645426': 34.27838897705078,
 '6786307': 34.06928253173828,
 '8289022': 33.92045211791992,
 '8289027': 33.04380416870117,
 '7071070': 32.890533447265625,
 '8289026': 32.7933349609375,
 '8289025': 32.54016876220703,
 '5730189': 32.462318420410156,
 '6246492': 32.440635681152344,
 '2577288': 32.290958404541016,
 '6365854': 32.240684509277344,
 '8289024': 31.794281005859375,
 '5813009': 31.587934494018555,
 '7071067': 31.460586547851562,
 '6246499': 31.366180419921875,
 '7071071': 31.344715118408203,
 '632140': 31.331851959228516,
 '6365851': 31.28065299987793,
 '7982528': 30.927696228027344,
 '5708657': 30.859756469726562,
 '6786306': 30.490520477294922,
 '6786304': 30.391342163085938,
 '7350773': 30.210357666015625,
 '5730196': 30.209400177001953,
 '5708660': 29.971702

In [49]:
question_ids[20]

'264827'