Skip to content

Commit

Permalink
refactor(types): allow doc building from arbitrary json/dict (#1877)
Browse files Browse the repository at this point in the history
* refactor(types): allow doc building from arbitrary json/dict
  • Loading branch information
hanxiao committed Feb 5, 2021
1 parent 8af7c41 commit 5e2767a
Show file tree
Hide file tree
Showing 13 changed files with 112 additions and 23 deletions.
2 changes: 1 addition & 1 deletion cli/__init__.py
Expand Up @@ -17,7 +17,7 @@ def _get_run_args(print_args: bool = True):
from pkg_resources import resource_filename
p = parser._actions[-1].choices[sys.argv[1]]
default_args = {a.dest: a.default for a in p._actions if
isinstance(a, _StoreAction) or isinstance(a, _StoreTrueAction)}
isinstance(a, (_StoreAction, _StoreTrueAction))}

with open(resource_filename('jina', '/'.join(('resources', 'jina.logo')))) as fp:
logo_str = fp.read()
Expand Down
2 changes: 1 addition & 1 deletion daemon/parser.py
Expand Up @@ -44,7 +44,7 @@ def _get_run_args(print_args: bool = True):
if print_args:
from pkg_resources import resource_filename
default_args = {a.dest: a.default for a in parser._actions if
isinstance(a, _StoreAction) or isinstance(a, _StoreTrueAction)}
isinstance(a, (_StoreAction, _StoreTrueAction))}

with open(resource_filename('jina', '/'.join(('resources', 'jina.logo')))) as fp:
logo_str = fp.read()
Expand Down
3 changes: 1 addition & 2 deletions extra-requirements.txt
Expand Up @@ -63,8 +63,7 @@ flaky: test
mock: test
requests: http, devel, test, daemon
prettytable: devel, test
sseclient-py: test
optuna: test, optimizer
optuna: cicd, optimizer
websockets: http, devel, test, ws, daemon
wsproto: http, devel, test, ws, daemon
pydantic: http, devel, test, daemon
Expand Down
2 changes: 1 addition & 1 deletion jina/drivers/querylang/queryset/dunderkey.py
Expand Up @@ -128,7 +128,7 @@ def dunder_get(_dict: Any, key: str) -> Any:

if isinstance(part1, int):
result = guard_iter(_dict)[part1]
elif isinstance(_dict, dict) or isinstance(_dict, Struct):
elif isinstance(_dict, (dict, Struct)):
if part1 in _dict:
result = _dict[part1]
else:
Expand Down
2 changes: 1 addition & 1 deletion jina/flow/base.py
Expand Up @@ -107,7 +107,7 @@ def _parse_endpoints(op_flow, pod_name, endpoint, connect_to_last_pod=False) ->
else:
endpoint = []

if isinstance(endpoint, list) or isinstance(endpoint, tuple):
if isinstance(endpoint, (list, tuple)):
for idx, s in enumerate(endpoint):
if s == pod_name:
raise FlowTopologyError('the income/output of a pod can not be itself')
Expand Down
4 changes: 2 additions & 2 deletions jina/helper.py
Expand Up @@ -278,14 +278,14 @@ def _scan(sub_d: Union[Dict, List], p):
def _replace(sub_d: Union[Dict, List], p):
if isinstance(sub_d, Dict):
for k, v in sub_d.items():
if isinstance(v, dict) or isinstance(v, list):
if isinstance(v, (dict, list)):
_replace(v, p.__dict__[k])
else:
if isinstance(v, str) and pat.findall(v):
sub_d[k] = _sub(v, p)
elif isinstance(sub_d, List):
for idx, v in enumerate(sub_d):
if isinstance(v, dict) or isinstance(v, list):
if isinstance(v, (dict, list)):
_replace(v, p[idx])
else:
if isinstance(v, str) and pat.findall(v):
Expand Down
4 changes: 2 additions & 2 deletions jina/jaml/__init__.py
Expand Up @@ -127,7 +127,7 @@ def _scan(sub_d, p):
def _replace(sub_d, p, resolve_ref=False):
if isinstance(sub_d, dict):
for k, v in sub_d.items():
if isinstance(v, dict) or isinstance(v, list):
if isinstance(v, (dict, list)):
_replace(v, p.__dict__[k], resolve_ref)
else:
if isinstance(v, str):
Expand All @@ -137,7 +137,7 @@ def _replace(sub_d, p, resolve_ref=False):
sub_d[k] = _sub(v)
elif isinstance(sub_d, list):
for idx, v in enumerate(sub_d):
if isinstance(v, dict) or isinstance(v, list):
if isinstance(v, (dict, list)):
_replace(v, p[idx], resolve_ref)
else:
if isinstance(v, str):
Expand Down
2 changes: 1 addition & 1 deletion jina/jaml/helper.py
Expand Up @@ -184,7 +184,7 @@ def _finditem(obj, key='py_modules'):
value = obj.get(key, [])
if isinstance(value, str):
mod.append(value)
elif isinstance(value, list) or isinstance(value, tuple):
elif isinstance(value, (list, tuple)):
mod.extend(value)
for k, v in obj.items():
if isinstance(v, dict):
Expand Down
2 changes: 1 addition & 1 deletion jina/jaml/parsers/__init__.py
Expand Up @@ -55,7 +55,7 @@ def get_parser(cls: Type['JAMLCompatible'], version: Optional[str]) -> 'Versione
"""
all_parsers, legacy_parser = _get_all_parser(cls)
if version:
if isinstance(version, float) or isinstance(version, int):
if isinstance(version, (float, int)):
version = str(version)
for p in all_parsers:
if p.version == version:
Expand Down
47 changes: 41 additions & 6 deletions jina/types/document/__init__.py
@@ -1,4 +1,5 @@
import base64
import json
import mimetypes
import os
import urllib.parse
Expand Down Expand Up @@ -29,6 +30,9 @@
DocumentSourceType = TypeVar('DocumentSourceType',
jina_pb2.DocumentProto, bytes, str, Dict)

_document_fields = set(list(jina_pb2.DocumentProto().DESCRIPTOR.fields_by_camelcase_name) + list(
jina_pb2.DocumentProto().DESCRIPTOR.fields_by_name))


class Document(ProtoTypeMixin):
"""
Expand Down Expand Up @@ -93,6 +97,7 @@ class Document(ProtoTypeMixin):
"""

def __init__(self, document: Optional[DocumentSourceType] = None,
field_resolver: Dict[str, str] = None,
copy: bool = False, **kwargs):
"""
Expand All @@ -104,7 +109,24 @@ def __init__(self, document: Optional[DocumentSourceType] = None,
it builds a view or a copy from it.
:param copy: when ``document`` is given as a :class:`DocumentProto` object, build a
view (i.e. weak reference) from it or a deep copy from it.
:param kwargs: other parameters to be set
:param field_resolver: a map from field names defined in ``document`` (JSON, dict) to the field
names defined in Protobuf. This is only used when the given ``document`` is
a JSON string or a Python dict.
:param kwargs: other parameters to be set _after_ the document is constructed
.. note::
When ``document`` is a JSON string or Python dictionary object, the constructor will only map the values
from known fields defined in Protobuf, all unknown fields are mapped to ``document.tags``. For example,
.. highlight:: python
.. code-block:: python
d = Document({'id': '123', 'hello': 'world', 'tags': {'good': 'bye'}})
assert d.id == '123' # true
assert d.tags['hello'] == 'world' # true
assert d.tags['good'] == 'bye' # true
"""
self._pb_body = jina_pb2.DocumentProto()
try:
Expand All @@ -113,10 +135,23 @@ def __init__(self, document: Optional[DocumentSourceType] = None,
self._pb_body.CopyFrom(document)
else:
self._pb_body = document
elif isinstance(document, dict):
json_format.ParseDict(document, self._pb_body)
elif isinstance(document, str):
json_format.Parse(document, self._pb_body)
elif isinstance(document, (dict, str)):
if isinstance(document, str):
document = json.loads(document)

if field_resolver:
document = {field_resolver.get(k, k): v for k, v in document.items()}

user_fields = set(document.keys())
if _document_fields.issuperset(user_fields):
json_format.ParseDict(document, self._pb_body)
else:
_intersect = _document_fields.intersection(user_fields)
_remainder = user_fields.difference(_intersect)
if _intersect:
json_format.ParseDict({k: document[k] for k in _intersect}, self._pb_body)
if _remainder:
self._pb_body.tags.update({k: document[k] for k in _remainder})
elif isinstance(document, bytes):
# directly parsing from binary string gives large false-positive
# fortunately protobuf throws a warning when the parsing seems go wrong
Expand Down Expand Up @@ -318,7 +353,7 @@ def set_attrs(self, **kwargs):
"""
for k, v in kwargs.items():
if isinstance(v, list) or isinstance(v, tuple):
if isinstance(v, (list, tuple)):
if k == 'chunks':
self.chunks.extend(v)
elif k == 'matches':
Expand Down
2 changes: 1 addition & 1 deletion jina/types/score/__init__.py
Expand Up @@ -72,7 +72,7 @@ def set_attrs(self, **kwargs):
"""
for k, v in kwargs.items():
if isinstance(v, list) or isinstance(v, tuple):
if isinstance(v, (list, tuple)):
self._pb_body.ClearField(k)
getattr(self._pb_body, k).extend(v)
elif isinstance(v, dict):
Expand Down
5 changes: 1 addition & 4 deletions tests/unit/clients/python/test_request.py
@@ -1,4 +1,5 @@
import os
import sys

import numpy as np
import pytest
Expand All @@ -13,7 +14,6 @@
from jina.proto.jina_pb2 import DocumentProto
from jina.types.ndarray.generic import NdArray

import sys

@pytest.mark.skipif(sys.version_info < (3, 8, 0), reason='somehow this does not work on Github workflow with Py3.7, '
'but Py 3.8 is fine, local Py3.7 is fine')
Expand Down Expand Up @@ -45,9 +45,6 @@ def test_data_type_builder_doc_bad():
with pytest.raises(BadDocType):
_new_doc_from_data(MessageToJson(a) + '🍔', DataInputType.DOCUMENT)

with pytest.raises(BadDocType):
_new_doc_from_data({'🍔': '🍔'}, DataInputType.DOCUMENT)


@pytest.mark.parametrize('input_type', [DataInputType.AUTO, DataInputType.CONTENT])
def test_data_type_builder_auto(input_type):
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/types/document/test_document.py
@@ -1,3 +1,5 @@
import json

import numpy as np
import pytest
from google.protobuf.json_format import MessageToDict
Expand Down Expand Up @@ -316,3 +318,59 @@ def build_document(chunk=None):
d2.chunks.clear()
d2.update_content_hash(include_fields=('chunks',), exclude_fields=None)
assert d1.content_hash != d2.content_hash


@pytest.mark.parametrize('from_str', [True, False])
@pytest.mark.parametrize('d_src', [
{'id': '123', 'mime_type': 'txt', 'parent_id': '456', 'tags': {'hello': 'world'}},
{'id': '123', 'mimeType': 'txt', 'parentId': '456', 'tags': {'hello': 'world'}},
{'id': '123', 'mimeType': 'txt', 'parent_id': '456', 'tags': {'hello': 'world'}},
])
def test_doc_from_dict_cases(d_src, from_str):
# regular case
if from_str:
d_src = json.dumps(d_src)
d = Document(d_src)
assert d.tags['hello'] == 'world'
assert d.mime_type == 'txt'
assert d.id == '123'
assert d.parent_id == '456'


@pytest.mark.parametrize('from_str', [True, False])
def test_doc_arbitrary_dict(from_str):
d_src = {'id': '123', 'hello': 'world', 'tags': {'good': 'bye'}}
if from_str:
d_src = json.dumps(d_src)
d = Document(d_src)
assert d.id == '123'
assert d.tags['hello'] == 'world'
assert d.tags['good'] == 'bye'

d_src = {'hello': 'world', 'good': 'bye'}
if from_str:
d_src = json.dumps(d_src)
d = Document(d_src)
assert d.tags['hello'] == 'world'
assert d.tags['good'] == 'bye'


@pytest.mark.parametrize('from_str', [True, False])
def test_doc_field_resolver(from_str):
d_src = {'music_id': '123', 'hello': 'world', 'tags': {'good': 'bye'}}
if from_str:
d_src = json.dumps(d_src)
d = Document(d_src)
assert d.id != '123'
assert d.tags['hello'] == 'world'
assert d.tags['good'] == 'bye'
assert d.tags['music_id'] == '123'

d_src = {'music_id': '123', 'hello': 'world', 'tags': {'good': 'bye'}}
if from_str:
d_src = json.dumps(d_src)
d = Document(d_src, field_resolver={'music_id': 'id'})
assert d.id == '123'
assert d.tags['hello'] == 'world'
assert d.tags['good'] == 'bye'
assert 'music_id' not in d.tags

0 comments on commit 5e2767a

Please sign in to comment.