Skip to content

Commit

Permalink
refactor: rename arguments and test output type
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Mar 5, 2021
1 parent 888b148 commit 720c2a4
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
16 changes: 7 additions & 9 deletions jina/types/sets/document.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import MutableSequence
from typing import Union, Iterable, Tuple, Sequence
from typing import Union, Iterable, Tuple, Sequence, List


import numpy as np
Expand Down Expand Up @@ -165,17 +165,14 @@ def all_contents(self) -> Tuple['np.ndarray', 'DocumentSet']:
"""
return self._extract_docs('content')

def _extract_docs(self, *attr: str) -> Tuple['np.ndarray', 'DocumentSet']:
list_of_contents_output = len(attr) > 1
if list_of_contents_output:
contents = [list() for _ in range(len(attr))]
else:
contents = []
def _extract_docs(self, *fields: str) -> Tuple[Union['np.ndarray', List['np.ndarray']], 'DocumentSet']:
list_of_contents_output = len(fields) > 1
contents = [[] for _ in fields if len(fields) > 1]
docs_pts = []
bad_docs = []

for doc in self:
content = doc.get_attrs_values(*attr)
content = doc.get_attrs_values(*fields)
content = content if list_of_contents_output else content[0]

if content is not None:
Expand All @@ -196,7 +193,8 @@ def _extract_docs(self, *attr: str) -> Tuple['np.ndarray', 'DocumentSet']:

if bad_docs and docs_pts:
default_logger.warning(
f'found {len(bad_docs)} no-{attr} docs at granularity {docs_pts[0].granularity}')
f'found {len(bad_docs)} docs at granularity {docs_pts[0].granularity} are missing one of the '
f'following fields: {fields} ')

return contents, DocumentSet(docs_pts)

Expand Down
12 changes: 12 additions & 0 deletions tests/unit/types/sets/test_documentset.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def test_get_content(num_rows, field):
docs.append(Document())

contents, pts = docs._extract_docs(field)
assert isinstance(contents, np.ndarray)

assert contents.shape == (batch_size, num_rows, embed_size)

Expand All @@ -253,6 +254,7 @@ def test_get_content_text_fields(field):
docs = DocumentSet([Document(**kwargs) for _ in range(batch_size)])

contents, pts = docs._extract_docs(field)
assert isinstance(contents, np.ndarray)

assert contents.shape == (batch_size,)
assert len(contents) == batch_size
Expand All @@ -274,6 +276,7 @@ def test_get_content_bytes_fields(field):

assert contents.shape == (batch_size,)
assert len(contents) == batch_size
assert isinstance(contents, np.ndarray)
for content in contents:
assert content == b'bytes'

Expand All @@ -291,6 +294,9 @@ def test_get_content_multiple_fields_text(fields):
contents, pts = docs._extract_docs(*fields)

assert len(contents) == len(fields)
assert isinstance(contents, list)
assert isinstance(contents[0], np.ndarray)
assert isinstance(contents[1], np.ndarray)

for content in contents:
assert len(content) == batch_size
Expand All @@ -312,6 +318,9 @@ def test_get_content_multiple_fields_arrays(num_rows):
contents, pts = docs._extract_docs(*fields)

assert len(contents) == len(fields)
assert isinstance(contents, list)
assert isinstance(contents[0], np.ndarray)
assert isinstance(contents[1], np.ndarray)

for content in contents:
assert len(content) == batch_size
Expand All @@ -333,6 +342,9 @@ def test_get_content_multiple_fields_merge(num_rows):
contents, pts = docs._extract_docs(*fields)

assert len(contents) == len(fields)
assert isinstance(contents, list)
assert isinstance(contents[0], np.ndarray)
assert isinstance(contents[1], np.ndarray)

for content in contents:
assert len(content) == batch_size
Expand Down

0 comments on commit 720c2a4

Please sign in to comment.