Skip to content

Commit

Permalink
test: add tests for batching executors
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Mar 3, 2021
1 parent f4aad55 commit 3d48e16
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 6 deletions.
4 changes: 3 additions & 1 deletion jina/executors/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .metas import get_default_metas
from ..helper import batch_iterator, typename, convert_tuple_to_list
from ..logging import default_logger
from itertools import islice
from itertools import islice, chain


def as_aggregate_method(func: Callable) -> Callable:
Expand Down Expand Up @@ -202,6 +202,8 @@ def _merge_results_after_batching(final_result, merge_over_axis: int = 0):
for col in range(num_cols):
reduced_result.append(np.concatenate([row[col] for row in final_result], merge_over_axis))
final_result = tuple(reduced_result)
elif isinstance(final_result[0], list):
final_result = list(chain.from_iterable(final_result))

if len(final_result):
return final_result
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from jina.executors.crafters import BaseCrafter
from jina.executors.decorators import batching, batching_multi_input
from jina import Document
from jina.types.sets import DocumentSet


class DummyCrafterText(BaseCrafter):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@batching(batch_size=3)
def craft(self, text, *args, **kwargs):
assert len(text) == 3
return [{'text': f'{txt}-crafted'} for txt in text]


def test_batching_text_single():
docs = DocumentSet([Document(text=f'text-{i}') for i in range(15)])
texts, _ = docs._extract_docs('text')

crafter = DummyCrafterText()
crafted_docs = crafter.craft(texts)
for i, crafted_doc in enumerate(crafted_docs):
assert crafted_doc['text'] == f'text-{i}-crafted'


class DummyCrafterTextId(BaseCrafter):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@batching(batch_size=3)
def craft(self, text, id, *args, **kwargs):
assert len(text) == 3
assert len(id) == 3
return [{'text': f'{txt}-crafted', 'id': i} for i, txt in zip(id, text)]


def test_batching_text_multi():
docs = DocumentSet([Document(text=f'text-{i}') for i in range(15)])
text_ids, _ = docs._extract_docs(*['text', 'id'])

crafter = DummyCrafterTextId()
crafted_docs = crafter.craft(text_ids)

for i, crafted_doc in enumerate(crafted_docs):
assert crafted_doc['text'] == f'text-{i}-crafted'
47 changes: 42 additions & 5 deletions tests/unit/executors/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(self):

@single
def f(self, data):
assert isinstance(data, int)
self.call_nbr += 1
return data

Expand All @@ -118,14 +119,14 @@ def f(self, data):

instance = A(1)
result = instance.f([1, 1, 1, 1])
assert result == [[1], [1], [1], [1]]
assert result == [1, 1, 1, 1]
assert len(instance.batch_sizes) == 4
for batch_size in instance.batch_sizes:
assert batch_size == 1

instance = A(3)
result = instance.f([1, 1, 1, 1])
assert result == [[1, 1, 1], [1]]
assert result == [1, 1, 1, 1]
assert len(instance.batch_sizes) == 2
assert instance.batch_sizes[0] == 3
assert instance.batch_sizes[1] == 1
Expand All @@ -150,14 +151,14 @@ def f(self, key, data):

instance = A(1)
result = instance.f(None, [1, 1, 1, 1])
assert result == [[1], [1], [1], [1]]
assert result == [1, 1, 1, 1]
assert len(instance.batch_sizes) == 4
for batch_size in instance.batch_sizes:
assert batch_size == 1

instance = A(3)
result = instance.f(None, [1, 1, 1, 1])
assert result == [[1, 1, 1], [1]]
assert result == [1, 1, 1, 1]
assert len(instance.batch_sizes) == 2
assert instance.batch_sizes[0] == 3
assert instance.batch_sizes[1] == 1
Expand Down Expand Up @@ -199,7 +200,7 @@ def f(self, data, ord_idx):
assert instance.ord_idx[4].start == 8
assert instance.ord_idx[4].stop == 10

assert result == [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
assert result == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


@pytest.mark.skip(
Expand Down Expand Up @@ -277,3 +278,39 @@ def score(
assert batch[0] == query_meta
assert len(batch[1]) == batch_size
assert len(batch[2]) == batch_size


def test_batching_as_ndarray():
class A:
def __init__(self, batch_size):
self.batch_size = batch_size
self.batch_sizes = []

@as_ndarray
@batching
def f(self, data):
self.batch_sizes.append(len(data))
return np.array(data)

instance = A(1)
result = instance.f([1, 1, 1, 1])
assert isinstance(result, np.ndarray)
np.testing.assert_equal(result, np.array([1., 1., 1., 1.]))
assert len(instance.batch_sizes) == 4
for batch_size in instance.batch_sizes:
assert batch_size == 1

instance = A(3)
result = instance.f([1, 1, 1, 1])
assert isinstance(result, np.ndarray)
np.testing.assert_equal(result, np.array([1., 1., 1., 1.]))
assert len(instance.batch_sizes) == 2
assert instance.batch_sizes[0] == 3
assert instance.batch_sizes[1] == 1

instance = A(5)
result = instance.f([1, 1, 1, 1])
assert isinstance(result, np.ndarray)
np.testing.assert_equal(result, np.array([1., 1., 1., 1.]))
assert len(instance.batch_sizes) == 1
assert instance.batch_sizes[0] == 4

0 comments on commit 3d48e16

Please sign in to comment.