From fbb8ce28c9b77ecfa69e075fa94a6fc61252c12e Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Tue, 10 Nov 2020 13:49:36 -0800 Subject: [PATCH] Adding support for batched D2D memcopy kernel on GPU. Signed-off-by: Josh Romero --- horovod/common/common.h | 1 + horovod/common/controller.cc | 26 +++++- horovod/common/controller.h | 3 +- horovod/common/global_state.h | 3 + horovod/common/operations.cc | 8 ++ horovod/common/ops/adasum_gpu_operations.cc | 5 +- horovod/common/ops/cuda/cuda_kernels.cu | 62 +++++++++++++ horovod/common/ops/cuda/cuda_kernels.h | 15 ++++ horovod/common/ops/ddl_operations.cc | 5 +- horovod/common/ops/gpu_operations.cc | 97 +++++++++++++++++++++ horovod/common/ops/gpu_operations.h | 7 ++ horovod/common/ops/mpi_gpu_operations.cc | 3 +- horovod/common/ops/nccl_operations.cc | 10 +-- 13 files changed, 224 insertions(+), 21 deletions(-) diff --git a/horovod/common/common.h b/horovod/common/common.h index f6f5f52722..342d74d793 100644 --- a/horovod/common/common.h +++ b/horovod/common/common.h @@ -79,6 +79,7 @@ namespace common { #define HOROVOD_HIERARCHICAL_ALLREDUCE "HOROVOD_HIERARCHICAL_ALLREDUCE" #define HOROVOD_HIERARCHICAL_ALLGATHER "HOROVOD_HIERARCHICAL_ALLGATHER" #define HOROVOD_CACHE_CAPACITY "HOROVOD_CACHE_CAPACITY" +#define HOROVOD_BATCH_D2D_MEMCOPIES "HOROVOD_BATCH_D2D_MEMCOPIES" #define HOROVOD_NUM_NCCL_STREAMS "HOROVOD_NUM_NCCL_STREAMS" #define HOROVOD_CPU_OPERATIONS "HOROVOD_CPU_OPERATIONS" #define HOROVOD_CONTROLLER "HOROVOD_CONTROLLER" diff --git a/horovod/common/controller.cc b/horovod/common/controller.cc index 26f90466b8..ebeb722624 100644 --- a/horovod/common/controller.cc +++ b/horovod/common/controller.cc @@ -27,6 +27,11 @@ #include "logging.h" #include "operations.h" +#if HAVE_CUDA +#include "ops/cuda/cuda_kernels.h" +#endif + + namespace horovod { namespace common { @@ -199,7 +204,7 @@ ResponseList Controller::ComputeResponseList(std::atomic_bool& shut_down, } // Fuse responses as normal. - response_list = FuseResponses(responses); + response_list = FuseResponses(responses, state); response_list.set_shutdown(cache_coordinator.should_shut_down()); } else { // There are uncached messages coming in, need communication to figure out @@ -306,7 +311,7 @@ ResponseList Controller::ComputeResponseList(std::atomic_bool& shut_down, responses.push_back(std::move(join_response)); state.joined_size = 0; } - response_list = FuseResponses(responses); + response_list = FuseResponses(responses, state); response_list.set_shutdown(should_shut_down); // Broadcast final results to other ranks. @@ -683,7 +688,8 @@ void Controller::CoordinateCacheAndState(CacheCoordinator& cache_coordinator) { } } -ResponseList Controller::FuseResponses(std::deque& responses) { +ResponseList Controller::FuseResponses(std::deque& responses, + HorovodGlobalState& state) { ResponseList response_list; while (!responses.empty()) { @@ -696,6 +702,12 @@ ResponseList Controller::FuseResponses(std::deque& responses) { // Attempt to add more responses to this fused response. tensor_size = response.tensor_sizes()[0] * GetTypeSize(response.tensor_type()); +#if HAVE_CUDA + if (state.batch_d2d_memcopies) { + // Add 16 byte pad for batched memcpy op + tensor_size = BATCHED_D2D_PADDING * ((tensor_size + BATCHED_D2D_PADDING - 1) / BATCHED_D2D_PADDING); + } +#endif std::deque skipped_responses; int64_t skipped_size = 0; while (!responses.empty()) { @@ -706,6 +718,14 @@ ResponseList Controller::FuseResponses(std::deque& responses) { ? 0 : new_response.tensor_sizes()[0] * GetTypeSize(new_response.tensor_type()); + +#if HAVE_CUDA + if (state.batch_d2d_memcopies) { + // Add 16 byte pad for batched memcpy op + new_tensor_size = BATCHED_D2D_PADDING * ((new_tensor_size + BATCHED_D2D_PADDING - 1) / BATCHED_D2D_PADDING); + } +#endif + if (response.response_type() == new_response.response_type() && response.devices() == new_response.devices() && response.tensor_type() == new_response.tensor_type() && diff --git a/horovod/common/controller.h b/horovod/common/controller.h index f68c2bd798..9c4a0755c9 100644 --- a/horovod/common/controller.h +++ b/horovod/common/controller.h @@ -162,7 +162,8 @@ class Controller : public std::enable_shared_from_this { // exist on any worker. void CoordinateCacheAndState(CacheCoordinator& cache_coordinator); - ResponseList FuseResponses(std::deque& responses); + ResponseList FuseResponses(std::deque& responses, + HorovodGlobalState& state); // Return the total byte size of the final allgathered output tensor int64_t diff --git a/horovod/common/global_state.h b/horovod/common/global_state.h index 30b6c202d8..cde515ec0b 100644 --- a/horovod/common/global_state.h +++ b/horovod/common/global_state.h @@ -110,6 +110,9 @@ struct HorovodGlobalState { // benefit from a smaller chunk size. int64_t adasum_mpi_chunk_size = 1<<30; + // Enable use of batched d2d memcopy kernel on GPU + bool batch_d2d_memcopies = false; + ~HorovodGlobalState() { // Make sure that the destructor of the background thread is safe to // call. If a thread is still joinable (not detached or complete) its diff --git a/horovod/common/operations.cc b/horovod/common/operations.cc index 7647389622..0e384bd272 100644 --- a/horovod/common/operations.cc +++ b/horovod/common/operations.cc @@ -504,6 +504,14 @@ void BackgroundThreadLoop(HorovodGlobalState& state) { "allgather and hierarchical allreduce."; } + // Set flag to enable use of batched memcopy kernel on GPU + auto horovod_batch_d2d_memcopies = + std::getenv(HOROVOD_BATCH_D2D_MEMCOPIES); + if (horovod_batch_d2d_memcopies != nullptr && + std::strtol(horovod_batch_d2d_memcopies, nullptr, 10) > 0) { + state.batch_d2d_memcopies = true; + } + // Enable auto-tuning. auto horovod_autotune = std::getenv(HOROVOD_AUTOTUNE); if (horovod_autotune != nullptr && diff --git a/horovod/common/ops/adasum_gpu_operations.cc b/horovod/common/ops/adasum_gpu_operations.cc index 291469a016..88988ecd8f 100644 --- a/horovod/common/ops/adasum_gpu_operations.cc +++ b/horovod/common/ops/adasum_gpu_operations.cc @@ -86,10 +86,7 @@ AdasumGpuAllreduceOp::NcclHierarchical(std::vector& entries, buffer_len = (size_t)first_entry.output->size(); } - int64_t num_elements = 0; - for (auto& e : entries) { - num_elements += e.tensor->shape().num_elements(); - } + int64_t num_elements = buffer_len / DataType_Size(first_entry.tensor->dtype()); if (response.prescale_factor() != 1.0) { // Execute prescaling op diff --git a/horovod/common/ops/cuda/cuda_kernels.cu b/horovod/common/ops/cuda/cuda_kernels.cu index 5cf83facdf..c19a13473d 100644 --- a/horovod/common/ops/cuda/cuda_kernels.cu +++ b/horovod/common/ops/cuda/cuda_kernels.cu @@ -21,6 +21,68 @@ namespace horovod { namespace common { +template +__device__ void batched_memcpy_d(size_t idx, const void* in, void* out, size_t size) { + + const T* input = reinterpret_cast(in); + T* output = reinterpret_cast(out); + const size_t num_elements = size / sizeof(T); + + for (size_t i = idx; i < num_elements; i += blockDim.x * blocks_per_copy) { + output[i] = input[i]; + } + + // Deal with any remaining bytes + size_t remainder = size % sizeof(T); + if (remainder > 0 && idx < remainder) { + const unsigned char* input_r = reinterpret_cast(input + num_elements); + unsigned char* output_r = reinterpret_cast(output + num_elements); + output_r[idx] = input_r[idx]; + } +} + +template +__global__ void batched_memcpy_k(BatchedD2DParams params) { + const size_t idx = blockDim.x * (blockIdx.x % blocks_per_copy) + threadIdx.x; + + const size_t size = params.sizes[blockIdx.x / blocks_per_copy]; + const void* input = params.in[blockIdx.x / blocks_per_copy]; + void* output = params.out[blockIdx.x / blocks_per_copy]; + + // Check alignment relative to 16 bytes + size_t align_in = reinterpret_cast(input) % BATCHED_D2D_PADDING; + size_t align_out = reinterpret_cast(output) % BATCHED_D2D_PADDING; + + // Select load/store size based on the misaligned buffer + size_t align = (align_out == 0) ? align_in : align_out; + if (align_in && align_out) { + // If both are misaligned, use unsigned char (this should not occur + // as fusion buffer locations should be aligned by applying BATCH_D2D_PADDING + // during construction.) + align = 1; + } + + if (align % 16 == 0) { + batched_memcpy_d(idx, input, output, size); + } else if (align % 8 == 0) { + batched_memcpy_d(idx, input, output, size); + } else if (align % 4 == 0) { + batched_memcpy_d(idx, input, output, size); + } else if (align % 2 == 0) { + batched_memcpy_d(idx, input, output, size); + } else { + batched_memcpy_d(idx, input, output, size); + } +} + +#define NTHREADS_D2D_KERNEL 1024 +#define BLOCKS_PER_COPY_D2D_KERNEL 8 +void BatchedD2DMemcpyCudaImpl(BatchedD2DParams& params, int num_copies, cudaStream_t stream) +{ + batched_memcpy_k<<>>(params); +} + template __global__ void scale_buffer_k(const T* input, T* output, int64_t num_elements, const TS scale_factor) { diff --git a/horovod/common/ops/cuda/cuda_kernels.h b/horovod/common/ops/cuda/cuda_kernels.h index 6fe480d72c..60636e4d3b 100644 --- a/horovod/common/ops/cuda/cuda_kernels.h +++ b/horovod/common/ops/cuda/cuda_kernels.h @@ -16,11 +16,26 @@ #ifndef CUDA_KERNELS_H #define CUDA_KERNELS_H +#include + #include "../../message.h" +#define BATCHED_D2D_CAPACITY 160 +#define BATCHED_D2D_PADDING 16 + namespace horovod { namespace common { +struct BatchedD2DParams { + void* out[BATCHED_D2D_CAPACITY]; + void* in[BATCHED_D2D_CAPACITY]; + size_t sizes[BATCHED_D2D_CAPACITY]; +}; + +// Performs a batched d2d memcopy +void BatchedD2DMemcpyCudaImpl(BatchedD2DParams& params, int num_copies, cudaStream_t stream); + +// Scales buffer by scalar void ScaleBufferCudaImpl(const void* fused_input_data, void* buffer_data, const int64_t num_elements, double scale_factor, DataType dtype, cudaStream_t stream); diff --git a/horovod/common/ops/ddl_operations.cc b/horovod/common/ops/ddl_operations.cc index 78e920e408..ec151b1d7e 100644 --- a/horovod/common/ops/ddl_operations.cc +++ b/horovod/common/ops/ddl_operations.cc @@ -66,10 +66,7 @@ Status DDLAllreduce::Execute(std::vector& entries, const Respo buffer_len = (size_t) first_entry.output->size(); } - int64_t num_elements = 0; - for (auto& e : entries) { - num_elements += e.tensor->shape().num_elements(); - } + int64_t num_elements = buffer_len / DataType_Size(first_entry.tensor->dtype()); if (response.prescale_factor() != 1.0) { // Execute prescaling op diff --git a/horovod/common/ops/gpu_operations.cc b/horovod/common/ops/gpu_operations.cc index 6fad044178..2c6e2a6674 100644 --- a/horovod/common/ops/gpu_operations.cc +++ b/horovod/common/ops/gpu_operations.cc @@ -15,6 +15,9 @@ // ============================================================================= #include "gpu_operations.h" +#if HAVE_CUDA +#include "cuda/cuda_kernels.h" +#endif #include @@ -95,6 +98,60 @@ bool GPUAllreduce::Enabled(const ParameterManager& param_manager, return entries[0].device != CPU_DEVICE_ID; } +#if HAVE_CUDA +void GPUAllreduce::MemcpyInFusionBuffer(const std::vector& entries, const void*& fused_input_data, + void*& buffer_data, size_t& buffer_len) { + // Access the fusion buffer. + auto& first_entry = entries[0]; + auto buffer = global_state_->fusion_buffer.GetBuffer( + first_entry.device, first_entry.context->framework(), global_state_->current_nccl_stream); + buffer_data = const_cast(buffer->AccessData(first_entry.context)); + + if (global_state_->batch_d2d_memcopies) { + int64_t offset = 0; + int idx = 0; + int count = 0; + + BatchedD2DParams d2d_params; + auto& first_entry = entries[0]; + for (auto& e : entries) { + void* buffer_data_at_offset = (uint8_t*)buffer_data + offset; + + // Set input/output pointers and sizes + d2d_params.out[idx % BATCHED_D2D_CAPACITY] = buffer_data_at_offset; + d2d_params.in[idx % BATCHED_D2D_CAPACITY] = (void*) e.tensor->data(); + d2d_params.sizes[idx % BATCHED_D2D_CAPACITY] = e.tensor->size(); + + offset += BATCHED_D2D_PADDING * ((e.tensor->size() + BATCHED_D2D_PADDING - 1) / BATCHED_D2D_PADDING); + idx++; + count++; + + if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int) entries.size()) { + // Perform batched d2d memcpy + BatchedD2DMemcpyCudaImpl(d2d_params, count, gpu_context_->streams[global_state_->current_nccl_stream][first_entry.device]); + gpu_context_->ErrorCheck("BatchedD2DMemcpyCudaImpl", cudaGetLastError()); + count = 0; + } + } + buffer_len = (size_t)offset; + + } else { + int64_t offset = 0; + for (auto& e : entries) { + void* buffer_data_at_offset = (uint8_t*) buffer_data + offset; + MemcpyEntryInFusionBuffer(entries, e, buffer_data_at_offset); + offset += e.tensor->size(); + } + + buffer_len = (size_t) offset; + } + + // Set the input data to originate from the buffer. + fused_input_data = buffer_data; +} +#endif + + void GPUAllreduce::MemcpyEntryInFusionBuffer(const std::vector& entries, const TensorTableEntry& e, void* buffer_data_at_offset) { auto& first_entry = entries[0]; @@ -102,6 +159,46 @@ void GPUAllreduce::MemcpyEntryInFusionBuffer(const std::vector gpu_context_->streams[global_state_->current_nccl_stream][first_entry.device]); } +#if HAVE_CUDA +void GPUAllreduce::MemcpyOutFusionBuffer(const void* buffer_data, std::vector& entries) { + if (global_state_->batch_d2d_memcopies) { + int64_t offset = 0; + int idx = 0; + int count = 0; + + BatchedD2DParams d2d_params; + auto& first_entry = entries[0]; + for (auto& e : entries) { + void* buffer_data_at_offset = (uint8_t*)buffer_data + offset; + + // Set input/output pointers and sizes + d2d_params.out[idx % BATCHED_D2D_CAPACITY] = (void*)(e.output->data()); + d2d_params.in[idx % BATCHED_D2D_CAPACITY] = buffer_data_at_offset; + d2d_params.sizes[idx % BATCHED_D2D_CAPACITY] = e.tensor->size(); + + offset += BATCHED_D2D_PADDING * ((e.tensor->size() + BATCHED_D2D_PADDING - 1) / BATCHED_D2D_PADDING); + idx++; + count++; + + if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int) entries.size()) { + // Perform batched d2d memcpy + BatchedD2DMemcpyCudaImpl(d2d_params, count, gpu_context_->streams[global_state_->current_nccl_stream][first_entry.device]); + gpu_context_->ErrorCheck("BatchedD2DMemcpyCudaImpl", cudaGetLastError()); + count = 0; + } + } + + } else { + int64_t offset = 0; + for (auto& e : entries) { + void* buffer_data_at_offset = (uint8_t*) buffer_data + offset; + MemcpyEntryOutFusionBuffer(entries, buffer_data_at_offset, e); + offset += e.tensor->size(); + } + } +} +#endif + void GPUAllreduce::MemcpyEntryOutFusionBuffer(const std::vector& entries, const void* buffer_data_at_offset, TensorTableEntry& e) { auto& first_entry = entries[0]; diff --git a/horovod/common/ops/gpu_operations.h b/horovod/common/ops/gpu_operations.h index d630ad9571..d0c8d242b7 100644 --- a/horovod/common/ops/gpu_operations.h +++ b/horovod/common/ops/gpu_operations.h @@ -136,6 +136,13 @@ class GPUAllreduce : public AllreduceOp { const Response& response) const override; protected: +#if HAVE_CUDA + void MemcpyInFusionBuffer(const std::vector& entries, const void*& fused_input_data, + void*& buffer_data, size_t& buffer_len) override; + + void MemcpyOutFusionBuffer(const void* buffer_data, std::vector& entries) override; +#endif + void MemcpyEntryInFusionBuffer(const std::vector& entries, const TensorTableEntry& e, void* buffer_data_at_offset) override; diff --git a/horovod/common/ops/mpi_gpu_operations.cc b/horovod/common/ops/mpi_gpu_operations.cc index be7d52bf5e..55509295e8 100644 --- a/horovod/common/ops/mpi_gpu_operations.cc +++ b/horovod/common/ops/mpi_gpu_operations.cc @@ -34,7 +34,6 @@ Status MPI_GPUAllreduce::Execute(std::vector& entries, const R const void* fused_input_data; void* buffer_data; size_t buffer_len; - int64_t num_elements = NumElements(entries); // Copy memory into the fusion buffer. auto& timeline = global_state_->timeline; @@ -51,6 +50,8 @@ Status MPI_GPUAllreduce::Execute(std::vector& entries, const R buffer_len = (size_t) first_entry.output->size(); } + int64_t num_elements = buffer_len / DataType_Size(first_entry.tensor->dtype()); + if (response.prescale_factor() != 1.0) { // Execute prescaling op ScaleBuffer(response.prescale_factor(), entries, fused_input_data, buffer_data, num_elements); diff --git a/horovod/common/ops/nccl_operations.cc b/horovod/common/ops/nccl_operations.cc index d54cf053d1..f9dd0d65f5 100644 --- a/horovod/common/ops/nccl_operations.cc +++ b/horovod/common/ops/nccl_operations.cc @@ -148,10 +148,7 @@ Status NCCLAllreduce::Execute(std::vector& entries, buffer_len = (size_t) first_entry.output->size(); } - int64_t num_elements = 0; - for (auto& e : entries) { - num_elements += e.tensor->shape().num_elements(); - } + int64_t num_elements = buffer_len / DataType_Size(first_entry.tensor->dtype()); if (response.prescale_factor() != 1.0) { // Execute prescaling op @@ -221,10 +218,7 @@ NCCLHierarchicalAllreduce::Execute(std::vector& entries, buffer_len = (size_t) first_entry.output->size(); } - int64_t num_elements = 0; - for (auto& e : entries) { - num_elements += e.tensor->shape().num_elements(); - } + int64_t num_elements = buffer_len / DataType_Size(first_entry.tensor->dtype()); if (response.prescale_factor() != 1.0) { // Execute prescaling op