Skip to content

Commit

Permalink
feat: use da.evaluate in Evaluator + configurable metrics (#352)
Browse files Browse the repository at this point in the history
  • Loading branch information
gmastrapas committed Jan 25, 2022
1 parent 36ad4a2 commit 82238b1
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 184 deletions.
42 changes: 35 additions & 7 deletions docs/components/tuner/evaluation.md
Expand Up @@ -8,7 +8,7 @@ or integrated in the training loop via the [evaluation callback](#using-the-eval

## Using the evaluator

The evaluator can be used standalone in order to compute the evaluation metrics on a sequence
The evaluator can be used standalone in order to compute evaluation metrics on a sequence
of documents:
```python
from finetuner.tuner.evaluation import Evaluator
Expand All @@ -27,14 +27,14 @@ The `query_data` (or eval data) are the documents that will be evaluated. They c
{term}`class dataset` or the {term}`session dataset` format. They should contain ground truths, in the form of
matches (`doc.matches`) when using session format and in the form of labels when using class format.

If an embedding model is given, the query docs are embedded, otherwise they are assumed to carry
If an embedding model is given, both query and index docs are embedded, otherwise they are assumed to carry
representations.

The `index_data` (or catalog) is an optional argument that defines the dataset against which the
query docs are matched. If not provided, query docs are matched against themselves.

The `evaluate()` method returns the computed metrics as a dictionary, mapping metric names to values.
The computed metrics are the following:
By default, the following metrics are computed:

- Precision
- Recall
Expand All @@ -46,7 +46,7 @@ The computed metrics are the following:
- DCG
- NDCG

More information on Information Retrieval metrics can be found
More details on these Information Retrieval metrics can be found
[here](https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)).

Let's consider an example. First, let's build a model that receives a NumPy array with a single element as input, and
Expand All @@ -63,7 +63,6 @@ class EmbeddingModel(torch.nn.Module):
output_shape: (bs, 10)
"""
return inputs.repeat(1, 10)

```

Now let's create some example data. We will divide into 2 sets, the evaluation data and the index data. The
Expand Down Expand Up @@ -134,7 +133,7 @@ evaluator = Evaluator(query_data, index_data, embed_model)
metrics = evaluator.evaluate(limit=1, distance='euclidean')
print(metrics)
```
```
```json
{
"r_precision": 1.0,
"precision_at_k": 1.0,
Expand All @@ -153,7 +152,7 @@ When evaluating with a bigger matching limit, we expect precision to drop:
metrics = evaluator.evaluate(limit=2, distance='euclidean')
print(metrics)
```
```
```json
{
"r_precision": 1.0,
"precision_at_k": 0.5,
Expand All @@ -167,6 +166,35 @@ print(metrics)
}
```

To customize the computed metrics, an optional `metrics` argument can be provided in
the Evaluator constructor, that maps metric names to metric functions and their keyword
arguments. For example:

```python
from docarray.math.evaluation import precision_at_k, recall_at_k

def f_score_at_k(binary_relevance, max_rel, k=None, beta=1.0):
precision = precision_at_k(binary_relevance, k=k)
recall = recall_at_k(binary_relevance, max_rel, k=k)
return ((1 + beta**2) * precision * recall) / (beta**2 * precision + recall)

metrics = {
'precision@5': (precision_at_k, {'k': 5}),
'recall@5': (recall_at_k, {'k': 5}),
'f1score@5': (f_score_at_k, {'k': 5, 'beta': 1.0})
}

evaluator = Evaluator(query_data, index_data, embed_model, metrics=metrics)
print(evaluator.evaluate(limit=2, distance='euclidean'))
```
```json
{
"precision@5": 0.2,
"recall@5": 1.0,
"f1score@5": 0.33333333333333337
}
```

## Using the evaluation callback

The evaluator can be handy for computing metrics in an evaluation script, or following a `finetuner.fit`
Expand Down
42 changes: 30 additions & 12 deletions finetuner/tuner/callback/evaluation.py
@@ -1,5 +1,5 @@
import math
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple

from ... import embed
from ..evaluation import Evaluator
Expand All @@ -13,29 +13,44 @@

class EvaluationCallback(BaseCallback):
"""
A callback that uses the Evaluator to calculate IR metrics at the end of each epoch. When used
with other callbacks that rely on metrics, like checkpoints and logging, this callback should be
defined first, so that it precedes in execution.
A callback that uses the Evaluator to calculate IR metrics at the end of each epoch.
When used with other callbacks that rely on metrics, like checkpoints and logging,
this callback should be defined first, so that it precedes in execution.
"""

def __init__(
self,
query_data: 'DocumentArray',
index_data: Optional['DocumentArray'] = None,
metrics: Optional[
Dict[str, Tuple[Callable[..., float], Dict[str, Any]]]
] = None,
exclude_self: bool = True,
limit: int = 20,
distance: str = 'cosine',
num_workers: int = 1,
):
"""
:param query_data: Search data used by the evaluator at the end of each epoch, to evaluate the model.
:param index_data: Index data or catalog used by the evaluator at the end of each epoch, to evaluate the model.
:param limit: The number of top search results to consider, when computing the evaluation metrics.
:param distance: The type of distance metric to use when matching query and index docs, available options are
``'cosine'``, ``'euclidean'`` and ``'sqeuclidean'``.
:param num_workers: The number of workers to use when matching query and index data.
:param query_data: Search data used by the evaluator at the end of each epoch,
to evaluate the model.
:param index_data: Index data or catalog used by the evaluator at the end of
each epoch, to evaluate the model.
:param metrics: A dictionary that specifies the metrics to calculate. It maps
metric names to tuples of metric functions and their keyword arguments. If
set to None, default metrics are computed.
:param exclude_self: Whether to exclude self when matching.
:param limit: The number of top search results to consider when computing the
evaluation metrics.
:param distance: The type of distance metric to use when matching query and
index docs, available options are ``'cosine'``, ``'euclidean'`` and
``'sqeuclidean'``.
:param num_workers: The number of workers to use when matching query and
index data.
"""
self._query_data = query_data
self._index_data = index_data
self._metrics = metrics
self._exclude_self = exclude_self
self._limit = limit
self._distance = distance
self._num_workers = num_workers
Expand Down Expand Up @@ -118,8 +133,11 @@ def on_epoch_end(self, tuner: 'BaseTuner'):
)

# compute metrics
evaluator = Evaluator(self._query_data, index_data)
evaluator = Evaluator(self._query_data, index_data, metrics=self._metrics)
tuner.state.eval_metrics = evaluator.evaluate(
limit=self._limit, distance=self._distance, num_workers=self._num_workers
exclude_self=self._exclude_self,
limit=self._limit,
distance=self._distance,
num_workers=self._num_workers,
)
tuner._progress_bar.update(task_id=self._match_pbar_id, visible=False)

0 comments on commit 82238b1

Please sign in to comment.