Skip to content

Commit

Permalink
Add support for MXNet async dependency engine. (#3242)
Browse files Browse the repository at this point in the history
Signed-off-by: Josh Romero <joshr@nvidia.com>
  • Loading branch information
romerojosh committed Nov 9, 2021
1 parent bf85497 commit b5d121e
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 6 deletions.
1 change: 1 addition & 0 deletions horovod/common/common.h
Expand Up @@ -188,6 +188,7 @@ struct Event {
Event(std::shared_ptr<gpuEvent_t> event, gpuStream_t stream) :
event(event), stream(stream) {};
std::shared_ptr<gpuEvent_t> event;
uint64_t event_idx;
gpuStream_t stream = nullptr;
#endif
};
Expand Down
5 changes: 5 additions & 0 deletions horovod/common/ops/cuda_operations.cc
Expand Up @@ -50,6 +50,7 @@ class GPUContext::impl {
}
if (!queue.empty()) {
*event = queue.front();
event->event_idx = ++cuda_event_idx[key];
queue.pop();
return cudaSuccess;
}
Expand All @@ -59,6 +60,9 @@ class GPUContext::impl {
status = cudaEventCreateWithFlags(&ev, cudaEventDisableTiming);
event->event = std::make_shared<cudaEvent_t>(ev);
event->stream = stream;
auto key2 = std::make_pair(device, stream);
event->event_idx = ++cuda_event_idx[key2];


return status;
}
Expand Down Expand Up @@ -255,6 +259,7 @@ class GPUContext::impl {
std::unordered_map<std::pair<int, cudaStream_t>, std::queue<Event>>
cuda_events;
std::unordered_map<std::pair<int, cudaStream_t>, bool> prepopulated;
std::unordered_map<std::pair<int, cudaStream_t>, std::atomic<uint64_t>> cuda_event_idx;
std::mutex cuda_events_mutex;

static constexpr int N_CUDA_EVENTS_PREPOPULATE = 128;
Expand Down
126 changes: 124 additions & 2 deletions horovod/mxnet/mpi_ops.cc
Expand Up @@ -20,6 +20,10 @@
#include "cuda_util.h"
#include "mpi_ops.h"

#if MXNET_MAJOR >= 2 || MXNET_ASYNC_GPU_ENGINE_SUPPORTED
#define MXNET_ASYNC_GPU_ENGINE_SUPPORTED 1
#endif

namespace horovod {
namespace mxnet {

Expand Down Expand Up @@ -72,7 +76,66 @@ bool IsTensorOnCPU(NDArray* tensor) {
return tensor->ctx().dev_mask() == cpu::kDevMask;
}

#if HAVE_CUDA
class MXReadyEvent : public common::ReadyEvent {
public:
MXReadyEvent(gpuEvent_t event) : event_(event) {};
bool Ready() const override {
HVD_GPU_CHECK(gpuEventSynchronize(event_));
return true;
};
gpuEvent_t event() const override {
return event_;
}

private:
gpuEvent_t event_;
};
#endif

ReadyEventList FormReadyEventList(NDArray* input, NDArray* output) {
ReadyEventList ready_event_list;

#if HAVE_CUDA && MXNET_ASYNC_GPU_ENGINE_SUPPORTED
// Get events from input tensor writers
{
auto& sync_obj = input->var()->sync_object;
std::lock_guard<std::mutex> l(sync_obj.mutex);
if (!sync_obj.writer_event.empty()) {
auto ev = sync_obj.writer_event[0].event.lock();
if (ev) {
ready_event_list.AddReadyEvent(std::make_shared<MXReadyEvent>(*ev));
}
}
}

// Get events from output tensor reader and writers
{
auto& sync_obj = output->var()->sync_object;
std::lock_guard<std::mutex> l(sync_obj.mutex);
for (auto& cuda_event : sync_obj.reader_events) {
auto ev = cuda_event.event.lock();
if (ev) {
ready_event_list.AddReadyEvent(std::make_shared<MXReadyEvent>(*ev));
}
}
if (!sync_obj.writer_event.empty()) {
auto ev = sync_obj.writer_event[0].event.lock();
if (ev) {
ready_event_list.AddReadyEvent(std::make_shared<MXReadyEvent>(*ev));
}
}
}
#endif
return ready_event_list;
}

#if MXNET_ASYNC_GPU_ENGINE_SUPPORTED
void DoHorovodOperation(void*, void* on_start_ptr, void* on_complete_ptr, void* param) {
auto on_start = *static_cast<CallbackOnStart*>(on_start_ptr);
#else
void DoHorovodOperation(void*, void* on_complete_ptr, void* param) {
#endif
ThrowIfError(common::CheckInitialized());

auto on_complete = *static_cast<CallbackOnComplete*>(on_complete_ptr);
Expand All @@ -91,14 +154,17 @@ void DoHorovodOperation(void*, void* on_complete_ptr, void* param) {
std::vector<ReadyEventList> ready_event_lists;
hvd_tensors.reserve(num_tensors);
hvd_contexts.reserve(num_tensors);
ready_event_lists.resize(num_tensors);
ready_event_lists.reserve(num_tensors);
callbacks.reserve(num_tensors);

auto callback_mutex = std::make_shared<std::mutex>();
for (int i = 0; i < num_tensors; ++i) {
auto input_tensor = ops_param->input_tensors[i].get();
auto output_tensor = ops_param->output_tensors[i].get();
auto output = ops_param->outputs[i];

ready_event_lists.emplace_back(FormReadyEventList(input_tensor, output_tensor));

hvd_tensors.emplace_back(std::make_shared<MXTensor>(input_tensor));
if (TensorUtil::GetDevice(input_tensor) != device) {
throw std::logic_error("Tensors in list must be on same device.");
Expand All @@ -109,10 +175,56 @@ void DoHorovodOperation(void*, void* on_complete_ptr, void* param) {
}
hvd_contexts.push_back(ctx);
callbacks.emplace_back([on_complete, ops_param, callback_mutex, i](const Status& status) {
auto input_tensor = ops_param->input_tensors[i].get();
auto output_tensor = ops_param->output_tensors[i].get();
#if HAVE_CUDA
auto hvd_event = status.event;
if (hvd_event.event) {
#if MXNET_ASYNC_GPU_ENGINE_SUPPORTED
auto async_engine_enabled = dmlc::GetEnv("MXNET_ASYNC_GPU_ENGINE", false);
if (async_engine_enabled) {
{
auto &sync_obj = input_tensor->var()->sync_object;
std::lock_guard<std::mutex> l(sync_obj.mutex);
// If some reader event is already recorded on the same stream,
// we want to replace ourselves by it
int i;
for (i = 0; i < sync_obj.reader_events.size(); ++i) {
auto stream = sync_obj.reader_events[i].stream;
if (stream == hvd_event.stream) {
sync_obj.reader_events[i].event = hvd_event.event;
sync_obj.reader_events[i].pool_index = hvd_event.event_idx;
break;
}
}
if (i == sync_obj.reader_events.size()) {
sync_obj.reader_events.push_back({hvd_event.event, hvd_event.stream, hvd_event.event_idx});
}
}

{
auto &sync_obj = output_tensor->var()->sync_object;
std::lock_guard<std::mutex> l(sync_obj.mutex);
sync_obj.reader_events.clear();
sync_obj.writer_event.clear();
sync_obj.writer_event.push_back({hvd_event.event, hvd_event.stream, hvd_event.event_idx});
}

if (ops_param->received_splits_tensor) {
{
auto &sync_obj = ops_param->received_splits_tensor.get()->var()->sync_object;
std::lock_guard<std::mutex> l(sync_obj.mutex);
sync_obj.reader_events.clear();
sync_obj.writer_event.clear();
sync_obj.writer_event.push_back({hvd_event.event, hvd_event.stream, hvd_event.event_idx});
}
}
} else {
HVD_GPU_CHECK(gpuEventSynchronize(*(hvd_event.event)));
}
#else
HVD_GPU_CHECK(gpuEventSynchronize(*(hvd_event.event)));
#endif
}
#endif

Expand Down Expand Up @@ -163,6 +275,9 @@ void DoHorovodOperation(void*, void* on_complete_ptr, void* param) {
break;
case OperationType::ALLTOALL:
{
#if MXNET_ASYNC_GPU_ENGINE_SUPPORTED
on_start(); // Need to call on_start to sync on possible D2H copy of splits tensor.
#endif
auto hvd_splits = std::make_shared<MXTensor>(ops_param->splits_tensor.get());
enqueue_result = EnqueueTensorAlltoall(
hvd_contexts[0], hvd_tensors[0], hvd_splits, ready_event_lists[0],
Expand Down Expand Up @@ -281,7 +396,12 @@ inline void PushHorovodOperation(OperationType op_type, NDArray* const * inputs,
}
}
#if HAVE_CUDA
#if MXNET_ASYNC_GPU_ENGINE_SUPPORTED
void DoHorovodOperationCudaOnCPU(void*, void* on_start_ptr, void* on_complete_ptr, void* param) {
auto on_start = *static_cast<CallbackOnComplete*>(on_start_ptr);
#else
void DoHorovodOperationCudaOnCPU(void*, void* on_complete_ptr, void* param) {
#endif
ThrowIfError(common::CheckInitialized());

auto on_complete = *static_cast<CallbackOnComplete*>(on_complete_ptr);
Expand All @@ -299,14 +419,16 @@ void DoHorovodOperationCudaOnCPU(void*, void* on_complete_ptr, void* param) {
std::vector<ReadyEventList> ready_event_lists;
hvd_cpu_buffers.reserve(num_tensors);
hvd_contexts.reserve(num_tensors);
ready_event_lists.resize(num_tensors);
ready_event_lists.reserve(num_tensors);
callbacks.reserve(num_tensors);

auto callback_mutex = std::make_shared<std::mutex>();
for (int i = 0; i < num_tensors; ++i) {
auto input = ops_param->cpu_input_tensors[i].get();
auto output = ops_param->cpu_output_tensors[i].get();

ready_event_lists.emplace_back(FormReadyEventList(input, output));

hvd_cpu_buffers.emplace_back(std::make_shared<MXTensor>(input));
if (TensorUtil::GetDevice(input) != device) {
throw std::logic_error("Tensors in list must be on same device.");
Expand Down
3 changes: 3 additions & 0 deletions horovod/mxnet/mpi_ops.h
Expand Up @@ -33,6 +33,9 @@ using namespace horovod::common;

typedef ::mxnet::NDArray NDArray;
typedef ::mxnet::Engine::CallbackOnComplete CallbackOnComplete;
#if MXNET_MAJOR >= 2 || MXNET_ASYNC_GPU_ENGINE_SUPPORTED
typedef ::mxnet::Engine::CallbackOnStart CallbackOnStart;
#endif
typedef Request::RequestType OperationType;
typedef std::shared_ptr<MXTensor> MXTensorSharedPtr;
typedef std::shared_ptr<NDArray> NDArraySharedPtr;
Expand Down
7 changes: 4 additions & 3 deletions horovod/torch/cuda_util.cc
Expand Up @@ -15,6 +15,7 @@

#if HAVE_GPU
#include "cuda_runtime.h"
#include <ATen/ATen.h>
#include <THC/THC.h>
#else
#include <stdexcept>
Expand All @@ -31,8 +32,8 @@ with_device::with_device(int device) {
restore_device_ = CPU_DEVICE_ID;
} else {
#if HAVE_GPU
THCudaCheck(cudaGetDevice(&restore_device_));
THCudaCheck(cudaSetDevice(device));
C10_CUDA_CHECK(cudaGetDevice(&restore_device_));
C10_CUDA_CHECK(cudaSetDevice(device));
#else
throw std::logic_error("Internal error. Requested device context manager "
"with GPU device but not compiled with CUDA.");
Expand All @@ -43,7 +44,7 @@ with_device::with_device(int device) {
with_device::~with_device() {
#if HAVE_GPU
if (restore_device_ != CPU_DEVICE_ID) {
THCudaCheck(cudaSetDevice(restore_device_));
C10_CUDA_CHECK(cudaSetDevice(restore_device_));
}
#endif
}
Expand Down
5 changes: 4 additions & 1 deletion test/parallel/base_test_mxnet.py
Expand Up @@ -33,7 +33,10 @@
from mxnet.test_utils import almost_equal, same
import horovod.mxnet as hvd

has_gpu = mx.context.num_gpus() > 0
try:
has_gpu = mx.context.num_gpus() > 0
except AttributeError:
has_gpu = mx.device.num_gpus() > 0

ccl_supported_types = set(['int32', 'int64', 'float32', 'float64'])

Expand Down

0 comments on commit b5d121e

Please sign in to comment.