Skip to content

Commit

Permalink
fix(evaluate): fix length and hash check on evaluate mixin (#3943)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Nov 17, 2021
1 parent b4f6d5c commit 961f48c
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 5 deletions.
78 changes: 76 additions & 2 deletions docs/fundamentals/document/documentarray-api.md
Expand Up @@ -332,7 +332,63 @@ Note that in the above GPU example we did a conversion. In practice, there is no
da.embeddings = torch_model(da.blobs) # <- no .numpy() is necessary
```

And then in just use `.match(da)`
And then in just use `.match(da)`.

### Evaluate matches

You can easily evaluate the performance of matches via {func}`~jina.types.arrays.mixins.evaluation.EvaluationMixin.evaluate`, provide that you have the groundtruth of the matches.

Jina provides some common metrics used in the information retrieval community that allows one to evaluate the nearest-neighbour matches. These metrics include: precision, recall, R-precision, hit rate, NDCG, etc. The full list of functions can be found in {module}`jina.math.evaluation`.

For example, let's create a `DocumentArray` with random embeddings and matching it to itself:

```python
import numpy as np
from jina import DocumentArray

da = DocumentArray.empty(10)
da.embeddings = np.random.random([10, 3])
da.match(da, exclude_self=True)
```

Now `da.matches` contains the matches. Let's use it as the groundtruth. Now let's create imperfect matches by mixing in ten "noise Documents" to every `d.matches`.

```python
da2 = copy.deepcopy(da)

for d in da2:
d.matches.extend(DocumentArray.empty(10))
d.matches = d.matches.shuffle()

print(da2.evaluate(da, metric='precision_at_k', k=5))
```

Now we should have the average Precision@10 close to 0.5.
```text
0.5399999999999999
```

Note that this value is an average number over all Documents of `da2`. If you want to look at the individual evaluation, you can check {attr}`~jina.Document.evaluations` attribute, e.g.

```python
for d in da2:
print(d.evaluations['precision_at_k'].value)
```

```text
0.4000000059604645
0.6000000238418579
0.5
0.5
0.5
0.4000000059604645
0.5
0.4000000059604645
0.5
0.30000001192092896
```

Note that `evaluate()` works only when two `DocumentArray` have the same length and their Documents are aligned by a hash function. The default hash function simply uses {attr}`~jina.Document.id`. You can specify your own hash function.

## Traverse nested structure

Expand Down Expand Up @@ -488,7 +544,6 @@ DocumentArray([

If you simply want to traverse **all** chunks and matches regardless their levels. You can simply use {meth}`jina.types.arrays.mixins.traverse.TraverseMixin.flatten`. It will return a `DocumentArray` with all chunks and matches flattened into the top-level, no more nested structure.


## Visualization

`DocumentArray` provides the `.plot_embeddings` function to plot Document embeddings in a 2D graph. `visualize` supports two methods
Expand Down Expand Up @@ -544,6 +599,25 @@ da.save('data.bin', file_format='binary')
da1 = DocumentArray.load('data.bin', file_format='binary')
```

## Batching

One can batch a large `DocumentArray` into small ones via {func}`~jina.types.arrays.mixins.group.GroupMixin.batch`. This is useful when a `DocumentArray` is too big to process at once. It is particular useful on `DocumentArrayMemmap`, which ensures the data gets loaded on-demand and in a conservative manner.

```python
from jina import DocumentArray

da = DocumentArray.empty(1000)

for b_da in da.batch(batch_size=256):
print(len(b_da))
```

```text
256
256
256
232
```


## Sampling
Expand Down
12 changes: 12 additions & 0 deletions jina/types/arrays/mixins/evaluation.py
Expand Up @@ -18,6 +18,7 @@ def evaluate(
metric: Union[str, Callable[..., float]],
hash_fn: Optional[Callable[['Document'], str]] = None,
metric_name: Optional[str] = None,
strict: bool = True,
**kwargs,
) -> Optional[float]:
"""Compute ranking evaluation metrics for a given `DocumentArray` when compared with a groundtruth.
Expand All @@ -31,9 +32,13 @@ def evaluate(
:param metric: The name of the metric, or multiple metrics to be computed
:param hash_fn: The function used for identifying the uniqueness of Documents. If not given, then ``Document.id`` is used.
:param metric_name: If provided, the results of the metrics computation will be stored in the `evaluations` field of each Document. If not provided, the name will be computed based on the metrics name.
:param strict: If set, then left and right sides are required to be fully aligned: on the length, and on the semantic of length. These are preventing
you to evaluate on irrelevant matches accidentally.
:param kwargs: Additional keyword arguments to be passed to `metric_fn`
:return: The average evaluation computed or a list of them if multiple metrics are required
"""
if strict:
self._check_length(len(other))

if hash_fn is None:
hash_fn = lambda d: d.id
Expand All @@ -48,6 +53,13 @@ def evaluate(
metric_name = metric_name or metric_fn.__name__
results = []
for d, gd in zip(self, other):
if not strict or hash_fn(d) != hash_fn(gd):
raise ValueError(
f'Document {d} from the left-hand side and '
f'{gd} from the right-hand are not hashed to the same value. '
f'This means your left and right DocumentArray may not be aligned; or it means your '
f'`hash_fn` is badly designed.'
)
if not d.matches or not gd.matches:
raise ValueError(
f'Document {d!r} or {gd!r} has no matches, please check your Document'
Expand Down
55 changes: 52 additions & 3 deletions tests/unit/types/arrays/mixins/test_eval_class.py
@@ -1,3 +1,5 @@
import copy

import numpy as np
import pytest

Expand Down Expand Up @@ -46,13 +48,60 @@ def test_eval_mixin_zero_match(metric_fn, kwargs):
da1.embeddings = np.random.random([10, 256])
da1.match(da1, exclude_self=True)

da2 = DocumentArray.empty(10)
da2 = copy.deepcopy(da1)
da2.embeddings = np.random.random([10, 256])
da2.match(da2, exclude_self=True)

r = da1.evaluate(da2, metric=metric_fn, **kwargs)
assert isinstance(r, float)
assert r == 0.0
assert r == 1.0
for d in da1:
d: Document
assert d.evaluations[metric_fn].value == 0.0
assert d.evaluations[metric_fn].value == 1.0


def test_diff_len_should_raise():
da1 = DocumentArray.empty(10)
da2 = DocumentArray.empty(5)
with pytest.raises(ValueError):
da1.evaluate(da2, metric='precision_at_k')


def test_diff_hash_fun_should_raise():
da1 = DocumentArray.empty(10)
da2 = DocumentArray.empty(10)
with pytest.raises(ValueError):
da1.evaluate(da2, metric='precision_at_k')


def test_same_hash_same_len_fun_should_work():
da1 = DocumentArray.empty(10)
da1.embeddings = np.random.random([10, 3])
da1.match(da1)
da2 = DocumentArray.empty(10)
da2.embeddings = np.random.random([10, 3])
da2.match(da2)
with pytest.raises(ValueError):
da1.evaluate(da2, metric='precision_at_k')
for d1, d2 in zip(da1, da2):
d1.id = d2.id

da1.evaluate(da2, metric='precision_at_k')


def test_adding_noise():
da = DocumentArray.empty(10)

da.embeddings = np.random.random([10, 3])
da.match(da, exclude_self=True)

da2 = copy.deepcopy(da)

for d in da2:
d.matches.extend(DocumentArray.empty(10))
d.matches = d.matches.shuffle()

assert da2.evaluate(da, metric='precision_at_k', k=10) < 1.0

for d in da2:
assert 0.0 < d.evaluations['precision_at_k'].value < 1.0

0 comments on commit 961f48c

Please sign in to comment.