Skip to content

Commit

Permalink
Revert "fix: fix multimodal example (#2178)" (#2191)
Browse files Browse the repository at this point in the history
This reverts commit 8849200.
  • Loading branch information
hanxiao committed Mar 17, 2021
1 parent 8849200 commit 44e6e91
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 147 deletions.
3 changes: 0 additions & 3 deletions jina/resources/multimodal/pods/segmenter.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import os

from jina import Segmenter, Crafter
from jina.executors.decorators import single, single_multi_input


class SimpleCrafter(Crafter):
"""Simple crafter for multimodal example."""

@single
def craft(self, tags):
"""
Read the data and add tags.
Expand All @@ -24,7 +22,6 @@ def craft(self, tags):
class BiSegmenter(Segmenter):
"""Segmenter for multimodal example."""

@single_multi_input(num_data=2, flatten_output=False)
def segment(self, text, uri):
"""
Segment data into text and uri.
Expand Down
12 changes: 5 additions & 7 deletions jina/types/sets/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def all_embeddings(self) -> Tuple['np.ndarray', 'DocumentSet']:
and the documents have no embedding in a :class:`DocumentSet`.
:rtype: A tuple of embedding in :class:`np.ndarray`
"""
return self.extract_docs('embedding', stack_contents=True)
return self.extract_docs('embedding')

@property
def all_contents(self) -> Tuple['np.ndarray', 'DocumentSet']:
Expand All @@ -170,16 +170,14 @@ def all_contents(self) -> Tuple['np.ndarray', 'DocumentSet']:
and the documents have no contents in a :class:`DocumentSet`.
:rtype: A tuple of embedding in :class:`np.ndarray`
"""
# stack true for backward compatibility, but will not work if content is blob of different shapes
return self.extract_docs('content', stack_contents=True)
return self.extract_docs('content')

def extract_docs(
self, *fields: str, stack_contents: bool = False
self, *fields: str
) -> Tuple[Union['np.ndarray', List['np.ndarray']], 'DocumentSet']:
"""Return in batches all the values of the fields
:param fields: Variable length argument with the name of the fields to extract
:param stack_contents: boolean flag indicating if output lists should be stacked with `np.stack`
:return: Returns an :class:`np.ndarray` or a list of :class:`np.ndarray` with the batches for these fields
"""

Expand All @@ -200,7 +198,7 @@ def extract_docs(
for idx, c in enumerate(contents):
if not c:
continue
if stack_contents and not isinstance(c[0], bytes):
if not isinstance(c[0], bytes):
contents[idx] = np.stack(c)
else:
for doc in self:
Expand All @@ -213,7 +211,7 @@ def extract_docs(

if not contents:
contents = None
elif stack_contents and not isinstance(contents[0], bytes):
elif not isinstance(contents[0], bytes):
contents = np.stack(contents)

if bad_docs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,12 @@ def craft(self, text, *args, **kwargs):
return {'text': f'{text}-crafted'}


@pytest.mark.parametrize('stack', [False, True])
@pytest.mark.parametrize(
'crafter', [DummyCrafterTextSingle(), DummyCrafterTextBatching()]
)
def test_batching_text_one_argument(stack, crafter):
def test_batching_text_one_argument(crafter):
docs = DocumentSet([Document(text=f'text-{i}') for i in range(15)])
texts, _ = docs.extract_docs('text', stack_contents=stack)
texts, _ = docs.extract_docs('text')

crafted_docs = crafter.craft(texts)
for i, crafted_doc in enumerate(crafted_docs):
Expand Down Expand Up @@ -94,14 +93,13 @@ def craft(self, text, id, *args, **kwargs):
return {'text': f'{text}-crafted', 'id': f'{id}-crafted'}


@pytest.mark.parametrize('stack', [False, True])
@pytest.mark.parametrize(
'crafter', [DummyCrafterTextIdSingle(), DummyCrafterTextIdBatching()]
)
def test_batching_text_multi(stack, crafter):
def test_batching_text_multi(crafter):
docs = DocumentSet([Document(text=f'text-{i}', id=f'id-{i}') for i in range(15)])
required_keys = ['text', 'id']
text_ids, _ = docs.extract_docs(*required_keys, stack_contents=stack)
text_ids, _ = docs.extract_docs(*required_keys)

crafted_docs = crafter.craft(*text_ids)

Expand Down Expand Up @@ -141,6 +139,7 @@ def __init__(self, *args, **kwargs):
@batching(batch_size=3)
def craft(self, blob, *args, **kwargs):
assert len(blob) == 3
assert blob.shape == (3, 2, 5)
return [{'blob': b} for b in blob]


Expand All @@ -154,17 +153,18 @@ def craft(self, blob, *args, **kwargs):
return {'blob': blob}


@pytest.mark.parametrize('stack', [False, True])
@pytest.mark.parametrize(
'crafter', [DummyCrafterBlobSingle(), DummyCrafterBlobBatching()]
)
def test_batching_blob_one_argument(stack, crafter):
def test_batching_blob_one_argument(crafter):
docs = DocumentSet([Document(blob=np.array([[i] * 5, [i] * 5])) for i in range(15)])
texts, _ = docs.extract_docs('blob', stack_contents=stack)
texts, _ = docs.extract_docs('blob')

crafted_docs = crafter.craft(texts)
for i, crafted_doc in enumerate(crafted_docs):
np.testing.assert_equal(crafted_doc['blob'], np.array([[i] * 5, [i] * 5]))
np.testing.assert_equal(
crafted_doc['blob'], np.array([[i, i, i, i, i], [i, i, i, i, i]])
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -200,6 +200,8 @@ def __init__(self, *args, **kwargs):
def craft(self, blob, embedding, *args, **kwargs):
assert len(blob) == 3
assert len(embedding) == 3
assert blob.shape == (3, 2, 5)
assert embedding.shape == (3, 5)
return [{'blob': b, 'embedding': e} for b, e in zip(blob, embedding)]


Expand All @@ -214,11 +216,10 @@ def craft(self, blob, embedding, *args, **kwargs):
return {'blob': blob, 'embedding': embedding}


@pytest.mark.parametrize('stack', [False, True])
@pytest.mark.parametrize(
'crafter', [DummyCrafterBlobEmbeddingSingle(), DummyCrafterBlobEmbeddingBatching()]
)
def test_batching_blob_multi(stack, crafter):
def test_batching_blob_multi(crafter):
docs = DocumentSet(
[
Document(
Expand All @@ -229,7 +230,7 @@ def test_batching_blob_multi(stack, crafter):
]
)
required_keys = ['blob', 'embedding']
text_ids, _ = docs.extract_docs(*required_keys, stack_contents=stack)
text_ids, _ = docs.extract_docs(*required_keys)

crafted_docs = crafter.craft(*text_ids)

Expand Down Expand Up @@ -279,6 +280,8 @@ def __init__(self, *args, **kwargs):
def craft(self, text, embedding, *args, **kwargs):
assert len(text) == 3
assert len(embedding) == 3
assert text.shape == (3,)
assert embedding.shape == (3, 5)
return [
{'text': f'{t}-crafted', 'embedding': e} for t, e in zip(text, embedding)
]
Expand All @@ -296,16 +299,15 @@ def craft(self, text, embedding, *args, **kwargs):
return {'text': f'{text}-crafted', 'embedding': embedding}


@pytest.mark.parametrize('stack', [False, True])
@pytest.mark.parametrize(
'crafter', [DummyCrafterTextEmbeddingSingle(), DummyCrafterTextEmbeddingBatching()]
)
def test_batching_mix_multi(stack, crafter):
def test_batching_mix_multi(crafter):
docs = DocumentSet(
[Document(text=f'text-{i}', embedding=np.array([i] * 5)) for i in range(15)]
)
required_keys = ['text', 'embedding']
text_ids, _ = docs.extract_docs(*required_keys, stack_contents=stack)
text_ids, _ = docs.extract_docs(*required_keys)

crafted_docs = crafter.craft(*text_ids)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def encode(self, data, *args, **kwargs):
)
def test_batching_encode_text(encoder):
docs = DocumentSet([Document(text=f'text-{i}') for i in range(15)])
texts, _ = docs.all_contents
texts, _ = docs.extract_docs('text')

embeds = encoder.encode(texts)

Expand Down Expand Up @@ -75,7 +75,7 @@ def encode(self, data, *args, **kwargs):
)
def test_batching_encode_blob(encoder):
docs = DocumentSet([Document(blob=np.random.random((10, 20))) for _ in range(15)])
blob, _ = docs.all_contents
blob, _ = docs.extract_docs('blob')

embeds = encoder.encode(blob)

Expand Down
Empty file.
Empty file.
60 changes: 0 additions & 60 deletions tests/integration/helloworld/multimodal/test_hello_multimodal.py

This file was deleted.

0 comments on commit 44e6e91

Please sign in to comment.