Skip to content

Commit

Permalink
refactor: change metric evaluator name adapt to eval_at (#1570)
Browse files Browse the repository at this point in the history
* refactor: change metric evaluator name adapt to eval_at

* fix: test rankingevaluation_driver

* fix: test metric name eval

* fix: remove metric from evaluators

* fix: remove metric from mocks

Co-authored-by: Florian Hönicke <hoenicke.florian@gmail.com>
  • Loading branch information
JoanFM and florian-hoenicke authored Jan 6, 2021
1 parent ba9e9ae commit 7200c65
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 23 deletions.
6 changes: 5 additions & 1 deletion jina/drivers/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ def _apply_all(
evaluation.value = self.exec_fn(self.extract(doc), self.extract(groundtruth))
if self._running_avg:
evaluation.value = self.exec.mean
evaluation.op_name = self.exec.metric

if hasattr(self.exec, 'eval_at'):
evaluation.op_name = f'{self.exec.__class__.__name__}@{self.exec.eval_at}'
else:
evaluation.op_name = self.exec.__class__.__name__
evaluation.ref_id = groundtruth.id

def extract(self, doc: 'Document') -> Any:
Expand Down
2 changes: 0 additions & 2 deletions jina/executors/evaluators/rank/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ class PrecisionEvaluator(BaseRankingEvaluator):
It computes how many of the first given `eval_at` matches are found in the groundtruth
"""

metric = 'Precision@N'

def __init__(self,
eval_at: Optional[int] = None,
*args, **kwargs):
Expand Down
2 changes: 0 additions & 2 deletions jina/executors/evaluators/rank/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ class RecallEvaluator(BaseRankingEvaluator):
It computes how many of the first given `eval_at` groundtruth are found in the matches
"""

metric = 'Recall@N'

def __init__(self,
eval_at: Optional[int] = None,
*args, **kwargs):
Expand Down
4 changes: 0 additions & 4 deletions tests/unit/drivers/test_craftevaluation_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@

class MockDiffEvaluator(BaseEvaluator):

@property
def metric(self):
return 'MockDiffEvaluator'

def evaluate(self, actual: Any, desired: Any, *args, **kwargs) -> float:
return abs(len(actual) - len(desired))

Expand Down
4 changes: 0 additions & 4 deletions tests/unit/drivers/test_encodingevaluation_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@ class MockDiffEvaluator(BaseEmbeddingEvaluator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@property
def metric(self):
return 'MockDiffEvaluator'

def evaluate(self, actual: 'np.array', desired: 'np.array', *args, **kwargs) -> float:
""""
:param actual: the embedding of the document (resulting from an Encoder)
Expand Down
20 changes: 10 additions & 10 deletions tests/unit/drivers/test_rankingevaluation_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def expect_parts(self) -> int:
class RunningAvgRankEvaluateDriver(RankEvaluateDriver):

def __init__(self, field: str, *args, **kwargs):
super().__init__(field, running_avg=True, *args, **kwargs)
super().__init__(field, runining_avg=True, *args, **kwargs)

@property
def exec_fn(self):
Expand All @@ -40,7 +40,7 @@ def simple_rank_evaluate_driver(field):


@pytest.fixture
def ruuningavg_rank_evaluate_driver(field):
def runningavg_rank_evaluate_driver(field):
return RunningAvgRankEvaluateDriver(field)


Expand All @@ -66,25 +66,25 @@ def add_matches(doc: jina_pb2.DocumentProto, num_matches):

@pytest.mark.parametrize('field', ['tags__id', 'score__value'])
def test_ranking_evaluate_simple_driver(simple_rank_evaluate_driver,
ground_truth_pairs):
ground_truth_pairs):
simple_rank_evaluate_driver.attach(executor=PrecisionEvaluator(eval_at=2), runtime=None)
simple_rank_evaluate_driver._apply_all(ground_truth_pairs)
for pair in ground_truth_pairs:
doc = pair.doc
assert len(doc.evaluations) == 1
assert doc.evaluations[0].op_name == 'Precision@N'
assert doc.evaluations[0].op_name == 'PrecisionEvaluator@2'
assert doc.evaluations[0].value == 1.0


@pytest.mark.parametrize('field', ['tags__id', 'score__value'])
def test_ranking_evaluate_ruuningavg_driver(ruuningavg_rank_evaluate_driver,
ground_truth_pairs):
ruuningavg_rank_evaluate_driver.attach(executor=PrecisionEvaluator(eval_at=2), runtime=None)
ruuningavg_rank_evaluate_driver._apply_all(ground_truth_pairs)
def test_ranking_evaluate_runningavg_driver(runningavg_rank_evaluate_driver,
ground_truth_pairs):
runningavg_rank_evaluate_driver.attach(executor=PrecisionEvaluator(eval_at=2), runtime=None)
runningavg_rank_evaluate_driver._apply_all(ground_truth_pairs)
for pair in ground_truth_pairs:
doc = pair.doc
assert len(doc.evaluations) == 1
assert doc.evaluations[0].op_name == 'Precision@N'
assert doc.evaluations[0].op_name == 'PrecisionEvaluator@2'
assert doc.evaluations[0].value == 1.0


Expand Down Expand Up @@ -152,7 +152,7 @@ def test_ranking_evaluate_driver_matches_in_chunks(simple_chunk_rank_evaluate_dr
assert len(doc.chunks) == 1
chunk = doc.chunks[0]
assert len(chunk.evaluations) == 1 # evaluation done at chunk level
assert chunk.evaluations[0].op_name == 'Precision@N'
assert chunk.evaluations[0].op_name == 'PrecisionEvaluator@2'
assert chunk.evaluations[0].value == 1.0


Expand Down

0 comments on commit 7200c65

Please sign in to comment.