diff --git a/jina/drivers/__init__.py b/jina/drivers/__init__.py index 98d1c0de2a261..601ccd9f3483a 100644 --- a/jina/drivers/__init__.py +++ b/jina/drivers/__init__.py @@ -170,11 +170,11 @@ def __getstate__(self) -> Dict[str, Any]: class BaseRecursiveDriver(BaseDriver): - def __init__(self, depth_range: Tuple[int] = (0, 0), apply_order: str = 'post', + def __init__(self, depth_range: Tuple[int] = (0, 1), apply_order: str = 'post', traverse_on: Tuple[str] = ('chunks',), *args, **kwargs): """ - :param depth_range: right-exclusive range of the recursion depth, (0,0) for root-level only + :param depth_range: right-exclusive range of the recursion depth, (0, 1) for root-level only :param apply_order: the traverse and apply order. if 'post' then first traverse then call apply, if 'pre' then first apply then traverse :param args: :param kwargs: @@ -226,8 +226,10 @@ def post_traverse(_docs, traverse_on, context_doc=None): """ if _docs: for d in _docs: + # check if apply to next level if d.level_depth < self._depth_end: post_traverse(getattr(d, traverse_on), traverse_on, d) + # check if apply to the current level if self.is_apply and (d.level_depth >= self._depth_start): self._apply(d, context_doc, traverse_on, *args, **kwargs) @@ -242,9 +244,11 @@ def pre_traverse(_docs, traverse_on, context_doc=None): self._apply_all(_docs, context_doc, traverse_on, *args, **kwargs) for d in _docs: + # check if apply on the current level if self.is_apply and d.level_depth >= self._depth_start: self._apply(d, context_doc, traverse_on, *args, **kwargs) - if d.level_depth < self._depth_end: + # check if apply to the next level + if (d.level_depth + 1) < self._depth_end: pre_traverse(getattr(d, traverse_on), traverse_on, d) if self.recursion_order == 'post': diff --git a/tests/unit/drivers/querylang/test_querylang_drivers.py b/tests/unit/drivers/querylang/test_querylang_drivers.py index 4fdec40eaaca5..6c7816f104f26 100644 --- a/tests/unit/drivers/querylang/test_querylang_drivers.py +++ b/tests/unit/drivers/querylang/test_querylang_drivers.py @@ -128,7 +128,7 @@ def validate(req): f = (Flow().add(uses='DummyModeIdSegmenter') .add( - uses='- !FilterQL | {lookups: {modality: mode2}, traverse_on: [chunks], depth_range: [1, 1]}')) + uses='- !FilterQL | {lookups: {modality: mode2}, traverse_on: [chunks], depth_range: [1, 2]}')) with f: f.index(random_docs_with_chunks(), output_fn=validate, callback_on_body=True) diff --git a/tests/unit/drivers/yaml/mockencoder-mode1.yml b/tests/unit/drivers/yaml/mockencoder-mode1.yml index 1fe00a5fb32e6..fd0fdf31e04bf 100644 --- a/tests/unit/drivers/yaml/mockencoder-mode1.yml +++ b/tests/unit/drivers/yaml/mockencoder-mode1.yml @@ -6,9 +6,9 @@ requests: with: lookups: {modality: mode1} traverse_on: [chunks] - depth_range: [1, 1] + depth_range: [1, 2] - !EncodeDriver with: method: encode traverse_on: [chunks] - depth_range: [1, 1] + depth_range: [1, 2] diff --git a/tests/unit/drivers/yaml/mockencoder-mode2.yml b/tests/unit/drivers/yaml/mockencoder-mode2.yml index d144f8055ac81..1fda352eccb0b 100644 --- a/tests/unit/drivers/yaml/mockencoder-mode2.yml +++ b/tests/unit/drivers/yaml/mockencoder-mode2.yml @@ -6,9 +6,9 @@ requests: with: lookups: {modality: mode2} traverse_on: [chunks] - depth_range: [1, 1] + depth_range: [1, 2] - !EncodeDriver with: method: encode traverse_on: [chunks] - depth_range: [1, 1] + depth_range: [1, 2] diff --git a/tests/unit/flow/test_flow.py b/tests/unit/flow/test_flow.py index 4ec96649c6347..8ce72ef7507bd 100644 --- a/tests/unit/flow/test_flow.py +++ b/tests/unit/flow/test_flow.py @@ -508,7 +508,7 @@ def input_fn(): flow = Flow().add(name='chunk_seg', parallel=3, uses='_pass').\ add(name='encoder12', parallel=2, - uses='- !FilterQL | {lookups: {modality__in: [mode1, mode2]}, depth_range: [0, 0]}') + uses='- !FilterQL | {lookups: {modality__in: [mode1, mode2]}, depth_range: [0, 1]}') with flow: flow.index(input_fn=input_fn, output_fn=validate)