From a9e3f15ec27c7e7946df4599cb3731f1b04ac701 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Mon, 25 Oct 2021 16:13:39 -0700 Subject: [PATCH] Add support for MXNet async dependency engine. Signed-off-by: Josh Romero --- horovod/common/common.h | 1 + horovod/common/ops/cuda_operations.cc | 5 + horovod/mxnet/mpi_ops.cc | 126 +++++++++++++++++++++++++- horovod/mxnet/mpi_ops.h | 3 + 4 files changed, 133 insertions(+), 2 deletions(-) diff --git a/horovod/common/common.h b/horovod/common/common.h index a6fadbf005..037e60dd72 100644 --- a/horovod/common/common.h +++ b/horovod/common/common.h @@ -188,6 +188,7 @@ struct Event { Event(std::shared_ptr event, gpuStream_t stream) : event(event), stream(stream) {}; std::shared_ptr event; + uint64_t event_idx; gpuStream_t stream = nullptr; #endif }; diff --git a/horovod/common/ops/cuda_operations.cc b/horovod/common/ops/cuda_operations.cc index 1f8d64d332..9e0eb53d38 100644 --- a/horovod/common/ops/cuda_operations.cc +++ b/horovod/common/ops/cuda_operations.cc @@ -50,6 +50,7 @@ class GPUContext::impl { } if (!queue.empty()) { *event = queue.front(); + event->event_idx = ++cuda_event_idx[key]; queue.pop(); return cudaSuccess; } @@ -59,6 +60,9 @@ class GPUContext::impl { status = cudaEventCreateWithFlags(&ev, cudaEventDisableTiming); event->event = std::make_shared(ev); event->stream = stream; + auto key2 = std::make_pair(device, stream); + event->event_idx = ++cuda_event_idx[key2]; + return status; } @@ -255,6 +259,7 @@ class GPUContext::impl { std::unordered_map, std::queue> cuda_events; std::unordered_map, bool> prepopulated; + std::unordered_map, std::atomic> cuda_event_idx; std::mutex cuda_events_mutex; static constexpr int N_CUDA_EVENTS_PREPOPULATE = 128; diff --git a/horovod/mxnet/mpi_ops.cc b/horovod/mxnet/mpi_ops.cc index 3a778bac03..69a525c582 100644 --- a/horovod/mxnet/mpi_ops.cc +++ b/horovod/mxnet/mpi_ops.cc @@ -20,6 +20,10 @@ #include "cuda_util.h" #include "mpi_ops.h" +#if MXNET_MAJOR >= 2 || MXNET_ASYNC_GPU_ENGINE_SUPPORTED +#define MXNET_ASYNC_GPU_ENGINE_SUPPORTED 1 +#endif + namespace horovod { namespace mxnet { @@ -72,7 +76,66 @@ bool IsTensorOnCPU(NDArray* tensor) { return tensor->ctx().dev_mask() == cpu::kDevMask; } +#if HAVE_CUDA +class MXReadyEvent : public common::ReadyEvent { +public: + MXReadyEvent(gpuEvent_t event) : event_(event) {}; + bool Ready() const override { + HVD_GPU_CHECK(gpuEventSynchronize(event_)); + return true; + }; + gpuEvent_t event() const override { + return event_; + } + +private: + gpuEvent_t event_; +}; +#endif + +ReadyEventList FormReadyEventList(NDArray* input, NDArray* output) { + ReadyEventList ready_event_list; + +#if HAVE_CUDA && MXNET_ASYNC_GPU_ENGINE_SUPPORTED + // Get events from input tensor writers + { + auto& sync_obj = input->var()->sync_object; + std::lock_guard l(sync_obj.mutex); + if (!sync_obj.writer_event.empty()) { + auto ev = sync_obj.writer_event[0].event.lock(); + if (ev) { + ready_event_list.AddReadyEvent(std::make_shared(*ev)); + } + } + } + + // Get events from output tensor reader and writers + { + auto& sync_obj = output->var()->sync_object; + std::lock_guard l(sync_obj.mutex); + for (auto& cuda_event : sync_obj.reader_events) { + auto ev = cuda_event.event.lock(); + if (ev) { + ready_event_list.AddReadyEvent(std::make_shared(*ev)); + } + } + if (!sync_obj.writer_event.empty()) { + auto ev = sync_obj.writer_event[0].event.lock(); + if (ev) { + ready_event_list.AddReadyEvent(std::make_shared(*ev)); + } + } + } +#endif + return ready_event_list; +} + +#if MXNET_ASYNC_GPU_ENGINE_SUPPORTED +void DoHorovodOperation(void*, void* on_start_ptr, void* on_complete_ptr, void* param) { + auto on_start = *static_cast(on_start_ptr); +#else void DoHorovodOperation(void*, void* on_complete_ptr, void* param) { +#endif ThrowIfError(common::CheckInitialized()); auto on_complete = *static_cast(on_complete_ptr); @@ -91,14 +154,17 @@ void DoHorovodOperation(void*, void* on_complete_ptr, void* param) { std::vector ready_event_lists; hvd_tensors.reserve(num_tensors); hvd_contexts.reserve(num_tensors); - ready_event_lists.resize(num_tensors); + ready_event_lists.reserve(num_tensors); callbacks.reserve(num_tensors); auto callback_mutex = std::make_shared(); for (int i = 0; i < num_tensors; ++i) { auto input_tensor = ops_param->input_tensors[i].get(); + auto output_tensor = ops_param->output_tensors[i].get(); auto output = ops_param->outputs[i]; + ready_event_lists.emplace_back(FormReadyEventList(input_tensor, output_tensor)); + hvd_tensors.emplace_back(std::make_shared(input_tensor)); if (TensorUtil::GetDevice(input_tensor) != device) { throw std::logic_error("Tensors in list must be on same device."); @@ -109,10 +175,56 @@ void DoHorovodOperation(void*, void* on_complete_ptr, void* param) { } hvd_contexts.push_back(ctx); callbacks.emplace_back([on_complete, ops_param, callback_mutex, i](const Status& status) { + auto input_tensor = ops_param->input_tensors[i].get(); + auto output_tensor = ops_param->output_tensors[i].get(); #if HAVE_CUDA auto hvd_event = status.event; if (hvd_event.event) { +#if MXNET_ASYNC_GPU_ENGINE_SUPPORTED + auto async_engine_enabled = dmlc::GetEnv("MXNET_ASYNC_GPU_ENGINE", false); + if (async_engine_enabled) { + { + auto &sync_obj = input_tensor->var()->sync_object; + std::lock_guard l(sync_obj.mutex); + // If some reader event is already recorded on the same stream, + // we want to replace ourselves by it + int i; + for (i = 0; i < sync_obj.reader_events.size(); ++i) { + auto stream = sync_obj.reader_events[i].stream; + if (stream == hvd_event.stream) { + sync_obj.reader_events[i].event = hvd_event.event; + sync_obj.reader_events[i].pool_index = hvd_event.event_idx; + break; + } + } + if (i == sync_obj.reader_events.size()) { + sync_obj.reader_events.push_back({hvd_event.event, hvd_event.stream, hvd_event.event_idx}); + } + } + + { + auto &sync_obj = output_tensor->var()->sync_object; + std::lock_guard l(sync_obj.mutex); + sync_obj.reader_events.clear(); + sync_obj.writer_event.clear(); + sync_obj.writer_event.push_back({hvd_event.event, hvd_event.stream, hvd_event.event_idx}); + } + + if (ops_param->received_splits_tensor) { + { + auto &sync_obj = ops_param->received_splits_tensor.get()->var()->sync_object; + std::lock_guard l(sync_obj.mutex); + sync_obj.reader_events.clear(); + sync_obj.writer_event.clear(); + sync_obj.writer_event.push_back({hvd_event.event, hvd_event.stream, hvd_event.event_idx}); + } + } + } else { + HVD_GPU_CHECK(gpuEventSynchronize(*(hvd_event.event))); + } +#else HVD_GPU_CHECK(gpuEventSynchronize(*(hvd_event.event))); +#endif } #endif @@ -163,6 +275,9 @@ void DoHorovodOperation(void*, void* on_complete_ptr, void* param) { break; case OperationType::ALLTOALL: { +#if MXNET_ASYNC_GPU_ENGINE_SUPPORTED + on_start(); // Need to call on_start to sync on possible D2H copy of splits tensor. +#endif auto hvd_splits = std::make_shared(ops_param->splits_tensor.get()); enqueue_result = EnqueueTensorAlltoall( hvd_contexts[0], hvd_tensors[0], hvd_splits, ready_event_lists[0], @@ -281,7 +396,12 @@ inline void PushHorovodOperation(OperationType op_type, NDArray* const * inputs, } } #if HAVE_CUDA +#if MXNET_ASYNC_GPU_ENGINE_SUPPORTED +void DoHorovodOperationCudaOnCPU(void*, void* on_start_ptr, void* on_complete_ptr, void* param) { + auto on_start = *static_cast(on_start_ptr); +#else void DoHorovodOperationCudaOnCPU(void*, void* on_complete_ptr, void* param) { +#endif ThrowIfError(common::CheckInitialized()); auto on_complete = *static_cast(on_complete_ptr); @@ -299,7 +419,7 @@ void DoHorovodOperationCudaOnCPU(void*, void* on_complete_ptr, void* param) { std::vector ready_event_lists; hvd_cpu_buffers.reserve(num_tensors); hvd_contexts.reserve(num_tensors); - ready_event_lists.resize(num_tensors); + ready_event_lists.reserve(num_tensors); callbacks.reserve(num_tensors); auto callback_mutex = std::make_shared(); @@ -307,6 +427,8 @@ void DoHorovodOperationCudaOnCPU(void*, void* on_complete_ptr, void* param) { auto input = ops_param->cpu_input_tensors[i].get(); auto output = ops_param->cpu_output_tensors[i].get(); + ready_event_lists.emplace_back(FormReadyEventList(input, output)); + hvd_cpu_buffers.emplace_back(std::make_shared(input)); if (TensorUtil::GetDevice(input) != device) { throw std::logic_error("Tensors in list must be on same device."); diff --git a/horovod/mxnet/mpi_ops.h b/horovod/mxnet/mpi_ops.h index 482f3592c7..97854c37ff 100644 --- a/horovod/mxnet/mpi_ops.h +++ b/horovod/mxnet/mpi_ops.h @@ -33,6 +33,9 @@ using namespace horovod::common; typedef ::mxnet::NDArray NDArray; typedef ::mxnet::Engine::CallbackOnComplete CallbackOnComplete; +#if MXNET_MAJOR >= 2 || MXNET_ASYNC_GPU_ENGINE_SUPPORTED +typedef ::mxnet::Engine::CallbackOnStart CallbackOnStart; +#endif typedef Request::RequestType OperationType; typedef std::shared_ptr MXTensorSharedPtr; typedef std::shared_ptr NDArraySharedPtr;