Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SYCL. Add basic functional for QuantileHistMaker #10174

Merged
merged 3 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions plugin/sycl/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*!
* Copyright 2017-2024 by Contributors
* \file updater_quantile_hist.cc
*/
#include <vector>

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include "xgboost/tree_updater.h"
#pragma GCC diagnostic pop

#include "xgboost/logging.h"

#include "updater_quantile_hist.h"
#include "../data.h"

namespace xgboost {
namespace sycl {
namespace tree {

DMLC_REGISTRY_FILE_TAG(updater_quantile_hist_sycl);

DMLC_REGISTER_PARAMETER(HistMakerTrainParam);

void QuantileHistMaker::Configure(const Args& args) {
const DeviceOrd device_spec = ctx_->Device();
qu_ = device_manager.GetQueue(device_spec);

param_.UpdateAllowUnknown(args);
hist_maker_param_.UpdateAllowUnknown(args);
}

void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param,
linalg::Matrix<GradientPair>* gpair,
DMatrix *dmat,
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) {
LOG(FATAL) << "Not Implemented yet";
}

bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data,
linalg::MatrixView<float> out_preds) {
LOG(FATAL) << "Not Implemented yet";
}

XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker_sycl")
.describe("Grow tree using quantized histogram with SYCL.")
.set_body(
[](Context const* ctx, ObjInfo const * task) {
return new QuantileHistMaker(ctx, task);
});
} // namespace tree
} // namespace sycl
} // namespace xgboost
105 changes: 105 additions & 0 deletions plugin/sycl/tree/updater_quantile_hist.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*!
* Copyright 2017-2024 by Contributors
* \file updater_quantile_hist.h
*/
#ifndef PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_
#define PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_

#include <dmlc/timer.h>
#include <xgboost/tree_updater.h>

#include <vector>

#include "../data/gradient_index.h"
#include "../common/hist_util.h"
#include "../common/row_set.h"
#include "../common/partition_builder.h"
#include "split_evaluator.h"
#include "../device_manager.h"

#include "xgboost/data.h"
#include "xgboost/json.h"
#include "../../src/tree/constraints.h"
#include "../../src/common/random.h"

namespace xgboost {
namespace sycl {
namespace tree {

// training parameters specific to this algorithm
struct HistMakerTrainParam
: public XGBoostParameter<HistMakerTrainParam> {
bool single_precision_histogram = false;
// declare parameters
DMLC_DECLARE_PARAMETER(HistMakerTrainParam) {
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
"Use single precision to build histograms.");
}
};

/*! \brief construct a tree using quantized feature values with SYCL backend*/
class QuantileHistMaker: public TreeUpdater {
public:
QuantileHistMaker(Context const* ctx, ObjInfo const * task) :
TreeUpdater(ctx), task_{task} {
updater_monitor_.Init("SYCLQuantileHistMaker");
}
void Configure(const Args& args) override;

void Update(xgboost::tree::TrainParam const *param,
linalg::Matrix<GradientPair>* gpair,
DMatrix* dmat,
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& trees) override;

bool UpdatePredictionCache(const DMatrix* data,
linalg::MatrixView<float> out_preds) override;

void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
FromJson(config.at("train_param"), &this->param_);
try {
FromJson(config.at("sycl_hist_train_param"), &this->hist_maker_param_);
} catch (std::out_of_range& e) {
// XGBoost model is from 1.1.x, so 'cpu_hist_train_param' is missing.
// We add this compatibility check because it's just recently that we (developers) began
// persuade R users away from using saveRDS() for model serialization. Hopefully, one day,
// everyone will be using xgb.save().
LOG(WARNING) << "Attempted to load interal configuration for a model file that was generated "
<< "by a previous version of XGBoost. A likely cause for this warning is that the model "
<< "was saved with saveRDS() in R or pickle.dump() in Python. We strongly ADVISE AGAINST "
<< "using saveRDS() or pickle.dump() so that the model remains accessible in current and "
<< "upcoming XGBoost releases. Please use xgb.save() instead to preserve models for the "
<< "long term. For more details and explanation, see "
<< "https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html";
this->hist_maker_param_.UpdateAllowUnknown(Args{});
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not necessary, we drop the configuration in learner if the versions don't match.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["train_param"] = ToJson(param_);
out["sycl_hist_train_param"] = ToJson(hist_maker_param_);
}

char const* Name() const override {
return "grow_quantile_histmaker_sycl";
}

protected:
HistMakerTrainParam hist_maker_param_;
// training parameter
xgboost::tree::TrainParam param_;

xgboost::common::Monitor updater_monitor_;

::sycl::queue qu_;
DeviceManager device_manager;
ObjInfo const *task_{nullptr};
};


} // namespace tree
} // namespace sycl
} // namespace xgboost

#endif // PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_
55 changes: 55 additions & 0 deletions tests/cpp/plugin/test_sycl_quantile_hist_builder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/**
* Copyright 2020-2024 by XGBoost contributors
*/
#include <gtest/gtest.h>

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include <xgboost/json.h>
#include <xgboost/task.h>
#include "../../../plugin/sycl/tree/updater_quantile_hist.h" // for QuantileHistMaker
#pragma GCC diagnostic pop

namespace xgboost::sycl::tree {
TEST(SyclQuantileHistMaker, Basic) {
Context ctx;
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});

ObjInfo task{ObjInfo::kRegression};
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)};

ASSERT_EQ(updater->Name(), "grow_quantile_histmaker_sycl");
}

TEST(SyclQuantileHistMaker, JsonIO) {
Context ctx;
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});

ObjInfo task{ObjInfo::kRegression};
Json config {Object()};
{
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)};
updater->Configure({{"max_depth", std::to_string(42)}});
updater->Configure({{"single_precision_histogram", std::to_string(true)}});
updater->SaveConfig(&config);
}

{
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)};
updater->LoadConfig(config);

Json new_config {Object()};
updater->SaveConfig(&new_config);

ASSERT_EQ(config, new_config);

auto max_depth = atoi(get<String const>(new_config["train_param"]["max_depth"]).c_str());
ASSERT_EQ(max_depth, 42);

auto single_precision_histogram = atoi(get<String const>(new_config["sycl_hist_train_param"]["single_precision_histogram"]).c_str());
ASSERT_EQ(single_precision_histogram, 1);
}

}
} // namespace xgboost::sycl::tree
Loading