Skip to content

Commit

Permalink
feat(jaml): add keyword jtype to represent YAML tag (#2044)
Browse files Browse the repository at this point in the history
* feat(jaml): add keyword jtype to represent YAML tag

* feat(parser): add description field

* feat(jaml): add keyword jtype to represent YAML tag
  • Loading branch information
hanxiao committed Feb 25, 2021
1 parent 974edbb commit 070ebde
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 9 deletions.
46 changes: 42 additions & 4 deletions jina/jaml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tempfile
import warnings
from types import SimpleNamespace
from typing import Dict, Any, Union, TextIO, Optional
from typing import Dict, Any, Union, TextIO, Optional, List

import yaml
from yaml.constructor import FullConstructor
Expand Down Expand Up @@ -91,6 +91,44 @@ def load(stream,
r = JAML.expand_dict(r, context)
return r

@staticmethod
def escape(value: str, include_unknown_tags: bool = True) -> str:
"""
Escape the YAML content by replacing all customized tags ``!`` to ``jtype: ``.
:param value: the original YAML content
:param include_unknown_tags: if to include unknown tags during escaping
"""
if include_unknown_tags:
r = r'!(\w+)\b'
else:
r = '|'.join(JAML.registered_tags())
r = rf'!({r})\b'
return re.sub(r, r'jtype: \1', value)

@staticmethod
def unescape(value: str, include_unknown_tags: bool = True) -> str:
"""
Unescape the YAML content by replacing all ``jtype: `` to tags.
:param value: the escaped YAML content
:param include_unknown_tags: if to include unknown tags during unescaping
"""
if include_unknown_tags:
r = r'jtype: (\w+)\b'
else:
r = '|'.join(JAML.registered_tags())
r = rf'jtype: ({r})\b'
return re.sub(r, r'!\1', value)

@staticmethod
def registered_tags() -> List[str]:
"""
Return a list of :class:`JAMLCompatible` classes that have been registered.
"""
return list(v[1:] for v in set(JinaLoader.yaml_constructors.keys()) if v and v.startswith('!'))

@staticmethod
def load_no_tags(stream, **kwargs):
"""
Expand All @@ -100,7 +138,7 @@ def load_no_tags(stream, **kwargs):
:param **kwargs: other kwargs
:return: the Python object
"""
safe_yml = '\n'.join(v if not re.match(r'^[\s-]*?!\b', v) else v.replace('!', '__cls: ') for v in stream)
safe_yml = JAML.escape('\n'.join(v for v in stream))
return JAML.load(safe_yml, **kwargs)

@staticmethod
Expand Down Expand Up @@ -430,10 +468,10 @@ def load_config(cls,
# also add YAML parent path to the search paths
load_py_modules(no_tag_yml, extra_search_paths=(os.path.dirname(s_path),) if s_path else None)
# revert yaml's tag and load again, this time with substitution
revert_tag_yml = JAML.dump(no_tag_yml).replace('__cls: ', '!')
tag_yml = JAML.unescape(JAML.dump(no_tag_yml))

# load into object, no more substitute
return JAML.load(revert_tag_yml, substitute=False)
return JAML.load(tag_yml, substitute=False)

@classmethod
def inject_config(cls, raw_config: Dict, *args, **kwargs) -> Dict:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/flow/test_flow_yaml_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_flow_uses_from_dict():
class DummyEncoder(BaseEncoder):
pass

d1 = {'__cls': 'DummyEncoder',
d1 = {'jtype': 'DummyEncoder',
'metas': {'name': 'dummy1'}}
with Flow().add(uses=d1):
pass
183 changes: 183 additions & 0 deletions tests/unit/jaml/test_type_parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import pytest

from jina.executors import BaseExecutor
from jina.jaml import JAML


class MyExecutor(BaseExecutor):
pass


def test_non_empty_reg_tags():
assert JAML.registered_tags()
assert 'BaseExecutor' in JAML.registered_tags()


@pytest.mark.parametrize('include_unk, expected', [
(True, '''
jtype: BaseExecutor {}
jtype: Blah {}
'''
),
(False, '''
jtype: BaseExecutor {}
!Blah {}
'''
)
])
def test_include_unknown(include_unk, expected):
y = '''
!BaseExecutor {}
!Blah {}
'''
assert JAML.escape(y, include_unknown_tags=include_unk).strip() == expected.strip()


@pytest.mark.parametrize('original, escaped',
[(
'''
!BaseExecutor {}
!Blah {}
!MyExecutor {}
''',
'''
jtype: BaseExecutor {}
!Blah {}
jtype: MyExecutor {}
'''
), (
'''
!BaseExecutor
with:
a: 123
b: BaseExecutor
jtype: unknown-blah
''',
'''
jtype: BaseExecutor
with:
a: 123
b: BaseExecutor
jtype: unknown-blah
'''
), (
'''
!CompoundIndexer
components:
- !NumpyIndexer
with:
index_filename: vec.gz
metric: euclidean
metas:
name: vecidx
- !BinaryPbIndexer
with:
index_filename: doc.gz
metas:
name: docidx
metas:
name: indexer
workspace: $JINA_WORKSPACE
''',
'''
jtype: CompoundIndexer
components:
- jtype: NumpyIndexer
with:
index_filename: vec.gz
metric: euclidean
metas:
name: vecidx
- jtype: BinaryPbIndexer
with:
index_filename: doc.gz
metas:
name: docidx
metas:
name: indexer
workspace: $JINA_WORKSPACE
'''
), (
'''
!CompoundIndexer
metas:
workspace: $TMP_WORKSPACE
components:
- !NumpyIndexer
with:
metric: euclidean
index_filename: vec.gz
metas:
name: vecidx # a customized name
- !BinaryPbIndexer
with:
index_filename: chunk.gz
metas:
name: kvidx # a customized name
requests:
on:
IndexRequest:
- !VectorIndexDriver
with:
executor: NumpyIndexer
filter_by: $FILTER_BY
- !KVIndexDriver
with:
executor: BinaryPbIndexer
filter_by: $FILTER_BY
SearchRequest:
- !VectorSearchDriver
with:
executor: NumpyIndexer
filter_by: $FILTER_BY
- !KVSearchDriver
with:
executor: BinaryPbIndexer
filter_by: $FILTER_BY
''',
'''
jtype: CompoundIndexer
metas:
workspace: $TMP_WORKSPACE
components:
- jtype: NumpyIndexer
with:
metric: euclidean
index_filename: vec.gz
metas:
name: vecidx # a customized name
- jtype: BinaryPbIndexer
with:
index_filename: chunk.gz
metas:
name: kvidx # a customized name
requests:
on:
IndexRequest:
- jtype: VectorIndexDriver
with:
executor: NumpyIndexer
filter_by: $FILTER_BY
- jtype: KVIndexDriver
with:
executor: BinaryPbIndexer
filter_by: $FILTER_BY
SearchRequest:
- jtype: VectorSearchDriver
with:
executor: NumpyIndexer
filter_by: $FILTER_BY
- jtype: KVSearchDriver
with:
executor: BinaryPbIndexer
filter_by: $FILTER_BY
'''
),
])
def test_escape(original, escaped):
assert JAML.escape(original, include_unknown_tags=False).strip() == escaped.strip()
assert JAML.unescape(JAML.escape(original, include_unknown_tags=False),
include_unknown_tags=False).strip() == original.strip()
8 changes: 4 additions & 4 deletions tests/unit/test_yamlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def test_load_from_dict():
# workspace: ${{this.name}}-${{this.batch_size}}

d1 = {
'__cls': 'BaseEncoder',
'jtype': 'BaseEncoder',
'metas': {'name': '${{BE_TEST_NAME}}',
'batch_size': '${{BATCH_SIZE}}',
'pea_id': '${{pea_id}}',
Expand All @@ -212,16 +212,16 @@ def test_load_from_dict():
# name: compound1

d2 = {
'__cls': 'CompoundExecutor',
'jtype': 'CompoundExecutor',
'components':
[
{
'__cls': 'BinaryPbIndexer',
'jtype': 'BinaryPbIndexer',
'with': {'index_filename': 'tmp1'},
'metas': {'name': 'test1'}
},
{
'__cls': 'BinaryPbIndexer',
'jtype': 'BinaryPbIndexer',
'with': {'index_filename': 'tmp2'},
'metas': {'name': 'test2'}
},
Expand Down

0 comments on commit 070ebde

Please sign in to comment.