Skip to content

Commit

Permalink
Add an option to disable multiprocessing in Inferencer(#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanaysoni committed Oct 25, 2019
1 parent 894b1c1 commit 7f2203e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 18 deletions.
42 changes: 25 additions & 17 deletions farm/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def inference_from_file(self, file):
preds_all = self.inference_from_dicts(dicts, rest_api_schema=False)
return preds_all

def inference_from_dicts(self, dicts, rest_api_schema=False):
def inference_from_dicts(self, dicts, rest_api_schema=False, use_multiprocessing=True):
"""
Runs down-stream inference using the prediction head.
Expand All @@ -149,6 +149,8 @@ def inference_from_dicts(self, dicts, rest_api_schema=False):
:param rest_api_schema: whether conform to the schema used for dicts in the HTTP API for Inference.
:type rest_api_schema: bool
:return: dict of predictions
:param use_multiprocessing: time incurred in spawning processes could outweigh performance boost for very small
number of dicts, eg, HTTP APIs for inference. This flags allows to disable multiprocessing for such cases.
"""
if self.prediction_type == "embedder":
Expand All @@ -170,25 +172,31 @@ def inference_from_dicts(self, dicts, rest_api_schema=False):
dict_batches_to_process = int(len(dicts) / multiprocessing_chunk_size)
num_cpus_used = min(mp.cpu_count(), dict_batches_to_process) or 1

with ExitStack() as stack:
p = stack.enter_context(mp.Pool(processes=num_cpus_used))
if use_multiprocessing:
with ExitStack() as stack:
p = stack.enter_context(mp.Pool(processes=num_cpus_used))

logger.info(
f"Got ya {num_cpus_used} parallel workers to do inference on {len(dicts)}dicts (chunksize = {multiprocessing_chunk_size})..."
)
log_ascii_workers(num_cpus_used, logger)
logger.info(
f"Got ya {num_cpus_used} parallel workers to do inference on {len(dicts)}dicts (chunksize = {multiprocessing_chunk_size})..."
)
log_ascii_workers(num_cpus_used, logger)

results = p.imap(
partial(self._multiproc, processor=self.processor, rest_api_schema=rest_api_schema),
grouper(dicts, multiprocessing_chunk_size),
1,
)
results = p.imap(
partial(self._multiproc, processor=self.processor, rest_api_schema=rest_api_schema),
grouper(dicts, multiprocessing_chunk_size),
1,
)

preds_all = []
with tqdm(total=len(dicts), unit=" Dicts") as pbar:
for dataset, tensor_names, sample in results:
preds_all.extend(self._run_inference(dataset, tensor_names, sample))
pbar.update(multiprocessing_chunk_size)

preds_all = []
with tqdm(total=len(dicts), unit=" Dicts") as pbar:
for dataset, tensor_names, sample in results:
preds_all.extend(self._run_inference(dataset, tensor_names, sample))
pbar.update(multiprocessing_chunk_size)
else:
chunk = next(grouper(dicts, len(dicts)))
dataset, tensor_names, sample = self._multiproc(chunk, processor=self.processor, rest_api_schema=rest_api_schema)
preds_all = self._run_inference(dataset, tensor_names, sample)

return preds_all

Expand Down
2 changes: 1 addition & 1 deletion farm/inference_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def post(self, model_id):
dicts = request.get_json().get("input", None)
if not dicts:
return {}
results = model.inference_from_dicts(dicts=dicts, rest_api_schema=True)
results = model.inference_from_dicts(dicts=dicts, rest_api_schema=True, use_multiprocessing=False)
return results[0]


Expand Down

0 comments on commit 7f2203e

Please sign in to comment.