Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: change batch_size to request_size #1677

Merged
merged 2 commits into from
Jan 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions cli/autocomplete.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def _gaa(key, parser):
'commands': ['--help', '--version', '--version-full', 'hello-world', 'pod', 'flow', 'gateway', 'ping', 'check',
'hub', 'pea', 'log', 'client', 'export-api'], 'completions': {
'hello-world': ['--help', '--workdir', '--download-proxy', '--shards', '--parallel', '--uses-index',
'--index-data-url', '--index-labels-url', '--index-batch-size', '--uses-query',
'--query-data-url', '--query-labels-url', '--query-batch-size', '--num-query', '--top-k'],
'--index-data-url', '--index-labels-url', '--index-request-size', '--uses-query',
'--query-data-url', '--query-labels-url', '--query-request-size', '--num-query', '--top-k'],
'pod': ['--help', '--name', '--log-config', '--identity', '--log-id', '--show-exc-info', '--port-ctrl',
'--ctrl-with-ipc', '--timeout-ctrl', '--ssh-server', '--ssh-keyfile', '--ssh-password', '--uses',
'--py-modules', '--port-in', '--port-out', '--host-in', '--host-out', '--socket-in', '--socket-out',
Expand Down Expand Up @@ -76,6 +76,6 @@ def _gaa(key, parser):
'--host', '--port-expose', '--daemon', '--runtime-backend', '--runtime', '--runtime-cls',
'--timeout-ready', '--env', '--expose-public', '--pea-id', '--pea-role'],
'log': ['--help', '--groupby-regex', '--refresh-time'],
'client': ['--help', '--batch-size', '--mode', '--top-k', '--mime-type', '--continue-on-error',
'client': ['--help', '--request-size', '--mode', '--top-k', '--mime-type', '--continue-on-error',
'--return-results', '--max-message-size', '--proxy', '--prefetch', '--prefetch-on-recv', '--restful',
'--rest-api', '--host', '--port-expose'], 'export-api': ['--help', '--yaml-path', '--json-path']}}
2 changes: 1 addition & 1 deletion jina/clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ async def _get_results(self,
on_always=on_always,
continue_on_error=self.args.continue_on_error,
logger=self.logger)
p_bar.update(self.args.batch_size)
p_bar.update(self.args.request_size)
except KeyboardInterrupt:
self.logger.warning('user cancel the process')
except grpc.aio._call.AioRpcError as rpc_ex:
Expand Down
8 changes: 5 additions & 3 deletions jina/clients/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _build_doc_from_content():


def _generate(data: GeneratorSourceType,
batch_size: int = 0,
request_size: int = 0,
mode: RequestType = RequestType.INDEX,
mime_type: str = None,
queryset: Union[AcceptQueryLangType, Iterator[AcceptQueryLangType]] = None,
Expand All @@ -55,10 +55,10 @@ def _generate(data: GeneratorSourceType,
:return:
"""

_kwargs = dict(mime_type=mime_type, length=batch_size, weight=1.0)
_kwargs = dict(mime_type=mime_type, length=request_size, weight=1.0)

try:
for batch in batch_iterator(data, batch_size):
for batch in batch_iterator(data, request_size):
req = Request()
req.request_type = str(mode)
for content in batch:
Expand Down Expand Up @@ -89,10 +89,12 @@ def index(*args, **kwargs):
"""Generate a indexing request"""
yield from _generate(*args, **kwargs)


def update(*args, **kwargs):
"""Generate a update request"""
yield from _generate(*args, **kwargs)


def delete(*args, **kwargs):
"""Generate a delete request"""
yield from _generate(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion jina/clients/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def send_requests(request_iterator):
on_always=on_always,
continue_on_error=self.args.continue_on_error,
logger=self.logger)
p_bar.update(self.args.batch_size)
p_bar.update(self.args.request_size)
if self.args.return_results:
result.append(response)
self.num_responses += 1
Expand Down
4 changes: 2 additions & 2 deletions jina/helloworld/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def hello_world(args):
# run it!
with f:
f.index(index_generator(num_docs=targets['index']['data'].shape[0], target=targets),
batch_size=args.index_batch_size)
request_size=args.index_request_size)

# wait for couple of seconds
countdown(8, reason=colored('behold! im going to switch to query mode', 'cyan',
Expand All @@ -75,7 +75,7 @@ def hello_world(args):
f.search(query_generator(num_docs=args.num_query, target=targets, with_groundtruth=True),
shuffle=True,
on_done=print_result,
batch_size=args.query_batch_size,
request_size=args.query_request_size,
top_k=args.top_k)

# write result to html
Expand Down
8 changes: 4 additions & 4 deletions jina/optimizers/flow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,23 @@ def __init__(
self,
flow_yaml: str,
documents: Iterator,
batch_size: int,
request_size: int,
task: str, # this can be only index or search as it is used to call the flow API
callback: Optional = None,
overwrite_workspace: bool = False,
):
"""
:param flow_yaml: path to flow yaml
:param documents: iterator with list or generator for getting the documents
:param batch_size: batch size used in the flow
:param request_size: request size used in the flow
:param task: task of the flow which can be `index` or `search`
:param callback: callback to be passed to the flow's `on_done`
:param overwrite_workspace: overwrite workspace created by the flow
"""
self.flow_yaml = flow_yaml
# TODO: Make changes for working with doc generator (Pratik, before v1.0)
self.documents = documents if type(documents) == list else list(documents)
self.batch_size = batch_size
self.request_size = request_size
if task in ('index', 'search'):
self.task = task
else:
Expand Down Expand Up @@ -73,7 +73,7 @@ def run(
with Flow.load_config(self.flow_yaml, context=trial_parameters) as f:
getattr(f, self.task)(
self.documents,
batch_size=self.batch_size,
request_size=self.request_size,
on_done=self.callback,
**kwargs,
)
Expand Down
5 changes: 3 additions & 2 deletions jina/parsers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
def mixin_client_cli_parser(parser):
gp = add_arg_group(parser, title='Client')

gp.add_argument('--batch-size', type=int, default=100,
# TODO (Joan): Remove `--batch-size` alias whenever the examples and documentations are updated
gp.add_argument('--request-size', '--batch-size', type=int, default=100,
help='the number of documents in each request')
gp.add_argument('--mode', choices=list(RequestType), type=RequestType.from_string,
# required=True,
Expand All @@ -17,4 +18,4 @@ def mixin_client_cli_parser(parser):
gp.add_argument('--continue-on-error', action='store_true', default=False,
help='if to continue on all requests when callback function throws an error')
gp.add_argument('--return-results', action='store_true', default=False,
help='if to return all results as a list')
help='if to return all results as a list')
2 changes: 0 additions & 2 deletions jina/parsers/flow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import argparse

from pkg_resources import resource_filename

from .base import set_base_parser
from .helper import _SHOW_ALL_ARGS
from ..enums import FlowOutputType, FlowOptimizeLevel, FlowInspectType
Expand Down
10 changes: 6 additions & 4 deletions jina/parsers/helloworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ def set_hw_parser(parser=None):
gp.add_argument('--index-labels-url', type=str,
default='http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
help='the url of index labels data (should be in idx3-ubyte.gz format)')
gp.add_argument('--index-batch-size', type=int,
gp.add_argument('--index-request-size', type=int,
default=1024,
help='the batch size in indexing')
help='the request size in indexing (the maximum number of documents that will be included in a '
'Request before sending it)')

gp = add_arg_group(parser, title='Search')
gp.add_argument('--uses-query', type=str,
Expand All @@ -47,9 +48,10 @@ def set_hw_parser(parser=None):
gp.add_argument('--query-labels-url', type=str,
default='http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
help='the url of query labels data (should be in idx3-ubyte.gz format)')
gp.add_argument('--query-batch-size', type=int,
gp.add_argument('--query-request-size', type=int,
default=32,
help='the batch size in searching')
help='the request size in searching (the maximum number of documents that will be included in a '
'Request before sending it)')
gp.add_argument('--num-query', type=int, default=128,
help='number of queries to visualize')
gp.add_argument('--top-k', type=int, default=50,
Expand Down
2 changes: 1 addition & 1 deletion jina/peapods/runtimes/asyncio/grpc/async_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def prefetch_req(num_req, fetch_to):

with TimeContext(f'prefetching {self.args.prefetch} requests', self.logger):
self.logger.warning('if this takes too long, you may want to take smaller "--prefetch" or '
'ask client to reduce "--batch-size"')
'ask client to reduce "--request-size"')
is_req_empty = await prefetch_req(self.args.prefetch, prefetch_task)
if is_req_empty and not prefetch_task:
self.logger.error('receive an empty stream from the client! '
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,18 @@ def document_generator(num_docs, num_chunks, num_chunks_chunks):
yield doc


@pytest.mark.parametrize('request_batch_size', [8, 16, 32])
@pytest.mark.parametrize('request_size', [8, 16, 32])
@pytest.mark.parametrize('driver_batch_size', [3, 4, 13])
def test_encode_driver_batching(request_batch_size, driver_batch_size, tmpdir):
def test_encode_driver_batching(request_size, driver_batch_size, tmpdir):
num_docs = 137
num_chunks = 0
num_chunks_chunks = 0

num_requests = int(num_docs / request_batch_size)
num_docs_last_req_batch = num_docs % (num_requests * request_batch_size)
num_requests = int(num_docs / request_size)
num_docs_last_req_batch = num_docs % (num_requests * request_size)

def validate_response(resp):
valid_resp_length = (len(resp.search.docs) == request_batch_size) or (
valid_resp_length = (len(resp.search.docs) == request_size) or (
len(resp.search.docs) == num_docs_last_req_batch)
assert valid_resp_length
for doc in resp.search.docs:
Expand All @@ -103,7 +103,7 @@ def fail_if_error(resp):
assert False

encoder = MockEncoder(driver_batch_size=driver_batch_size,
num_docs_in_same_request=request_batch_size,
num_docs_in_same_request=request_size,
total_num_docs=num_docs)

driver = EncodeDriver(batch_size=driver_batch_size,
Expand All @@ -117,23 +117,23 @@ def fail_if_error(resp):

with Flow().add(uses=executor_yml_file) as f:
f.search(input_fn=document_generator(num_docs, num_chunks, num_chunks_chunks),
batch_size=request_batch_size,
request_size=request_size,
on_done=validate_response,
on_error=fail_if_error)


@pytest.mark.parametrize('request_batch_size', [8, 16, 32])
@pytest.mark.parametrize('driver_batch_size', [8, 32, 64])
@pytest.mark.parametrize('request_size', [8, 16, 32])
@pytest.mark.parametrize('driver_batch_size', [3, 4, 13])
@pytest.mark.parametrize('num_chunks', [2, 8])
@pytest.mark.parametrize('num_chunks_chunks', [2, 8])
def test_encode_driver_batching_with_chunks(request_batch_size, driver_batch_size, num_chunks, num_chunks_chunks,
def test_encode_driver_batching_with_chunks(request_size, driver_batch_size, num_chunks, num_chunks_chunks,
tmpdir):
num_docs = 128
num_requests = int(num_docs / request_batch_size)
num_docs_last_req_batch = num_docs % (num_requests * request_batch_size)
num_docs = 137
num_requests = int(num_docs / request_size)
num_docs_last_req_batch = num_docs % (num_requests * request_size)

def validate_response(resp):
valid_resp_length = (len(resp.search.docs) == request_batch_size) or (
valid_resp_length = (len(resp.search.docs) == request_size) or (
len(resp.search.docs) == num_docs_last_req_batch)
assert valid_resp_length
for doc in resp.search.docs:
Expand All @@ -147,7 +147,7 @@ def fail_if_error(resp):
assert False

encoder = MockEncoder(driver_batch_size=driver_batch_size,
num_docs_in_same_request=request_batch_size + request_batch_size*num_chunks + request_batch_size*num_chunks*num_chunks_chunks,
num_docs_in_same_request=request_size + request_size*num_chunks + request_size*num_chunks*num_chunks_chunks,
total_num_docs=num_docs + num_docs*num_chunks + num_docs*num_chunks*num_chunks_chunks)

driver = EncodeDriver(batch_size=driver_batch_size,
Expand All @@ -161,6 +161,6 @@ def fail_if_error(resp):

with Flow().add(uses=executor_yml_file) as f:
f.search(input_fn=document_generator(num_docs, num_chunks, num_chunks_chunks),
batch_size=request_batch_size,
request_size=request_size,
on_done=validate_response,
on_error=fail_if_error)
2 changes: 1 addition & 1 deletion tests/integration/evaluation/test_evaluation_from_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_evaluation_from_file(random_workspace, index_groundtruth, evaluate_docs
monkeypatch.setenv("RESTFUL", restful)

with Flow.load_config(index_yaml) as index_gt_flow:
index_gt_flow.index(input_fn=index_groundtruth, batch_size=10)
index_gt_flow.index(input_fn=index_groundtruth, request_size=10)

m = mocker.Mock()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def fill_responses(resp):
with Flow().load_config(os.path.join(cur_dir, 'flow.yml')) as f:
f.search(input_fn=data,
on_done=fill_responses,
batch_size=1,
request_size=1,
callback_on='body')

del os.environ['JINA_NON_BLOCKING_PARALLEL']
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/optimizers/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def document_generator(num_doc):
eval_flow_runner = FlowRunner(
flow_yaml='tests/integration/optimizers/flow.yml',
documents=document_generator(10),
batch_size=1,
request_size=1,
task='search',
callback=EvaluationCallback(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def random_workspace(tmpdir):
def test_indexer_with_ref_indexer(random_workspace, parallel, index_docs, mocker):
top_k = 10
with Flow.load_config('index.yml') as index_flow:
index_flow.index(input_fn=index_docs, batch_size=10)
index_flow.index(input_fn=index_docs, request_size=10)

mock = mocker.Mock()

Expand Down Expand Up @@ -69,7 +69,7 @@ def random_workspace_move(tmpdir):
def test_indexer_with_ref_indexer_move(random_workspace_move, parallel, index_docs, mocker):
top_k = 10
with Flow.load_config('index.yml') as index_flow:
index_flow.index(input_fn=index_docs, batch_size=10)
index_flow.index(input_fn=index_docs, request_size=10)

mock = mocker.Mock()

Expand Down
18 changes: 9 additions & 9 deletions tests/unit/clients/python/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def random_lines(num_lines):
for j in range(1, num_lines + 1):
yield f'i\'m dummy doc {j}'

req = _generate(data=random_lines(100), batch_size=100)
req = _generate(data=random_lines(100), request_size=100)

request = next(req)
assert len(request.index.docs) == 100
Expand All @@ -91,7 +91,7 @@ def test_request_generate_lines_from_list():
def random_lines(num_lines):
return [f'i\'m dummy doc {j}' for j in range(1, num_lines + 1)]

req = _generate(data=random_lines(100), batch_size=100)
req = _generate(data=random_lines(100), request_size=100)

request = next(req)
assert len(request.index.docs) == 100
Expand All @@ -106,7 +106,7 @@ def random_lines(num_lines):
for j in range(1, num_lines + 1):
yield f'https://github.com i\'m dummy doc {j}'

req = _generate(data=random_lines(100), batch_size=100)
req = _generate(data=random_lines(100), request_size=100)

request = next(req)
assert len(request.index.docs) == 100
Expand All @@ -121,7 +121,7 @@ def random_lines(num_lines):
for j in range(1, num_lines + 1):
yield f'i\'m dummy doc {j}'

req = _generate(data=random_lines(100), batch_size=100)
req = _generate(data=random_lines(100), request_size=100)

request = next(req)
assert len(request.index.docs) == 100
Expand All @@ -141,7 +141,7 @@ def random_docs(num_docs):
doc.mime_type = 'mime_type'
yield doc

req = _generate(data=random_docs(100), batch_size=100)
req = _generate(data=random_docs(100), request_size=100)

request = next(req)
assert len(request.index.docs) == 100
Expand Down Expand Up @@ -174,7 +174,7 @@ def random_docs(num_docs):
}
yield doc

req = _generate(data=random_docs(100), batch_size=100)
req = _generate(data=random_docs(100), request_size=100)

request = next(req)
assert len(request.index.docs) == 100
Expand Down Expand Up @@ -213,7 +213,7 @@ def random_docs(num_docs):
}
yield json.dumps(doc)

req = _generate(data=random_docs(100), batch_size=100)
req = _generate(data=random_docs(100), request_size=100)

request = next(req)
assert len(request.index.docs) == 100
Expand All @@ -231,7 +231,7 @@ def random_docs(num_docs):
def test_request_generate_numpy_arrays():
input_array = np.random.random([10, 10])

req = _generate(data=input_array, batch_size=5)
req = _generate(data=input_array, request_size=5)

request = next(req)
assert len(request.index.docs) == 5
Expand All @@ -253,7 +253,7 @@ def generator():
for array in input_array:
yield array

req = _generate(data=generator(), batch_size=5)
req = _generate(data=generator(), request_size=5)

request = next(req)
assert len(request.index.docs) == 5
Expand Down