Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(types): allow doc building from arbitrary json/dict #1877

Merged
merged 3 commits into from
Feb 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _get_run_args(print_args: bool = True):
from pkg_resources import resource_filename
p = parser._actions[-1].choices[sys.argv[1]]
default_args = {a.dest: a.default for a in p._actions if
isinstance(a, _StoreAction) or isinstance(a, _StoreTrueAction)}
isinstance(a, (_StoreAction, _StoreTrueAction))}

with open(resource_filename('jina', '/'.join(('resources', 'jina.logo')))) as fp:
logo_str = fp.read()
Expand Down
2 changes: 1 addition & 1 deletion daemon/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _get_run_args(print_args: bool = True):
if print_args:
from pkg_resources import resource_filename
default_args = {a.dest: a.default for a in parser._actions if
isinstance(a, _StoreAction) or isinstance(a, _StoreTrueAction)}
isinstance(a, (_StoreAction, _StoreTrueAction))}

with open(resource_filename('jina', '/'.join(('resources', 'jina.logo')))) as fp:
logo_str = fp.read()
Expand Down
3 changes: 1 addition & 2 deletions extra-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ flaky: test
mock: test
requests: http, devel, test, daemon
prettytable: devel, test
sseclient-py: test
optuna: test, optimizer
optuna: cicd, optimizer
websockets: http, devel, test, ws, daemon
wsproto: http, devel, test, ws, daemon
pydantic: http, devel, test, daemon
Expand Down
2 changes: 1 addition & 1 deletion jina/drivers/querylang/queryset/dunderkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def dunder_get(_dict: Any, key: str) -> Any:

if isinstance(part1, int):
result = guard_iter(_dict)[part1]
elif isinstance(_dict, dict) or isinstance(_dict, Struct):
elif isinstance(_dict, (dict, Struct)):
if part1 in _dict:
result = _dict[part1]
else:
Expand Down
2 changes: 1 addition & 1 deletion jina/flow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _parse_endpoints(op_flow, pod_name, endpoint, connect_to_last_pod=False) ->
else:
endpoint = []

if isinstance(endpoint, list) or isinstance(endpoint, tuple):
if isinstance(endpoint, (list, tuple)):
for idx, s in enumerate(endpoint):
if s == pod_name:
raise FlowTopologyError('the income/output of a pod can not be itself')
Expand Down
4 changes: 2 additions & 2 deletions jina/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,14 +278,14 @@ def _scan(sub_d: Union[Dict, List], p):
def _replace(sub_d: Union[Dict, List], p):
if isinstance(sub_d, Dict):
for k, v in sub_d.items():
if isinstance(v, dict) or isinstance(v, list):
if isinstance(v, (dict, list)):
_replace(v, p.__dict__[k])
else:
if isinstance(v, str) and pat.findall(v):
sub_d[k] = _sub(v, p)
elif isinstance(sub_d, List):
for idx, v in enumerate(sub_d):
if isinstance(v, dict) or isinstance(v, list):
if isinstance(v, (dict, list)):
_replace(v, p[idx])
else:
if isinstance(v, str) and pat.findall(v):
Expand Down
4 changes: 2 additions & 2 deletions jina/jaml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _scan(sub_d, p):
def _replace(sub_d, p, resolve_ref=False):
if isinstance(sub_d, dict):
for k, v in sub_d.items():
if isinstance(v, dict) or isinstance(v, list):
if isinstance(v, (dict, list)):
_replace(v, p.__dict__[k], resolve_ref)
else:
if isinstance(v, str):
Expand All @@ -137,7 +137,7 @@ def _replace(sub_d, p, resolve_ref=False):
sub_d[k] = _sub(v)
elif isinstance(sub_d, list):
for idx, v in enumerate(sub_d):
if isinstance(v, dict) or isinstance(v, list):
if isinstance(v, (dict, list)):
_replace(v, p[idx], resolve_ref)
else:
if isinstance(v, str):
Expand Down
2 changes: 1 addition & 1 deletion jina/jaml/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _finditem(obj, key='py_modules'):
value = obj.get(key, [])
if isinstance(value, str):
mod.append(value)
elif isinstance(value, list) or isinstance(value, tuple):
elif isinstance(value, (list, tuple)):
mod.extend(value)
for k, v in obj.items():
if isinstance(v, dict):
Expand Down
2 changes: 1 addition & 1 deletion jina/jaml/parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_parser(cls: Type['JAMLCompatible'], version: Optional[str]) -> 'Versione
"""
all_parsers, legacy_parser = _get_all_parser(cls)
if version:
if isinstance(version, float) or isinstance(version, int):
if isinstance(version, (float, int)):
version = str(version)
for p in all_parsers:
if p.version == version:
Expand Down
47 changes: 41 additions & 6 deletions jina/types/document/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import json
import mimetypes
import os
import urllib.parse
Expand Down Expand Up @@ -29,6 +30,9 @@
DocumentSourceType = TypeVar('DocumentSourceType',
jina_pb2.DocumentProto, bytes, str, Dict)

_document_fields = set(list(jina_pb2.DocumentProto().DESCRIPTOR.fields_by_camelcase_name) + list(
jina_pb2.DocumentProto().DESCRIPTOR.fields_by_name))


class Document(ProtoTypeMixin):
"""
Expand Down Expand Up @@ -93,6 +97,7 @@ class Document(ProtoTypeMixin):
"""

def __init__(self, document: Optional[DocumentSourceType] = None,
field_resolver: Dict[str, str] = None,
copy: bool = False, **kwargs):
"""

Expand All @@ -104,7 +109,24 @@ def __init__(self, document: Optional[DocumentSourceType] = None,
it builds a view or a copy from it.
:param copy: when ``document`` is given as a :class:`DocumentProto` object, build a
view (i.e. weak reference) from it or a deep copy from it.
:param kwargs: other parameters to be set
:param field_resolver: a map from field names defined in ``document`` (JSON, dict) to the field
names defined in Protobuf. This is only used when the given ``document`` is
a JSON string or a Python dict.
:param kwargs: other parameters to be set _after_ the document is constructed

.. note::

When ``document`` is a JSON string or Python dictionary object, the constructor will only map the values
from known fields defined in Protobuf, all unknown fields are mapped to ``document.tags``. For example,

.. highlight:: python
.. code-block:: python

d = Document({'id': '123', 'hello': 'world', 'tags': {'good': 'bye'}})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this one!


assert d.id == '123' # true
assert d.tags['hello'] == 'world' # true
assert d.tags['good'] == 'bye' # true
"""
self._pb_body = jina_pb2.DocumentProto()
try:
Expand All @@ -113,10 +135,23 @@ def __init__(self, document: Optional[DocumentSourceType] = None,
self._pb_body.CopyFrom(document)
else:
self._pb_body = document
elif isinstance(document, dict):
json_format.ParseDict(document, self._pb_body)
elif isinstance(document, str):
json_format.Parse(document, self._pb_body)
elif isinstance(document, (dict, str)):
if isinstance(document, str):
document = json.loads(document)

if field_resolver:
document = {field_resolver.get(k, k): v for k, v in document.items()}

user_fields = set(document.keys())
if _document_fields.issuperset(user_fields):
json_format.ParseDict(document, self._pb_body)
else:
_intersect = _document_fields.intersection(user_fields)
_remainder = user_fields.difference(_intersect)
if _intersect:
json_format.ParseDict({k: document[k] for k in _intersect}, self._pb_body)
if _remainder:
self._pb_body.tags.update({k: document[k] for k in _remainder})
elif isinstance(document, bytes):
# directly parsing from binary string gives large false-positive
# fortunately protobuf throws a warning when the parsing seems go wrong
Expand Down Expand Up @@ -318,7 +353,7 @@ def set_attrs(self, **kwargs):

"""
for k, v in kwargs.items():
if isinstance(v, list) or isinstance(v, tuple):
if isinstance(v, (list, tuple)):
if k == 'chunks':
self.chunks.extend(v)
elif k == 'matches':
Expand Down
2 changes: 1 addition & 1 deletion jina/types/score/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def set_attrs(self, **kwargs):

"""
for k, v in kwargs.items():
if isinstance(v, list) or isinstance(v, tuple):
if isinstance(v, (list, tuple)):
self._pb_body.ClearField(k)
getattr(self._pb_body, k).extend(v)
elif isinstance(v, dict):
Expand Down
5 changes: 1 addition & 4 deletions tests/unit/clients/python/test_request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys

import numpy as np
import pytest
Expand All @@ -13,7 +14,6 @@
from jina.proto.jina_pb2 import DocumentProto
from jina.types.ndarray.generic import NdArray

import sys

@pytest.mark.skipif(sys.version_info < (3, 8, 0), reason='somehow this does not work on Github workflow with Py3.7, '
'but Py 3.8 is fine, local Py3.7 is fine')
Expand Down Expand Up @@ -45,9 +45,6 @@ def test_data_type_builder_doc_bad():
with pytest.raises(BadDocType):
_new_doc_from_data(MessageToJson(a) + '🍔', DataInputType.DOCUMENT)

with pytest.raises(BadDocType):
_new_doc_from_data({'🍔': '🍔'}, DataInputType.DOCUMENT)


@pytest.mark.parametrize('input_type', [DataInputType.AUTO, DataInputType.CONTENT])
def test_data_type_builder_auto(input_type):
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/types/document/test_document.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json

import numpy as np
import pytest
from google.protobuf.json_format import MessageToDict
Expand Down Expand Up @@ -316,3 +318,59 @@ def build_document(chunk=None):
d2.chunks.clear()
d2.update_content_hash(include_fields=('chunks',), exclude_fields=None)
assert d1.content_hash != d2.content_hash


@pytest.mark.parametrize('from_str', [True, False])
@pytest.mark.parametrize('d_src', [
{'id': '123', 'mime_type': 'txt', 'parent_id': '456', 'tags': {'hello': 'world'}},
{'id': '123', 'mimeType': 'txt', 'parentId': '456', 'tags': {'hello': 'world'}},
{'id': '123', 'mimeType': 'txt', 'parent_id': '456', 'tags': {'hello': 'world'}},
])
def test_doc_from_dict_cases(d_src, from_str):
# regular case
if from_str:
d_src = json.dumps(d_src)
d = Document(d_src)
assert d.tags['hello'] == 'world'
assert d.mime_type == 'txt'
assert d.id == '123'
assert d.parent_id == '456'


@pytest.mark.parametrize('from_str', [True, False])
def test_doc_arbitrary_dict(from_str):
d_src = {'id': '123', 'hello': 'world', 'tags': {'good': 'bye'}}
if from_str:
d_src = json.dumps(d_src)
d = Document(d_src)
assert d.id == '123'
assert d.tags['hello'] == 'world'
assert d.tags['good'] == 'bye'

d_src = {'hello': 'world', 'good': 'bye'}
if from_str:
d_src = json.dumps(d_src)
d = Document(d_src)
assert d.tags['hello'] == 'world'
assert d.tags['good'] == 'bye'


@pytest.mark.parametrize('from_str', [True, False])
def test_doc_field_resolver(from_str):
d_src = {'music_id': '123', 'hello': 'world', 'tags': {'good': 'bye'}}
if from_str:
d_src = json.dumps(d_src)
d = Document(d_src)
assert d.id != '123'
assert d.tags['hello'] == 'world'
assert d.tags['good'] == 'bye'
assert d.tags['music_id'] == '123'

d_src = {'music_id': '123', 'hello': 'world', 'tags': {'good': 'bye'}}
if from_str:
d_src = json.dumps(d_src)
d = Document(d_src, field_resolver={'music_id': 'id'})
assert d.id == '123'
assert d.tags['hello'] == 'world'
assert d.tags['good'] == 'bye'
assert 'music_id' not in d.tags