Skip to content

Commit

Permalink
refactor: change batch_size to request_size
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Jan 13, 2021
1 parent 0bcd12f commit 95e3c1b
Show file tree
Hide file tree
Showing 15 changed files with 47 additions and 43 deletions.
2 changes: 1 addition & 1 deletion jina/clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ async def _get_results(self,
on_always=on_always,
continue_on_error=self.args.continue_on_error,
logger=self.logger)
p_bar.update(self.args.batch_size)
p_bar.update(self.args.request_size)
except KeyboardInterrupt:
self.logger.warning('user cancel the process')
except grpc.aio._call.AioRpcError as rpc_ex:
Expand Down
8 changes: 5 additions & 3 deletions jina/clients/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _build_doc_from_content():


def _generate(data: GeneratorSourceType,
batch_size: int = 0,
request_size: int = 0,
mode: RequestType = RequestType.INDEX,
mime_type: str = None,
queryset: Union[AcceptQueryLangType, Iterator[AcceptQueryLangType]] = None,
Expand All @@ -55,10 +55,10 @@ def _generate(data: GeneratorSourceType,
:return:
"""

_kwargs = dict(mime_type=mime_type, length=batch_size, weight=1.0)
_kwargs = dict(mime_type=mime_type, length=request_size, weight=1.0)

try:
for batch in batch_iterator(data, batch_size):
for batch in batch_iterator(data, request_size):
req = Request()
req.request_type = str(mode)
for content in batch:
Expand Down Expand Up @@ -89,10 +89,12 @@ def index(*args, **kwargs):
"""Generate a indexing request"""
yield from _generate(*args, **kwargs)


def update(*args, **kwargs):
"""Generate a update request"""
yield from _generate(*args, **kwargs)


def delete(*args, **kwargs):
"""Generate a delete request"""
yield from _generate(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion jina/clients/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def send_requests(request_iterator):
on_always=on_always,
continue_on_error=self.args.continue_on_error,
logger=self.logger)
p_bar.update(self.args.batch_size)
p_bar.update(self.args.request_size)
if self.args.return_results:
result.append(response)
self.num_responses += 1
Expand Down
4 changes: 2 additions & 2 deletions jina/helloworld/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def hello_world(args):
# run it!
with f:
f.index(index_generator(num_docs=targets['index']['data'].shape[0], target=targets),
batch_size=args.index_batch_size)
request_size=args.index_request_size)

# wait for couple of seconds
countdown(8, reason=colored('behold! im going to switch to query mode', 'cyan',
Expand All @@ -75,7 +75,7 @@ def hello_world(args):
f.search(query_generator(num_docs=args.num_query, target=targets, with_groundtruth=True),
shuffle=True,
on_done=print_result,
batch_size=args.query_batch_size,
request_size=args.query_request_size,
top_k=args.top_k)

# write result to html
Expand Down
8 changes: 4 additions & 4 deletions jina/optimizers/flow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,23 @@ def __init__(
self,
flow_yaml: str,
documents: Iterator,
batch_size: int,
request_size: int,
task: str, # this can be only index or search as it is used to call the flow API
callback: Optional = None,
overwrite_workspace: bool = False,
):
"""
:param flow_yaml: path to flow yaml
:param documents: iterator with list or generator for getting the documents
:param batch_size: batch size used in the flow
:param request_size: request size used in the flow
:param task: task of the flow which can be `index` or `search`
:param callback: callback to be passed to the flow's `on_done`
:param overwrite_workspace: overwrite workspace created by the flow
"""
self.flow_yaml = flow_yaml
# TODO: Make changes for working with doc generator (Pratik, before v1.0)
self.documents = documents if type(documents) == list else list(documents)
self.batch_size = batch_size
self.request_size = request_size
if task in ('index', 'search'):
self.task = task
else:
Expand Down Expand Up @@ -73,7 +73,7 @@ def run(
with Flow.load_config(self.flow_yaml, context=trial_parameters) as f:
getattr(f, self.task)(
self.documents,
batch_size=self.batch_size,
request_size=self.request_size,
on_done=self.callback,
**kwargs,
)
Expand Down
2 changes: 1 addition & 1 deletion jina/parsers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
def mixin_client_cli_parser(parser):
gp = add_arg_group(parser, title='Client')

gp.add_argument('--batch-size', type=int, default=100,
gp.add_argument('--request-size', type=int, default=100,
help='the number of documents in each request')
gp.add_argument('--mode', choices=list(RequestType), type=RequestType.from_string,
# required=True,
Expand Down
10 changes: 6 additions & 4 deletions jina/parsers/helloworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ def set_hw_parser(parser=None):
gp.add_argument('--index-labels-url', type=str,
default='http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
help='the url of index labels data (should be in idx3-ubyte.gz format)')
gp.add_argument('--index-batch-size', type=int,
gp.add_argument('--index-request-size', type=int,
default=1024,
help='the batch size in indexing')
help='the request size in indexing (the maximum number of documents that will be included in a '
'Request before sending it)')

gp = add_arg_group(parser, title='Search')
gp.add_argument('--uses-query', type=str,
Expand All @@ -47,9 +48,10 @@ def set_hw_parser(parser=None):
gp.add_argument('--query-labels-url', type=str,
default='http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
help='the url of query labels data (should be in idx3-ubyte.gz format)')
gp.add_argument('--query-batch-size', type=int,
gp.add_argument('--query-request-size', type=int,
default=32,
help='the batch size in searching')
help='the request size in searching (the maximum number of documents that will be included in a '
'Request before sending it)')
gp.add_argument('--num-query', type=int, default=128,
help='number of queries to visualize')
gp.add_argument('--top-k', type=int, default=50,
Expand Down
2 changes: 1 addition & 1 deletion jina/peapods/runtimes/asyncio/grpc/async_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def prefetch_req(num_req, fetch_to):

with TimeContext(f'prefetching {self.args.prefetch} requests', self.logger):
self.logger.warning('if this takes too long, you may want to take smaller "--prefetch" or '
'ask client to reduce "--batch-size"')
'ask client to reduce "--request-size"')
is_req_empty = await prefetch_req(self.args.prefetch, prefetch_task)
if is_req_empty and not prefetch_task:
self.logger.error('receive an empty stream from the client! '
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,18 @@ def document_generator(num_docs, num_chunks, num_chunks_chunks):
yield doc


@pytest.mark.parametrize('request_batch_size', [8, 16, 32])
@pytest.mark.parametrize('request_size', [8, 16, 32])
@pytest.mark.parametrize('driver_batch_size', [3, 4, 13])
def test_encode_driver_batching(request_batch_size, driver_batch_size, tmpdir):
def test_encode_driver_batching(request_size, driver_batch_size, tmpdir):
num_docs = 137
num_chunks = 0
num_chunks_chunks = 0

num_requests = int(num_docs / request_batch_size)
num_docs_last_req_batch = num_docs % (num_requests * request_batch_size)
num_requests = int(num_docs / request_size)
num_docs_last_req_batch = num_docs % (num_requests * request_size)

def validate_response(resp):
valid_resp_length = (len(resp.search.docs) == request_batch_size) or (
valid_resp_length = (len(resp.search.docs) == request_size) or (
len(resp.search.docs) == num_docs_last_req_batch)
assert valid_resp_length
for doc in resp.search.docs:
Expand All @@ -103,7 +103,7 @@ def fail_if_error(resp):
assert False

encoder = MockEncoder(driver_batch_size=driver_batch_size,
num_docs_in_same_request=request_batch_size,
num_docs_in_same_request=request_size,
total_num_docs=num_docs)

driver = EncodeDriver(batch_size=driver_batch_size,
Expand All @@ -117,23 +117,23 @@ def fail_if_error(resp):

with Flow().add(uses=executor_yml_file) as f:
f.search(input_fn=document_generator(num_docs, num_chunks, num_chunks_chunks),
batch_size=request_batch_size,
request_size=request_size,
on_done=validate_response,
on_error=fail_if_error)


@pytest.mark.parametrize('request_batch_size', [8, 16, 32])
@pytest.mark.parametrize('driver_batch_size', [8, 32, 64])
@pytest.mark.parametrize('request_size', [8, 16, 32])
@pytest.mark.parametrize('driver_batch_size', [3, 4, 13])
@pytest.mark.parametrize('num_chunks', [2, 8])
@pytest.mark.parametrize('num_chunks_chunks', [2, 8])
def test_encode_driver_batching_with_chunks(request_batch_size, driver_batch_size, num_chunks, num_chunks_chunks,
def test_encode_driver_batching_with_chunks(request_size, driver_batch_size, num_chunks, num_chunks_chunks,
tmpdir):
num_docs = 128
num_requests = int(num_docs / request_batch_size)
num_docs_last_req_batch = num_docs % (num_requests * request_batch_size)
num_docs = 137
num_requests = int(num_docs / request_size)
num_docs_last_req_batch = num_docs % (num_requests * request_size)

def validate_response(resp):
valid_resp_length = (len(resp.search.docs) == request_batch_size) or (
valid_resp_length = (len(resp.search.docs) == request_size) or (
len(resp.search.docs) == num_docs_last_req_batch)
assert valid_resp_length
for doc in resp.search.docs:
Expand All @@ -147,7 +147,7 @@ def fail_if_error(resp):
assert False

encoder = MockEncoder(driver_batch_size=driver_batch_size,
num_docs_in_same_request=request_batch_size + request_batch_size*num_chunks + request_batch_size*num_chunks*num_chunks_chunks,
num_docs_in_same_request=request_size + request_size*num_chunks + request_size*num_chunks*num_chunks_chunks,
total_num_docs=num_docs + num_docs*num_chunks + num_docs*num_chunks*num_chunks_chunks)

driver = EncodeDriver(batch_size=driver_batch_size,
Expand All @@ -161,6 +161,6 @@ def fail_if_error(resp):

with Flow().add(uses=executor_yml_file) as f:
f.search(input_fn=document_generator(num_docs, num_chunks, num_chunks_chunks),
batch_size=request_batch_size,
request_size=request_size,
on_done=validate_response,
on_error=fail_if_error)
2 changes: 1 addition & 1 deletion tests/integration/evaluation/test_evaluation_from_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_evaluation_from_file(random_workspace, index_groundtruth, evaluate_docs
monkeypatch.setenv("RESTFUL", restful)

with Flow.load_config(index_yaml) as index_gt_flow:
index_gt_flow.index(input_fn=index_groundtruth, batch_size=10)
index_gt_flow.index(input_fn=index_groundtruth, request_size=10)

m = mocker.Mock()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def fill_responses(resp):
with Flow().load_config(os.path.join(cur_dir, 'flow.yml')) as f:
f.search(input_fn=data,
on_done=fill_responses,
batch_size=1,
request_size=1,
callback_on='body')

del os.environ['JINA_NON_BLOCKING_PARALLEL']
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/optimizers/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def document_generator(num_doc):
eval_flow_runner = FlowRunner(
flow_yaml='tests/integration/optimizers/flow.yml',
documents=document_generator(10),
batch_size=1,
request_size=1,
task='search',
callback=EvaluationCallback(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def random_workspace(tmpdir):
def test_indexer_with_ref_indexer(random_workspace, parallel, index_docs, mocker):
top_k = 10
with Flow.load_config('index.yml') as index_flow:
index_flow.index(input_fn=index_docs, batch_size=10)
index_flow.index(input_fn=index_docs, request_size=10)

mock = mocker.Mock()

Expand Down Expand Up @@ -69,7 +69,7 @@ def random_workspace_move(tmpdir):
def test_indexer_with_ref_indexer_move(random_workspace_move, parallel, index_docs, mocker):
top_k = 10
with Flow.load_config('index.yml') as index_flow:
index_flow.index(input_fn=index_docs, batch_size=10)
index_flow.index(input_fn=index_docs, request_size=10)

mock = mocker.Mock()

Expand Down
6 changes: 3 additions & 3 deletions tests/unit/optimizers/test_flow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def callback(resp):
flow_runner = FlowRunner(
flow_yaml='flow.yml',
documents=random_docs(5),
batch_size=1,
request_size=1,
task='index',
overwrite_workspace=True,
)
Expand All @@ -32,7 +32,7 @@ def callback(resp):
flow_runner = FlowRunner(
flow_yaml='flow.yml',
documents=random_docs(5),
batch_size=1,
request_size=1,
task='search',
callback=callback
)
Expand All @@ -48,7 +48,7 @@ def test_wrong_task():
_ = FlowRunner(
flow_yaml='flow.yml',
documents=random_docs(5),
batch_size=1,
request_size=1,
task='query',
)
assert 'task can be either of index or search' == str(excinfo.value)
4 changes: 2 additions & 2 deletions tests/unit/test_loadbalance.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_lb():
uses='SlowWorker',
parallel=10)
with f:
f.index(input_fn=random_docs(100), batch_size=10)
f.index(input_fn=random_docs(100), request_size=10)


def test_roundrobin():
Expand All @@ -39,4 +39,4 @@ def test_roundrobin():
uses='SlowWorker',
parallel=10, scheduling=SchedulerType.ROUND_ROBIN)
with f:
f.index(input_fn=random_docs(100), batch_size=10)
f.index(input_fn=random_docs(100), request_size=10)

0 comments on commit 95e3c1b

Please sign in to comment.