Skip to content

Commit

Permalink
[CUDA] Add binary objective for cuda_exp (#5425)
Browse files Browse the repository at this point in the history
* add binary objective for cuda_exp

* include <string> and <vector>

* exchange include ordering

* fix length of score to copy in evaluation

* fix EvalOneMetric

* fix cuda binary objective and prediction when boosting on gpu

* Add white space

* fix BoostFromScore for CUDABinaryLogloss

update log in test_register_logger

* include <algorithm>

* simplify shared memory buffer
  • Loading branch information
shiyu1994 committed Aug 31, 2022
1 parent 81d4d4d commit 2b8fe8b
Show file tree
Hide file tree
Showing 11 changed files with 446 additions and 12 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Expand Up @@ -398,6 +398,8 @@ endif()
if(USE_CUDA_EXP)
src/boosting/cuda/*.cpp
src/boosting/cuda/*.cu
src/objective/cuda/*.cpp
src/objective/cuda/*.cu
src/treelearner/cuda/*.cpp
src/treelearner/cuda/*.cu
src/io/cuda/*.cu
Expand Down
27 changes: 27 additions & 0 deletions include/LightGBM/cuda/cuda_objective_function.hpp
@@ -0,0 +1,27 @@
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/

#ifndef LIGHTGBM_OBJECTIVE_CUDA_CUDA_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_CUDA_CUDA_OBJECTIVE_HPP_

#ifdef USE_CUDA_EXP

#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/meta.h>

namespace LightGBM {

class CUDAObjectiveInterface {
public:
virtual void ConvertOutputCUDA(const data_size_t /*num_data*/, const double* /*input*/, double* /*output*/) const {}
};

} // namespace LightGBM

#endif // USE_CUDA_EXP

#endif // LIGHTGBM_OBJECTIVE_CUDA_CUDA_OBJECTIVE_HPP_
9 changes: 9 additions & 0 deletions include/LightGBM/objective_function.h
Expand Up @@ -93,6 +93,15 @@ class ObjectiveFunction {
* \brief Whether boosting is done on CUDA
*/
virtual bool IsCUDAObjective() const { return false; }

#ifdef USE_CUDA_EXP
/*!
* \brief Get output convert function for CUDA version
*/
virtual std::function<void(data_size_t, const double*, double*)> GetCUDAConvertOutputFunc() const {
return [] (data_size_t, const double*, double*) {};
}
#endif // USE_CUDA_EXP
};

} // namespace LightGBM
Expand Down
26 changes: 19 additions & 7 deletions src/boosting/gbdt.cpp
Expand Up @@ -607,22 +607,26 @@ void GBDT::UpdateScore(const Tree* tree, const int cur_tree_id) {
}
}

std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* score) const {
#ifdef USE_CUDA_EXP
std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* score, const data_size_t num_data) const {
#else
std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* score, const data_size_t /*num_data*/) const {
#endif // USE_CUDA_EXP
#ifdef USE_CUDA_EXP
const bool evaluation_on_cuda = metric->IsCUDAMetric();
if ((boosting_on_gpu_ && evaluation_on_cuda) || (!boosting_on_gpu_ && !evaluation_on_cuda)) {
#endif // USE_CUDA_EXP
return metric->Eval(score, objective_function_);
#ifdef USE_CUDA_EXP
} else if (boosting_on_gpu_ && !evaluation_on_cuda) {
const size_t total_size = static_cast<size_t>(num_data_) * static_cast<size_t>(num_tree_per_iteration_);
const size_t total_size = static_cast<size_t>(num_data) * static_cast<size_t>(num_tree_per_iteration_);
if (total_size > host_score_.size()) {
host_score_.resize(total_size, 0.0f);
}
CopyFromCUDADeviceToHost<double>(host_score_.data(), score, total_size, __FILE__, __LINE__);
return metric->Eval(host_score_.data(), objective_function_);
} else {
const size_t total_size = static_cast<size_t>(num_data_) * static_cast<size_t>(num_tree_per_iteration_);
const size_t total_size = static_cast<size_t>(num_data) * static_cast<size_t>(num_tree_per_iteration_);
if (total_size > cuda_score_.Size()) {
cuda_score_.Resize(total_size);
}
Expand All @@ -641,7 +645,7 @@ std::string GBDT::OutputMetric(int iter) {
if (need_output) {
for (auto& sub_metric : training_metrics_) {
auto name = sub_metric->GetName();
auto scores = EvalOneMetric(sub_metric, train_score_updater_->score());
auto scores = EvalOneMetric(sub_metric, train_score_updater_->score(), train_score_updater_->num_data());
for (size_t k = 0; k < name.size(); ++k) {
std::stringstream tmp_buf;
tmp_buf << "Iteration:" << iter
Expand All @@ -658,7 +662,7 @@ std::string GBDT::OutputMetric(int iter) {
if (need_output || early_stopping_round_ > 0) {
for (size_t i = 0; i < valid_metrics_.size(); ++i) {
for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
auto test_scores = EvalOneMetric(valid_metrics_[i][j], valid_score_updater_[i]->score());
auto test_scores = EvalOneMetric(valid_metrics_[i][j], valid_score_updater_[i]->score(), valid_score_updater_[i]->num_data());
auto name = valid_metrics_[i][j]->GetName();
for (size_t k = 0; k < name.size(); ++k) {
std::stringstream tmp_buf;
Expand Down Expand Up @@ -698,15 +702,15 @@ std::vector<double> GBDT::GetEvalAt(int data_idx) const {
std::vector<double> ret;
if (data_idx == 0) {
for (auto& sub_metric : training_metrics_) {
auto scores = EvalOneMetric(sub_metric, train_score_updater_->score());
auto scores = EvalOneMetric(sub_metric, train_score_updater_->score(), train_score_updater_->num_data());
for (auto score : scores) {
ret.push_back(score);
}
}
} else {
auto used_idx = data_idx - 1;
for (size_t j = 0; j < valid_metrics_[used_idx].size(); ++j) {
auto test_scores = EvalOneMetric(valid_metrics_[used_idx][j], valid_score_updater_[used_idx]->score());
auto test_scores = EvalOneMetric(valid_metrics_[used_idx][j], valid_score_updater_[used_idx]->score(), valid_score_updater_[used_idx]->num_data());
for (auto score : test_scores) {
ret.push_back(score);
}
Expand Down Expand Up @@ -760,6 +764,14 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
num_data = valid_score_updater_[used_idx]->num_data();
*out_len = static_cast<int64_t>(num_data) * num_class_;
}
#ifdef USE_CUDA_EXP
std::vector<double> host_raw_scores;
if (boosting_on_gpu_) {
host_raw_scores.resize(static_cast<size_t>(*out_len), 0.0);
CopyFromCUDADeviceToHost<double>(host_raw_scores.data(), raw_scores, static_cast<size_t>(*out_len), __FILE__, __LINE__);
raw_scores = host_raw_scores.data();
}
#endif // USE_CUDA_EXP
if (objective_function_ != nullptr) {
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data; ++i) {
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/gbdt.h
Expand Up @@ -443,7 +443,7 @@ class GBDT : public GBDTBase {
* \brief eval results for one metric
*/
virtual std::vector<double> EvalOneMetric(const Metric* metric, const double* score) const;
virtual std::vector<double> EvalOneMetric(const Metric* metric, const double* score, const data_size_t num_data) const;

/*!
* \brief Print metric result of current iteration
Expand Down
2 changes: 1 addition & 1 deletion src/objective/binary_objective.hpp
Expand Up @@ -189,7 +189,7 @@ class BinaryLogloss: public ObjectiveFunction {

data_size_t NumPositiveData() const override { return num_pos_data_; }

private:
protected:
/*! \brief Number of data */
data_size_t num_data_;
/*! \brief Number of positive samples */
Expand Down
82 changes: 82 additions & 0 deletions src/objective/cuda/cuda_binary_objective.cpp
@@ -0,0 +1,82 @@
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/

#ifdef USE_CUDA_EXP

#include "cuda_binary_objective.hpp"

#include <string>
#include <vector>

namespace LightGBM {

CUDABinaryLogloss::CUDABinaryLogloss(const Config& config):
BinaryLogloss(config), ova_class_id_(-1) {
cuda_label_ = nullptr;
cuda_ova_label_ = nullptr;
cuda_weights_ = nullptr;
cuda_boost_from_score_ = nullptr;
cuda_sum_weights_ = nullptr;
cuda_label_weights_ = nullptr;
}

CUDABinaryLogloss::CUDABinaryLogloss(const Config& config, const int ova_class_id):
BinaryLogloss(config, [ova_class_id](label_t label) { return static_cast<int>(label) == ova_class_id; }), ova_class_id_(ova_class_id) {}

CUDABinaryLogloss::CUDABinaryLogloss(const std::vector<std::string>& strs): BinaryLogloss(strs) {}

CUDABinaryLogloss::~CUDABinaryLogloss() {
DeallocateCUDAMemory<label_t>(&cuda_ova_label_, __FILE__, __LINE__);
DeallocateCUDAMemory<double>(&cuda_label_weights_, __FILE__, __LINE__);
DeallocateCUDAMemory<double>(&cuda_boost_from_score_, __FILE__, __LINE__);
DeallocateCUDAMemory<double>(&cuda_sum_weights_, __FILE__, __LINE__);
}

void CUDABinaryLogloss::Init(const Metadata& metadata, data_size_t num_data) {
BinaryLogloss::Init(metadata, num_data);
if (ova_class_id_ == -1) {
cuda_label_ = metadata.cuda_metadata()->cuda_label();
cuda_ova_label_ = nullptr;
} else {
InitCUDAMemoryFromHostMemory<label_t>(&cuda_ova_label_, metadata.cuda_metadata()->cuda_label(), static_cast<size_t>(num_data), __FILE__, __LINE__);
LaunchResetOVACUDALableKernel();
cuda_label_ = cuda_ova_label_;
}
cuda_weights_ = metadata.cuda_metadata()->cuda_weights();
AllocateCUDAMemory<double>(&cuda_boost_from_score_, 1, __FILE__, __LINE__);
SetCUDAMemory<double>(cuda_boost_from_score_, 0, 1, __FILE__, __LINE__);
AllocateCUDAMemory<double>(&cuda_sum_weights_, 1, __FILE__, __LINE__);
SetCUDAMemory<double>(cuda_sum_weights_, 0, 1, __FILE__, __LINE__);
if (label_weights_[0] != 1.0f || label_weights_[1] != 1.0f) {
InitCUDAMemoryFromHostMemory<double>(&cuda_label_weights_, label_weights_, 2, __FILE__, __LINE__);
} else {
cuda_label_weights_ = nullptr;
}
}

void CUDABinaryLogloss::GetGradients(const double* scores, score_t* gradients, score_t* hessians) const {
LaunchGetGradientsKernel(scores, gradients, hessians);
SynchronizeCUDADevice(__FILE__, __LINE__);
}

double CUDABinaryLogloss::BoostFromScore(int) const {
LaunchBoostFromScoreKernel();
SynchronizeCUDADevice(__FILE__, __LINE__);
double boost_from_score = 0.0f;
CopyFromCUDADeviceToHost<double>(&boost_from_score, cuda_boost_from_score_, 1, __FILE__, __LINE__);
double pavg = 0.0f;
CopyFromCUDADeviceToHost<double>(&pavg, cuda_sum_weights_, 1, __FILE__, __LINE__);
Log::Info("[%s:%s]: pavg=%f -> initscore=%f", GetName(), __func__, pavg, boost_from_score);
return boost_from_score;
}

void CUDABinaryLogloss::ConvertOutputCUDA(const data_size_t num_data, const double* input, double* output) const {
LaunchConvertOutputCUDAKernel(num_data, input, output);
}

} // namespace LightGBM

#endif // USE_CUDA_EXP

0 comments on commit 2b8fe8b

Please sign in to comment.