Skip to content

Commit

Permalink
Reset parallel batch statistics at the end of a run (#556)
Browse files Browse the repository at this point in the history
  • Loading branch information
aabadie authored and lesteve committed Oct 13, 2017
1 parent 36f08e5 commit df568d0
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 5 deletions.
28 changes: 23 additions & 5 deletions joblib/_parallel_backends.py
Expand Up @@ -169,9 +169,14 @@ class AutoBatchingMixin(object):
# on a single worker while other workers have no work to process any more.
MAX_IDEAL_BATCH_DURATION = 2

# Batching counters
_effective_batch_size = 1
_smoothed_batch_duration = 0.0
# Batching counters default values
_DEFAULT_EFFECTIVE_BATCH_SIZE = 1
_DEFAULT_SMOOTHED_BATCH_DURATION = 0.0

def __init__(self):
self._effective_batch_size = self._DEFAULT_EFFECTIVE_BATCH_SIZE
self._smoothed_batch_duration = self._DEFAULT_SMOOTHED_BATCH_DURATION


def compute_batch_size(self):
"""Determine the optimal batch size"""
Expand Down Expand Up @@ -215,7 +220,8 @@ def compute_batch_size(self):
# CallBack as long as the batch_size is constant. Therefore
# we need to reset the estimate whenever we re-tune the batch
# size.
self._smoothed_batch_duration = 0
self._smoothed_batch_duration = \
self._DEFAULT_SMOOTHED_BATCH_DURATION

return batch_size

Expand All @@ -225,7 +231,7 @@ def batch_completed(self, batch_size, duration):
# Update the smoothed streaming estimate of the duration of a batch
# from dispatch to completion
old_duration = self._smoothed_batch_duration
if old_duration == 0:
if old_duration == self._DEFAULT_SMOOTHED_BATCH_DURATION:
# First record of duration for this batch size after the last
# reset.
new_duration = duration
Expand All @@ -235,6 +241,14 @@ def batch_completed(self, batch_size, duration):
new_duration = 0.8 * old_duration + 0.2 * duration
self._smoothed_batch_duration = new_duration

def reset_batch_stats(self):
"""Reset batch statistics to default values.
This avoids interferences with future jobs.
"""
self._effective_batch_size = self._DEFAULT_EFFECTIVE_BATCH_SIZE
self._smoothed_batch_duration = self._DEFAULT_SMOOTHED_BATCH_DURATION


class ThreadingBackend(PoolManagerMixin, ParallelBackendBase):
"""A ParallelBackend which will use a thread pool to execute batches in.
Expand Down Expand Up @@ -342,6 +356,8 @@ def terminate(self):
if self.JOBLIB_SPAWNED_PROCESS in os.environ:
del os.environ[self.JOBLIB_SPAWNED_PROCESS]

self.reset_batch_stats()


class LokyBackend(AutoBatchingMixin, ParallelBackendBase):
"""Managing pool of workers with loky instead of multiprocessing."""
Expand Down Expand Up @@ -411,6 +427,8 @@ def terminate(self):
delete_folder(self._workers._temp_folder)
self._workers = None

self.reset_batch_stats()

def abort_everything(self, ensure_ready=True):
"""Shutdown the workers and restart a new one with the same parameters
"""
Expand Down
30 changes: 30 additions & 0 deletions joblib/test/test_parallel.py
Expand Up @@ -940,3 +940,33 @@ def __reduce__(self):
with warns(DeprecationWarning):
with raises(ValueError):
delayed(UnpicklableCallable(), check_pickle=True)


@with_multiprocessing
@parametrize('backend', ['multiprocessing', 'loky'])
def test_backend_batch_statistics_reset(backend):
"""Test that a parallel backend correctly resets its batch statistics."""
relative_tolerance = 0.2
n_jobs = 2
n_inputs = 500
task_time = 2. / n_inputs

p = Parallel(verbose=10, n_jobs=n_jobs, backend=backend)
start_time = time.time()
p(delayed(time.sleep)(task_time) for i in range(n_inputs))
ref_time = time.time() - start_time
assert (p._backend._effective_batch_size ==
p._backend._DEFAULT_EFFECTIVE_BATCH_SIZE)
assert (p._backend._smoothed_batch_duration ==
p._backend._DEFAULT_SMOOTHED_BATCH_DURATION)

start_time = time.time()
p(delayed(time.sleep)(task_time) for i in range(n_inputs))
test_time = time.time() - start_time
assert (p._backend._effective_batch_size ==
p._backend._DEFAULT_EFFECTIVE_BATCH_SIZE)
assert (p._backend._smoothed_batch_duration ==
p._backend._DEFAULT_SMOOTHED_BATCH_DURATION)

# Tolerance in the timing comparison to avoid random failures on CIs
assert test_time / ref_time <= 1 + relative_tolerance

0 comments on commit df568d0

Please sign in to comment.