forked from dask/dask
/
multiprocessing.py
102 lines (79 loc) · 2.81 KB
/
multiprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from __future__ import absolute_import, division, print_function
import multiprocessing
import pickle
import sys
from .async import get_async # TODO: get better get
from .context import _globals
from .optimize import fuse, cull
import cloudpickle
if sys.version_info.major < 3:
import copy_reg as copyreg
else:
import copyreg
def _reduce_method_descriptor(m):
return getattr, (m.__objclass__, m.__name__)
# type(set.union) is used as a proxy to <class 'method_descriptor'>
copyreg.pickle(type(set.union), _reduce_method_descriptor)
def _dumps(x):
return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL)
_loads = pickle.loads
def _process_get_id():
return multiprocessing.current_process().ident
def get(dsk, keys, num_workers=None, func_loads=None, func_dumps=None,
optimize_graph=True, **kwargs):
""" Multiprocessed get function appropriate for Bags
Parameters
----------
dsk : dict
dask graph
keys : object or list
Desired results from graph
num_workers : int
Number of worker processes (defaults to number of cores)
func_dumps : function
Function to use for function serialization
(defaults to cloudpickle.dumps)
func_loads : function
Function to use for function deserialization
(defaults to cloudpickle.loads)
optimize_graph : bool
If True [default], `fuse` is applied to the graph before computation.
"""
pool = _globals['pool']
if pool is None:
pool = multiprocessing.Pool(num_workers,
initializer=initialize_worker_process)
cleanup = True
else:
cleanup = False
# Optimize Dask
dsk2, dependencies = cull(dsk, keys)
if optimize_graph:
dsk3, dependencies = fuse(dsk2, keys, dependencies)
else:
dsk3 = dsk2
# We specify marshalling functions in order to catch serialization
# errors and report them to the user.
loads = func_loads or _globals.get('func_loads') or _loads
dumps = func_dumps or _globals.get('func_dumps') or _dumps
# Note former versions used a multiprocessing Manager to share
# a Queue between parent and workers, but this is fragile on Windows
# (issue #1652).
try:
# Run
result = get_async(pool.apply_async, len(pool._pool), dsk3, keys,
get_id=_process_get_id,
dumps=dumps, loads=loads, **kwargs)
finally:
if cleanup:
pool.close()
return result
def initialize_worker_process():
"""
Initialize a worker process before running any tasks in it.
"""
# If Numpy is already imported, presumably its random state was
# inherited from the parent => re-seed it.
np = sys.modules.get('numpy')
if np is not None:
np.random.seed()