In [1]:
import json
import os
from collections import defaultdict
import numpy as np
import pytrec_eval
from scipy.stats import ttest_rel
from tqdm.notebook import tqdm

from common import QrelDataLoader, weight_add_result

  from tqdm.autonotebook import tqdm


In [2]:
def each_q_evaluate(qrels, results, k_values=[1, 10, 100]):
    map_string = "map_cut." + ",".join([str(k) for k in k_values])
    recall_string = "recall." + ",".join([str(k) for k in k_values])
    precision_string = "P." + ",".join([str(k) for k in k_values])
    ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
    evaluator = pytrec_eval.RelevanceEvaluator(qrels, {map_string, ndcg_string, recall_string, precision_string})
    # evaluator = pytrec_eval.RelevanceEvaluator(qrels, {ndcg_string})
    scores = evaluator.evaluate(results)
    return scores

In [3]:
def pair_t_test(qrels, results1, results2, metric="ndcg_cut_10"):
    evaluate_result1 = each_q_evaluate(qrels, results1)
    evaluate_result2 = each_q_evaluate(qrels, results2)
    result1_scores = [evaluate_result1[qid][metric] if qid in evaluate_result1 else 0 for qid in qrels.keys()]
    result2_scores = [evaluate_result2[qid][metric] if qid in evaluate_result2 else 0 for qid in qrels.keys()]
    test_result = ttest_rel(result1_scores, result2_scores)
    return test_result

In [32]:
def test_per_dataset_w_bm25_dense(source_result, dataset, data_dir_root, result_dir_root):
    data_dir = os.path.join(data_dir_root, dataset)
    queries, qrels = QrelDataLoader(data_folder=data_dir).load(split="test")
    qids = list(qrels.keys())

    bm25_result_path = os.path.join(result_dir_root, dataset, "result/bm25/analysis.json")
    dense_result_path = os.path.join(result_dir_root, dataset, "result/dot/mpnet-v3-mse-beir-dot/analysis.json")

    with open(bm25_result_path) as f:
        bm25_result = json.load(f)
    
    with open(dense_result_path) as f:
        dense_result = json.load(f)
        
    each_result = {"bm25": bm25_result, "dense": dense_result}
    test_results = {}
    for key, target_result in each_result.items():
        test_result = pair_t_test(qrels, source_result, target_result)
        test_results[key] = test_result
        
    return test_results

In [33]:
def test_per_dataset_w_splade(source_result, dataset, data_dir_root, result_dir_root):
    data_dir = os.path.join(data_dir_root, dataset)
    queries, qrels = QrelDataLoader(data_folder=data_dir).load(split="test")
    qids = list(qrels.keys())

    splade_result_path = os.path.join(result_dir_root, dataset, "result/dot/distil-splade/analysis.json")

    result_pathes = {"splade": splade_result_path }
    test_results = {}
    for key, result_path in result_pathes.items():
        with open(result_path) as f:
            target_result = json.load(f)
        test_result = pair_t_test(qrels, source_result, target_result)
        test_results[key] = test_result
    return test_results

In [34]:
# def multiple_test(test_pvalues, alpha):
#     rejects = defaultdict(list)
#     for algo, pvalues in test_pvalues.items():
#         sorted_pvalues = sorted(pvalues.items(), key=lambda x: x[1])
#         sorted_pvalues = [(ds, pvalue) for ds, pvalue in sorted_pvalues if ds not in {"msmarco", "trec-robust04-title"}]
#         M = len(sorted_pvalues)
#         for i, (ds, pvalue) in enumerate(sorted_pvalues):
#             fixed_pvalue = pvalue * M / (i+1) * sum([1/j for j in range(1,i+2)])
#             if fixed_pvalue > alpha:
#                 break
#             rejects[algo].append(ds)
#     return rejects
    

In [72]:
def multiple_test(test_pvalues, alpha):
    rejects = {}
    for algo, pvalues in test_pvalues.items():
        not_rejects = list()
        all_ds = {ds for ds in pvalues.keys() if ds not in {"msmarco", "trec-robust04-title"}}
        sorted_pvalues = sorted(pvalues.items(), key=lambda x: -x[1])
        sorted_pvalues = [(ds, pvalue) for ds, pvalue in sorted_pvalues if ds not in {"msmarco", "trec-robust04-title"}]
        M = len(sorted_pvalues)
        q = sorted_pvalues[0][1]
        for i, (ds, pvalue) in enumerate(sorted_pvalues):
            # fixed_pvalue = pvalue * M / ((M-i) * sum([1/j for j in range(1,(M-i+1))]))
            fixed_pvalue = pvalue * M / (M-i)
            q = min(fixed_pvalue, q)
            print(algo, ds, pvalue, fixed_pvalue)
            if fixed_pvalue < alpha:
                break
            not_rejects.append(ds)
        rejects[algo] = all_ds - set(not_rejects)
    return rejects

In [52]:
data_dir_root = "/home/gaia_data/iida.h/BEIR/datasets/"
result_dir_root = "/home/gaia_data/iida.h/BEIR/C-BM25/results/"
    
cbm25_test_pvalues = defaultdict(dict)
cbm25_test_results = defaultdict(dict)
datasets = ["arguana", "climate-fever", "dbpedia-entity", "fever", "fiqa", "hotpotqa", "msmarco", "nfcorpus", "nq",
      "quora", "scidocs", "scifact", "trec-covid", "trec-robust04-title", "trec-robust04-desc", "webis-touche2020"]
for dataset in datasets:
    cbm25_result_path = os.path.join(result_dir_root, dataset, "result/lss/mpnet-tod/analysis.json")
    with open(cbm25_result_path) as f:
        source_result = list(json.load(f).values())[0]
    test_dataset_results = test_per_dataset_w_bm25_dense(source_result, dataset, data_dir_root, result_dir_root)
    cbm25_test_results[dataset] = test_dataset_results
    for algo, test_dataset_result in test_dataset_results.items():
        if test_dataset_result.statistic > 0:
            cbm25_test_pvalues[algo][dataset] = test_dataset_result.pvalue

In [53]:
cbm25_test_pvalues

defaultdict(dict,
            {'bm25': {'arguana': 2.464945516201108e-40,
              'climate-fever': 4.933654838352857e-62,
              'dbpedia-entity': 1.4446514667235053e-16,
              'fever': 0.0,
              'fiqa': 8.201073482955002e-22,
              'hotpotqa': 1.1579499482805089e-245,
              'msmarco': 2.9292923269298105e-06,
              'nfcorpus': 5.2136776178102795e-06,
              'nq': 5.233053743559298e-183,
              'quora': 1.6705795409597773e-106,
              'scidocs': 3.7572179679093955e-07,
              'scifact': 0.00013028857098378187,
              'trec-covid': 1.2479070705800925e-07,
              'trec-robust04-title': 0.0005139565923109141,
              'trec-robust04-desc': 1.0558573144856432e-05},
             'dense': {'climate-fever': 2.2949247509158473e-12,
              'dbpedia-entity': 0.48272237799948015,
              'fever': 3.8020274747604186e-34,
              'fiqa': 0.5110611664728221,
              'hotpotqa'

In [76]:
multiple_test(cbm25_test_pvalues, 0.05)

bm25 scifact 0.00013028857098378187 0.00013028857098378187
dense fiqa 0.5110611664728221 0.5110611664728221
dense dbpedia-entity 0.48272237799948015 0.5266062305448874
dense trec-robust04-desc 0.01013176065757653 0.012158112789091836


{'bm25': {'arguana',
  'climate-fever',
  'dbpedia-entity',
  'fever',
  'fiqa',
  'hotpotqa',
  'nfcorpus',
  'nq',
  'quora',
  'scidocs',
  'scifact',
  'trec-covid',
  'trec-robust04-desc'},
 'dense': {'climate-fever',
  'fever',
  'hotpotqa',
  'nfcorpus',
  'quora',
  'scidocs',
  'scifact',
  'trec-covid',
  'trec-robust04-desc',
  'webis-touche2020'}}

In [37]:
def load_dataset_and_cbm25_dense_bm25_result(dataset, data_dir_root, result_dir_root):
    data_dir = os.path.join(data_dir_root, dataset)
    queries, qrels = QrelDataLoader(data_folder=data_dir).load(split="test")
    bm25_result_path = os.path.join(result_dir_root, dataset, "result/bm25/analysis.json")
    cbm25_result_path = os.path.join(result_dir_root, dataset, "result/lss/mpnet-tod/analysis.json")
    dense_result_path = os.path.join(result_dir_root, dataset, "result/dot/mpnet-v3-mse-beir-dot/analysis.json")

    with open(cbm25_result_path) as f:
        cbm25_result = list(json.load(f).values())[0]

    with open(bm25_result_path) as f:
        bm25_result = json.load(f)
        
    with open(dense_result_path) as f:
        dense_result = json.load(f)
            
    all_qids = qrels.keys()
    return (all_qids, qrels, cbm25_result, dense_result, bm25_result)

In [55]:
data_dir_root = "/home/gaia_data/iida.h/BEIR/datasets/"
result_dir_root = "/home/gaia_data/iida.h/BEIR/C-BM25/results/"
    
datasets = ["arguana", "climate-fever", "dbpedia-entity", "fever", "fiqa", "hotpotqa", "msmarco", "nfcorpus", "nq",
      "quora", "scidocs", "scifact", "trec-covid", "trec-robust04-title", "trec-robust04-desc", "webis-touche2020"]

hcbm25_test_pvalues = defaultdict(dict)
hcbm25_test_results = defaultdict(dict)
for dataset in tqdm(datasets):
    all_qids, qrels, cbm25_result, dense_result, bm25_result = load_dataset_and_cbm25_dense_bm25_result(dataset, data_dir_root, result_dir_root)
    hbm25_result = weight_add_result(bm25_result, dense_result, all_qids, 0.5)
    hcbm25_result = weight_add_result(cbm25_result, dense_result, all_qids, 0.5)
    hbm25_test_dataset_results = {"hbm25": pair_t_test(qrels, hcbm25_result, hbm25_result)}
    sp_test_dataset_results = test_per_dataset_w_splade(hcbm25_result, dataset, data_dir_root, result_dir_root)
    test_dataset_results = {**hbm25_test_dataset_results, **sp_test_dataset_results}
    hcbm25_test_results[dataset] = test_dataset_results
    for algo, test_dataset_result in test_dataset_results.items():
        if test_dataset_result.statistic > 0:
            hcbm25_test_pvalues[algo][dataset] = test_dataset_result.pvalue
                                                                   

  0%|          | 0/16 [00:00<?, ?it/s]

In [56]:
hcbm25_test_results

defaultdict(dict,
            {'arguana': {'hbm25': Ttest_relResult(statistic=7.6224058309373035, pvalue=4.5583093241402896e-14),
              'splade': Ttest_relResult(statistic=-3.1576434612552964, pvalue=0.0016243289360790471)},
             'climate-fever': {'hbm25': Ttest_relResult(statistic=5.8161653065573, pvalue=7.314597323453154e-09),
              'splade': Ttest_relResult(statistic=5.211544751216174, pvalue=2.1273486714276232e-07)},
             'dbpedia-entity': {'hbm25': Ttest_relResult(statistic=-0.3199438251015226, pvalue=0.7491783822684387),
              'splade': Ttest_relResult(statistic=-4.537767687587477, pvalue=7.534238479434832e-06)},
             'fever': {'hbm25': Ttest_relResult(statistic=2.8377193750063374, pvalue=0.004557436815083441),
              'splade': Ttest_relResult(statistic=-3.3615659747112727, pvalue=0.0007793783757575142)},
             'fiqa': {'hbm25': Ttest_relResult(statistic=0.1908144019112572, pvalue=0.8487308634256152),
              'sp

In [57]:
hcbm25_test_pvalues

defaultdict(dict,
            {'hbm25': {'arguana': 4.5583093241402896e-14,
              'climate-fever': 7.314597323453154e-09,
              'fever': 0.004557436815083441,
              'fiqa': 0.8487308634256152,
              'nfcorpus': 0.42673935264084695,
              'scidocs': 0.6830871007461452,
              'scifact': 0.39890875877594256,
              'trec-covid': 1.2696341895143044e-05,
              'trec-robust04-title': 0.0496963162295799,
              'webis-touche2020': 0.050660407219942166},
             'splade': {'climate-fever': 2.1273486714276232e-07,
              'fiqa': 0.05752742333854616,
              'quora': 2.408375867943711e-55,
              'scidocs': 0.036189138053583675,
              'scifact': 0.032201734792060024,
              'trec-covid': 0.06250663936426228,
              'trec-robust04-title': 0.025365403406381788,
              'trec-robust04-desc': 0.020688456214385968,
              'webis-touche2020': 0.00021353971037997772}})

In [79]:
multiple_test(hcbm25_test_pvalues, 0.05)

hbm25 fiqa 0.8487308634256152 0.8487308634256152
hbm25 scidocs 0.6830871007461452 0.7684729883394134
hbm25 nfcorpus 0.42673935264084695 0.5486648819668032
hbm25 scifact 0.39890875877594256 0.5983631381639138
hbm25 webis-touche2020 0.050660407219942166 0.09118873299589589
hbm25 fever 0.004557436815083441 0.010254232833937742
splade trec-covid 0.06250663936426228 0.06250663936426228
splade fiqa 0.05752742333854616 0.06574562667262418
splade scidocs 0.036189138053583675 0.0482521840714449


{'hbm25': {'arguana', 'climate-fever', 'fever', 'trec-covid'},
 'splade': {'climate-fever',
  'quora',
  'scidocs',
  'scifact',
  'trec-robust04-desc',
  'webis-touche2020'}}