From 5e2767a43cb23873ed0583413d387091027fcbb8 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Fri, 5 Feb 2021 12:38:28 +0100 Subject: [PATCH] refactor(types): allow doc building from arbitrary json/dict (#1877) * refactor(types): allow doc building from arbitrary json/dict --- cli/__init__.py | 2 +- daemon/parser.py | 2 +- extra-requirements.txt | 3 +- jina/drivers/querylang/queryset/dunderkey.py | 2 +- jina/flow/base.py | 2 +- jina/helper.py | 4 +- jina/jaml/__init__.py | 4 +- jina/jaml/helper.py | 2 +- jina/jaml/parsers/__init__.py | 2 +- jina/types/document/__init__.py | 47 ++++++++++++++-- jina/types/score/__init__.py | 2 +- tests/unit/clients/python/test_request.py | 5 +- tests/unit/types/document/test_document.py | 58 ++++++++++++++++++++ 13 files changed, 112 insertions(+), 23 deletions(-) diff --git a/cli/__init__.py b/cli/__init__.py index 33deb88dc78b3..a60137aa14fe5 100644 --- a/cli/__init__.py +++ b/cli/__init__.py @@ -17,7 +17,7 @@ def _get_run_args(print_args: bool = True): from pkg_resources import resource_filename p = parser._actions[-1].choices[sys.argv[1]] default_args = {a.dest: a.default for a in p._actions if - isinstance(a, _StoreAction) or isinstance(a, _StoreTrueAction)} + isinstance(a, (_StoreAction, _StoreTrueAction))} with open(resource_filename('jina', '/'.join(('resources', 'jina.logo')))) as fp: logo_str = fp.read() diff --git a/daemon/parser.py b/daemon/parser.py index d9ab2f476943b..fc4b7457294c7 100644 --- a/daemon/parser.py +++ b/daemon/parser.py @@ -44,7 +44,7 @@ def _get_run_args(print_args: bool = True): if print_args: from pkg_resources import resource_filename default_args = {a.dest: a.default for a in parser._actions if - isinstance(a, _StoreAction) or isinstance(a, _StoreTrueAction)} + isinstance(a, (_StoreAction, _StoreTrueAction))} with open(resource_filename('jina', '/'.join(('resources', 'jina.logo')))) as fp: logo_str = fp.read() diff --git a/extra-requirements.txt b/extra-requirements.txt index 9052cbfe9ba55..184ead3495f89 100644 --- a/extra-requirements.txt +++ b/extra-requirements.txt @@ -63,8 +63,7 @@ flaky: test mock: test requests: http, devel, test, daemon prettytable: devel, test -sseclient-py: test -optuna: test, optimizer +optuna: cicd, optimizer websockets: http, devel, test, ws, daemon wsproto: http, devel, test, ws, daemon pydantic: http, devel, test, daemon diff --git a/jina/drivers/querylang/queryset/dunderkey.py b/jina/drivers/querylang/queryset/dunderkey.py index 308dde7b99b02..40daf8384bc70 100644 --- a/jina/drivers/querylang/queryset/dunderkey.py +++ b/jina/drivers/querylang/queryset/dunderkey.py @@ -128,7 +128,7 @@ def dunder_get(_dict: Any, key: str) -> Any: if isinstance(part1, int): result = guard_iter(_dict)[part1] - elif isinstance(_dict, dict) or isinstance(_dict, Struct): + elif isinstance(_dict, (dict, Struct)): if part1 in _dict: result = _dict[part1] else: diff --git a/jina/flow/base.py b/jina/flow/base.py index 094c4d099ffdf..2be01a2ceabb3 100644 --- a/jina/flow/base.py +++ b/jina/flow/base.py @@ -107,7 +107,7 @@ def _parse_endpoints(op_flow, pod_name, endpoint, connect_to_last_pod=False) -> else: endpoint = [] - if isinstance(endpoint, list) or isinstance(endpoint, tuple): + if isinstance(endpoint, (list, tuple)): for idx, s in enumerate(endpoint): if s == pod_name: raise FlowTopologyError('the income/output of a pod can not be itself') diff --git a/jina/helper.py b/jina/helper.py index cd8800a70608f..48ebb8f72f7cf 100644 --- a/jina/helper.py +++ b/jina/helper.py @@ -278,14 +278,14 @@ def _scan(sub_d: Union[Dict, List], p): def _replace(sub_d: Union[Dict, List], p): if isinstance(sub_d, Dict): for k, v in sub_d.items(): - if isinstance(v, dict) or isinstance(v, list): + if isinstance(v, (dict, list)): _replace(v, p.__dict__[k]) else: if isinstance(v, str) and pat.findall(v): sub_d[k] = _sub(v, p) elif isinstance(sub_d, List): for idx, v in enumerate(sub_d): - if isinstance(v, dict) or isinstance(v, list): + if isinstance(v, (dict, list)): _replace(v, p[idx]) else: if isinstance(v, str) and pat.findall(v): diff --git a/jina/jaml/__init__.py b/jina/jaml/__init__.py index 72ec1540254af..50a4e26f8422a 100644 --- a/jina/jaml/__init__.py +++ b/jina/jaml/__init__.py @@ -127,7 +127,7 @@ def _scan(sub_d, p): def _replace(sub_d, p, resolve_ref=False): if isinstance(sub_d, dict): for k, v in sub_d.items(): - if isinstance(v, dict) or isinstance(v, list): + if isinstance(v, (dict, list)): _replace(v, p.__dict__[k], resolve_ref) else: if isinstance(v, str): @@ -137,7 +137,7 @@ def _replace(sub_d, p, resolve_ref=False): sub_d[k] = _sub(v) elif isinstance(sub_d, list): for idx, v in enumerate(sub_d): - if isinstance(v, dict) or isinstance(v, list): + if isinstance(v, (dict, list)): _replace(v, p[idx], resolve_ref) else: if isinstance(v, str): diff --git a/jina/jaml/helper.py b/jina/jaml/helper.py index 46fba3fa560a0..a879b8b2c2cc8 100644 --- a/jina/jaml/helper.py +++ b/jina/jaml/helper.py @@ -184,7 +184,7 @@ def _finditem(obj, key='py_modules'): value = obj.get(key, []) if isinstance(value, str): mod.append(value) - elif isinstance(value, list) or isinstance(value, tuple): + elif isinstance(value, (list, tuple)): mod.extend(value) for k, v in obj.items(): if isinstance(v, dict): diff --git a/jina/jaml/parsers/__init__.py b/jina/jaml/parsers/__init__.py index 88798a64c0c10..7d04d177e52cb 100644 --- a/jina/jaml/parsers/__init__.py +++ b/jina/jaml/parsers/__init__.py @@ -55,7 +55,7 @@ def get_parser(cls: Type['JAMLCompatible'], version: Optional[str]) -> 'Versione """ all_parsers, legacy_parser = _get_all_parser(cls) if version: - if isinstance(version, float) or isinstance(version, int): + if isinstance(version, (float, int)): version = str(version) for p in all_parsers: if p.version == version: diff --git a/jina/types/document/__init__.py b/jina/types/document/__init__.py index c2cf49b65bc85..350f4dc838865 100644 --- a/jina/types/document/__init__.py +++ b/jina/types/document/__init__.py @@ -1,4 +1,5 @@ import base64 +import json import mimetypes import os import urllib.parse @@ -29,6 +30,9 @@ DocumentSourceType = TypeVar('DocumentSourceType', jina_pb2.DocumentProto, bytes, str, Dict) +_document_fields = set(list(jina_pb2.DocumentProto().DESCRIPTOR.fields_by_camelcase_name) + list( + jina_pb2.DocumentProto().DESCRIPTOR.fields_by_name)) + class Document(ProtoTypeMixin): """ @@ -93,6 +97,7 @@ class Document(ProtoTypeMixin): """ def __init__(self, document: Optional[DocumentSourceType] = None, + field_resolver: Dict[str, str] = None, copy: bool = False, **kwargs): """ @@ -104,7 +109,24 @@ def __init__(self, document: Optional[DocumentSourceType] = None, it builds a view or a copy from it. :param copy: when ``document`` is given as a :class:`DocumentProto` object, build a view (i.e. weak reference) from it or a deep copy from it. - :param kwargs: other parameters to be set + :param field_resolver: a map from field names defined in ``document`` (JSON, dict) to the field + names defined in Protobuf. This is only used when the given ``document`` is + a JSON string or a Python dict. + :param kwargs: other parameters to be set _after_ the document is constructed + + .. note:: + + When ``document`` is a JSON string or Python dictionary object, the constructor will only map the values + from known fields defined in Protobuf, all unknown fields are mapped to ``document.tags``. For example, + + .. highlight:: python + .. code-block:: python + + d = Document({'id': '123', 'hello': 'world', 'tags': {'good': 'bye'}}) + + assert d.id == '123' # true + assert d.tags['hello'] == 'world' # true + assert d.tags['good'] == 'bye' # true """ self._pb_body = jina_pb2.DocumentProto() try: @@ -113,10 +135,23 @@ def __init__(self, document: Optional[DocumentSourceType] = None, self._pb_body.CopyFrom(document) else: self._pb_body = document - elif isinstance(document, dict): - json_format.ParseDict(document, self._pb_body) - elif isinstance(document, str): - json_format.Parse(document, self._pb_body) + elif isinstance(document, (dict, str)): + if isinstance(document, str): + document = json.loads(document) + + if field_resolver: + document = {field_resolver.get(k, k): v for k, v in document.items()} + + user_fields = set(document.keys()) + if _document_fields.issuperset(user_fields): + json_format.ParseDict(document, self._pb_body) + else: + _intersect = _document_fields.intersection(user_fields) + _remainder = user_fields.difference(_intersect) + if _intersect: + json_format.ParseDict({k: document[k] for k in _intersect}, self._pb_body) + if _remainder: + self._pb_body.tags.update({k: document[k] for k in _remainder}) elif isinstance(document, bytes): # directly parsing from binary string gives large false-positive # fortunately protobuf throws a warning when the parsing seems go wrong @@ -318,7 +353,7 @@ def set_attrs(self, **kwargs): """ for k, v in kwargs.items(): - if isinstance(v, list) or isinstance(v, tuple): + if isinstance(v, (list, tuple)): if k == 'chunks': self.chunks.extend(v) elif k == 'matches': diff --git a/jina/types/score/__init__.py b/jina/types/score/__init__.py index 777e820178a79..a71d7d1e8f270 100644 --- a/jina/types/score/__init__.py +++ b/jina/types/score/__init__.py @@ -72,7 +72,7 @@ def set_attrs(self, **kwargs): """ for k, v in kwargs.items(): - if isinstance(v, list) or isinstance(v, tuple): + if isinstance(v, (list, tuple)): self._pb_body.ClearField(k) getattr(self._pb_body, k).extend(v) elif isinstance(v, dict): diff --git a/tests/unit/clients/python/test_request.py b/tests/unit/clients/python/test_request.py index 0c294e86f1913..99b840b74a0e7 100644 --- a/tests/unit/clients/python/test_request.py +++ b/tests/unit/clients/python/test_request.py @@ -1,4 +1,5 @@ import os +import sys import numpy as np import pytest @@ -13,7 +14,6 @@ from jina.proto.jina_pb2 import DocumentProto from jina.types.ndarray.generic import NdArray -import sys @pytest.mark.skipif(sys.version_info < (3, 8, 0), reason='somehow this does not work on Github workflow with Py3.7, ' 'but Py 3.8 is fine, local Py3.7 is fine') @@ -45,9 +45,6 @@ def test_data_type_builder_doc_bad(): with pytest.raises(BadDocType): _new_doc_from_data(MessageToJson(a) + '🍔', DataInputType.DOCUMENT) - with pytest.raises(BadDocType): - _new_doc_from_data({'🍔': '🍔'}, DataInputType.DOCUMENT) - @pytest.mark.parametrize('input_type', [DataInputType.AUTO, DataInputType.CONTENT]) def test_data_type_builder_auto(input_type): diff --git a/tests/unit/types/document/test_document.py b/tests/unit/types/document/test_document.py index 994f9cdb06a3e..b0776b070c77d 100644 --- a/tests/unit/types/document/test_document.py +++ b/tests/unit/types/document/test_document.py @@ -1,3 +1,5 @@ +import json + import numpy as np import pytest from google.protobuf.json_format import MessageToDict @@ -316,3 +318,59 @@ def build_document(chunk=None): d2.chunks.clear() d2.update_content_hash(include_fields=('chunks',), exclude_fields=None) assert d1.content_hash != d2.content_hash + + +@pytest.mark.parametrize('from_str', [True, False]) +@pytest.mark.parametrize('d_src', [ + {'id': '123', 'mime_type': 'txt', 'parent_id': '456', 'tags': {'hello': 'world'}}, + {'id': '123', 'mimeType': 'txt', 'parentId': '456', 'tags': {'hello': 'world'}}, + {'id': '123', 'mimeType': 'txt', 'parent_id': '456', 'tags': {'hello': 'world'}}, +]) +def test_doc_from_dict_cases(d_src, from_str): + # regular case + if from_str: + d_src = json.dumps(d_src) + d = Document(d_src) + assert d.tags['hello'] == 'world' + assert d.mime_type == 'txt' + assert d.id == '123' + assert d.parent_id == '456' + + +@pytest.mark.parametrize('from_str', [True, False]) +def test_doc_arbitrary_dict(from_str): + d_src = {'id': '123', 'hello': 'world', 'tags': {'good': 'bye'}} + if from_str: + d_src = json.dumps(d_src) + d = Document(d_src) + assert d.id == '123' + assert d.tags['hello'] == 'world' + assert d.tags['good'] == 'bye' + + d_src = {'hello': 'world', 'good': 'bye'} + if from_str: + d_src = json.dumps(d_src) + d = Document(d_src) + assert d.tags['hello'] == 'world' + assert d.tags['good'] == 'bye' + + +@pytest.mark.parametrize('from_str', [True, False]) +def test_doc_field_resolver(from_str): + d_src = {'music_id': '123', 'hello': 'world', 'tags': {'good': 'bye'}} + if from_str: + d_src = json.dumps(d_src) + d = Document(d_src) + assert d.id != '123' + assert d.tags['hello'] == 'world' + assert d.tags['good'] == 'bye' + assert d.tags['music_id'] == '123' + + d_src = {'music_id': '123', 'hello': 'world', 'tags': {'good': 'bye'}} + if from_str: + d_src = json.dumps(d_src) + d = Document(d_src, field_resolver={'music_id': 'id'}) + assert d.id == '123' + assert d.tags['hello'] == 'world' + assert d.tags['good'] == 'bye' + assert 'music_id' not in d.tags