Skip to content

Commit

Permalink
feat: allow rank driver access info in tags (#1718)
Browse files Browse the repository at this point in the history
* feat: allow rank driver access info in tags

* fix: consider possibility no required keys

* refactor: update jina/executors/rankers/__init__.py

Co-authored-by: Nan Wang <nan.wang@jina.ai>

* refactor: update jina/executors/rankers/__init__.py

Co-authored-by: Nan Wang <nan.wang@jina.ai>

* refactor: move queryset from drivers to types

* refactor: keep relative imports

* refactor: refactor get_attrs

* fix: make sure order of get_attrs is respected

* fix: solve id and tags__id conflict

* fix: fix considering not having required_keys

* fix: update jina/types/document/__init__.py

Co-authored-by: Nan Wang <nan.wang@jina.ai>

* fix: access only using dunder_get

* fix: fix hub build io test

* fix: fix how to get attrs

Co-authored-by: Nan Wang <nan.wang@jina.ai>
  • Loading branch information
JoanFM and nan-wang committed Feb 16, 2021
1 parent d153381 commit 5bc4955
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 34 deletions.
4 changes: 1 addition & 3 deletions jina/drivers/rank/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from ...types.document import Document
from ...types.score import NamedScore


if False:
from ...types.sets import DocumentSet

Expand Down Expand Up @@ -49,11 +48,10 @@ def _apply_all(self, docs: 'DocumentSet', context_doc: 'Document', *args,
- Set the ``traversal_paths`` of this driver such that it traverses along the ``matches`` of the ``chunks`` at the level desired.
"""

# if at the top-level already, no need to aggregate further
query_meta = context_doc.get_attrs(*self.exec.required_keys)

old_match_scores = {match.id: match.score.value for match in docs}
match_meta = {match.id: match.get_attrs(*self.exec.required_keys) for match in docs}

# if there are no matches, no need to sort them
if not old_match_scores:
return
Expand Down
1 change: 1 addition & 0 deletions jina/drivers/rank/aggregate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def _apply_all(self, docs: 'DocumentSet', context_doc: 'Document', *args,
match_meta = {}
parent_id_chunk_id_map = defaultdict(list)
matches_by_id = defaultdict(Document)

query_meta[context_doc.id] = context_doc.get_attrs(*self.exec.required_keys)
for match in docs:
match_info = self._extract_query_match_info(match=match, query=context_doc)
Expand Down
9 changes: 9 additions & 0 deletions jina/executors/rankers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ class Chunk2DocRanker(BaseRanker):
"""

required_keys = {'text'} #: a set of ``str``, key-values to extracted from the chunk-level protobuf message
"""set: Set of required keys to be extracted from matches and query to fill the information of `query` and `chunk` meta information.
These are the set of keys to be extracted from `Document`.
All the keys not found in the `DocumentProto` fields, will be extracted from the `tags` structure of `Document`.
.. seealso::
:meth:`get_attrs` of :class:`Document`
"""
COL_MATCH_PARENT_ID = 'match_parent_id'
COL_MATCH_ID = 'match_id'
COL_DOC_CHUNK_ID = 'doc_chunk_id'
Expand Down
40 changes: 39 additions & 1 deletion jina/types/document/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
from ..score import NamedScore
from ..sets.chunk import ChunkSet
from ..sets.match import MatchSet
from ..querylang.queryset.dunderkey import dunder_get
from ...excepts import BadDocType
from ...helper import is_url, typename, random_identity, download_mermaid_url
from ...importer import ImportExtensions
from ...proto import jina_pb2
from ...logging import default_logger

__all__ = ['Document', 'DocumentContentType', 'DocumentSourceType']
DIGEST_SIZE = 8
Expand Down Expand Up @@ -457,8 +459,44 @@ def get_attrs(self, *args) -> Dict[str, Any]:
.. seealso::
:meth:`update` for bulk set/update attributes
.. note::
Arguments will be extracted using `dunder_get`
.. highlight:: python
.. code-block:: python
d = Document({'id': '123', 'hello': 'world', 'tags': {'id': 'external_id', 'good': 'bye'}})
assert d.id == '123' # true
assert d.tags['hello'] == 'world' # true
assert d.tags['good'] == 'bye' # true
assert d.tags['id'] == 'external_id' # true
res = d.get_attrs(*['id', 'tags__hello', 'tags__good', 'tags__id'])
assert res['id'] == '123' # true
assert res['tags__hello'] == 'world' # true
assert res['tags__good'] == 'bye' # true
assert res['tags__id'] == 'external_id' # true
"""
return {k: getattr(self, k) for k in args if hasattr(self, k)}

ret = {}
for k in args:

try:
if hasattr(self, k):
value = getattr(self, k)
else:
value = dunder_get(self._pb_body, k)

if not value:
raise ValueError

ret[k] = value
continue
except (AttributeError, ValueError):
default_logger.warning(f'Could not get attribute from key {k}, returning None')
ret[k] = None
return ret

@property
def buffer(self) -> bytes:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/docker/test_hub_build_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ def test_hub_build_level_fail(monkeypatch, test_workspace, docker_image):
os.path.join(cur_dir, 'yaml/test-joint.yml'), 60, True,
JinaLogger('unittest'))

assert expected_failed_levels == failed_levels
assert expected_failed_levels == failed_levels
25 changes: 0 additions & 25 deletions tests/unit/drivers/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,6 @@
from jina.types.ndarray.generic import NdArray


@pytest.fixture(scope='function')
def document():
with Document() as doc:
doc.text = 'this is text'
doc.tags['id'] = 'id in tags'
doc.tags['inner_dict'] = {'id': 'id in inner_dict'}
with Document() as chunk:
chunk.text = 'text in chunk'
chunk.tags['id'] = 'id in chunk tags'
doc.chunks.add(chunk)
return doc


@pytest.mark.parametrize(
'proto_type', ['float32', 'float64', 'uint8']
)
Expand All @@ -44,18 +31,6 @@ def test_array_protobuf_conversions_with_quantize(quantize, proto_type):
np.testing.assert_almost_equal(d.value, random_array, decimal=2)


def test_pb_obj2dict(document):
res = document.get_attrs('text', 'tags', 'chunks')
assert res['text'] == 'this is text'
assert res['tags']['id'] == 'id in tags'
assert res['tags']['inner_dict']['id'] == 'id in inner_dict'
rcs = list(res['chunks'])
assert len(rcs) == 1
assert isinstance(rcs[0], Document)
assert rcs[0].text == 'text in chunk'
assert rcs[0].tags['id'] == 'id in chunk tags'


def test_add_route():
r = jina_pb2.RequestProto()
r.control.command = jina_pb2.RequestProto.ControlRequestProto.IDLE
Expand Down
56 changes: 52 additions & 4 deletions tests/unit/types/document/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

import numpy as np
import pytest

from google.protobuf.json_format import MessageToDict
from jina import NdArray, Request
from jina.proto.jina_pb2 import DocumentProto
from jina.types.document import Document
from jina.types.score import NamedScore
from tests import random_docs

DOCUMENTS_PER_LEVEL = 1


@pytest.mark.parametrize('field', ['blob', 'embedding'])
def test_ndarray_get_set(field):
Expand Down Expand Up @@ -230,7 +229,6 @@ def test_doc_setattr():


def test_doc_score():
from jina import Document
from jina.types.score import NamedScore
with Document() as doc:
doc.text = 'text'
Expand Down Expand Up @@ -459,7 +457,6 @@ def test_update_embedding():

def test_non_empty_fields():
d_score = Document(score=NamedScore(value=42))
print(d_score.ListFields())
assert d_score.non_empty_fields == ('id', 'score')

d = Document()
Expand Down Expand Up @@ -495,3 +492,54 @@ def test_update_exclude_field():
# check if merging on embedding is correct
assert len(d.chunks) == 1
assert d.chunks[0].id == '🐢'


def test_get_attr():
d = Document({'id': '123', 'text': 'document', 'feature1': 121, 'name': 'name',
'tags': {'id': 'identity', 'a': 'b', 'c': 'd'}})
d.score = NamedScore(value=42)

required_keys = ['id', 'text', 'tags__name', 'tags__feature1', 'score__value', 'tags__c', 'tags__id', 'tags__inexistant', 'inexistant']
res = d.get_attrs(*required_keys)

assert len(res.keys()) == len(required_keys)
assert res['id'] == '123'
assert res['tags__feature1'] == 121
assert res['tags__name'] == 'name'
assert res['text'] == 'document'
assert res['tags__c'] == 'd'
assert res['tags__id'] == 'identity'
assert res['score__value'] == 42
assert res['tags__inexistant'] is None
assert res['inexistant'] is None

res2 = d.get_attrs(*['tags', 'text'])
assert len(res2.keys()) == 2
assert res2['text'] == 'document'
assert res2['tags'] == d.tags

d = Document({'id': '123', 'tags': {'outterkey': {'innerkey': 'real_value'}}})
res3 = d.get_attrs(*['tags__outterkey__innerkey'])
assert len(res3.keys()) == 1
assert res3['tags__outterkey__innerkey'] == 'real_value'


def test_pb_obj2dict():
document = Document()
with document:
document.text = 'this is text'
document.tags['id'] = 'id in tags'
document.tags['inner_dict'] = {'id': 'id in inner_dict'}
with Document() as chunk:
chunk.text = 'text in chunk'
chunk.tags['id'] = 'id in chunk tags'
document.chunks.add(chunk)
res = document.get_attrs('text', 'tags', 'chunks')
assert res['text'] == 'this is text'
assert res['tags']['id'] == 'id in tags'
assert res['tags']['inner_dict']['id'] == 'id in inner_dict'
rcs = list(res['chunks'])
assert len(rcs) == 1
assert isinstance(rcs[0], Document)
assert rcs[0].text == 'text in chunk'
assert rcs[0].tags['id'] == 'id in chunk tags'

0 comments on commit 5bc4955

Please sign in to comment.