Skip to content

Commit

Permalink
fix: fix match score (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Jan 11, 2022
1 parent 3a906db commit 5c1fb55
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docarray/array/mixins/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def match(
if only_id:
d = Document(id=rhv[_id].id)
else:
d = rhv[int(_id)] # type: Document
d = Document(rhv[int(_id)], copy=True) # type: Document

if d.id in lhv:
d = Document(
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/array/mixins/test_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,3 +515,34 @@ def test_diff_framework_match(ndarray_val):
da = DocumentArray.empty(10)
da.embeddings = ndarray_val
da.match(da)


def test_match_ensure_scores_unique():
import numpy as np
from docarray import DocumentArray

da1 = DocumentArray.empty(4)
da1.embeddings = np.array(
[[0, 0, 0, 0, 1], [1, 0, 0, 0, 0], [1, 1, 1, 1, 0], [1, 2, 2, 1, 0]]
)

da2 = DocumentArray.empty(5)
da2.embeddings = np.array(
[
[0.0, 0.1, 0.0, 0.0, 0.0],
[1.0, 0.1, 0.0, 0.0, 0.0],
[1.0, 1.2, 1.0, 1.0, 0.0],
[1.0, 2.2, 2.0, 1.0, 0.0],
[4.0, 5.2, 2.0, 1.0, 0.0],
]
)

da1.match(da2, metric='euclidean', only_id=False, limit=5)

assert len(da1) == 4
for query in da1:
previous_score = -10000
assert len(query.matches) == 5
for m in query.matches:
assert m.scores['euclidean'].value >= previous_score
previous_score = m.scores['euclidean'].value

0 comments on commit 5c1fb55

Please sign in to comment.