Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AMD GPU XLA Op Implementation #3486

Merged
merged 11 commits into from May 26, 2022
4 changes: 4 additions & 0 deletions CMakeLists.txt
Expand Up @@ -350,6 +350,10 @@ if(HAVE_CUDA OR HAVE_SUB_PROJECT_CUDA)
add_subdirectory(horovod/common/ops/cuda)
endif()

if(HAVE_ROCM)
add_subdirectory(horovod/common/ops/rocm)
endif()

# if we need compatible c++ abi
# Duplicate gloo folder and add it as a new sub-project
if(HAVE_GLOO AND ((DEFINED Tensorflow_CXX11 AND NOT Tensorflow_CXX11) OR (DEFINED Pytorch_CXX11 AND NOT Pytorch_CXX11) OR (DEFINED Mxnet_CXX11 AND NOT Mxnet_CXX11)))
Expand Down
8 changes: 7 additions & 1 deletion horovod/common/common.h
Expand Up @@ -33,9 +33,12 @@
using gpuError_t = cudaError_t;
using gpuEvent_t = cudaEvent_t;
using gpuStream_t = cudaStream_t;
using gpuPointerAttribute_t = cudaPointerAttributes;
#define gpuEventCreateWithFlags cudaEventCreateWithFlags
#define gpuEventDisableTiming cudaEventDisableTiming
#define gpuEventRecord cudaEventRecord
#define gpuEventQuery cudaEventQuery
#define gpuErrorNotReady cudaErrorNotReady
#define gpuEventSynchronize cudaEventSynchronize
#define gpuStreamWaitEvent cudaStreamWaitEvent
#define HVD_GPU_CHECK(x) \
Expand All @@ -45,15 +48,17 @@ using gpuStream_t = cudaStream_t;
throw std::logic_error(std::string("GPU Error:") + cudaGetErrorString(cuda_result)); \
} \
} while (0)
#endif
weihanmines marked this conversation as resolved.
Show resolved Hide resolved
#elif HAVE_ROCM
#include <hip/hip_runtime_api.h>
using gpuError_t = hipError_t;
using gpuEvent_t = hipEvent_t;
using gpuStream_t = hipStream_t;
using gpuPointerAttribute_t = hipPointerAttribute_t;
#define gpuEventCreateWithFlags hipEventCreateWithFlags
#define gpuEventDisableTiming hipEventDisableTiming
#define gpuEventRecord hipEventRecord
#define gpuEventQuery hipEventQuery
#define gpuErrorNotReady hipErrorNotReady
#define gpuEventSynchronize hipEventSynchronize
#define gpuStreamWaitEvent hipStreamWaitEvent
#define HVD_GPU_CHECK(x) \
Expand All @@ -64,6 +69,7 @@ using gpuStream_t = hipStream_t;
} \
} while (0)
#endif
#endif


namespace horovod {
Expand Down
3 changes: 3 additions & 0 deletions horovod/common/ops/cuda/cuda_kernels.cu
Expand Up @@ -13,6 +13,9 @@
// limitations under the License.
// =============================================================================

// ATTENTION: Any change here might obsolete hip_kernels.cu in rocm folder.
// Please keep this file synced with hip_kernels.cu.

#include "cuda_kernels.h"

#include <stdexcept>
Expand Down
3 changes: 3 additions & 0 deletions horovod/common/ops/cuda/cuda_kernels.h
Expand Up @@ -13,6 +13,9 @@
// limitations under the License.
// =============================================================================

// ATTENTION: Any change here might obsolete hip_kernels.h in rocm folder.
// Please keep this file synced with hip_kernels.h.

#ifndef CUDA_KERNELS_H
#define CUDA_KERNELS_H

Expand Down
39 changes: 35 additions & 4 deletions horovod/common/ops/gpu_operations.cc
Expand Up @@ -18,6 +18,9 @@
#if HAVE_CUDA
#include "cuda/cuda_kernels.h"
#endif
#if HAVE_ROCM
#include "rocm/hip_kernels.h"
#endif

#include <thread>

Expand Down Expand Up @@ -151,7 +154,7 @@ bool GPUAllreduce::Enabled(const ParameterManager& param_manager,
return entries[0].device != CPU_DEVICE_ID;
}

#if HAVE_CUDA
#if HAVE_GPU
void GPUAllreduce::MemcpyInFusionBuffer(
const std::vector<TensorTableEntry>& entries, const void*& fused_input_data,
void*& buffer_data, size_t& buffer_len) {
Expand Down Expand Up @@ -185,10 +188,17 @@ void GPUAllreduce::MemcpyInFusionBuffer(

if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int)entries.size()) {
// Perform batched d2d memcpy
#if HAVE_CUDA
BatchedD2DMemcpyCudaImpl(
d2d_params, count,
gpu_context_->streams[global_state_->current_nccl_stream]
[first_entry.device]);
#elif HAVE_ROCM
BatchedD2DMemcpyROCmImpl(
d2d_params, count,
gpu_context_->streams[global_state_->current_nccl_stream]
[first_entry.device]);
#endif
// TODO: https://github.com/horovod/horovod/issues/2230
// gpu_context_->ErrorCheck("BatchedD2DMemcpyCudaImpl",
// cudaGetLastError());
Expand All @@ -213,7 +223,7 @@ void GPUAllreduce::MemcpyInFusionBuffer(
}
#endif

#if HAVE_CUDA
#if HAVE_GPU
void GPUAllreduce::ScaleMemcpyInFusionBuffer(
const std::vector<TensorTableEntry>& entries, const void*& fused_input_data,
void*& buffer_data, size_t& buffer_len, double scale_factor) {
Expand Down Expand Up @@ -246,10 +256,17 @@ void GPUAllreduce::ScaleMemcpyInFusionBuffer(

if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int)entries.size()) {
// Perform batched d2d memcpy
#if HAVE_CUDA
BatchedScaledD2DMemcpyCudaImpl(
d2d_params, count, scale_factor, first_entry.tensor->dtype(),
gpu_context_->streams[global_state_->current_nccl_stream]
[first_entry.device]);
#elif HAVE_ROCM
BatchedScaledD2DMemcpyROCmImpl(
d2d_params, count, scale_factor, first_entry.tensor->dtype(),
gpu_context_->streams[global_state_->current_nccl_stream]
[first_entry.device]);
#endif
// TODO: https://github.com/horovod/horovod/issues/2230
// gpu_context_->ErrorCheck("BatchedScaledD2DMemcpyCudaImpl",
// cudaGetLastError());
Expand Down Expand Up @@ -290,7 +307,7 @@ void GPUAllreduce::MemcpyEntryInFusionBuffer(
->streams[global_state_->current_nccl_stream][first_entry.device]);
}

#if HAVE_CUDA
#if HAVE_GPU
void GPUAllreduce::MemcpyOutFusionBuffer(
const void* buffer_data, std::vector<TensorTableEntry>& entries) {
if (global_state_->batch_d2d_memcopies) {
Expand All @@ -316,10 +333,17 @@ void GPUAllreduce::MemcpyOutFusionBuffer(

if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int)entries.size()) {
// Perform batched d2d memcpy
#if HAVE_CUDA
BatchedD2DMemcpyCudaImpl(
d2d_params, count,
gpu_context_->streams[global_state_->current_nccl_stream]
[first_entry.device]);
#elif HAVE_ROCM
BatchedD2DMemcpyROCmImpl(
d2d_params, count,
gpu_context_->streams[global_state_->current_nccl_stream]
[first_entry.device]);
#endif
// TODO: https://github.com/horovod/horovod/issues/2230
// gpu_context_->ErrorCheck("BatchedD2DMemcpyCudaImpl",
// cudaGetLastError());
Expand All @@ -338,7 +362,7 @@ void GPUAllreduce::MemcpyOutFusionBuffer(
}
#endif

#if HAVE_CUDA
#if HAVE_GPU
void GPUAllreduce::ScaleMemcpyOutFusionBuffer(
void* buffer_data, size_t buffer_len, double scale_factor,
std::vector<TensorTableEntry>& entries) {
Expand Down Expand Up @@ -366,10 +390,17 @@ void GPUAllreduce::ScaleMemcpyOutFusionBuffer(

if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int)entries.size()) {
// Perform batched d2d memcpy
#if HAVE_CUDA
BatchedScaledD2DMemcpyCudaImpl(
d2d_params, count, scale_factor, first_entry.tensor->dtype(),
gpu_context_->streams[global_state_->current_nccl_stream]
[first_entry.device]);
#elif HAVE_ROCM
BatchedScaledD2DMemcpyROCmImpl(
d2d_params, count, scale_factor, first_entry.tensor->dtype(),
gpu_context_->streams[global_state_->current_nccl_stream]
[first_entry.device]);
#endif
// TODO: https://github.com/horovod/horovod/issues/2230
// gpu_context_->ErrorCheck("BatchedD2DMemcpyCudaImpl",
// cudaGetLastError());
Expand Down
2 changes: 1 addition & 1 deletion horovod/common/ops/gpu_operations.h
Expand Up @@ -160,7 +160,7 @@ class GPUAllreduce : public AllreduceOp {
const Response& response) const override;

protected:
#if HAVE_CUDA
#if HAVE_GPU
void MemcpyInFusionBuffer(const std::vector<TensorTableEntry>& entries,
const void*& fused_input_data, void*& buffer_data,
size_t& buffer_len) override;
Expand Down
118 changes: 97 additions & 21 deletions horovod/common/ops/hip_operations.cc
Expand Up @@ -14,17 +14,18 @@
// limitations under the License.
// =============================================================================
weihanmines marked this conversation as resolved.
Show resolved Hide resolved

#include "../hashes.h"
#include "../message.h"
#include "gpu_operations.h"
#include "rocm/hip_kernels.h"

#include <thread>

namespace horovod {
namespace common {

class GPUContext::impl {
public:
hipError_t GetGpuEvent(hipEvent_t* event) {
hipError_t GetGpuEvent(Event* event, hipStream_t stream) {
int device;
auto status = hipGetDevice(&device);
if (status != hipSuccess) {
Expand All @@ -34,18 +35,38 @@ class GPUContext::impl {
auto& mutex = hip_events_mutex;
{
std::lock_guard<std::mutex> guard(mutex);
auto& queue = hip_events[device];
auto key = std::make_pair(device, stream);
auto& queue = hip_events[key];
if (!prepopulated[key]) {
// On first call for device and stream pair, prepopulate event queue.
// This is to minimize event reuse of callback events passed to
// framework.
for (int i = 0; i < N_HIP_EVENTS_PREPOPULATE; ++i) {
hipEvent_t ev;
status = hipEventCreateWithFlags(&ev, hipEventDisableTiming);
queue.emplace(std::make_shared<hipEvent_t>(ev), stream);
}
prepopulated[key] = true;
}
if (!queue.empty()) {
*event = queue.front();
event->event_idx = ++hip_event_idx[key];
queue.pop();
return hipSuccess;
}
}

return hipEventCreateWithFlags(event, hipEventDisableTiming);
hipEvent_t ev;
status = hipEventCreateWithFlags(&ev, hipEventDisableTiming);
event->event = std::make_shared<hipEvent_t>(ev);
event->stream = stream;
auto key2 = std::make_pair(device, stream);
event->event_idx = ++hip_event_idx[key2];

return status;
}

hipError_t ReleaseGpuEvent(hipEvent_t event) {
hipError_t ReleaseGpuEvent(Event event) {
int device;
auto status = hipGetDevice(&device);
if (status != hipSuccess) {
Expand All @@ -55,7 +76,7 @@ class GPUContext::impl {
auto& mutex = hip_events_mutex;
{
std::lock_guard<std::mutex> guard(mutex);
auto& queue = hip_events[device];
auto& queue = hip_events[std::make_pair(device, event.stream)];
queue.push(event);
}

Expand All @@ -69,22 +90,59 @@ class GPUContext::impl {
}
}

void RecordEvent(std::queue<std::pair<std::string, hipEvent_t>>& event_queue,
void RecordEvent(std::queue<std::pair<std::string, Event>>& event_queue,
std::string name, hipStream_t& stream) {
hipEvent_t event;
ErrorCheck("GetGpuEvent", GetGpuEvent(&event));
ErrorCheck("hipEventRecord", hipEventRecord(event, stream));
Event event;
ErrorCheck("GetGpuEvent", GetGpuEvent(&event, stream));
ErrorCheck("hipEventRecord",
hipEventRecord(*(event.event), event.stream));
event_queue.emplace(name, event);
}

void
WaitForEvents(std::queue<std::pair<std::string, hipEvent_t>>& event_queue,
Event RecordEvent(hipStream_t& stream) {
Event event;
ErrorCheck("GetGpuEvent", GetGpuEvent(&event, stream));
ErrorCheck("hipEventRecord",
hipEventRecord(*(event.event), event.stream));
return event;
}

void WaitForEvents(std::queue<std::pair<std::string, Event>>& event_queue,
const std::vector<TensorTableEntry>& entries,
Timeline& timeline,
const std::function<void()>& error_check_callback) {
while (!event_queue.empty()) {
std::string name;
hipEvent_t event;
Event event;
std::tie(name, event) = event_queue.front();
event_queue.pop();
if (name != "") {
timeline.ActivityStartAll(entries, name);
}

hipError_t hip_result = hipEventSynchronize(*(event.event));
if (hip_result != hipSuccess) {
throw std::logic_error(std::string("cudaEventSynchronize failed: ") +
hipGetErrorString(hip_result));
}
if (error_check_callback) {
error_check_callback();
}

if (name != "") {
timeline.ActivityEndAll(entries);
}
ErrorCheck("ReleaseGpuEvent", ReleaseGpuEvent(event));
}
}

void WaitForEventsElastic(
std::queue<std::pair<std::string, Event>>& event_queue,
const std::vector<TensorTableEntry>& entries, Timeline& timeline,
const std::function<void()>& error_check_callback) {
while (!event_queue.empty()) {
std::string name;
Event event;
std::tie(name, event) = event_queue.front();
event_queue.pop();
if (name != "") {
Expand All @@ -95,13 +153,13 @@ class GPUContext::impl {
// complete
hipError_t hip_result;
while (true) {
hip_result = hipEventQuery(event);
hip_result = hipEventQuery(*(event.event));
if (hip_result == hipSuccess) {
break;
}

if (hip_result != hipErrorNotReady) {
throw std::logic_error(std::string("hipEventQuery failed: ") +
throw std::logic_error(std::string("cudaEventQuery failed: ") +
hipGetErrorString(hip_result));
}

Expand All @@ -118,11 +176,25 @@ class GPUContext::impl {
}
}

void WaitForEventsElastic(
std::queue<std::pair<std::string, hipEvent_t>>& event_queue,
const std::vector<TensorTableEntry>& entries, Timeline& timeline,
const std::function<void()>& error_check_callback) {
WaitForEvents(event_queue, entries, timeline, error_check_callback);
void ClearEvents(std::queue<std::pair<std::string, Event>>& event_queue,
const std::vector<TensorTableEntry>& entries,
Timeline& timeline,
const std::function<void()>& error_check_callback,
bool elastic) {
while (!event_queue.empty()) {
std::string name;
Event event;
std::tie(name, event) = event_queue.front();
event_queue.pop();
if (name != "") {
timeline.ActivityStartAll(entries, name);
}

if (name != "") {
timeline.ActivityEndAll(entries);
}
ErrorCheck("ReleaseGpuEvent", ReleaseGpuEvent(event));
}
}

void StreamCreate(hipStream_t* stream) {
Expand Down Expand Up @@ -176,8 +248,12 @@ class GPUContext::impl {
private:
// We reuse HIP events as it appears that their creation carries non-zero
// cost.
std::unordered_map<int, std::queue<hipEvent_t>> hip_events;
std::unordered_map<std::pair<int, hipStream_t>, std::queue<Event>>
hip_events;
std::unordered_map<std::pair<int, hipStream_t>, bool> prepopulated;
std::unordered_map<std::pair<int, hipStream_t>, std::atomic<uint64_t>> hip_event_idx;
std::mutex hip_events_mutex;
static constexpr int N_HIP_EVENTS_PREPOPULATE = 128;
};

#include "gpu_context_impl.cc"
Expand Down