/
xla_mpi_ops.cc
575 lines (504 loc) · 20.4 KB
/
xla_mpi_ops.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
// Copyright 2021 The TensorFlow Authors. All Rights Reserved.
// Modifications copyright (C) 2021, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <queue>
#include <thread>
#include <unordered_map>
#if TENSORFLOW_VERSION >= 2005000000
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/human_readable_json.h"
#if HAVE_GPU
#if HAVE_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#define CUDA_CALL(func) \
{ \
cudaError_t e = (func); \
CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
<< "CUDA: " << cudaGetErrorString(e); \
}
#define OMPI_SKIP_MPICXX
#include "../common/operations.h"
#include "../common/utils/env_parser.h"
#include "./custom_call_config_generated.h"
using namespace tensorflow;
namespace horovod {
namespace xla {
namespace {
common::DataType GetHVDType(::xla::PrimitiveType type) {
switch (type) {
case ::xla::U8:
return common::HOROVOD_UINT8;
case ::xla::S8:
return common::HOROVOD_INT8;
case ::xla::U16:
return common::HOROVOD_UINT16;
case ::xla::S16:
return common::HOROVOD_INT16;
case ::xla::S32:
return common::HOROVOD_INT32;
case ::xla::S64:
return common::HOROVOD_INT64;
case ::xla::F16:
return common::HOROVOD_FLOAT16;
case ::xla::F32:
return common::HOROVOD_FLOAT32;
case ::xla::F64:
return common::HOROVOD_FLOAT64;
case ::xla::PRED:
return common::HOROVOD_BOOL;
default:
throw std::logic_error("Invalid XLA tensor type.");
}
}
// CustomCallConfig stores configurations of Horovod ops. We pass this config
// to ::xla::CustomCall so that the XLA CustomCall can represent various Horovod
// ops. Flatbuffer is used to serialize the config into string to conform to the
// XLA CustomCall interface.
class CustomCallConfig {
public:
std::string SerializeToString();
void ParseFromString(std::string);
public:
std::string tensor_name_;
common::DataType tensor_type_;
std::vector<std::vector<int64_t>> input_shapes_;
std::vector<std::vector<int64_t>> output_shapes_;
float prescale_factor_;
float postscale_factor_;
int root_rank_;
int reduce_op_;
int process_set_id_;
};
std::string CustomCallConfig::SerializeToString() {
flatbuffers::FlatBufferBuilder fbb(1024);
std::vector<flatbuffers::Offset<wire::TensorShape>> input_shapes_obj;
absl::c_for_each(input_shapes_, [&](const std::vector<int64_t>& dims) {
input_shapes_obj.push_back(wire::CreateTensorShapeDirect(fbb, &dims));
});
std::vector<flatbuffers::Offset<wire::TensorShape>> output_shapes_obj;
absl::c_for_each(output_shapes_, [&](const std::vector<int64_t>& dims) {
output_shapes_obj.push_back(wire::CreateTensorShapeDirect(fbb, &dims));
});
auto wire = wire::CreateCustomCallConfigDirect(
fbb, tensor_name_.c_str(), (common::wire::DataType)tensor_type_,
&input_shapes_obj, &output_shapes_obj, prescale_factor_,
postscale_factor_, root_rank_, reduce_op_, process_set_id_);
fbb.Finish(wire);
uint8_t* buf = fbb.GetBufferPointer();
auto size = fbb.GetSize();
return std::string((char*)buf, size);
}
void CustomCallConfig::ParseFromString(std::string input) {
const wire::CustomCallConfig* obj =
flatbuffers::GetRoot<wire::CustomCallConfig>(
(const uint8_t*)input.data());
tensor_name_ = obj->tensor_name()->str();
tensor_type_ = (common::DataType)obj->tensor_type();
for (auto it = obj->input_shapes()->begin(); it != obj->input_shapes()->end();
it++) {
auto shape_obj = *it;
input_shapes_.push_back(std::vector<int64_t>(shape_obj->dims()->begin(),
shape_obj->dims()->end()));
}
for (auto it = obj->output_shapes()->begin();
it != obj->output_shapes()->end(); it++) {
auto shape_obj = *it;
output_shapes_.push_back(std::vector<int64_t>(shape_obj->dims()->begin(),
shape_obj->dims()->end()));
}
prescale_factor_ = obj->prescale_factor();
postscale_factor_ = obj->postscale_factor();
root_rank_ = obj->root_rank();
reduce_op_ = obj->reduce_op();
process_set_id_ = obj->process_set_id();
if (VLOG_IS_ON(2)) {
VLOG(2) << "tensor_name " << tensor_name_;
VLOG(2) << "tensor_type " << tensor_type_;
VLOG(2) << "prescale_factor = " << prescale_factor_;
VLOG(2) << "postscale_factor = " << postscale_factor_;
VLOG(2) << "root_rank = " << root_rank_;
VLOG(2) << "reduce_op = " << reduce_op_;
VLOG(2) << "process_set_id = " << process_set_id_;
}
}
// HVDAllreduceOp is an XLAOpKernel that lowers the Tensorflow HorovodAllreduce
// op into XLA HLOs. The overall idea is to lower an Tensorflow op into two
// corresponding HLO custom-calls, `start` and `end` calls, so that the XLA can
// asynchronously interact with the Horovod runtime. The `start` call is always
// non-blocking for latency hiding and the `end` call could be blocking. For
// example, as shown in HVDAllreduceOp::Compile() below, the "HorovodAllreduce"
// op is lowered into the "CallbackHVDAllreduce" and "CallbackHVDAllreduceDone"
// HLO custom-calls, whose implementations are also provided through dynamic
// registration in this file.
class HVDAllreduceOp : public XlaOpKernel {
public:
explicit HVDAllreduceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_op", &reduce_op_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("prescale_factor", &prescale_factor_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("postscale_factor", &postscale_factor_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_name_scope", &ignore_name_scope_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("process_set_id", &process_set_id_));
}
void Compile(XlaOpKernelContext* ctx) override {
node_name_ = name();
if (ignore_name_scope_) {
auto pos = node_name_.find_last_of('/');
if (pos != std::string::npos) {
node_name_ = node_name_.substr(pos + 1);
}
}
// Generate below HLOs:
// start = custom-call(in), custom_call_target="CallbackHVDAllreduce"
// end = custom-call(start),
// custom_call_target="CallbackHVDAllreduceDone"
// Note that tensors `in`, `start`, and `end'` are aliased, as we want the
// all-reduce operation to be in-place.
::xla::XlaBuilder* const b = ctx->builder();
// First, generate HVDAllreduce.
std::vector<
std::pair<::xla::ShapeIndex, std::pair<int64, ::xla::ShapeIndex>>>
output_operand_aliasing = {
{::xla::ShapeIndex{}, {0, ::xla::ShapeIndex{}}}};
::xla::XlaOp input = ctx->Input(0);
::xla::XlaOp allreduce_start = b->ReportErrorOrReturn(
BuildAllreduceCustomCall(b, {input}, /*is_start=*/true));
// Then, generate HVDAllreduceDone.
::xla::XlaOp allreduce_end = b->ReportErrorOrReturn(
BuildAllreduceCustomCall(b, {allreduce_start},
/*is_start=*/false, output_operand_aliasing));
ctx->SetOutput(0, allreduce_end);
return;
}
private:
::xla::StatusOr<::xla::XlaOp> BuildAllreduceCustomCall(
::xla::XlaBuilder* b, absl::Span<const ::xla::XlaOp> operands,
bool is_start,
absl::Span<const std::pair<::xla::ShapeIndex,
std::pair<int64, ::xla::ShapeIndex>>>
output_operand_aliasing = {});
private:
std::string node_name_;
int reduce_op_;
// Using float since TF does not support double OP attributes
float prescale_factor_;
float postscale_factor_;
bool ignore_name_scope_;
int process_set_id_;
};
// Implements a customized registrar so that the registration is an opt-in,
// controlled by HOROVOD_ENABLE_XLA_OPS.
#define HVD_REGISTER_XLA_OP(NAME, OP) \
HVD_REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP)
#define HVD_REGISTER_XLA_OP_UNIQ_HELPER(COUNTER, OP_NAME, OP) \
HVD_REGISTER_XLA_OP_UNIQ(COUNTER, OP_NAME, OP)
#define HVD_REGISTER_XLA_OP_UNIQ(CTR, OP_NAME, OP) \
static HVDXlaOpRegistrar xla_op_registrar__body__##CTR##__object( \
OP_NAME, [](::tensorflow::OpKernelConstruction* context) \
-> ::tensorflow::OpKernel* { return new OP(context); });
class HVDXlaOpRegistrar {
public:
HVDXlaOpRegistrar(string op_name,
::tensorflow::XlaOpRegistry::Factory factory) {
bool enable_xla_ops = false;
common::SetBoolFromEnv(HOROVOD_ENABLE_XLA_OPS, enable_xla_ops, true);
if (enable_xla_ops) {
xla_op_registrar_ = new XlaOpRegistrar(
::tensorflow::XlaOpRegistrationBuilder::Name(op_name).Build(factory));
}
}
private:
XlaOpRegistrar* xla_op_registrar_;
};
HVD_REGISTER_XLA_OP("HorovodAllreduce", HVDAllreduceOp);
// A helper function to build HLOs for all-reduce.
::xla::StatusOr<::xla::XlaOp> HVDAllreduceOp::BuildAllreduceCustomCall(
::xla::XlaBuilder* b, absl::Span<const ::xla::XlaOp> operands,
bool is_start,
absl::Span<
const std::pair<::xla::ShapeIndex, std::pair<int64, ::xla::ShapeIndex>>>
output_operand_aliasing) {
string call_target_name =
is_start ? "CallbackHVDAllreduce" : "CallbackHVDAllreduceDone";
CustomCallConfig config;
config.tensor_name_ = node_name_;
for (const ::xla::XlaOp& opnd : operands) {
TF_ASSIGN_OR_RETURN(::xla::Shape shape, b->GetShape(opnd));
config.input_shapes_.push_back(std::vector<int64_t>(
shape.dimensions().begin(), shape.dimensions().end()));
}
TF_ASSIGN_OR_RETURN(::xla::Shape output_shape, b->GetShape(operands.at(0)));
config.output_shapes_.push_back(std::vector<int64_t>(
output_shape.dimensions().begin(), output_shape.dimensions().end()));
config.tensor_type_ = GetHVDType(output_shape.element_type());
config.prescale_factor_ = prescale_factor_;
config.postscale_factor_ = postscale_factor_;
config.reduce_op_ = reduce_op_;
config.process_set_id_ = process_set_id_;
return ::xla::CustomCall(
b, call_target_name, operands, output_shape, config.SerializeToString(),
/*has_side_effect=*/false, output_operand_aliasing, /*literal=*/nullptr,
// Special schedule hints are given so that XLA knows how to schedule
// the opague custom-calls for performance.
is_start ? ::xla::CustomCallSchedule::EARLIEST
: ::xla::CustomCallSchedule::LATEST);
}
// Returns a hash for rendezvous.
uint64 GetRendezvousKeyHash(const string& key) {
string k = strings::StrCat(key);
return Hash64(k.data(), k.size());
}
// Implements a rendezvous to coordinate the `start` and `end` HLO callbacks.
class HVDCustomCallRendezvous {
public:
struct Payload {
std::shared_ptr<gpuEvent_t> event;
};
// This `Signal` method places payload to be consumed by Wait().
//
// Requirement: tensor_name shall be unique in a graph.
void Signal(string tensor_name, common::Event hvd_event) {
// Use `tensor_name` to generate a hash value to retrieve the queue.
uint64 key_hash = GetRendezvousKeyHash(tensor_name);
mutex_lock l(mu_);
InitQueue(key_hash);
Queue& queue = *table_[key_hash];
if (queue.empty() || queue.front() != nullptr) {
// No earlier waiters are waiting, so simply push a payload in the back.
queue.push_back(new Payload{hvd_event.event});
return;
}
// There is an earlier waiter to consume this signal. Place payload
// at the front of the queue where the waiter is polling.
CHECK(nullptr == queue.front());
queue.front() = new Payload{hvd_event.event};
}
// The `Wait` method consumes Payloads. We assume there is at most one
// outstanding `Wait` call due to its blocking nature to simplify the
// implementation. Consequently, this method always operates on the very
// first item in the queue.
void Wait(string tensor_name, CUstream stream) {
uint64 key_hash = GetRendezvousKeyHash(tensor_name);
{
mutex_lock l(mu_);
InitQueue(key_hash);
Queue& queue = *table_[key_hash];
if (queue.empty()) {
// So long as the queue is empty, place a NULL payload. Then waiting for
// Signal() to place the payload below.
queue.push_back(nullptr);
}
}
auto has_available_signal = [&]() {
mutex_lock l(mu_);
Queue& queue = *table_[key_hash];
return nullptr != queue.front();
};
while (!has_available_signal()) {
// Busy waiting. As we don't anticipate the blocking occurs frequently,
// this busy waiting should be fine. If this creates any performance
// overhead, we may implement conditional var wait.
std::this_thread::sleep_for(std::chrono::nanoseconds(100));
}
mutex_lock l(mu_);
Queue* queue = table_[key_hash];
Payload* payload = queue->front();
std::shared_ptr<gpuEvent_t> event = payload->event;
queue->pop_front();
if (queue->empty()) {
table_.erase(key_hash);
delete queue;
}
if (event) {
CUDA_CALL(cudaStreamWaitEvent(stream, *event, /*flags=*/0));
}
delete payload;
}
private:
// This method is not thread-safe.
void InitQueue(uint64 key_hash) {
auto it = table_.find(key_hash);
if (it == table_.end()) {
table_[key_hash] = new Queue();
}
}
private:
// `nullptr` denotes non-readiness of the payload.
typedef std::deque<Payload*> Queue;
// maps a hash value to queue. We will use tensor_names to generate the hash
// values.
typedef absl::flat_hash_map<uint64, Queue*> Table;
mutex mu_;
Table table_ GUARDED_BY(mu_);
};
/*static*/ HVDCustomCallRendezvous* GetHVDCustomCallRendezvous() {
static HVDCustomCallRendezvous* self = new HVDCustomCallRendezvous();
return self;
}
class XLAReadyEvent : public common::ReadyEvent {
public:
XLAReadyEvent(cudaStream_t stream) : stream_(stream) {
CUDA_CALL(cudaEventCreate(&event_));
CUDA_CALL(cudaEventRecord(event_, stream));
}
~XLAReadyEvent() { CUDA_CALL(cudaEventDestroy(event_)); }
bool Ready() const override {
cudaError_t result = cudaEventQuery(event_);
return cudaErrorNotReady != result;
}
gpuEvent_t event() const override { return event_; }
private:
cudaStream_t stream_; // Not Owned.
cudaEvent_t event_; // Owned.
};
class XLATensor : public common::Tensor {
public:
XLATensor(common::DataType type, common::TensorShape shape, void* buffer)
: type_(type), shape_(std::move(shape)), buffer_(buffer) {}
virtual const common::DataType dtype() const override { return type_; }
virtual const common::TensorShape shape() const override { return shape_; }
virtual const void* data() const override { return buffer_; }
virtual int64_t size() const override {
return shape_.num_elements() * common::DataType_Size(type_);
}
protected:
common::DataType type_;
common::TensorShape shape_;
void* buffer_; // Not owned.
};
class XLAOpContext : public common::OpContext {
public:
XLAOpContext(int device) : device_(device) {}
virtual common::Status AllocatePersistent(
int64_t size, std::shared_ptr<common::PersistentBuffer>* tensor) override;
virtual common::Status
AllocateOutput(common::TensorShape shape,
std::shared_ptr<common::Tensor>* tensor) override;
virtual common::Status
AllocateZeros(int64_t num_elements, common::DataType dtype,
std::shared_ptr<common::Tensor>* tensor) override;
virtual common::Framework framework() const override {
return common::Framework::XLA;
}
private:
int device_;
};
class XLAPersistentBuffer : public common::PersistentBuffer {
public:
XLAPersistentBuffer(int device, int64_t size);
virtual const void*
AccessData(std::shared_ptr<common::OpContext> context) const override;
private:
int device_;
void* buffer_;
};
XLAPersistentBuffer::XLAPersistentBuffer(int device, int64_t size)
: device_(device) {
int restore_device;
CUDA_CALL(cudaGetDevice(&restore_device));
CUDA_CALL(cudaSetDevice(device));
// Simply call cudaMalloc for persistent buffer.
CUDA_CALL(cudaMalloc((void**)&buffer_, size));
CUDA_CALL(cudaSetDevice(restore_device));
}
const void* XLAPersistentBuffer::AccessData(
std::shared_ptr<common::OpContext> /*context*/) const {
return buffer_;
}
common::Status XLAOpContext::AllocatePersistent(
int64_t size, std::shared_ptr<common::PersistentBuffer>* tensor) {
*tensor = std::make_shared<XLAPersistentBuffer>(device_, size);
return common::Status::OK();
}
common::Status
XLAOpContext::AllocateOutput(common::TensorShape shape,
std::shared_ptr<common::Tensor>* tensor) {
// XLA must manage I/O buffers.
return common::Status::PreconditionError(
"AllocateOutput is not supported for XLA.");
}
common::Status
XLAOpContext::AllocateZeros(int64_t num_elements, common::DataType dtype,
std::shared_ptr<common::Tensor>* tensor) {
// XLA must manage I/O buffers.
return common::Status::PreconditionError(
"AllocateZeros is not supported for XLA.");
}
common::ReadyEvent* RecordReadyEvent(cudaStream_t stream) {
return new XLAReadyEvent(stream);
}
int GetDeviceOrdinal(void* ptr) {
cudaPointerAttributes attrs;
CUDA_CALL(cudaPointerGetAttributes(&attrs, ptr));
return attrs.device;
}
// Implements for the `HVDAllreduce` HLO CustomCall.
void CallbackHVDAllreduce(CUstream stream, void** buffers, const char* opaque,
size_t opaque_len) {
CHECK(common::CheckInitialized().ok());
CustomCallConfig config;
config.ParseFromString(std::string(opaque, opaque_len));
// Enqueue requests to the Horovod runtime.
common::ReadyEventList ready_event_list;
ready_event_list.AddReadyEvent(
std::shared_ptr<common::ReadyEvent>(RecordReadyEvent(stream)));
int dev_ordinal = GetDeviceOrdinal(buffers[0]);
auto hvd_context = std::make_shared<XLAOpContext>(dev_ordinal);
auto hvd_input = std::make_shared<XLATensor>(
config.tensor_type_, common::TensorShape(config.input_shapes_[0]),
buffers[0]);
auto hvd_output = std::make_shared<XLATensor>(
config.tensor_type_, common::TensorShape(config.input_shapes_[0]),
buffers[1]);
common::Status enqueue_result = EnqueueTensorAllreduce(
hvd_context, hvd_input, hvd_output, ready_event_list, config.tensor_name_,
dev_ordinal,
[=](const common::Status& status) {
// When request is done processing, signal `HVDAllreduceDone`.
CHECK(status.ok()) << status.reason();
GetHVDCustomCallRendezvous()->Signal(config.tensor_name_, status.event);
},
(horovod::common::ReduceOp)config.reduce_op_,
(double)config.prescale_factor_, (double)config.postscale_factor_,
config.process_set_id_);
CHECK(enqueue_result.ok()) << enqueue_result.reason();
}
// Implements for the `HVDAllreduceDone` HLO CustomCall.
void CallbackHVDAllreduceDone(CUstream stream, void** /*buffers*/,
const char* opaque, size_t opaque_len) {
// Blocking until the request is done processing by the Horovod runtime.
VLOG(2) << "hvd-allreduce-done - Start";
CustomCallConfig config;
config.ParseFromString(std::string(opaque, opaque_len));
GetHVDCustomCallRendezvous()->Wait(config.tensor_name_, stream);
VLOG(2) << "hvd-allreduce-done - End";
}
XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduce, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "CUDA");
} // namespace
} // namespace tensorflow
} // namespace horovod
#endif // TENSORFLOW_VERSION >= 2005000000
#endif // HAVE_CUDA
#endif // HAVE_GPU