Skip to content

Commit

Permalink
only read cols we need
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipMay committed Jun 6, 2020
1 parent 675b867 commit e4b8dbf
Showing 1 changed file with 37 additions and 49 deletions.
86 changes: 37 additions & 49 deletions farm/infer.py
Expand Up @@ -121,7 +121,7 @@ def __init__(
model.connect_heads_with_processor(processor.tasks, require_labels=False)
set_all_seeds(42)

self._set_multiprocessing_pool(num_processes)
self.num_processes = num_processes

@classmethod
def load(
Expand Down Expand Up @@ -259,28 +259,6 @@ def load(
disable_tqdm=disable_tqdm
)

def _set_multiprocessing_pool(self, num_processes):
"""
Initialize a multiprocessing.Pool for instances of Inferencer.
:param num_processes: the number of processes for `multiprocessing.Pool`. Set to value of 0 to disable
multiprocessing. Set to None to let Inferencer use all CPU cores. If you want to
debug the Language Model, you might need to disable multiprocessing!
:type num_processes: int
:return:
"""
self.process_pool = None
if num_processes == 0: # disable multiprocessing
self.process_pool = None
else:
if num_processes is None: # use all CPU cores
num_processes = mp.cpu_count() - 1
self.process_pool = mp.Pool(processes=num_processes)
logger.info(
f"Got ya {num_processes} parallel workers to do inference ..."
)
log_ascii_workers(n=num_processes,logger=logger)

def save(self, path):
self.model.save(path)
self.processor.save(path)
Expand Down Expand Up @@ -358,7 +336,7 @@ def inference_from_dicts(
if len(self.model.prediction_heads) > 0:
aggregate_preds = hasattr(self.model.prediction_heads[0], "aggregate_preds")

if self.process_pool is None: # multiprocessing disabled (helpful for debugging or using in web frameworks)
if self.num_processes == 0: # multiprocessing disabled (helpful for debugging or using in web frameworks)
predictions = self._inference_without_multiprocessing(dicts, return_json, aggregate_preds)
return predictions
else: # use multiprocessing for inference
Expand Down Expand Up @@ -439,32 +417,42 @@ def _inference_with_multiprocessing(
:rtype: iter
"""

# We group the input dicts into chunks and feed each chunk to a different process
# in the pool, where it gets converted to a pytorch dataset
results = self.process_pool.imap(
partial(self._create_datasets_chunkwise, processor=self.processor),
grouper(iterable=dicts, n=multiprocessing_chunksize),
1,
)
# TODO: the docstring of __init__ said we use all CPU cores but we use one less
num_processes = mp.cpu_count() - 1

# Once a process spits out a preprocessed chunk. we feed this dataset directly to the model.
# So we don't need to wait until all preprocessing has finished before getting first predictions.
for dataset, tensor_names, baskets in results:
# TODO change format of formatted_preds in QA (list of dicts)
if aggregate_preds:
predictions = self._get_predictions_and_aggregate(
dataset, tensor_names, baskets
)
else:
predictions = self._get_predictions(dataset, tensor_names, baskets)

if return_json:
# TODO this try catch should be removed when all tasks return prediction objects
try:
predictions = [x.to_json() for x in predictions]
except AttributeError:
pass
yield from predictions
# Use context manager to close the pool again
with mp.Pool(processes=num_processes) as pool:
logger.info(
f"Got ya {num_processes} parallel workers to do inference ..."
)
log_ascii_workers(n=num_processes, logger=logger)

# We group the input dicts into chunks and feed each chunk to a different process
# in the pool, where it gets converted to a pytorch dataset
results = pool.imap(
partial(self._create_datasets_chunkwise, processor=self.processor),
grouper(iterable=dicts, n=multiprocessing_chunksize),
1,
)

# Once a process spits out a preprocessed chunk. we feed this dataset directly to the model.
# So we don't need to wait until all preprocessing has finished before getting first predictions.
for dataset, tensor_names, baskets in results:
# TODO change format of formatted_preds in QA (list of dicts)
if aggregate_preds:
predictions = self._get_predictions_and_aggregate(
dataset, tensor_names, baskets
)
else:
predictions = self._get_predictions(dataset, tensor_names, baskets)

if return_json:
# TODO this try catch should be removed when all tasks return prediction objects
try:
predictions = [x.to_json() for x in predictions]
except AttributeError:
pass
yield from predictions

@classmethod
def _create_datasets_chunkwise(cls, chunk, processor):
Expand Down

0 comments on commit e4b8dbf

Please sign in to comment.