In [1]:
import argparse
import os
import numpy as np
import scipy.stats
import statistics
import matplotlib.pyplot as plt
plt.switch_backend('agg')
from operator import itemgetter

plt.style.use('ggplot')

from msmarco_compare import compute_metrics_from_files 

In [2]:
def load_metrics(file):
    metrics = {}
    with open(file, 'r') as f:
        for line in f:
            metric, qid, score = line.split('\t')
            metric = metric.strip()
            qid = qid.strip()
            score = score.strip()
            if qid == 'all':
                continue
            if metric not in metrics:
                metrics[metric] = {}
            metrics[metric][qid] = float(score)

    return metrics

In [3]:
def plot(all_results, ymin=-1, ymax=1, output_path="."):
    fig, ax = plt.subplots(1, 1, figsize=(16, 3))
    all_results.sort(key = itemgetter(1), reverse=True)
    x = [_x+0.5 for _x in range(len(all_results))]
    y = [float(ele[1]) for ele in all_results]
    ax.bar(x, y, width=0.6, align='edge')
    ax.set_xticks(x)
    ax.set_xticklabels([int(ele[0]) for ele in all_results], {'fontsize': 4}, rotation='vertical')
    ax.grid(True)
    ax.set_title("Per-topic analysis on {}".format(metric))
    ax.set_xlabel('Topics')
    ax.set_ylabel('{} Diff'.format(metric))
    ax.set_ylim(ymin, ymax)
    output_fn = os.path.join(output_path, 'per_query_{}.pdf'.format(metric))
    plt.savefig(output_fn, bbox_inches='tight', format='pdf')

In [11]:
# parser = argparse.ArgumentParser()
# parser.add_argument("--base", type=str, help='base run', required=True)
# parser.add_argument("--comparison", type=str, help='comparison run', required=True)
# parser.add_argument("--qrels", type=str, help='qrels', required=True)
# parser.add_argument("--metric", type=str, help='metric', default="map")
# parser.add_argument("--msmarco", action='store_true', default=False, help='whether to use masarco eval script')
# parser.add_argument("--ymin", type=float, help='min value of the y axis', default=-1)
# parser.add_argument("--ymax", type=float, help='max value of the y axis', default=1)

#args = parser.parse_args()

base = "runs/BM25.txt"
comp = "runs/word2vec.txt"
qrels = "qrels.dl19-doc.txt"
#metric = "map set_f.1 P.10 set_recall.10 ndcg"
metric = "map"
msmarco = False
ymin = -1
ymax = 1

if msmarco:
    base_all, base_metrics = compute_metrics_from_files(qrels, base, per_query_score=True) 
    comp_all, comp_metrics = compute_metrics_from_files(qrels, comp, per_query_score=True) 
else:
    os.system(f'G:/python/ir/eval/trec_eval.9.0.4/trec_eval -q -M1000 -m {metric} {qrels} {base} > eval.base')
    os.system(f'G:/python/ir/eval/trec_eval.9.0.4/trec_eval -q -M1000 -m {metric} {qrels} {comp} > eval.comp')

    base_metrics = load_metrics('eval.base')
    comp_metrics = load_metrics('eval.comp')

# trec_eval expects something like 'P.10' on the command line but outputs 'P_10'
if "." in metric:
    metric = "_".join(metric.split("."))

all_results = []
num_better = 0
num_worse = 0
num_unchanged = 0
biggest_gain = 0
biggest_gain_topic = ''
biggest_loss = 0
biggest_loss_topic = ''
if msmarco:
    metric = "MRR@10"
keys = []
for key in base_metrics[metric]:
    base_score = base_metrics[metric][key]
    if key not in comp_metrics[metric]:
        continue
    keys.append(key)
    comp_score = comp_metrics[metric][key]
    diff = comp_score - base_score
    # This is our relatively arbitrary definition of "better", "worse", and "unchanged".
    if diff > 0.01:
        num_better += 1
    elif diff < -0.01:
        num_worse += 1
    else:
        num_unchanged += 1
    if diff > biggest_gain:
        biggest_gain = diff
        biggest_gain_topic = key
    if diff < biggest_loss:
        biggest_loss = diff
        biggest_loss_topic = key
    all_results.append((key, diff))
    print(f'{key}\t{base_score:.4}\t{comp_score:.4}\t{diff:.4}')

# Extract the paired scores
a = [base_metrics[metric][k] for k in keys]
b = [comp_metrics[metric][k] for k in keys]

(tstat, pvalue) = scipy.stats.ttest_rel(a, b)
print(f'base mean: {np.mean(a):.4}')
print(f'comp mean: {np.mean(b):.4}')
print(f't-statistic: {tstat:.6}, p-value: {pvalue:.6}')
print(f'better (diff > 0.01): {num_better:>3}')
print(f'worse  (diff > 0.01): {num_worse:>3}')
print(f'(mostly) unchanged  : {num_unchanged:>3}')
print(f'biggest gain: {biggest_gain:.4} (topic {biggest_gain_topic})')
print(f'biggest loss: {biggest_loss:.4} (topic {biggest_loss_topic})')

plot(all_results, ymin=ymin, ymax=ymax)

1037798	0.0	0.0	0.0
104861	0.0091	0.0091	0.0
1063750	0.0076	0.0074	-0.0002
1103812	0.0292	0.0292	0.0
1106007	0.0041	0.0041	0.0
1110199	0.0711	0.0711	0.0
1112341	0.0056	0.007	0.0014
1113437	0.0067	0.0067	0.0
1114646	0.1019	0.1019	0.0
1114819	0.0158	0.0158	0.0
1115776	0.6092	0.5687	-0.0405
1117099	0.011	0.0125	0.0015
1121402	0.1818	0.1818	0.0
1124210	0.0362	0.0362	0.0
1129237	0.0614	0.0526	-0.0088
1132213	0.3366	0.3366	0.0
1133167	0.0503	0.0503	0.0
130510	0.2063	0.2063	0.0
131843	0.24	0.28	0.04
146187	0.177	0.177	0.0
148538	0.0257	0.0257	0.0
156493	0.0662	0.0662	0.0
182539	0.1626	0.1626	0.0
183378	0.0241	0.0241	0.0
19335	0.1016	0.1016	0.0
207786	0.085	0.085	0.0
264014	0.0565	0.0565	0.0
287683	0.3278	0.3417	0.0139
359349	0.0546	0.0546	0.0
405717	0.1124	0.1124	0.0
443396	0.0	0.0	0.0
451602	0.0059	0.0059	0.0
47923	0.0071	0.0071	0.0
489204	0.0029	0.0029	0.0
490595	0.1293	0.1293	0.0
527433	0.0683	0.0683	0.0
573724	0.14	0.1448	0.0048
833860	0.0348	0.0348	0.0
855410	1.0	1.0	0.0
87181	0.0421	0.0

  ax.set_xticklabels([int(ele[0]) for ele in all_results], {'fontsize': 4}, rotation='vertical')
