Skip to content

Commit

Permalink
refactor: change batch_size to request_size (#1677)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Jan 14, 2021
1 parent 5b2fa02 commit 5530733
Show file tree
Hide file tree
Showing 18 changed files with 61 additions and 58 deletions.
6 changes: 3 additions & 3 deletions cli/autocomplete.py
Expand Up @@ -35,8 +35,8 @@ def _gaa(key, parser):
'commands': ['--help', '--version', '--version-full', 'hello-world', 'pod', 'flow', 'gateway', 'ping', 'check',
'hub', 'pea', 'log', 'client', 'export-api'], 'completions': {
'hello-world': ['--help', '--workdir', '--download-proxy', '--shards', '--parallel', '--uses-index',
'--index-data-url', '--index-labels-url', '--index-batch-size', '--uses-query',
'--query-data-url', '--query-labels-url', '--query-batch-size', '--num-query', '--top-k'],
'--index-data-url', '--index-labels-url', '--index-request-size', '--uses-query',
'--query-data-url', '--query-labels-url', '--query-request-size', '--num-query', '--top-k'],
'pod': ['--help', '--name', '--log-config', '--identity', '--log-id', '--show-exc-info', '--port-ctrl',
'--ctrl-with-ipc', '--timeout-ctrl', '--ssh-server', '--ssh-keyfile', '--ssh-password', '--uses',
'--py-modules', '--port-in', '--port-out', '--host-in', '--host-out', '--socket-in', '--socket-out',
Expand Down Expand Up @@ -76,6 +76,6 @@ def _gaa(key, parser):
'--host', '--port-expose', '--daemon', '--runtime-backend', '--runtime', '--runtime-cls',
'--timeout-ready', '--env', '--expose-public', '--pea-id', '--pea-role'],
'log': ['--help', '--groupby-regex', '--refresh-time'],
'client': ['--help', '--batch-size', '--mode', '--top-k', '--mime-type', '--continue-on-error',
'client': ['--help', '--request-size', '--mode', '--top-k', '--mime-type', '--continue-on-error',
'--return-results', '--max-message-size', '--proxy', '--prefetch', '--prefetch-on-recv', '--restful',
'--rest-api', '--host', '--port-expose'], 'export-api': ['--help', '--yaml-path', '--json-path']}}
2 changes: 1 addition & 1 deletion jina/clients/base.py
Expand Up @@ -135,7 +135,7 @@ async def _get_results(self,
on_always=on_always,
continue_on_error=self.args.continue_on_error,
logger=self.logger)
p_bar.update(self.args.batch_size)
p_bar.update(self.args.request_size)
except KeyboardInterrupt:
self.logger.warning('user cancel the process')
except grpc.aio._call.AioRpcError as rpc_ex:
Expand Down
8 changes: 5 additions & 3 deletions jina/clients/request.py
Expand Up @@ -42,7 +42,7 @@ def _build_doc_from_content():


def _generate(data: GeneratorSourceType,
batch_size: int = 0,
request_size: int = 0,
mode: RequestType = RequestType.INDEX,
mime_type: str = None,
queryset: Union[AcceptQueryLangType, Iterator[AcceptQueryLangType]] = None,
Expand All @@ -55,10 +55,10 @@ def _generate(data: GeneratorSourceType,
:return:
"""

_kwargs = dict(mime_type=mime_type, length=batch_size, weight=1.0)
_kwargs = dict(mime_type=mime_type, length=request_size, weight=1.0)

try:
for batch in batch_iterator(data, batch_size):
for batch in batch_iterator(data, request_size):
req = Request()
req.request_type = str(mode)
for content in batch:
Expand Down Expand Up @@ -89,10 +89,12 @@ def index(*args, **kwargs):
"""Generate a indexing request"""
yield from _generate(*args, **kwargs)


def update(*args, **kwargs):
"""Generate a update request"""
yield from _generate(*args, **kwargs)


def delete(*args, **kwargs):
"""Generate a delete request"""
yield from _generate(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion jina/clients/websockets.py
Expand Up @@ -73,7 +73,7 @@ async def send_requests(request_iterator):
on_always=on_always,
continue_on_error=self.args.continue_on_error,
logger=self.logger)
p_bar.update(self.args.batch_size)
p_bar.update(self.args.request_size)
if self.args.return_results:
result.append(response)
self.num_responses += 1
Expand Down
4 changes: 2 additions & 2 deletions jina/helloworld/__init__.py
Expand Up @@ -62,7 +62,7 @@ def hello_world(args):
# run it!
with f:
f.index(index_generator(num_docs=targets['index']['data'].shape[0], target=targets),
batch_size=args.index_batch_size)
request_size=args.index_request_size)

# wait for couple of seconds
countdown(8, reason=colored('behold! im going to switch to query mode', 'cyan',
Expand All @@ -75,7 +75,7 @@ def hello_world(args):
f.search(query_generator(num_docs=args.num_query, target=targets, with_groundtruth=True),
shuffle=True,
on_done=print_result,
batch_size=args.query_batch_size,
request_size=args.query_request_size,
top_k=args.top_k)

# write result to html
Expand Down
8 changes: 4 additions & 4 deletions jina/optimizers/flow_runner.py
Expand Up @@ -14,23 +14,23 @@ def __init__(
self,
flow_yaml: str,
documents: Iterator,
batch_size: int,
request_size: int,
task: str, # this can be only index or search as it is used to call the flow API
callback: Optional = None,
overwrite_workspace: bool = False,
):
"""
:param flow_yaml: path to flow yaml
:param documents: iterator with list or generator for getting the documents
:param batch_size: batch size used in the flow
:param request_size: request size used in the flow
:param task: task of the flow which can be `index` or `search`
:param callback: callback to be passed to the flow's `on_done`
:param overwrite_workspace: overwrite workspace created by the flow
"""
self.flow_yaml = flow_yaml
# TODO: Make changes for working with doc generator (Pratik, before v1.0)
self.documents = documents if type(documents) == list else list(documents)
self.batch_size = batch_size
self.request_size = request_size
if task in ('index', 'search'):
self.task = task
else:
Expand Down Expand Up @@ -73,7 +73,7 @@ def run(
with Flow.load_config(self.flow_yaml, context=trial_parameters) as f:
getattr(f, self.task)(
self.documents,
batch_size=self.batch_size,
request_size=self.request_size,
on_done=self.callback,
**kwargs,
)
Expand Down
5 changes: 3 additions & 2 deletions jina/parsers/client.py
Expand Up @@ -5,7 +5,8 @@
def mixin_client_cli_parser(parser):
gp = add_arg_group(parser, title='Client')

gp.add_argument('--batch-size', type=int, default=100,
# TODO (Joan): Remove `--batch-size` alias whenever the examples and documentations are updated
gp.add_argument('--request-size', '--batch-size', type=int, default=100,
help='the number of documents in each request')
gp.add_argument('--mode', choices=list(RequestType), type=RequestType.from_string,
# required=True,
Expand All @@ -17,4 +18,4 @@ def mixin_client_cli_parser(parser):
gp.add_argument('--continue-on-error', action='store_true', default=False,
help='if to continue on all requests when callback function throws an error')
gp.add_argument('--return-results', action='store_true', default=False,
help='if to return all results as a list')
help='if to return all results as a list')
2 changes: 0 additions & 2 deletions jina/parsers/flow.py
@@ -1,7 +1,5 @@
import argparse

from pkg_resources import resource_filename

from .base import set_base_parser
from .helper import _SHOW_ALL_ARGS
from ..enums import FlowOutputType, FlowOptimizeLevel, FlowInspectType
Expand Down
10 changes: 6 additions & 4 deletions jina/parsers/helloworld.py
Expand Up @@ -33,9 +33,10 @@ def set_hw_parser(parser=None):
gp.add_argument('--index-labels-url', type=str,
default='http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
help='the url of index labels data (should be in idx3-ubyte.gz format)')
gp.add_argument('--index-batch-size', type=int,
gp.add_argument('--index-request-size', type=int,
default=1024,
help='the batch size in indexing')
help='the request size in indexing (the maximum number of documents that will be included in a '
'Request before sending it)')

gp = add_arg_group(parser, title='Search')
gp.add_argument('--uses-query', type=str,
Expand All @@ -47,9 +48,10 @@ def set_hw_parser(parser=None):
gp.add_argument('--query-labels-url', type=str,
default='http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
help='the url of query labels data (should be in idx3-ubyte.gz format)')
gp.add_argument('--query-batch-size', type=int,
gp.add_argument('--query-request-size', type=int,
default=32,
help='the batch size in searching')
help='the request size in searching (the maximum number of documents that will be included in a '
'Request before sending it)')
gp.add_argument('--num-query', type=int, default=128,
help='number of queries to visualize')
gp.add_argument('--top-k', type=int, default=50,
Expand Down
2 changes: 1 addition & 1 deletion jina/peapods/runtimes/asyncio/grpc/async_call.py
Expand Up @@ -56,7 +56,7 @@ async def prefetch_req(num_req, fetch_to):

with TimeContext(f'prefetching {self.args.prefetch} requests', self.logger):
self.logger.warning('if this takes too long, you may want to take smaller "--prefetch" or '
'ask client to reduce "--batch-size"')
'ask client to reduce "--request-size"')
is_req_empty = await prefetch_req(self.args.prefetch, prefetch_task)
if is_req_empty and not prefetch_task:
self.logger.error('receive an empty stream from the client! '
Expand Down
Expand Up @@ -82,18 +82,18 @@ def document_generator(num_docs, num_chunks, num_chunks_chunks):
yield doc


@pytest.mark.parametrize('request_batch_size', [8, 16, 32])
@pytest.mark.parametrize('request_size', [8, 16, 32])
@pytest.mark.parametrize('driver_batch_size', [3, 4, 13])
def test_encode_driver_batching(request_batch_size, driver_batch_size, tmpdir):
def test_encode_driver_batching(request_size, driver_batch_size, tmpdir):
num_docs = 137
num_chunks = 0
num_chunks_chunks = 0

num_requests = int(num_docs / request_batch_size)
num_docs_last_req_batch = num_docs % (num_requests * request_batch_size)
num_requests = int(num_docs / request_size)
num_docs_last_req_batch = num_docs % (num_requests * request_size)

def validate_response(resp):
valid_resp_length = (len(resp.search.docs) == request_batch_size) or (
valid_resp_length = (len(resp.search.docs) == request_size) or (
len(resp.search.docs) == num_docs_last_req_batch)
assert valid_resp_length
for doc in resp.search.docs:
Expand All @@ -103,7 +103,7 @@ def fail_if_error(resp):
assert False

encoder = MockEncoder(driver_batch_size=driver_batch_size,
num_docs_in_same_request=request_batch_size,
num_docs_in_same_request=request_size,
total_num_docs=num_docs)

driver = EncodeDriver(batch_size=driver_batch_size,
Expand All @@ -117,23 +117,23 @@ def fail_if_error(resp):

with Flow().add(uses=executor_yml_file) as f:
f.search(input_fn=document_generator(num_docs, num_chunks, num_chunks_chunks),
batch_size=request_batch_size,
request_size=request_size,
on_done=validate_response,
on_error=fail_if_error)


@pytest.mark.parametrize('request_batch_size', [8, 16, 32])
@pytest.mark.parametrize('driver_batch_size', [8, 32, 64])
@pytest.mark.parametrize('request_size', [8, 16, 32])
@pytest.mark.parametrize('driver_batch_size', [3, 4, 13])
@pytest.mark.parametrize('num_chunks', [2, 8])
@pytest.mark.parametrize('num_chunks_chunks', [2, 8])
def test_encode_driver_batching_with_chunks(request_batch_size, driver_batch_size, num_chunks, num_chunks_chunks,
def test_encode_driver_batching_with_chunks(request_size, driver_batch_size, num_chunks, num_chunks_chunks,
tmpdir):
num_docs = 128
num_requests = int(num_docs / request_batch_size)
num_docs_last_req_batch = num_docs % (num_requests * request_batch_size)
num_docs = 137
num_requests = int(num_docs / request_size)
num_docs_last_req_batch = num_docs % (num_requests * request_size)

def validate_response(resp):
valid_resp_length = (len(resp.search.docs) == request_batch_size) or (
valid_resp_length = (len(resp.search.docs) == request_size) or (
len(resp.search.docs) == num_docs_last_req_batch)
assert valid_resp_length
for doc in resp.search.docs:
Expand All @@ -147,7 +147,7 @@ def fail_if_error(resp):
assert False

encoder = MockEncoder(driver_batch_size=driver_batch_size,
num_docs_in_same_request=request_batch_size + request_batch_size*num_chunks + request_batch_size*num_chunks*num_chunks_chunks,
num_docs_in_same_request=request_size + request_size*num_chunks + request_size*num_chunks*num_chunks_chunks,
total_num_docs=num_docs + num_docs*num_chunks + num_docs*num_chunks*num_chunks_chunks)

driver = EncodeDriver(batch_size=driver_batch_size,
Expand All @@ -161,6 +161,6 @@ def fail_if_error(resp):

with Flow().add(uses=executor_yml_file) as f:
f.search(input_fn=document_generator(num_docs, num_chunks, num_chunks_chunks),
batch_size=request_batch_size,
request_size=request_size,
on_done=validate_response,
on_error=fail_if_error)
2 changes: 1 addition & 1 deletion tests/integration/evaluation/test_evaluation_from_file.py
Expand Up @@ -73,7 +73,7 @@ def test_evaluation_from_file(random_workspace, index_groundtruth, evaluate_docs
monkeypatch.setenv("RESTFUL", restful)

with Flow.load_config(index_yaml) as index_gt_flow:
index_gt_flow.index(input_fn=index_groundtruth, batch_size=10)
index_gt_flow.index(input_fn=index_groundtruth, request_size=10)

m = mocker.Mock()

Expand Down
Expand Up @@ -24,7 +24,7 @@ def fill_responses(resp):
with Flow().load_config(os.path.join(cur_dir, 'flow.yml')) as f:
f.search(input_fn=data,
on_done=fill_responses,
batch_size=1,
request_size=1,
callback_on='body')

del os.environ['JINA_NON_BLOCKING_PARALLEL']
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/optimizers/test_optimizer.py
Expand Up @@ -19,7 +19,7 @@ def document_generator(num_doc):
eval_flow_runner = FlowRunner(
flow_yaml='tests/integration/optimizers/flow.yml',
documents=document_generator(10),
batch_size=1,
request_size=1,
task='search',
callback=EvaluationCallback(),
)
Expand Down
Expand Up @@ -39,7 +39,7 @@ def random_workspace(tmpdir):
def test_indexer_with_ref_indexer(random_workspace, parallel, index_docs, mocker):
top_k = 10
with Flow.load_config('index.yml') as index_flow:
index_flow.index(input_fn=index_docs, batch_size=10)
index_flow.index(input_fn=index_docs, request_size=10)

mock = mocker.Mock()

Expand Down Expand Up @@ -69,7 +69,7 @@ def random_workspace_move(tmpdir):
def test_indexer_with_ref_indexer_move(random_workspace_move, parallel, index_docs, mocker):
top_k = 10
with Flow.load_config('index.yml') as index_flow:
index_flow.index(input_fn=index_docs, batch_size=10)
index_flow.index(input_fn=index_docs, request_size=10)

mock = mocker.Mock()

Expand Down
18 changes: 9 additions & 9 deletions tests/unit/clients/python/test_request.py
Expand Up @@ -77,7 +77,7 @@ def random_lines(num_lines):
for j in range(1, num_lines + 1):
yield f'i\'m dummy doc {j}'

req = _generate(data=random_lines(100), batch_size=100)
req = _generate(data=random_lines(100), request_size=100)

request = next(req)
assert len(request.index.docs) == 100
Expand All @@ -91,7 +91,7 @@ def test_request_generate_lines_from_list():
def random_lines(num_lines):
return [f'i\'m dummy doc {j}' for j in range(1, num_lines + 1)]

req = _generate(data=random_lines(100), batch_size=100)
req = _generate(data=random_lines(100), request_size=100)

request = next(req)
assert len(request.index.docs) == 100
Expand All @@ -106,7 +106,7 @@ def random_lines(num_lines):
for j in range(1, num_lines + 1):
yield f'https://github.com i\'m dummy doc {j}'

req = _generate(data=random_lines(100), batch_size=100)
req = _generate(data=random_lines(100), request_size=100)

request = next(req)
assert len(request.index.docs) == 100
Expand All @@ -121,7 +121,7 @@ def random_lines(num_lines):
for j in range(1, num_lines + 1):
yield f'i\'m dummy doc {j}'

req = _generate(data=random_lines(100), batch_size=100)
req = _generate(data=random_lines(100), request_size=100)

request = next(req)
assert len(request.index.docs) == 100
Expand All @@ -141,7 +141,7 @@ def random_docs(num_docs):
doc.mime_type = 'mime_type'
yield doc

req = _generate(data=random_docs(100), batch_size=100)
req = _generate(data=random_docs(100), request_size=100)

request = next(req)
assert len(request.index.docs) == 100
Expand Down Expand Up @@ -174,7 +174,7 @@ def random_docs(num_docs):
}
yield doc

req = _generate(data=random_docs(100), batch_size=100)
req = _generate(data=random_docs(100), request_size=100)

request = next(req)
assert len(request.index.docs) == 100
Expand Down Expand Up @@ -213,7 +213,7 @@ def random_docs(num_docs):
}
yield json.dumps(doc)

req = _generate(data=random_docs(100), batch_size=100)
req = _generate(data=random_docs(100), request_size=100)

request = next(req)
assert len(request.index.docs) == 100
Expand All @@ -231,7 +231,7 @@ def random_docs(num_docs):
def test_request_generate_numpy_arrays():
input_array = np.random.random([10, 10])

req = _generate(data=input_array, batch_size=5)
req = _generate(data=input_array, request_size=5)

request = next(req)
assert len(request.index.docs) == 5
Expand All @@ -253,7 +253,7 @@ def generator():
for array in input_array:
yield array

req = _generate(data=generator(), batch_size=5)
req = _generate(data=generator(), request_size=5)

request = next(req)
assert len(request.index.docs) == 5
Expand Down

0 comments on commit 5530733

Please sign in to comment.