From 79fc7ecc76bc3878524a3557a5e5f214748ffb76 Mon Sep 17 00:00:00 2001 From: Joan Fontanals Martinez Date: Wed, 29 Jul 2020 17:39:37 +0200 Subject: [PATCH] test(drivers): add test for segment driver --- jina/counter.py | 8 ++- jina/drivers/craft.py | 1 - jina/executors/frameworks.py | 1 - tests/unit/drivers/test_segmenter_driver.py | 79 +++++++++++++++++++++ tests/unit/test_client.py | 8 +-- 5 files changed, 89 insertions(+), 8 deletions(-) create mode 100644 tests/unit/drivers/test_segmenter_driver.py diff --git a/jina/counter.py b/jina/counter.py index 81cc746fa3be4..e5d8df4d11b2f 100644 --- a/jina/counter.py +++ b/jina/counter.py @@ -17,10 +17,14 @@ def __next__(self): class SimpleCounter(BaseCounter): def __init__(self, seed: int = 0): super().__init__(seed) + # note that zero is reserved + if self.seed == 0: + self.seed = 1 def __next__(self): - self.seed += 1 # note the first number starts with 1 and zero is reserved - return self.seed + ret = self.seed + self.seed += 1 + return ret class RandomUintCounter(BaseCounter): diff --git a/jina/drivers/craft.py b/jina/drivers/craft.py index bca7faf2ed2ee..6745079bbb240 100644 --- a/jina/drivers/craft.py +++ b/jina/drivers/craft.py @@ -69,6 +69,5 @@ def _apply(self, doc: 'jina_pb2.Document', *args, **kwargs): c.parent_id = doc.id c.level_depth = doc.level_depth + 1 c.mime_type = doc.mime_type - doc.length = len(ret) else: self.logger.warning(f'doc {doc.id} at level {doc.level_depth} gives no chunk') diff --git a/jina/executors/frameworks.py b/jina/executors/frameworks.py index 3ea5c4d2fd433..9ebe9d18b39e8 100644 --- a/jina/executors/frameworks.py +++ b/jina/executors/frameworks.py @@ -1,7 +1,6 @@ __copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved." __license__ = "Apache-2.0" -import subprocess from abc import abstractmethod from . import BaseExecutor diff --git a/tests/unit/drivers/test_segmenter_driver.py b/tests/unit/drivers/test_segmenter_driver.py new file mode 100644 index 0000000000000..f6868f1120f27 --- /dev/null +++ b/tests/unit/drivers/test_segmenter_driver.py @@ -0,0 +1,79 @@ +from typing import Dict, List +import numpy as np +from jina.drivers.craft import SegmentDriver +from jina.executors.crafters import BaseSegmenter +from jina.drivers.helper import array2pb +from jina.proto import jina_pb2 +from tests import JinaTestCase + + +class MockSegmenter(BaseSegmenter): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.required_keys = {'text'} + + def craft(self, text: str, *args, **kwargs) -> List[Dict]: + if text == 'valid': + # length, parent_id and id are protected keys that won't affect the segments + return [{'blob': np.array([0.0, 0.0, 0.0]), 'weight': 0}, + {'blob': np.array([1.0, 1.0, 1.0]), 'weight': 1}, + {'blob': np.array([2.0, 2.0, 2.0]), 'weight': 2, 'length': 10, 'parent_id': 50, 'id': 10}] + else: + return [{'non_existing_key': 1}] + + +class SimpleSegmentDriver(SegmentDriver): + + def __init__(self, first_chunk_id, *args, **kwargs): + super().__init__(first_chunk_id=first_chunk_id, random_chunk_id=False, *args, **kwargs) + + @property + def exec_fn(self): + return self._exec_fn + + +def create_documents_to_segment(): + doc1 = jina_pb2.Document() + doc1.id = 1 + doc1.text = 'valid' + doc1.length = 2 + doc2 = jina_pb2.Document() + doc2.id = 2 + doc2.text = 'invalid' + doc2.length = 2 + return [doc1, doc2] + + +class SegmentDriverTestCase(JinaTestCase): + + def test_segment_driver(self): + docs = create_documents_to_segment() + driver = SimpleSegmentDriver(first_chunk_id=3) + executor = MockSegmenter() + driver.attach(executor=executor, pea=None) + driver._apply(docs[0]) + + self.assertEqual(docs[0].length, 2) + self.assertEqual(docs[1].length, 2) + + self.assertEqual(docs[0].chunks[0].id, 3) + self.assertEqual(docs[0].chunks[0].parent_id, docs[0].id) + self.assertEqual(docs[0].chunks[0].blob, array2pb(np.array([0.0, 0.0, 0.0]))) + self.assertEqual(docs[0].chunks[0].weight, 0) + self.assertEqual(docs[0].chunks[0].length, 3) + + self.assertEqual(docs[0].chunks[1].id, 4) + self.assertEqual(docs[0].chunks[1].parent_id, docs[0].id) + self.assertEqual(docs[0].chunks[1].blob, array2pb(np.array([1.0, 1.0, 1.0]))) + self.assertEqual(docs[0].chunks[1].weight, 1) + self.assertEqual(docs[0].chunks[1].length, 3) + + self.assertEqual(docs[0].chunks[2].id, 5) + self.assertEqual(docs[0].chunks[2].parent_id, docs[0].id) + self.assertEqual(docs[0].chunks[2].blob, array2pb(np.array([2.0, 2.0, 2.0]))) + self.assertEqual(docs[0].chunks[2].weight, 2) + self.assertEqual(docs[0].chunks[2].length, 3) + + with self.assertRaises(AttributeError) as error: + driver._apply(docs[1]) + self.assertEqual(error.exception.__str__(), '\'Document\' object has no attribute \'non_existing_key\'') diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 90bdab7d6a13a..865025574e1e4 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -17,7 +17,7 @@ cur_dir = os.path.dirname(os.path.abspath(__file__)) -class MyTestCase(JinaTestCase): +class ClientTestCase(JinaTestCase): def test_client(self): f = Flow().add(uses='_forward') @@ -70,13 +70,13 @@ def test_gateway_index_with_args(self): json={'data': [ 'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAICAIAAABLbSncAAAA2ElEQVR4nADIADf/AxWcWRUeCEeBO68T3u1qLWarHqMaxDnxhAEaLh0Ssu6ZGfnKcjP4CeDLoJok3o4aOPYAJocsjktZfo4Z7Q/WR1UTgppAAdguAhR+AUm9AnqRH2jgdBZ0R+kKxAFoAME32BL7fwQbcLzhw+dXMmY9BS9K8EarXyWLH8VYK1MACkxlLTY4Eh69XfjpROqjE7P0AeBx6DGmA8/lRRlTCmPkL196pC0aWBkVs2wyjqb/LABVYL8Xgeomjl3VtEMxAeaUrGvnIawVh/oBAAD///GwU6v3yCoVAAAAAElFTkSuQmCC', 'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAICAIAAABLbSncAAAA2ElEQVR4nADIADf/AvdGjTZeOlQq07xSYPgJjlWRwfWEBx2+CgAVrPrP+O5ghhOa+a0cocoWnaMJFAsBuCQCgiJOKDBcIQTiLieOrPD/cp/6iZ/Iu4HqAh5dGzggIQVJI3WqTxwVTDjs5XJOy38AlgHoaKgY+xJEXeFTyR7FOfF7JNWjs3b8evQE6B2dTDvQZx3n3Rz6rgOtVlaZRLvR9geCAxuY3G+0mepEAhrTISES3bwPWYYi48OUrQOc//IaJeij9xZGGmDIG9kc73fNI7eA8VMBAAD//0SxXMMT90UdAAAAAElFTkSuQmCC'], - 'first_doc_id': 5, + 'first_doc_id': 0, }) j = a.json() self.assertTrue('index' in j) self.assertEqual(len(j['index']['docs']), 2) - self.assertEqual(j['index']['docs'][0]['id'], 6) # doc zero is reserved - self.assertEqual(j['index']['docs'][1]['id'], 7) + self.assertEqual(j['index']['docs'][0]['id'], 1) # doc zero is reserved + self.assertEqual(j['index']['docs'][1]['id'], 2) self.assertEqual(j['index']['docs'][0]['uri'], 'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAICAIAAABLbSncAAAA2ElEQVR4nADIADf/AxWcWRUeCEeBO68T3u1qLWarHqMaxDnxhAEaLh0Ssu6ZGfnKcjP4CeDLoJok3o4aOPYAJocsjktZfo4Z7Q/WR1UTgppAAdguAhR+AUm9AnqRH2jgdBZ0R+kKxAFoAME32BL7fwQbcLzhw+dXMmY9BS9K8EarXyWLH8VYK1MACkxlLTY4Eh69XfjpROqjE7P0AeBx6DGmA8/lRRlTCmPkL196pC0aWBkVs2wyjqb/LABVYL8Xgeomjl3VtEMxAeaUrGvnIawVh/oBAAD///GwU6v3yCoVAAAAAElFTkSuQmCC') self.assertEqual(a.status_code, 200)