In [26]:
import pandas as pd
import pyterrier as pt
import numpy as np
import os
from repro_eval.Evaluator import RpdEvaluator
from repro_eval.util import arp, arp_scores
import pytrec_eval
import yaml
if not pt.started():
    pt.init()

# Time Fuse
Fuse the new run with one old run by boosting old (known) documents up and new documents down.

In [102]:
def time_fuse(run_recent, run_old, _lambda=0.5):
    qid_ranking_groups = run_old.groupby('qid')
    qid_ranking_dict = {qid: list(ranking['docno']) for qid, ranking in qid_ranking_groups}
    
    def weigh(row):
        if not qid_ranking_dict.get(row['qid']):
            print("Could not find", row['qid'])
            
        if row['docno'] in qid_ranking_dict.get(row['qid'], []):
            return row['score'] * _lambda ** 2
        else:
            return row['score'] * (1-_lambda) ** 2              
    reranking = run_recent.copy()
    
    # min max normalization per topic
    reranking['score'] = reranking.groupby('qid')['score'].transform(lambda x : x / x.max())
    
    # weight if in old ranking
    reranking['score'] = reranking.progress_apply(weigh, axis=1)
    reranking = reranking.sort_values(['qid','score'], ascending=False).groupby('qid').head(10000)
    reranking['rank'] = reranking.groupby('qid')['score'].rank(ascending=False).astype(int)
    return reranking

In [3]:
bm25_t5 = pt.io.read_results('data/results/trec/CIR_BM25_D-t5_T-t5')
bm25_t4_extended = pt.io.read_results('data/results/trec/CIR_BM25_D-t4_T-t5_extended')

In [4]:
# _lambda = 0.5 + 1e-3
_lambda = 0.5 + 1e-12
bm25_t5_reranked_by_t4 = time_fuse(bm25_t5, bm25_t4_extended, _lambda=_lambda)

 22%|██▏       | 325376/1463631 [00:03<00:11, 100129.99it/s]


KeyboardInterrupt: 

In [None]:
pt.io.write_results(bm25_t5_reranked_by_t4, 'bm25_t5_reranked_by_t4', format='trec',run_name='bm25_t5_reranked_by_t4')
pt.io.write_results(bm25_t5, 'bm25_t5', format='trec', run_name='bm25_t5')

rpd_eval = RpdEvaluator(run_b_orig_path='bm25_t5', run_b_rep_path='bm25_t5_reranked_by_t4')
correlations = rpd_eval.ktau_union().get('baseline')
correlation_scores = [x for x in list(correlations.values()) if ~np.isnan(x)]
print("Avg. Kendall's tau: ", sum(correlation_scores) / len(correlation_scores))

Avg. Kendall's tau:  0.056043715260201694


# Lambda sweep

In [81]:
base_path = "data"
runs_path = "results/trec"
reranked_path = "results/fuse_time"

run_new_path = "CIR_BM25_D-t3_T-t3"
run_old_path = "CIR_BM25_D-t2_T-t3_extended"

with open("data/LongEval/metadata.yml", "r") as yamlfile:
    config = yaml.load(yamlfile, Loader=yaml.FullLoader)

In [103]:
run_new = pt.io.read_results(os.path.join(base_path, runs_path, run_new_path))
run_old = pt.io.read_results(os.path.join(base_path, runs_path, run_old_path))

with open(os.path.join(base_path, config["subcollections"]["t3"]["qrels"]["test"]), "r") as f_qrels:
    qrels = pytrec_eval.parse_qrel(f_qrels)
evaluator = pytrec_eval.RelevanceEvaluator(qrels, pytrec_eval.supported_measures)

results = {}
for l in np.logspace(-3, -12, num=10):  # sweep from 1e-3 to 1e-12
    _lambda = 0.5 + l
    run_reranked = time_fuse(run_new, run_old, _lambda=_lambda)
    results[_lambda] = {}
    
    # write results
    run_name = f'CIR_BM25_D-t3_T-t3_rr-t2-{l}'
    run_reranked_path = os.path.join(base_path, reranked_path, run_name)
    pt.io.write_results(run_reranked, run_reranked_path, format='trec', run_name=run_name)

    # evaluate
    rpd_eval = RpdEvaluator(run_b_orig_path=os.path.join(base_path, runs_path, run_new_path), run_b_rep_path=run_reranked_path)
    
    correlations = rpd_eval.ktau_union().get('baseline')
    correlation_scores = [x for x in list(correlations.values()) if ~np.isnan(x)]
    avg_tau = sum(correlation_scores) / len(correlation_scores)
    results[_lambda]["tau"] = avg_tau
    
    with open(run_reranked_path) as run_reranked:
        run = pytrec_eval.parse_run(run_reranked)
        scores = evaluator.evaluate(run)
        results[_lambda]["arp"] = arp_scores(scores)

 43%|████▎     | 250858/585414 [00:02<00:02, 123910.65it/s]

Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713


100%|██████████| 585414/585414 [00:05<00:00, 116661.91it/s]
 43%|████▎     | 250317/585414 [00:02<00:02, 123777.88it/s]

Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713


100%|██████████| 585414/585414 [00:05<00:00, 116324.29it/s]
 43%|████▎     | 251829/585414 [00:02<00:02, 124825.75it/s]

Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713


100%|██████████| 585414/585414 [00:04<00:00, 117113.72it/s]
 43%|████▎     | 250704/585414 [00:02<00:02, 124158.99it/s]

Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713


100%|██████████| 585414/585414 [00:05<00:00, 116875.40it/s]
 42%|████▏     | 247741/585414 [00:02<00:02, 122791.54it/s]

Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713


100%|██████████| 585414/585414 [00:05<00:00, 115944.91it/s]
 43%|████▎     | 250453/585414 [00:02<00:02, 124132.22it/s]

Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713


100%|██████████| 585414/585414 [00:05<00:00, 116406.06it/s]
 43%|████▎     | 251655/585414 [00:02<00:02, 125090.66it/s]

Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713


100%|██████████| 585414/585414 [00:04<00:00, 117095.85it/s]
 43%|████▎     | 250979/585414 [00:02<00:02, 124421.60it/s]

Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713


100%|██████████| 585414/585414 [00:05<00:00, 117038.65it/s]
 43%|████▎     | 250989/585414 [00:02<00:02, 123601.07it/s]

Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713


100%|██████████| 585414/585414 [00:05<00:00, 116075.11it/s]
 43%|████▎     | 249877/585414 [00:02<00:02, 123762.85it/s]

Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713
Could not find q012351539607713


100%|██████████| 585414/585414 [00:05<00:00, 115964.32it/s]


In [104]:
for l in results.keys():
    print(l, results[l]["tau"], results[l]["arp"]["bpref"])

0.501 0.07894661726059203 0.4369563935417115
0.5001 0.4268032114101738 0.4371686732635674
0.50001 0.8274175099065462 0.43726071550788775
0.500001 0.9377921482960695 0.43728103787875366
0.5000001 0.9530700020867431 0.43728103787875366
0.50000001 0.9545358759874004 0.43728103787875366
0.500000001 0.9547528153347746 0.43728103787875366
0.5000000001 0.9547562635990923 0.43728103787875366
0.50000000001 0.9547562635990923 0.43728103787875366
0.500000000001 0.9547562635990923 0.43728103787875366


In [105]:
# raw ranking
with open(os.path.join(base_path, runs_path, run_new_path)) as run_new:
    run = pytrec_eval.parse_run(run_new)
    scores = evaluator.evaluate(run)
    print(arp_scores(scores)["bpref"])

0.43728103787875366


# Filter Fuse

In [106]:
import numpy as np
from repro_eval.Evaluator import RpdEvaluator
from ranx import Run, fuse

In [107]:
def filter_and_fuse(run_recent, old_runs: list):
    qid_ranking_groups = run_recent.groupby('qid')
    qid_ranking_dict_recent = {qid: pd.Series(ranking['score'].values, ranking['docno']).to_dict() for qid, ranking in qid_ranking_groups}
    
    runs = [Run.from_dict(qid_ranking_dict_recent)]
    
    for run_old in old_runs:
        qid_ranking_groups = run_old.groupby('qid')
        qid_ranking_dict_old = {qid: pd.Series(ranking['score'].values, ranking['docno']).to_dict() for qid, ranking in qid_ranking_groups}
        for qid, ranking in qid_ranking_dict_old.items():
            docs_recent = qid_ranking_dict_recent.get(qid).keys()
            qid_ranking_dict_old[qid] = {docid: score for docid, score in ranking.items() if docid in docs_recent}
        runs.append(Run.from_dict(qid_ranking_dict_old))
    
    combined_run = fuse(runs = runs, method = "rrf")

    return combined_run

In [129]:
base_path = "data"
runs_path = "results/trec"
reranked_path = "results/filter_fuse"

run_new_path = "CIR_BM25_D-t3_T-t3"
run_old_path = "CIR_BM25_D-t2_T-t3_extended"

with open("data/LongEval/metadata.yml", "r") as yamlfile:
    config = yaml.load(yamlfile, Loader=yaml.FullLoader)

In [72]:
run_new = pt.io.read_results(os.path.join(base_path, runs_path, run_new_path))

old_runs = []
for name in os.listdir(os.path.join(base_path, runs_path)):
    if "T-t3" in name and name.endswith("extended"):
        run = pt.io.read_results(os.path.join(base_path, runs_path, name))
        old_runs.append(run)

In [113]:
# find core qids
topic_sets = []
for i in old_runs:
    topic_sets.append(set(i["qid"]))
topic_sets.append(set(run_new["qid"]))

core = set.intersection(*topic_sets)
print(len(core))

597


In [119]:
old_runs_cleaned = []
for run in old_runs:
    old_runs_cleaned.append(run[run["qid"].isin(core)])
    

In [120]:
old_runs = old_runs_cleaned

In [127]:
run_new_cleaned = run_new[run_new["qid"].isin(core)]

In [130]:
run_reranked = filter_and_fuse(run_new_cleaned, old_runs)

In [135]:
run_name = f'CIR_BM25_D-t3_T-t3_rr-ff'
run_reranked_path = os.path.join(base_path, reranked_path, run_name)
run_reranked.save(run_reranked_path,  kind='trec')

In [138]:
rpd_eval = RpdEvaluator(run_b_orig_path=os.path.join(base_path, runs_path, run_new_path), run_b_rep_path=run_reranked_path)

correlations = rpd_eval.ktau_union().get('baseline')
correlation_scores = [x for x in list(correlations.values()) if ~np.isnan(x)]

print("Avg. Kendall's tau: ", sum(correlation_scores) / len(correlation_scores))

with open(run_reranked_path) as run_reranked:
    run = pytrec_eval.parse_run(run_reranked)
    scores = evaluator.evaluate(run)
    print(arp_scores(scores)["bpref"])

Avg. Kendall's tau:  0.004242197558995947
0.4218068440668467
