Skip to content

Commit

Permalink
feat: add setters and getters for da and dam (#3418)
Browse files Browse the repository at this point in the history
Co-authored-by: felix-wang <35718120+numb3r3@users.noreply.github.com>
Co-authored-by: AlaeddineAbdessalem <alaeddine-13@live.fr>
  • Loading branch information
3 people committed Sep 20, 2021
1 parent 19f31c0 commit 4a5eee9
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 0 deletions.
61 changes: 61 additions & 0 deletions jina/types/arrays/document.py
Expand Up @@ -25,6 +25,7 @@
from .search_ops import DocumentArraySearchOpsMixin
from .traversable import TraversableSequence
from ..document import Document
from ..struct import StructView
from ...helper import typename
from ...proto import jina_pb2

Expand Down Expand Up @@ -152,6 +153,49 @@ def blobs(self, blobs: np.ndarray):
for d, x in zip(self, blobs):
d.blob = x

@property
def tags(self) -> List[StructView]:
"""Get the tags attribute of all Documents"""
...

@tags.setter
def tags(self, tags: Sequence[Union[Dict, StructView]]):
"""Set the tags attribute for all Documents
:param tags: A sequence of tags to set, should be the same length as the
number of Documents
"""

if len(tags) != len(self):
raise ValueError(
f'the number of tags in the input ({len(tags)}), should match the'
f'number of Documents ({len(self)})'
)

for doc, tags_doc in zip(self, tags):
doc.tags = tags_doc

@property
def texts(self) -> List[str]:
"""Get the text attribute of all Documents"""
...

@texts.setter
def texts(self, texts: Sequence[str]):
"""Set the text attribute for all Documents
:param texts: A sequence of texts to set, should be the same length as the
number of Documents
"""
if len(texts) != len(self):
raise ValueError(
f'the number of texts in the input ({len(texts)}), should match the'
f'number of Documents ({len(self)})'
)

for doc, text in zip(self, texts):
doc.text = text


class DocumentArray(
TraversableSequence,
Expand Down Expand Up @@ -540,6 +584,23 @@ def embeddings(self, emb: np.ndarray):
for d, x in zip(self, emb):
d.embedding = x

@DocumentArrayGetAttrMixin.tags.getter
def tags(self) -> List[StructView]:
"""Get the tags attribute of all Documents
:return: List of ``tags`` attributes for all Documents
"""
tags = [StructView(d.tags) for d in self._pb_body]
return tags

@DocumentArrayGetAttrMixin.texts.getter
def texts(self) -> List[str]:
"""Get the text attribute of all Documents
:return: List of ``text`` attributes for all Documents
"""
return [d.text for d in self._pb_body]

@DocumentArrayGetAttrMixin.blobs.getter
def blobs(self) -> np.ndarray:
"""Return a `np.ndarray` stacking all the `blob` attributes.
Expand Down
17 changes: 17 additions & 0 deletions jina/types/arrays/memmap.py
Expand Up @@ -24,6 +24,7 @@
from .search_ops import DocumentArraySearchOpsMixin
from .traversable import TraversableSequence
from ..document import Document
from ..struct import StructView
from ...logging.predefined import default_logger


Expand Down Expand Up @@ -563,6 +564,22 @@ def embeddings(self, emb: np.ndarray):
for d, x in zip(self, emb):
d.embedding = x

@DocumentArrayGetAttrMixin.tags.getter
def tags(self) -> Tuple[StructView]:
"""Get the tags attribute of all Documents
:return: List of ``tags`` attributes for all Documents
"""
return self.get_attributes('tags')

@DocumentArrayGetAttrMixin.texts.getter
def texts(self) -> Tuple[str]:
"""Get the text attribute of all Documents
:return: List of ``text`` attributes for all Documents
"""
return self.get_attributes('text')

@DocumentArrayGetAttrMixin.blobs.getter
def blobs(self) -> np.ndarray:
"""Return a `np.ndarray` stacking all the `blob` attributes.
Expand Down
48 changes: 48 additions & 0 deletions tests/unit/types/arrays/test_documentarray.py
Expand Up @@ -564,6 +564,54 @@ def test_blobs_setter_da():
np.testing.assert_almost_equal(x, doc.blob)


def test_tags_getter_da():
da = DocumentArray([Document(tags={'a': 2, 'c': 'd'}) for _ in range(100)])
assert len(da.tags) == 100
assert da.tags == da.get_attributes('tags')


def test_tags_setter_da():
tags = [{'a': 2, 'c': 'd'} for _ in range(100)]
da = DocumentArray([Document() for _ in range(100)])
da.tags = tags
assert da.tags == tags

for x, doc in zip(tags, da):
assert x == doc.tags


def test_setter_wrong_len():
da = DocumentArray([Document() for _ in range(100)])
tags = [{'1': 2}]

with pytest.raises(ValueError, match='the number of tags in the'):
da.tags = tags


def test_texts_getter_da():
da = DocumentArray([Document(text='hello') for _ in range(100)])
assert len(da.texts) == 100
assert da.texts == da.get_attributes('text')


def test_texts_setter_da():
texts = ['text' for _ in range(100)]
da = DocumentArray([Document() for _ in range(100)])
da.texts = texts
assert da.texts == texts

for x, doc in zip(texts, da):
assert x == doc.text


def test_texts_wrong_len():
da = DocumentArray([Document() for _ in range(100)])
texts = ['hello']

with pytest.raises(ValueError, match='the number of texts in the'):
da.texts = texts


def test_blobs_wrong_len():
da = DocumentArray([Document() for _ in range(100)])
blobs = np.ones((2, 10, 10))
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/types/arrays/test_memmap.py
Expand Up @@ -529,6 +529,62 @@ def test_blobs_setter_dam(tmpdir):
np.testing.assert_almost_equal(x, doc.blob)


def test_tags_getter_dam(tmpdir):
dam = DocumentArrayMemmap(tmpdir)
dam.extend([Document(tags={'a': 2, 'c': 'd'}) for _ in range(100)])
assert len(dam.tags) == 100
assert dam.tags == dam.get_attributes('tags')


def test_tags_setter_dam(tmpdir):
dam = DocumentArrayMemmap(tmpdir)
tags = [{'a': 2, 'c': 'd'} for _ in range(100)]
dam.extend([Document() for _ in range(100)])
dam.tags = tags
assert dam.tags == tags

for x, doc in zip(tags, dam):
assert x == doc.tags


def test_setter_wrong_len(tmpdir):
dam = DocumentArrayMemmap(tmpdir)
dam.extend([Document() for _ in range(100)])
tags = [{'1': 2}]

with pytest.raises(ValueError, match='the number of tags in the'):
dam.tags = tags


def test_texts_getter_dam(tmpdir):
dam = DocumentArrayMemmap(tmpdir)
dam.extend([Document(text='hello') for _ in range(100)])
assert len(dam.texts) == 100
t1 = dam.texts
t2 = dam.get_attributes('text')
assert t1 == t2


def test_texts_setter_dam(tmpdir):
dam = DocumentArrayMemmap(tmpdir)
dam.extend([Document() for _ in range(100)])
texts = ['text' for _ in range(100)]
dam.texts = texts
assert dam.texts == texts

for x, doc in zip(texts, dam):
assert x == doc.text


def test_texts_wrong_len(tmpdir):
dam = DocumentArrayMemmap(tmpdir)
dam.extend([Document() for _ in range(100)])
texts = ['hello']

with pytest.raises(ValueError, match='the number of texts in the'):
dam.texts = texts


def test_blobs_wrong_len(tmpdir):
dam = DocumentArrayMemmap(tmpdir)
dam.extend([Document() for x in range(100)])
Expand Down

0 comments on commit 4a5eee9

Please sign in to comment.