In [1]:
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir import util, LoggingHandler

dataset = "nfcorpus"

url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = "dataset/{}".format(dataset)
data_path = util.download_and_unzip(url, out_dir)

  from tqdm.autonotebook import tqdm


In [2]:
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

  0%|          | 0/3633 [00:00<?, ?it/s]

In [3]:
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from beir.retrieval.evaluation import EvaluateRetrieval

from beir_models import Splade, BEIRSpladeModel
from transformers import AutoModelForMaskedLM, AutoTokenizer

splade_dir = "../msmarco/Splade_0_MLMTransformer"
model = Splade(splade_dir)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(splade_dir)
beir_splade = BEIRSpladeModel(model, tokenizer)

dres = DRES(beir_splade)
retriever = EvaluateRetrieval(dres, score_function="dot")
results = retriever.retrieve(corpus, queries)
ndcg, map_, recall, p = EvaluateRetrieval.evaluate(qrels, results, [1, 10, 100, 1000])
print(ndcg,recall, p) 

Batches:   0%|          | 0/11 [00:00<?, ?it/s]

Batches:   0%|          | 0/114 [00:00<?, ?it/s]

{'NDCG@1': 0.48452, 'NDCG@10': 0.34692, 'NDCG@100': 0.30894, 'NDCG@1000': 0.38538} {'Recall@1': 0.0621, 'Recall@10': 0.16507, 'Recall@100': 0.28621, 'Recall@1000': 0.56959} {'P@1': 0.49845, 'P@10': 0.24644, 'P@100': 0.077, 'P@1000': 0.01928}


In [4]:
from beir_models import BEIRColBERT
from beir.reranking import Rerank
colbert_dir = "colbert_hn_[q][d]_spladeneg_num20_denoiseFalse_marginkldiv5_from-colbert-batch_size_16-2022-06-30_06-08-47"
epoch = 50000
beir_colbert = BEIRColBERT(f"training_with_sentence_transformers/output/{colbert_dir}/{epoch}/0_ColBERTTransformer")

In [5]:
reranker = Rerank(beir_colbert, batch_size=128)
rerank_results = reranker.rerank(corpus, queries, results, top_k=100)
ndcg, _map, recall, precision = EvaluateRetrieval.evaluate(qrels, rerank_results, retriever.k_values)
print(ndcg,recall, p)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 323/323 [02:15<00:00,  2.38it/s]

{'NDCG@1': 0.48607, 'NDCG@3': 0.41429, 'NDCG@5': 0.38385, 'NDCG@10': 0.34748, 'NDCG@100': 0.30876, 'NDCG@1000': 0.29425} {'Recall@1': 0.06492, 'Recall@3': 0.11089, 'Recall@5': 0.13017, 'Recall@10': 0.16371, 'Recall@100': 0.28622, 'Recall@1000': 0.28622} {'P@1': 0.49845, 'P@10': 0.24644, 'P@100': 0.077, 'P@1000': 0.01928}





In [10]:
sorted(results['PLAIN-2'].items(), key=lambda x: -x[1])

[('MED-2431', 22.63833999633789),
 ('MED-2429', 22.299915313720703),
 ('MED-4827', 21.93878173828125),
 ('MED-14', 21.818187713623047),
 ('MED-10', 21.562702178955078),
 ('MED-2427', 17.89531898498535),
 ('MED-2439', 17.712846755981445),
 ('MED-2525', 17.55705451965332),
 ('MED-2428', 17.493106842041016),
 ('MED-2814', 17.292911529541016),
 ('MED-3856', 16.93328285217285),
 ('MED-4830', 16.731578826904297),
 ('MED-2434', 16.1236572265625),
 ('MED-1732', 15.875133514404297),
 ('MED-2435', 15.646245002746582),
 ('MED-1193', 15.574505805969238),
 ('MED-5001', 15.559112548828125),
 ('MED-950', 15.518030166625977),
 ('MED-3832', 15.262715339660645),
 ('MED-4690', 15.179137229919434),
 ('MED-3551', 15.165375709533691),
 ('MED-1829', 15.114436149597168),
 ('MED-3130', 15.0867280960083),
 ('MED-4162', 15.039274215698242),
 ('MED-3553', 14.985124588012695),
 ('MED-2430', 14.974706649780273),
 ('MED-2102', 14.851646423339844),
 ('MED-4097', 14.785355567932129),
 ('MED-3840', 14.766180992126465),

In [11]:
sorted(rerank_results['PLAIN-2'].items(), key=lambda x: -x[1])

[('MED-2429', 20.067201614379883),
 ('MED-14', 20.012500762939453),
 ('MED-10', 19.887901306152344),
 ('MED-2431', 19.629043579101562),
 ('MED-4827', 18.887794494628906),
 ('MED-2439', 16.6368408203125),
 ('MED-2525', 16.495521545410156),
 ('MED-2428', 16.06807518005371),
 ('MED-2427', 15.341209411621094),
 ('MED-1193', 15.126762390136719),
 ('MED-3856', 15.089019775390625),
 ('MED-3832', 14.57259464263916),
 ('MED-2434', 14.519179344177246),
 ('MED-4162', 14.41952133178711),
 ('MED-4830', 14.413945198059082),
 ('MED-950', 14.25903034210205),
 ('MED-2440', 14.253463745117188),
 ('MED-2430', 14.172355651855469),
 ('MED-3130', 13.971807479858398),
 ('MED-3840', 13.964836120605469),
 ('MED-2102', 13.942554473876953),
 ('MED-5001', 13.936078071594238),
 ('MED-1829', 13.926973342895508),
 ('MED-4759', 13.842513084411621),
 ('MED-5117', 13.825496673583984),
 ('MED-4756', 13.776257514953613),
 ('MED-2122', 13.775945663452148),
 ('MED-4097', 13.757621765136719),
 ('MED-5184', 13.72982311248779

In [14]:
from collections import defaultdict
k1=60
k2=60
top_k = 100

def rankfuse(ranklist1, ranklist2, k1=60,k2=60,top_k=100):
    top_run = defaultdict(dict)
    for q in ranklist1:
        res = sorted(ranklist1[q].items(), key=lambda x: -x[1])[:top_k]
        for idx, item in enumerate(res):
            top_run[q][item[0]] = 1.0/(int(idx+1) + k1)

    for q in ranklist2:
        res = sorted(ranklist2[q].items(), key=lambda x: -x[1])[:top_k]
        for idx, item in enumerate(res):
            if item[0] in top_run[q]:
                top_run[q][item[0]] += 1.0/(int(idx+1) + k2)
            else:
                top_run[q][item[0]] = 1.0/(int(idx+1) + k2)

    return top_run

In [15]:
run = rankfuse(rerank_results, results)

In [16]:
run

defaultdict(dict,
            {'PLAIN-2': {'MED-2429': 0.03252247488101534,
              'MED-14': 0.031754032258064516,
              'MED-10': 0.03125763125763126,
              'MED-2431': 0.032018442622950824,
              'MED-4827': 0.03125763125763126,
              'MED-2439': 0.03007688828584351,
              'MED-2525': 0.029631255487269532,
              'MED-2428': 0.02919863597612958,
              'MED-2427': 0.02964426877470356,
              'MED-1193': 0.027443609022556388,
              'MED-3856': 0.028169014084507043,
              'MED-3832': 0.026547116736990152,
              'MED-2434': 0.0273972602739726,
              'MED-4162': 0.02541827541827542,
              'MED-4830': 0.027222222222222224,
              'MED-950': 0.025978407557354925,
              'MED-2440': 0.023856578204404292,
              'MED-2430': 0.024448419797257006,
              'MED-3130': 0.024706420619185605,
              'MED-3840': 0.023735955056179776,
              'MED-2102':

In [17]:
ndcg, _map, recall, precision = EvaluateRetrieval.evaluate(qrels, rerank_results, retriever.k_values)
print(ndcg,recall, p)

{'NDCG@1': 0.48607, 'NDCG@3': 0.41429, 'NDCG@5': 0.38385, 'NDCG@10': 0.34748, 'NDCG@100': 0.30876, 'NDCG@1000': 0.29425} {'Recall@1': 0.06492, 'Recall@3': 0.11089, 'Recall@5': 0.13017, 'Recall@10': 0.16371, 'Recall@100': 0.28622, 'Recall@1000': 0.28622} {'P@1': 0.49845, 'P@10': 0.24644, 'P@100': 0.077, 'P@1000': 0.01928}


In [4]:
import numpy as np
1/(1+np.exp(3))

0.0066928509242848554