Skip to content

Commit

Permalink
Improve reporting of BEIR results on command line (#1888)
Browse files Browse the repository at this point in the history
> python -m pyserini.2cr.beir --all --display-commands --dry-run
  • Loading branch information
lintool committed May 12, 2024
1 parent d27b944 commit bf68fc5
Showing 1 changed file with 61 additions and 18 deletions.
79 changes: 61 additions & 18 deletions pyserini/2cr/beir.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,10 @@ def run_conditions(args):

for expected in datasets['scores']:
for metric in expected:
if not args.skip_eval:
if not args.skip_eval and not args.dry_run:
if not os.path.exists(runfile):
continue

score = float(run_eval_and_return_metric(metric, f'beir-v1.0.0-{dataset}-test',
trec_eval_metric_definitions[metric], runfile))
if math.isclose(score, float(expected[metric])):
Expand All @@ -250,6 +250,7 @@ def run_conditions(args):

top_level_sums = defaultdict(lambda: defaultdict(float))
cqadupstack_sums = defaultdict(lambda: defaultdict(float))
cqa_scores = defaultdict(lambda: defaultdict(float))
final_scores = defaultdict(lambda: defaultdict(float))

# Compute the running sums to compute the final mean scores
Expand All @@ -267,30 +268,72 @@ def run_conditions(args):
for metric in metrics:
# Compute mean over cqa sub-collections first
cqa_score = cqadupstack_sums[model][metric] / 12
cqa_scores[model][metric] = cqa_score
# Roll cqa scores into final overall mean
final_score = (top_level_sums[model][metric] + cqa_score) / 18
final_scores[model][metric] = final_score

print(' ' * 30 + 'BM25-flat' + ' ' * 10 + 'BM25-mf' + ' ' * 13 + 'SPLADE' + ' ' * 11 + 'Contriever' + ' ' * 5 + 'Contriever-msmarco' + ' ' * 2 + 'BGE-base-en-v1.5' + ' ' * 5 + 'cohere-embed-english-v3.0')
print(' ' * 26 + 'nDCG@10 R@100 ' * 6)
cqa_output_flag = False

print(' ' * 30 + 'BM25-flat' + ' ' * 10 + 'BM25-mf' + ' ' * 13 + 'SPLADE' + ' ' * 11 + 'Contriever' + ' ' * 5 + 'Contriever-msmarco' + ' ' * 2 + 'BGE-base-en-v1.5' + ' ' * 4 + 'cohere-en-v3.0')
print(' ' * 26 + 'nDCG@10 R@100 ' * 7)
print(' ' * 27 + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14)
for dataset in beir_keys:
# The first encounter of 'cqa', print out the average.
if dataset.startswith('cqa') and not cqa_output_flag:
print('cqa' + ' ' * 22 + f'{cqa_scores["bm25-flat"]["nDCG@10"]:8.3f}{cqa_scores["bm25-flat"]["R@100"]:8.3f} ' +
f'{cqa_scores["bm25-multifield"]["nDCG@10"]:8.3f}{cqa_scores["bm25-multifield"]["R@100"]:8.3f} ' +
f'{cqa_scores["splade-pp-ed"]["nDCG@10"]:8.3f}{cqa_scores["splade-pp-ed"]["R@100"]:8.3f} ' +
f'{cqa_scores["contriever"]["nDCG@10"]:8.3f}{cqa_scores["contriever"]["R@100"]:8.3f} ' +
f'{cqa_scores["contriever-msmarco"]["nDCG@10"]:8.3f}{cqa_scores["contriever-msmarco"]["R@100"]:8.3f} ' +
f'{cqa_scores["bge-base-en-v1.5"]["nDCG@10"]:8.3f}{cqa_scores["bge-base-en-v1.5"]["R@100"]:8.3f} ' +
f'{cqa_scores["cohere-embed-english-v3.0"]["nDCG@10"]:8.3f}{cqa_scores["cohere-embed-english-v3.0"]["R@100"]:8.3f}')
cqa_output_flag = True
continue

# Skip all other cqa sub-collections.
if dataset.startswith('cqa'):
continue

print(f'{dataset:25}' +
f'{table[dataset]["bm25-flat"]["nDCG@10"]:8.3f}{table[dataset]["bm25-flat"]["R@100"]:8.3f} ' +
f'{table[dataset]["bm25-multifield"]["nDCG@10"]:8.3f}{table[dataset]["bm25-multifield"]["R@100"]:8.3f} ' +
f'{table[dataset]["splade-pp-ed"]["nDCG@10"]:8.3f}{table[dataset]["splade-pp-ed"]["R@100"]:8.3f} ' +
f'{table[dataset]["contriever"]["nDCG@10"]:8.3f}{table[dataset]["contriever"]["R@100"]:8.3f} ' +
f'{table[dataset]["contriever-msmarco"]["nDCG@10"]:8.3f}{table[dataset]["contriever-msmarco"]["R@100"]:8.3f} ' +
f'{table[dataset]["bge-base-en-v1.5"]["nDCG@10"]:8.3f}{table[dataset]["bge-base-en-v1.5"]["R@100"]:8.3f} ' +
f'{table[dataset]["cohere-embed-english-v3.0"]["nDCG@10"]:8.3f}{table[dataset]["cohere-embed-english-v3.0"]["R@100"]:8.3f}')
print(' ' * 27 + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14)
print('avg' + ' ' * 22 + f'{final_scores["bm25-flat"]["nDCG@10"]:8.3f}{final_scores["bm25-flat"]["R@100"]:8.3f} ' +
f'{final_scores["bm25-multifield"]["nDCG@10"]:8.3f}{final_scores["bm25-multifield"]["R@100"]:8.3f} ' +
f'{final_scores["splade-pp-ed"]["nDCG@10"]:8.3f}{final_scores["splade-pp-ed"]["R@100"]:8.3f} ' +
f'{final_scores["contriever"]["nDCG@10"]:8.3f}{final_scores["contriever"]["R@100"]:8.3f} ' +
f'{final_scores["contriever-msmarco"]["nDCG@10"]:8.3f}{final_scores["contriever-msmarco"]["R@100"]:8.3f} ' +
f'{final_scores["bge-base-en-v1.5"]["nDCG@10"]:8.3f}{final_scores["bge-base-en-v1.5"]["R@100"]:8.3f} ' +
f'{final_scores["cohere-embed-english-v3.0"]["nDCG@10"]:8.3f}{final_scores["cohere-embed-english-v3.0"]["R@100"]:8.3f}')

print('\n')
# Separately print out all the cqa sub-collections.
for dataset in beir_keys:
if not dataset.startswith('cqa'):
continue

print(f'{dataset:25}' +
f'{table[dataset]["bm25-flat"]["nDCG@10"]:8.4f}{table[dataset]["bm25-flat"]["R@100"]:8.4f} ' +
f'{table[dataset]["bm25-multifield"]["nDCG@10"]:8.4f}{table[dataset]["bm25-multifield"]["R@100"]:8.4f} ' +
f'{table[dataset]["splade-pp-ed"]["nDCG@10"]:8.4f}{table[dataset]["splade-pp-ed"]["R@100"]:8.4f} ' +
f'{table[dataset]["contriever"]["nDCG@10"]:8.4f}{table[dataset]["contriever"]["R@100"]:8.4f} ' +
f'{table[dataset]["contriever-msmarco"]["nDCG@10"]:8.4f}{table[dataset]["contriever-msmarco"]["R@100"]:8.4f} ' +
f'{table[dataset]["bge-base-en-v1.5"]["nDCG@10"]:8.4f}{table[dataset]["bge-base-en-v1.5"]["R@100"]:8.4f} ' +
f'{table[dataset]["cohere-embed-english-v3.0"]["nDCG@10"]:8.4f}{table[dataset]["cohere-embed-english-v3.0"]["R@100"]:8.4f}')
f'{table[dataset]["bm25-flat"]["nDCG@10"]:8.3f}{table[dataset]["bm25-flat"]["R@100"]:8.3f} ' +
f'{table[dataset]["bm25-multifield"]["nDCG@10"]:8.3f}{table[dataset]["bm25-multifield"]["R@100"]:8.3f} ' +
f'{table[dataset]["splade-pp-ed"]["nDCG@10"]:8.3f}{table[dataset]["splade-pp-ed"]["R@100"]:8.3f} ' +
f'{table[dataset]["contriever"]["nDCG@10"]:8.3f}{table[dataset]["contriever"]["R@100"]:8.3f} ' +
f'{table[dataset]["contriever-msmarco"]["nDCG@10"]:8.3f}{table[dataset]["contriever-msmarco"]["R@100"]:8.3f} ' +
f'{table[dataset]["bge-base-en-v1.5"]["nDCG@10"]:8.3f}{table[dataset]["bge-base-en-v1.5"]["R@100"]:8.3f} ' +
f'{table[dataset]["cohere-embed-english-v3.0"]["nDCG@10"]:8.3f}{table[dataset]["cohere-embed-english-v3.0"]["R@100"]:8.3f}')
print(' ' * 27 + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14)
print('avg' + ' ' * 22 + f'{final_scores["bm25-flat"]["nDCG@10"]:8.4f}{final_scores["bm25-flat"]["R@100"]:8.4f} ' +
f'{final_scores["bm25-multifield"]["nDCG@10"]:8.4f}{final_scores["bm25-multifield"]["R@100"]:8.4f} ' +
f'{final_scores["splade-pp-ed"]["nDCG@10"]:8.4f}{final_scores["splade-pp-ed"]["R@100"]:8.4f} ' +
f'{final_scores["contriever"]["nDCG@10"]:8.4f}{final_scores["contriever"]["R@100"]:8.4f} ' +
f'{final_scores["contriever-msmarco"]["nDCG@10"]:8.4f}{final_scores["contriever-msmarco"]["R@100"]:8.4f} ' +
f'{final_scores["bge-base-en-v1.5"]["nDCG@10"]:8.4f}{final_scores["bge-base-en-v1.5"]["R@100"]:8.4f} ' +
f'{final_scores["cohere-embed-english-v3.0"]["nDCG@10"]:8.4f}{final_scores["cohere-embed-english-v3.0"]["R@100"]:8.4f}')
print('avg' + ' ' * 22 + f'{cqa_scores["bm25-flat"]["nDCG@10"]:8.3f}{cqa_scores["bm25-flat"]["R@100"]:8.3f} ' +
f'{cqa_scores["bm25-multifield"]["nDCG@10"]:8.3f}{cqa_scores["bm25-multifield"]["R@100"]:8.3f} ' +
f'{cqa_scores["splade-pp-ed"]["nDCG@10"]:8.3f}{cqa_scores["splade-pp-ed"]["R@100"]:8.3f} ' +
f'{cqa_scores["contriever"]["nDCG@10"]:8.3f}{cqa_scores["contriever"]["R@100"]:8.3f} ' +
f'{cqa_scores["contriever-msmarco"]["nDCG@10"]:8.3f}{cqa_scores["contriever-msmarco"]["R@100"]:8.3f} ' +
f'{cqa_scores["bge-base-en-v1.5"]["nDCG@10"]:8.3f}{cqa_scores["bge-base-en-v1.5"]["R@100"]:8.3f} ' +
f'{cqa_scores["cohere-embed-english-v3.0"]["nDCG@10"]:8.3f}{cqa_scores["cohere-embed-english-v3.0"]["R@100"]:8.3f}')

end = time.time()

Expand Down

0 comments on commit bf68fc5

Please sign in to comment.