-
Notifications
You must be signed in to change notification settings - Fork 4
/
parallel.py
58 lines (41 loc) · 1.49 KB
/
parallel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
Helper functions for parallel processing.
.. currentmodule:: erlab.parallel
"""
__all__ = ["joblib_progress", "joblib_progress_qt"]
import contextlib
import sys
import joblib
import joblib._parallel_backends
import tqdm.auto
from qtpy import QtCore
@contextlib.contextmanager
def joblib_progress(file=None, **kwargs):
"""Patches joblib to report into a tqdm progress bar."""
if file is None:
file = sys.stdout
tqdm_object = tqdm.auto.tqdm(iterable=None, file=file, **kwargs)
def tqdm_print_progress(self):
if self.n_completed_tasks > tqdm_object.n:
n_completed = self.n_completed_tasks - tqdm_object.n
tqdm_object.update(n=n_completed)
original_print_progress = joblib.parallel.Parallel.print_progress
joblib.parallel.Parallel.print_progress = tqdm_print_progress
try:
yield tqdm_object
finally:
joblib.parallel.Parallel.print_progress = original_print_progress
tqdm_object.close()
@contextlib.contextmanager
def joblib_progress_qt(signal: QtCore.Signal):
"""Context manager for interactive windows.
The number of completed tasks are emitted by the given signal.
"""
def qt_print_progress(self):
signal.emit(self.n_completed_tasks)
original_print_progress = joblib.parallel.Parallel.print_progress
joblib.parallel.Parallel.print_progress = qt_print_progress
try:
yield None
finally:
joblib.parallel.Parallel.print_progress = original_print_progress