forked from horovod/horovod
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mpi_ops.cc
460 lines (402 loc) · 14.7 KB
/
mpi_ops.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
// Modifications copyright (C) 2018 Uber Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <queue>
#include <thread>
#include <unordered_map>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#define EIGEN_USE_THREADS
#if HAVE_CUDA
#include "tensorflow/stream_executor/stream.h"
#endif
#define OMPI_SKIP_MPICXX
#include "../common/operations.h"
using namespace tensorflow;
using namespace horovod;
namespace horovod {
namespace tensorflow {
namespace {
Status ConvertStatus(common::Status status) {
switch (status.type()) {
case common::OK:
return Status::OK();
case common::UNKNOWN_ERROR:
return errors::Unknown(status.reason());
case common::PRECONDITION_ERROR:
return errors::FailedPrecondition(status.reason());
case common::ABORTED:
return errors::Aborted(status.reason());
default:
return errors::Unknown("Unknown error.");
}
}
common::Status ConvertStatus(Status status) {
switch (status.code()) {
case error::Code::OK:
return common::Status::OK();
case error::Code::UNKNOWN:
return common::Status::UnknownError(status.error_message());
case error::Code::FAILED_PRECONDITION:
return common::Status::PreconditionError(status.error_message());
case error::Code::ABORTED:
return common::Status::Aborted(status.error_message());
default:
return common::Status::UnknownError("Unknown error.");
}
}
#if HAVE_CUDA
class TFReadyEvent : public common::ReadyEvent {
public:
TFReadyEvent(DeviceContext* device_context);
bool Ready() const override;
private:
std::shared_ptr<perftools::gputools::Event> event_;
};
#endif
class TFPersistentBuffer : public common::PersistentBuffer {
public:
TFPersistentBuffer(OpKernelContext* context, int64_t size);
virtual const void*
AccessData(std::shared_ptr<common::OpContext> context) const override;
private:
std::shared_ptr<PersistentTensor> tensor_;
};
class TFTensor : public common::Tensor {
public:
TFTensor(::tensorflow::Tensor& tensor);
virtual const common::MPIDataType dtype() const override;
virtual const common::TensorShape shape() const override;
virtual const void* data() const override;
virtual int64_t size() const override;
protected:
::tensorflow::Tensor tensor_;
};
class TFOpContext : public common::OpContext {
public:
TFOpContext(OpKernelContext* context);
virtual common::Status AllocatePersistent(
int64_t size, std::shared_ptr<common::PersistentBuffer>* tensor) override;
virtual common::Status
AllocateOutput(common::TensorShape shape,
std::shared_ptr<common::Tensor>* tensor) override;
virtual common::Framework framework() const override;
OpKernelContext* GetKernelContext() const;
private:
OpKernelContext* context_ = nullptr;
};
#if HAVE_CUDA
TFReadyEvent::TFReadyEvent(DeviceContext* device_context) {
auto executor = device_context->stream()->parent();
auto ready_event = new perftools::gputools::Event(executor);
ready_event->Init();
device_context->stream()->ThenRecordEvent(ready_event);
event_ = std::shared_ptr<perftools::gputools::Event>(ready_event);
}
bool TFReadyEvent::Ready() const {
return event_->PollForStatus() !=
perftools::gputools::Event::Status::kPending;
}
#endif
TFPersistentBuffer::TFPersistentBuffer(OpKernelContext* context, int64_t size) {
tensor_ = std::make_shared<PersistentTensor>();
TensorShape buffer_shape;
buffer_shape.AddDim(size);
Tensor* unused;
Status status = context->allocate_persistent(DT_INT8, buffer_shape,
tensor_.get(), &unused);
if (!status.ok()) {
throw status;
}
#if HAVE_CUDA
// On GPU allocation is asynchronous, we need to wait for it to
// complete.
auto device_context = context->op_device_context();
if (device_context != nullptr) {
device_context->stream()->BlockHostUntilDone();
}
#endif
}
const void* TFPersistentBuffer::AccessData(
std::shared_ptr<common::OpContext> context) const {
// It's safe to cast context to TFOpContext, since only TFOpContext creates
// TFPersistentBuffer.
return (const void *)tensor_
->AccessTensor(
std::dynamic_pointer_cast<TFOpContext>(context)->GetKernelContext())
->tensor_data()
.data();
}
TFTensor::TFTensor(::tensorflow::Tensor& tensor) : tensor_(tensor) {}
const common::MPIDataType TFTensor::dtype() const {
switch (tensor_.dtype()) {
case DT_UINT8:
return common::HOROVOD_UINT8;
case DT_INT8:
return common::HOROVOD_INT8;
case DT_UINT16:
return common::HOROVOD_UINT16;
case DT_INT16:
return common::HOROVOD_INT16;
case DT_INT32:
return common::HOROVOD_INT32;
case DT_INT64:
return common::HOROVOD_INT64;
case DT_FLOAT:
return common::HOROVOD_FLOAT32;
case DT_DOUBLE:
return common::HOROVOD_FLOAT64;
case DT_BOOL:
return common::HOROVOD_BOOL;
default:
throw std::logic_error("Invalid tensor type.");
}
}
const common::TensorShape TFTensor::shape() const {
common::TensorShape shape;
for (auto dim : tensor_.shape()) {
shape.AddDim(dim.size);
}
return shape;
}
const void* TFTensor::data() const { return (const void*)tensor_.tensor_data().data(); }
int64_t TFTensor::size() const { return (int64_t)tensor_.tensor_data().size(); }
TFOpContext::TFOpContext(OpKernelContext* context) : context_(context) {}
common::Status TFOpContext::AllocatePersistent(
int64_t size, std::shared_ptr<common::PersistentBuffer>* tensor) {
try {
*tensor = std::make_shared<TFPersistentBuffer>(context_, size);
return common::Status::OK();
} catch (Status& status) {
return ConvertStatus(status);
}
}
common::Status
TFOpContext::AllocateOutput(common::TensorShape shape,
std::shared_ptr<common::Tensor>* tensor) {
TensorShape tf_shape;
for (int idx = 0; idx < shape.dims(); idx++) {
tf_shape.AddDim(shape.dim_size(idx));
}
Tensor* tf_tensor;
Status status = context_->allocate_output(0, tf_shape, &tf_tensor);
if (status.ok()) {
*tensor = std::make_shared<TFTensor>(*tf_tensor);
}
#if HAVE_CUDA
// On GPU allocation is asynchronous, we need to wait for it to
// complete.
auto device_context = context_->op_device_context();
if (device_context != nullptr) {
device_context->stream()->BlockHostUntilDone();
}
#endif
return ConvertStatus(status);
}
common::Framework TFOpContext::framework() const {
return common::Framework::TENSORFLOW;
}
OpKernelContext* TFOpContext::GetKernelContext() const { return context_; }
int GetDeviceID(OpKernelContext* context) {
int device = CPU_DEVICE_ID;
if (context->device() != nullptr &&
context->device()->tensorflow_gpu_device_info() != nullptr) {
device = context->device()->tensorflow_gpu_device_info()->gpu_id;
}
return device;
}
// On GPU this event will signal that data is ready, and tensors are
// allocated.
common::ReadyEvent* RecordReadyEvent(OpKernelContext* context) {
#if HAVE_CUDA
auto device_context = context->op_device_context();
if (device_context != nullptr) {
return new TFReadyEvent(device_context);
}
#endif
return nullptr;
}
} // namespace
class HorovodAllreduceOp : public AsyncOpKernel {
public:
explicit HorovodAllreduceOp(OpKernelConstruction* context)
: AsyncOpKernel(context) {}
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
OP_REQUIRES_OK_ASYNC(context, ConvertStatus(common::CheckInitialized()),
done);
auto node_name = name();
auto device = GetDeviceID(context);
auto tensor = context->input(0);
Tensor* output;
OP_REQUIRES_OK_ASYNC(
context, context->allocate_output(0, tensor.shape(), &output), done);
// ReadyEvent makes sure input tensor is ready, and output is allocated.
auto ready_event = std::shared_ptr<common::ReadyEvent>(RecordReadyEvent(context));
auto hvd_context = std::make_shared<TFOpContext>(context);
auto hvd_tensor = std::make_shared<TFTensor>(tensor);
auto hvd_output = std::make_shared<TFTensor>(*output);
auto enqueue_result = EnqueueTensorAllreduce(
hvd_context, hvd_tensor, hvd_output, ready_event, node_name, device,
[context, done](const common::Status& status) {
context->SetStatus(ConvertStatus(status));
done();
});
OP_REQUIRES_OK_ASYNC(context, ConvertStatus(enqueue_result), done);
}
};
REGISTER_KERNEL_BUILDER(Name("HorovodAllreduce").Device(DEVICE_CPU),
HorovodAllreduceOp);
#if HOROVOD_GPU_ALLREDUCE
REGISTER_KERNEL_BUILDER(Name("HorovodAllreduce").Device(DEVICE_GPU),
HorovodAllreduceOp);
#endif
REGISTER_OP("HorovodAllreduce")
.Attr("T: {int32, int64, float32, float64}")
.Input("tensor: T")
.Output("sum: T")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
})
.Doc(R"doc(
Perform an MPI Allreduce on a tensor. All other processes that do a reduction
on a tensor with the same name must have the same dimension for that tensor.
Tensors are reduced with other tensors that have the same node name for the
allreduce.
Arguments
tensor: A tensor to reduce.
Output
sum: A tensor with the same shape as `tensor`, summed across all MPI processes.
)doc");
class HorovodAllgatherOp : public AsyncOpKernel {
public:
explicit HorovodAllgatherOp(OpKernelConstruction* context)
: AsyncOpKernel(context) {}
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
OP_REQUIRES_OK_ASYNC(context, ConvertStatus(common::CheckInitialized()),
done);
auto node_name = name();
auto device = GetDeviceID(context);
auto tensor = context->input(0);
// ReadyEvent makes sure input tensor is ready. We cannot pre-allocate
// output for allgather, since shape of result is only known after all
// ranks make a request.
auto ready_event = std::shared_ptr<common::ReadyEvent>(RecordReadyEvent(context));
auto hvd_context = std::make_shared<TFOpContext>(context);
auto hvd_tensor = std::make_shared<TFTensor>(tensor);
auto enqueue_result = EnqueueTensorAllgather(
hvd_context, hvd_tensor, ready_event, node_name, device,
[context, done](const common::Status& status) {
context->SetStatus(ConvertStatus(status));
done();
});
OP_REQUIRES_OK_ASYNC(context, ConvertStatus(enqueue_result), done);
}
}; // namespace tensorflow
REGISTER_KERNEL_BUILDER(Name("HorovodAllgather").Device(DEVICE_CPU),
HorovodAllgatherOp);
#if HOROVOD_GPU_ALLGATHER
REGISTER_KERNEL_BUILDER(Name("HorovodAllgather").Device(DEVICE_GPU),
HorovodAllgatherOp);
#endif
REGISTER_OP("HorovodAllgather")
.Attr(
"T: {uint8, int8, uint16, int16, int32, int64, float32, float64, bool}")
.Input("tensor: T")
.Output("output: T")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle output;
TF_RETURN_IF_ERROR(
c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &output));
c->set_output(0, output);
return Status::OK();
})
.Doc(R"doc(
Perform an MPI Allgather on a tensor. All other processes that do a gather on a
tensor with the same name must have the same rank for that tensor, and have the
same dimension on all but the first dimension.
Arguments
tensor: A tensor to gather.
Output
gathered: A tensor with the same shape as `tensor` except for the first dimension.
)doc");
class HorovodBroadcastOp : public AsyncOpKernel {
public:
explicit HorovodBroadcastOp(OpKernelConstruction* context)
: AsyncOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("root_rank", &root_rank_));
}
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
OP_REQUIRES_OK_ASYNC(context, ConvertStatus(common::CheckInitialized()),
done);
auto node_name = name();
auto device = GetDeviceID(context);
auto tensor = context->input(0);
Tensor* output = nullptr;
if (common::horovod_rank() == root_rank_) {
context->set_output(0, tensor);
} else {
OP_REQUIRES_OK_ASYNC(
context, context->allocate_output(0, tensor.shape(), &output), done);
}
// ReadyEvent makes sure input tensor is ready, and output is allocated.
auto ready_event = std::shared_ptr<common::ReadyEvent>(RecordReadyEvent(context));
auto hvd_context = std::make_shared<TFOpContext>(context);
auto hvd_tensor = std::make_shared<TFTensor>(tensor);
std::shared_ptr<TFTensor> hvd_output = nullptr;
if (output != nullptr) {
hvd_output = std::make_shared<TFTensor>(*output);
}
auto enqueue_result = EnqueueTensorBroadcast(
hvd_context, hvd_tensor, hvd_output, root_rank_, ready_event, node_name,
device, [context, done](const common::Status& status) {
context->SetStatus(ConvertStatus(status));
done();
});
OP_REQUIRES_OK_ASYNC(context, ConvertStatus(enqueue_result), done);
}
private:
int root_rank_;
};
REGISTER_KERNEL_BUILDER(Name("HorovodBroadcast").Device(DEVICE_CPU),
HorovodBroadcastOp);
#if HOROVOD_GPU_BROADCAST
REGISTER_KERNEL_BUILDER(Name("HorovodBroadcast").Device(DEVICE_GPU),
HorovodBroadcastOp);
#endif
REGISTER_OP("HorovodBroadcast")
.Attr(
"T: {uint8, int8, uint16, int16, int32, int64, float32, float64, bool}")
.Attr("root_rank: int")
.Input("tensor: T")
.Output("output: T")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
})
.Doc(R"doc(
Perform an MPI Broadcast on a tensor. All other processes that do a broadcast
on a tensor with the same name must have the same dimension for that tensor.
Arguments
tensor: A tensor to broadcast.
root_rank: Rank that will send data, other ranks will receive data.
Output
output: A tensor with the same shape as `tensor` and same value as
`tensor` on root rank.
)doc");
} // namespace tensorflow
} // namespace horovod