Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for batched D2D memcopy kernel on GPU. #2435

Merged
merged 3 commits into from Nov 13, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions horovod/common/common.h
Expand Up @@ -79,6 +79,7 @@ namespace common {
#define HOROVOD_HIERARCHICAL_ALLREDUCE "HOROVOD_HIERARCHICAL_ALLREDUCE"
#define HOROVOD_HIERARCHICAL_ALLGATHER "HOROVOD_HIERARCHICAL_ALLGATHER"
#define HOROVOD_CACHE_CAPACITY "HOROVOD_CACHE_CAPACITY"
#define HOROVOD_BATCH_D2D_MEMCOPIES "HOROVOD_BATCH_D2D_MEMCOPIES"
#define HOROVOD_NUM_NCCL_STREAMS "HOROVOD_NUM_NCCL_STREAMS"
#define HOROVOD_CPU_OPERATIONS "HOROVOD_CPU_OPERATIONS"
#define HOROVOD_CONTROLLER "HOROVOD_CONTROLLER"
Expand Down
26 changes: 23 additions & 3 deletions horovod/common/controller.cc
Expand Up @@ -27,6 +27,11 @@
#include "logging.h"
#include "operations.h"

#if HAVE_CUDA
#include "ops/cuda/cuda_kernels.h"
#endif


namespace horovod {
namespace common {

Expand Down Expand Up @@ -199,7 +204,7 @@ ResponseList Controller::ComputeResponseList(std::atomic_bool& shut_down,
}

// Fuse responses as normal.
response_list = FuseResponses(responses);
response_list = FuseResponses(responses, state);
response_list.set_shutdown(cache_coordinator.should_shut_down());
} else {
// There are uncached messages coming in, need communication to figure out
Expand Down Expand Up @@ -306,7 +311,7 @@ ResponseList Controller::ComputeResponseList(std::atomic_bool& shut_down,
responses.push_back(std::move(join_response));
state.joined_size = 0;
}
response_list = FuseResponses(responses);
response_list = FuseResponses(responses, state);
response_list.set_shutdown(should_shut_down);

// Broadcast final results to other ranks.
Expand Down Expand Up @@ -683,7 +688,8 @@ void Controller::CoordinateCacheAndState(CacheCoordinator& cache_coordinator) {
}
}

ResponseList Controller::FuseResponses(std::deque<Response>& responses) {
ResponseList Controller::FuseResponses(std::deque<Response>& responses,
HorovodGlobalState& state) {
ResponseList response_list;
while (!responses.empty()) {

Expand All @@ -696,6 +702,12 @@ ResponseList Controller::FuseResponses(std::deque<Response>& responses) {
// Attempt to add more responses to this fused response.

tensor_size = response.tensor_sizes()[0] * GetTypeSize(response.tensor_type());
#if HAVE_CUDA
if (state.batch_d2d_memcopies) {
// Add 16 byte pad for batched memcpy op
tensor_size = BATCHED_D2D_PADDING * ((tensor_size + BATCHED_D2D_PADDING - 1) / BATCHED_D2D_PADDING);
}
#endif
std::deque<Response> skipped_responses;
int64_t skipped_size = 0;
while (!responses.empty()) {
Expand All @@ -706,6 +718,14 @@ ResponseList Controller::FuseResponses(std::deque<Response>& responses) {
? 0
: new_response.tensor_sizes()[0] *
GetTypeSize(new_response.tensor_type());

#if HAVE_CUDA
if (state.batch_d2d_memcopies) {
// Add 16 byte pad for batched memcpy op
new_tensor_size = BATCHED_D2D_PADDING * ((new_tensor_size + BATCHED_D2D_PADDING - 1) / BATCHED_D2D_PADDING);
}
#endif

if (response.response_type() == new_response.response_type() &&
response.devices() == new_response.devices() &&
response.tensor_type() == new_response.tensor_type() &&
Expand Down
3 changes: 2 additions & 1 deletion horovod/common/controller.h
Expand Up @@ -162,7 +162,8 @@ class Controller : public std::enable_shared_from_this<Controller> {
// exist on any worker.
void CoordinateCacheAndState(CacheCoordinator& cache_coordinator);

ResponseList FuseResponses(std::deque<Response>& responses);
ResponseList FuseResponses(std::deque<Response>& responses,
HorovodGlobalState& state);

// Return the total byte size of the final allgathered output tensor
int64_t
Expand Down
3 changes: 3 additions & 0 deletions horovod/common/global_state.h
Expand Up @@ -110,6 +110,9 @@ struct HorovodGlobalState {
// benefit from a smaller chunk size.
int64_t adasum_mpi_chunk_size = 1<<30;

// Enable use of batched d2d memcopy kernel on GPU
bool batch_d2d_memcopies = true;

~HorovodGlobalState() {
// Make sure that the destructor of the background thread is safe to
// call. If a thread is still joinable (not detached or complete) its
Expand Down
8 changes: 8 additions & 0 deletions horovod/common/operations.cc
Expand Up @@ -504,6 +504,14 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {
"allgather and hierarchical allreduce.";
}

// Set flag to enable use of batched memcopy kernel on GPU
auto horovod_batch_d2d_memcopies =
std::getenv(HOROVOD_BATCH_D2D_MEMCOPIES);
if (horovod_batch_d2d_memcopies != nullptr &&
std::strtol(horovod_batch_d2d_memcopies, nullptr, 10) > 0) {
state.batch_d2d_memcopies = true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the default is now true, we should also change this behavior such that if the user specifies HOROVOD_BATCH_D2D_MEMCOPIES=0 it sets the value to false.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, good catch. Just fixed this.

}

// Enable auto-tuning.
auto horovod_autotune = std::getenv(HOROVOD_AUTOTUNE);
if (horovod_autotune != nullptr &&
Expand Down
5 changes: 1 addition & 4 deletions horovod/common/ops/adasum_gpu_operations.cc
Expand Up @@ -86,10 +86,7 @@ AdasumGpuAllreduceOp::NcclHierarchical(std::vector<TensorTableEntry>& entries,
buffer_len = (size_t)first_entry.output->size();
}

int64_t num_elements = 0;
for (auto& e : entries) {
num_elements += e.tensor->shape().num_elements();
}
int64_t num_elements = buffer_len / DataType_Size(first_entry.tensor->dtype());

if (response.prescale_factor() != 1.0) {
// Execute prescaling op
Expand Down
62 changes: 62 additions & 0 deletions horovod/common/ops/cuda/cuda_kernels.cu
Expand Up @@ -21,6 +21,68 @@
namespace horovod {
namespace common {

template<typename T, int blocks_per_copy>
__device__ void batched_memcpy_d(size_t idx, const void* in, void* out, size_t size) {

const T* input = reinterpret_cast<const T *>(in);
T* output = reinterpret_cast<T *>(out);
const size_t num_elements = size / sizeof(T);

for (size_t i = idx; i < num_elements; i += blockDim.x * blocks_per_copy) {
output[i] = input[i];
}

// Deal with any remaining bytes
size_t remainder = size % sizeof(T);
if (remainder > 0 && idx < remainder) {
const unsigned char* input_r = reinterpret_cast<const unsigned char *>(input + num_elements);
unsigned char* output_r = reinterpret_cast<unsigned char *>(output + num_elements);
output_r[idx] = input_r[idx];
}
}

template<int blocks_per_copy>
__global__ void batched_memcpy_k(BatchedD2DParams params) {
const size_t idx = blockDim.x * (blockIdx.x % blocks_per_copy) + threadIdx.x;

const size_t size = params.sizes[blockIdx.x / blocks_per_copy];
const void* input = params.in[blockIdx.x / blocks_per_copy];
void* output = params.out[blockIdx.x / blocks_per_copy];

// Check alignment relative to 16 bytes
size_t align_in = reinterpret_cast<size_t>(input) % BATCHED_D2D_PADDING;
size_t align_out = reinterpret_cast<size_t>(output) % BATCHED_D2D_PADDING;

// Select load/store size based on the misaligned buffer
size_t align = (align_out == 0) ? align_in : align_out;
if (align_in && align_out) {
// If both are misaligned, use unsigned char (this should not occur
// as fusion buffer locations should be aligned by applying BATCH_D2D_PADDING
// during construction.)
align = 1;
}

if (align % 16 == 0) {
batched_memcpy_d<ulonglong2, blocks_per_copy>(idx, input, output, size);
} else if (align % 8 == 0) {
batched_memcpy_d<unsigned long long, blocks_per_copy>(idx, input, output, size);
} else if (align % 4 == 0) {
batched_memcpy_d<unsigned int, blocks_per_copy>(idx, input, output, size);
} else if (align % 2 == 0) {
batched_memcpy_d<unsigned short, blocks_per_copy>(idx, input, output, size);
} else {
batched_memcpy_d<unsigned char, blocks_per_copy>(idx, input, output, size);
}
}

#define NTHREADS_D2D_KERNEL 1024
#define BLOCKS_PER_COPY_D2D_KERNEL 8
void BatchedD2DMemcpyCudaImpl(BatchedD2DParams& params, int num_copies, cudaStream_t stream)
{
batched_memcpy_k<BLOCKS_PER_COPY_D2D_KERNEL><<<num_copies * BLOCKS_PER_COPY_D2D_KERNEL,
NTHREADS_D2D_KERNEL, 0, stream>>>(params);
}

template<typename T, typename TS>
__global__ void scale_buffer_k(const T* input, T* output, int64_t num_elements, const TS scale_factor) {

Expand Down
15 changes: 15 additions & 0 deletions horovod/common/ops/cuda/cuda_kernels.h
Expand Up @@ -16,11 +16,26 @@
#ifndef CUDA_KERNELS_H
#define CUDA_KERNELS_H

#include <cuda_runtime.h>

#include "../../message.h"

#define BATCHED_D2D_CAPACITY 160
tgaddair marked this conversation as resolved.
Show resolved Hide resolved
#define BATCHED_D2D_PADDING 16

namespace horovod {
namespace common {

struct BatchedD2DParams {
void* out[BATCHED_D2D_CAPACITY];
void* in[BATCHED_D2D_CAPACITY];
size_t sizes[BATCHED_D2D_CAPACITY];
};

// Performs a batched d2d memcopy
void BatchedD2DMemcpyCudaImpl(BatchedD2DParams& params, int num_copies, cudaStream_t stream);

// Scales buffer by scalar
void ScaleBufferCudaImpl(const void* fused_input_data, void* buffer_data, const int64_t num_elements,
double scale_factor, DataType dtype, cudaStream_t stream);

Expand Down
5 changes: 1 addition & 4 deletions horovod/common/ops/ddl_operations.cc
Expand Up @@ -66,10 +66,7 @@ Status DDLAllreduce::Execute(std::vector<TensorTableEntry>& entries, const Respo
buffer_len = (size_t) first_entry.output->size();
}

int64_t num_elements = 0;
for (auto& e : entries) {
num_elements += e.tensor->shape().num_elements();
}
int64_t num_elements = buffer_len / DataType_Size(first_entry.tensor->dtype());

if (response.prescale_factor() != 1.0) {
// Execute prescaling op
Expand Down
97 changes: 97 additions & 0 deletions horovod/common/ops/gpu_operations.cc
Expand Up @@ -15,6 +15,9 @@
// =============================================================================

#include "gpu_operations.h"
#if HAVE_CUDA
#include "cuda/cuda_kernels.h"
#endif

#include <thread>

Expand Down Expand Up @@ -95,13 +98,107 @@ bool GPUAllreduce::Enabled(const ParameterManager& param_manager,
return entries[0].device != CPU_DEVICE_ID;
}

#if HAVE_CUDA
void GPUAllreduce::MemcpyInFusionBuffer(const std::vector<TensorTableEntry>& entries, const void*& fused_input_data,
void*& buffer_data, size_t& buffer_len) {
// Access the fusion buffer.
auto& first_entry = entries[0];
auto buffer = global_state_->fusion_buffer.GetBuffer(
first_entry.device, first_entry.context->framework(), global_state_->current_nccl_stream);
buffer_data = const_cast<void*>(buffer->AccessData(first_entry.context));

if (global_state_->batch_d2d_memcopies) {
int64_t offset = 0;
int idx = 0;
int count = 0;

BatchedD2DParams d2d_params;
auto& first_entry = entries[0];
for (auto& e : entries) {
void* buffer_data_at_offset = (uint8_t*)buffer_data + offset;

// Set input/output pointers and sizes
d2d_params.out[idx % BATCHED_D2D_CAPACITY] = buffer_data_at_offset;
d2d_params.in[idx % BATCHED_D2D_CAPACITY] = (void*) e.tensor->data();
d2d_params.sizes[idx % BATCHED_D2D_CAPACITY] = e.tensor->size();

offset += BATCHED_D2D_PADDING * ((e.tensor->size() + BATCHED_D2D_PADDING - 1) / BATCHED_D2D_PADDING);
idx++;
count++;

if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int) entries.size()) {
// Perform batched d2d memcpy
BatchedD2DMemcpyCudaImpl(d2d_params, count, gpu_context_->streams[global_state_->current_nccl_stream][first_entry.device]);
gpu_context_->ErrorCheck("BatchedD2DMemcpyCudaImpl", cudaGetLastError());
count = 0;
}
}
buffer_len = (size_t)offset;

} else {
int64_t offset = 0;
for (auto& e : entries) {
void* buffer_data_at_offset = (uint8_t*) buffer_data + offset;
MemcpyEntryInFusionBuffer(entries, e, buffer_data_at_offset);
offset += e.tensor->size();
}

buffer_len = (size_t) offset;
}

// Set the input data to originate from the buffer.
fused_input_data = buffer_data;
}
#endif


void GPUAllreduce::MemcpyEntryInFusionBuffer(const std::vector<TensorTableEntry>& entries,
const TensorTableEntry& e, void* buffer_data_at_offset) {
auto& first_entry = entries[0];
gpu_context_->MemcpyAsyncD2D(buffer_data_at_offset, e.tensor->data(), (size_t) e.tensor->size(),
gpu_context_->streams[global_state_->current_nccl_stream][first_entry.device]);
}

#if HAVE_CUDA
void GPUAllreduce::MemcpyOutFusionBuffer(const void* buffer_data, std::vector<TensorTableEntry>& entries) {
if (global_state_->batch_d2d_memcopies) {
int64_t offset = 0;
int idx = 0;
int count = 0;

BatchedD2DParams d2d_params;
auto& first_entry = entries[0];
for (auto& e : entries) {
void* buffer_data_at_offset = (uint8_t*)buffer_data + offset;

// Set input/output pointers and sizes
d2d_params.out[idx % BATCHED_D2D_CAPACITY] = (void*)(e.output->data());
d2d_params.in[idx % BATCHED_D2D_CAPACITY] = buffer_data_at_offset;
d2d_params.sizes[idx % BATCHED_D2D_CAPACITY] = e.tensor->size();

offset += BATCHED_D2D_PADDING * ((e.tensor->size() + BATCHED_D2D_PADDING - 1) / BATCHED_D2D_PADDING);
idx++;
count++;

if (idx % BATCHED_D2D_CAPACITY == 0 || idx == (int) entries.size()) {
// Perform batched d2d memcpy
BatchedD2DMemcpyCudaImpl(d2d_params, count, gpu_context_->streams[global_state_->current_nccl_stream][first_entry.device]);
gpu_context_->ErrorCheck("BatchedD2DMemcpyCudaImpl", cudaGetLastError());
count = 0;
}
}

} else {
int64_t offset = 0;
for (auto& e : entries) {
void* buffer_data_at_offset = (uint8_t*) buffer_data + offset;
MemcpyEntryOutFusionBuffer(entries, buffer_data_at_offset, e);
offset += e.tensor->size();
}
}
}
#endif

void GPUAllreduce::MemcpyEntryOutFusionBuffer(const std::vector<TensorTableEntry>& entries,
const void* buffer_data_at_offset, TensorTableEntry& e) {
auto& first_entry = entries[0];
Expand Down
7 changes: 7 additions & 0 deletions horovod/common/ops/gpu_operations.h
Expand Up @@ -136,6 +136,13 @@ class GPUAllreduce : public AllreduceOp {
const Response& response) const override;

protected:
#if HAVE_CUDA
void MemcpyInFusionBuffer(const std::vector<TensorTableEntry>& entries, const void*& fused_input_data,
void*& buffer_data, size_t& buffer_len) override;

void MemcpyOutFusionBuffer(const void* buffer_data, std::vector<TensorTableEntry>& entries) override;
#endif

void MemcpyEntryInFusionBuffer(const std::vector<TensorTableEntry>& entries,
const TensorTableEntry& e, void* buffer_data_at_offset) override;

Expand Down
3 changes: 2 additions & 1 deletion horovod/common/ops/mpi_gpu_operations.cc
Expand Up @@ -34,7 +34,6 @@ Status MPI_GPUAllreduce::Execute(std::vector<TensorTableEntry>& entries, const R
const void* fused_input_data;
void* buffer_data;
size_t buffer_len;
int64_t num_elements = NumElements(entries);

// Copy memory into the fusion buffer.
auto& timeline = global_state_->timeline;
Expand All @@ -51,6 +50,8 @@ Status MPI_GPUAllreduce::Execute(std::vector<TensorTableEntry>& entries, const R
buffer_len = (size_t) first_entry.output->size();
}

int64_t num_elements = buffer_len / DataType_Size(first_entry.tensor->dtype());

if (response.prescale_factor() != 1.0) {
// Execute prescaling op
ScaleBuffer(response.prescale_factor(), entries, fused_input_data, buffer_data, num_elements);
Expand Down