diff --git a/CMakeLists.txt b/CMakeLists.txt index e8ded61d08..7707ed0a30 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -350,6 +350,10 @@ if(HAVE_CUDA OR HAVE_SUB_PROJECT_CUDA) add_subdirectory(horovod/common/ops/cuda) endif() +if(HAVE_ROCM) + add_subdirectory(horovod/common/ops/rocm) +endif() + # if we need compatible c++ abi # Duplicate gloo folder and add it as a new sub-project if(HAVE_GLOO AND ((DEFINED Tensorflow_CXX11 AND NOT Tensorflow_CXX11) OR (DEFINED Pytorch_CXX11 AND NOT Pytorch_CXX11) OR (DEFINED Mxnet_CXX11 AND NOT Mxnet_CXX11))) diff --git a/horovod/common/common.h b/horovod/common/common.h index 8876a02c9e..46565437e7 100644 --- a/horovod/common/common.h +++ b/horovod/common/common.h @@ -33,9 +33,12 @@ using gpuError_t = cudaError_t; using gpuEvent_t = cudaEvent_t; using gpuStream_t = cudaStream_t; +using gpuPointerAttribute_t = cudaPointerAttributes; #define gpuEventCreateWithFlags cudaEventCreateWithFlags #define gpuEventDisableTiming cudaEventDisableTiming #define gpuEventRecord cudaEventRecord +#define gpuEventQuery cudaEventQuery +#define gpuErrorNotReady cudaErrorNotReady #define gpuEventSynchronize cudaEventSynchronize #define gpuStreamWaitEvent cudaStreamWaitEvent #define HVD_GPU_CHECK(x) \ @@ -45,15 +48,17 @@ using gpuStream_t = cudaStream_t; throw std::logic_error(std::string("GPU Error:") + cudaGetErrorString(cuda_result)); \ } \ } while (0) -#endif #elif HAVE_ROCM #include using gpuError_t = hipError_t; using gpuEvent_t = hipEvent_t; using gpuStream_t = hipStream_t; +using gpuPointerAttribute_t = hipPointerAttribute_t; #define gpuEventCreateWithFlags hipEventCreateWithFlags #define gpuEventDisableTiming hipEventDisableTiming #define gpuEventRecord hipEventRecord +#define gpuEventQuery hipEventQuery +#define gpuErrorNotReady hipErrorNotReady #define gpuEventSynchronize hipEventSynchronize #define gpuStreamWaitEvent hipStreamWaitEvent #define HVD_GPU_CHECK(x) \ @@ -64,6 +69,7 @@ using gpuStream_t = hipStream_t; } \ } while (0) #endif +#endif namespace horovod { diff --git a/horovod/common/ops/cuda/cuda_kernels.cu b/horovod/common/ops/cuda/cuda_kernels.cu index 0a6e65dddf..f27b208c6d 100644 --- a/horovod/common/ops/cuda/cuda_kernels.cu +++ b/horovod/common/ops/cuda/cuda_kernels.cu @@ -13,6 +13,9 @@ // limitations under the License. // ============================================================================= +// ATTENTION: Any change here might obsolete hip_kernels.cu in rocm folder. +// Please keep this file synced with hip_kernels.cu. + #include "cuda_kernels.h" #include diff --git a/horovod/common/ops/cuda/cuda_kernels.h b/horovod/common/ops/cuda/cuda_kernels.h index 70c678ad4b..b0f47dfc8d 100644 --- a/horovod/common/ops/cuda/cuda_kernels.h +++ b/horovod/common/ops/cuda/cuda_kernels.h @@ -13,6 +13,9 @@ // limitations under the License. // ============================================================================= +// ATTENTION: Any change here might obsolete hip_kernels.h in rocm folder. +// Please keep this file synced with hip_kernels.h. + #ifndef CUDA_KERNELS_H #define CUDA_KERNELS_H diff --git a/horovod/common/ops/gpu_operations.cc b/horovod/common/ops/gpu_operations.cc index 0078d604f7..0adef04198 100644 --- a/horovod/common/ops/gpu_operations.cc +++ b/horovod/common/ops/gpu_operations.cc @@ -18,6 +18,9 @@ #if HAVE_CUDA #include "cuda/cuda_kernels.h" #endif +#if HAVE_ROCM +#include "rocm/hip_kernels.h" +#endif #include @@ -151,7 +154,7 @@ bool GPUAllreduce::Enabled(const ParameterManager& param_manager, return entries[0].device != CPU_DEVICE_ID; } -#if HAVE_CUDA +#if HAVE_GPU void GPUAllreduce::MemcpyInFusionBuffer( const std::vector& entries, const void*& fused_input_data, void*& buffer_data, size_t& buffer_len) { @@ -185,10 +188,17 @@ void GPUAllreduce::MemcpyInFusionBuffer( if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int)entries.size()) { // Perform batched d2d memcpy +#if HAVE_CUDA BatchedD2DMemcpyCudaImpl( d2d_params, count, gpu_context_->streams[global_state_->current_nccl_stream] [first_entry.device]); +#elif HAVE_ROCM + BatchedD2DMemcpyROCmImpl( + d2d_params, count, + gpu_context_->streams[global_state_->current_nccl_stream] + [first_entry.device]); +#endif // TODO: https://github.com/horovod/horovod/issues/2230 // gpu_context_->ErrorCheck("BatchedD2DMemcpyCudaImpl", // cudaGetLastError()); @@ -213,7 +223,7 @@ void GPUAllreduce::MemcpyInFusionBuffer( } #endif -#if HAVE_CUDA +#if HAVE_GPU void GPUAllreduce::ScaleMemcpyInFusionBuffer( const std::vector& entries, const void*& fused_input_data, void*& buffer_data, size_t& buffer_len, double scale_factor) { @@ -246,10 +256,17 @@ void GPUAllreduce::ScaleMemcpyInFusionBuffer( if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int)entries.size()) { // Perform batched d2d memcpy +#if HAVE_CUDA BatchedScaledD2DMemcpyCudaImpl( d2d_params, count, scale_factor, first_entry.tensor->dtype(), gpu_context_->streams[global_state_->current_nccl_stream] [first_entry.device]); +#elif HAVE_ROCM + BatchedScaledD2DMemcpyROCmImpl( + d2d_params, count, scale_factor, first_entry.tensor->dtype(), + gpu_context_->streams[global_state_->current_nccl_stream] + [first_entry.device]); +#endif // TODO: https://github.com/horovod/horovod/issues/2230 // gpu_context_->ErrorCheck("BatchedScaledD2DMemcpyCudaImpl", // cudaGetLastError()); @@ -290,7 +307,7 @@ void GPUAllreduce::MemcpyEntryInFusionBuffer( ->streams[global_state_->current_nccl_stream][first_entry.device]); } -#if HAVE_CUDA +#if HAVE_GPU void GPUAllreduce::MemcpyOutFusionBuffer( const void* buffer_data, std::vector& entries) { if (global_state_->batch_d2d_memcopies) { @@ -316,10 +333,17 @@ void GPUAllreduce::MemcpyOutFusionBuffer( if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int)entries.size()) { // Perform batched d2d memcpy +#if HAVE_CUDA BatchedD2DMemcpyCudaImpl( d2d_params, count, gpu_context_->streams[global_state_->current_nccl_stream] [first_entry.device]); +#elif HAVE_ROCM + BatchedD2DMemcpyROCmImpl( + d2d_params, count, + gpu_context_->streams[global_state_->current_nccl_stream] + [first_entry.device]); +#endif // TODO: https://github.com/horovod/horovod/issues/2230 // gpu_context_->ErrorCheck("BatchedD2DMemcpyCudaImpl", // cudaGetLastError()); @@ -338,7 +362,7 @@ void GPUAllreduce::MemcpyOutFusionBuffer( } #endif -#if HAVE_CUDA +#if HAVE_GPU void GPUAllreduce::ScaleMemcpyOutFusionBuffer( void* buffer_data, size_t buffer_len, double scale_factor, std::vector& entries) { @@ -366,10 +390,17 @@ void GPUAllreduce::ScaleMemcpyOutFusionBuffer( if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int)entries.size()) { // Perform batched d2d memcpy +#if HAVE_CUDA BatchedScaledD2DMemcpyCudaImpl( d2d_params, count, scale_factor, first_entry.tensor->dtype(), gpu_context_->streams[global_state_->current_nccl_stream] [first_entry.device]); +#elif HAVE_ROCM + BatchedScaledD2DMemcpyROCmImpl( + d2d_params, count, scale_factor, first_entry.tensor->dtype(), + gpu_context_->streams[global_state_->current_nccl_stream] + [first_entry.device]); +#endif // TODO: https://github.com/horovod/horovod/issues/2230 // gpu_context_->ErrorCheck("BatchedD2DMemcpyCudaImpl", // cudaGetLastError()); diff --git a/horovod/common/ops/gpu_operations.h b/horovod/common/ops/gpu_operations.h index f9342cde33..709aaa36fa 100644 --- a/horovod/common/ops/gpu_operations.h +++ b/horovod/common/ops/gpu_operations.h @@ -160,7 +160,7 @@ class GPUAllreduce : public AllreduceOp { const Response& response) const override; protected: -#if HAVE_CUDA +#if HAVE_GPU void MemcpyInFusionBuffer(const std::vector& entries, const void*& fused_input_data, void*& buffer_data, size_t& buffer_len) override; diff --git a/horovod/common/ops/hip_operations.cc b/horovod/common/ops/hip_operations.cc index 0479c6151e..d53dde3bc5 100644 --- a/horovod/common/ops/hip_operations.cc +++ b/horovod/common/ops/hip_operations.cc @@ -14,17 +14,18 @@ // limitations under the License. // ============================================================================= +#include "../hashes.h" #include "../message.h" #include "gpu_operations.h" +#include "rocm/hip_kernels.h" #include namespace horovod { namespace common { - class GPUContext::impl { public: - hipError_t GetGpuEvent(hipEvent_t* event) { + hipError_t GetGpuEvent(Event* event, hipStream_t stream) { int device; auto status = hipGetDevice(&device); if (status != hipSuccess) { @@ -34,18 +35,38 @@ class GPUContext::impl { auto& mutex = hip_events_mutex; { std::lock_guard guard(mutex); - auto& queue = hip_events[device]; + auto key = std::make_pair(device, stream); + auto& queue = hip_events[key]; + if (!prepopulated[key]) { + // On first call for device and stream pair, prepopulate event queue. + // This is to minimize event reuse of callback events passed to + // framework. + for (int i = 0; i < N_HIP_EVENTS_PREPOPULATE; ++i) { + hipEvent_t ev; + status = hipEventCreateWithFlags(&ev, hipEventDisableTiming); + queue.emplace(std::make_shared(ev), stream); + } + prepopulated[key] = true; + } if (!queue.empty()) { *event = queue.front(); + event->event_idx = ++hip_event_idx[key]; queue.pop(); return hipSuccess; } } - return hipEventCreateWithFlags(event, hipEventDisableTiming); + hipEvent_t ev; + status = hipEventCreateWithFlags(&ev, hipEventDisableTiming); + event->event = std::make_shared(ev); + event->stream = stream; + auto key2 = std::make_pair(device, stream); + event->event_idx = ++hip_event_idx[key2]; + + return status; } - hipError_t ReleaseGpuEvent(hipEvent_t event) { + hipError_t ReleaseGpuEvent(Event event) { int device; auto status = hipGetDevice(&device); if (status != hipSuccess) { @@ -55,7 +76,7 @@ class GPUContext::impl { auto& mutex = hip_events_mutex; { std::lock_guard guard(mutex); - auto& queue = hip_events[device]; + auto& queue = hip_events[std::make_pair(device, event.stream)]; queue.push(event); } @@ -69,22 +90,59 @@ class GPUContext::impl { } } - void RecordEvent(std::queue>& event_queue, + void RecordEvent(std::queue>& event_queue, std::string name, hipStream_t& stream) { - hipEvent_t event; - ErrorCheck("GetGpuEvent", GetGpuEvent(&event)); - ErrorCheck("hipEventRecord", hipEventRecord(event, stream)); + Event event; + ErrorCheck("GetGpuEvent", GetGpuEvent(&event, stream)); + ErrorCheck("hipEventRecord", + hipEventRecord(*(event.event), event.stream)); event_queue.emplace(name, event); } - void - WaitForEvents(std::queue>& event_queue, + Event RecordEvent(hipStream_t& stream) { + Event event; + ErrorCheck("GetGpuEvent", GetGpuEvent(&event, stream)); + ErrorCheck("hipEventRecord", + hipEventRecord(*(event.event), event.stream)); + return event; + } + + void WaitForEvents(std::queue>& event_queue, const std::vector& entries, Timeline& timeline, const std::function& error_check_callback) { while (!event_queue.empty()) { std::string name; - hipEvent_t event; + Event event; + std::tie(name, event) = event_queue.front(); + event_queue.pop(); + if (name != "") { + timeline.ActivityStartAll(entries, name); + } + + hipError_t hip_result = hipEventSynchronize(*(event.event)); + if (hip_result != hipSuccess) { + throw std::logic_error(std::string("cudaEventSynchronize failed: ") + + hipGetErrorString(hip_result)); + } + if (error_check_callback) { + error_check_callback(); + } + + if (name != "") { + timeline.ActivityEndAll(entries); + } + ErrorCheck("ReleaseGpuEvent", ReleaseGpuEvent(event)); + } + } + + void WaitForEventsElastic( + std::queue>& event_queue, + const std::vector& entries, Timeline& timeline, + const std::function& error_check_callback) { + while (!event_queue.empty()) { + std::string name; + Event event; std::tie(name, event) = event_queue.front(); event_queue.pop(); if (name != "") { @@ -95,13 +153,13 @@ class GPUContext::impl { // complete hipError_t hip_result; while (true) { - hip_result = hipEventQuery(event); + hip_result = hipEventQuery(*(event.event)); if (hip_result == hipSuccess) { break; } if (hip_result != hipErrorNotReady) { - throw std::logic_error(std::string("hipEventQuery failed: ") + + throw std::logic_error(std::string("cudaEventQuery failed: ") + hipGetErrorString(hip_result)); } @@ -118,11 +176,25 @@ class GPUContext::impl { } } - void WaitForEventsElastic( - std::queue>& event_queue, - const std::vector& entries, Timeline& timeline, - const std::function& error_check_callback) { - WaitForEvents(event_queue, entries, timeline, error_check_callback); + void ClearEvents(std::queue>& event_queue, + const std::vector& entries, + Timeline& timeline, + const std::function& error_check_callback, + bool elastic) { + while (!event_queue.empty()) { + std::string name; + Event event; + std::tie(name, event) = event_queue.front(); + event_queue.pop(); + if (name != "") { + timeline.ActivityStartAll(entries, name); + } + + if (name != "") { + timeline.ActivityEndAll(entries); + } + ErrorCheck("ReleaseGpuEvent", ReleaseGpuEvent(event)); + } } void StreamCreate(hipStream_t* stream) { @@ -176,8 +248,12 @@ class GPUContext::impl { private: // We reuse HIP events as it appears that their creation carries non-zero // cost. - std::unordered_map> hip_events; + std::unordered_map, std::queue> + hip_events; + std::unordered_map, bool> prepopulated; + std::unordered_map, std::atomic> hip_event_idx; std::mutex hip_events_mutex; + static constexpr int N_HIP_EVENTS_PREPOPULATE = 128; }; #include "gpu_context_impl.cc" diff --git a/horovod/common/ops/rocm/CMakeLists.txt b/horovod/common/ops/rocm/CMakeLists.txt new file mode 100644 index 0000000000..e0bbf65836 --- /dev/null +++ b/horovod/common/ops/rocm/CMakeLists.txt @@ -0,0 +1,24 @@ +message(STATUS "Build Horovod for ROCm") +if (NOT DEFINED HCC_APTH) + if (DEFINED ENV{HCC_PATH}) + set(HIP_PATH ${HCC_PATH} CACHE PATH "Path to which HCC has been installed") + else() + set(HCC_PATH "${ROCM_PATH}/hcc" CACHE PATH "Path to which HCC has been set") + endif() + set(HCC_HOME "{HCC_PATH}") +endif() + +list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm) +set(HIP_CLANG_PATH "${ROCM_PATH}/llvm/bin") +set(CMAKE_MODULE_PATH "${HIP_PATH}/cmake" ${CMAKE_MODULE_PATH}) +set(HIP_HCC_FLAGS ${HIP_HCC_FLAGS};-D__HIP_PLATFORM_HIPCC__=1;-fPIC) +find_package(HIP QUIET REQUIRED) +set(HIP_HIPCC_FLAGS ${HIP_HIPCC_FLAGS};-fPIC) +list(APPEND HIP_HCC_FLAGS_RELEASE -O3 -fPIC) +list(APPEND HIP_HCC_FLAGS_DEBUG -G -fPIC) + +list(APPEND HIP_HIPCC_FLAGS -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC) +hip_add_library(horovod_cuda_kernels STATIC hip_kernels.cu) +target_compile_definitions(horovod_cuda_kernels PRIVATE _GLIBCXX_USE_CXX11_ABI=1) +hip_add_library(compatible_horovod_cuda_kernels STATIC hip_kernels.cu) +target_compile_definitions(compatible_horovod_cuda_kernels PRIVATE _GLIBCXX_USE_CXX11_ABI=0) diff --git a/horovod/common/ops/rocm/hip_kernels.cu b/horovod/common/ops/rocm/hip_kernels.cu new file mode 100644 index 0000000000..d515ad3b4c --- /dev/null +++ b/horovod/common/ops/rocm/hip_kernels.cu @@ -0,0 +1,328 @@ +// Copyright (C) 2020 NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +// ATTENTION: Any change here might obsolete cuda_kernels.cu in cuda folder. +// Please keep this file synced with cuda_kernels.cu. + +#include "hip_kernels.h" + +#include +#include + +namespace horovod { +namespace common { + +template +__device__ void batched_memcpy_d(size_t idx, const void* in, void* out, size_t size) { + + const T* input = reinterpret_cast(in); + T* output = reinterpret_cast(out); + const size_t num_elements = size / sizeof(T); + + for (size_t i = idx; i < num_elements; i += blockDim.x * blocks_per_copy) { + output[i] = input[i]; + } + + // Deal with any remaining bytes + size_t remainder = size % sizeof(T); + if (remainder > 0 && idx < remainder) { + const unsigned char* input_r = reinterpret_cast(input + num_elements); + unsigned char* output_r = reinterpret_cast(output + num_elements); + output_r[idx] = input_r[idx]; + } +} + +template +__global__ void batched_memcpy_k(BatchedD2DParams params) { + const size_t idx = blockDim.x * (blockIdx.x % blocks_per_copy) + threadIdx.x; + + const size_t size = params.sizes[blockIdx.x / blocks_per_copy]; + const void* input = params.in[blockIdx.x / blocks_per_copy]; + void* output = params.out[blockIdx.x / blocks_per_copy]; + + // Check alignment relative to 16 bytes + size_t align_in = reinterpret_cast(input) % BATCHED_D2D_PADDING; + size_t align_out = reinterpret_cast(output) % BATCHED_D2D_PADDING; + + // Select load/store size based on the misaligned buffer + size_t align = (align_out == 0) ? align_in : align_out; + if (align_in && align_out) { + // If both are misaligned, use unsigned char (this should not occur + // as fusion buffer locations should be aligned by applying BATCH_D2D_PADDING + // during construction.) + align = 1; + } + + if (align % 16 == 0) { + batched_memcpy_d(idx, input, output, size); + } else if (align % 8 == 0) { + batched_memcpy_d(idx, input, output, size); + } else if (align % 4 == 0) { + batched_memcpy_d(idx, input, output, size); + } else if (align % 2 == 0) { + batched_memcpy_d(idx, input, output, size); + } else { + batched_memcpy_d(idx, input, output, size); + } +} + +#define NTHREADS_D2D_KERNEL 1024 +#define BLOCKS_PER_COPY_D2D_KERNEL 8 +void BatchedD2DMemcpyROCmImpl(BatchedD2DParams& params, int num_copies, hipStream_t stream) +{ + batched_memcpy_k<<>>(params); +} + +template +__global__ void scale_buffer_k(const T* input, T* output, int64_t num_elements, const TS scale_factor) { + + const size_t idx = static_cast(blockDim.x) * blockIdx.x + threadIdx.x; + + for (size_t i = idx; i < num_elements; i += gridDim.x * blockDim.x) { + output[i] = scale_factor * input[i]; + } +} + +// Specialization for half2 +__global__ void scale_buffer_half2_k(const __half* input, __half* output, int64_t num_elements, const __half scale_factor) { + + const size_t idx = static_cast(blockDim.x) * blockIdx.x + threadIdx.x; + +#if __CUDA_ARCH__ > 530 + const __half2* input_h2 = reinterpret_cast(input); + __half2* output_h2 = reinterpret_cast<__half2 *>(output); + const __half2 scale_factor_h2 = __halves2half2(scale_factor, scale_factor); + + for (size_t i = idx; i < num_elements / 2; i += gridDim.x * blockDim.x) { + output_h2[i] = __hmul2(scale_factor_h2, input_h2[i]); + } + + // Deal with last element if num_elements is odd + if (idx == 0 && num_elements % 2) { + output[num_elements - 1] = __hmul(scale_factor, input[num_elements - 1]); + } +#else + for (size_t i = idx; i < num_elements; i += gridDim.x * blockDim.x) { + output[i] = __float2half(__half2float(scale_factor) * __half2float(input[i])); + } +#endif +} + +// Specialization for architectures without __half compute +template<> +__global__ void scale_buffer_k(const __half* input, __half* output, int64_t num_elements, const __half scale_factor) { + + const size_t idx = static_cast(blockDim.x) * blockIdx.x + threadIdx.x; + +#if __CUDA_ARCH__ > 530 + for (size_t i = idx; i < num_elements; i += gridDim.x * blockDim.x) { + output[i] = scale_factor * input[i]; + } +#else + for (size_t i = idx; i < num_elements; i += gridDim.x * blockDim.x) { + output[i] = __float2half(__half2float(scale_factor) * __half2float(input[i])); + } +#endif +} + +#define NTHREADS_SCALE_BUFFER_KERNEL 512 +void ScaleBufferROCmImpl(const void* fused_input_data, void* buffer_data, const int64_t num_elements, double scale_factor, + DataType dtype, hipStream_t stream) { + const int64_t blocks = (num_elements + NTHREADS_SCALE_BUFFER_KERNEL - 1) / NTHREADS_SCALE_BUFFER_KERNEL; + const int threads = NTHREADS_SCALE_BUFFER_KERNEL; + switch (dtype) { + case HOROVOD_UINT8: + scale_buffer_k<<>>((const uint8_t*) fused_input_data, (uint8_t*) buffer_data, + num_elements, scale_factor); + break; + case HOROVOD_INT8: + scale_buffer_k<<>>((const int8_t*) fused_input_data, (int8_t*) buffer_data, + num_elements, scale_factor); + break; + case HOROVOD_INT32: + scale_buffer_k<<>>((const int32_t*) fused_input_data, (int32_t*) buffer_data, + num_elements, scale_factor); + break; + case HOROVOD_INT64: + scale_buffer_k<<>>((const int64_t*) fused_input_data, (int64_t*) buffer_data, + num_elements, scale_factor); + break; + case HOROVOD_FLOAT16: + { + __half scale_factor_half = __float2half((float) scale_factor); + if ((size_t) fused_input_data % 4 == 0 && (size_t) buffer_data % 4 == 0) { + // If alignment allows, use half2 specialized kernel + int64_t num_elements_h2 = (num_elements + 1) / 2; + int64_t blocks_h2 = (num_elements_h2 + NTHREADS_SCALE_BUFFER_KERNEL - 1) / NTHREADS_SCALE_BUFFER_KERNEL; + scale_buffer_half2_k<<>>((const __half*) fused_input_data, (__half*) buffer_data, + num_elements, scale_factor_half); + } else { + scale_buffer_k<<>>((const __half*) fused_input_data, (__half*) buffer_data, + num_elements, scale_factor_half); + } + break; + } + case HOROVOD_FLOAT32: + scale_buffer_k<<>>((const float*) fused_input_data, (float*) buffer_data, + num_elements, (float) scale_factor); + break; + case HOROVOD_FLOAT64: + scale_buffer_k<<>>((const double*) fused_input_data, (double*) buffer_data, + num_elements, scale_factor); + break; + default: + throw std::logic_error("Type " + DataType_Name(dtype) + + " not supported by ScaleBufferROCmImpl."); + } +} + +template +__device__ void batched_scaled_memcpy_d(size_t idx, const T* input, T* output, size_t size, const TS scale_factor) { + + const int64_t num_words = size / sizeof(TL); + const TL* read_ptr = reinterpret_cast(input); + TL* write_ptr = reinterpret_cast(output); + for (size_t i = idx; i < num_words; i += blockDim.x * blocks_per_copy) { + // Load word + TL word = read_ptr[i]; + T* val = reinterpret_cast(&word); + + // Scale elements in word + for (int j = 0; j < sizeof(TL) / sizeof(T); ++j) { + val[j] *= scale_factor; + } + + // Write word + write_ptr[i] = word; + } + + // Deal with any remaining elements + size_t remainder = (size % sizeof(TL)) / sizeof(T); + if (remainder > 0 && idx < remainder) { + const T* input_r = reinterpret_cast(read_ptr + num_words); + T* output_r = reinterpret_cast(write_ptr + num_words); + output_r[idx] = scale_factor * input_r[idx]; + } +} + +// Specialization for architectures without __half compute +template +__device__ void batched_scaled_memcpy_d(size_t idx, const __half* input, __half* output, size_t size, const __half scale_factor) { + + const int64_t num_words = size / sizeof(TL); + const TL* read_ptr = reinterpret_cast(input); + TL* write_ptr = reinterpret_cast(output); + for (size_t i = idx; i < num_words; i += blockDim.x * blocks_per_copy) { + // Load word + TL word = read_ptr[i]; + __half* val = reinterpret_cast<__half*>(&word); + + // Scale elements in word + for (int j = 0; j < sizeof(TL) / sizeof(__half); ++j) { +#if __CUDA_ARCH__ > 530 + val[j] *= scale_factor; +#else + val[j] = __float2half(__half2float(scale_factor) * __half2float(val[j])); +#endif + } + + // Write word + write_ptr[i] = word; + } + + // Deal with any remaining elements + size_t remainder = (size % sizeof(TL)) / sizeof(__half); + if (remainder > 0 && idx < remainder) { + const __half* input_r = reinterpret_cast(read_ptr + num_words); + __half* output_r = reinterpret_cast<__half*>(write_ptr + num_words); +#if __CUDA_ARCH__ > 530 + output_r[idx] = scale_factor * input_r[idx]; +#else + output_r[idx] = __float2half(__half2float(scale_factor) * __half2float(input_r[idx])); +#endif + } +} + +template +__global__ void batched_scaled_memcpy_k(BatchedD2DParams params, TS scale_factor) { + const size_t idx = blockDim.x * (blockIdx.x % blocks_per_copy) + threadIdx.x; + + const size_t size = params.sizes[blockIdx.x / blocks_per_copy]; + const T* input = reinterpret_cast(params.in[blockIdx.x / blocks_per_copy]); + T* output = reinterpret_cast(params.out[blockIdx.x / blocks_per_copy]); + + // Check alignment relative to 16 bytes + size_t align_in = reinterpret_cast(input) % BATCHED_D2D_PADDING; + size_t align_out = reinterpret_cast(output) % BATCHED_D2D_PADDING; + + // Select load/store size based on the misaligned buffer + size_t align = (align_out == 0) ? align_in : align_out; + if (align_in && align_out) { + + // If both are misaligned, use datatype size + align = sizeof(T); + } + + if (align % 16 == 0) { + batched_scaled_memcpy_d(idx, input, output, size, scale_factor); + } else if (align % 8 == 0) { + batched_scaled_memcpy_d(idx, input, output, size, scale_factor); + } else if (align % 4 == 0) { + batched_scaled_memcpy_d(idx, input, output, size, scale_factor); + } else if (align % 2 == 0) { + batched_scaled_memcpy_d(idx, input, output, size, scale_factor); + } else { + batched_scaled_memcpy_d(idx, input, output, size, scale_factor); + } +} + +void BatchedScaledD2DMemcpyROCmImpl(BatchedD2DParams& params, int num_copies, double scale_factor, + DataType dtype, hipStream_t stream) { + const int64_t blocks = num_copies * BLOCKS_PER_COPY_D2D_KERNEL; + const int threads = NTHREADS_D2D_KERNEL; + switch (dtype) { + case HOROVOD_UINT8: + batched_scaled_memcpy_k<<>>(params, scale_factor); + break; + case HOROVOD_INT8: + batched_scaled_memcpy_k<<>>(params, scale_factor); + break; + case HOROVOD_INT32: + batched_scaled_memcpy_k<<>>(params, scale_factor); + break; + case HOROVOD_INT64: + batched_scaled_memcpy_k<<>>(params, scale_factor); + break; + case HOROVOD_FLOAT16: { + __half scale_factor_half = __float2half((float) scale_factor); + batched_scaled_memcpy_k<__half, BLOCKS_PER_COPY_D2D_KERNEL><<>>(params, scale_factor_half); + break; + } + case HOROVOD_FLOAT32: + batched_scaled_memcpy_k<<>>(params, (float) scale_factor); + break; + case HOROVOD_FLOAT64: + batched_scaled_memcpy_k<<>>(params, scale_factor); + break; + default: + throw std::logic_error("Type " + DataType_Name(dtype) + + " not supported by BatchedScaledD2DMemcpyROCmImpl."); + } +} + +} // namespace common +} // namespace horovod + diff --git a/horovod/common/ops/rocm/hip_kernels.h b/horovod/common/ops/rocm/hip_kernels.h new file mode 100644 index 0000000000..2b578b1678 --- /dev/null +++ b/horovod/common/ops/rocm/hip_kernels.h @@ -0,0 +1,51 @@ +// Copyright (C) 2020 NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +// ATTENTION: Any change here might obsolete cuda_kernels.h in cuda folder. +// Please keep this file synced with cuda_kernels.h. + +#ifndef HIP_KERNELS_H +#define HIP_KERNELS_H + +#include + +#include "../../message.h" + +#define BATCHED_D2D_CAPACITY 160 +#define BATCHED_D2D_PADDING 16 + +namespace horovod { +namespace common { + +struct BatchedD2DParams { + void* out[BATCHED_D2D_CAPACITY]; + void* in[BATCHED_D2D_CAPACITY]; + size_t sizes[BATCHED_D2D_CAPACITY]; +}; + +// Performs a batched d2d memcopy +void BatchedD2DMemcpyROCmImpl(BatchedD2DParams& params, int num_copies, hipStream_t stream); + +// Scales buffer by scalar +void ScaleBufferROCmImpl(const void* fused_input_data, void* buffer_data, const int64_t num_elements, + double scale_factor, DataType dtype, hipStream_t stream); + +void BatchedScaledD2DMemcpyROCmImpl(BatchedD2DParams& params, int num_copies, double scale_factor, + DataType dtype, hipStream_t stream); + +} // namespace common +} // namespace horovod + +#endif // HIP_KERNELS_H diff --git a/horovod/tensorflow/CMakeLists.txt b/horovod/tensorflow/CMakeLists.txt index 875d5a7e47..fdf7ebe43d 100644 --- a/horovod/tensorflow/CMakeLists.txt +++ b/horovod/tensorflow/CMakeLists.txt @@ -52,6 +52,13 @@ if(HAVE_CUDA) list(APPEND TF_LINKER_LIBS compatible_horovod_cuda_kernels) endif() endif() +if(HAVE_ROCM) + if (Tensorflow_CXX11) + list(APPEND TF_LINKER_LIBS horovod_cuda_kernels) + else() + list(APPEND TF_LINKER_LIBS compatible_horovod_cuda_kernels) + endif() +endif() set(CMAKE_CXX_FLAGS "${Tensorflow_COMPILE_FLAGS} ${CMAKE_CXX_FLAGS}") parse_version(${Tensorflow_VERSION} VERSION_DEC) add_definitions(-DTENSORFLOW_VERSION=${VERSION_DEC}) diff --git a/horovod/tensorflow/mpi_ops.cc b/horovod/tensorflow/mpi_ops.cc index a226ad9c10..535736743a 100644 --- a/horovod/tensorflow/mpi_ops.cc +++ b/horovod/tensorflow/mpi_ops.cc @@ -27,6 +27,10 @@ #define EIGEN_USE_GPU #endif // HAVE_CUDA || HAVE_ROCM +#if HAVE_ROCM +#define EIGEN_USE_HIP +#endif + #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" diff --git a/horovod/tensorflow/xla_mpi_ops.cc b/horovod/tensorflow/xla_mpi_ops.cc index 59ba7d511c..d2ea00c5d4 100644 --- a/horovod/tensorflow/xla_mpi_ops.cc +++ b/horovod/tensorflow/xla_mpi_ops.cc @@ -19,6 +19,7 @@ #include #include + #if TENSORFLOW_VERSION >= 2006000000 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -33,6 +34,7 @@ #include "tensorflow/core/platform/human_readable_json.h" #if HAVE_GPU +#include "../common/common.h" #if HAVE_CUDA #include @@ -49,6 +51,15 @@ #include "../common/operations.h" #include "../common/utils/env_parser.h" #include "./custom_call_config_generated.h" +#elif HAVE_ROCM + +#include + +#define OMPI_SKIP_MPICXX +#include "../common/operations.h" +#include "../common/utils/env_parser.h" +#include "./custom_call_config_generated.h" +#endif // HAVE_CUDA using namespace tensorflow; @@ -335,7 +346,7 @@ class HVDCustomCallRendezvous { // outstanding `Wait` call due to its blocking nature to simplify the // implementation. Consequently, this method always operates on the very // first item in the queue. - void Wait(string tensor_name, CUstream stream) { + void Wait(string tensor_name, gpuStream_t stream) { uint64 key_hash = GetRendezvousKeyHash(tensor_name); { @@ -371,7 +382,11 @@ class HVDCustomCallRendezvous { delete queue; } if (event) { +#if HAVE_CUDA CUDA_CALL(cudaStreamWaitEvent(stream, *event, /*flags=*/0)); +#elif HAVE_ROCM + HVD_GPU_CHECK(hipStreamWaitEvent(stream, *event, /*flags=*/0)); +#endif } delete payload; } @@ -403,21 +418,28 @@ class HVDCustomCallRendezvous { class XLAReadyEvent : public common::ReadyEvent { public: - XLAReadyEvent(cudaStream_t stream) : stream_(stream) { + XLAReadyEvent(gpuStream_t stream) : stream_(stream) { +#if HAVE_CUDA CUDA_CALL(cudaEventCreate(&event_)); CUDA_CALL(cudaEventRecord(event_, stream)); } ~XLAReadyEvent() { CUDA_CALL(cudaEventDestroy(event_)); } +#elif HAVE_ROCM + HVD_GPU_CHECK(hipEventCreate(&event_)); + HVD_GPU_CHECK(hipEventRecord(event_, stream)); + } + ~XLAReadyEvent() { HVD_GPU_CHECK(hipEventDestroy(event_)); } +#endif bool Ready() const override { - cudaError_t result = cudaEventQuery(event_); - return cudaErrorNotReady != result; + gpuError_t result = gpuEventQuery(event_); + return gpuErrorNotReady != result; } gpuEvent_t event() const override { return event_; } private: - cudaStream_t stream_; // Not Owned. - cudaEvent_t event_; // Owned. + gpuStream_t stream_; // Not Owned. + gpuEvent_t event_; // Owned. }; class XLATensor : public common::Tensor { @@ -475,11 +497,19 @@ class XLAPersistentBuffer : public common::PersistentBuffer { XLAPersistentBuffer::XLAPersistentBuffer(int device, int64_t size) : device_(device) { int restore_device; +#if HAVE_CUDA CUDA_CALL(cudaGetDevice(&restore_device)); CUDA_CALL(cudaSetDevice(device)); // Simply call cudaMalloc for persistent buffer. CUDA_CALL(cudaMalloc((void**)&buffer_, size)); CUDA_CALL(cudaSetDevice(restore_device)); +#elif HAVE_ROCM + HVD_GPU_CHECK(hipGetDevice(&restore_device)); + HVD_GPU_CHECK(hipSetDevice(device)); + // Simply call hipMalloc for persistent buffer. + HVD_GPU_CHECK(hipMalloc((void**)&buffer_, size)); + HVD_GPU_CHECK(hipSetDevice(restore_device)); +#endif } const void* XLAPersistentBuffer::AccessData( @@ -509,18 +539,22 @@ XLAOpContext::AllocateZeros(int64_t num_elements, common::DataType dtype, "AllocateZeros is not supported for XLA."); } -common::ReadyEvent* RecordReadyEvent(cudaStream_t stream) { +common::ReadyEvent* RecordReadyEvent(gpuStream_t stream) { return new XLAReadyEvent(stream); } int GetDeviceOrdinal(void* ptr) { - cudaPointerAttributes attrs; + gpuPointerAttribute_t attrs; +#if HAVE_CUDA CUDA_CALL(cudaPointerGetAttributes(&attrs, ptr)); +#elif HAVE_ROCM + HVD_GPU_CHECK(hipPointerGetAttributes(&attrs, ptr)); +#endif return attrs.device; } // Implements for the `HVDAllreduce` HLO CustomCall. -void CallbackHVDAllreduce(CUstream stream, void** buffers, const char* opaque, +void CallbackHVDAllreduce(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { CHECK(common::CheckInitialized().ok()); CustomCallConfig config; @@ -553,7 +587,7 @@ void CallbackHVDAllreduce(CUstream stream, void** buffers, const char* opaque, } // Implements for the `HVDAllreduceDone` HLO CustomCall. -void CallbackHVDAllreduceDone(CUstream stream, void** /*buffers*/, +void CallbackHVDAllreduceDone(gpuStream_t stream, void** /*buffers*/, const char* opaque, size_t opaque_len) { // Blocking until the request is done processing by the Horovod runtime. VLOG(2) << "hvd-allreduce-done - Start"; @@ -563,13 +597,17 @@ void CallbackHVDAllreduceDone(CUstream stream, void** /*buffers*/, VLOG(2) << "hvd-allreduce-done - End"; } +#if HAVE_CUDA XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduce, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "CUDA"); +#elif HAVE_ROCM +XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduce, "ROCM"); +XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "ROCM"); +#endif } // namespace } // namespace tensorflow } // namespace horovod -#endif // TENSORFLOW_VERSION >= 2006000000 -#endif // HAVE_CUDA #endif // HAVE_GPU +#endif // TENSORFLOW_VERSION >= 2006000000 diff --git a/horovod/torch/CMakeLists.txt b/horovod/torch/CMakeLists.txt index aadab89854..d0f86fae2d 100644 --- a/horovod/torch/CMakeLists.txt +++ b/horovod/torch/CMakeLists.txt @@ -53,6 +53,13 @@ if(HAVE_CUDA) list(APPEND PYTORCH_LINKER_LIBS compatible_horovod_cuda_kernels) endif() endif() +if(HAVE_ROCM) + if (Pytorch_CXX11) + list(APPEND PYTORCH_LINKER_LIBS horovod_cuda_kernels) + else() + list(APPEND PYTORCH_LINKER_LIBS compatible_horovod_cuda_kernels) + endif() +endif() parse_version(${Pytorch_VERSION} VERSION_DEC) add_definitions(-DPYTORCH_VERSION=${VERSION_DEC} -DTORCH_API_INCLUDE_EXTENSION_H=1) set(Pytorch_CXX11 ${Pytorch_CXX11} PARENT_SCOPE)