Skip to content

Commit

Permalink
fix(drivers): fix ranker driver
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Jul 28, 2020
1 parent dd43083 commit ed6f904
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 120 deletions.
33 changes: 3 additions & 30 deletions jina/drivers/rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def _apply(self, doc: 'jina_pb2.Document', *args, **kwargs):
query_chunk_meta = {}
match_chunk_meta = {}
for c in doc.chunks:
for k in c.matches:
match_idx.append((k.id, k.parent_id, c.id, k.score.value))
for match in c.matches:
match_idx.append((match.parent_id, match.id, c.id, match.score.value))
query_chunk_meta[c.id] = pb_obj2dict(c, self.exec.required_keys)
match_chunk_meta[k.id] = pb_obj2dict(k, self.exec.required_keys)
match_chunk_meta[match.id] = pb_obj2dict(match, self.exec.required_keys)

# np.uint32 uses 32 bits. np.float32 uses 23 bit mantissa, so integer greater than 2^23 will have their
# least significant bits truncated.
Expand All @@ -52,30 +52,3 @@ def _apply(self, doc: 'jina_pb2.Document', *args, **kwargs):
r.score.ref_id = doc.id # label the score is computed against doc
r.score.value = score
r.score.op_name = exec.__class__.__name__


class DocRankDriver(BaseRankDriver):
"""Score documents' matches based on their features and the query document
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.recursion_order = 'post'

def _apply(self, doc: 'jina_pb2.Document', *args, **kwargs):
# Score all documents' matches
match_idx = []
query_doc_meta = {doc.id: pb_obj2dict(doc, self.exec.required_keys)}
match_doc_meta = {}

for match in doc.matches:
match_idx.append(match.id)
match_doc_meta[match.id] = pb_obj2dict(match, self.exec.required_keys)

if match_idx:
match_idx = np.array(match_idx, dtype=np.float64)
doc_scores = self.exec_fn(match_idx, query_doc_meta, match_doc_meta)

for idx, _, score in enumerate(doc_scores):
doc.matches[idx].score.value = score
doc.matches[idx].score.op_name = exec.__class__.__name__
106 changes: 106 additions & 0 deletions tests/unit/drivers/test_chunk2doc_rank_drivers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os

from jina.proto import jina_pb2
from jina.drivers.rank import Chunk2DocRankDriver
from jina.executors.rankers import Chunk2DocRanker, MaxRanker, MinRanker
from tests import JinaTestCase

cur_dir = os.path.dirname(os.path.abspath(__file__))


class MockLengthRanker(Chunk2DocRanker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.required_keys = {'length'}

def _get_score(self, match_idx, query_chunk_meta, match_chunk_meta, *args, **kwargs):
return match_idx[0][self.col_doc_id], match_chunk_meta[match_idx[0][self.col_chunk_id]]['length']


class SimpleChunk2DocRankDriver(Chunk2DocRankDriver):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@property
def exec_fn(self):
return self._exec_fn


def create_document_to_score():
# doc: 1 - chunk: 2 - match: 4 (parent 40), 5 (parent 50)
# - chunk: 3 - match: 6 (parent 60), 7 (parent 70)
doc = jina_pb2.Document()
doc.id = 1
for c in range(2):
chunk = doc.chunks.add()
chunk.id = doc.id + c + 1
for m in range(2):
match = chunk.matches.add()
match.id = 2 * chunk.id + m
match.parent_id = 10 * match.id
match.length = match.id
# to be used by MaxRanker and MinRanker
match.score.ref_id = chunk.id
match.score.value = match.id
return doc


class Chunk2DocRankerDriverTestCase(JinaTestCase):

def test_chunk2doc_ranker_driver_mock_exec(self):
doc = create_document_to_score()
driver = SimpleChunk2DocRankDriver()
executor = MockLengthRanker()
driver.attach(executor=executor, pea=None)
driver._apply(doc)
self.assertEqual(len(doc.matches), 4)
self.assertEqual(doc.matches[0].id, 70)
self.assertEqual(doc.matches[0].score.value, 7)
self.assertEqual(doc.matches[1].id, 60)
self.assertEqual(doc.matches[1].score.value, 6)
self.assertEqual(doc.matches[2].id, 50)
self.assertEqual(doc.matches[2].score.value, 5)
self.assertEqual(doc.matches[3].id, 40)
self.assertEqual(doc.matches[3].score.value, 4)
for match in doc.matches:
# match score is computed w.r.t to doc.id
self.assertEqual(match.score.ref_id, doc.id)

def test_chunk2doc_ranker_driver_MaxRanker(self):
doc = create_document_to_score()
driver = SimpleChunk2DocRankDriver()
executor = MaxRanker()
driver.attach(executor=executor, pea=None)
driver._apply(doc)
self.assertEqual(len(doc.matches), 4)
self.assertEqual(doc.matches[0].id, 70)
self.assertEqual(doc.matches[0].score.value, 7)
self.assertEqual(doc.matches[1].id, 60)
self.assertEqual(doc.matches[1].score.value, 6)
self.assertEqual(doc.matches[2].id, 50)
self.assertEqual(doc.matches[2].score.value, 5)
self.assertEqual(doc.matches[3].id, 40)
self.assertEqual(doc.matches[3].score.value, 4)
for match in doc.matches:
# match score is computed w.r.t to doc.id
self.assertEqual(match.score.ref_id, doc.id)

def test_chunk2doc_ranker_driver_MinRanker(self):
doc = create_document_to_score()
driver = SimpleChunk2DocRankDriver()
executor = MinRanker()
driver.attach(executor=executor, pea=None)
driver._apply(doc)
self.assertEqual(len(doc.matches), 4)
self.assertEqual(doc.matches[0].id, 40)
self.assertAlmostEqual(doc.matches[0].score.value, 1/(1 + 4))
self.assertEqual(doc.matches[1].id, 50)
self.assertAlmostEqual(doc.matches[1].score.value, 1/(1 + 5))
self.assertEqual(doc.matches[2].id, 60)
self.assertAlmostEqual(doc.matches[2].score.value, 1/(1 + 6))
self.assertEqual(doc.matches[3].id, 70)
self.assertAlmostEqual(doc.matches[3].score.value, 1/(1 + 7))
for match in doc.matches:
# match score is computed w.r.t to doc.id
self.assertEqual(match.score.ref_id, doc.id)
62 changes: 0 additions & 62 deletions tests/unit/drivers/test_rank_drivers.py

This file was deleted.

8 changes: 1 addition & 7 deletions tests/unit/executors/rankers/test_bi_match.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
import unittest

from jina.executors.rankers.bi_match import BiMatchRanker
from tests.unit.executors.rankers import RankerTestCase


class MyTestCase(RankerTestCase):
class BiMatchTestCase(RankerTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ranker = BiMatchRanker()


if __name__ == '__main__':
unittest.main()
8 changes: 1 addition & 7 deletions tests/unit/executors/rankers/test_bm25.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
import unittest

from jina.executors.rankers.tfidf import BM25Ranker
from tests.unit.executors.rankers import RankerTestCase


class MyTestCase(RankerTestCase):
class BM25TestCase(RankerTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ranker = BM25Ranker()


if __name__ == '__main__':
unittest.main()
8 changes: 1 addition & 7 deletions tests/unit/executors/rankers/test_max.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
import unittest

from jina.executors.rankers import MaxRanker
from tests.unit.executors.rankers import RankerTestCase


class MyTestCase(RankerTestCase):
class MaxRankerTestCase(RankerTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ranker = MaxRanker()


if __name__ == '__main__':
unittest.main()
8 changes: 1 addition & 7 deletions tests/unit/executors/rankers/test_tfidf.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
import unittest

from jina.executors.rankers.tfidf import TfIdfRanker
from tests.unit.executors.rankers import RankerTestCase


class MyTestCase(RankerTestCase):
class TfIdfRankerTestCase(RankerTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ranker = TfIdfRanker(threshold=0.2)


if __name__ == '__main__':
unittest.main()

0 comments on commit ed6f904

Please sign in to comment.