forked from horovod/horovod
/
mpi_ops.py
406 lines (310 loc) · 14.6 KB
/
mpi_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
# Copyright 2018 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Load all the necessary PyTorch C types.
import torch
from horovod.torch import mpi_lib_impl
from horovod.torch import mpi_lib
from horovod.torch import rank, size
# Schema: handle -> input, output
# We keep input in order to make sure it does not get garbage collected
# before the operation is finished.
_handle_map = {}
# Null pointer.
_NULL = mpi_lib._ffi.NULL
def _check_function(function_factory, tensor):
function = function_factory(tensor)
if not hasattr(mpi_lib, function):
raise ValueError('Tensor type %s is not supported.' % tensor.type())
if not tensor.is_contiguous():
raise ValueError('Tensor is required to be contiguous.')
return function
def _allreduce_function_factory(tensor):
return 'horovod_torch_allreduce_async_' + tensor.type().replace('.', '_')
def _allreduce_async(tensor, output, average, name):
function = _check_function(_allreduce_function_factory, tensor)
handle = getattr(mpi_lib, function)(tensor, output, average,
name.encode() if name is not None else _NULL)
_handle_map[handle] = (tensor, output)
return handle
def allreduce_async(tensor, average=True, name=None):
"""
A function that performs asynchronous averaging or summation of the input tensor
over all the Horovod processes. The input tensor is not modified.
The reduction operation is keyed by the name. If name is not provided, an incremented
auto-generated name is used. The tensor type and shape must be the same on all
Horovod processes for a given name. The reduction will not start until all processes
are ready to send and receive the tensor.
Arguments:
tensor: A tensor to average and sum.
average: A flag indicating whether to compute average or summation,
defaults to average.
name: A name of the reduction operation.
Returns:
A handle to the allreduce operation that can be used with `poll()` or
`synchronize()`.
"""
output = tensor.new(tensor.shape)
return _allreduce_async(tensor, output, average, name)
class HorovodAllreduce(torch.autograd.Function):
"""An autograd function that performs allreduce on a tensor."""
@staticmethod
def forward(ctx, tensor, average, name):
ctx.average = average
handle = allreduce_async(tensor, average, name)
return synchronize(handle)
@staticmethod
def backward(ctx, grad_output):
return allreduce(grad_output, ctx.average), None, None
def allreduce(tensor, average=True, name=None):
"""
A function that performs averaging or summation of the input tensor over all the
Horovod processes. The input tensor is not modified.
The reduction operation is keyed by the name. If name is not provided, an incremented
auto-generated name is used. The tensor type and shape must be the same on all
Horovod processes for a given name. The reduction will not start until all processes
are ready to send and receive the tensor.
This acts as a thin wrapper around an autograd function. If your input
tensor requires gradients, then callings this function will allow gradients
to be computed and backpropagated.
Arguments:
tensor: A tensor to average and sum.
average: A flag indicating whether to compute average or summation,
defaults to average.
name: A name of the reduction operation.
Returns:
A tensor of the same shape and type as `tensor`, averaged or summed across all
processes.
"""
return HorovodAllreduce.apply(tensor, average, name)
def allreduce_async_(tensor, average=True, name=None):
"""
A function that performs asynchronous in-place averaging or summation of the input
tensor over all the Horovod processes.
The reduction operation is keyed by the name. If name is not provided, an incremented
auto-generated name is used. The tensor type and shape must be the same on all
Horovod processes for a given name. The reduction will not start until all processes
are ready to send and receive the tensor.
Arguments:
tensor: A tensor to average and sum.
average: A flag indicating whether to compute average or summation,
defaults to average.
name: A name of the reduction operation.
Returns:
A handle to the allreduce operation that can be used with `poll()` or
`synchronize()`.
"""
return _allreduce_async(tensor, tensor, average, name)
def allreduce_(tensor, average=True, name=None):
"""
A function that performs in-place averaging or summation of the input tensor over
all the Horovod processes.
The reduction operation is keyed by the name. If name is not provided, an incremented
auto-generated name is used. The tensor type and shape must be the same on all
Horovod processes for a given name. The reduction will not start until all processes
are ready to send and receive the tensor.
Arguments:
tensor: A tensor to average and sum.
average: A flag indicating whether to compute average or summation,
defaults to average.
name: A name of the reduction operation.
Returns:
A tensor of the same shape and type as `tensor`, averaged or summed across all
processes.
"""
handle = allreduce_async_(tensor, average, name)
return synchronize(handle)
def _allgather_function_factory(tensor):
return 'horovod_torch_allgather_async_' + tensor.type().replace('.', '_')
def _allgather_async(tensor, output, name):
function = _check_function(_allgather_function_factory, tensor)
handle = getattr(mpi_lib, function)(
tensor, output, name.encode() if name is not None else _NULL)
_handle_map[handle] = (tensor, output)
return handle
def allgather_async(tensor, name=None):
"""
A function that asynchronously concatenates the input tensor with the same input
tensor on all other Horovod processes. The input tensor is not modified.
The concatenation is done on the first dimension, so the input tensors on the
different processes must have the same rank and shape, except for the first
dimension, which is allowed to be different.
Arguments:
tensor: A tensor to allgather.
name: A name of the allgather operation.
Returns:
A handle to the allgather operation that can be used with `poll()` or
`synchronize()`.
"""
output = tensor.new()
return _allgather_async(tensor, output, name)
class HorovodAllgather(torch.autograd.Function):
"""An autograd function that performs allgather on a tensor."""
@staticmethod
def forward(ctx, tensor, name):
ctx.dim = tensor.shape[0]
handle = allgather_async(tensor, name)
return synchronize(handle)
@staticmethod
def backward(ctx, grad_output):
grad_reduced = allreduce(grad_output, average=False)
dim_t = torch.IntTensor([ctx.dim])
dim = allgather(dim_t).view(size())
r = rank()
offset = torch.sum(dim.narrow(0, 0, r)).data[0] if r != 0 else 0
return grad_reduced.narrow(0, offset, ctx.dim), None
def allgather(tensor, name=None):
"""
A function that concatenates the input tensor with the same input tensor on
all other Horovod processes. The input tensor is not modified.
The concatenation is done on the first dimension, so the input tensors on the
different processes must have the same rank and shape, except for the first
dimension, which is allowed to be different.
This acts as a thin wrapper around an autograd function. If your input
tensor requires gradients, then callings this function will allow gradients
to be computed and backpropagated.
Arguments:
tensor: A tensor to allgather.
name: A name of the allgather operation.
Returns:
A tensor of the same type as `tensor`, concatenated on dimension zero
across all processes. The shape is identical to the input shape, except for
the first dimension, which may be greater and is the sum of all first
dimensions of the tensors in different Horovod processes.
"""
return HorovodAllgather.apply(tensor, name)
def _broadcast_function_factory(tensor):
return 'horovod_torch_broadcast_async_' + tensor.type().replace('.', '_')
def _broadcast_async(tensor, output, root_rank, name):
function = _check_function(_broadcast_function_factory, tensor)
handle = getattr(mpi_lib, function)(
tensor, output, root_rank, name.encode() if name is not None else _NULL)
_handle_map[handle] = (tensor, output)
return handle
def broadcast_async(tensor, root_rank, name=None):
"""
A function that asynchronously broadcasts the input tensor on root rank to the same
input tensor on all other Horovod processes. The input tensor is not modified.
The broadcast operation is keyed by the name. If name is not provided, an incremented
auto-generated name is used. The tensor type and shape must be the same on all
Horovod processes for a given name. The broadcast will not start until all processes
are ready to send and receive the tensor.
Arguments:
tensor: A tensor to broadcast.
root_rank: The rank to broadcast the value from.
name: A name of the broadcast operation.
Returns:
A handle to the broadcast operation that can be used with `poll()` or
`synchronize()`.
"""
output = tensor.new(tensor.shape)
return _broadcast_async(tensor, output, root_rank, name)
class HorovodBroadcast(torch.autograd.Function):
"""An autograd function that broadcasts a tensor."""
@staticmethod
def forward(ctx, tensor, root_rank, name):
ctx.root_rank = root_rank
handle = broadcast_async(tensor, root_rank, name)
return synchronize(handle)
@staticmethod
def backward(ctx, grad_output):
grad_reduced = allreduce(grad_output, average=False)
if rank() != ctx.root_rank:
grad_reduced *= 0
return grad_reduced, None, None
def broadcast(tensor, root_rank, name=None):
"""
A function that broadcasts the input tensor on root rank to the same input tensor
on all other Horovod processes. The input tensor is not modified.
The broadcast operation is keyed by the name. If name is not provided, an incremented
auto-generated name is used. The tensor type and shape must be the same on all
Horovod processes for a given name. The broadcast will not start until all processes
are ready to send and receive the tensor.
This acts as a thin wrapper around an autograd function. If your input
tensor requires gradients, then callings this function will allow gradients
to be computed and backpropagated.
Arguments:
tensor: A tensor to broadcast.
root_rank: The rank to broadcast the value from.
name: A name of the broadcast operation.
Returns:
A tensor of the same shape and type as `tensor`, with the value broadcasted
from root rank.
"""
return HorovodBroadcast.apply(tensor, root_rank, name)
def broadcast_async_(tensor, root_rank, name=None):
"""
A function that asynchronously broadcasts the input tensor on root rank to the same
input tensor on all other Horovod processes. The operation is performed in-place.
The broadcast operation is keyed by the name. If name is not provided, an incremented
auto-generated name is used. The tensor type and shape must be the same on all
Horovod processes for a given name. The broadcast will not start until all processes
are ready to send and receive the tensor.
Arguments:
tensor: A tensor to broadcast.
root_rank: The rank to broadcast the value from.
name: A name of the broadcast operation.
Returns:
A handle to the broadcast operation that can be used with `poll()` or
`synchronize()`.
"""
return _broadcast_async(tensor, tensor, root_rank, name)
def broadcast_(tensor, root_rank, name=None):
"""
A function that broadcasts the input tensor on root rank to the same input tensor
on all other Horovod processes. The operation is performed in-place.
The broadcast operation is keyed by the name. If name is not provided, an incremented
auto-generated name is used. The tensor type and shape must be the same on all
Horovod processes for a given name. The broadcast will not start until all processes
are ready to send and receive the tensor.
Arguments:
tensor: A tensor to broadcast.
root_rank: The rank to broadcast the value from.
name: A name of the broadcast operation.
Returns:
A tensor of the same shape and type as `tensor`, with the value broadcasted
from root rank.
"""
handle = broadcast_async_(tensor, root_rank, name)
return synchronize(handle)
def poll(handle):
"""
Polls an allreduce, allgather or broadcast handle to determine whether underlying
asynchronous operation has completed. After `poll()` returns `True`, `synchronize()`
will return without blocking.
Arguments:
handle: A handle returned by an allreduce, allgather or broadcast asynchronous
operation.
Returns:
A flag indicating whether the operation has completed.
"""
return mpi_lib.horovod_torch_poll(handle) != 0
def synchronize(handle):
"""
Synchronizes an asynchronous allreduce, allgather or broadcast operation until
it's completed. Returns the result of the operation.
Arguments:
handle: A handle returned by an allreduce, allgather or broadcast asynchronous
operation.
Returns:
An output tensor of the operation.
"""
if handle not in _handle_map:
return
mpi_lib.horovod_torch_wait_and_clear(handle)
_, output = _handle_map.pop(handle)
return output