Skip to content

Commit

Permalink
fix(drivers): fix the default value for depth_range
Browse files Browse the repository at this point in the history
  • Loading branch information
nan-wang committed Aug 13, 2020
1 parent 80d368d commit 2e793ad
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 9 deletions.
10 changes: 7 additions & 3 deletions jina/drivers/__init__.py
Expand Up @@ -170,11 +170,11 @@ def __getstate__(self) -> Dict[str, Any]:

class BaseRecursiveDriver(BaseDriver):

def __init__(self, depth_range: Tuple[int] = (0, 0), apply_order: str = 'post',
def __init__(self, depth_range: Tuple[int] = (0, 1), apply_order: str = 'post',
traverse_on: Tuple[str] = ('chunks',), *args, **kwargs):
"""
:param depth_range: right-exclusive range of the recursion depth, (0,0) for root-level only
:param depth_range: right-exclusive range of the recursion depth, (0, 1) for root-level only
:param apply_order: the traverse and apply order. if 'post' then first traverse then call apply, if 'pre' then first apply then traverse
:param args:
:param kwargs:
Expand Down Expand Up @@ -226,8 +226,10 @@ def post_traverse(_docs, traverse_on, context_doc=None):
"""
if _docs:
for d in _docs:
# check if apply to next level
if d.level_depth < self._depth_end:
post_traverse(getattr(d, traverse_on), traverse_on, d)
# check if apply to the current level
if self.is_apply and (d.level_depth >= self._depth_start):
self._apply(d, context_doc, traverse_on, *args, **kwargs)

Expand All @@ -242,9 +244,11 @@ def pre_traverse(_docs, traverse_on, context_doc=None):
self._apply_all(_docs, context_doc, traverse_on, *args, **kwargs)

for d in _docs:
# check if apply on the current level
if self.is_apply and d.level_depth >= self._depth_start:
self._apply(d, context_doc, traverse_on, *args, **kwargs)
if d.level_depth < self._depth_end:
# check if apply to the next level
if (d.level_depth + 1) < self._depth_end:
pre_traverse(getattr(d, traverse_on), traverse_on, d)

if self.recursion_order == 'post':
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/drivers/querylang/test_querylang_drivers.py
Expand Up @@ -128,7 +128,7 @@ def validate(req):

f = (Flow().add(uses='DummyModeIdSegmenter')
.add(
uses='- !FilterQL | {lookups: {modality: mode2}, traverse_on: [chunks], depth_range: [1, 1]}'))
uses='- !FilterQL | {lookups: {modality: mode2}, traverse_on: [chunks], depth_range: [1, 2]}'))

with f:
f.index(random_docs_with_chunks(), output_fn=validate, callback_on_body=True)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/drivers/yaml/mockencoder-mode1.yml
Expand Up @@ -6,9 +6,9 @@ requests:
with:
lookups: {modality: mode1}
traverse_on: [chunks]
depth_range: [1, 1]
depth_range: [1, 2]
- !EncodeDriver
with:
method: encode
traverse_on: [chunks]
depth_range: [1, 1]
depth_range: [1, 2]
4 changes: 2 additions & 2 deletions tests/unit/drivers/yaml/mockencoder-mode2.yml
Expand Up @@ -6,9 +6,9 @@ requests:
with:
lookups: {modality: mode2}
traverse_on: [chunks]
depth_range: [1, 1]
depth_range: [1, 2]
- !EncodeDriver
with:
method: encode
traverse_on: [chunks]
depth_range: [1, 1]
depth_range: [1, 2]
2 changes: 1 addition & 1 deletion tests/unit/flow/test_flow.py
Expand Up @@ -508,7 +508,7 @@ def input_fn():

flow = Flow().add(name='chunk_seg', parallel=3, uses='_pass').\
add(name='encoder12', parallel=2,
uses='- !FilterQL | {lookups: {modality__in: [mode1, mode2]}, depth_range: [0, 0]}')
uses='- !FilterQL | {lookups: {modality__in: [mode1, mode2]}, depth_range: [0, 1]}')
with flow:
flow.index(input_fn=input_fn, output_fn=validate)

Expand Down

0 comments on commit 2e793ad

Please sign in to comment.