Skip to content

Commit

Permalink
test: sharding non existent (#1719)
Browse files Browse the repository at this point in the history
* test: sharding non existent

* test: sharding merge

* fix: proof of concept on how to reduce at request level

* refactor: promote apply_root

* refactor: fixed circular import

* fix: type error

* fix: another linting error fixed

* feat: cleaner interface

* fix: name change due to deprecation

* test: sharding move merge root

Co-authored-by: Joan Fontanals Martinez <joan.martinez@jina.ai>
Co-authored-by: Maximilian Werk <maximilian.werk@jina.ai>
  • Loading branch information
3 people committed Jan 22, 2021
1 parent d800a58 commit 86853bf
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 20 deletions.
27 changes: 18 additions & 9 deletions jina/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from ..executors import AnyExecutor
from ..logging.logger import JinaLogger
from ..types.message import Message
from ..types.request import Request
from ..types.document import Document
from ..types.request import Request
from ..types.sets import QueryLangSet, DocumentSet


Expand Down Expand Up @@ -200,6 +200,13 @@ def queryset(self) -> 'QueryLangSet':
else:
return []

@property
def docs(self):
if self.expect_parts > 1:
return (d for r in reversed(self.partial_reqs) for d in r.docs)
else:
return self.req.docs

@property
def logger(self) -> 'JinaLogger':
"""Shortcut to ``self.runtime.logger``"""
Expand Down Expand Up @@ -239,6 +246,15 @@ def __init__(self, traversal_paths: Tuple[str] = ('c', 'r'), *args, **kwargs):
super().__init__(*args, **kwargs)
self._traversal_paths = [path.lower() for path in traversal_paths]

def _apply_root(
self,
docs: 'DocumentSet',
field: str,
*args,
**kwargs,
) -> None:
return self._apply_all(docs, None, field, *args, **kwargs)

# TODO(Han): probably want to publicize this, as it is not obvious for driver
# developer which one should be inherited
def _apply_all(
Expand All @@ -256,20 +272,13 @@ def _apply_all(
:param field: where ``docs`` comes from, either ``matches`` or ``chunks``
"""

@property
def docs(self):
if self.expect_parts > 1:
return (d for r in reversed(self.partial_reqs) for d in r.docs)
else:
return self.req.docs

def __call__(self, *args, **kwargs):
self._traverse_apply(self.docs, *args, **kwargs)

def _traverse_apply(self, docs: 'DocumentSet', *args, **kwargs) -> None:
for path in self._traversal_paths:
if path[0] == 'r':
self._traverse_rec(docs, None, None, [], *args, **kwargs)
self._apply_root(docs, 'docs', *args, **kwargs)
for doc in docs:
self._traverse_rec(
[doc],
Expand Down
3 changes: 0 additions & 3 deletions jina/drivers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

if False:
from ..types.sets import DocumentSet
from ..types.document import Document


class ConvertDriver(BaseRecursiveDriver):
Expand All @@ -14,8 +13,6 @@ def __init__(self, convert_fn: str, *args, **kwargs):
def _apply_all(
self,
docs: 'DocumentSet',
context_doc: 'Document',
field: str,
*args,
**kwargs,
) -> None:
Expand Down
2 changes: 0 additions & 2 deletions jina/drivers/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ def __call__(self, *args, **kwargs):
def _apply_all(
self,
docs: Iterator['DocGroundtruthPair'],
context_doc: 'DocGroundtruthPair' = None,
field: str = None,
*args,
**kwargs
) -> None:
Expand Down
5 changes: 0 additions & 5 deletions jina/drivers/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from . import BaseExecutableDriver
from ..helper import typename
from ..types.document import Document

if False:
from ..types.sets import DocumentSet
Expand Down Expand Up @@ -32,8 +31,6 @@ def __init__(self, output_tag: str = 'prediction', *args, **kwargs):
def _apply_all(
self,
docs: 'DocumentSet',
context_doc: 'Document',
field: str,
*args,
**kwargs,
) -> None:
Expand Down Expand Up @@ -147,8 +144,6 @@ class Prediction2DocBlobDriver(BasePredictDriver):
def _apply_all(
self,
docs: 'DocumentSet',
context_doc: 'Document',
field: str,
*args,
**kwargs,
) -> None:
Expand Down
9 changes: 9 additions & 0 deletions jina/drivers/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,22 @@ def __call__(self, *args, **kwargs):
self._traverse_apply(self.docs, *args, **kwargs)
self.doc_pointers.clear()

def _apply_root(self, docs: 'DocumentSet', field: str, *args, **kwargs):
docs = []
for doc in self.docs:
docs.append(doc)
request = self.msg.request
request.body.ClearField(field)
request.docs.extend(docs)

def _apply_all(
self,
docs: 'DocumentSet',
context_doc: 'Document',
field: str,
*args,
**kwargs) -> None:

if context_doc.id not in self.doc_pointers:
self.doc_pointers[context_doc.id] = context_doc
else:
Expand Down
12 changes: 12 additions & 0 deletions jina/resources/executors._merge_root.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
!BaseExecutor
with: {}
metas:
name: merge_root
requests:
on:
[SearchRequest, TrainRequest, IndexRequest, DeleteRequest, UpdateRequest]:
- !ReduceAllDriver
with:
traversal_paths: ['r']
ControlRequest:
- !ControlReqDriver {}
60 changes: 60 additions & 0 deletions tests/integration/sharding/test_search_non_existent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
import random
import string

import numpy as np
import pytest

from jina import Document, Flow

random.seed(0)
np.random.seed(0)

cur_dir = os.path.dirname(os.path.abspath(__file__))


@pytest.fixture
def config(tmpdir):
os.environ['JINA_SHARDING_DIR'] = str(tmpdir)
yield
del os.environ['JINA_SHARDING_DIR']


def random_docs(start, end, embed_dim=10):
for j in range(start, end):
d = Document()
d.id = f'{j:0>16}'
d.tags['id'] = j
d.text = ''.join(random.choice(string.ascii_lowercase) for _ in range(10)).encode('utf8')
d.embedding = np.random.random([embed_dim])
yield d


def test_search_non_existent(config, mocker):
yaml_file = 'index_kv_simple.yml'

def validate_results(resp):
mock()
assert len(resp.docs) == 3

with Flow().add(
uses=os.path.join(cur_dir, 'yaml', yaml_file),
shards=2,
separated_workspace=True,
) as index_flow:
index_flow.index(input_fn=random_docs(0, 3), request_size=1)

mock = mocker.Mock()
with Flow(read_only=True).add(
show_exc_info=True,
uses=os.path.join(cur_dir, 'yaml', yaml_file),
shards=2,
separated_workspace=True,
uses_after='_merge_root',
polling='all'
) as search_flow:
search_flow.search(input_fn=random_docs(0, 5),
on_done=validate_results,
request_size=5
)
mock.assert_called_once()
20 changes: 20 additions & 0 deletions tests/integration/sharding/yaml/index_kv_simple.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
!BinaryPbIndexer
with:
index_filename: doc.gzip
metas:
name: kvidx
workspace: $JINA_SHARDING_DIR

requests:
on:
IndexRequest:
- !KVIndexDriver
with:
executor: kvidx
traversal_paths: ['r']
SearchRequest:
- !KVSearchDriver
with:
executor: kvidx
is_merge: false
traversal_paths: ['r']
4 changes: 3 additions & 1 deletion tests/unit/yaml/test-executor-with-custom-driver.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ metas:
requests:
on:
IndexRequest:
- !DummyEncodeDriver {} # cannot be found in dummy_encode_driver.py
- !DummyEncodeDriver
with:
traversal_paths: ['c'] # cannot be found in dummy_encode_driver.py
SearchRequest:
- !EncodeDriver {}

0 comments on commit 86853bf

Please sign in to comment.