Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(helper): merge single and single_multi_input into one #2192

Merged
merged 6 commits into from
Mar 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ jobs:
GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}

core-test:
needs: [prep-testbed, commit-lint, lint-flake-8, check-docstring, check-black]
needs: [prep-testbed, commit-lint, lint-flake-8]
runs-on: ubuntu-20.04
strategy:
fail-fast: false
Expand Down Expand Up @@ -198,7 +198,7 @@ jobs:

# just for blocking the merge until all parallel core-test are successful
success-all-test:
needs: [core-test, docker-image-test]
needs: [core-test, docker-image-test, check-docstring, check-black]
if: always()
runs-on: ubuntu-20.04
steps:
Expand Down
1 change: 0 additions & 1 deletion jina/drivers/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def _apply_all(self, docs: 'DocumentSet', *args, **kwargs):
f'mismatched {len(docs_pts)} docs from level {docs_pts[0].granularity} '
f'and length of returned crafted documents: {len(docs_chunks)}, the length must be the same'
)
self.logger.error(msg)
raise LengthMismatchException(msg)

for doc, chunks in zip(docs_pts, docs_chunks):
Expand Down
196 changes: 52 additions & 144 deletions jina/executors/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

import inspect
from functools import wraps
from itertools import islice, chain
from typing import Callable, Any, Union, Iterator, List, Optional, Dict, Iterable

import numpy as np

from .metas import get_default_metas
from ..helper import batch_iterator, typename, convert_tuple_to_list
from ..logging import default_logger
from itertools import islice, chain


def as_aggregate_method(func: Callable) -> Callable:
Expand Down Expand Up @@ -78,7 +78,7 @@ def wrap_func(cls, func_lst, wrapper):
"""
for f_name in func_lst:
if hasattr(cls, f_name) and all(
getattr(cls, f_name) != getattr(i, f_name, None) for i in cls.mro()[1:]
getattr(cls, f_name) != getattr(i, f_name, None) for i in cls.mro()[1:]
):
setattr(cls, f_name, wrapper(getattr(cls, f_name)))

Expand Down Expand Up @@ -170,7 +170,7 @@ def arg_wrapper(self, *args, **kwargs):


def _get_slice(
data: Union[Iterator[Any], List[Any], np.ndarray], total_size: int
data: Union[Iterator[Any], List[Any], np.ndarray], total_size: int
) -> Union[Iterator[Any], List[Any], np.ndarray]:
if isinstance(data, Dict):
data = islice(data.items(), total_size)
Expand Down Expand Up @@ -200,7 +200,7 @@ def _get_total_size(full_data_size, batch_size, num_batch):


def _merge_results_after_batching(
final_result, merge_over_axis: int = 0, flatten: bool = True
final_result, merge_over_axis: int = 0, flatten: bool = True
):
if not final_result:
return
Expand All @@ -215,15 +215,15 @@ def _merge_results_after_batching(


def batching(
func: Optional[Callable[[Any], np.ndarray]] = None,
batch_size: Optional[Union[int, Callable]] = None,
num_batch: Optional[int] = None,
split_over_axis: int = 0,
merge_over_axis: int = 0,
slice_on: int = 1,
label_on: Optional[int] = None,
ordinal_idx_arg: Optional[int] = None,
flatten_output: bool = True,
func: Optional[Callable[[Any], np.ndarray]] = None,
batch_size: Optional[Union[int, Callable]] = None,
num_batch: Optional[int] = None,
split_over_axis: int = 0,
merge_over_axis: int = 0,
slice_on: int = 1,
label_on: Optional[int] = None,
ordinal_idx_arg: Optional[int] = None,
flatten_output: bool = True,
) -> Any:
"""Split the input of a function into small batches and call :func:`func` on each batch
, collect the merged result and return. This is useful when the input is too big to fit into memory
Expand Down Expand Up @@ -266,8 +266,8 @@ def arg_wrapper(*args, **kwargs):
args = list(args)

b_size = (
batch_size(data) if callable(batch_size) else batch_size
) or getattr(args[0], 'batch_size', None)
batch_size(data) if callable(batch_size) else batch_size
) or getattr(args[0], 'batch_size', None)
# no batching if b_size is None
if b_size is None or data is None:
return func(*args, **kwargs)
Expand All @@ -288,7 +288,7 @@ def arg_wrapper(*args, **kwargs):
slice_idx = None

for b in batch_iterator(
data[:total_size], b_size, split_over_axis, yield_slice=yield_slice
data[:total_size], b_size, split_over_axis, yield_slice=yield_slice
):
if yield_slice:
slice_idx = b
Expand Down Expand Up @@ -330,13 +330,13 @@ def arg_wrapper(*args, **kwargs):


def batching_multi_input(
func: Optional[Callable[[Any], np.ndarray]] = None,
batch_size: Optional[Union[int, Callable]] = None,
num_batch: Optional[int] = None,
split_over_axis: int = 0,
merge_over_axis: int = 0,
slice_on: int = 1,
num_data: int = 1,
func: Optional[Callable[[Any], np.ndarray]] = None,
batch_size: Optional[Union[int, Callable]] = None,
num_batch: Optional[int] = None,
split_over_axis: int = 0,
merge_over_axis: int = 0,
slice_on: int = 1,
slice_nargs: int = 1,
) -> Any:
"""Split the input of a function into small batches and call :func:`func` on each batch
, collect the merged result and return. This is useful when the input is too big to fit into memory
Expand All @@ -348,7 +348,7 @@ def batching_multi_input(
:param merge_over_axis: merge over which axis into a single result
:param slice_on: the location of the data. When using inside a class,
``slice_on`` should take ``self`` into consideration.
:param num_data: the number of data inside the arguments
:param slice_nargs: the number of data inside the arguments
:return: the merged result as if run :func:`func` once on the input.

..warning:
Expand Down Expand Up @@ -397,24 +397,23 @@ def arg_wrapper(*args, **kwargs):
full_data_size = _get_size(args[slice_on], split_over_axis)
total_size = _get_total_size(full_data_size, b_size, num_batch)
final_result = []
yield_dict = [
isinstance(args[slice_on + i], Dict) for i in range(0, num_data)
]

yield_dict = [isinstance(args[slice_on + i], Dict) for i in range(0, slice_nargs)]
yield_slice = [isinstance(args[slice_on + i], np.memmap) for i in range(0, slice_nargs)]

data_iterators = [
batch_iterator(
_get_slice(args[slice_on + i], total_size),
b_size,
split_over_axis,
yield_slice=yield_slice[i],
yield_dict=yield_dict[i],
)
for i in range(0, num_data)
for i in range(0, slice_nargs)
]

for batch in data_iterators[0]:
args[slice_on] = batch
for idx in range(1, num_data):
args[slice_on + idx] = next(data_iterators[idx])

for new_args in zip(*data_iterators):
args[slice_on: slice_on + slice_nargs] = new_args
r = func(*args, **kwargs)

if r is not None:
Expand All @@ -431,109 +430,19 @@ def arg_wrapper(*args, **kwargs):


def single(
func: Optional[Callable[[Any], np.ndarray]] = None,
merge_over_axis: int = 0,
slice_on: int = 1,
flatten_output: bool = False,
) -> Any:
"""
Guarantee that the input of a function is provided as a single instance and not in batches

:param func: function to decorate
:param merge_over_axis: merge over which axis into a single result
:param slice_on: the location of the data. When using inside a class,
``slice_on`` should take ``self`` into consideration.
:param flatten_output: Flag to determine if a result of list of lists needs to be flattened in output
:return: the merged result as if run :func:`func` once on the input.

Example:
.. highlight:: python
.. code-block:: python

class OneByOneCrafter:
@single
def craft(self, text: str) -> Dict:

.. note:
Single decorator will let the user interact with the executor in 3 different ways:
- Providing batches: (This decorator will make sure that the actual method receives just a single instance)
- Providing a single instance
- Providing a single instance through kwargs.

.. highlight:: python
.. code-block:: python

class OneByOneCrafter:
@single
def craft(self, text: str) -> Dict:
return {'text' : f'{text}-crafted'}

crafter = OneByOneCrafter()

results = crafted.craft(['text1', 'text2'])
assert len(results) == 2
assert results[0] == {'text': 'text1-crafted'}
assert results[1] == {'text': 'text2-crafted'}

result = crafter.craft('text')
assert result['text'] == 'text-crafted'

results = crafted.craft(text='text')
assert result['text'] == 'text-crafted'
"""

def _single(func):
@wraps(func)
def arg_wrapper(*args, **kwargs):

# like this one can use the function with single kwargs
if (
len(args) <= slice_on
or isinstance(args[slice_on], str)
or isinstance(args[slice_on], bytes)
or not isinstance(args[slice_on], Iterable)
):
# like this one can use the function with single kwargs
return func(*args, **kwargs)

args = list(args)
data = args[slice_on]

default_logger.debug(f'batching disabled for {func.__qualname__}')

final_result = []
for instance in data:
args[slice_on] = instance
r = func(*args, **kwargs)
if r is not None:
final_result.append(r)

return _merge_results_after_batching(
final_result, merge_over_axis, flatten_output
)

return arg_wrapper

if func:
return _single(func)
else:
return _single


def single_multi_input(
func: Optional[Callable[[Any], np.ndarray]] = None,
merge_over_axis: int = 0,
slice_on: int = 1,
num_data: int = 1,
flatten_output: bool = True,
func: Optional[Callable[[Any], np.ndarray]] = None,
merge_over_axis: int = 0,
slice_on: int = 1,
slice_nargs: int = 1,
flatten_output: bool = False,
) -> Any:
"""Guarantee that the inputs of a function with more than one argument is provided as single instances and not in batches

:param func: function to decorate
:param merge_over_axis: merge over which axis into a single result
:param slice_on: the location of the data. When using inside a class,
``slice_on`` should take ``self`` into consideration.
:param num_data: the number of data inside the arguments
:param slice_nargs: the number of positional arguments considered as data
:param flatten_output: If this is set to True, the results from different batches will be chained and the returning value is a list of the results. Otherwise, the returning value is a list of lists, in which each element is a list containing the result from one single batch. Note if there is only one batch returned, the returned result is always flatten.
:return: the merged result as if run :func:`func` once on the input.

Expand All @@ -546,7 +455,7 @@ def single_multi_input(

class OneByOneCrafter:

@single_multi_input
@single
def craft(self, text: str, id: str) -> Dict:
...

Expand All @@ -560,7 +469,7 @@ def craft(self, text: str, id: str) -> Dict:
.. code-block:: python

class OneByOneCrafter:
@single_multi_input
@single
def craft(self, text: str, id: str) -> Dict:
return {'text': f'{text}-crafted', 'id': f'{id}-crafted'}

Expand All @@ -587,27 +496,26 @@ def arg_wrapper(*args, **kwargs):
args = list(args)
default_logger.debug(f'batching disabled for {func.__qualname__}')

data_iterators = args[slice_on: slice_on + slice_nargs]

if len(args) <= slice_on:
# like this one can use the function with single kwargs
return func(*args, **kwargs)

data_iterators = list([args[slice_on + i] for i in range(0, num_data)])

if (
len(args) <= slice_on
or isinstance(data_iterators[0], str)
or isinstance(data_iterators[0], bytes)
or not isinstance(data_iterators[0], Iterable)
elif len(args) < slice_on + slice_nargs:
raise IndexError(f'can not select positional args at {slice_on}: {slice_nargs}, '
f'your `args` has {len(args)} arguments.')
elif (
len(args) <= slice_on
or isinstance(data_iterators[0], str)
or isinstance(data_iterators[0], bytes)
or not isinstance(data_iterators[0], Iterable)
):
# like this one can use the function with single kwargs
return func(*args, **kwargs)

final_result = []
for i, instance in enumerate(data_iterators[0]):
args[slice_on] = instance
for idx in range(1, num_data):
args[slice_on + idx] = data_iterators[idx][i]

for new_args in zip(*data_iterators):
args[slice_on: slice_on + slice_nargs] = new_args
r = func(*args, **kwargs)

if r is not None:
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/evaluation/rank/yaml/dummy_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.match_required_keys = ['tags__dummy_score']

@batching_multi_input(num_data=3)
@batching_multi_input(slice_nargs=3)
def score(
self,
old_match_scores: List[Dict],
Expand Down