-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SYCL] Add sampling initialization (#10216)
--------- Co-authored-by: Dmitry Razdoburdin <>
- Loading branch information
1 parent
59d7b8d
commit 58513dc
Showing
4 changed files
with
235 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
/*! | ||
* Copyright 2017-2024 by Contributors | ||
* \file hist_updater.cc | ||
*/ | ||
|
||
#include "hist_updater.h" | ||
|
||
#include <oneapi/dpl/random> | ||
|
||
namespace xgboost { | ||
namespace sycl { | ||
namespace tree { | ||
|
||
template<typename GradientSumT> | ||
void HistUpdater<GradientSumT>::InitSampling( | ||
const USMVector<GradientPair, MemoryType::on_device> &gpair, | ||
USMVector<size_t, MemoryType::on_device>* row_indices) { | ||
const size_t num_rows = row_indices->Size(); | ||
auto* row_idx = row_indices->Data(); | ||
const auto* gpair_ptr = gpair.DataConst(); | ||
uint64_t num_samples = 0; | ||
const auto subsample = param_.subsample; | ||
::sycl::event event; | ||
|
||
{ | ||
::sycl::buffer<uint64_t, 1> flag_buf(&num_samples, 1); | ||
uint64_t seed = seed_; | ||
seed_ += num_rows; | ||
event = qu_.submit([&](::sycl::handler& cgh) { | ||
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh); | ||
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)), | ||
[=](::sycl::item<1> pid) { | ||
uint64_t i = pid.get_id(0); | ||
|
||
// Create minstd_rand engine | ||
oneapi::dpl::minstd_rand engine(seed, i); | ||
oneapi::dpl::bernoulli_distribution coin_flip(subsample); | ||
|
||
auto rnd = coin_flip(engine); | ||
if (gpair_ptr[i].GetHess() >= 0.0f && rnd) { | ||
AtomicRef<uint64_t> num_samples_ref(flag_buf_acc[0]); | ||
row_idx[num_samples_ref++] = i; | ||
} | ||
}); | ||
}); | ||
/* After calling a destructor for flag_buf, content will be copyed to num_samples */ | ||
} | ||
|
||
row_indices->Resize(&qu_, num_samples, 0, &event); | ||
qu_.wait(); | ||
} | ||
|
||
template class HistUpdater<float>; | ||
template class HistUpdater<double>; | ||
|
||
} // namespace tree | ||
} // namespace sycl | ||
} // namespace xgboost |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
/*! | ||
* Copyright 2017-2024 by Contributors | ||
* \file hist_updater.h | ||
*/ | ||
#ifndef PLUGIN_SYCL_TREE_HIST_UPDATER_H_ | ||
#define PLUGIN_SYCL_TREE_HIST_UPDATER_H_ | ||
|
||
#pragma GCC diagnostic push | ||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare" | ||
#pragma GCC diagnostic ignored "-W#pragma-messages" | ||
#include <xgboost/tree_updater.h> | ||
#pragma GCC diagnostic pop | ||
|
||
#include <utility> | ||
#include <memory> | ||
|
||
#include "../common/partition_builder.h" | ||
#include "split_evaluator.h" | ||
|
||
#include "../data.h" | ||
|
||
namespace xgboost { | ||
namespace sycl { | ||
namespace tree { | ||
|
||
template<typename GradientSumT> | ||
class HistUpdater { | ||
public: | ||
explicit HistUpdater(::sycl::queue qu, | ||
const xgboost::tree::TrainParam& param, | ||
std::unique_ptr<TreeUpdater> pruner, | ||
FeatureInteractionConstraintHost int_constraints_, | ||
DMatrix const* fmat) | ||
: qu_(qu), param_(param), | ||
tree_evaluator_(qu, param, fmat->Info().num_col_), | ||
pruner_(std::move(pruner)), | ||
interaction_constraints_{std::move(int_constraints_)}, | ||
p_last_tree_(nullptr), p_last_fmat_(fmat) { | ||
builder_monitor_.Init("SYCL::Quantile::HistUpdater"); | ||
kernel_monitor_.Init("SYCL::Quantile::HistUpdater"); | ||
const auto sub_group_sizes = | ||
qu_.get_device().get_info<::sycl::info::device::sub_group_sizes>(); | ||
sub_group_size_ = sub_group_sizes.back(); | ||
} | ||
|
||
protected: | ||
void InitSampling(const USMVector<GradientPair, MemoryType::on_device> &gpair, | ||
USMVector<size_t, MemoryType::on_device>* row_indices); | ||
|
||
size_t sub_group_size_; | ||
const xgboost::tree::TrainParam& param_; | ||
TreeEvaluator<GradientSumT> tree_evaluator_; | ||
std::unique_ptr<TreeUpdater> pruner_; | ||
FeatureInteractionConstraintHost interaction_constraints_; | ||
|
||
// back pointers to tree and data matrix | ||
const RegTree* p_last_tree_; | ||
DMatrix const* const p_last_fmat_; | ||
|
||
xgboost::common::Monitor builder_monitor_; | ||
xgboost::common::Monitor kernel_monitor_; | ||
|
||
uint64_t seed_ = 0; | ||
|
||
::sycl::queue qu_; | ||
}; | ||
|
||
} // namespace tree | ||
} // namespace sycl | ||
} // namespace xgboost | ||
|
||
#endif // PLUGIN_SYCL_TREE_HIST_UPDATER_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,3 +18,4 @@ dependencies: | |
- pytest-timeout | ||
- pytest-cov | ||
- dpcpp_linux-64 | ||
- onedpl-devel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
/** | ||
* Copyright 2020-2024 by XGBoost contributors | ||
*/ | ||
#include <gtest/gtest.h> | ||
|
||
#include <oneapi/dpl/random> | ||
|
||
#include "../../../plugin/sycl/tree/hist_updater.h" | ||
#include "../../../plugin/sycl/device_manager.h" | ||
|
||
#include "../helpers.h" | ||
|
||
namespace xgboost::sycl::tree { | ||
|
||
template <typename GradientSumT> | ||
class TestHistUpdater : public HistUpdater<GradientSumT> { | ||
public: | ||
TestHistUpdater(::sycl::queue qu, | ||
const xgboost::tree::TrainParam& param, | ||
std::unique_ptr<TreeUpdater> pruner, | ||
FeatureInteractionConstraintHost int_constraints_, | ||
DMatrix const* fmat) : HistUpdater<GradientSumT>(qu, param, std::move(pruner), | ||
int_constraints_, fmat) {} | ||
|
||
void TestInitSampling(const USMVector<GradientPair, MemoryType::on_device> &gpair, | ||
USMVector<size_t, MemoryType::on_device>* row_indices) { | ||
HistUpdater<GradientSumT>::InitSampling(gpair, row_indices); | ||
} | ||
}; | ||
|
||
template <typename GradientSumT> | ||
void TestHistUpdaterSampling(const xgboost::tree::TrainParam& param) { | ||
const size_t num_rows = 1u << 12; | ||
const size_t num_columns = 1; | ||
|
||
Context ctx; | ||
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); | ||
|
||
DeviceManager device_manager; | ||
auto qu = device_manager.GetQueue(ctx.Device()); | ||
ObjInfo task{ObjInfo::kRegression}; | ||
|
||
auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0}.GenerateDMatrix(); | ||
|
||
FeatureInteractionConstraintHost int_constraints; | ||
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)}; | ||
|
||
TestHistUpdater<GradientSumT> updater(qu, param, std::move(pruner), int_constraints, p_fmat.get()); | ||
|
||
USMVector<size_t, MemoryType::on_device> row_indices_0(&qu, num_rows); | ||
USMVector<size_t, MemoryType::on_device> row_indices_1(&qu, num_rows); | ||
USMVector<GradientPair, MemoryType::on_device> gpair(&qu, num_rows); | ||
auto* gpair_ptr = gpair.Data(); | ||
qu.submit([&](::sycl::handler& cgh) { | ||
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)), | ||
[=](::sycl::item<1> pid) { | ||
uint64_t i = pid.get_linear_id(); | ||
|
||
constexpr uint32_t seed = 777; | ||
oneapi::dpl::minstd_rand engine(seed, i); | ||
oneapi::dpl::uniform_real_distribution<GradientPair::ValueT> distr(-1., 1.); | ||
gpair_ptr[i] = {distr(engine), distr(engine)}; | ||
}); | ||
}).wait(); | ||
|
||
updater.TestInitSampling(gpair, &row_indices_0); | ||
|
||
size_t n_samples = row_indices_0.Size(); | ||
// Half of gpairs have neg hess | ||
ASSERT_LT(n_samples, num_rows * 0.5 * param.subsample * 1.2); | ||
ASSERT_GT(n_samples, num_rows * 0.5 * param.subsample / 1.2); | ||
|
||
// Check if two lanunches generate different realisations: | ||
updater.TestInitSampling(gpair, &row_indices_1); | ||
if (row_indices_1.Size() == n_samples) { | ||
std::vector<size_t> row_indices_0_host(n_samples); | ||
std::vector<size_t> row_indices_1_host(n_samples); | ||
qu.memcpy(row_indices_0_host.data(), row_indices_0.Data(), n_samples * sizeof(size_t)).wait(); | ||
qu.memcpy(row_indices_1_host.data(), row_indices_1.Data(), n_samples * sizeof(size_t)).wait(); | ||
|
||
// The order in row_indices_0 and row_indices_1 can be different | ||
std::set<size_t> rows; | ||
for (auto row : row_indices_0_host) { | ||
rows.insert(row); | ||
} | ||
|
||
size_t num_diffs = 0; | ||
for (auto row : row_indices_1_host) { | ||
if (rows.count(row) == 0) num_diffs++; | ||
} | ||
|
||
ASSERT_NE(num_diffs, 0); | ||
} | ||
|
||
} | ||
|
||
TEST(SyclHistUpdater, Sampling) { | ||
xgboost::tree::TrainParam param; | ||
param.UpdateAllowUnknown(Args{{"subsample", "0.7"}}); | ||
|
||
TestHistUpdaterSampling<float>(param); | ||
TestHistUpdaterSampling<double>(param); | ||
} | ||
} // namespace xgboost::sycl::tree |