Skip to content

Commit

Permalink
Adds trec_utils.evaluate, a function that calls trec_eval and parses …
Browse files Browse the repository at this point in the history
…its results, and a mechanism to build a qrel and run in-memory and return trec_eval measures.
  • Loading branch information
Christophe Van Gysel committed Jan 20, 2017
1 parent 2882114 commit ceb1fa7
Showing 1 changed file with 93 additions and 22 deletions.
115 changes: 93 additions & 22 deletions py/cvangysel/trec_utils.py
Expand Up @@ -11,6 +11,7 @@
import re
import scipy.stats
import shutil
import subprocess
import sys

measures = {
Expand Down Expand Up @@ -626,18 +627,43 @@ def compute_significance(first_trec_eval, second_trec_eval, measures):
return significance_results


def parse_trec_run(f, return_score=False):
def parse_trec_run(f, return_score=False,
ignore_duplicates=False,
ignore_parse_errors=False):
run = collections.defaultdict(dict)

for line in f:
topic, _, candidate, ranking, score, _ = line.strip().split()
for idx, line in enumerate(f):
if not line.strip(): # Skip empty lines.
continue

try:
data = line.strip().split()

# In some old runs, the 7th field contained
# relevance feedback labels.
if len(data) not in (6, 7):
raise ValueError()

assert candidate not in run[topic]
topic_id, _, candidate, ranking, score, _ = data[:6]
except ValueError as e:
logging.error('Encountered parsing error at line %d (%s).',
idx + 1, line.strip())

if not ignore_parse_errors:
raise e

if not ignore_duplicates:
assert candidate not in run[topic_id], (
topic_id, candidate)
elif ignore_duplicates and candidate in run[topic_id]:
logging.warning('Candidate %s occurs at least twice '
'in topic %s at rank %s.',
candidate, topic_id, ranking)

if not return_score:
run[topic][candidate] = float(ranking)
run[topic_id][candidate] = float(ranking)
else:
run[topic][candidate] = float(score)
run[topic_id][candidate] = float(score)

return run

Expand Down Expand Up @@ -667,12 +693,30 @@ def write_run(self, model_name, out_f,
write_run(model_name, data, out_f, max_objects_per_query)


def write_run(model_name, data, out_f,
max_objects_per_query=sys.maxsize,
skip_sorting=False):
return write_ranking(
model_name, data, out_f, max_objects_per_query, skip_sorting,
'{subject} Q0 {object} {rank} {relevance:.40f} {model_name}\n')


def write_qrel(model_name, data, out_f,
max_objects_per_query=sys.maxsize,
skip_sorting=False):
return write_ranking(
None, data, out_f, max_objects_per_query, skip_sorting,
'{subject} 0 {object} {relevance}\n')


class OnlineTRECRun(object):

def __init__(self, name, rank_cutoff=sys.maxsize):
def __init__(self, name, rank_cutoff=sys.maxsize, write_fn=write_run):
self.name = name
self.rank_cutoff = rank_cutoff

self.write_fn = write_fn

self.tmp_file = tempfile.NamedTemporaryFile(
mode='w', delete=False)

Expand All @@ -682,10 +726,10 @@ def __init__(self, name, rank_cutoff=sys.maxsize):
def add_ranking(self, subject_id, object_assesments):
assert self.tmp_file

write_run(self.name,
data={subject_id: object_assesments},
out_f=self.tmp_file,
max_objects_per_query=self.rank_cutoff)
self.write_fn(self.name,
data={subject_id: object_assesments},
out_f=self.tmp_file,
max_objects_per_query=self.rank_cutoff)

def close_and_write(self, out_path, overwrite=True):
assert self.tmp_file
Expand All @@ -702,10 +746,39 @@ def close_and_write(self, out_path, overwrite=True):

os.remove(tmp_file_path)

def close_and_evaluate(self, qrel):
assert self.tmp_file

def write_run(model_name, data, out_f,
max_objects_per_query=sys.maxsize,
skip_sorting=False):
assert isinstance(qrel, OnlineTRECRun) and qrel.write_fn == write_qrel
assert qrel.tmp_file

qrel_tmp_file_path = qrel.tmp_file.name
run_tmp_file_path = self.tmp_file.name

qrel.tmp_file.close()
self.tmp_file.close()

trec_eval = evaluate(run_tmp_file_path, qrel_tmp_file_path)

os.remove(qrel_tmp_file_path)
os.remove(run_tmp_file_path)

return trec_eval


def evaluate(run_path, qrel_path):
assert os.path.exists(run_path)
assert os.path.exists(qrel_path)

command = ['trec_eval -q -m all_trec {} {}'.format(qrel_path, run_path)]
out = subprocess.check_output(command, shell=True)
return parse_trec_eval(out.decode('ascii').split('\n'))


def write_ranking(model_name, data, out_f,
max_objects_per_query,
skip_sorting,
format):
"""
Write a run to an output file.
Expand Down Expand Up @@ -745,11 +818,9 @@ def write_run(model_name, data, out_f,
if isinstance(object_id, bytes):
object_id = object_id.decode('utf8')

out_f.write(
'{subject} Q0 {object} {rank} {relevance:.40f} '
'{model_name}\n'.format(
subject=subject_id,
object=object_id,
rank=rank + 1,
relevance=relevance,
model_name=model_name))
out_f.write(format.format(
subject=subject_id,
object=object_id,
rank=rank + 1,
relevance=relevance,
model_name=model_name))

0 comments on commit ceb1fa7

Please sign in to comment.