/
mpi_launcher.py
397 lines (323 loc) · 11.6 KB
/
mpi_launcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
"""Manages the spawning of mpi processes to send to the various solvers.
"""
import os
import functools
import numpy as np
from .slepc_linalg import (
eigs_slepc,
svds_slepc,
mfn_multiply_slepc,
ssolve_slepc,
)
from ..core import _NUM_THREAD_WORKERS
# Work out if already running as mpi
if (
("OMPI_COMM_WORLD_SIZE" in os.environ) # OpenMPI
or ("PMI_SIZE" in os.environ) # MPICH
):
QUIMB_MPI_LAUNCHED = "_QUIMB_MPI_LAUNCHED" in os.environ
ALREADY_RUNNING_AS_MPI = True
USE_SYNCRO = "QUIMB_SYNCRO_MPI" in os.environ
else:
QUIMB_MPI_LAUNCHED = False
ALREADY_RUNNING_AS_MPI = False
USE_SYNCRO = False
# default to not allowing mpi spawning capabilities
ALLOW_SPAWN = {
"TRUE": True,
"ON": True,
"FALSE": False,
"OFF": False,
}[os.environ.get("QUIMB_MPI_SPAWN", "False").upper()]
# Work out the desired total number of workers
for _NUM_MPI_WORKERS_VAR in (
"QUIMB_NUM_MPI_WORKERS",
"QUIMB_NUM_PROCS",
"OMPI_COMM_WORLD_SIZE",
"PMI_SIZE",
"OMP_NUM_THREADS",
):
if _NUM_MPI_WORKERS_VAR in os.environ:
NUM_MPI_WORKERS = int(os.environ[_NUM_MPI_WORKERS_VAR])
break
else:
import psutil
_NUM_MPI_WORKERS_VAR = "psutil"
NUM_MPI_WORKERS = psutil.cpu_count(logical=False)
def can_use_mpi_pool():
"""Function to determine whether we are allowed to call `get_mpi_pool`."""
return ALLOW_SPAWN or ALREADY_RUNNING_AS_MPI
def bcast(result, comm, result_rank):
"""Broadcast a result to all workers, dispatching to proper MPI (rather
than pickled) communication if the result is a numpy array.
"""
rank = comm.Get_rank()
# make sure all workers know if result is an array or not
if rank == result_rank:
is_ndarray = isinstance(result, np.ndarray)
else:
is_ndarray = None
is_ndarray = comm.bcast(is_ndarray, root=result_rank)
# standard (pickle) bcast if not array
if not is_ndarray:
return comm.bcast(result, root=result_rank)
# make sure all workers have shape and dtype
if rank == result_rank:
shape_dtype = result.shape, str(result.dtype)
else:
shape_dtype = None
shape_dtype = comm.bcast(shape_dtype, root=result_rank)
shape, dtype = shape_dtype
# allocate data space
if rank != result_rank:
result = np.empty(shape, dtype=dtype)
# use fast communication for main array
comm.Bcast(result, root=result_rank)
return result
class SyncroFuture:
def __init__(self, result, result_rank, comm):
self._result = result
self.result_rank = result_rank
self.comm = comm
def result(self):
rank = self.comm.Get_rank()
if rank == self.result_rank:
should_it = isinstance(self._result, tuple) and any(
isinstance(x, np.ndarray) for x in self._result
)
if should_it:
iterate_over = len(self._result)
else:
iterate_over = 0
else:
iterate_over = None
iterate_over = self.comm.bcast(iterate_over, root=self.result_rank)
if iterate_over:
if rank != self.result_rank:
self._result = (None,) * iterate_over
result = tuple(
bcast(x, self.comm, self.result_rank) for x in self._result
)
else:
result = bcast(self._result, self.comm, self.result_rank)
return result
@staticmethod
def cancel():
raise ValueError(
"SyncroFutures cannot be cancelled - they are "
"submitted in a parallel round-robin fasion where "
"each worker immediately computes all its results."
)
class SynchroMPIPool:
"""An object that looks like a ``concurrent.futures`` executor but actually
distributes tasks in a round-robin fashion based to MPI workers, before
broadcasting the results to each other.
"""
def __init__(self):
import itertools
from mpi4py import MPI
self.comm = MPI.COMM_WORLD
self.size = self.comm.Get_size()
self.rank = self.comm.Get_rank()
self.counter = itertools.cycle(range(0, self.size))
self._max_workers = self.size
def submit(self, fn, *args, **kwargs):
# round robin iterate through ranks
current_counter = next(self.counter)
# accept job and compute if have the same rank, else do nothing
if current_counter == self.rank:
res = fn(*args, **kwargs)
else:
res = None
# wrap the result in a SyncroFuture, that will broadcast result
return SyncroFuture(res, current_counter, self.comm)
def shutdown(self):
pass
class CachedPoolWithShutdown:
"""Decorator for caching the mpi pool when called with the equivalent args,
and shutting down previous ones when not needed.
"""
def __init__(self, pool_fn):
self._settings = "__UNINITIALIZED__"
self._pool_fn = pool_fn
def __call__(self, num_workers=None, num_threads=1):
# convert None to default so the cache the same
if num_workers is None:
num_workers = NUM_MPI_WORKERS
elif ALREADY_RUNNING_AS_MPI and (num_workers != NUM_MPI_WORKERS):
raise ValueError(
"Can't specify number of processes when running "
"under MPI rather than spawning processes."
)
# first call
if self._settings == "__UNINITIALIZED__":
self._pool = self._pool_fn(num_workers, num_threads)
self._settings = (num_workers, num_threads)
# new type of pool requested
elif self._settings != (num_workers, num_threads):
self._pool.shutdown()
self._pool = self._pool_fn(num_workers, num_threads)
self._settings = (num_workers, num_threads)
return self._pool
@CachedPoolWithShutdown
def get_mpi_pool(num_workers=None, num_threads=1):
"""Get the MPI executor pool, with specified number of processes and
threads per process.
"""
if (num_workers == 1) and (num_threads == _NUM_THREAD_WORKERS):
from concurrent.futures import ProcessPoolExecutor
return ProcessPoolExecutor(1)
if not QUIMB_MPI_LAUNCHED:
raise RuntimeError(
"For the moment, quimb programs using `get_mpi_pool` need to be "
"explicitly launched using `quimb-mpi-python`."
)
if USE_SYNCRO:
return SynchroMPIPool()
if not can_use_mpi_pool():
raise RuntimeError(
"`get_mpi_pool()` cannot be explicitly called unless already "
"running under MPI, or you set the environment variable "
"`QUIMB_MPI_SPAWN=True`."
)
from mpi4py.futures import MPIPoolExecutor
return MPIPoolExecutor(
num_workers,
main=False,
env={
"OMP_NUM_THREADS": str(num_threads),
"QUIMB_NUM_MPI_WORKERS": str(num_workers),
"_QUIMB_MPI_LAUNCHED": "SPAWNED",
},
)
class GetMPIBeforeCall(object):
"""Wrap a function to automatically get the correct communicator before
its called, and to set the `comm_self` kwarg to allow forced self mode.
This is called by every mpi process before the function evaluation.
"""
def __init__(self, fn):
self.fn = fn
def __call__(
self, *args, comm_self=False, wait_for_workers=None, **kwargs
):
"""
Parameters
----------
args
Supplied to self.fn
comm_self : bool, optional
Whether to force use of MPI.COMM_SELF
wait_for_workers : int, optional
If set, wait for the communicator to have this many workers, this
can help to catch some errors regarding expected worker numbers.
kwargs
Supplied to self.fn
"""
from mpi4py import MPI
if not comm_self:
comm = MPI.COMM_WORLD
else:
comm = MPI.COMM_SELF
if wait_for_workers is not None:
from time import time
t0 = time()
while comm.Get_size() != wait_for_workers:
if time() - t0 > 2:
raise RuntimeError(
f"Timeout while waiting for {wait_for_workers} "
f"workers to join comm {comm}."
)
res = self.fn(*args, comm=comm, **kwargs)
return res
class SpawnMPIProcessesFunc(object):
"""Automatically wrap a function to be executed in parallel by a
pool of mpi workers.
This is only called by the master mpi process in manual mode, only by
the (non-mpi) spawning process in automatic mode, or by all processes in
syncro mode.
"""
def __init__(self, fn):
self.fn = fn
def __call__(
self,
*args,
num_workers=None,
num_threads=1,
mpi_pool=None,
spawn_all=USE_SYNCRO or (not ALREADY_RUNNING_AS_MPI),
**kwargs,
):
"""
Parameters
----------
args
Supplied to `self.fn`.
num_workers : int, optional
How many total process should run function in parallel.
num_threads : int, optional
How many (OMP) threads each process should use
mpi_pool : pool-like, optional
If not None (default), submit function to this pool.
spawn_all : bool, optional
Whether all the parallel processes should be spawned (True), or
num_workers - 1, so that the current process can also do work.
kwargs
Supplied to `self.fn`.
Returns
-------
`fn` output from the master process.
"""
if num_workers is None:
num_workers = NUM_MPI_WORKERS
if (
# use must explicitly run program as
(not can_use_mpi_pool())
or
# no pool or communicator needed
(num_workers == 1)
):
return self.fn(*args, comm_self=True, **kwargs)
kwargs["wait_for_workers"] = num_workers
if mpi_pool is not None:
pool = mpi_pool
else:
pool = get_mpi_pool(num_workers, num_threads)
# the (non mpi) main process is idle while the workers compute.
if spawn_all:
futures = [
pool.submit(self.fn, *args, **kwargs)
for _ in range(num_workers)
]
results = [f.result() for f in futures]
# the master process is the master mpi process and contributes
else:
futures = [
pool.submit(self.fn, *args, **kwargs)
for _ in range(num_workers - 1)
]
results = [self.fn(*args, **kwargs)] + [
f.result() for f in futures
]
# Get master result, (not always first submitted)
return next(r for r in results if r is not None)
# ---------------------------------- SLEPC ---------------------------------- #
eigs_slepc_mpi = functools.wraps(eigs_slepc)(GetMPIBeforeCall(eigs_slepc))
eigs_slepc_spawn = functools.wraps(eigs_slepc)(
SpawnMPIProcessesFunc(eigs_slepc_mpi)
)
svds_slepc_mpi = functools.wraps(svds_slepc)(GetMPIBeforeCall(svds_slepc))
svds_slepc_spawn = functools.wraps(svds_slepc)(
SpawnMPIProcessesFunc(svds_slepc_mpi)
)
mfn_multiply_slepc_mpi = functools.wraps(mfn_multiply_slepc)(
GetMPIBeforeCall(mfn_multiply_slepc)
)
mfn_multiply_slepc_spawn = functools.wraps(mfn_multiply_slepc)(
SpawnMPIProcessesFunc(mfn_multiply_slepc_mpi)
)
ssolve_slepc_mpi = functools.wraps(ssolve_slepc)(
GetMPIBeforeCall(ssolve_slepc)
)
ssolve_slepc_spawn = functools.wraps(ssolve_slepc)(
SpawnMPIProcessesFunc(ssolve_slepc_mpi)
)