Skip to content

Commit

Permalink
add nccl synchronization in tree training
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyu1994 committed Feb 20, 2024
1 parent 02b725b commit 3bfb784
Show file tree
Hide file tree
Showing 18 changed files with 539 additions and 79 deletions.
2 changes: 1 addition & 1 deletion .ci/check_python_dists.sh
Expand Up @@ -25,7 +25,7 @@ if [ $PY_MINOR_VER -gt 7 ]; then
pydistcheck \
--inspect \
--ignore 'compiled-objects-have-debug-symbols,distro-too-large-compressed' \
--max-allowed-size-uncompressed '100M' \
--max-allowed-size-uncompressed '500M' \
--max-allowed-files 800 \
${DIST_DIR}/* || exit 1
elif { test $(uname -m) = "aarch64"; }; then
Expand Down
2 changes: 1 addition & 1 deletion include/LightGBM/cuda/cuda_objective_function.hpp
Expand Up @@ -19,7 +19,7 @@
namespace LightGBM {

template <typename HOST_OBJECTIVE>
class CUDAObjectiveInterface: public HOST_OBJECTIVE, NCCLInfo {
class CUDAObjectiveInterface: public HOST_OBJECTIVE, public NCCLInfo {
public:
explicit CUDAObjectiveInterface(const Config& config): HOST_OBJECTIVE(config) {
if (config.num_gpus <= 1) {
Expand Down
46 changes: 45 additions & 1 deletion include/LightGBM/cuda/cuda_utils.hu
Expand Up @@ -102,6 +102,8 @@ void CopyFromCUDADeviceToCUDADeviceAsync(T* dst_ptr, const T* src_ptr, size_t si

void SynchronizeCUDADevice(const char* file, const int line);

void SynchronizeCUDAStream(cudaStream_t cuda_stream, const char* file, const int line);

template <typename T>
void SetCUDAMemory(T* dst_ptr, int value, size_t size, const char* file, const int line) {
CUDASUCCESS_OR_FATAL_OUTER(cudaMemset(reinterpret_cast<void*>(dst_ptr), value, size * sizeof(T)));
Expand Down Expand Up @@ -220,7 +222,7 @@ class NCCLInfo {
public:
NCCLInfo() {}

void SetNCCLInfo(
virtual void SetNCCLInfo(
ncclComm_t nccl_communicator,
int nccl_gpu_rank,
int local_gpu_rank,
Expand All @@ -242,6 +244,48 @@ class NCCLInfo {
data_size_t global_num_data_ = 0;
};

cudaStream_t CUDAStreamCreate();

void CUDAStreamDestroy(cudaStream_t cuda_stream);

void NCCLGroupStart();

void NCCLGroupEnd();

template <typename T>
void NCCLAllReduce(const T* send_buffer, T* recv_buffer, size_t count, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) {
NCCLCHECK(ncclAllReduce(reinterpret_cast<const void*>(send_buffer), reinterpret_cast<void*>(recv_buffer), count, datatype, op, comm, stream));
}

template <typename T>
void NCCLAllReduce(const T* send_buffer, T* recv_buffer, size_t count, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm) {
cudaStream_t nccl_stream;
CUDASUCCESS_OR_FATAL(cudaStreamCreate(&nccl_stream));
NCCLCHECK(ncclAllReduce(reinterpret_cast<const void*>(send_buffer), reinterpret_cast<void*>(recv_buffer), count, datatype, op, comm, nccl_stream));
CUDASUCCESS_OR_FATAL(cudaStreamSynchronize(nccl_stream));
CUDASUCCESS_OR_FATAL(cudaStreamDestroy(nccl_stream));
}

template <typename T>
T NCCLAllReduce(T send_value, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) {
CUDAVector<T> send_buffer(1);
CopyFromHostToCUDADevice<T>(send_buffer.RawData(), &send_value, 1, __FILE__, __LINE__);
NCCLAllReduce<T>(send_buffer.RawDataReadOnly(), send_buffer.RawData(), 1, datatype, op, comm, stream);
T recv_value = 0;
CopyFromCUDADeviceToHost<T>(&recv_value, send_buffer.RawDataReadOnly(), 1, __FILE__, __LINE__);
return recv_value;
}

template <typename T>
T NCCLAllReduce(T send_value, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm) {
CUDAVector<T> send_buffer(1);
CopyFromHostToCUDADevice<T>(send_buffer.RawData(), &send_value, 1, __FILE__, __LINE__);
NCCLAllReduce<T>(send_buffer.RawDataReadOnly(), send_buffer.RawData(), 1, datatype, op, comm);
T recv_value = 0;
CopyFromCUDADeviceToHost<T>(&recv_value, send_buffer.RawDataReadOnly(), 1, __FILE__, __LINE__);
return recv_value;
}

} // namespace LightGBM

#endif // USE_CUDA
Expand Down
3 changes: 1 addition & 2 deletions src/boosting/cuda/nccl_gbdt_component.hpp
Expand Up @@ -49,10 +49,9 @@ class NCCLGBDTComponent: public NCCLInfo {
hessians_.reset(new CUDAVector<score_t>(num_data_in_gpu_));
tree_learner_.reset(new CUDASingleGPUTreeLearner(config, boosting_on_gpu));

tree_learner_->SetNCCLInfo(nccl_communicator_, nccl_gpu_rank_, local_gpu_rank_, gpu_device_id_, train_data->num_data());

objective_function_->Init(dataset_->metadata(), dataset_->num_data());
tree_learner_->Init(dataset_.get(), is_constant_hessian);
tree_learner_->SetNCCLInfo(nccl_communicator_, nccl_gpu_rank_, local_gpu_rank_, gpu_device_id_, train_data->num_data());
}

ObjectiveFunction* objective_function() { return objective_function_.get(); }
Expand Down
22 changes: 22 additions & 0 deletions src/cuda/cuda_utils.cpp
Expand Up @@ -13,6 +13,10 @@ void SynchronizeCUDADevice(const char* file, const int line) {
gpuAssert(cudaDeviceSynchronize(), file, line);
}

void SynchronizeCUDAStream(cudaStream_t cuda_stream, const char* file, const int line) {
gpuAssert(cudaStreamSynchronize(cuda_stream), file, line);
}

void PrintLastCUDAError() {
const char* error_name = cudaGetErrorName(cudaGetLastError());
Log::Fatal(error_name);
Expand All @@ -32,6 +36,24 @@ int GetCUDADevice(const char* file, int line) {
return cur_gpu_device_id;
}

cudaStream_t CUDAStreamCreate() {
cudaStream_t cuda_stream;
CUDASUCCESS_OR_FATAL(cudaStreamCreate(&cuda_stream));
return cuda_stream;
}

void CUDAStreamDestroy(cudaStream_t cuda_stream) {
CUDASUCCESS_OR_FATAL(cudaStreamDestroy(cuda_stream));
}

void NCCLGroupStart() {
NCCLCHECK(ncclGroupStart());
}

void NCCLGroupEnd() {
NCCLCHECK(ncclGroupEnd());
}

} // namespace LightGBM

#endif // USE_CUDA
16 changes: 14 additions & 2 deletions src/objective/cuda/cuda_binary_objective.cu
Expand Up @@ -98,9 +98,21 @@ double CUDABinaryLogloss::LaunchCalcInitScoreKernel(const int /*class_id*/) cons
}
SynchronizeCUDADevice(__FILE__, __LINE__);
if (cuda_weights_ == nullptr) {
BoostFromScoreKernel_2_BinaryLogloss<false><<<1, 1>>>(cuda_boost_from_score_, cuda_sum_weights_, num_data_, sigmoid_);
if (nccl_communicator_ == nullptr) {
BoostFromScoreKernel_2_BinaryLogloss<false><<<1, 1>>>(cuda_boost_from_score_, cuda_sum_weights_, num_data_, sigmoid_);
} else {
NCCLAllReduce<double>(cuda_boost_from_score_, cuda_boost_from_score_, 1, ncclFloat64, ncclSum, nccl_communicator_);
const data_size_t global_num_data = NCCLAllReduce<data_size_t>(num_data_, ncclInt32, ncclSum, nccl_communicator_);
BoostFromScoreKernel_2_BinaryLogloss<false><<<1, 1>>>(cuda_boost_from_score_, cuda_sum_weights_, global_num_data, sigmoid_);
}
} else {
BoostFromScoreKernel_2_BinaryLogloss<true><<<1, 1>>>(cuda_boost_from_score_, cuda_sum_weights_, num_data_, sigmoid_);
if (nccl_communicator_ == nullptr) {
BoostFromScoreKernel_2_BinaryLogloss<true><<<1, 1>>>(cuda_boost_from_score_, cuda_sum_weights_, num_data_, sigmoid_);
} else {
NCCLAllReduce<double>(cuda_boost_from_score_, cuda_boost_from_score_, 1, ncclFloat64, ncclSum, nccl_communicator_);
NCCLAllReduce<double>(cuda_sum_weights_, cuda_sum_weights_, 1, ncclFloat64, ncclSum, nccl_communicator_);
BoostFromScoreKernel_2_BinaryLogloss<true><<<1, 1>>>(cuda_boost_from_score_, cuda_sum_weights_, num_data_, sigmoid_);
}
}
SynchronizeCUDADevice(__FILE__, __LINE__);
double boost_from_score = 0.0f;
Expand Down
22 changes: 15 additions & 7 deletions src/treelearner/cuda/cuda_data_partition.cpp
Expand Up @@ -18,13 +18,15 @@ CUDADataPartition::CUDADataPartition(
const int num_total_bin,
const int num_leaves,
const int num_threads,
const bool use_quantized_grad,
hist_t* cuda_hist):

num_data_(train_data->num_data()),
num_features_(train_data->num_features()),
num_total_bin_(num_total_bin),
num_leaves_(num_leaves),
num_threads_(num_threads),
use_quantized_grad_(use_quantized_grad),
cuda_hist_(cuda_hist) {
CalcBlockDim(num_data_);
max_num_split_indices_blocks_ = grid_dim_;
Expand Down Expand Up @@ -59,7 +61,6 @@ CUDADataPartition::CUDADataPartition(
cuda_block_data_to_left_offset_ = nullptr;
cuda_block_data_to_right_offset_ = nullptr;
cuda_out_data_indices_in_leaf_ = nullptr;
cuda_split_info_buffer_ = nullptr;
cuda_num_data_ = nullptr;
}

Expand All @@ -75,7 +76,6 @@ CUDADataPartition::~CUDADataPartition() {
DeallocateCUDAMemory<data_size_t>(&cuda_block_data_to_left_offset_, __FILE__, __LINE__);
DeallocateCUDAMemory<data_size_t>(&cuda_block_data_to_right_offset_, __FILE__, __LINE__);
DeallocateCUDAMemory<data_size_t>(&cuda_out_data_indices_in_leaf_, __FILE__, __LINE__);
DeallocateCUDAMemory<int>(&cuda_split_info_buffer_, __FILE__, __LINE__);
DeallocateCUDAMemory<data_size_t>(&cuda_num_data_, __FILE__, __LINE__);
CUDASUCCESS_OR_FATAL(cudaStreamDestroy(cuda_streams_[0]));
CUDASUCCESS_OR_FATAL(cudaStreamDestroy(cuda_streams_[1]));
Expand All @@ -102,7 +102,7 @@ void CUDADataPartition::Init() {
AllocateCUDAMemory<hist_t*>(&cuda_hist_pool_, static_cast<size_t>(num_leaves_), __FILE__, __LINE__);
CopyFromHostToCUDADevice<hist_t*>(cuda_hist_pool_, &cuda_hist_, 1, __FILE__, __LINE__);

AllocateCUDAMemory<int>(&cuda_split_info_buffer_, 16, __FILE__, __LINE__);
cuda_split_info_buffer_.Resize(18);

AllocateCUDAMemory<double>(&cuda_leaf_output_, static_cast<size_t>(num_leaves_), __FILE__, __LINE__);

Expand Down Expand Up @@ -159,7 +159,9 @@ void CUDADataPartition::Split(
double* left_leaf_sum_of_hessians,
double* right_leaf_sum_of_hessians,
double* left_leaf_sum_of_gradients,
double* right_leaf_sum_of_gradients) {
double* right_leaf_sum_of_gradients,
data_size_t* global_left_leaf_num_data,
data_size_t* global_right_leaf_num_data) {
CalcBlockDim(num_data_in_leaf);
global_timer.Start("GenDataToLeftBitVector");
GenDataToLeftBitVector(num_data_in_leaf,
Expand Down Expand Up @@ -187,7 +189,9 @@ void CUDADataPartition::Split(
left_leaf_sum_of_hessians,
right_leaf_sum_of_hessians,
left_leaf_sum_of_gradients,
right_leaf_sum_of_gradients);
right_leaf_sum_of_gradients,
global_left_leaf_num_data,
global_right_leaf_num_data);
global_timer.Stop("SplitInner");
}

Expand Down Expand Up @@ -238,7 +242,9 @@ void CUDADataPartition::SplitInner(
double* left_leaf_sum_of_hessians,
double* right_leaf_sum_of_hessians,
double* left_leaf_sum_of_gradients,
double* right_leaf_sum_of_gradients) {
double* right_leaf_sum_of_gradients,
data_size_t* global_left_leaf_num_data,
data_size_t* global_right_leaf_num_data) {
LaunchSplitInnerKernel(
num_data_in_leaf,
best_split_info,
Expand All @@ -253,7 +259,9 @@ void CUDADataPartition::SplitInner(
left_leaf_sum_of_hessians,
right_leaf_sum_of_hessians,
left_leaf_sum_of_gradients,
right_leaf_sum_of_gradients);
right_leaf_sum_of_gradients,
global_left_leaf_num_data,
global_right_leaf_num_data);
++cur_num_leaves_;
}

Expand Down
77 changes: 61 additions & 16 deletions src/treelearner/cuda/cuda_data_partition.cu
Expand Up @@ -782,6 +782,7 @@ __global__ void AggregateBlockOffsetKernel1(
}
}

template <bool USE_NCCL, bool USE_GRAD_DISCRETIZED>
__global__ void SplitTreeStructureKernel(const int left_leaf_index,
const int right_leaf_index,
data_size_t* block_to_left_offset_buffer,
Expand Down Expand Up @@ -822,11 +823,17 @@ __global__ void SplitTreeStructureKernel(const int left_leaf_index,
cuda_split_info_buffer_for_hessians[3] = best_split_info->right_sum_gradients;
}

if (cuda_leaf_num_data[left_leaf_index] < cuda_leaf_num_data[right_leaf_index]) {
bool left_is_smaller = USE_NCCL ?
cuda_split_info_buffer[16] < cuda_split_info_buffer[17] :
cuda_leaf_num_data[left_leaf_index] < cuda_leaf_num_data[right_leaf_index];

if (left_is_smaller) {
if (global_thread_index == 0) {
hist_t* parent_hist_ptr = cuda_hist_pool[left_leaf_index];
cuda_hist_pool[right_leaf_index] = parent_hist_ptr;
cuda_hist_pool[left_leaf_index] = cuda_hist + 2 * right_leaf_index * num_total_bin;
cuda_hist_pool[left_leaf_index] = USE_GRAD_DISCRETIZED ?
cuda_hist + right_leaf_index * num_total_bin :
cuda_hist + 2 * right_leaf_index * num_total_bin;
smaller_leaf_splits->hist_in_leaf = cuda_hist_pool[left_leaf_index];
larger_leaf_splits->hist_in_leaf = cuda_hist_pool[right_leaf_index];
} else if (global_thread_index == 1) {
Expand Down Expand Up @@ -958,7 +965,9 @@ void CUDADataPartition::LaunchSplitInnerKernel(
double* left_leaf_sum_of_hessians_ref,
double* right_leaf_sum_of_hessians_ref,
double* left_leaf_sum_of_gradients_ref,
double* right_leaf_sum_of_gradients_ref) {
double* right_leaf_sum_of_gradients_ref,
data_size_t* global_left_leaf_num_data,
data_size_t* global_right_leaf_num_data) {
int num_blocks_final_ref = grid_dim_ - 1;
int num_blocks_final_aligned = 1;
while (num_blocks_final_ref > 0) {
Expand Down Expand Up @@ -986,6 +995,20 @@ void CUDADataPartition::LaunchSplitInnerKernel(
}
SynchronizeCUDADevice(__FILE__, __LINE__);
global_timer.Stop("CUDADataPartition::AggregateBlockOffsetKernel");

if (nccl_communicator_ != nullptr) {
NCCLGroupStart();
NCCLAllReduce<data_size_t>(
cuda_leaf_num_data_ + left_leaf_index,
cuda_split_info_buffer_.RawData() + 16,
1, ncclInt32, ncclSum, nccl_communicator_, cuda_streams_[0]);
NCCLAllReduce<data_size_t>(
cuda_leaf_num_data_ + right_leaf_index,
cuda_split_info_buffer_.RawData() + 17,
1, ncclInt32, ncclSum, nccl_communicator_, cuda_streams_[0]);
NCCLGroupEnd();
}

global_timer.Start("CUDADataPartition::SplitInnerKernel");
SplitInnerKernel<<<grid_dim_, block_dim_, 0, cuda_streams_[1]>>>(
left_leaf_index, right_leaf_index, cuda_leaf_data_start_, cuda_leaf_num_data_, cuda_data_indices_,
Expand All @@ -995,22 +1018,40 @@ void CUDADataPartition::LaunchSplitInnerKernel(
SynchronizeCUDADevice(__FILE__, __LINE__);

global_timer.Start("CUDADataPartition::SplitTreeStructureKernel");
SplitTreeStructureKernel<<<4, 5, 0, cuda_streams_[0]>>>(left_leaf_index, right_leaf_index,
cuda_block_data_to_left_offset_,
cuda_block_data_to_right_offset_, cuda_leaf_data_start_, cuda_leaf_data_end_,
cuda_leaf_num_data_, cuda_out_data_indices_in_leaf_,
best_split_info,
smaller_leaf_splits,
larger_leaf_splits,
num_total_bin_,
cuda_hist_,
cuda_hist_pool_,
cuda_leaf_output_, cuda_split_info_buffer_);

#define SPLI_TREE_ARGS \
left_leaf_index, right_leaf_index, \
cuda_block_data_to_left_offset_, \
cuda_block_data_to_right_offset_, cuda_leaf_data_start_, cuda_leaf_data_end_, \
cuda_leaf_num_data_, cuda_out_data_indices_in_leaf_, \
best_split_info, \
smaller_leaf_splits, \
larger_leaf_splits, \
num_total_bin_, \
cuda_hist_, \
cuda_hist_pool_, \
cuda_leaf_output_, cuda_split_info_buffer_.RawData()

if (nccl_communicator_ != nullptr) {
if (use_quantized_grad_) {
SplitTreeStructureKernel<true, true><<<4, 5, 0, cuda_streams_[0]>>>(SPLI_TREE_ARGS);
} else {
SplitTreeStructureKernel<true, false><<<4, 5, 0, cuda_streams_[0]>>>(SPLI_TREE_ARGS);
}
} else {
if (use_quantized_grad_) {
SplitTreeStructureKernel<false, true><<<4, 5, 0, cuda_streams_[0]>>>(SPLI_TREE_ARGS);
} else {
SplitTreeStructureKernel<false, false><<<4, 5, 0, cuda_streams_[0]>>>(SPLI_TREE_ARGS);
}
}

#undef SPLI_TREE_ARGS
global_timer.Stop("CUDADataPartition::SplitTreeStructureKernel");
std::vector<int> cpu_split_info_buffer(16);
std::vector<int> cpu_split_info_buffer(18);
const double* cpu_sum_hessians_info = reinterpret_cast<const double*>(cpu_split_info_buffer.data() + 8);
global_timer.Start("CUDADataPartition::CopyFromCUDADeviceToHostAsync");
CopyFromCUDADeviceToHostAsync<int>(cpu_split_info_buffer.data(), cuda_split_info_buffer_, 16, cuda_streams_[0], __FILE__, __LINE__);
CopyFromCUDADeviceToHostAsync<int>(cpu_split_info_buffer.data(), cuda_split_info_buffer_.RawData(), 16, cuda_streams_[0], __FILE__, __LINE__);
SynchronizeCUDADevice(__FILE__, __LINE__);
global_timer.Stop("CUDADataPartition::CopyFromCUDADeviceToHostAsync");
const data_size_t left_leaf_num_data = cpu_split_info_buffer[1];
Expand All @@ -1029,6 +1070,10 @@ void CUDADataPartition::LaunchSplitInnerKernel(
*right_leaf_sum_of_hessians_ref = cpu_sum_hessians_info[1];
*left_leaf_sum_of_gradients_ref = cpu_sum_hessians_info[2];
*right_leaf_sum_of_gradients_ref = cpu_sum_hessians_info[3];
if (nccl_communicator_ != nullptr) {
*global_left_leaf_num_data = cpu_split_info_buffer[16];
*global_right_leaf_num_data = cpu_split_info_buffer[17];
}
}

template <bool USE_BAGGING>
Expand Down

0 comments on commit 3bfb784

Please sign in to comment.