-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
113 additions
and
120 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import os | ||
|
||
from jina.proto import jina_pb2 | ||
from jina.drivers.rank import Chunk2DocRankDriver | ||
from jina.executors.rankers import Chunk2DocRanker, MaxRanker, MinRanker | ||
from tests import JinaTestCase | ||
|
||
cur_dir = os.path.dirname(os.path.abspath(__file__)) | ||
|
||
|
||
class MockLengthRanker(Chunk2DocRanker): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.required_keys = {'length'} | ||
|
||
def _get_score(self, match_idx, query_chunk_meta, match_chunk_meta, *args, **kwargs): | ||
return match_idx[0][self.col_doc_id], match_chunk_meta[match_idx[0][self.col_chunk_id]]['length'] | ||
|
||
|
||
class SimpleChunk2DocRankDriver(Chunk2DocRankDriver): | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
@property | ||
def exec_fn(self): | ||
return self._exec_fn | ||
|
||
|
||
def create_document_to_score(): | ||
# doc: 1 - chunk: 2 - match: 4 (parent 40), 5 (parent 50) | ||
# - chunk: 3 - match: 6 (parent 60), 7 (parent 70) | ||
doc = jina_pb2.Document() | ||
doc.id = 1 | ||
for c in range(2): | ||
chunk = doc.chunks.add() | ||
chunk.id = doc.id + c + 1 | ||
for m in range(2): | ||
match = chunk.matches.add() | ||
match.id = 2 * chunk.id + m | ||
match.parent_id = 10 * match.id | ||
match.length = match.id | ||
# to be used by MaxRanker and MinRanker | ||
match.score.ref_id = chunk.id | ||
match.score.value = match.id | ||
return doc | ||
|
||
|
||
class Chunk2DocRankerDriverTestCase(JinaTestCase): | ||
|
||
def test_chunk2doc_ranker_driver_mock_exec(self): | ||
doc = create_document_to_score() | ||
driver = SimpleChunk2DocRankDriver() | ||
executor = MockLengthRanker() | ||
driver.attach(executor=executor, pea=None) | ||
driver._apply(doc) | ||
self.assertEqual(len(doc.matches), 4) | ||
self.assertEqual(doc.matches[0].id, 70) | ||
self.assertEqual(doc.matches[0].score.value, 7) | ||
self.assertEqual(doc.matches[1].id, 60) | ||
self.assertEqual(doc.matches[1].score.value, 6) | ||
self.assertEqual(doc.matches[2].id, 50) | ||
self.assertEqual(doc.matches[2].score.value, 5) | ||
self.assertEqual(doc.matches[3].id, 40) | ||
self.assertEqual(doc.matches[3].score.value, 4) | ||
for match in doc.matches: | ||
# match score is computed w.r.t to doc.id | ||
self.assertEqual(match.score.ref_id, doc.id) | ||
|
||
def test_chunk2doc_ranker_driver_MaxRanker(self): | ||
doc = create_document_to_score() | ||
driver = SimpleChunk2DocRankDriver() | ||
executor = MaxRanker() | ||
driver.attach(executor=executor, pea=None) | ||
driver._apply(doc) | ||
self.assertEqual(len(doc.matches), 4) | ||
self.assertEqual(doc.matches[0].id, 70) | ||
self.assertEqual(doc.matches[0].score.value, 7) | ||
self.assertEqual(doc.matches[1].id, 60) | ||
self.assertEqual(doc.matches[1].score.value, 6) | ||
self.assertEqual(doc.matches[2].id, 50) | ||
self.assertEqual(doc.matches[2].score.value, 5) | ||
self.assertEqual(doc.matches[3].id, 40) | ||
self.assertEqual(doc.matches[3].score.value, 4) | ||
for match in doc.matches: | ||
# match score is computed w.r.t to doc.id | ||
self.assertEqual(match.score.ref_id, doc.id) | ||
|
||
def test_chunk2doc_ranker_driver_MinRanker(self): | ||
doc = create_document_to_score() | ||
driver = SimpleChunk2DocRankDriver() | ||
executor = MinRanker() | ||
driver.attach(executor=executor, pea=None) | ||
driver._apply(doc) | ||
self.assertEqual(len(doc.matches), 4) | ||
self.assertEqual(doc.matches[0].id, 40) | ||
self.assertAlmostEqual(doc.matches[0].score.value, 1/(1 + 4)) | ||
self.assertEqual(doc.matches[1].id, 50) | ||
self.assertAlmostEqual(doc.matches[1].score.value, 1/(1 + 5)) | ||
self.assertEqual(doc.matches[2].id, 60) | ||
self.assertAlmostEqual(doc.matches[2].score.value, 1/(1 + 6)) | ||
self.assertEqual(doc.matches[3].id, 70) | ||
self.assertAlmostEqual(doc.matches[3].score.value, 1/(1 + 7)) | ||
for match in doc.matches: | ||
# match score is computed w.r.t to doc.id | ||
self.assertEqual(match.score.ref_id, doc.id) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,8 @@ | ||
import unittest | ||
|
||
from jina.executors.rankers.bi_match import BiMatchRanker | ||
from tests.unit.executors.rankers import RankerTestCase | ||
|
||
|
||
class MyTestCase(RankerTestCase): | ||
class BiMatchTestCase(RankerTestCase): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.ranker = BiMatchRanker() | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,8 @@ | ||
import unittest | ||
|
||
from jina.executors.rankers.tfidf import BM25Ranker | ||
from tests.unit.executors.rankers import RankerTestCase | ||
|
||
|
||
class MyTestCase(RankerTestCase): | ||
class BM25TestCase(RankerTestCase): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.ranker = BM25Ranker() | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,8 @@ | ||
import unittest | ||
|
||
from jina.executors.rankers import MaxRanker | ||
from tests.unit.executors.rankers import RankerTestCase | ||
|
||
|
||
class MyTestCase(RankerTestCase): | ||
class MaxRankerTestCase(RankerTestCase): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.ranker = MaxRanker() | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,8 @@ | ||
import unittest | ||
|
||
from jina.executors.rankers.tfidf import TfIdfRanker | ||
from tests.unit.executors.rankers import RankerTestCase | ||
|
||
|
||
class MyTestCase(RankerTestCase): | ||
class TfIdfRankerTestCase(RankerTestCase): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.ranker = TfIdfRanker(threshold=0.2) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |