Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: prepare changes to have batching for every executor #2110

Merged
merged 23 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
939408b
test: test _extract content from docset
JoanFM Mar 3, 2021
f4aad55
feat: allow extracting multiple contents
JoanFM Mar 3, 2021
d1297eb
test: add tests for batching executors
JoanFM Mar 3, 2021
d9ce11b
docs: add docstrings in document
JoanFM Mar 3, 2021
15b2257
docs: add docstrings in documentset
JoanFM Mar 3, 2021
532085e
fix: fix getting content from multi input
JoanFM Mar 3, 2021
33911a7
test: add tests to batching for segmenters
JoanFM Mar 4, 2021
dd99a39
Merge branch 'master' of https://github.com/jina-ai/jina into test-ge…
JoanFM Mar 4, 2021
7bdaa99
Merge branch 'master' of https://github.com/jina-ai/jina into test-ge…
JoanFM Mar 4, 2021
f3d1774
test: add explicit tests about batching for encoders
JoanFM Mar 4, 2021
28650bd
fix: single should handle non iterable input
JoanFM Mar 4, 2021
888b148
Merge branch 'master' of https://github.com/jina-ai/jina into test-ge…
JoanFM Mar 5, 2021
720c2a4
refactor: rename arguments and test output type
JoanFM Mar 5, 2021
3cd1dc8
Merge branch 'master' of https://github.com/jina-ai/jina into test-ge…
JoanFM Mar 5, 2021
650db8b
fix: fix black
JoanFM Mar 5, 2021
498135a
fix: update jina/executors/decorators.py
JoanFM Mar 8, 2021
a48a96c
fix: update jina/executors/decorators.py
JoanFM Mar 8, 2021
cef0f83
fix: update jina/types/sets/document.py
JoanFM Mar 8, 2021
0a822fe
fix: improve readability extract content
JoanFM Mar 8, 2021
b23c2d0
fix: update jina/types/sets/document.py
JoanFM Mar 8, 2021
58588cd
fix: update test_craft_executors_batching.py
JoanFM Mar 8, 2021
78238ac
Merge branch 'master' of https://github.com/jina-ai/jina into test-ge…
JoanFM Mar 8, 2021
d1ab9ec
Merge branch 'master' of https://github.com/jina-ai/jina into test-ge…
JoanFM Mar 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
96 changes: 88 additions & 8 deletions jina/executors/decorators.py
Expand Up @@ -5,14 +5,14 @@

import inspect
from functools import wraps
from typing import Callable, Any, Union, Iterator, List, Optional, Dict
from typing import Callable, Any, Union, Iterator, List, Optional, Dict, Iterable

import numpy as np

from .metas import get_default_metas
from ..helper import batch_iterator, typename, convert_tuple_to_list
from ..logging import default_logger
from itertools import islice
from itertools import islice, chain


def as_aggregate_method(func: Callable) -> Callable:
Expand Down Expand Up @@ -199,14 +199,17 @@ def _get_total_size(full_data_size, batch_size, num_batch):
return total_size


def _merge_results_after_batching(final_result, merge_over_axis: int = 0):
def _merge_results_after_batching(
final_result, merge_over_axis: int = 0, flatten: bool = True
):
if len(final_result) == 1:
# the only result of one batch
return final_result[0]

if len(final_result) and merge_over_axis is not None:
if final_result:
if isinstance(final_result[0], np.ndarray):
final_result = np.concatenate(final_result, merge_over_axis)
if len(final_result[0].shape) > 1:
final_result = np.concatenate(final_result, merge_over_axis)
elif isinstance(final_result[0], tuple):
reduced_result = []
num_cols = len(final_result[0])
Expand All @@ -215,6 +218,8 @@ def _merge_results_after_batching(final_result, merge_over_axis: int = 0):
np.concatenate([row[col] for row in final_result], merge_over_axis)
)
final_result = tuple(reduced_result)
elif isinstance(final_result[0], list) and flatten:
final_result = list(chain.from_iterable(final_result))

if len(final_result):
return final_result
Expand All @@ -229,6 +234,7 @@ def batching(
slice_on: int = 1,
label_on: Optional[int] = None,
ordinal_idx_arg: Optional[int] = None,
flatten_output: bool = True,
) -> Any:
"""Split the input of a function into small batches and call :func:`func` on each batch
, collect the merged result and return. This is useful when the input is too big to fit into memory
Expand All @@ -244,6 +250,7 @@ def batching(
:param ordinal_idx_arg: the location of the ordinal indexes argument. Needed for classes
where function decorated needs to know the ordinal indexes of the data in the batch
(Not used when label_on is used)
:param flatten_output: If this is set to True, the results from different batches will be chained and the returning value is a list of the results. Otherwise, the returning value is a list of lists, in which each element is a list containing the result from one single batch. Note if there is only one batch returned, the returned result is always flatten.
:return: the merged result as if run :func:`func` once on the input.

Example:
Expand Down Expand Up @@ -321,7 +328,9 @@ def arg_wrapper(*args, **kwargs):
if r is not None:
final_result.append(r)

return _merge_results_after_batching(final_result, merge_over_axis)
return _merge_results_after_batching(
final_result, merge_over_axis, flatten_output
)

return arg_wrapper

Expand Down Expand Up @@ -436,6 +445,7 @@ def single(
func: Callable[[Any], np.ndarray] = None,
merge_over_axis: int = 0,
slice_on: int = 1,
flatten_output: bool = True,
) -> Any:
"""
Guarantee that the input of a function is provided as a single instance and not in batches
Expand All @@ -444,6 +454,7 @@ def single(
:param merge_over_axis: merge over which axis into a single result
:param slice_on: the location of the data. When using inside a class,
``slice_on`` should take ``self`` into consideration.
:param flatten_output: Flag to determine if a result of list of lists needs to be flattened in output
JoanFM marked this conversation as resolved.
Show resolved Hide resolved
:return: the merged result as if run :func:`func` once on the input.

Example:
Expand All @@ -458,9 +469,12 @@ def craft(self, text: str) -> Dict:
def _single(func):
@wraps(func)
def arg_wrapper(*args, **kwargs):
# priority: decorator > class_attribute
# by default data is in args[1] (self needs to be taken into account)
data = args[slice_on]

if not isinstance(data, Iterable):
return func(*args, **kwargs)

args = list(args)

default_logger.debug(f'batching disabled for {func.__qualname__}')
Expand All @@ -472,8 +486,74 @@ def arg_wrapper(*args, **kwargs):
if r is not None:
final_result.append(r)

return _merge_results_after_batching(
final_result, merge_over_axis, flatten_output
)

return arg_wrapper

if func:
return _single(func)
else:
return _single


def single_multi_input(
func: Callable[[Any], np.ndarray] = None,
merge_over_axis: int = 0,
slice_on: int = 1,
num_data: int = 1,
) -> Any:
"""Guarantee that the inputs of a function with more than one argument is provided as single instances and not in batches

:param func: function to decorate
:param merge_over_axis: merge over which axis into a single result
:param slice_on: the location of the data. When using inside a class,
``slice_on`` should take ``self`` into consideration.
:param num_data: the number of data inside the arguments
:return: the merged result as if run :func:`func` once on the input.

..warning:
data arguments will be taken starting from ``slice_on` to ``slice_on + num_data``

Example:
.. highlight:: python
.. code-block:: python

class OneByOneCrafter:

@single_multi_input
def craft(self, text: str, id: str) -> Dict:
...
"""

def _single_multi_input(func):
@wraps(func)
def arg_wrapper(*args, **kwargs):
# by default data is in args[1:] (self needs to be taken into account)
args = list(args)
default_logger.debug(f'batching disabled for {func.__qualname__}')
data_iterators = [args[slice_on + i] for i in range(0, num_data)]

if not isinstance(data_iterators[0], Iterable):
return func(*args, **kwargs)

final_result = []
for i, instance in enumerate(data_iterators[0]):
args[slice_on] = instance
for idx in range(1, num_data):
args[slice_on + idx] = data_iterators[idx][i]

r = func(*args, **kwargs)

if r is not None:
final_result.append(r)

return _merge_results_after_batching(final_result, merge_over_axis)

return arg_wrapper

return _single(func)
if func:
return _single_multi_input(func)
else:
return _single_multi_input