-
Notifications
You must be signed in to change notification settings - Fork 37
/
mpi.py
220 lines (181 loc) · 6.99 KB
/
mpi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# Copyright (c) 2015-2020 by the parties listed in the AUTHORS file.
# All rights reserved. Use of this source code is governed by
# a BSD-style license that can be found in the LICENSE file.
import os
import sys
import itertools
import numpy as np
from ._libtoast import Logger
from .pshmem import MPIShared, MPILock
use_mpi = None
MPI = None
if use_mpi is None:
# See if the user has explicitly disabled MPI.
if "MPI_DISABLE" in os.environ:
use_mpi = False
else:
# Special handling for running on a NERSC login node. This is for convenience.
# The same behavior could be implemented with environment variables set in a
# shell resource file.
at_nersc = False
if "NERSC_HOST" in os.environ:
at_nersc = True
in_slurm = False
if "SLURM_JOB_NAME" in os.environ:
in_slurm = True
if (not at_nersc) or in_slurm:
try:
import mpi4py.MPI as MPI
use_mpi = True
except:
# There could be many possible exceptions raised...
from ._libtoast import Logger
log = Logger.get()
log.info("mpi4py not found- using serial operations only")
use_mpi = False
def get_world():
"""Retrieve the default world communicator and its properties.
If MPI is enabled, this returns MPI.COMM_WORLD and the process rank and number of
processes. If MPI is disabled, this returns None for the communicator, zero
for the rank, and one for the number of processes.
Returns:
(tuple): The communicator, number of processes, and rank.
"""
rank = 0
procs = 1
world = None
if use_mpi:
world = MPI.COMM_WORLD
rank = world.rank
procs = world.size
return world, procs, rank
class Comm(object):
"""Class which represents a two-level hierarchy of MPI communicators.
A Comm object splits the full set of processes into groups of size
"group". If group_size does not divide evenly into the size of the given
communicator, then those processes remain idle.
A Comm object stores three MPI communicators: The "world" communicator
given here, which contains all processes to consider, a "group"
communicator (one per group), and a "rank" communicator which contains the
processes with the same group-rank across all groups.
If MPI is not enabled, then all communicators are set to None.
Args:
world (mpi4py.MPI.Comm): the MPI communicator containing all processes.
group (int): the size of each process group.
"""
def __init__(self, world=None, groupsize=0):
log = Logger.get()
if world is None:
if use_mpi:
# Default is COMM_WORLD
world = MPI.COMM_WORLD
else:
# MPI is disabled, leave world as None.
pass
else:
if use_mpi:
# We were passed a communicator to use. Check that it is
# actually a communicator, otherwise fall back to COMM_WORLD.
if not isinstance(world, MPI.Comm):
log.warning(
"Specified world communicator is not a valid "
"mpi4py.MPI.Comm object. Using COMM_WORLD."
)
world = MPI.COMM_WORLD
else:
log.warning(
"World communicator specified even though "
"MPI is disabled. Ignoring this constructor "
"argument."
)
world = None
# Special case, MPI available but the user want a serial
# data object
if world == MPI.COMM_SELF:
world = None
self._wcomm = world
self._wrank = 0
self._wsize = 1
if self._wcomm is not None:
self._wrank = self._wcomm.rank
self._wsize = self._wcomm.size
self._gsize = groupsize
if (self._gsize < 0) or (self._gsize > self._wsize):
log.warning(
"Invalid groupsize ({}). Should be between {} "
"and {}. Using single process group instead.".format(
groupsize, 0, self._wsize
)
)
self._gsize = 0
if self._gsize == 0:
self._gsize = self._wsize
self._ngroups = self._wsize // self._gsize
if self._ngroups * self._gsize != self._wsize:
msg = (
"World communicator size ({}) is not evenly divisible "
"by requested group size ({}).".format(self._wsize, self._gsize)
)
log.error(msg)
raise RuntimeError(msg)
self._group = self._wrank // self._gsize
self._grank = self._wrank % self._gsize
if self._ngroups == 1:
# We just have one group with all processes.
self._gcomm = self._wcomm
if use_mpi:
self._rcomm = MPI.COMM_SELF
else:
self._rcomm = None
else:
# We need to split the communicator. This code is never executed
# unless MPI is enabled and we have multiple groups.
self._gcomm = self._wcomm.Split(self._group, self._grank)
self._rcomm = self._wcomm.Split(self._grank, self._group)
@property
def world_size(self):
"""The size of the world communicator."""
return self._wsize
@property
def world_rank(self):
"""The rank of this process in the world communicator."""
return self._wrank
@property
def ngroups(self):
"""The number of process groups."""
return self._ngroups
@property
def group(self):
"""The group containing this process."""
return self._group
@property
def group_size(self):
"""The size of the group containing this process."""
return self._gsize
@property
def group_rank(self):
"""The rank of this process in the group communicator."""
return self._grank
@property
def comm_world(self):
"""The world communicator."""
return self._wcomm
@property
def comm_group(self):
"""The communicator shared by processes within this group."""
return self._gcomm
@property
def comm_rank(self):
"""The communicator shared by processes with the same group_rank."""
return self._rcomm
def __repr__(self):
lines = [
" World MPI communicator = {}".format(self._wcomm),
" World MPI size = {}".format(self._wsize),
" World MPI rank = {}".format(self._wrank),
" Group MPI communicator = {}".format(self._gcomm),
" Group MPI size = {}".format(self._gsize),
" Group MPI rank = {}".format(self._grank),
" Rank MPI communicator = {}".format(self._rcomm),
]
return "<toast.Comm\n{}\n>".format("\n".join(lines))