From 59f17f84b17fa52c8f1bc7c4b3412a9aac460cf5 Mon Sep 17 00:00:00 2001 From: weihanmines Date: Fri, 18 Mar 2022 23:07:16 +0000 Subject: [PATCH 01/11] separate rocm implementation added Signed-off-by: Wei Han Signed-off-by: weihanmines --- CMakeLists.txt | 4 + horovod/common/common.h | 2 +- horovod/common/ops/gpu_operations.h | 16 + horovod/tensorflow/mpi_ops.cc | 4 + horovod/tensorflow/xla_mpi_ops.cc | 530 ++++++++++++++++++++++++++++ 5 files changed, 555 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e8ded61d08..7707ed0a30 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -350,6 +350,10 @@ if(HAVE_CUDA OR HAVE_SUB_PROJECT_CUDA) add_subdirectory(horovod/common/ops/cuda) endif() +if(HAVE_ROCM) + add_subdirectory(horovod/common/ops/rocm) +endif() + # if we need compatible c++ abi # Duplicate gloo folder and add it as a new sub-project if(HAVE_GLOO AND ((DEFINED Tensorflow_CXX11 AND NOT Tensorflow_CXX11) OR (DEFINED Pytorch_CXX11 AND NOT Pytorch_CXX11) OR (DEFINED Mxnet_CXX11 AND NOT Mxnet_CXX11))) diff --git a/horovod/common/common.h b/horovod/common/common.h index 8876a02c9e..f46c4fc81a 100644 --- a/horovod/common/common.h +++ b/horovod/common/common.h @@ -45,7 +45,6 @@ using gpuStream_t = cudaStream_t; throw std::logic_error(std::string("GPU Error:") + cudaGetErrorString(cuda_result)); \ } \ } while (0) -#endif #elif HAVE_ROCM #include using gpuError_t = hipError_t; @@ -64,6 +63,7 @@ using gpuStream_t = hipStream_t; } \ } while (0) #endif +#endif namespace horovod { diff --git a/horovod/common/ops/gpu_operations.h b/horovod/common/ops/gpu_operations.h index f9342cde33..c538b1cea0 100644 --- a/horovod/common/ops/gpu_operations.h +++ b/horovod/common/ops/gpu_operations.h @@ -176,6 +176,22 @@ class GPUAllreduce : public AllreduceOp { double scale_factor, std::vector& entries); #endif +#if HAVE_ROCM + void MemcpyInFusionBuffer(const std::vector& entries, + const void*& fused_input_data, void*& buffer_data, + size_t& buffer_len) override; + + void MemcpyOutFusionBuffer(const void* buffer_data, + std::vector& entries) override; + + void ScaleMemcpyInFusionBuffer(const std::vector& entries, + const void*& fused_input_data, + void*& buffer_data, size_t& buffer_len, + double scale_factor); + void ScaleMemcpyOutFusionBuffer(void* buffer_data, size_t buffer_len, + double scale_factor, + std::vector& entries); +#endif void MemcpyEntryInFusionBuffer(const std::vector& entries, const TensorTableEntry& e, diff --git a/horovod/tensorflow/mpi_ops.cc b/horovod/tensorflow/mpi_ops.cc index a226ad9c10..535736743a 100644 --- a/horovod/tensorflow/mpi_ops.cc +++ b/horovod/tensorflow/mpi_ops.cc @@ -27,6 +27,10 @@ #define EIGEN_USE_GPU #endif // HAVE_CUDA || HAVE_ROCM +#if HAVE_ROCM +#define EIGEN_USE_HIP +#endif + #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" diff --git a/horovod/tensorflow/xla_mpi_ops.cc b/horovod/tensorflow/xla_mpi_ops.cc index 59ba7d511c..7c06232584 100644 --- a/horovod/tensorflow/xla_mpi_ops.cc +++ b/horovod/tensorflow/xla_mpi_ops.cc @@ -572,4 +572,534 @@ XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "CUDA"); #endif // TENSORFLOW_VERSION >= 2006000000 #endif // HAVE_CUDA +#if HAVE_ROCM + +#include +#include "../common/common.h" + +#define OMPI_SKIP_MPICXX +#include "../common/operations.h" +#include "../common/utils/env_parser.h" +#include "./custom_call_config_generated.h" + +using namespace tensorflow; + +namespace horovod { +namespace xla { +namespace { + +common::DataType GetHVDType(::xla::PrimitiveType type) { + switch (type) { + case ::xla::U8: + return common::HOROVOD_UINT8; + case ::xla::S8: + return common::HOROVOD_INT8; + case ::xla::U16: + return common::HOROVOD_UINT16; + case ::xla::S16: + return common::HOROVOD_INT16; + case ::xla::S32: + return common::HOROVOD_INT32; + case ::xla::S64: + return common::HOROVOD_INT64; + case ::xla::F16: + return common::HOROVOD_FLOAT16; + case ::xla::F32: + return common::HOROVOD_FLOAT32; + case ::xla::F64: + return common::HOROVOD_FLOAT64; + case ::xla::PRED: + return common::HOROVOD_BOOL; + default: + throw std::logic_error("Invalid XLA tensor type."); + } +} + +// CustomCallConfig stores configurations of Horovod ops. We pass this config +// to ::xla::CustomCall so that the XLA CustomCall can represent various Horovod +// ops. Flatbuffer is used to serialize the config into string to conform to the +// XLA CustomCall interface. +class CustomCallConfig { +public: + std::string SerializeToString(); + void ParseFromString(std::string); + +public: + std::string tensor_name_; + common::DataType tensor_type_; + std::vector> input_shapes_; + std::vector> output_shapes_; + float prescale_factor_; + float postscale_factor_; + int root_rank_; + int reduce_op_; + int process_set_id_; +}; + +std::string CustomCallConfig::SerializeToString() { + flatbuffers::FlatBufferBuilder fbb(1024); + + std::vector> input_shapes_obj; + absl::c_for_each(input_shapes_, [&](const std::vector& dims) { + input_shapes_obj.push_back(wire::CreateTensorShapeDirect(fbb, &dims)); + }); + std::vector> output_shapes_obj; + absl::c_for_each(output_shapes_, [&](const std::vector& dims) { + output_shapes_obj.push_back(wire::CreateTensorShapeDirect(fbb, &dims)); + }); + auto wire = wire::CreateCustomCallConfigDirect( + fbb, tensor_name_.c_str(), (common::wire::DataType)tensor_type_, + &input_shapes_obj, &output_shapes_obj, prescale_factor_, + postscale_factor_, root_rank_, reduce_op_, process_set_id_); + fbb.Finish(wire); + + uint8_t* buf = fbb.GetBufferPointer(); + auto size = fbb.GetSize(); + return std::string((char*)buf, size); +} + +void CustomCallConfig::ParseFromString(std::string input) { + const wire::CustomCallConfig* obj = + flatbuffers::GetRoot( + (const uint8_t*)input.data()); + + tensor_name_ = obj->tensor_name()->str(); + tensor_type_ = (common::DataType)obj->tensor_type(); + for (auto it = obj->input_shapes()->begin(); it != obj->input_shapes()->end(); + it++) { + auto shape_obj = *it; + input_shapes_.push_back(std::vector(shape_obj->dims()->begin(), + shape_obj->dims()->end())); + } + for (auto it = obj->output_shapes()->begin(); + it != obj->output_shapes()->end(); it++) { + auto shape_obj = *it; + output_shapes_.push_back(std::vector(shape_obj->dims()->begin(), + shape_obj->dims()->end())); + } + prescale_factor_ = obj->prescale_factor(); + postscale_factor_ = obj->postscale_factor(); + root_rank_ = obj->root_rank(); + reduce_op_ = obj->reduce_op(); + process_set_id_ = obj->process_set_id(); + + if (VLOG_IS_ON(2)) { + VLOG(2) << "tensor_name " << tensor_name_; + VLOG(2) << "tensor_type " << tensor_type_; + VLOG(2) << "prescale_factor = " << prescale_factor_; + VLOG(2) << "postscale_factor = " << postscale_factor_; + VLOG(2) << "root_rank = " << root_rank_; + VLOG(2) << "reduce_op = " << reduce_op_; + VLOG(2) << "process_set_id = " << process_set_id_; + } +} + +// HVDAllreduceOp is an XLAOpKernel that lowers the Tensorflow HorovodAllreduce +// op into XLA HLOs. The overall idea is to lower an Tensorflow op into two +// corresponding HLO custom-calls, `start` and `end` calls, so that the XLA can +// asynchronously interact with the Horovod runtime. The `start` call is always +// non-blocking for latency hiding and the `end` call could be blocking. For +// example, as shown in HVDAllreduceOp::Compile() below, the "HorovodAllreduce" +// op is lowered into the "CallbackHVDAllreduce" and "CallbackHVDAllreduceDone" +// HLO custom-calls, whose implementations are also provided through dynamic +// registration in this file. +class HVDAllreduceOp : public XlaOpKernel { +public: + explicit HVDAllreduceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_op", &reduce_op_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("prescale_factor", &prescale_factor_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("postscale_factor", &postscale_factor_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_name_scope", &ignore_name_scope_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("process_set_id", &process_set_id_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + node_name_ = name(); + if (ignore_name_scope_) { + auto pos = node_name_.find_last_of('/'); + if (pos != std::string::npos) { + node_name_ = node_name_.substr(pos + 1); + } + } + + // Generate below HLOs: + // start = custom-call(in), custom_call_target="CallbackHVDAllreduce" + // end = custom-call(start), + // custom_call_target="CallbackHVDAllreduceDone" + // Note that tensors `in`, `start`, and `end'` are aliased, as we want the + // all-reduce operation to be in-place. + ::xla::XlaBuilder* const b = ctx->builder(); + // First, generate HVDAllreduce. + std::vector< + std::pair<::xla::ShapeIndex, std::pair>> + output_operand_aliasing = { + {::xla::ShapeIndex{}, {0, ::xla::ShapeIndex{}}}}; + ::xla::XlaOp input = ctx->Input(0); + ::xla::XlaOp allreduce_start = b->ReportErrorOrReturn( + BuildAllreduceCustomCall(b, {input}, /*is_start=*/true)); + // Then, generate HVDAllreduceDone. + ::xla::XlaOp allreduce_end = b->ReportErrorOrReturn( + BuildAllreduceCustomCall(b, {allreduce_start}, + /*is_start=*/false, output_operand_aliasing)); + ctx->SetOutput(0, allreduce_end); + return; + } + +private: + ::xla::StatusOr<::xla::XlaOp> BuildAllreduceCustomCall( + ::xla::XlaBuilder* b, absl::Span operands, + bool is_start, + absl::Span>> + output_operand_aliasing = {}); + +private: + std::string node_name_; + int reduce_op_; + // Using float since TF does not support double OP attributes + float prescale_factor_; + float postscale_factor_; + bool ignore_name_scope_; + int process_set_id_; +}; + +// Implements a customized registrar so that the registration is an opt-in, +// controlled by HOROVOD_ENABLE_XLA_OPS. +#define HVD_REGISTER_XLA_OP(NAME, OP) \ + HVD_REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP) + +#define HVD_REGISTER_XLA_OP_UNIQ_HELPER(COUNTER, OP_NAME, OP) \ + HVD_REGISTER_XLA_OP_UNIQ(COUNTER, OP_NAME, OP) + +#define HVD_REGISTER_XLA_OP_UNIQ(CTR, OP_NAME, OP) \ + static HVDXlaOpRegistrar xla_op_registrar__body__##CTR##__object( \ + OP_NAME, [](::tensorflow::OpKernelConstruction* context) \ + -> ::tensorflow::OpKernel* { return new OP(context); }); + +class HVDXlaOpRegistrar { +public: + HVDXlaOpRegistrar(string op_name, + ::tensorflow::XlaOpRegistry::Factory factory) { + bool enable_xla_ops = false; + common::SetBoolFromEnv(HOROVOD_ENABLE_XLA_OPS, enable_xla_ops, true); + if (enable_xla_ops) { + xla_op_registrar_ = new XlaOpRegistrar( + ::tensorflow::XlaOpRegistrationBuilder::Name(op_name).Build(factory)); + } + } + +private: + XlaOpRegistrar* xla_op_registrar_; +}; + +HVD_REGISTER_XLA_OP("HorovodAllreduce", HVDAllreduceOp); + +// A helper function to build HLOs for all-reduce. +::xla::StatusOr<::xla::XlaOp> HVDAllreduceOp::BuildAllreduceCustomCall( + ::xla::XlaBuilder* b, absl::Span operands, + bool is_start, + absl::Span< + const std::pair<::xla::ShapeIndex, std::pair>> + output_operand_aliasing) { + string call_target_name = + is_start ? "CallbackHVDAllreduce" : "CallbackHVDAllreduceDone"; + CustomCallConfig config; + config.tensor_name_ = node_name_; + for (const ::xla::XlaOp& opnd : operands) { + TF_ASSIGN_OR_RETURN(::xla::Shape shape, b->GetShape(opnd)); + config.input_shapes_.push_back(std::vector( + shape.dimensions().begin(), shape.dimensions().end())); + } + TF_ASSIGN_OR_RETURN(::xla::Shape output_shape, b->GetShape(operands.at(0))); + config.output_shapes_.push_back(std::vector( + output_shape.dimensions().begin(), output_shape.dimensions().end())); + config.tensor_type_ = GetHVDType(output_shape.element_type()); + config.prescale_factor_ = prescale_factor_; + config.postscale_factor_ = postscale_factor_; + config.reduce_op_ = reduce_op_; + config.process_set_id_ = process_set_id_; + + return ::xla::CustomCall( + b, call_target_name, operands, output_shape, config.SerializeToString(), + /*has_side_effect=*/false, output_operand_aliasing, /*literal=*/nullptr, + // Special schedule hints are given so that XLA knows how to schedule + // the opague custom-calls for performance. + is_start ? ::xla::CustomCallSchedule::SCHEDULE_EARLIEST + : ::xla::CustomCallSchedule::SCHEDULE_LATEST); +} + +// Returns a hash for rendezvous. +uint64 GetRendezvousKeyHash(const string& key) { + string k = strings::StrCat(key); + return Hash64(k.data(), k.size()); +} + +// Implements a rendezvous to coordinate the `start` and `end` HLO callbacks. +class HVDCustomCallRendezvous { +public: + struct Payload { + std::shared_ptr event; + }; + + // This `Signal` method places payload to be consumed by Wait(). + // + // Requirement: tensor_name shall be unique in a graph. + void Signal(string tensor_name, common::Event hvd_event) { + // Use `tensor_name` to generate a hash value to retrieve the queue. + uint64 key_hash = GetRendezvousKeyHash(tensor_name); + mutex_lock l(mu_); + InitQueue(key_hash); + + Queue& queue = *table_[key_hash]; + if (queue.empty() || queue.front() != nullptr) { + // No earlier waiters are waiting, so simply push a payload in the back. + queue.push_back(new Payload{hvd_event.event}); + return; + } + + // There is an earlier waiter to consume this signal. Place payload + // at the front of the queue where the waiter is polling. + CHECK(nullptr == queue.front()); + queue.front() = new Payload{hvd_event.event}; + } + + // The `Wait` method consumes Payloads. We assume there is at most one + // outstanding `Wait` call due to its blocking nature to simplify the + // implementation. Consequently, this method always operates on the very + // first item in the queue. + void Wait(string tensor_name, hipStream_t stream) { + uint64 key_hash = GetRendezvousKeyHash(tensor_name); + + { + mutex_lock l(mu_); + InitQueue(key_hash); + Queue& queue = *table_[key_hash]; + if (queue.empty()) { + // So long as the queue is empty, place a NULL payload. Then waiting for + // Signal() to place the payload below. + queue.push_back(nullptr); + } + } + + auto has_available_signal = [&]() { + mutex_lock l(mu_); + Queue& queue = *table_[key_hash]; + return nullptr != queue.front(); + }; + while (!has_available_signal()) { + // Busy waiting. As we don't anticipate the blocking occurs frequently, + // this busy waiting should be fine. If this creates any performance + // overhead, we may implement conditional var wait. + std::this_thread::sleep_for(std::chrono::nanoseconds(100)); + } + + mutex_lock l(mu_); + Queue* queue = table_[key_hash]; + Payload* payload = queue->front(); + std::shared_ptr event = payload->event; + queue->pop_front(); + if (queue->empty()) { + table_.erase(key_hash); + delete queue; + } + if (event) { + HVD_GPU_CHECK(hipStreamWaitEvent(stream, *event, /*flags=*/0)); + } + delete payload; + } + +private: + // This method is not thread-safe. + void InitQueue(uint64 key_hash) { + auto it = table_.find(key_hash); + if (it == table_.end()) { + table_[key_hash] = new Queue(); + } + } + +private: + // `nullptr` denotes non-readiness of the payload. + typedef std::deque Queue; + // maps a hash value to queue. We will use tensor_names to generate the hash + // values. + typedef absl::flat_hash_map Table; + + mutex mu_; + Table table_ GUARDED_BY(mu_); +}; + +/*static*/ HVDCustomCallRendezvous* GetHVDCustomCallRendezvous() { + static HVDCustomCallRendezvous* self = new HVDCustomCallRendezvous(); + return self; +} + +class XLAReadyEvent : public common::ReadyEvent { +public: + XLAReadyEvent(hipStream_t stream) : stream_(stream) { + HVD_GPU_CHECK(hipEventCreate(&event_)); + HVD_GPU_CHECK(hipEventRecord(event_, stream)); + } + ~XLAReadyEvent() { HVD_GPU_CHECK(hipEventDestroy(event_)); } + + bool Ready() const override { + hipError_t result = hipEventQuery(event_); + return hipErrorNotReady != result; + } + gpuEvent_t event() const override { return event_; } + +private: + hipStream_t stream_; // Not Owned. + hipEvent_t event_; // Owned. +}; + +class XLATensor : public common::Tensor { +public: + XLATensor(common::DataType type, common::TensorShape shape, void* buffer) + : type_(type), shape_(std::move(shape)), buffer_(buffer) {} + + virtual const common::DataType dtype() const override { return type_; } + virtual const common::TensorShape shape() const override { return shape_; } + virtual const void* data() const override { return buffer_; } + virtual int64_t size() const override { + return shape_.num_elements() * common::DataType_Size(type_); + } + +protected: + common::DataType type_; + common::TensorShape shape_; + void* buffer_; // Not owned. +}; + +class XLAOpContext : public common::OpContext { +public: + XLAOpContext(int device) : device_(device) {} + + virtual common::Status AllocatePersistent( + int64_t size, std::shared_ptr* tensor) override; + + virtual common::Status + AllocateOutput(common::TensorShape shape, + std::shared_ptr* tensor) override; + + virtual common::Status + AllocateZeros(int64_t num_elements, common::DataType dtype, + std::shared_ptr* tensor) override; + + virtual common::Framework framework() const override { + return common::Framework::XLA; + } + +private: + int device_; +}; + +class XLAPersistentBuffer : public common::PersistentBuffer { +public: + XLAPersistentBuffer(int device, int64_t size); + virtual const void* + AccessData(std::shared_ptr context) const override; + +private: + int device_; + void* buffer_; +}; + +XLAPersistentBuffer::XLAPersistentBuffer(int device, int64_t size) + : device_(device) { + int restore_device; + HVD_GPU_CHECK(hipGetDevice(&restore_device)); + HVD_GPU_CHECK(hipSetDevice(device)); + // Simply call cudaMalloc for persistent buffer. + HVD_GPU_CHECK(hipMalloc((void**)&buffer_, size)); + HVD_GPU_CHECK(hipSetDevice(restore_device)); +} + +const void* XLAPersistentBuffer::AccessData( + std::shared_ptr /*context*/) const { + return buffer_; +} + +common::Status XLAOpContext::AllocatePersistent( + int64_t size, std::shared_ptr* tensor) { + *tensor = std::make_shared(device_, size); + return common::Status::OK(); +} + +common::Status +XLAOpContext::AllocateOutput(common::TensorShape shape, + std::shared_ptr* tensor) { + // XLA must manage I/O buffers. + return common::Status::PreconditionError( + "AllocateOutput is not supported for XLA."); +} + +common::Status +XLAOpContext::AllocateZeros(int64_t num_elements, common::DataType dtype, + std::shared_ptr* tensor) { + // XLA must manage I/O buffers. + return common::Status::PreconditionError( + "AllocateZeros is not supported for XLA."); +} + +common::ReadyEvent* RecordReadyEvent(hipStream_t stream) { + return new XLAReadyEvent(stream); +} + +int GetDeviceOrdinal(void* ptr) { + hipPointerAttribute_t attrs; + HVD_GPU_CHECK(hipPointerGetAttributes(&attrs, ptr)); + return attrs.device; +} + +// Implements for the `HVDAllreduce` HLO CustomCall. +void CallbackHVDAllreduce(hipStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + CHECK(common::CheckInitialized().ok()); + CustomCallConfig config; + config.ParseFromString(std::string(opaque, opaque_len)); + + // Enqueue requests to the Horovod runtime. + common::ReadyEventList ready_event_list; + ready_event_list.AddReadyEvent( + std::shared_ptr(RecordReadyEvent(stream))); + int dev_ordinal = GetDeviceOrdinal(buffers[0]); + auto hvd_context = std::make_shared(dev_ordinal); + auto hvd_input = std::make_shared( + config.tensor_type_, common::TensorShape(config.input_shapes_[0]), + buffers[0]); + auto hvd_output = std::make_shared( + config.tensor_type_, common::TensorShape(config.input_shapes_[0]), + buffers[1]); + common::Status enqueue_result = EnqueueTensorAllreduce( + hvd_context, hvd_input, hvd_output, ready_event_list, config.tensor_name_, + dev_ordinal, + [=](const common::Status& status) { + // When request is done processing, signal `HVDAllreduceDone`. + CHECK(status.ok()) << status.reason(); + GetHVDCustomCallRendezvous()->Signal(config.tensor_name_, status.event); + }, + (horovod::common::ReduceOp)config.reduce_op_, + (double)config.prescale_factor_, (double)config.postscale_factor_, + config.process_set_id_); + CHECK(enqueue_result.ok()) << enqueue_result.reason(); +} + +// Implements for the `HVDAllreduceDone` HLO CustomCall. +void CallbackHVDAllreduceDone(hipStream_t stream, void** /*buffers*/, + const char* opaque, size_t opaque_len) { + // Blocking until the request is done processing by the Horovod runtime. + VLOG(2) << "hvd-allreduce-done - Start"; + CustomCallConfig config; + config.ParseFromString(std::string(opaque, opaque_len)); + GetHVDCustomCallRendezvous()->Wait(config.tensor_name_, stream); + VLOG(2) << "hvd-allreduce-done - End"; +} + +XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduce, "ROCm"); +XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "ROCm"); + +} // namespace +} // namespace tensorflow +} // namespace horovod +#endif //HAVE_ROCM #endif // HAVE_GPU From e9d381e6c0f537274c13490554d3a7e1439d71db Mon Sep 17 00:00:00 2001 From: weihanmines Date: Sat, 19 Mar 2022 05:31:00 +0000 Subject: [PATCH 02/11] add rocm kenrels Signed-off-by: Wei Han Signed-off-by: weihanmines --- horovod/common/ops/hip_operations.cc | 119 +++++++-- horovod/common/ops/rocm/CMakeLists.txt | 23 ++ horovod/common/ops/rocm/hip_kernels.cu | 325 +++++++++++++++++++++++++ horovod/common/ops/rocm/hip_kernels.h | 48 ++++ 4 files changed, 494 insertions(+), 21 deletions(-) create mode 100644 horovod/common/ops/rocm/CMakeLists.txt create mode 100644 horovod/common/ops/rocm/hip_kernels.cu create mode 100644 horovod/common/ops/rocm/hip_kernels.h diff --git a/horovod/common/ops/hip_operations.cc b/horovod/common/ops/hip_operations.cc index 0479c6151e..5307831329 100644 --- a/horovod/common/ops/hip_operations.cc +++ b/horovod/common/ops/hip_operations.cc @@ -14,17 +14,18 @@ // limitations under the License. // ============================================================================= +#include "../hashes.h" #include "../message.h" #include "gpu_operations.h" +#include "rocm/hip_kernels.h" #include namespace horovod { namespace common { - class GPUContext::impl { public: - hipError_t GetGpuEvent(hipEvent_t* event) { + hipError_t GetGpuEvent(Event* event, hipStream_t stream) { int device; auto status = hipGetDevice(&device); if (status != hipSuccess) { @@ -34,18 +35,39 @@ class GPUContext::impl { auto& mutex = hip_events_mutex; { std::lock_guard guard(mutex); - auto& queue = hip_events[device]; + auto key = std::make_pair(device, stream); + auto& queue = hip_events[key]; + if (!prepopulated[key]) { + // On first call for device and stream pair, prepopulate event queue. + // This is to minimize event reuse of callback events passed to + // framework. + for (int i = 0; i < N_HIP_EVENTS_PREPOPULATE; ++i) { + hipEvent_t ev; + status = hipEventCreateWithFlags(&ev, hipEventDisableTiming); + queue.emplace(std::make_shared(ev), stream); + } + prepopulated[key] = true; + } if (!queue.empty()) { *event = queue.front(); + event->event_idx = ++hip_event_idx[key]; queue.pop(); return hipSuccess; } } - return hipEventCreateWithFlags(event, hipEventDisableTiming); + hipEvent_t ev; + status = hipEventCreateWithFlags(&ev, hipEventDisableTiming); + event->event = std::make_shared(ev); + event->stream = stream; + auto key2 = std::make_pair(device, stream); + event->event_idx = ++hip_event_idx[key2]; + + + return status; } - hipError_t ReleaseGpuEvent(hipEvent_t event) { + hipError_t ReleaseGpuEvent(Event event) { int device; auto status = hipGetDevice(&device); if (status != hipSuccess) { @@ -55,7 +77,7 @@ class GPUContext::impl { auto& mutex = hip_events_mutex; { std::lock_guard guard(mutex); - auto& queue = hip_events[device]; + auto& queue = hip_events[std::make_pair(device, event.stream)]; queue.push(event); } @@ -69,22 +91,59 @@ class GPUContext::impl { } } - void RecordEvent(std::queue>& event_queue, + void RecordEvent(std::queue>& event_queue, std::string name, hipStream_t& stream) { - hipEvent_t event; - ErrorCheck("GetGpuEvent", GetGpuEvent(&event)); - ErrorCheck("hipEventRecord", hipEventRecord(event, stream)); + Event event; + ErrorCheck("GetGpuEvent", GetGpuEvent(&event, stream)); + ErrorCheck("hipEventRecord", + hipEventRecord(*(event.event), event.stream)); event_queue.emplace(name, event); } - void - WaitForEvents(std::queue>& event_queue, + Event RecordEvent(hipStream_t& stream) { + Event event; + ErrorCheck("GetGpuEvent", GetGpuEvent(&event, stream)); + ErrorCheck("hipEventRecord", + hipEventRecord(*(event.event), event.stream)); + return event; + } + + void WaitForEvents(std::queue>& event_queue, const std::vector& entries, Timeline& timeline, const std::function& error_check_callback) { while (!event_queue.empty()) { std::string name; - hipEvent_t event; + Event event; + std::tie(name, event) = event_queue.front(); + event_queue.pop(); + if (name != "") { + timeline.ActivityStartAll(entries, name); + } + + hipError_t hip_result = hipEventSynchronize(*(event.event)); + if (hip_result != hipSuccess) { + throw std::logic_error(std::string("cudaEventSynchronize failed: ") + + hipGetErrorString(hip_result)); + } + if (error_check_callback) { + error_check_callback(); + } + + if (name != "") { + timeline.ActivityEndAll(entries); + } + ErrorCheck("ReleaseGpuEvent", ReleaseGpuEvent(event)); + } + } + + void WaitForEventsElastic( + std::queue>& event_queue, + const std::vector& entries, Timeline& timeline, + const std::function& error_check_callback) { + while (!event_queue.empty()) { + std::string name; + Event event; std::tie(name, event) = event_queue.front(); event_queue.pop(); if (name != "") { @@ -95,13 +154,13 @@ class GPUContext::impl { // complete hipError_t hip_result; while (true) { - hip_result = hipEventQuery(event); + hip_result = hipEventQuery(*(event.event)); if (hip_result == hipSuccess) { break; } if (hip_result != hipErrorNotReady) { - throw std::logic_error(std::string("hipEventQuery failed: ") + + throw std::logic_error(std::string("cudaEventQuery failed: ") + hipGetErrorString(hip_result)); } @@ -118,11 +177,25 @@ class GPUContext::impl { } } - void WaitForEventsElastic( - std::queue>& event_queue, - const std::vector& entries, Timeline& timeline, - const std::function& error_check_callback) { - WaitForEvents(event_queue, entries, timeline, error_check_callback); + void ClearEvents(std::queue>& event_queue, + const std::vector& entries, + Timeline& timeline, + const std::function& error_check_callback, + bool elastic) { + while (!event_queue.empty()) { + std::string name; + Event event; + std::tie(name, event) = event_queue.front(); + event_queue.pop(); + if (name != "") { + timeline.ActivityStartAll(entries, name); + } + + if (name != "") { + timeline.ActivityEndAll(entries); + } + ErrorCheck("ReleaseGpuEvent", ReleaseGpuEvent(event)); + } } void StreamCreate(hipStream_t* stream) { @@ -176,8 +249,12 @@ class GPUContext::impl { private: // We reuse HIP events as it appears that their creation carries non-zero // cost. - std::unordered_map> hip_events; + std::unordered_map, std::queue> + hip_events; + std::unordered_map, bool> prepopulated; + std::unordered_map, std::atomic> hip_event_idx; std::mutex hip_events_mutex; + static constexpr int N_HIP_EVENTS_PREPOPULATE = 128; }; #include "gpu_context_impl.cc" diff --git a/horovod/common/ops/rocm/CMakeLists.txt b/horovod/common/ops/rocm/CMakeLists.txt new file mode 100644 index 0000000000..eec1f0c293 --- /dev/null +++ b/horovod/common/ops/rocm/CMakeLists.txt @@ -0,0 +1,23 @@ +if (NOT DEFINED HCC_APTH) + if (DEFINED ENV{HCC_PATH}) + set(HIP_PATH ${HCC_PATH} CACHE PATH "Path to which HCC has been installed") + else() + set(HCC_PATH "${ROCM_PATH}/hcc" CACHE PATH "Path to which HCC has been set") + endif() + set(HCC_HOME "{HCC_PATH}") +endif() + +list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm) +set(HIP_CLANG_PATH "${ROCM_PATH}/llvm/bin") +set(CMAKE_MODULE_PATH "${HIP_PATH}/cmake" ${CMAKE_MODULE_PATH}) +set(HIP_HCC_FLAGS ${HIP_HCC_FLAGS};-D__HIP_PLATFORM_HIPCC__=1;-fPIC) +find_package(HIP QUIET REQUIRED) +set(HIP_HIPCC_FLAGS ${HIP_HIPCC_FLAGS};-fPIC) +list(APPEND HIP_HCC_FLAGS_RELEASE -O3 -fPIC) +list(APPEND HIP_HCC_FLAGS_DEBUG -G -fPIC) + +list(APPEND HIP_HIPCC_FLAGS -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC) +hip_add_library(horovod_cuda_kernels STATIC hip_kernels.cu) +target_compile_definitions(horovod_cuda_kernels PRIVATE _GLIBCXX_USE_CXX11_ABI=1) +hip_add_library(compatible_horovod_cuda_kernels STATIC hip_kernels.cu) +target_compile_definitions(compatible_horovod_cuda_kernels PRIVATE _GLIBCXX_USE_CXX11_ABI=0) diff --git a/horovod/common/ops/rocm/hip_kernels.cu b/horovod/common/ops/rocm/hip_kernels.cu new file mode 100644 index 0000000000..af463a3de9 --- /dev/null +++ b/horovod/common/ops/rocm/hip_kernels.cu @@ -0,0 +1,325 @@ +// Copyright (C) 2020 NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "hip_kernels.h" + +#include +#include + +namespace horovod { +namespace common { + +template +__device__ void batched_memcpy_d(size_t idx, const void* in, void* out, size_t size) { + + const T* input = reinterpret_cast(in); + T* output = reinterpret_cast(out); + const size_t num_elements = size / sizeof(T); + + for (size_t i = idx; i < num_elements; i += blockDim.x * blocks_per_copy) { + output[i] = input[i]; + } + + // Deal with any remaining bytes + size_t remainder = size % sizeof(T); + if (remainder > 0 && idx < remainder) { + const unsigned char* input_r = reinterpret_cast(input + num_elements); + unsigned char* output_r = reinterpret_cast(output + num_elements); + output_r[idx] = input_r[idx]; + } +} + +template +__global__ void batched_memcpy_k(BatchedD2DParams params) { + const size_t idx = blockDim.x * (blockIdx.x % blocks_per_copy) + threadIdx.x; + + const size_t size = params.sizes[blockIdx.x / blocks_per_copy]; + const void* input = params.in[blockIdx.x / blocks_per_copy]; + void* output = params.out[blockIdx.x / blocks_per_copy]; + + // Check alignment relative to 16 bytes + size_t align_in = reinterpret_cast(input) % BATCHED_D2D_PADDING; + size_t align_out = reinterpret_cast(output) % BATCHED_D2D_PADDING; + + // Select load/store size based on the misaligned buffer + size_t align = (align_out == 0) ? align_in : align_out; + if (align_in && align_out) { + // If both are misaligned, use unsigned char (this should not occur + // as fusion buffer locations should be aligned by applying BATCH_D2D_PADDING + // during construction.) + align = 1; + } + + if (align % 16 == 0) { + batched_memcpy_d(idx, input, output, size); + } else if (align % 8 == 0) { + batched_memcpy_d(idx, input, output, size); + } else if (align % 4 == 0) { + batched_memcpy_d(idx, input, output, size); + } else if (align % 2 == 0) { + batched_memcpy_d(idx, input, output, size); + } else { + batched_memcpy_d(idx, input, output, size); + } +} + +#define NTHREADS_D2D_KERNEL 1024 +#define BLOCKS_PER_COPY_D2D_KERNEL 8 +void BatchedD2DMemcpyCudaImpl(BatchedD2DParams& params, int num_copies, hipStream_t stream) +{ + batched_memcpy_k<<>>(params); +} + +template +__global__ void scale_buffer_k(const T* input, T* output, int64_t num_elements, const TS scale_factor) { + + const size_t idx = static_cast(blockDim.x) * blockIdx.x + threadIdx.x; + + for (size_t i = idx; i < num_elements; i += gridDim.x * blockDim.x) { + output[i] = scale_factor * input[i]; + } +} + +// Specialization for half2 +__global__ void scale_buffer_half2_k(const __half* input, __half* output, int64_t num_elements, const __half scale_factor) { + + const size_t idx = static_cast(blockDim.x) * blockIdx.x + threadIdx.x; + +#if __CUDA_ARCH__ > 530 + const __half2* input_h2 = reinterpret_cast(input); + __half2* output_h2 = reinterpret_cast<__half2 *>(output); + const __half2 scale_factor_h2 = __halves2half2(scale_factor, scale_factor); + + for (size_t i = idx; i < num_elements / 2; i += gridDim.x * blockDim.x) { + output_h2[i] = __hmul2(scale_factor_h2, input_h2[i]); + } + + // Deal with last element if num_elements is odd + if (idx == 0 && num_elements % 2) { + output[num_elements - 1] = __hmul(scale_factor, input[num_elements - 1]); + } +#else + for (size_t i = idx; i < num_elements; i += gridDim.x * blockDim.x) { + output[i] = __float2half(__half2float(scale_factor) * __half2float(input[i])); + } +#endif +} + +// Specialization for architectures without __half compute +template<> +__global__ void scale_buffer_k(const __half* input, __half* output, int64_t num_elements, const __half scale_factor) { + + const size_t idx = static_cast(blockDim.x) * blockIdx.x + threadIdx.x; + +#if __CUDA_ARCH__ > 530 + for (size_t i = idx; i < num_elements; i += gridDim.x * blockDim.x) { + output[i] = scale_factor * input[i]; + } +#else + for (size_t i = idx; i < num_elements; i += gridDim.x * blockDim.x) { + output[i] = __float2half(__half2float(scale_factor) * __half2float(input[i])); + } +#endif +} + +#define NTHREADS_SCALE_BUFFER_KERNEL 512 +void ScaleBufferCudaImpl(const void* fused_input_data, void* buffer_data, const int64_t num_elements, double scale_factor, + DataType dtype, hipStream_t stream) { + const int64_t blocks = (num_elements + NTHREADS_SCALE_BUFFER_KERNEL - 1) / NTHREADS_SCALE_BUFFER_KERNEL; + const int threads = NTHREADS_SCALE_BUFFER_KERNEL; + switch (dtype) { + case HOROVOD_UINT8: + scale_buffer_k<<>>((const uint8_t*) fused_input_data, (uint8_t*) buffer_data, + num_elements, scale_factor); + break; + case HOROVOD_INT8: + scale_buffer_k<<>>((const int8_t*) fused_input_data, (int8_t*) buffer_data, + num_elements, scale_factor); + break; + case HOROVOD_INT32: + scale_buffer_k<<>>((const int32_t*) fused_input_data, (int32_t*) buffer_data, + num_elements, scale_factor); + break; + case HOROVOD_INT64: + scale_buffer_k<<>>((const int64_t*) fused_input_data, (int64_t*) buffer_data, + num_elements, scale_factor); + break; + case HOROVOD_FLOAT16: + { + __half scale_factor_half = __float2half((float) scale_factor); + if ((size_t) fused_input_data % 4 == 0 && (size_t) buffer_data % 4 == 0) { + // If alignment allows, use half2 specialized kernel + int64_t num_elements_h2 = (num_elements + 1) / 2; + int64_t blocks_h2 = (num_elements_h2 + NTHREADS_SCALE_BUFFER_KERNEL - 1) / NTHREADS_SCALE_BUFFER_KERNEL; + scale_buffer_half2_k<<>>((const __half*) fused_input_data, (__half*) buffer_data, + num_elements, scale_factor_half); + } else { + scale_buffer_k<<>>((const __half*) fused_input_data, (__half*) buffer_data, + num_elements, scale_factor_half); + } + break; + } + case HOROVOD_FLOAT32: + scale_buffer_k<<>>((const float*) fused_input_data, (float*) buffer_data, + num_elements, (float) scale_factor); + break; + case HOROVOD_FLOAT64: + scale_buffer_k<<>>((const double*) fused_input_data, (double*) buffer_data, + num_elements, scale_factor); + break; + default: + throw std::logic_error("Type " + DataType_Name(dtype) + + " not supported by ScaleBufferCudaImpl."); + } +} + +template +__device__ void batched_scaled_memcpy_d(size_t idx, const T* input, T* output, size_t size, const TS scale_factor) { + + const int64_t num_words = size / sizeof(TL); + const TL* read_ptr = reinterpret_cast(input); + TL* write_ptr = reinterpret_cast(output); + for (size_t i = idx; i < num_words; i += blockDim.x * blocks_per_copy) { + // Load word + TL word = read_ptr[i]; + T* val = reinterpret_cast(&word); + + // Scale elements in word + for (int j = 0; j < sizeof(TL) / sizeof(T); ++j) { + val[j] *= scale_factor; + } + + // Write word + write_ptr[i] = word; + } + + // Deal with any remaining elements + size_t remainder = (size % sizeof(TL)) / sizeof(T); + if (remainder > 0 && idx < remainder) { + const T* input_r = reinterpret_cast(read_ptr + num_words); + T* output_r = reinterpret_cast(write_ptr + num_words); + output_r[idx] = scale_factor * input_r[idx]; + } +} + +// Specialization for architectures without __half compute +template +__device__ void batched_scaled_memcpy_d(size_t idx, const __half* input, __half* output, size_t size, const __half scale_factor) { + + const int64_t num_words = size / sizeof(TL); + const TL* read_ptr = reinterpret_cast(input); + TL* write_ptr = reinterpret_cast(output); + for (size_t i = idx; i < num_words; i += blockDim.x * blocks_per_copy) { + // Load word + TL word = read_ptr[i]; + __half* val = reinterpret_cast<__half*>(&word); + + // Scale elements in word + for (int j = 0; j < sizeof(TL) / sizeof(__half); ++j) { +#if __CUDA_ARCH__ > 530 + val[j] *= scale_factor; +#else + val[j] = __float2half(__half2float(scale_factor) * __half2float(val[j])); +#endif + } + + // Write word + write_ptr[i] = word; + } + + // Deal with any remaining elements + size_t remainder = (size % sizeof(TL)) / sizeof(__half); + if (remainder > 0 && idx < remainder) { + const __half* input_r = reinterpret_cast(read_ptr + num_words); + __half* output_r = reinterpret_cast<__half*>(write_ptr + num_words); +#if __CUDA_ARCH__ > 530 + output_r[idx] = scale_factor * input_r[idx]; +#else + output_r[idx] = __float2half(__half2float(scale_factor) * __half2float(input_r[idx])); +#endif + } +} + +template +__global__ void batched_scaled_memcpy_k(BatchedD2DParams params, TS scale_factor) { + const size_t idx = blockDim.x * (blockIdx.x % blocks_per_copy) + threadIdx.x; + + const size_t size = params.sizes[blockIdx.x / blocks_per_copy]; + const T* input = reinterpret_cast(params.in[blockIdx.x / blocks_per_copy]); + T* output = reinterpret_cast(params.out[blockIdx.x / blocks_per_copy]); + + // Check alignment relative to 16 bytes + size_t align_in = reinterpret_cast(input) % BATCHED_D2D_PADDING; + size_t align_out = reinterpret_cast(output) % BATCHED_D2D_PADDING; + + // Select load/store size based on the misaligned buffer + size_t align = (align_out == 0) ? align_in : align_out; + if (align_in && align_out) { + + // If both are misaligned, use datatype size + align = sizeof(T); + } + + if (align % 16 == 0) { + batched_scaled_memcpy_d(idx, input, output, size, scale_factor); + } else if (align % 8 == 0) { + batched_scaled_memcpy_d(idx, input, output, size, scale_factor); + } else if (align % 4 == 0) { + batched_scaled_memcpy_d(idx, input, output, size, scale_factor); + } else if (align % 2 == 0) { + batched_scaled_memcpy_d(idx, input, output, size, scale_factor); + } else { + batched_scaled_memcpy_d(idx, input, output, size, scale_factor); + } +} + +void BatchedScaledD2DMemcpyCudaImpl(BatchedD2DParams& params, int num_copies, double scale_factor, + DataType dtype, hipStream_t stream) { + const int64_t blocks = num_copies * BLOCKS_PER_COPY_D2D_KERNEL; + const int threads = NTHREADS_D2D_KERNEL; + switch (dtype) { + case HOROVOD_UINT8: + batched_scaled_memcpy_k<<>>(params, scale_factor); + break; + case HOROVOD_INT8: + batched_scaled_memcpy_k<<>>(params, scale_factor); + break; + case HOROVOD_INT32: + batched_scaled_memcpy_k<<>>(params, scale_factor); + break; + case HOROVOD_INT64: + batched_scaled_memcpy_k<<>>(params, scale_factor); + break; + case HOROVOD_FLOAT16: { + __half scale_factor_half = __float2half((float) scale_factor); + batched_scaled_memcpy_k<__half, BLOCKS_PER_COPY_D2D_KERNEL><<>>(params, scale_factor_half); + break; + } + case HOROVOD_FLOAT32: + batched_scaled_memcpy_k<<>>(params, (float) scale_factor); + break; + case HOROVOD_FLOAT64: + batched_scaled_memcpy_k<<>>(params, scale_factor); + break; + default: + throw std::logic_error("Type " + DataType_Name(dtype) + + " not supported by BatchedScaledD2DMemcpyCudaImpl."); + } +} + +} // namespace common +} // namespace horovod + diff --git a/horovod/common/ops/rocm/hip_kernels.h b/horovod/common/ops/rocm/hip_kernels.h new file mode 100644 index 0000000000..e99e05e389 --- /dev/null +++ b/horovod/common/ops/rocm/hip_kernels.h @@ -0,0 +1,48 @@ +// Copyright (C) 2020 NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef HIP_KERNELS_H +#define HIP_KERNELS_H + +#include + +#include "../../message.h" + +#define BATCHED_D2D_CAPACITY 160 +#define BATCHED_D2D_PADDING 16 + +namespace horovod { +namespace common { + +struct BatchedD2DParams { + void* out[BATCHED_D2D_CAPACITY]; + void* in[BATCHED_D2D_CAPACITY]; + size_t sizes[BATCHED_D2D_CAPACITY]; +}; + +// Performs a batched d2d memcopy +void BatchedD2DMemcpyCudaImpl(BatchedD2DParams& params, int num_copies, hipStream_t stream); + +// Scales buffer by scalar +void ScaleBufferCudaImpl(const void* fused_input_data, void* buffer_data, const int64_t num_elements, + double scale_factor, DataType dtype, hipStream_t stream); + +void BatchedScaledD2DMemcpyCudaImpl(BatchedD2DParams& params, int num_copies, double scale_factor, + DataType dtype, hipStream_t stream); + +} // namespace common +} // namespace horovod + +#endif // CUDA_KERNELS_H From db4e412ae2f59bb5223a08730d245a02f8493d27 Mon Sep 17 00:00:00 2001 From: weihanmines Date: Mon, 21 Mar 2022 23:19:23 +0000 Subject: [PATCH 03/11] complete ROCm implementaiton Signed-off-by: Wei Han Signed-off-by: weihanmines --- horovod/common/ops/gpu_operations.cc | 237 +++++++++++++++++++++++++ horovod/common/ops/rocm/CMakeLists.txt | 1 + horovod/common/ops/rocm/hip_kernels.cu | 10 +- horovod/common/ops/rocm/hip_kernels.h | 6 +- horovod/tensorflow/CMakeLists.txt | 7 + horovod/torch/CMakeLists.txt | 7 + 6 files changed, 260 insertions(+), 8 deletions(-) diff --git a/horovod/common/ops/gpu_operations.cc b/horovod/common/ops/gpu_operations.cc index 0078d604f7..aa62447800 100644 --- a/horovod/common/ops/gpu_operations.cc +++ b/horovod/common/ops/gpu_operations.cc @@ -18,6 +18,9 @@ #if HAVE_CUDA #include "cuda/cuda_kernels.h" #endif +#if HAVE_ROCM +#include "rocm/hip_kernels.h" +#endif #include @@ -213,6 +216,68 @@ void GPUAllreduce::MemcpyInFusionBuffer( } #endif +#if HAVE_ROCM +void GPUAllreduce::MemcpyInFusionBuffer( + const std::vector& entries, const void*& fused_input_data, + void*& buffer_data, size_t& buffer_len) { + // Access the fusion buffer. + auto& first_entry = entries[0]; + auto buffer = global_state_->fusion_buffer.GetBuffer( + first_entry.device, first_entry.context->framework(), + global_state_->current_nccl_stream); + buffer_data = const_cast(buffer->AccessData(first_entry.context)); + + if (global_state_->batch_d2d_memcopies) { + int64_t offset = 0; + int idx = 0; + int count = 0; + + BatchedD2DParams d2d_params; + auto& first_entry = entries[0]; + for (auto& e : entries) { + void* buffer_data_at_offset = (uint8_t*)buffer_data + offset; + + // Set input/output pointers and sizes + d2d_params.out[idx % BATCHED_D2D_CAPACITY] = buffer_data_at_offset; + d2d_params.in[idx % BATCHED_D2D_CAPACITY] = (void*)e.tensor->data(); + d2d_params.sizes[idx % BATCHED_D2D_CAPACITY] = e.tensor->size(); + + offset += + BATCHED_D2D_PADDING * + ((e.tensor->size() + BATCHED_D2D_PADDING - 1) / BATCHED_D2D_PADDING); + idx++; + count++; + + if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int)entries.size()) { + // Perform batched d2d memcpy + BatchedD2DMemcpyROCmImpl( + d2d_params, count, + gpu_context_->streams[global_state_->current_nccl_stream] + [first_entry.device]); + // TODO: https://github.com/horovod/horovod/issues/2230 + // gpu_context_->ErrorCheck("BatchedD2DMemcpyCudaImpl", + // cudaGetLastError()); + count = 0; + } + } + buffer_len = (size_t)offset; + + } else { + int64_t offset = 0; + for (auto& e : entries) { + void* buffer_data_at_offset = (uint8_t*)buffer_data + offset; + MemcpyEntryInFusionBuffer(entries, e, buffer_data_at_offset); + offset += e.tensor->size(); + } + + buffer_len = (size_t)offset; + } + + // Set the input data to originate from the buffer. + fused_input_data = buffer_data; +} +#endif + #if HAVE_CUDA void GPUAllreduce::ScaleMemcpyInFusionBuffer( const std::vector& entries, const void*& fused_input_data, @@ -280,6 +345,73 @@ void GPUAllreduce::ScaleMemcpyInFusionBuffer( } #endif +#if HAVE_ROCM +void GPUAllreduce::ScaleMemcpyInFusionBuffer( + const std::vector& entries, const void*& fused_input_data, + void*& buffer_data, size_t& buffer_len, double scale_factor) { + auto& first_entry = entries[0]; + // Access the fusion buffer. + auto buffer = global_state_->fusion_buffer.GetBuffer( + first_entry.device, first_entry.context->framework(), + global_state_->current_nccl_stream); + buffer_data = const_cast(buffer->AccessData(first_entry.context)); + + if (global_state_->batch_d2d_memcopies) { + int64_t offset = 0; + int idx = 0; + int count = 0; + + BatchedD2DParams d2d_params; + for (auto& e : entries) { + void* buffer_data_at_offset = (uint8_t*)buffer_data + offset; + + // Set input/output pointers and sizes + d2d_params.out[idx % BATCHED_D2D_CAPACITY] = buffer_data_at_offset; + d2d_params.in[idx % BATCHED_D2D_CAPACITY] = (void*)e.tensor->data(); + d2d_params.sizes[idx % BATCHED_D2D_CAPACITY] = e.tensor->size(); + + offset += + BATCHED_D2D_PADDING * + ((e.tensor->size() + BATCHED_D2D_PADDING - 1) / BATCHED_D2D_PADDING); + idx++; + count++; + + if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int)entries.size()) { + // Perform batched d2d memcpy + BatchedScaledD2DMemcpyROCmImpl( + d2d_params, count, scale_factor, first_entry.tensor->dtype(), + gpu_context_->streams[global_state_->current_nccl_stream] + [first_entry.device]); + // TODO: https://github.com/horovod/horovod/issues/2230 + // gpu_context_->ErrorCheck("BatchedScaledD2DMemcpyCudaImpl", + // cudaGetLastError()); + count = 0; + } + } + buffer_len = (size_t)offset; + + } else { + int64_t offset = 0; + for (auto& e : entries) { + void* buffer_data_at_offset = (uint8_t*)buffer_data + offset; + MemcpyEntryInFusionBuffer(entries, e, buffer_data_at_offset); + offset += e.tensor->size(); + } + + buffer_len = (size_t)offset; + int64_t num_elements = + buffer_len / DataType_Size(first_entry.tensor->dtype()); + if (scale_factor != 1.0) { + ScaleBuffer(scale_factor, entries, buffer_data, buffer_data, + num_elements); + } + } + + // Set the input data to originate from the buffer. + fused_input_data = buffer_data; +} +#endif + void GPUAllreduce::MemcpyEntryInFusionBuffer( const std::vector& entries, const TensorTableEntry& e, void* buffer_data_at_offset) { @@ -338,6 +470,54 @@ void GPUAllreduce::MemcpyOutFusionBuffer( } #endif +#if HAVE_ROCM +void GPUAllreduce::MemcpyOutFusionBuffer( + const void* buffer_data, std::vector& entries) { + if (global_state_->batch_d2d_memcopies) { + int64_t offset = 0; + int idx = 0; + int count = 0; + + BatchedD2DParams d2d_params; + auto& first_entry = entries[0]; + for (auto& e : entries) { + void* buffer_data_at_offset = (uint8_t*)buffer_data + offset; + + // Set input/output pointers and sizes + d2d_params.out[idx % BATCHED_D2D_CAPACITY] = (void*)(e.output->data()); + d2d_params.in[idx % BATCHED_D2D_CAPACITY] = buffer_data_at_offset; + d2d_params.sizes[idx % BATCHED_D2D_CAPACITY] = e.tensor->size(); + + offset += + BATCHED_D2D_PADDING * + ((e.tensor->size() + BATCHED_D2D_PADDING - 1) / BATCHED_D2D_PADDING); + idx++; + count++; + + if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int)entries.size()) { + // Perform batched d2d memcpy + BatchedD2DMemcpyROCmImpl( + d2d_params, count, + gpu_context_->streams[global_state_->current_nccl_stream] + [first_entry.device]); + // TODO: https://github.com/horovod/horovod/issues/2230 + // gpu_context_->ErrorCheck("BatchedD2DMemcpyCudaImpl", + // cudaGetLastError()); + count = 0; + } + } + + } else { + int64_t offset = 0; + for (auto& e : entries) { + void* buffer_data_at_offset = (uint8_t*)buffer_data + offset; + MemcpyEntryOutFusionBuffer(entries, buffer_data_at_offset, e); + offset += e.tensor->size(); + } + } +} +#endif + #if HAVE_CUDA void GPUAllreduce::ScaleMemcpyOutFusionBuffer( void* buffer_data, size_t buffer_len, double scale_factor, @@ -395,6 +575,63 @@ void GPUAllreduce::ScaleMemcpyOutFusionBuffer( } #endif +#if HAVE_ROCM +void GPUAllreduce::ScaleMemcpyOutFusionBuffer( + void* buffer_data, size_t buffer_len, double scale_factor, + std::vector& entries) { + auto& first_entry = entries[0]; + + if (global_state_->batch_d2d_memcopies) { + int64_t offset = 0; + int idx = 0; + int count = 0; + + BatchedD2DParams d2d_params; + for (auto& e : entries) { + void* buffer_data_at_offset = (uint8_t*)buffer_data + offset; + + // Set input/output pointers and sizes + d2d_params.out[idx % BATCHED_D2D_CAPACITY] = (void*)(e.output->data()); + d2d_params.in[idx % BATCHED_D2D_CAPACITY] = buffer_data_at_offset; + d2d_params.sizes[idx % BATCHED_D2D_CAPACITY] = e.tensor->size(); + + offset += + BATCHED_D2D_PADDING * + ((e.tensor->size() + BATCHED_D2D_PADDING - 1) / BATCHED_D2D_PADDING); + idx++; + count++; + + if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int)entries.size()) { + // Perform batched d2d memcpy + BatchedScaledD2DMemcpyROCmImpl( + d2d_params, count, scale_factor, first_entry.tensor->dtype(), + gpu_context_->streams[global_state_->current_nccl_stream] + [first_entry.device]); + // TODO: https://github.com/horovod/horovod/issues/2230 + // gpu_context_->ErrorCheck("BatchedD2DMemcpyCudaImpl", + // cudaGetLastError()); + count = 0; + } + } + + } else { + int64_t num_elements = + buffer_len / DataType_Size(first_entry.tensor->dtype()); + if (scale_factor != 1.0) { + ScaleBuffer(scale_factor, entries, buffer_data, buffer_data, + num_elements); + } + + int64_t offset = 0; + for (auto& e : entries) { + void* buffer_data_at_offset = (uint8_t*)buffer_data + offset; + MemcpyEntryOutFusionBuffer(entries, buffer_data_at_offset, e); + offset += e.tensor->size(); + } + } +} +#endif + void GPUAllreduce::MemcpyEntryOutFusionBuffer( const std::vector& entries, const void* buffer_data_at_offset, TensorTableEntry& e) { diff --git a/horovod/common/ops/rocm/CMakeLists.txt b/horovod/common/ops/rocm/CMakeLists.txt index eec1f0c293..b20b85ec0e 100644 --- a/horovod/common/ops/rocm/CMakeLists.txt +++ b/horovod/common/ops/rocm/CMakeLists.txt @@ -1,3 +1,4 @@ +message(STATUS "Built Horovod for ROCm") if (NOT DEFINED HCC_APTH) if (DEFINED ENV{HCC_PATH}) set(HIP_PATH ${HCC_PATH} CACHE PATH "Path to which HCC has been installed") diff --git a/horovod/common/ops/rocm/hip_kernels.cu b/horovod/common/ops/rocm/hip_kernels.cu index af463a3de9..8fe8e1caa9 100644 --- a/horovod/common/ops/rocm/hip_kernels.cu +++ b/horovod/common/ops/rocm/hip_kernels.cu @@ -77,7 +77,7 @@ __global__ void batched_memcpy_k(BatchedD2DParams params) { #define NTHREADS_D2D_KERNEL 1024 #define BLOCKS_PER_COPY_D2D_KERNEL 8 -void BatchedD2DMemcpyCudaImpl(BatchedD2DParams& params, int num_copies, hipStream_t stream) +void BatchedD2DMemcpyROCmImpl(BatchedD2DParams& params, int num_copies, hipStream_t stream) { batched_memcpy_k<<>>(params); @@ -136,7 +136,7 @@ __global__ void scale_buffer_k(const __half* input, __half* output, int64_t num_ } #define NTHREADS_SCALE_BUFFER_KERNEL 512 -void ScaleBufferCudaImpl(const void* fused_input_data, void* buffer_data, const int64_t num_elements, double scale_factor, +void ScaleBufferROCmImpl(const void* fused_input_data, void* buffer_data, const int64_t num_elements, double scale_factor, DataType dtype, hipStream_t stream) { const int64_t blocks = (num_elements + NTHREADS_SCALE_BUFFER_KERNEL - 1) / NTHREADS_SCALE_BUFFER_KERNEL; const int threads = NTHREADS_SCALE_BUFFER_KERNEL; @@ -182,7 +182,7 @@ void ScaleBufferCudaImpl(const void* fused_input_data, void* buffer_data, const break; default: throw std::logic_error("Type " + DataType_Name(dtype) + - " not supported by ScaleBufferCudaImpl."); + " not supported by ScaleBufferROCmImpl."); } } @@ -286,7 +286,7 @@ __global__ void batched_scaled_memcpy_k(BatchedD2DParams params, TS scale_factor } } -void BatchedScaledD2DMemcpyCudaImpl(BatchedD2DParams& params, int num_copies, double scale_factor, +void BatchedScaledD2DMemcpyROCmImpl(BatchedD2DParams& params, int num_copies, double scale_factor, DataType dtype, hipStream_t stream) { const int64_t blocks = num_copies * BLOCKS_PER_COPY_D2D_KERNEL; const int threads = NTHREADS_D2D_KERNEL; @@ -316,7 +316,7 @@ void BatchedScaledD2DMemcpyCudaImpl(BatchedD2DParams& params, int num_copies, do break; default: throw std::logic_error("Type " + DataType_Name(dtype) + - " not supported by BatchedScaledD2DMemcpyCudaImpl."); + " not supported by BatchedScaledD2DMemcpyROCmImpl."); } } diff --git a/horovod/common/ops/rocm/hip_kernels.h b/horovod/common/ops/rocm/hip_kernels.h index e99e05e389..d3d99f490c 100644 --- a/horovod/common/ops/rocm/hip_kernels.h +++ b/horovod/common/ops/rocm/hip_kernels.h @@ -33,13 +33,13 @@ struct BatchedD2DParams { }; // Performs a batched d2d memcopy -void BatchedD2DMemcpyCudaImpl(BatchedD2DParams& params, int num_copies, hipStream_t stream); +void BatchedD2DMemcpyROCmImpl(BatchedD2DParams& params, int num_copies, hipStream_t stream); // Scales buffer by scalar -void ScaleBufferCudaImpl(const void* fused_input_data, void* buffer_data, const int64_t num_elements, +void ScaleBufferROCmImpl(const void* fused_input_data, void* buffer_data, const int64_t num_elements, double scale_factor, DataType dtype, hipStream_t stream); -void BatchedScaledD2DMemcpyCudaImpl(BatchedD2DParams& params, int num_copies, double scale_factor, +void BatchedScaledD2DMemcpyROCmImpl(BatchedD2DParams& params, int num_copies, double scale_factor, DataType dtype, hipStream_t stream); } // namespace common diff --git a/horovod/tensorflow/CMakeLists.txt b/horovod/tensorflow/CMakeLists.txt index 875d5a7e47..fdf7ebe43d 100644 --- a/horovod/tensorflow/CMakeLists.txt +++ b/horovod/tensorflow/CMakeLists.txt @@ -52,6 +52,13 @@ if(HAVE_CUDA) list(APPEND TF_LINKER_LIBS compatible_horovod_cuda_kernels) endif() endif() +if(HAVE_ROCM) + if (Tensorflow_CXX11) + list(APPEND TF_LINKER_LIBS horovod_cuda_kernels) + else() + list(APPEND TF_LINKER_LIBS compatible_horovod_cuda_kernels) + endif() +endif() set(CMAKE_CXX_FLAGS "${Tensorflow_COMPILE_FLAGS} ${CMAKE_CXX_FLAGS}") parse_version(${Tensorflow_VERSION} VERSION_DEC) add_definitions(-DTENSORFLOW_VERSION=${VERSION_DEC}) diff --git a/horovod/torch/CMakeLists.txt b/horovod/torch/CMakeLists.txt index aadab89854..d0f86fae2d 100644 --- a/horovod/torch/CMakeLists.txt +++ b/horovod/torch/CMakeLists.txt @@ -53,6 +53,13 @@ if(HAVE_CUDA) list(APPEND PYTORCH_LINKER_LIBS compatible_horovod_cuda_kernels) endif() endif() +if(HAVE_ROCM) + if (Pytorch_CXX11) + list(APPEND PYTORCH_LINKER_LIBS horovod_cuda_kernels) + else() + list(APPEND PYTORCH_LINKER_LIBS compatible_horovod_cuda_kernels) + endif() +endif() parse_version(${Pytorch_VERSION} VERSION_DEC) add_definitions(-DPYTORCH_VERSION=${VERSION_DEC} -DTORCH_API_INCLUDE_EXTENSION_H=1) set(Pytorch_CXX11 ${Pytorch_CXX11} PARENT_SCOPE) From 8fa52b779313c4472e23e182f39c71a2d502b787 Mon Sep 17 00:00:00 2001 From: weihanmines Date: Wed, 27 Apr 2022 05:19:59 +0000 Subject: [PATCH 04/11] platform string changed in XLA backend Signed-off-by: weihanmines --- horovod/tensorflow/xla_mpi_ops.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/horovod/tensorflow/xla_mpi_ops.cc b/horovod/tensorflow/xla_mpi_ops.cc index 7c06232584..5894999d03 100644 --- a/horovod/tensorflow/xla_mpi_ops.cc +++ b/horovod/tensorflow/xla_mpi_ops.cc @@ -19,6 +19,7 @@ #include #include + #if TENSORFLOW_VERSION >= 2006000000 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -570,7 +571,6 @@ XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "CUDA"); } // namespace tensorflow } // namespace horovod -#endif // TENSORFLOW_VERSION >= 2006000000 #endif // HAVE_CUDA #if HAVE_ROCM @@ -1095,11 +1095,12 @@ void CallbackHVDAllreduceDone(hipStream_t stream, void** /*buffers*/, VLOG(2) << "hvd-allreduce-done - End"; } -XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduce, "ROCm"); -XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "ROCm"); +XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduce, "ROCM"); +XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "ROCM"); } // namespace } // namespace tensorflow } // namespace horovod #endif //HAVE_ROCM #endif // HAVE_GPU +#endif // TENSORFLOW_VERSION >= 2006000000 From 703ec25ef6e7e21cfef392c63a34fbe8973e60e7 Mon Sep 17 00:00:00 2001 From: weihanmines Date: Mon, 9 May 2022 20:43:40 +0000 Subject: [PATCH 05/11] remove duplates in gpu operations Signed-off-by: weihanmines --- horovod/common/ops/gpu_operations.cc | 238 ++------------------------- horovod/common/ops/gpu_operations.h | 18 +- 2 files changed, 17 insertions(+), 239 deletions(-) diff --git a/horovod/common/ops/gpu_operations.cc b/horovod/common/ops/gpu_operations.cc index aa62447800..61b3c89274 100644 --- a/horovod/common/ops/gpu_operations.cc +++ b/horovod/common/ops/gpu_operations.cc @@ -154,7 +154,7 @@ bool GPUAllreduce::Enabled(const ParameterManager& param_manager, return entries[0].device != CPU_DEVICE_ID; } -#if HAVE_CUDA +#if HAVE_CUDA || HAVE_ROCM void GPUAllreduce::MemcpyInFusionBuffer( const std::vector& entries, const void*& fused_input_data, void*& buffer_data, size_t& buffer_len) { @@ -188,72 +188,17 @@ void GPUAllreduce::MemcpyInFusionBuffer( if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int)entries.size()) { // Perform batched d2d memcpy +#if HAVE_CUDA BatchedD2DMemcpyCudaImpl( d2d_params, count, gpu_context_->streams[global_state_->current_nccl_stream] [first_entry.device]); - // TODO: https://github.com/horovod/horovod/issues/2230 - // gpu_context_->ErrorCheck("BatchedD2DMemcpyCudaImpl", - // cudaGetLastError()); - count = 0; - } - } - buffer_len = (size_t)offset; - - } else { - int64_t offset = 0; - for (auto& e : entries) { - void* buffer_data_at_offset = (uint8_t*)buffer_data + offset; - MemcpyEntryInFusionBuffer(entries, e, buffer_data_at_offset); - offset += e.tensor->size(); - } - - buffer_len = (size_t)offset; - } - - // Set the input data to originate from the buffer. - fused_input_data = buffer_data; -} -#endif - -#if HAVE_ROCM -void GPUAllreduce::MemcpyInFusionBuffer( - const std::vector& entries, const void*& fused_input_data, - void*& buffer_data, size_t& buffer_len) { - // Access the fusion buffer. - auto& first_entry = entries[0]; - auto buffer = global_state_->fusion_buffer.GetBuffer( - first_entry.device, first_entry.context->framework(), - global_state_->current_nccl_stream); - buffer_data = const_cast(buffer->AccessData(first_entry.context)); - - if (global_state_->batch_d2d_memcopies) { - int64_t offset = 0; - int idx = 0; - int count = 0; - - BatchedD2DParams d2d_params; - auto& first_entry = entries[0]; - for (auto& e : entries) { - void* buffer_data_at_offset = (uint8_t*)buffer_data + offset; - - // Set input/output pointers and sizes - d2d_params.out[idx % BATCHED_D2D_CAPACITY] = buffer_data_at_offset; - d2d_params.in[idx % BATCHED_D2D_CAPACITY] = (void*)e.tensor->data(); - d2d_params.sizes[idx % BATCHED_D2D_CAPACITY] = e.tensor->size(); - - offset += - BATCHED_D2D_PADDING * - ((e.tensor->size() + BATCHED_D2D_PADDING - 1) / BATCHED_D2D_PADDING); - idx++; - count++; - - if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int)entries.size()) { - // Perform batched d2d memcpy +#elif HAVE_ROCM BatchedD2DMemcpyROCmImpl( d2d_params, count, gpu_context_->streams[global_state_->current_nccl_stream] [first_entry.device]); +#endif // TODO: https://github.com/horovod/horovod/issues/2230 // gpu_context_->ErrorCheck("BatchedD2DMemcpyCudaImpl", // cudaGetLastError()); @@ -278,7 +223,7 @@ void GPUAllreduce::MemcpyInFusionBuffer( } #endif -#if HAVE_CUDA +#if HAVE_CUDA || HAVE_ROCM void GPUAllreduce::ScaleMemcpyInFusionBuffer( const std::vector& entries, const void*& fused_input_data, void*& buffer_data, size_t& buffer_len, double scale_factor) { @@ -311,77 +256,17 @@ void GPUAllreduce::ScaleMemcpyInFusionBuffer( if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int)entries.size()) { // Perform batched d2d memcpy +#if HAVE_CUDA BatchedScaledD2DMemcpyCudaImpl( d2d_params, count, scale_factor, first_entry.tensor->dtype(), gpu_context_->streams[global_state_->current_nccl_stream] [first_entry.device]); - // TODO: https://github.com/horovod/horovod/issues/2230 - // gpu_context_->ErrorCheck("BatchedScaledD2DMemcpyCudaImpl", - // cudaGetLastError()); - count = 0; - } - } - buffer_len = (size_t)offset; - - } else { - int64_t offset = 0; - for (auto& e : entries) { - void* buffer_data_at_offset = (uint8_t*)buffer_data + offset; - MemcpyEntryInFusionBuffer(entries, e, buffer_data_at_offset); - offset += e.tensor->size(); - } - - buffer_len = (size_t)offset; - int64_t num_elements = - buffer_len / DataType_Size(first_entry.tensor->dtype()); - if (scale_factor != 1.0) { - ScaleBuffer(scale_factor, entries, buffer_data, buffer_data, - num_elements); - } - } - - // Set the input data to originate from the buffer. - fused_input_data = buffer_data; -} -#endif - -#if HAVE_ROCM -void GPUAllreduce::ScaleMemcpyInFusionBuffer( - const std::vector& entries, const void*& fused_input_data, - void*& buffer_data, size_t& buffer_len, double scale_factor) { - auto& first_entry = entries[0]; - // Access the fusion buffer. - auto buffer = global_state_->fusion_buffer.GetBuffer( - first_entry.device, first_entry.context->framework(), - global_state_->current_nccl_stream); - buffer_data = const_cast(buffer->AccessData(first_entry.context)); - - if (global_state_->batch_d2d_memcopies) { - int64_t offset = 0; - int idx = 0; - int count = 0; - - BatchedD2DParams d2d_params; - for (auto& e : entries) { - void* buffer_data_at_offset = (uint8_t*)buffer_data + offset; - - // Set input/output pointers and sizes - d2d_params.out[idx % BATCHED_D2D_CAPACITY] = buffer_data_at_offset; - d2d_params.in[idx % BATCHED_D2D_CAPACITY] = (void*)e.tensor->data(); - d2d_params.sizes[idx % BATCHED_D2D_CAPACITY] = e.tensor->size(); - - offset += - BATCHED_D2D_PADDING * - ((e.tensor->size() + BATCHED_D2D_PADDING - 1) / BATCHED_D2D_PADDING); - idx++; - count++; - - if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int)entries.size()) { - // Perform batched d2d memcpy +#elif HAVE_ROCM BatchedScaledD2DMemcpyROCmImpl( d2d_params, count, scale_factor, first_entry.tensor->dtype(), gpu_context_->streams[global_state_->current_nccl_stream] [first_entry.device]); +#endif // TODO: https://github.com/horovod/horovod/issues/2230 // gpu_context_->ErrorCheck("BatchedScaledD2DMemcpyCudaImpl", // cudaGetLastError()); @@ -422,7 +307,7 @@ void GPUAllreduce::MemcpyEntryInFusionBuffer( ->streams[global_state_->current_nccl_stream][first_entry.device]); } -#if HAVE_CUDA +#if HAVE_CUDA || HAVE_ROCM void GPUAllreduce::MemcpyOutFusionBuffer( const void* buffer_data, std::vector& entries) { if (global_state_->batch_d2d_memcopies) { @@ -448,58 +333,17 @@ void GPUAllreduce::MemcpyOutFusionBuffer( if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int)entries.size()) { // Perform batched d2d memcpy +#if HAVE_CUDA BatchedD2DMemcpyCudaImpl( d2d_params, count, gpu_context_->streams[global_state_->current_nccl_stream] [first_entry.device]); - // TODO: https://github.com/horovod/horovod/issues/2230 - // gpu_context_->ErrorCheck("BatchedD2DMemcpyCudaImpl", - // cudaGetLastError()); - count = 0; - } - } - - } else { - int64_t offset = 0; - for (auto& e : entries) { - void* buffer_data_at_offset = (uint8_t*)buffer_data + offset; - MemcpyEntryOutFusionBuffer(entries, buffer_data_at_offset, e); - offset += e.tensor->size(); - } - } -} -#endif - -#if HAVE_ROCM -void GPUAllreduce::MemcpyOutFusionBuffer( - const void* buffer_data, std::vector& entries) { - if (global_state_->batch_d2d_memcopies) { - int64_t offset = 0; - int idx = 0; - int count = 0; - - BatchedD2DParams d2d_params; - auto& first_entry = entries[0]; - for (auto& e : entries) { - void* buffer_data_at_offset = (uint8_t*)buffer_data + offset; - - // Set input/output pointers and sizes - d2d_params.out[idx % BATCHED_D2D_CAPACITY] = (void*)(e.output->data()); - d2d_params.in[idx % BATCHED_D2D_CAPACITY] = buffer_data_at_offset; - d2d_params.sizes[idx % BATCHED_D2D_CAPACITY] = e.tensor->size(); - - offset += - BATCHED_D2D_PADDING * - ((e.tensor->size() + BATCHED_D2D_PADDING - 1) / BATCHED_D2D_PADDING); - idx++; - count++; - - if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int)entries.size()) { - // Perform batched d2d memcpy +#elif HAVE_ROCM BatchedD2DMemcpyROCmImpl( d2d_params, count, gpu_context_->streams[global_state_->current_nccl_stream] [first_entry.device]); +#endif // TODO: https://github.com/horovod/horovod/issues/2230 // gpu_context_->ErrorCheck("BatchedD2DMemcpyCudaImpl", // cudaGetLastError()); @@ -518,7 +362,7 @@ void GPUAllreduce::MemcpyOutFusionBuffer( } #endif -#if HAVE_CUDA +#if HAVE_CUDA || HAVE_ROCM void GPUAllreduce::ScaleMemcpyOutFusionBuffer( void* buffer_data, size_t buffer_len, double scale_factor, std::vector& entries) { @@ -546,67 +390,17 @@ void GPUAllreduce::ScaleMemcpyOutFusionBuffer( if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int)entries.size()) { // Perform batched d2d memcpy +#if HAVE_CUDA BatchedScaledD2DMemcpyCudaImpl( d2d_params, count, scale_factor, first_entry.tensor->dtype(), gpu_context_->streams[global_state_->current_nccl_stream] [first_entry.device]); - // TODO: https://github.com/horovod/horovod/issues/2230 - // gpu_context_->ErrorCheck("BatchedD2DMemcpyCudaImpl", - // cudaGetLastError()); - count = 0; - } - } - - } else { - int64_t num_elements = - buffer_len / DataType_Size(first_entry.tensor->dtype()); - if (scale_factor != 1.0) { - ScaleBuffer(scale_factor, entries, buffer_data, buffer_data, - num_elements); - } - - int64_t offset = 0; - for (auto& e : entries) { - void* buffer_data_at_offset = (uint8_t*)buffer_data + offset; - MemcpyEntryOutFusionBuffer(entries, buffer_data_at_offset, e); - offset += e.tensor->size(); - } - } -} -#endif - -#if HAVE_ROCM -void GPUAllreduce::ScaleMemcpyOutFusionBuffer( - void* buffer_data, size_t buffer_len, double scale_factor, - std::vector& entries) { - auto& first_entry = entries[0]; - - if (global_state_->batch_d2d_memcopies) { - int64_t offset = 0; - int idx = 0; - int count = 0; - - BatchedD2DParams d2d_params; - for (auto& e : entries) { - void* buffer_data_at_offset = (uint8_t*)buffer_data + offset; - - // Set input/output pointers and sizes - d2d_params.out[idx % BATCHED_D2D_CAPACITY] = (void*)(e.output->data()); - d2d_params.in[idx % BATCHED_D2D_CAPACITY] = buffer_data_at_offset; - d2d_params.sizes[idx % BATCHED_D2D_CAPACITY] = e.tensor->size(); - - offset += - BATCHED_D2D_PADDING * - ((e.tensor->size() + BATCHED_D2D_PADDING - 1) / BATCHED_D2D_PADDING); - idx++; - count++; - - if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int)entries.size()) { - // Perform batched d2d memcpy +#elif HAVE_ROCM BatchedScaledD2DMemcpyROCmImpl( d2d_params, count, scale_factor, first_entry.tensor->dtype(), gpu_context_->streams[global_state_->current_nccl_stream] [first_entry.device]); +#endif // TODO: https://github.com/horovod/horovod/issues/2230 // gpu_context_->ErrorCheck("BatchedD2DMemcpyCudaImpl", // cudaGetLastError()); diff --git a/horovod/common/ops/gpu_operations.h b/horovod/common/ops/gpu_operations.h index c538b1cea0..0cebbc0485 100644 --- a/horovod/common/ops/gpu_operations.h +++ b/horovod/common/ops/gpu_operations.h @@ -160,23 +160,7 @@ class GPUAllreduce : public AllreduceOp { const Response& response) const override; protected: -#if HAVE_CUDA - void MemcpyInFusionBuffer(const std::vector& entries, - const void*& fused_input_data, void*& buffer_data, - size_t& buffer_len) override; - - void MemcpyOutFusionBuffer(const void* buffer_data, - std::vector& entries) override; - - void ScaleMemcpyInFusionBuffer(const std::vector& entries, - const void*& fused_input_data, - void*& buffer_data, size_t& buffer_len, - double scale_factor); - void ScaleMemcpyOutFusionBuffer(void* buffer_data, size_t buffer_len, - double scale_factor, - std::vector& entries); -#endif -#if HAVE_ROCM +#if HAVE_CUDA || HAVE_ROCM void MemcpyInFusionBuffer(const std::vector& entries, const void*& fused_input_data, void*& buffer_data, size_t& buffer_len) override; From 7ad8abeeb7bec350503bee076fc36b522e6a2f07 Mon Sep 17 00:00:00 2001 From: weihanmines Date: Tue, 10 May 2022 16:00:31 +0000 Subject: [PATCH 06/11] use HAVE_GPU instead of HAVE_CUDA || HAVE_ROCM Signed-off-by: weihanmines --- horovod/common/ops/gpu_operations.cc | 8 ++++---- horovod/common/ops/gpu_operations.h | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/horovod/common/ops/gpu_operations.cc b/horovod/common/ops/gpu_operations.cc index 61b3c89274..0adef04198 100644 --- a/horovod/common/ops/gpu_operations.cc +++ b/horovod/common/ops/gpu_operations.cc @@ -154,7 +154,7 @@ bool GPUAllreduce::Enabled(const ParameterManager& param_manager, return entries[0].device != CPU_DEVICE_ID; } -#if HAVE_CUDA || HAVE_ROCM +#if HAVE_GPU void GPUAllreduce::MemcpyInFusionBuffer( const std::vector& entries, const void*& fused_input_data, void*& buffer_data, size_t& buffer_len) { @@ -223,7 +223,7 @@ void GPUAllreduce::MemcpyInFusionBuffer( } #endif -#if HAVE_CUDA || HAVE_ROCM +#if HAVE_GPU void GPUAllreduce::ScaleMemcpyInFusionBuffer( const std::vector& entries, const void*& fused_input_data, void*& buffer_data, size_t& buffer_len, double scale_factor) { @@ -307,7 +307,7 @@ void GPUAllreduce::MemcpyEntryInFusionBuffer( ->streams[global_state_->current_nccl_stream][first_entry.device]); } -#if HAVE_CUDA || HAVE_ROCM +#if HAVE_GPU void GPUAllreduce::MemcpyOutFusionBuffer( const void* buffer_data, std::vector& entries) { if (global_state_->batch_d2d_memcopies) { @@ -362,7 +362,7 @@ void GPUAllreduce::MemcpyOutFusionBuffer( } #endif -#if HAVE_CUDA || HAVE_ROCM +#if HAVE_GPU void GPUAllreduce::ScaleMemcpyOutFusionBuffer( void* buffer_data, size_t buffer_len, double scale_factor, std::vector& entries) { diff --git a/horovod/common/ops/gpu_operations.h b/horovod/common/ops/gpu_operations.h index 0cebbc0485..709aaa36fa 100644 --- a/horovod/common/ops/gpu_operations.h +++ b/horovod/common/ops/gpu_operations.h @@ -160,7 +160,7 @@ class GPUAllreduce : public AllreduceOp { const Response& response) const override; protected: -#if HAVE_CUDA || HAVE_ROCM +#if HAVE_GPU void MemcpyInFusionBuffer(const std::vector& entries, const void*& fused_input_data, void*& buffer_data, size_t& buffer_len) override; From 7fde67ae087c7e89842d675fb72c88891e845323 Mon Sep 17 00:00:00 2001 From: weihanmines Date: Tue, 10 May 2022 21:08:46 +0000 Subject: [PATCH 07/11] remove duplication in xla mpi ops impl Signed-off-by: weihanmines --- horovod/common/common.h | 6 + horovod/tensorflow/xla_mpi_ops.cc | 584 +++--------------------------- 2 files changed, 49 insertions(+), 541 deletions(-) diff --git a/horovod/common/common.h b/horovod/common/common.h index f46c4fc81a..46565437e7 100644 --- a/horovod/common/common.h +++ b/horovod/common/common.h @@ -33,9 +33,12 @@ using gpuError_t = cudaError_t; using gpuEvent_t = cudaEvent_t; using gpuStream_t = cudaStream_t; +using gpuPointerAttribute_t = cudaPointerAttributes; #define gpuEventCreateWithFlags cudaEventCreateWithFlags #define gpuEventDisableTiming cudaEventDisableTiming #define gpuEventRecord cudaEventRecord +#define gpuEventQuery cudaEventQuery +#define gpuErrorNotReady cudaErrorNotReady #define gpuEventSynchronize cudaEventSynchronize #define gpuStreamWaitEvent cudaStreamWaitEvent #define HVD_GPU_CHECK(x) \ @@ -50,9 +53,12 @@ using gpuStream_t = cudaStream_t; using gpuError_t = hipError_t; using gpuEvent_t = hipEvent_t; using gpuStream_t = hipStream_t; +using gpuPointerAttribute_t = hipPointerAttribute_t; #define gpuEventCreateWithFlags hipEventCreateWithFlags #define gpuEventDisableTiming hipEventDisableTiming #define gpuEventRecord hipEventRecord +#define gpuEventQuery hipEventQuery +#define gpuErrorNotReady hipErrorNotReady #define gpuEventSynchronize hipEventSynchronize #define gpuStreamWaitEvent hipStreamWaitEvent #define HVD_GPU_CHECK(x) \ diff --git a/horovod/tensorflow/xla_mpi_ops.cc b/horovod/tensorflow/xla_mpi_ops.cc index 5894999d03..a52baa4e50 100644 --- a/horovod/tensorflow/xla_mpi_ops.cc +++ b/horovod/tensorflow/xla_mpi_ops.cc @@ -34,6 +34,7 @@ #include "tensorflow/core/platform/human_readable_json.h" #if HAVE_GPU +#include "../common/common.h" #if HAVE_CUDA #include @@ -50,6 +51,15 @@ #include "../common/operations.h" #include "../common/utils/env_parser.h" #include "./custom_call_config_generated.h" +#elif HAVE_ROCM + +#include + +#define OMPI_SKIP_MPICXX +#include "../common/operations.h" +#include "../common/utils/env_parser.h" +#include "./custom_call_config_generated.h" +#endif // HAVE_CUDA using namespace tensorflow; @@ -336,7 +346,7 @@ class HVDCustomCallRendezvous { // outstanding `Wait` call due to its blocking nature to simplify the // implementation. Consequently, this method always operates on the very // first item in the queue. - void Wait(string tensor_name, CUstream stream) { + void Wait(string tensor_name, gpuStream_t stream) { uint64 key_hash = GetRendezvousKeyHash(tensor_name); { @@ -372,7 +382,11 @@ class HVDCustomCallRendezvous { delete queue; } if (event) { +#if HAVE_CUDA CUDA_CALL(cudaStreamWaitEvent(stream, *event, /*flags=*/0)); +#elif HAVE_ROCM + HVD_GPU_CHECK(hipStreamWaitEvent(stream, *event, /*flags=*/0)); +#endif } delete payload; } @@ -404,21 +418,28 @@ class HVDCustomCallRendezvous { class XLAReadyEvent : public common::ReadyEvent { public: - XLAReadyEvent(cudaStream_t stream) : stream_(stream) { + XLAReadyEvent(gpuStream_t stream) : stream_(stream) { +#if HAVE_CUDA CUDA_CALL(cudaEventCreate(&event_)); CUDA_CALL(cudaEventRecord(event_, stream)); } ~XLAReadyEvent() { CUDA_CALL(cudaEventDestroy(event_)); } +#elif HAVE_ROCM + HVD_GPU_CHECK(hipEventCreate(&event_)); + HVD_GPU_CHECK(hipEventRecord(event_, stream)); + } + ~XLAReadyEvent() { HVD_GPU_CHECK(hipEventDestroy(event_)); } +#endif bool Ready() const override { - cudaError_t result = cudaEventQuery(event_); - return cudaErrorNotReady != result; + gpuError_t result = gpuEventQuery(event_); + return gpuErrorNotReady != result; } gpuEvent_t event() const override { return event_; } private: - cudaStream_t stream_; // Not Owned. - cudaEvent_t event_; // Owned. + gpuStream_t stream_; // Not Owned. + gpuEvent_t event_; // Owned. }; class XLATensor : public common::Tensor { @@ -476,11 +497,19 @@ class XLAPersistentBuffer : public common::PersistentBuffer { XLAPersistentBuffer::XLAPersistentBuffer(int device, int64_t size) : device_(device) { int restore_device; +#if HAVE_CUDA CUDA_CALL(cudaGetDevice(&restore_device)); CUDA_CALL(cudaSetDevice(device)); // Simply call cudaMalloc for persistent buffer. CUDA_CALL(cudaMalloc((void**)&buffer_, size)); CUDA_CALL(cudaSetDevice(restore_device)); +#elif HAVE_ROCM + HVD_GPU_CHECK(hipGetDevice(&restore_device)); + HVD_GPU_CHECK(hipSetDevice(device)); + // Simply call cudaMalloc for persistent buffer. + HVD_GPU_CHECK(hipMalloc((void**)&buffer_, size)); + HVD_GPU_CHECK(hipSetDevice(restore_device)); +#endif } const void* XLAPersistentBuffer::AccessData( @@ -510,18 +539,22 @@ XLAOpContext::AllocateZeros(int64_t num_elements, common::DataType dtype, "AllocateZeros is not supported for XLA."); } -common::ReadyEvent* RecordReadyEvent(cudaStream_t stream) { +common::ReadyEvent* RecordReadyEvent(gpuStream_t stream) { return new XLAReadyEvent(stream); } int GetDeviceOrdinal(void* ptr) { - cudaPointerAttributes attrs; + gpuPointerAttribute_t attrs; +#if HAVE_CUDA CUDA_CALL(cudaPointerGetAttributes(&attrs, ptr)); +#elif HAVE_ROCM + HVD_GPU_CHECK(hipPointerGetAttributes(&attrs, ptr)); +#endif return attrs.device; } // Implements for the `HVDAllreduce` HLO CustomCall. -void CallbackHVDAllreduce(CUstream stream, void** buffers, const char* opaque, +void CallbackHVDAllreduce(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { CHECK(common::CheckInitialized().ok()); CustomCallConfig config; @@ -554,7 +587,7 @@ void CallbackHVDAllreduce(CUstream stream, void** buffers, const char* opaque, } // Implements for the `HVDAllreduceDone` HLO CustomCall. -void CallbackHVDAllreduceDone(CUstream stream, void** /*buffers*/, +void CallbackHVDAllreduceDone(gpuStream_t stream, void** /*buffers*/, const char* opaque, size_t opaque_len) { // Blocking until the request is done processing by the Horovod runtime. VLOG(2) << "hvd-allreduce-done - Start"; @@ -571,536 +604,5 @@ XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "CUDA"); } // namespace tensorflow } // namespace horovod -#endif // HAVE_CUDA -#if HAVE_ROCM - -#include -#include "../common/common.h" - -#define OMPI_SKIP_MPICXX -#include "../common/operations.h" -#include "../common/utils/env_parser.h" -#include "./custom_call_config_generated.h" - -using namespace tensorflow; - -namespace horovod { -namespace xla { -namespace { - -common::DataType GetHVDType(::xla::PrimitiveType type) { - switch (type) { - case ::xla::U8: - return common::HOROVOD_UINT8; - case ::xla::S8: - return common::HOROVOD_INT8; - case ::xla::U16: - return common::HOROVOD_UINT16; - case ::xla::S16: - return common::HOROVOD_INT16; - case ::xla::S32: - return common::HOROVOD_INT32; - case ::xla::S64: - return common::HOROVOD_INT64; - case ::xla::F16: - return common::HOROVOD_FLOAT16; - case ::xla::F32: - return common::HOROVOD_FLOAT32; - case ::xla::F64: - return common::HOROVOD_FLOAT64; - case ::xla::PRED: - return common::HOROVOD_BOOL; - default: - throw std::logic_error("Invalid XLA tensor type."); - } -} - -// CustomCallConfig stores configurations of Horovod ops. We pass this config -// to ::xla::CustomCall so that the XLA CustomCall can represent various Horovod -// ops. Flatbuffer is used to serialize the config into string to conform to the -// XLA CustomCall interface. -class CustomCallConfig { -public: - std::string SerializeToString(); - void ParseFromString(std::string); - -public: - std::string tensor_name_; - common::DataType tensor_type_; - std::vector> input_shapes_; - std::vector> output_shapes_; - float prescale_factor_; - float postscale_factor_; - int root_rank_; - int reduce_op_; - int process_set_id_; -}; - -std::string CustomCallConfig::SerializeToString() { - flatbuffers::FlatBufferBuilder fbb(1024); - - std::vector> input_shapes_obj; - absl::c_for_each(input_shapes_, [&](const std::vector& dims) { - input_shapes_obj.push_back(wire::CreateTensorShapeDirect(fbb, &dims)); - }); - std::vector> output_shapes_obj; - absl::c_for_each(output_shapes_, [&](const std::vector& dims) { - output_shapes_obj.push_back(wire::CreateTensorShapeDirect(fbb, &dims)); - }); - auto wire = wire::CreateCustomCallConfigDirect( - fbb, tensor_name_.c_str(), (common::wire::DataType)tensor_type_, - &input_shapes_obj, &output_shapes_obj, prescale_factor_, - postscale_factor_, root_rank_, reduce_op_, process_set_id_); - fbb.Finish(wire); - - uint8_t* buf = fbb.GetBufferPointer(); - auto size = fbb.GetSize(); - return std::string((char*)buf, size); -} - -void CustomCallConfig::ParseFromString(std::string input) { - const wire::CustomCallConfig* obj = - flatbuffers::GetRoot( - (const uint8_t*)input.data()); - - tensor_name_ = obj->tensor_name()->str(); - tensor_type_ = (common::DataType)obj->tensor_type(); - for (auto it = obj->input_shapes()->begin(); it != obj->input_shapes()->end(); - it++) { - auto shape_obj = *it; - input_shapes_.push_back(std::vector(shape_obj->dims()->begin(), - shape_obj->dims()->end())); - } - for (auto it = obj->output_shapes()->begin(); - it != obj->output_shapes()->end(); it++) { - auto shape_obj = *it; - output_shapes_.push_back(std::vector(shape_obj->dims()->begin(), - shape_obj->dims()->end())); - } - prescale_factor_ = obj->prescale_factor(); - postscale_factor_ = obj->postscale_factor(); - root_rank_ = obj->root_rank(); - reduce_op_ = obj->reduce_op(); - process_set_id_ = obj->process_set_id(); - - if (VLOG_IS_ON(2)) { - VLOG(2) << "tensor_name " << tensor_name_; - VLOG(2) << "tensor_type " << tensor_type_; - VLOG(2) << "prescale_factor = " << prescale_factor_; - VLOG(2) << "postscale_factor = " << postscale_factor_; - VLOG(2) << "root_rank = " << root_rank_; - VLOG(2) << "reduce_op = " << reduce_op_; - VLOG(2) << "process_set_id = " << process_set_id_; - } -} - -// HVDAllreduceOp is an XLAOpKernel that lowers the Tensorflow HorovodAllreduce -// op into XLA HLOs. The overall idea is to lower an Tensorflow op into two -// corresponding HLO custom-calls, `start` and `end` calls, so that the XLA can -// asynchronously interact with the Horovod runtime. The `start` call is always -// non-blocking for latency hiding and the `end` call could be blocking. For -// example, as shown in HVDAllreduceOp::Compile() below, the "HorovodAllreduce" -// op is lowered into the "CallbackHVDAllreduce" and "CallbackHVDAllreduceDone" -// HLO custom-calls, whose implementations are also provided through dynamic -// registration in this file. -class HVDAllreduceOp : public XlaOpKernel { -public: - explicit HVDAllreduceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_op", &reduce_op_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("prescale_factor", &prescale_factor_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("postscale_factor", &postscale_factor_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_name_scope", &ignore_name_scope_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("process_set_id", &process_set_id_)); - } - - void Compile(XlaOpKernelContext* ctx) override { - node_name_ = name(); - if (ignore_name_scope_) { - auto pos = node_name_.find_last_of('/'); - if (pos != std::string::npos) { - node_name_ = node_name_.substr(pos + 1); - } - } - - // Generate below HLOs: - // start = custom-call(in), custom_call_target="CallbackHVDAllreduce" - // end = custom-call(start), - // custom_call_target="CallbackHVDAllreduceDone" - // Note that tensors `in`, `start`, and `end'` are aliased, as we want the - // all-reduce operation to be in-place. - ::xla::XlaBuilder* const b = ctx->builder(); - // First, generate HVDAllreduce. - std::vector< - std::pair<::xla::ShapeIndex, std::pair>> - output_operand_aliasing = { - {::xla::ShapeIndex{}, {0, ::xla::ShapeIndex{}}}}; - ::xla::XlaOp input = ctx->Input(0); - ::xla::XlaOp allreduce_start = b->ReportErrorOrReturn( - BuildAllreduceCustomCall(b, {input}, /*is_start=*/true)); - // Then, generate HVDAllreduceDone. - ::xla::XlaOp allreduce_end = b->ReportErrorOrReturn( - BuildAllreduceCustomCall(b, {allreduce_start}, - /*is_start=*/false, output_operand_aliasing)); - ctx->SetOutput(0, allreduce_end); - return; - } - -private: - ::xla::StatusOr<::xla::XlaOp> BuildAllreduceCustomCall( - ::xla::XlaBuilder* b, absl::Span operands, - bool is_start, - absl::Span>> - output_operand_aliasing = {}); - -private: - std::string node_name_; - int reduce_op_; - // Using float since TF does not support double OP attributes - float prescale_factor_; - float postscale_factor_; - bool ignore_name_scope_; - int process_set_id_; -}; - -// Implements a customized registrar so that the registration is an opt-in, -// controlled by HOROVOD_ENABLE_XLA_OPS. -#define HVD_REGISTER_XLA_OP(NAME, OP) \ - HVD_REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP) - -#define HVD_REGISTER_XLA_OP_UNIQ_HELPER(COUNTER, OP_NAME, OP) \ - HVD_REGISTER_XLA_OP_UNIQ(COUNTER, OP_NAME, OP) - -#define HVD_REGISTER_XLA_OP_UNIQ(CTR, OP_NAME, OP) \ - static HVDXlaOpRegistrar xla_op_registrar__body__##CTR##__object( \ - OP_NAME, [](::tensorflow::OpKernelConstruction* context) \ - -> ::tensorflow::OpKernel* { return new OP(context); }); - -class HVDXlaOpRegistrar { -public: - HVDXlaOpRegistrar(string op_name, - ::tensorflow::XlaOpRegistry::Factory factory) { - bool enable_xla_ops = false; - common::SetBoolFromEnv(HOROVOD_ENABLE_XLA_OPS, enable_xla_ops, true); - if (enable_xla_ops) { - xla_op_registrar_ = new XlaOpRegistrar( - ::tensorflow::XlaOpRegistrationBuilder::Name(op_name).Build(factory)); - } - } - -private: - XlaOpRegistrar* xla_op_registrar_; -}; - -HVD_REGISTER_XLA_OP("HorovodAllreduce", HVDAllreduceOp); - -// A helper function to build HLOs for all-reduce. -::xla::StatusOr<::xla::XlaOp> HVDAllreduceOp::BuildAllreduceCustomCall( - ::xla::XlaBuilder* b, absl::Span operands, - bool is_start, - absl::Span< - const std::pair<::xla::ShapeIndex, std::pair>> - output_operand_aliasing) { - string call_target_name = - is_start ? "CallbackHVDAllreduce" : "CallbackHVDAllreduceDone"; - CustomCallConfig config; - config.tensor_name_ = node_name_; - for (const ::xla::XlaOp& opnd : operands) { - TF_ASSIGN_OR_RETURN(::xla::Shape shape, b->GetShape(opnd)); - config.input_shapes_.push_back(std::vector( - shape.dimensions().begin(), shape.dimensions().end())); - } - TF_ASSIGN_OR_RETURN(::xla::Shape output_shape, b->GetShape(operands.at(0))); - config.output_shapes_.push_back(std::vector( - output_shape.dimensions().begin(), output_shape.dimensions().end())); - config.tensor_type_ = GetHVDType(output_shape.element_type()); - config.prescale_factor_ = prescale_factor_; - config.postscale_factor_ = postscale_factor_; - config.reduce_op_ = reduce_op_; - config.process_set_id_ = process_set_id_; - - return ::xla::CustomCall( - b, call_target_name, operands, output_shape, config.SerializeToString(), - /*has_side_effect=*/false, output_operand_aliasing, /*literal=*/nullptr, - // Special schedule hints are given so that XLA knows how to schedule - // the opague custom-calls for performance. - is_start ? ::xla::CustomCallSchedule::SCHEDULE_EARLIEST - : ::xla::CustomCallSchedule::SCHEDULE_LATEST); -} - -// Returns a hash for rendezvous. -uint64 GetRendezvousKeyHash(const string& key) { - string k = strings::StrCat(key); - return Hash64(k.data(), k.size()); -} - -// Implements a rendezvous to coordinate the `start` and `end` HLO callbacks. -class HVDCustomCallRendezvous { -public: - struct Payload { - std::shared_ptr event; - }; - - // This `Signal` method places payload to be consumed by Wait(). - // - // Requirement: tensor_name shall be unique in a graph. - void Signal(string tensor_name, common::Event hvd_event) { - // Use `tensor_name` to generate a hash value to retrieve the queue. - uint64 key_hash = GetRendezvousKeyHash(tensor_name); - mutex_lock l(mu_); - InitQueue(key_hash); - - Queue& queue = *table_[key_hash]; - if (queue.empty() || queue.front() != nullptr) { - // No earlier waiters are waiting, so simply push a payload in the back. - queue.push_back(new Payload{hvd_event.event}); - return; - } - - // There is an earlier waiter to consume this signal. Place payload - // at the front of the queue where the waiter is polling. - CHECK(nullptr == queue.front()); - queue.front() = new Payload{hvd_event.event}; - } - - // The `Wait` method consumes Payloads. We assume there is at most one - // outstanding `Wait` call due to its blocking nature to simplify the - // implementation. Consequently, this method always operates on the very - // first item in the queue. - void Wait(string tensor_name, hipStream_t stream) { - uint64 key_hash = GetRendezvousKeyHash(tensor_name); - - { - mutex_lock l(mu_); - InitQueue(key_hash); - Queue& queue = *table_[key_hash]; - if (queue.empty()) { - // So long as the queue is empty, place a NULL payload. Then waiting for - // Signal() to place the payload below. - queue.push_back(nullptr); - } - } - - auto has_available_signal = [&]() { - mutex_lock l(mu_); - Queue& queue = *table_[key_hash]; - return nullptr != queue.front(); - }; - while (!has_available_signal()) { - // Busy waiting. As we don't anticipate the blocking occurs frequently, - // this busy waiting should be fine. If this creates any performance - // overhead, we may implement conditional var wait. - std::this_thread::sleep_for(std::chrono::nanoseconds(100)); - } - - mutex_lock l(mu_); - Queue* queue = table_[key_hash]; - Payload* payload = queue->front(); - std::shared_ptr event = payload->event; - queue->pop_front(); - if (queue->empty()) { - table_.erase(key_hash); - delete queue; - } - if (event) { - HVD_GPU_CHECK(hipStreamWaitEvent(stream, *event, /*flags=*/0)); - } - delete payload; - } - -private: - // This method is not thread-safe. - void InitQueue(uint64 key_hash) { - auto it = table_.find(key_hash); - if (it == table_.end()) { - table_[key_hash] = new Queue(); - } - } - -private: - // `nullptr` denotes non-readiness of the payload. - typedef std::deque Queue; - // maps a hash value to queue. We will use tensor_names to generate the hash - // values. - typedef absl::flat_hash_map Table; - - mutex mu_; - Table table_ GUARDED_BY(mu_); -}; - -/*static*/ HVDCustomCallRendezvous* GetHVDCustomCallRendezvous() { - static HVDCustomCallRendezvous* self = new HVDCustomCallRendezvous(); - return self; -} - -class XLAReadyEvent : public common::ReadyEvent { -public: - XLAReadyEvent(hipStream_t stream) : stream_(stream) { - HVD_GPU_CHECK(hipEventCreate(&event_)); - HVD_GPU_CHECK(hipEventRecord(event_, stream)); - } - ~XLAReadyEvent() { HVD_GPU_CHECK(hipEventDestroy(event_)); } - - bool Ready() const override { - hipError_t result = hipEventQuery(event_); - return hipErrorNotReady != result; - } - gpuEvent_t event() const override { return event_; } - -private: - hipStream_t stream_; // Not Owned. - hipEvent_t event_; // Owned. -}; - -class XLATensor : public common::Tensor { -public: - XLATensor(common::DataType type, common::TensorShape shape, void* buffer) - : type_(type), shape_(std::move(shape)), buffer_(buffer) {} - - virtual const common::DataType dtype() const override { return type_; } - virtual const common::TensorShape shape() const override { return shape_; } - virtual const void* data() const override { return buffer_; } - virtual int64_t size() const override { - return shape_.num_elements() * common::DataType_Size(type_); - } - -protected: - common::DataType type_; - common::TensorShape shape_; - void* buffer_; // Not owned. -}; - -class XLAOpContext : public common::OpContext { -public: - XLAOpContext(int device) : device_(device) {} - - virtual common::Status AllocatePersistent( - int64_t size, std::shared_ptr* tensor) override; - - virtual common::Status - AllocateOutput(common::TensorShape shape, - std::shared_ptr* tensor) override; - - virtual common::Status - AllocateZeros(int64_t num_elements, common::DataType dtype, - std::shared_ptr* tensor) override; - - virtual common::Framework framework() const override { - return common::Framework::XLA; - } - -private: - int device_; -}; - -class XLAPersistentBuffer : public common::PersistentBuffer { -public: - XLAPersistentBuffer(int device, int64_t size); - virtual const void* - AccessData(std::shared_ptr context) const override; - -private: - int device_; - void* buffer_; -}; - -XLAPersistentBuffer::XLAPersistentBuffer(int device, int64_t size) - : device_(device) { - int restore_device; - HVD_GPU_CHECK(hipGetDevice(&restore_device)); - HVD_GPU_CHECK(hipSetDevice(device)); - // Simply call cudaMalloc for persistent buffer. - HVD_GPU_CHECK(hipMalloc((void**)&buffer_, size)); - HVD_GPU_CHECK(hipSetDevice(restore_device)); -} - -const void* XLAPersistentBuffer::AccessData( - std::shared_ptr /*context*/) const { - return buffer_; -} - -common::Status XLAOpContext::AllocatePersistent( - int64_t size, std::shared_ptr* tensor) { - *tensor = std::make_shared(device_, size); - return common::Status::OK(); -} - -common::Status -XLAOpContext::AllocateOutput(common::TensorShape shape, - std::shared_ptr* tensor) { - // XLA must manage I/O buffers. - return common::Status::PreconditionError( - "AllocateOutput is not supported for XLA."); -} - -common::Status -XLAOpContext::AllocateZeros(int64_t num_elements, common::DataType dtype, - std::shared_ptr* tensor) { - // XLA must manage I/O buffers. - return common::Status::PreconditionError( - "AllocateZeros is not supported for XLA."); -} - -common::ReadyEvent* RecordReadyEvent(hipStream_t stream) { - return new XLAReadyEvent(stream); -} - -int GetDeviceOrdinal(void* ptr) { - hipPointerAttribute_t attrs; - HVD_GPU_CHECK(hipPointerGetAttributes(&attrs, ptr)); - return attrs.device; -} - -// Implements for the `HVDAllreduce` HLO CustomCall. -void CallbackHVDAllreduce(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - CHECK(common::CheckInitialized().ok()); - CustomCallConfig config; - config.ParseFromString(std::string(opaque, opaque_len)); - - // Enqueue requests to the Horovod runtime. - common::ReadyEventList ready_event_list; - ready_event_list.AddReadyEvent( - std::shared_ptr(RecordReadyEvent(stream))); - int dev_ordinal = GetDeviceOrdinal(buffers[0]); - auto hvd_context = std::make_shared(dev_ordinal); - auto hvd_input = std::make_shared( - config.tensor_type_, common::TensorShape(config.input_shapes_[0]), - buffers[0]); - auto hvd_output = std::make_shared( - config.tensor_type_, common::TensorShape(config.input_shapes_[0]), - buffers[1]); - common::Status enqueue_result = EnqueueTensorAllreduce( - hvd_context, hvd_input, hvd_output, ready_event_list, config.tensor_name_, - dev_ordinal, - [=](const common::Status& status) { - // When request is done processing, signal `HVDAllreduceDone`. - CHECK(status.ok()) << status.reason(); - GetHVDCustomCallRendezvous()->Signal(config.tensor_name_, status.event); - }, - (horovod::common::ReduceOp)config.reduce_op_, - (double)config.prescale_factor_, (double)config.postscale_factor_, - config.process_set_id_); - CHECK(enqueue_result.ok()) << enqueue_result.reason(); -} - -// Implements for the `HVDAllreduceDone` HLO CustomCall. -void CallbackHVDAllreduceDone(hipStream_t stream, void** /*buffers*/, - const char* opaque, size_t opaque_len) { - // Blocking until the request is done processing by the Horovod runtime. - VLOG(2) << "hvd-allreduce-done - Start"; - CustomCallConfig config; - config.ParseFromString(std::string(opaque, opaque_len)); - GetHVDCustomCallRendezvous()->Wait(config.tensor_name_, stream); - VLOG(2) << "hvd-allreduce-done - End"; -} - -XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduce, "ROCM"); -XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "ROCM"); - -} // namespace -} // namespace tensorflow -} // namespace horovod -#endif //HAVE_ROCM #endif // HAVE_GPU #endif // TENSORFLOW_VERSION >= 2006000000 From b75eba96384b51a855d2bc4a953da5f8833899d9 Mon Sep 17 00:00:00 2001 From: weihanmines <39137181+weihanmines@users.noreply.github.com> Date: Mon, 16 May 2022 12:23:23 -0500 Subject: [PATCH 08/11] Update horovod/common/ops/rocm/CMakeLists.txt Co-authored-by: Enrico Minack Signed-off-by: weihanmines --- horovod/common/ops/rocm/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/horovod/common/ops/rocm/CMakeLists.txt b/horovod/common/ops/rocm/CMakeLists.txt index b20b85ec0e..e0bbf65836 100644 --- a/horovod/common/ops/rocm/CMakeLists.txt +++ b/horovod/common/ops/rocm/CMakeLists.txt @@ -1,4 +1,4 @@ -message(STATUS "Built Horovod for ROCm") +message(STATUS "Build Horovod for ROCm") if (NOT DEFINED HCC_APTH) if (DEFINED ENV{HCC_PATH}) set(HIP_PATH ${HCC_PATH} CACHE PATH "Path to which HCC has been installed") From f7aa621a349cb970f30e041c9a72887f589174df Mon Sep 17 00:00:00 2001 From: weihanmines Date: Mon, 16 May 2022 18:18:35 +0000 Subject: [PATCH 09/11] removed an extra line in hip_operations.cc Signed-off-by: weihanmines --- horovod/common/ops/hip_operations.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/horovod/common/ops/hip_operations.cc b/horovod/common/ops/hip_operations.cc index 5307831329..d53dde3bc5 100644 --- a/horovod/common/ops/hip_operations.cc +++ b/horovod/common/ops/hip_operations.cc @@ -63,7 +63,6 @@ class GPUContext::impl { auto key2 = std::make_pair(device, stream); event->event_idx = ++hip_event_idx[key2]; - return status; } From 4f5a5db5fe23ee1c59f33a6df16cf9b2f319caf1 Mon Sep 17 00:00:00 2001 From: weihanmines Date: Mon, 16 May 2022 18:28:35 +0000 Subject: [PATCH 10/11] Added sync warning message in [cuda|hip]_kernels.[h|cu] files Signed-off-by: weihanmines --- horovod/common/ops/cuda/cuda_kernels.cu | 3 +++ horovod/common/ops/cuda/cuda_kernels.h | 3 +++ horovod/common/ops/rocm/hip_kernels.cu | 3 +++ horovod/common/ops/rocm/hip_kernels.h | 5 ++++- 4 files changed, 13 insertions(+), 1 deletion(-) diff --git a/horovod/common/ops/cuda/cuda_kernels.cu b/horovod/common/ops/cuda/cuda_kernels.cu index 0a6e65dddf..f27b208c6d 100644 --- a/horovod/common/ops/cuda/cuda_kernels.cu +++ b/horovod/common/ops/cuda/cuda_kernels.cu @@ -13,6 +13,9 @@ // limitations under the License. // ============================================================================= +// ATTENTION: Any change here might obsolete hip_kernels.cu in rocm folder. +// Please keep this file synced with hip_kernels.cu. + #include "cuda_kernels.h" #include diff --git a/horovod/common/ops/cuda/cuda_kernels.h b/horovod/common/ops/cuda/cuda_kernels.h index 70c678ad4b..b0f47dfc8d 100644 --- a/horovod/common/ops/cuda/cuda_kernels.h +++ b/horovod/common/ops/cuda/cuda_kernels.h @@ -13,6 +13,9 @@ // limitations under the License. // ============================================================================= +// ATTENTION: Any change here might obsolete hip_kernels.h in rocm folder. +// Please keep this file synced with hip_kernels.h. + #ifndef CUDA_KERNELS_H #define CUDA_KERNELS_H diff --git a/horovod/common/ops/rocm/hip_kernels.cu b/horovod/common/ops/rocm/hip_kernels.cu index 8fe8e1caa9..d515ad3b4c 100644 --- a/horovod/common/ops/rocm/hip_kernels.cu +++ b/horovod/common/ops/rocm/hip_kernels.cu @@ -13,6 +13,9 @@ // limitations under the License. // ============================================================================= +// ATTENTION: Any change here might obsolete cuda_kernels.cu in cuda folder. +// Please keep this file synced with cuda_kernels.cu. + #include "hip_kernels.h" #include diff --git a/horovod/common/ops/rocm/hip_kernels.h b/horovod/common/ops/rocm/hip_kernels.h index d3d99f490c..2b578b1678 100644 --- a/horovod/common/ops/rocm/hip_kernels.h +++ b/horovod/common/ops/rocm/hip_kernels.h @@ -13,6 +13,9 @@ // limitations under the License. // ============================================================================= +// ATTENTION: Any change here might obsolete cuda_kernels.h in cuda folder. +// Please keep this file synced with cuda_kernels.h. + #ifndef HIP_KERNELS_H #define HIP_KERNELS_H @@ -45,4 +48,4 @@ void BatchedScaledD2DMemcpyROCmImpl(BatchedD2DParams& params, int num_copies, do } // namespace common } // namespace horovod -#endif // CUDA_KERNELS_H +#endif // HIP_KERNELS_H From ba0c354a6d97eb0acfc9006000963d5ebb2d8ee5 Mon Sep 17 00:00:00 2001 From: weihanmines Date: Wed, 25 May 2022 19:26:12 +0000 Subject: [PATCH 11/11] fixed a comment string and added a preprocessor branch for ROCM Signed-off-by: weihanmines --- horovod/tensorflow/xla_mpi_ops.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/horovod/tensorflow/xla_mpi_ops.cc b/horovod/tensorflow/xla_mpi_ops.cc index a52baa4e50..d2ea00c5d4 100644 --- a/horovod/tensorflow/xla_mpi_ops.cc +++ b/horovod/tensorflow/xla_mpi_ops.cc @@ -506,7 +506,7 @@ XLAPersistentBuffer::XLAPersistentBuffer(int device, int64_t size) #elif HAVE_ROCM HVD_GPU_CHECK(hipGetDevice(&restore_device)); HVD_GPU_CHECK(hipSetDevice(device)); - // Simply call cudaMalloc for persistent buffer. + // Simply call hipMalloc for persistent buffer. HVD_GPU_CHECK(hipMalloc((void**)&buffer_, size)); HVD_GPU_CHECK(hipSetDevice(restore_device)); #endif @@ -597,8 +597,13 @@ void CallbackHVDAllreduceDone(gpuStream_t stream, void** /*buffers*/, VLOG(2) << "hvd-allreduce-done - End"; } +#if HAVE_CUDA XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduce, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "CUDA"); +#elif HAVE_ROCM +XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduce, "ROCM"); +XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "ROCM"); +#endif } // namespace } // namespace tensorflow