Skip to content

Commit

Permalink
feat: allow extracting multiple contents
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Mar 3, 2021
1 parent 939408b commit 96f6acb
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 8 deletions.
58 changes: 52 additions & 6 deletions jina/types/document/__init__.py
Expand Up @@ -7,7 +7,7 @@
import urllib.request
import warnings
from hashlib import blake2b
from typing import Union, Dict, Optional, TypeVar, Any, Tuple
from typing import Union, Dict, Optional, TypeVar, Any, Tuple, List

import numpy as np
from google.protobuf import json_format
Expand Down Expand Up @@ -426,6 +426,14 @@ def chunks(self) -> 'ChunkSet':
"""Get all chunks of the current document."""
return ChunkSet(self._pb_body.chunks, reference_doc=self)

def __getattr__(self, item):
print(f' item {item}')
if hasattr(self._pb_body, item):
value = getattr(self._pb_body, item)
else:
value = dunder_get(self._pb_body, item)
return value

def set_attrs(self, **kwargs):
"""Bulk update Document fields with key-value specified in kwargs
Expand Down Expand Up @@ -483,12 +491,8 @@ def get_attrs(self, *args) -> Dict[str, Any]:

ret = {}
for k in args:

try:
if hasattr(self, k):
value = getattr(self, k)
else:
value = dunder_get(self._pb_body, k)
value = getattr(self, k)

if value is None:
raise ValueError
Expand All @@ -499,6 +503,48 @@ def get_attrs(self, *args) -> Dict[str, Any]:
ret[k] = None
return ret

def get_attrs_values(self, *args) -> List[Any]:
"""Bulk fetch Document fields and return a list of the values of these fields
.. note::
Arguments will be extracted using `dunder_get`
.. highlight:: python
.. code-block:: python
d = Document({'id': '123', 'hello': 'world', 'tags': {'id': 'external_id', 'good': 'bye'}})
assert d.id == '123' # true
assert d.tags['hello'] == 'world' # true
assert d.tags['good'] == 'bye' # true
assert d.tags['id'] == 'external_id' # true
res = d.get_attrs_values(*['id', 'tags__hello', 'tags__good', 'tags__id'])
assert res == ['123', 'world', 'bye', 'external_id']
assert res['id'] == '123' # true
assert res['tags__hello'] == 'world' # true
assert res['tags__good'] == 'bye' # true
assert res['tags__id'] == 'external_id' # true
"""

ret = []
for k in args:

try:
value = getattr(self, k)

if value is None:
raise ValueError

ret.append(value)
except (AttributeError, ValueError):
default_logger.warning(f'Could not get attribute `{typename(self)}.{k}`, returning `None`')
ret.append(None)

return ret

@property
def buffer(self) -> bytes:
"""Return ``buffer``, one of the content form of a Document.
Expand Down
5 changes: 3 additions & 2 deletions jina/types/sets/document.py
Expand Up @@ -193,13 +193,14 @@ def all_contents(self) -> Tuple['np.ndarray', 'DocumentSet']:
"""
return self._extract_docs('content')

def _extract_docs(self, attr: str) -> Tuple['np.ndarray', 'DocumentSet']:
def _extract_docs(self, *attr: str) -> Tuple['np.ndarray', 'DocumentSet']:
contents = []
docs_pts = []
bad_docs = []

for doc in self:
content = getattr(doc, attr)
content = doc.get_attrs_values(*attr)
content = content[0] if len(content) == 1 else content

if content is not None:
contents.append(content)
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/types/sets/test_documentset.py
Expand Up @@ -276,3 +276,42 @@ def test_get_content_bytes_fields(field):
assert len(contents) == batch_size
for content in contents:
assert content == b'bytes'


@pytest.mark.parametrize('fields', [['id', 'text'], ['content_hash', 'modality']])
def test_get_content_multiple_fields_text(fields):
batch_size = 10

kwargs = {
field: f'text-{field}' for field in fields
}

docs = DocumentSet([Document(**kwargs) for _ in range(batch_size)])

contents, pts = docs._extract_docs(*fields)

print(f' contents.shape {contents.shape}')
assert contents.shape == (batch_size, len(fields))
assert len(contents) == batch_size
for content in contents:
assert len(content) == len(fields)


@pytest.mark.parametrize('num_rows', [1, 2, 3])
def test_get_content_multiple_fields_arrays(num_rows):
fields = ['blob', 'embedding']

batch_size = 10
embed_size = 20

kwargs = {
field: np.random.random((num_rows, embed_size)) for field in fields
}
docs = DocumentSet([Document(**kwargs) for _ in range(batch_size)])

contents, pts = docs._extract_docs(*fields)

assert contents.shape == (batch_size, len(fields), num_rows, embed_size)
assert len(contents) == batch_size
for content in contents:
assert len(content) == len(fields)

0 comments on commit 96f6acb

Please sign in to comment.