Skip to content

Commit

Permalink
[SYCL] Add sampling initialization (#10216)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Dmitry Razdoburdin <>
  • Loading branch information
razdoburdin committed Apr 24, 2024
1 parent 59d7b8d commit 58513dc
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 0 deletions.
58 changes: 58 additions & 0 deletions plugin/sycl/tree/hist_updater.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*!
* Copyright 2017-2024 by Contributors
* \file hist_updater.cc
*/

#include "hist_updater.h"

#include <oneapi/dpl/random>

namespace xgboost {
namespace sycl {
namespace tree {

template<typename GradientSumT>
void HistUpdater<GradientSumT>::InitSampling(
const USMVector<GradientPair, MemoryType::on_device> &gpair,
USMVector<size_t, MemoryType::on_device>* row_indices) {
const size_t num_rows = row_indices->Size();
auto* row_idx = row_indices->Data();
const auto* gpair_ptr = gpair.DataConst();
uint64_t num_samples = 0;
const auto subsample = param_.subsample;
::sycl::event event;

{
::sycl::buffer<uint64_t, 1> flag_buf(&num_samples, 1);
uint64_t seed = seed_;
seed_ += num_rows;
event = qu_.submit([&](::sycl::handler& cgh) {
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
[=](::sycl::item<1> pid) {
uint64_t i = pid.get_id(0);

// Create minstd_rand engine
oneapi::dpl::minstd_rand engine(seed, i);
oneapi::dpl::bernoulli_distribution coin_flip(subsample);

auto rnd = coin_flip(engine);
if (gpair_ptr[i].GetHess() >= 0.0f && rnd) {
AtomicRef<uint64_t> num_samples_ref(flag_buf_acc[0]);
row_idx[num_samples_ref++] = i;
}
});
});
/* After calling a destructor for flag_buf, content will be copyed to num_samples */
}

row_indices->Resize(&qu_, num_samples, 0, &event);
qu_.wait();
}

template class HistUpdater<float>;
template class HistUpdater<double>;

} // namespace tree
} // namespace sycl
} // namespace xgboost
72 changes: 72 additions & 0 deletions plugin/sycl/tree/hist_updater.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*!
* Copyright 2017-2024 by Contributors
* \file hist_updater.h
*/
#ifndef PLUGIN_SYCL_TREE_HIST_UPDATER_H_
#define PLUGIN_SYCL_TREE_HIST_UPDATER_H_

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include <xgboost/tree_updater.h>
#pragma GCC diagnostic pop

#include <utility>
#include <memory>

#include "../common/partition_builder.h"
#include "split_evaluator.h"

#include "../data.h"

namespace xgboost {
namespace sycl {
namespace tree {

template<typename GradientSumT>
class HistUpdater {
public:
explicit HistUpdater(::sycl::queue qu,
const xgboost::tree::TrainParam& param,
std::unique_ptr<TreeUpdater> pruner,
FeatureInteractionConstraintHost int_constraints_,
DMatrix const* fmat)
: qu_(qu), param_(param),
tree_evaluator_(qu, param, fmat->Info().num_col_),
pruner_(std::move(pruner)),
interaction_constraints_{std::move(int_constraints_)},
p_last_tree_(nullptr), p_last_fmat_(fmat) {
builder_monitor_.Init("SYCL::Quantile::HistUpdater");
kernel_monitor_.Init("SYCL::Quantile::HistUpdater");
const auto sub_group_sizes =
qu_.get_device().get_info<::sycl::info::device::sub_group_sizes>();
sub_group_size_ = sub_group_sizes.back();
}

protected:
void InitSampling(const USMVector<GradientPair, MemoryType::on_device> &gpair,
USMVector<size_t, MemoryType::on_device>* row_indices);

size_t sub_group_size_;
const xgboost::tree::TrainParam& param_;
TreeEvaluator<GradientSumT> tree_evaluator_;
std::unique_ptr<TreeUpdater> pruner_;
FeatureInteractionConstraintHost interaction_constraints_;

// back pointers to tree and data matrix
const RegTree* p_last_tree_;
DMatrix const* const p_last_fmat_;

xgboost::common::Monitor builder_monitor_;
xgboost::common::Monitor kernel_monitor_;

uint64_t seed_ = 0;

::sycl::queue qu_;
};

} // namespace tree
} // namespace sycl
} // namespace xgboost

#endif // PLUGIN_SYCL_TREE_HIST_UPDATER_H_
1 change: 1 addition & 0 deletions tests/ci_build/conda_env/linux_sycl_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ dependencies:
- pytest-timeout
- pytest-cov
- dpcpp_linux-64
- onedpl-devel
104 changes: 104 additions & 0 deletions tests/cpp/plugin/test_sycl_hist_updater.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/**
* Copyright 2020-2024 by XGBoost contributors
*/
#include <gtest/gtest.h>

#include <oneapi/dpl/random>

#include "../../../plugin/sycl/tree/hist_updater.h"
#include "../../../plugin/sycl/device_manager.h"

#include "../helpers.h"

namespace xgboost::sycl::tree {

template <typename GradientSumT>
class TestHistUpdater : public HistUpdater<GradientSumT> {
public:
TestHistUpdater(::sycl::queue qu,
const xgboost::tree::TrainParam& param,
std::unique_ptr<TreeUpdater> pruner,
FeatureInteractionConstraintHost int_constraints_,
DMatrix const* fmat) : HistUpdater<GradientSumT>(qu, param, std::move(pruner),
int_constraints_, fmat) {}

void TestInitSampling(const USMVector<GradientPair, MemoryType::on_device> &gpair,
USMVector<size_t, MemoryType::on_device>* row_indices) {
HistUpdater<GradientSumT>::InitSampling(gpair, row_indices);
}
};

template <typename GradientSumT>
void TestHistUpdaterSampling(const xgboost::tree::TrainParam& param) {
const size_t num_rows = 1u << 12;
const size_t num_columns = 1;

Context ctx;
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});

DeviceManager device_manager;
auto qu = device_manager.GetQueue(ctx.Device());
ObjInfo task{ObjInfo::kRegression};

auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0}.GenerateDMatrix();

FeatureInteractionConstraintHost int_constraints;
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)};

TestHistUpdater<GradientSumT> updater(qu, param, std::move(pruner), int_constraints, p_fmat.get());

USMVector<size_t, MemoryType::on_device> row_indices_0(&qu, num_rows);
USMVector<size_t, MemoryType::on_device> row_indices_1(&qu, num_rows);
USMVector<GradientPair, MemoryType::on_device> gpair(&qu, num_rows);
auto* gpair_ptr = gpair.Data();
qu.submit([&](::sycl::handler& cgh) {
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
[=](::sycl::item<1> pid) {
uint64_t i = pid.get_linear_id();

constexpr uint32_t seed = 777;
oneapi::dpl::minstd_rand engine(seed, i);
oneapi::dpl::uniform_real_distribution<GradientPair::ValueT> distr(-1., 1.);
gpair_ptr[i] = {distr(engine), distr(engine)};
});
}).wait();

updater.TestInitSampling(gpair, &row_indices_0);

size_t n_samples = row_indices_0.Size();
// Half of gpairs have neg hess
ASSERT_LT(n_samples, num_rows * 0.5 * param.subsample * 1.2);
ASSERT_GT(n_samples, num_rows * 0.5 * param.subsample / 1.2);

// Check if two lanunches generate different realisations:
updater.TestInitSampling(gpair, &row_indices_1);
if (row_indices_1.Size() == n_samples) {
std::vector<size_t> row_indices_0_host(n_samples);
std::vector<size_t> row_indices_1_host(n_samples);
qu.memcpy(row_indices_0_host.data(), row_indices_0.Data(), n_samples * sizeof(size_t)).wait();
qu.memcpy(row_indices_1_host.data(), row_indices_1.Data(), n_samples * sizeof(size_t)).wait();

// The order in row_indices_0 and row_indices_1 can be different
std::set<size_t> rows;
for (auto row : row_indices_0_host) {
rows.insert(row);
}

size_t num_diffs = 0;
for (auto row : row_indices_1_host) {
if (rows.count(row) == 0) num_diffs++;
}

ASSERT_NE(num_diffs, 0);
}

}

TEST(SyclHistUpdater, Sampling) {
xgboost::tree::TrainParam param;
param.UpdateAllowUnknown(Args{{"subsample", "0.7"}});

TestHistUpdaterSampling<float>(param);
TestHistUpdaterSampling<double>(param);
}
} // namespace xgboost::sycl::tree

0 comments on commit 58513dc

Please sign in to comment.