Skip to content

Commit

Permalink
fix(jaml): remove duplicate py_modules while parsing yaml (#3830)
Browse files Browse the repository at this point in the history
  • Loading branch information
bwanglzu committed Nov 1, 2021
1 parent d22f471 commit d78530b
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 11 deletions.
11 changes: 4 additions & 7 deletions jina/jaml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,13 +613,10 @@ def _delitem(

@classmethod
def _override_yml_params(cls, raw_yaml, field_name, override_field):
if override_field is not None:
field_params = raw_yaml.get(field_name, None)
if field_params:
field_params.update(**override_field)
raw_yaml.update(field_params)
else:
raw_yaml[field_name] = override_field
if override_field:
field_params = raw_yaml.get(field_name, {})
field_params.update(**override_field)
raw_yaml[field_name] = field_params

@staticmethod
def is_valid_jaml(obj: Dict) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions jina/jaml/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,14 +226,14 @@ def load_py_modules(d: Dict, extra_search_paths: Optional[List[str]] = None) ->
:param d: the dictionary to traverse
:param extra_search_paths: any extra paths to search
"""
mod = []
mod = set()

def _finditem(obj, key='py_modules'):
value = obj.get(key, [])
if isinstance(value, str):
mod.append(value)
mod.add(value)
elif isinstance(value, (list, tuple)):
mod.extend(value)
mod.update(value)
for k, v in obj.items():
if isinstance(v, dict):
_finditem(v, key)
Expand Down
44 changes: 43 additions & 1 deletion tests/unit/jaml/test_type_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from jina.enums import SocketType
from jina.executors import BaseExecutor
from jina.jaml import JAML
from jina.jaml import JAML, JAMLCompatible
from jina import __default_executor__, requests


Expand Down Expand Up @@ -101,3 +101,45 @@ def test_cls_from_tag():
assert JAML.cls_from_tag('!MyExec') == MyExec
assert JAML.cls_from_tag('BaseExecutor') == BaseExecutor
assert JAML.cls_from_tag('Nonexisting') is None


@pytest.mark.parametrize(
'field_name, override_field',
[
('with', None),
('metas', None),
('requests', None),
('with', {'a': 456, 'b': 'updated-test'}),
(
'metas',
{'name': 'test-name-updated', 'workspace': 'test-work-space-updated'},
),
('requests', {'/foo': 'baz'}),
# assure py_modules only occurs once #3830
(
'metas',
{
'name': 'test-name-updated',
'workspace': 'test-work-space-updated',
'py_modules': 'test_module.py',
},
),
],
)
def test_override_yml_params(field_name, override_field):
original_raw_yaml = {
'jtype': 'SimpleIndexer',
'with': {'a': 123, 'b': 'test'},
'metas': {'name': 'test-name', 'workspace': 'test-work-space'},
'requests': {'/foo': 'bar'},
}
updated_raw_yaml = original_raw_yaml
JAMLCompatible()._override_yml_params(updated_raw_yaml, field_name, override_field)
if override_field:
assert updated_raw_yaml[field_name] == override_field
else:
assert original_raw_yaml == updated_raw_yaml
# assure we don't create py_modules twice
if override_field == 'metas' and 'py_modules' in override_field:
assert 'py_modules' in updated_raw_yaml['metas']
assert 'py_modules' not in updated_raw_yaml

0 comments on commit d78530b

Please sign in to comment.