Skip to content

Commit

Permalink
test(driver): refactor rank driver tests (#2135)
Browse files Browse the repository at this point in the history
* test: refactor driver tests

* test: fix types

* fix: cooments

* test: fix wrong setting

* fix: revert changes
  • Loading branch information
Yongxuanzhang committed Mar 16, 2021
1 parent fba4308 commit b3550e9
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 67 deletions.
Expand Up @@ -3,7 +3,7 @@
from jina import Document
from jina.drivers.rank.aggregate import AggregateMatches2DocRankDriver
from jina.executors.rankers import Chunk2DocRanker
from jina.proto import jina_pb2
from jina.types.score import NamedScore
from jina.types.sets import DocumentSet


Expand Down Expand Up @@ -67,38 +67,22 @@ def create_document_to_score_same_depth_level():
# | matches: (id: 4, parent_id: 30, score.value: 20, length: 2),
# | matches: (id: 5, parent_id: 30, score.value: 10, length: 1),

doc = jina_pb2.DocumentProto()
doc.id = str(1) * 16
doc = Document()
doc.id = 1

match2 = doc.matches.add()
match2.id = str(2) * 16
match2.parent_id = str(20) * 8
match2.length = 3
match2.score.ref_id = doc.id
match2.score.value = 30

match3 = doc.matches.add()
match3.id = str(3) * 16
match3.parent_id = str(20) * 8
match3.length = 4
match3.score.ref_id = doc.id
match3.score.value = 40

match4 = doc.matches.add()
match4.id = str(4) * 16
match4.parent_id = str(30) * 8
match4.length = 2
match4.score.ref_id = doc.id
match4.score.value = 20

match5 = doc.matches.add()
match5.id = str(4) * 16
match5.parent_id = str(30) * 8
match5.length = 1
match5.score.ref_id = doc.id
match5.score.value = 10

return Document(doc)
for match_id, parent_id, match_score, match_length in [
(2, 20, 30, 3),
(3, 20, 40, 4),
(4, 30, 20, 2),
(5, 30, 10, 1),
]:
match = Document()
match.id = match_id
match.parent_id = parent_id
match.length = match_length
match.score = NamedScore(value=match_score, ref_id=doc.id)
doc.matches.append(match)
return doc


def test_collect_matches2doc_ranker_driver_mock_ranker():
Expand All @@ -109,10 +93,10 @@ def test_collect_matches2doc_ranker_driver_mock_ranker():
driver()
dm = list(doc.matches)
assert len(dm) == 2
assert dm[0].id == '20' * 8
assert dm[0].id == '20'
assert dm[0].score.value == 3
assert dm[1].id == '30' * 8
assert dm[1].score.value == 1
assert dm[1].id == '30'
assert dm[1].score.value == 2
for match in dm:
# match score is computed w.r.t to doc.id
assert match.score.ref_id == doc.id
Expand All @@ -132,20 +116,20 @@ def test_collect_matches2doc_ranker_driver_min_ranker(keep_source_matches_as_chu
min_value_30 = sys.maxsize
min_value_20 = sys.maxsize
for match in doc.matches:
if match.parent_id == '30' * 8:
if match.parent_id == '30':
if match.score.value < min_value_30:
min_value_30 = match.score.value
if match.parent_id == '20' * 8:
if match.parent_id == '20':
if match.score.value < min_value_20:
min_value_20 = match.score.value

assert min_value_30 < min_value_20
driver()
dm = list(doc.matches)
assert len(dm) == 2
assert dm[0].id == '30' * 8
assert dm[0].id == '30'
assert dm[0].score.value == pytest.approx((1.0 / (1.0 + min_value_30)), 0.0000001)
assert dm[1].id == '20' * 8
assert dm[1].id == '20'
assert dm[1].score.value == pytest.approx((1.0 / (1.0 + min_value_20)), 0.0000001)
for match in dm:
# match score is computed w.r.t to doc.id
Expand All @@ -166,9 +150,9 @@ def test_collect_matches2doc_ranker_driver_max_ranker(keep_source_matches_as_chu
driver()
dm = list(doc.matches)
assert len(dm) == 2
assert dm[0].id == '20' * 8
assert dm[0].id == '20'
assert dm[0].score.value == 40
assert dm[1].id == '30' * 8
assert dm[1].id == '30'
assert dm[1].score.value == 20
for match in dm:
# match score is computed w.r.t to doc.id
Expand Down
49 changes: 24 additions & 25 deletions tests/unit/drivers/rank/aggregate/test_chunk2doc_rank_drivers.py
Expand Up @@ -3,7 +3,7 @@
from jina import Document
from jina.drivers.rank.aggregate import Chunk2DocRankDriver
from jina.executors.rankers import Chunk2DocRanker
from jina.proto import jina_pb2
from jina.types.score import NamedScore
from jina.types.sets import DocumentSet

DISCOUNT_VAL = 0.5
Expand Down Expand Up @@ -72,25 +72,26 @@ def create_document_to_score():
# |- chunk: 3
# |- matches: (id: 6, parent_id: 60, score.value: 6),
# |- matches: (id: 7, parent_id: 70, score.value: 7)
doc = jina_pb2.DocumentProto()
doc = Document()
doc.id = '1'
for c in range(2):
chunk = doc.chunks.add()
chunk = Document()
chunk_id = str(c + 2)
chunk.id = chunk_id
for m in range(2):
match = chunk.matches.add()
match = Document()
match_id = 2 * int(chunk_id) + m
match.id = str(match_id)
parent_id = 10 * int(match_id)
match.parent_id = str(parent_id)
match.length = int(match_id)
# to be used by MaxRanker and MinRanker
match.score.ref_id = chunk.id
match.score.value = int(match_id)
match.score = NamedScore(value=int(match_id), ref_id=chunk.id)
match.tags['price'] = match.score.value
match.tags['discount'] = DISCOUNT_VAL
return Document(doc)
chunk.matches.append(match)
doc.chunks.append(chunk)
return doc


def create_chunk_matches_to_score():
Expand All @@ -101,25 +102,25 @@ def create_chunk_matches_to_score():
# |- chunks: (id: 20)
# |- matches: (id: 21, parent_id: 2, score.value: 4),
# |- matches: (id: 22, parent_id: 2, score.value: 5)
doc = jina_pb2.DocumentProto()
doc = Document()
doc.id = '1'
doc.granularity = 0
num_matches = 2
for parent_id in range(1, 3):
chunk = doc.chunks.add()
chunk = Document()
chunk_id = parent_id * 10
chunk.id = str(chunk_id)
chunk.granularity = doc.granularity + 1
for score_value in range(parent_id * 2, parent_id * 2 + num_matches):
match = chunk.matches.add()
match = Document()
match.granularity = chunk.granularity
match.parent_id = str(parent_id)
match.score.value = score_value
match.score.ref_id = chunk.id
match.score = NamedScore(value=score_value, ref_id=chunk.id)
match.id = str(10 * int(parent_id) + score_value)
match.length = 4

return Document(doc)
chunk.matches.append(match)
doc.chunks.append(chunk)
return doc


def create_chunk_chunk_matches_to_score():
Expand All @@ -131,34 +132,33 @@ def create_chunk_chunk_matches_to_score():
# |- chunks: (id: 20)
# |- matches: (id: 21, parent_id: 2, score.value: 4),
# |- matches: (id: 22, parent_id: 2, score.value: 5)
doc = jina_pb2.DocumentProto()
doc = Document()
doc.id = '100'
doc.granularity = 0
chunk = doc.chunks.add()
chunk = Document()
chunk.id = '101'
chunk.parent_id = doc.id
chunk.granularity = doc.granularity + 1
num_matches = 2
for parent_id in range(1, 3):
chunk_chunk = chunk.chunks.add()
chunk_chunk = Document()
chunk_chunk.id = str(parent_id * 10)
chunk_chunk.parent_id = str(parent_id)
chunk_chunk.granularity = chunk.granularity + 1
for score_value in range(parent_id * 2, parent_id * 2 + num_matches):
match = chunk_chunk.matches.add()
match = Document()
match.parent_id = str(parent_id)
match.score.value = score_value
match.score.ref_id = chunk_chunk.id
match.score = NamedScore(value=score_value, ref_id=chunk_chunk.id)
match.id = str(10 * parent_id + score_value)
match.length = 4
chunk_chunk.matches.append(match)
chunk.chunks.append(chunk_chunk)
doc.chunks.append(chunk)
return Document(doc)


@pytest.mark.parametrize(
'executor', [MockMaxRanker(), MockPriceDiscountRanker(), MockLengthRanker()]
)
@pytest.mark.parametrize('keep_source_matches_as_chunks', [False, True])
def test_chunk2doc_ranker_driver_mock_ranker(keep_source_matches_as_chunks, executor):
def test_chunk2doc_ranker_driver_mock_ranker(keep_source_matches_as_chunks):
doc = create_document_to_score()
driver = SimpleChunk2DocRankDriver(
docs=DocumentSet([doc]),
Expand Down Expand Up @@ -196,7 +196,6 @@ def test_chunk2doc_ranker_driver_max_ranker(keep_source_matches_as_chunks):
scale = 1 if not isinstance(executor, MockPriceDiscountRanker) else DISCOUNT_VAL
assert len(doc.matches) == 4
assert doc.matches[0].id == '70'

assert doc.matches[0].score.value == 7 * scale
assert doc.matches[1].id == '60'
assert doc.matches[1].score.value == 6 * scale
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/drivers/rank/test_matches2doc_rank_drivers.py
Expand Up @@ -3,6 +3,7 @@
from jina import Document
from jina.drivers.rank import Matches2DocRankDriver
from jina.executors.rankers import Match2DocRanker
from jina.types.score import NamedScore
from jina.executors.decorators import batching_multi_input
from jina.types.sets import DocumentSet

Expand Down Expand Up @@ -61,7 +62,7 @@ def create_document_to_score():
with Document() as match:
match.id = str(match_id) * match_length
match.length = match_score
match.score.value = match_score
match.score = NamedScore(value=match_score, ref_id=doc.id)
doc.matches.append(match)
return doc

Expand Down

0 comments on commit b3550e9

Please sign in to comment.