Skip to content

Commit

Permalink
[CUDA] Set GPU device ID in threads (#6028)
Browse files Browse the repository at this point in the history
* set gpu device id in open mp threads

* move SetCUDADevice outside for loop

---------

Co-authored-by: James Lamb <jaylamb20@gmail.com>
  • Loading branch information
shiyu1994 and jameslamb committed Aug 13, 2023
1 parent fe838d8 commit 5c9e61d
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 71 deletions.
1 change: 1 addition & 0 deletions include/LightGBM/cuda/cuda_column_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class CUDAColumnData {

void ResizeWhenCopySubrow(const data_size_t num_used_indices);

int gpu_device_id_;
int num_threads_;
data_size_t num_data_;
int num_columns_;
Expand Down
3 changes: 3 additions & 0 deletions include/LightGBM/cuda/cuda_metric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#ifdef USE_CUDA

#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/metric.h>

namespace LightGBM {
Expand All @@ -19,6 +20,8 @@ class CUDAMetricInterface: public HOST_METRIC {
explicit CUDAMetricInterface(const Config& config): HOST_METRIC(config) {
cuda_labels_ = nullptr;
cuda_weights_ = nullptr;
const int gpu_device_id = config.gpu_device_id >= 0 ? config.gpu_device_id : 0;
SetCUDADevice(gpu_device_id, __FILE__, __LINE__);
}

void Init(const Metadata& metadata, data_size_t num_data) override {
Expand Down
5 changes: 4 additions & 1 deletion include/LightGBM/cuda/cuda_objective_function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ namespace LightGBM {
template <typename HOST_OBJECTIVE>
class CUDAObjectiveInterface: public HOST_OBJECTIVE {
public:
explicit CUDAObjectiveInterface(const Config& config): HOST_OBJECTIVE(config) {}
explicit CUDAObjectiveInterface(const Config& config): HOST_OBJECTIVE(config) {
const int gpu_device_id = config.gpu_device_id >= 0 ? config.gpu_device_id : 0;
SetCUDADevice(gpu_device_id, __FILE__, __LINE__);
}

explicit CUDAObjectiveInterface(const std::vector<std::string>& strs): HOST_OBJECTIVE(strs) {}

Expand Down
2 changes: 2 additions & 0 deletions include/LightGBM/cuda/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort =

void SetCUDADevice(int gpu_device_id, const char* file, int line);

int GetCUDADevice(const char* file, int line);

template <typename T>
void AllocateCUDAMemory(T** out_ptr, size_t size, const char* file, const int line) {
void* tmp_ptr = nullptr;
Expand Down
6 changes: 6 additions & 0 deletions src/cuda/cuda_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ void SetCUDADevice(int gpu_device_id, const char* file, int line) {
}
}

int GetCUDADevice(const char* file, int line) {
int cur_gpu_device_id = 0;
CUDASUCCESS_OR_FATAL_OUTER(cudaGetDevice(&cur_gpu_device_id));
return cur_gpu_device_id;
}

} // namespace LightGBM

#endif // USE_CUDA
149 changes: 79 additions & 70 deletions src/io/cuda/cuda_column_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@ namespace LightGBM {
CUDAColumnData::CUDAColumnData(const data_size_t num_data, const int gpu_device_id) {
num_threads_ = OMP_NUM_THREADS();
num_data_ = num_data;
if (gpu_device_id >= 0) {
SetCUDADevice(gpu_device_id, __FILE__, __LINE__);
} else {
SetCUDADevice(0, __FILE__, __LINE__);
}
gpu_device_id_ = gpu_device_id >= 0 ? gpu_device_id : 0;
SetCUDADevice(gpu_device_id_, __FILE__, __LINE__);
cuda_used_indices_ = nullptr;
cuda_data_by_column_ = nullptr;
cuda_column_bit_type_ = nullptr;
Expand Down Expand Up @@ -117,37 +114,41 @@ void CUDAColumnData::Init(const int num_columns,
feature_mfb_is_na_ = feature_mfb_is_na;
data_by_column_.resize(num_columns_, nullptr);
OMP_INIT_EX();
#pragma omp parallel for schedule(static) num_threads(num_threads_)
for (int column_index = 0; column_index < num_columns_; ++column_index) {
OMP_LOOP_EX_BEGIN();
const int8_t bit_type = column_bit_type[column_index];
if (column_data[column_index] != nullptr) {
// is dense column
if (bit_type == 4) {
column_bit_type_[column_index] = 8;
InitOneColumnData<false, true, uint8_t>(column_data[column_index], nullptr, &data_by_column_[column_index]);
} else if (bit_type == 8) {
InitOneColumnData<false, false, uint8_t>(column_data[column_index], nullptr, &data_by_column_[column_index]);
} else if (bit_type == 16) {
InitOneColumnData<false, false, uint16_t>(column_data[column_index], nullptr, &data_by_column_[column_index]);
} else if (bit_type == 32) {
InitOneColumnData<false, false, uint32_t>(column_data[column_index], nullptr, &data_by_column_[column_index]);
} else {
Log::Fatal("Unknow column bit type %d", bit_type);
}
} else {
// is sparse column
if (bit_type == 8) {
InitOneColumnData<true, false, uint8_t>(nullptr, column_bin_iterator[column_index], &data_by_column_[column_index]);
} else if (bit_type == 16) {
InitOneColumnData<true, false, uint16_t>(nullptr, column_bin_iterator[column_index], &data_by_column_[column_index]);
} else if (bit_type == 32) {
InitOneColumnData<true, false, uint32_t>(nullptr, column_bin_iterator[column_index], &data_by_column_[column_index]);
#pragma omp parallel num_threads(num_threads_)
{
SetCUDADevice(gpu_device_id_, __FILE__, __LINE__);
#pragma omp for schedule(static)
for (int column_index = 0; column_index < num_columns_; ++column_index) {
OMP_LOOP_EX_BEGIN();
const int8_t bit_type = column_bit_type[column_index];
if (column_data[column_index] != nullptr) {
// is dense column
if (bit_type == 4) {
column_bit_type_[column_index] = 8;
InitOneColumnData<false, true, uint8_t>(column_data[column_index], nullptr, &data_by_column_[column_index]);
} else if (bit_type == 8) {
InitOneColumnData<false, false, uint8_t>(column_data[column_index], nullptr, &data_by_column_[column_index]);
} else if (bit_type == 16) {
InitOneColumnData<false, false, uint16_t>(column_data[column_index], nullptr, &data_by_column_[column_index]);
} else if (bit_type == 32) {
InitOneColumnData<false, false, uint32_t>(column_data[column_index], nullptr, &data_by_column_[column_index]);
} else {
Log::Fatal("Unknow column bit type %d", bit_type);
}
} else {
Log::Fatal("Unknow column bit type %d", bit_type);
// is sparse column
if (bit_type == 8) {
InitOneColumnData<true, false, uint8_t>(nullptr, column_bin_iterator[column_index], &data_by_column_[column_index]);
} else if (bit_type == 16) {
InitOneColumnData<true, false, uint16_t>(nullptr, column_bin_iterator[column_index], &data_by_column_[column_index]);
} else if (bit_type == 32) {
InitOneColumnData<true, false, uint32_t>(nullptr, column_bin_iterator[column_index], &data_by_column_[column_index]);
} else {
Log::Fatal("Unknow column bit type %d", bit_type);
}
}
OMP_LOOP_EX_END();
}
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
feature_to_column_ = feature_to_column;
Expand Down Expand Up @@ -182,24 +183,28 @@ void CUDAColumnData::CopySubrow(
AllocateCUDAMemory<data_size_t>(&cuda_used_indices_, num_used_indices_size, __FILE__, __LINE__);
data_by_column_.resize(num_columns_, nullptr);
OMP_INIT_EX();
#pragma omp parallel for schedule(static) num_threads(num_threads_)
for (int column_index = 0; column_index < num_columns_; ++column_index) {
OMP_LOOP_EX_BEGIN();
const uint8_t bit_type = column_bit_type_[column_index];
if (bit_type == 8) {
uint8_t* column_data = nullptr;
AllocateCUDAMemory<uint8_t>(&column_data, num_used_indices_size, __FILE__, __LINE__);
data_by_column_[column_index] = reinterpret_cast<void*>(column_data);
} else if (bit_type == 16) {
uint16_t* column_data = nullptr;
AllocateCUDAMemory<uint16_t>(&column_data, num_used_indices_size, __FILE__, __LINE__);
data_by_column_[column_index] = reinterpret_cast<void*>(column_data);
} else if (bit_type == 32) {
uint32_t* column_data = nullptr;
AllocateCUDAMemory<uint32_t>(&column_data, num_used_indices_size, __FILE__, __LINE__);
data_by_column_[column_index] = reinterpret_cast<void*>(column_data);
#pragma omp parallel num_threads(num_threads_)
{
SetCUDADevice(gpu_device_id_, __FILE__, __LINE__);
#pragma omp for schedule(static)
for (int column_index = 0; column_index < num_columns_; ++column_index) {
OMP_LOOP_EX_BEGIN();
const uint8_t bit_type = column_bit_type_[column_index];
if (bit_type == 8) {
uint8_t* column_data = nullptr;
AllocateCUDAMemory<uint8_t>(&column_data, num_used_indices_size, __FILE__, __LINE__);
data_by_column_[column_index] = reinterpret_cast<void*>(column_data);
} else if (bit_type == 16) {
uint16_t* column_data = nullptr;
AllocateCUDAMemory<uint16_t>(&column_data, num_used_indices_size, __FILE__, __LINE__);
data_by_column_[column_index] = reinterpret_cast<void*>(column_data);
} else if (bit_type == 32) {
uint32_t* column_data = nullptr;
AllocateCUDAMemory<uint32_t>(&column_data, num_used_indices_size, __FILE__, __LINE__);
data_by_column_[column_index] = reinterpret_cast<void*>(column_data);
}
OMP_LOOP_EX_END();
}
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
InitCUDAMemoryFromHostMemory<void*>(&cuda_data_by_column_, data_by_column_.data(), data_by_column_.size(), __FILE__, __LINE__);
Expand All @@ -221,27 +226,31 @@ void CUDAColumnData::ResizeWhenCopySubrow(const data_size_t num_used_indices) {
DeallocateCUDAMemory<data_size_t>(&cuda_used_indices_, __FILE__, __LINE__);
AllocateCUDAMemory<data_size_t>(&cuda_used_indices_, num_used_indices_size, __FILE__, __LINE__);
OMP_INIT_EX();
#pragma omp parallel for schedule(static) num_threads(num_threads_)
for (int column_index = 0; column_index < num_columns_; ++column_index) {
OMP_LOOP_EX_BEGIN();
const uint8_t bit_type = column_bit_type_[column_index];
if (bit_type == 8) {
uint8_t* column_data = reinterpret_cast<uint8_t*>(data_by_column_[column_index]);
DeallocateCUDAMemory<uint8_t>(&column_data, __FILE__, __LINE__);
AllocateCUDAMemory<uint8_t>(&column_data, num_used_indices_size, __FILE__, __LINE__);
data_by_column_[column_index] = reinterpret_cast<void*>(column_data);
} else if (bit_type == 16) {
uint16_t* column_data = reinterpret_cast<uint16_t*>(data_by_column_[column_index]);
DeallocateCUDAMemory<uint16_t>(&column_data, __FILE__, __LINE__);
AllocateCUDAMemory<uint16_t>(&column_data, num_used_indices_size, __FILE__, __LINE__);
data_by_column_[column_index] = reinterpret_cast<void*>(column_data);
} else if (bit_type == 32) {
uint32_t* column_data = reinterpret_cast<uint32_t*>(data_by_column_[column_index]);
DeallocateCUDAMemory<uint32_t>(&column_data, __FILE__, __LINE__);
AllocateCUDAMemory<uint32_t>(&column_data, num_used_indices_size, __FILE__, __LINE__);
data_by_column_[column_index] = reinterpret_cast<void*>(column_data);
#pragma omp parallel num_threads(num_threads_)
{
SetCUDADevice(gpu_device_id_, __FILE__, __LINE__);
#pragma omp for schedule(static)
for (int column_index = 0; column_index < num_columns_; ++column_index) {
OMP_LOOP_EX_BEGIN();
const uint8_t bit_type = column_bit_type_[column_index];
if (bit_type == 8) {
uint8_t* column_data = reinterpret_cast<uint8_t*>(data_by_column_[column_index]);
DeallocateCUDAMemory<uint8_t>(&column_data, __FILE__, __LINE__);
AllocateCUDAMemory<uint8_t>(&column_data, num_used_indices_size, __FILE__, __LINE__);
data_by_column_[column_index] = reinterpret_cast<void*>(column_data);
} else if (bit_type == 16) {
uint16_t* column_data = reinterpret_cast<uint16_t*>(data_by_column_[column_index]);
DeallocateCUDAMemory<uint16_t>(&column_data, __FILE__, __LINE__);
AllocateCUDAMemory<uint16_t>(&column_data, num_used_indices_size, __FILE__, __LINE__);
data_by_column_[column_index] = reinterpret_cast<void*>(column_data);
} else if (bit_type == 32) {
uint32_t* column_data = reinterpret_cast<uint32_t*>(data_by_column_[column_index]);
DeallocateCUDAMemory<uint32_t>(&column_data, __FILE__, __LINE__);
AllocateCUDAMemory<uint32_t>(&column_data, num_used_indices_size, __FILE__, __LINE__);
data_by_column_[column_index] = reinterpret_cast<void*>(column_data);
}
OMP_LOOP_EX_END();
}
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
DeallocateCUDAMemory<void*>(&cuda_data_by_column_, __FILE__, __LINE__);
Expand Down

0 comments on commit 5c9e61d

Please sign in to comment.