Skip to content

Commit

Permalink
refactor(helper): merge single and single_multi_input into one (#2192)
Browse files Browse the repository at this point in the history
* refactor(helper): merge single and single_multi_input into one
  • Loading branch information
hanxiao committed Mar 17, 2021
1 parent b5f8010 commit abceba0
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 175 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ jobs:
GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
core-test:
needs: [prep-testbed, commit-lint, lint-flake-8, check-docstring, check-black]
needs: [prep-testbed, commit-lint, lint-flake-8]
runs-on: ubuntu-20.04
strategy:
fail-fast: false
Expand Down Expand Up @@ -198,7 +198,7 @@ jobs:

# just for blocking the merge until all parallel core-test are successful
success-all-test:
needs: [core-test, docker-image-test]
needs: [core-test, docker-image-test, check-docstring, check-black]
if: always()
runs-on: ubuntu-20.04
steps:
Expand Down
1 change: 0 additions & 1 deletion jina/drivers/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def _apply_all(self, docs: 'DocumentSet', *args, **kwargs):
f'mismatched {len(docs_pts)} docs from level {docs_pts[0].granularity} '
f'and length of returned crafted documents: {len(docs_chunks)}, the length must be the same'
)
self.logger.error(msg)
raise LengthMismatchException(msg)

for doc, chunks in zip(docs_pts, docs_chunks):
Expand Down
196 changes: 52 additions & 144 deletions jina/executors/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

import inspect
from functools import wraps
from itertools import islice, chain
from typing import Callable, Any, Union, Iterator, List, Optional, Dict, Iterable

import numpy as np

from .metas import get_default_metas
from ..helper import batch_iterator, typename, convert_tuple_to_list
from ..logging import default_logger
from itertools import islice, chain


def as_aggregate_method(func: Callable) -> Callable:
Expand Down Expand Up @@ -78,7 +78,7 @@ def wrap_func(cls, func_lst, wrapper):
"""
for f_name in func_lst:
if hasattr(cls, f_name) and all(
getattr(cls, f_name) != getattr(i, f_name, None) for i in cls.mro()[1:]
getattr(cls, f_name) != getattr(i, f_name, None) for i in cls.mro()[1:]
):
setattr(cls, f_name, wrapper(getattr(cls, f_name)))

Expand Down Expand Up @@ -170,7 +170,7 @@ def arg_wrapper(self, *args, **kwargs):


def _get_slice(
data: Union[Iterator[Any], List[Any], np.ndarray], total_size: int
data: Union[Iterator[Any], List[Any], np.ndarray], total_size: int
) -> Union[Iterator[Any], List[Any], np.ndarray]:
if isinstance(data, Dict):
data = islice(data.items(), total_size)
Expand Down Expand Up @@ -200,7 +200,7 @@ def _get_total_size(full_data_size, batch_size, num_batch):


def _merge_results_after_batching(
final_result, merge_over_axis: int = 0, flatten: bool = True
final_result, merge_over_axis: int = 0, flatten: bool = True
):
if not final_result:
return
Expand All @@ -215,15 +215,15 @@ def _merge_results_after_batching(


def batching(
func: Optional[Callable[[Any], np.ndarray]] = None,
batch_size: Optional[Union[int, Callable]] = None,
num_batch: Optional[int] = None,
split_over_axis: int = 0,
merge_over_axis: int = 0,
slice_on: int = 1,
label_on: Optional[int] = None,
ordinal_idx_arg: Optional[int] = None,
flatten_output: bool = True,
func: Optional[Callable[[Any], np.ndarray]] = None,
batch_size: Optional[Union[int, Callable]] = None,
num_batch: Optional[int] = None,
split_over_axis: int = 0,
merge_over_axis: int = 0,
slice_on: int = 1,
label_on: Optional[int] = None,
ordinal_idx_arg: Optional[int] = None,
flatten_output: bool = True,
) -> Any:
"""Split the input of a function into small batches and call :func:`func` on each batch
, collect the merged result and return. This is useful when the input is too big to fit into memory
Expand Down Expand Up @@ -266,8 +266,8 @@ def arg_wrapper(*args, **kwargs):
args = list(args)

b_size = (
batch_size(data) if callable(batch_size) else batch_size
) or getattr(args[0], 'batch_size', None)
batch_size(data) if callable(batch_size) else batch_size
) or getattr(args[0], 'batch_size', None)
# no batching if b_size is None
if b_size is None or data is None:
return func(*args, **kwargs)
Expand All @@ -288,7 +288,7 @@ def arg_wrapper(*args, **kwargs):
slice_idx = None

for b in batch_iterator(
data[:total_size], b_size, split_over_axis, yield_slice=yield_slice
data[:total_size], b_size, split_over_axis, yield_slice=yield_slice
):
if yield_slice:
slice_idx = b
Expand Down Expand Up @@ -330,13 +330,13 @@ def arg_wrapper(*args, **kwargs):


def batching_multi_input(
func: Optional[Callable[[Any], np.ndarray]] = None,
batch_size: Optional[Union[int, Callable]] = None,
num_batch: Optional[int] = None,
split_over_axis: int = 0,
merge_over_axis: int = 0,
slice_on: int = 1,
num_data: int = 1,
func: Optional[Callable[[Any], np.ndarray]] = None,
batch_size: Optional[Union[int, Callable]] = None,
num_batch: Optional[int] = None,
split_over_axis: int = 0,
merge_over_axis: int = 0,
slice_on: int = 1,
slice_nargs: int = 1,
) -> Any:
"""Split the input of a function into small batches and call :func:`func` on each batch
, collect the merged result and return. This is useful when the input is too big to fit into memory
Expand All @@ -348,7 +348,7 @@ def batching_multi_input(
:param merge_over_axis: merge over which axis into a single result
:param slice_on: the location of the data. When using inside a class,
``slice_on`` should take ``self`` into consideration.
:param num_data: the number of data inside the arguments
:param slice_nargs: the number of data inside the arguments
:return: the merged result as if run :func:`func` once on the input.
..warning:
Expand Down Expand Up @@ -397,24 +397,23 @@ def arg_wrapper(*args, **kwargs):
full_data_size = _get_size(args[slice_on], split_over_axis)
total_size = _get_total_size(full_data_size, b_size, num_batch)
final_result = []
yield_dict = [
isinstance(args[slice_on + i], Dict) for i in range(0, num_data)
]

yield_dict = [isinstance(args[slice_on + i], Dict) for i in range(0, slice_nargs)]
yield_slice = [isinstance(args[slice_on + i], np.memmap) for i in range(0, slice_nargs)]

data_iterators = [
batch_iterator(
_get_slice(args[slice_on + i], total_size),
b_size,
split_over_axis,
yield_slice=yield_slice[i],
yield_dict=yield_dict[i],
)
for i in range(0, num_data)
for i in range(0, slice_nargs)
]

for batch in data_iterators[0]:
args[slice_on] = batch
for idx in range(1, num_data):
args[slice_on + idx] = next(data_iterators[idx])

for new_args in zip(*data_iterators):
args[slice_on: slice_on + slice_nargs] = new_args
r = func(*args, **kwargs)

if r is not None:
Expand All @@ -431,109 +430,19 @@ def arg_wrapper(*args, **kwargs):


def single(
func: Optional[Callable[[Any], np.ndarray]] = None,
merge_over_axis: int = 0,
slice_on: int = 1,
flatten_output: bool = False,
) -> Any:
"""
Guarantee that the input of a function is provided as a single instance and not in batches
:param func: function to decorate
:param merge_over_axis: merge over which axis into a single result
:param slice_on: the location of the data. When using inside a class,
``slice_on`` should take ``self`` into consideration.
:param flatten_output: Flag to determine if a result of list of lists needs to be flattened in output
:return: the merged result as if run :func:`func` once on the input.
Example:
.. highlight:: python
.. code-block:: python
class OneByOneCrafter:
@single
def craft(self, text: str) -> Dict:
.. note:
Single decorator will let the user interact with the executor in 3 different ways:
- Providing batches: (This decorator will make sure that the actual method receives just a single instance)
- Providing a single instance
- Providing a single instance through kwargs.
.. highlight:: python
.. code-block:: python
class OneByOneCrafter:
@single
def craft(self, text: str) -> Dict:
return {'text' : f'{text}-crafted'}
crafter = OneByOneCrafter()
results = crafted.craft(['text1', 'text2'])
assert len(results) == 2
assert results[0] == {'text': 'text1-crafted'}
assert results[1] == {'text': 'text2-crafted'}
result = crafter.craft('text')
assert result['text'] == 'text-crafted'
results = crafted.craft(text='text')
assert result['text'] == 'text-crafted'
"""

def _single(func):
@wraps(func)
def arg_wrapper(*args, **kwargs):

# like this one can use the function with single kwargs
if (
len(args) <= slice_on
or isinstance(args[slice_on], str)
or isinstance(args[slice_on], bytes)
or not isinstance(args[slice_on], Iterable)
):
# like this one can use the function with single kwargs
return func(*args, **kwargs)

args = list(args)
data = args[slice_on]

default_logger.debug(f'batching disabled for {func.__qualname__}')

final_result = []
for instance in data:
args[slice_on] = instance
r = func(*args, **kwargs)
if r is not None:
final_result.append(r)

return _merge_results_after_batching(
final_result, merge_over_axis, flatten_output
)

return arg_wrapper

if func:
return _single(func)
else:
return _single


def single_multi_input(
func: Optional[Callable[[Any], np.ndarray]] = None,
merge_over_axis: int = 0,
slice_on: int = 1,
num_data: int = 1,
flatten_output: bool = True,
func: Optional[Callable[[Any], np.ndarray]] = None,
merge_over_axis: int = 0,
slice_on: int = 1,
slice_nargs: int = 1,
flatten_output: bool = False,
) -> Any:
"""Guarantee that the inputs of a function with more than one argument is provided as single instances and not in batches
:param func: function to decorate
:param merge_over_axis: merge over which axis into a single result
:param slice_on: the location of the data. When using inside a class,
``slice_on`` should take ``self`` into consideration.
:param num_data: the number of data inside the arguments
:param slice_nargs: the number of positional arguments considered as data
:param flatten_output: If this is set to True, the results from different batches will be chained and the returning value is a list of the results. Otherwise, the returning value is a list of lists, in which each element is a list containing the result from one single batch. Note if there is only one batch returned, the returned result is always flatten.
:return: the merged result as if run :func:`func` once on the input.
Expand All @@ -546,7 +455,7 @@ def single_multi_input(
class OneByOneCrafter:
@single_multi_input
@single
def craft(self, text: str, id: str) -> Dict:
...
Expand All @@ -560,7 +469,7 @@ def craft(self, text: str, id: str) -> Dict:
.. code-block:: python
class OneByOneCrafter:
@single_multi_input
@single
def craft(self, text: str, id: str) -> Dict:
return {'text': f'{text}-crafted', 'id': f'{id}-crafted'}
Expand All @@ -587,27 +496,26 @@ def arg_wrapper(*args, **kwargs):
args = list(args)
default_logger.debug(f'batching disabled for {func.__qualname__}')

data_iterators = args[slice_on: slice_on + slice_nargs]

if len(args) <= slice_on:
# like this one can use the function with single kwargs
return func(*args, **kwargs)

data_iterators = list([args[slice_on + i] for i in range(0, num_data)])

if (
len(args) <= slice_on
or isinstance(data_iterators[0], str)
or isinstance(data_iterators[0], bytes)
or not isinstance(data_iterators[0], Iterable)
elif len(args) < slice_on + slice_nargs:
raise IndexError(f'can not select positional args at {slice_on}: {slice_nargs}, '
f'your `args` has {len(args)} arguments.')
elif (
len(args) <= slice_on
or isinstance(data_iterators[0], str)
or isinstance(data_iterators[0], bytes)
or not isinstance(data_iterators[0], Iterable)
):
# like this one can use the function with single kwargs
return func(*args, **kwargs)

final_result = []
for i, instance in enumerate(data_iterators[0]):
args[slice_on] = instance
for idx in range(1, num_data):
args[slice_on + idx] = data_iterators[idx][i]

for new_args in zip(*data_iterators):
args[slice_on: slice_on + slice_nargs] = new_args
r = func(*args, **kwargs)

if r is not None:
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/evaluation/rank/yaml/dummy_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.match_required_keys = ['tags__dummy_score']

@batching_multi_input(num_data=3)
@batching_multi_input(slice_nargs=3)
def score(
self,
old_match_scores: List[Dict],
Expand Down

0 comments on commit abceba0

Please sign in to comment.