-
Notifications
You must be signed in to change notification settings - Fork 3.9k
/
topology.py
456 lines (372 loc) · 16.9 KB
/
topology.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
# Copyright 2019 The Microsoft DeepSpeed Team
from deepspeed.utils import logger
from deepspeed import comm as dist
import sys
from collections import namedtuple
from itertools import product as cartesian_product
class ProcessTopology:
""" Manages the mapping of n-dimensional Cartesian coordinates to linear
indices. This mapping is used to map the rank of processes to the grid
for various forms of parallelism.
Each axis of the tensor is accessed by its name. The provided ordering
of the axes defines the layout of the topology. ProcessTopology uses a "row-major"
layout of the tensor axes, and so axes=['x', 'y'] would map coordinates (x,y) and
(x,y+1) to adjacent linear indices. If instead axes=['y', 'x'] was used, coordinates
(x,y) and (x+1,y) would be adjacent.
Some methods return ProcessCoord namedtuples.
"""
def __init__(self, axes, dims):
"""Create a mapping of n-dimensional tensor coordinates to linear indices.
Arguments:
axes (list): the names of the tensor axes
dims (list): the dimension (length) of each axis of the topology tensor
"""
self.axes = axes # names of each topology axis
self.dims = dims # length of each topology axis
# This is actually a class that lets us hash {'row':3, 'col':2} mappings
self.ProcessCoord = namedtuple('ProcessCoord', axes)
self.mapping = {}
ranges = [range(d) for d in dims]
# example: 1, (0,0,1)
for global_rank, coord in enumerate(cartesian_product(*ranges)):
key = {axis: coord[self.axes.index(axis)] for axis in self.axes}
key = self.ProcessCoord(**key)
# for example, {ProcessCoord(row=0, col=1) : 1}
self.mapping[key] = global_rank
def get_rank(self, **coord_kwargs):
"""Return the global rank of a process via its coordinates.
Coordinates are specified as kwargs. For example:
>>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])
>>> X.get_rank(x=0, y=1)
1
"""
if len(coord_kwargs) != len(self.axes):
raise ValueError('get_rank() does not support slices. Use filter_match())')
key = self.ProcessCoord(**coord_kwargs)
assert key in self.mapping, f'key {kwargs} invalid'
return self.mapping[key]
def get_axis_names(self):
"""Return a list of the axis names in the ordering of the topology. """
return self.axes
def get_rank_repr(self,
rank,
omit_axes=['data',
'pipe'],
inner_sep='_',
outer_sep='-'):
"""Return a string representation of a rank.
This method is primarily used for checkpointing model data.
For example:
>>> topo = Topo(axes=['a', 'b'], dims=[2, 2])
>>> topo.get_rank_repr(rank=3)
'a_01-b_01'
>>> topo.get_rank_repr(rank=3, omit_axes=['a'])
'b_01'
Args:
rank (int): A rank in the topology.
omit_axes (list, optional): Axes that should not be in the representation. Defaults to ['data', 'pipe'].
inner_sep (str, optional): [description]. Defaults to '_'.
outer_sep (str, optional): [description]. Defaults to '-'.
Returns:
str: A string representation of the coordinate owned by ``rank``.
"""
omit_axes = frozenset(omit_axes)
axes = [a for a in self.get_axis_names() if a not in omit_axes]
names = []
for ax in axes:
ax_rank = getattr(self.get_coord(rank=rank), ax)
names.append(f'{ax}{inner_sep}{ax_rank:02d}')
return outer_sep.join(names)
def get_dim(self, axis):
"""Return the number of processes along the given axis.
For example:
>>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])
>>> X.get_dim('y')
3
"""
if axis not in self.axes:
return 0
return self.dims[self.axes.index(axis)]
def get_coord(self, rank):
"""Return the coordinate owned by a process rank.
The axes of the returned namedtuple can be directly accessed as members. For
example:
>>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])
>>> coord = X.get_coord(rank=1)
>>> coord.x
0
>>> coord.y
1
"""
for coord, idx in self.mapping.items():
if idx == rank:
return coord
raise ValueError(f'rank {rank} not found in topology.')
def get_axis_comm_lists(self, axis):
""" Construct lists suitable for a communicator group along axis ``axis``.
Example:
>>> topo = Topo(axes=['pipe', 'data', 'model'], dims=[2, 2, 2])
>>> topo.get_axis_comm_lists('pipe')
[
[0, 4], # data=0, model=0
[1, 5], # data=0, model=1
[2, 6], # data=1, model=0
[3, 7], # data=1, model=1
]
Returns:
A list of lists whose coordinates match in all axes *except* ``axis``.
"""
# We don't want to RuntimeError because it allows us to write more generalized
# code for hybrid parallelisms.
if axis not in self.axes:
return []
# Grab all axes but `axis`
other_axes = [a for a in self.axes if a != axis]
lists = []
# Construct all combinations of coords with other_axes
ranges = [range(self.get_dim(a)) for a in other_axes]
for coord in cartesian_product(*ranges):
other_keys = {a: coord[other_axes.index(a)] for a in other_axes}
# now go over all ranks in `axis`.
sub_list = []
for axis_key in range(self.get_dim(axis)):
key = self.ProcessCoord(**other_keys, **{axis: axis_key})
sub_list.append(self.mapping[key])
lists.append(sub_list)
return lists
def filter_match(self, **filter_kwargs):
"""Return the list of ranks whose coordinates match the provided criteria.
Example:
>>> X = ProcessTopology(axes=['pipe', 'data', 'model'], dims=[2, 2, 2])
>>> X.filter_match(pipe=0, data=1)
[2, 3]
>>> [X.get_coord(rank) for rank in X.filter_match(pipe=0, data=1)]
[ProcessCoord(pipe=0, data=1, model=0), ProcessCoord(pipe=0, data=1, model=1)]
Arguments:
**filter_kwargs (dict): criteria used to select coordinates.
Returns:
The list of ranks whose coordinates match filter_kwargs.
"""
def _filter_helper(x):
for key, val in filter_kwargs.items():
if getattr(x, key) != val:
return False
return True
coords = filter(_filter_helper, self.mapping.keys())
return [self.mapping[coord] for coord in coords]
def get_axis_list(self, axis, idx):
"""Returns the list of global ranks whose coordinate in an axis is idx.
For example:
>>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])
>>> X.get_axis_list(axis='x', idx=0)
[0, 1, 2]
>>> X.get_axis_list(axis='y', idx=0)
[0, 3]
"""
# This could be faster by generating the desired keys directly instead of
# filtering.
axis_num = self.axes.index(axis)
ranks = [self.mapping[k] for k in self.mapping.keys() if k[axis_num] == idx]
return ranks
def world_size(self):
return len(self.mapping)
def __str__(self):
return str(self.mapping)
def _prime_factors(N):
""" Returns the prime factorization of positive integer N. """
if N <= 0:
raise ValueError("Values must be strictly positive.")
primes = []
while N != 1:
for candidate in range(2, N + 1):
if N % candidate == 0:
primes.append(candidate)
N //= candidate
break
return primes
class PipeDataParallelTopology(ProcessTopology):
""" A topology specialization for hybrid data and pipeline parallelism.
Uses data parallelism on the last dimension to encourage gradient
reductions to use high-bandwidth intra-node links and lower-volume
pipeline communications to use low-bandwidth inter-node links.
"""
def __init__(self, num_pp, num_dp):
super().__init__(axes=['pipe', 'data'], dims=[num_pp, num_dp])
class PipeModelDataParallelTopology(ProcessTopology):
""" A topology for hybrid pipeline, model, and data parallelism. """
def __init__(self, num_pp, num_mp, num_dp):
super().__init__(axes=['pipe', 'data', 'model'], dims=[num_pp, num_dp, num_mp])
class PipelineParallelGrid:
"""Implements a grid object that stores the data parallel ranks
corresponding to each of the model parallel stages
The grid object organizes the processes in a distributed pytorch job
into a 2D grid, of stage_id and data_parallel_id.
self.stage_id and self.data_parallel_id stores the stage id
and the data parallel id of current process.
self.dp_group groups the processes by stage_id.
self.dp_group[i], is a list containing all process ranks whose
stage_id is i.
self.p2p_groups stores a list of tuple, where each tuple
stores process ranks of adjacent stages for a given data_parallel_id.
For example if num_stage is 5 then a tuple [7,8] represents stages [3, 4],
with data_parallel id = 1. A stage wrap around will appear as non-adjacent ranks,
for example tuple [4,0] with representing wrap-around stage 4 and 0, for
data_parallel_id = 0, or similarly [9,5] represents wrapped around stages [4,0]
for data_parallel_id = 1.
"""
def __init__(self, topology=None, process_group=None):
# TODO use process_group if provided
self.global_rank = dist.get_rank()
self.world_size = dist.get_world_size()
if topology is not None:
if self.global_rank == 0:
print('Using topology:', topology)
self._topo = topology
else:
num_pp = 1
num_dp = 1
for idx, prime in enumerate(_prime_factors(self.world_size)):
if idx % 2 == 0:
num_pp *= prime
else:
num_dp *= prime
self._topo = PipeDataParallelTopology(num_dp=num_dp, num_pp=num_pp)
self.data_parallel_size = max(self._topo.get_dim('data'), 1)
self.pipe_parallel_size = max(self._topo.get_dim('pipe'), 1)
self.model_parallel_size = max(self._topo.get_dim('model'), 1)
self.slice_parallel_size = self.model_parallel_size
assert self._is_grid_valid(), "Invalid Grid"
self.stage_id = self.get_stage_id()
self.data_parallel_id = self.get_data_parallel_id()
# Create new ProcessGroups for all model parallelism. DeepSpeedLight uses these
# to detect overflow, etc.
self.ds_model_proc_group = None
self.ds_model_rank = -1
for dp in range(self.data_parallel_size):
ranks = sorted(self._topo.get_axis_list(axis='data', idx=dp))
if self.global_rank == 0:
#print(f'RANK={self.global_rank} building DeepSpeed model group: {ranks}')
pass
proc_group = dist.new_group(ranks=ranks)
if self.global_rank in ranks:
self.ds_model_proc_group = proc_group
self.ds_model_world_size = len(ranks)
self.ds_model_rank = ranks.index(self.global_rank)
assert self.ds_model_rank > -1
assert self.ds_model_proc_group is not None
# Create new ProcessGroup for gradient all-reduces - these are the data parallel groups
self.dp_group = []
self.dp_groups = self._topo.get_axis_comm_lists('data')
for g in self.dp_groups:
proc_group = dist.new_group(ranks=g)
if self.global_rank in g:
self.dp_group = g
self.dp_proc_group = proc_group
self.is_first_stage = (self.stage_id == 0)
self.is_last_stage = (self.stage_id == (self.pipe_parallel_size - 1))
self.p2p_groups = self._build_p2p_groups()
# Create new ProcessGroup for pipeline collectives - these are pipe parallel groups
self.pp_group = []
self.pp_proc_group = None
self.pipe_groups = self._topo.get_axis_comm_lists('pipe')
for ranks in self.pipe_groups:
if self.global_rank == 0:
#print(f'RANK={self.global_rank} building pipeline group: {ranks}')
pass
proc_group = dist.new_group(ranks=ranks)
if self.global_rank in ranks:
self.pp_group = ranks
self.pp_proc_group = proc_group
assert self.pp_proc_group is not None
# Create new ProcessGroup for model (tensor-slicing) collectives
# Short circuit case without model parallelism.
# TODO: it would be nice if topology had bcast semantics to avoid this branching
# case?
if self.model_parallel_size == 1:
for group_rank in range(self.world_size):
group_rank = [group_rank]
group = dist.new_group(ranks=group_rank)
if group_rank[0] == self.global_rank:
self.slice_group = group_rank
self.slice_proc_group = group
return
else:
self.mp_group = []
self.model_groups = self._topo.get_axis_comm_lists('model')
for g in self.model_groups:
proc_group = dist.new_group(ranks=g)
if self.global_rank in g:
self.slice_group = g
self.slice_proc_group = proc_group
def get_stage_id(self):
return self._topo.get_coord(rank=self.global_rank).pipe
def get_data_parallel_id(self):
return self._topo.get_coord(rank=self.global_rank).data
def _build_p2p_groups(self):
"""Groups for sending and receiving activations and gradients across model
parallel stages.
"""
comm_lists = self._topo.get_axis_comm_lists('pipe')
p2p_lists = []
for rank in range(self.world_size):
for l in comm_lists:
assert len(l) == self.pipe_parallel_size
if rank in l:
idx = l.index(rank)
buddy_rank = l[(idx + 1) % self.pipe_parallel_size]
p2p_lists.append([rank, buddy_rank])
break # next global rank
assert len(p2p_lists) == self.world_size
return p2p_lists
def _is_grid_valid(self):
ranks = 1
for ax in self._topo.get_axis_names():
ranks *= self._topo.get_dim(ax)
return ranks == dist.get_world_size()
#returns the global rank of the process with the provided stage id
#which has the same data_parallel_id as caller process
def stage_to_global(self, stage_id, **kwargs):
me = self._topo.get_coord(self.global_rank)
transform = me._replace(pipe=stage_id, **kwargs)._asdict()
return self._topo.get_rank(**transform)
def topology(self):
return self._topo
# MPU functions for DeepSpeed integration
def get_global_rank(self):
return self.global_rank
def get_pipe_parallel_rank(self):
""" The stage of the pipeline this rank resides in. """
return self.get_stage_id()
def get_pipe_parallel_world_size(self):
""" The number of stages in the pipeline. """
return self.pipe_parallel_size
def get_pipe_parallel_group(self):
""" The group of ranks within the same pipeline. """
return self.pp_proc_group
def get_data_parallel_rank(self):
""" Which pipeline this rank resides in. """
return self.data_parallel_id
def get_data_parallel_world_size(self):
""" The number of pipelines. """
return self.data_parallel_size
def get_data_parallel_group(self):
""" The group of ranks within the same stage of all pipelines. """
return self.dp_proc_group
# These are model parallel groups across all types of model parallelism.
# Deepspeed uses them to detect overflow, etc.
def get_model_parallel_rank(self):
return self.ds_model_rank
def get_model_parallel_world_size(self):
return self.ds_model_world_size
def get_model_parallel_group(self):
return self.ds_model_proc_group
# For Megatron-style tensor slicing
def get_slice_parallel_rank(self):
if 'model' in self._topo.get_axis_names():
return self._topo.get_coord(rank=self.global_rank).model
else:
return 0
def get_slice_parallel_world_size(self):
return self.slice_parallel_size
def get_slice_parallel_group(self):
return self.slice_proc_group