Skip to content

Commit

Permalink
test(drivers): add test for segment driver
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Jul 30, 2020
1 parent 55fca18 commit 79fc7ec
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 8 deletions.
8 changes: 6 additions & 2 deletions jina/counter.py
Expand Up @@ -17,10 +17,14 @@ def __next__(self):
class SimpleCounter(BaseCounter):
def __init__(self, seed: int = 0):
super().__init__(seed)
# note that zero is reserved
if self.seed == 0:
self.seed = 1

def __next__(self):
self.seed += 1 # note the first number starts with 1 and zero is reserved
return self.seed
ret = self.seed
self.seed += 1
return ret


class RandomUintCounter(BaseCounter):
Expand Down
1 change: 0 additions & 1 deletion jina/drivers/craft.py
Expand Up @@ -69,6 +69,5 @@ def _apply(self, doc: 'jina_pb2.Document', *args, **kwargs):
c.parent_id = doc.id
c.level_depth = doc.level_depth + 1
c.mime_type = doc.mime_type
doc.length = len(ret)
else:
self.logger.warning(f'doc {doc.id} at level {doc.level_depth} gives no chunk')
1 change: 0 additions & 1 deletion jina/executors/frameworks.py
@@ -1,7 +1,6 @@
__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"

import subprocess
from abc import abstractmethod

from . import BaseExecutor
Expand Down
79 changes: 79 additions & 0 deletions tests/unit/drivers/test_segmenter_driver.py
@@ -0,0 +1,79 @@
from typing import Dict, List
import numpy as np
from jina.drivers.craft import SegmentDriver
from jina.executors.crafters import BaseSegmenter
from jina.drivers.helper import array2pb
from jina.proto import jina_pb2
from tests import JinaTestCase


class MockSegmenter(BaseSegmenter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.required_keys = {'text'}

def craft(self, text: str, *args, **kwargs) -> List[Dict]:
if text == 'valid':
# length, parent_id and id are protected keys that won't affect the segments
return [{'blob': np.array([0.0, 0.0, 0.0]), 'weight': 0},
{'blob': np.array([1.0, 1.0, 1.0]), 'weight': 1},
{'blob': np.array([2.0, 2.0, 2.0]), 'weight': 2, 'length': 10, 'parent_id': 50, 'id': 10}]
else:
return [{'non_existing_key': 1}]


class SimpleSegmentDriver(SegmentDriver):

def __init__(self, first_chunk_id, *args, **kwargs):
super().__init__(first_chunk_id=first_chunk_id, random_chunk_id=False, *args, **kwargs)

@property
def exec_fn(self):
return self._exec_fn


def create_documents_to_segment():
doc1 = jina_pb2.Document()
doc1.id = 1
doc1.text = 'valid'
doc1.length = 2
doc2 = jina_pb2.Document()
doc2.id = 2
doc2.text = 'invalid'
doc2.length = 2
return [doc1, doc2]


class SegmentDriverTestCase(JinaTestCase):

def test_segment_driver(self):
docs = create_documents_to_segment()
driver = SimpleSegmentDriver(first_chunk_id=3)
executor = MockSegmenter()
driver.attach(executor=executor, pea=None)
driver._apply(docs[0])

self.assertEqual(docs[0].length, 2)
self.assertEqual(docs[1].length, 2)

self.assertEqual(docs[0].chunks[0].id, 3)
self.assertEqual(docs[0].chunks[0].parent_id, docs[0].id)
self.assertEqual(docs[0].chunks[0].blob, array2pb(np.array([0.0, 0.0, 0.0])))
self.assertEqual(docs[0].chunks[0].weight, 0)
self.assertEqual(docs[0].chunks[0].length, 3)

self.assertEqual(docs[0].chunks[1].id, 4)
self.assertEqual(docs[0].chunks[1].parent_id, docs[0].id)
self.assertEqual(docs[0].chunks[1].blob, array2pb(np.array([1.0, 1.0, 1.0])))
self.assertEqual(docs[0].chunks[1].weight, 1)
self.assertEqual(docs[0].chunks[1].length, 3)

self.assertEqual(docs[0].chunks[2].id, 5)
self.assertEqual(docs[0].chunks[2].parent_id, docs[0].id)
self.assertEqual(docs[0].chunks[2].blob, array2pb(np.array([2.0, 2.0, 2.0])))
self.assertEqual(docs[0].chunks[2].weight, 2)
self.assertEqual(docs[0].chunks[2].length, 3)

with self.assertRaises(AttributeError) as error:
driver._apply(docs[1])
self.assertEqual(error.exception.__str__(), '\'Document\' object has no attribute \'non_existing_key\'')
8 changes: 4 additions & 4 deletions tests/unit/test_client.py
Expand Up @@ -17,7 +17,7 @@
cur_dir = os.path.dirname(os.path.abspath(__file__))


class MyTestCase(JinaTestCase):
class ClientTestCase(JinaTestCase):

def test_client(self):
f = Flow().add(uses='_forward')
Expand Down Expand Up @@ -70,13 +70,13 @@ def test_gateway_index_with_args(self):
json={'data': [
'',
''],
'first_doc_id': 5,
'first_doc_id': 0,
})
j = a.json()
self.assertTrue('index' in j)
self.assertEqual(len(j['index']['docs']), 2)
self.assertEqual(j['index']['docs'][0]['id'], 6) # doc zero is reserved
self.assertEqual(j['index']['docs'][1]['id'], 7)
self.assertEqual(j['index']['docs'][0]['id'], 1) # doc zero is reserved
self.assertEqual(j['index']['docs'][1]['id'], 2)
self.assertEqual(j['index']['docs'][0]['uri'],
'')
self.assertEqual(a.status_code, 200)
Expand Down

0 comments on commit 79fc7ec

Please sign in to comment.