Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SYCL. Add sampling initialization #10216

Merged
merged 2 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions plugin/sycl/tree/hist_updater.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*!
* Copyright 2017-2024 by Contributors
* \file hist_updater.cc
*/

#include "hist_updater.h"

#include <oneapi/dpl/random>

namespace xgboost {
namespace sycl {
namespace tree {

template<typename GradientSumT>
void HistUpdater<GradientSumT>::InitSampling(
const USMVector<GradientPair, MemoryType::on_device> &gpair,
USMVector<size_t, MemoryType::on_device>* row_indices) {
const size_t num_rows = row_indices->Size();
auto* row_idx = row_indices->Data();
const auto* gpair_ptr = gpair.DataConst();
uint64_t num_samples = 0;
const auto subsample = param_.subsample;
::sycl::event event;

{
::sycl::buffer<uint64_t, 1> flag_buf(&num_samples, 1);
uint64_t seed = seed_;
seed_ += num_rows;
event = qu_.submit([&](::sycl::handler& cgh) {
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
[=](::sycl::item<1> pid) {
uint64_t i = pid.get_id(0);

// Create minstd_rand engine
oneapi::dpl::minstd_rand engine(seed, i);
oneapi::dpl::bernoulli_distribution coin_flip(subsample);

auto rnd = coin_flip(engine);
if (gpair_ptr[i].GetHess() >= 0.0f && rnd) {
AtomicRef<uint64_t> num_samples_ref(flag_buf_acc[0]);
row_idx[num_samples_ref++] = i;
}
});
});
/* After calling a destructor for flag_buf, content will be copyed to num_samples */
}

row_indices->Resize(&qu_, num_samples, 0, &event);
qu_.wait();
}

template class HistUpdater<float>;
template class HistUpdater<double>;

} // namespace tree
} // namespace sycl
} // namespace xgboost
72 changes: 72 additions & 0 deletions plugin/sycl/tree/hist_updater.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*!
* Copyright 2017-2024 by Contributors
* \file hist_updater.h
*/
#ifndef PLUGIN_SYCL_TREE_HIST_UPDATER_H_
#define PLUGIN_SYCL_TREE_HIST_UPDATER_H_

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include <xgboost/tree_updater.h>
#pragma GCC diagnostic pop

#include <utility>
#include <memory>

#include "../common/partition_builder.h"
#include "split_evaluator.h"

#include "../data.h"

namespace xgboost {
namespace sycl {
namespace tree {

template<typename GradientSumT>
class HistUpdater {
public:
explicit HistUpdater(::sycl::queue qu,
const xgboost::tree::TrainParam& param,
std::unique_ptr<TreeUpdater> pruner,
FeatureInteractionConstraintHost int_constraints_,
DMatrix const* fmat)
: qu_(qu), param_(param),
tree_evaluator_(qu, param, fmat->Info().num_col_),
pruner_(std::move(pruner)),
interaction_constraints_{std::move(int_constraints_)},
p_last_tree_(nullptr), p_last_fmat_(fmat) {
builder_monitor_.Init("SYCL::Quantile::HistUpdater");
kernel_monitor_.Init("SYCL::Quantile::HistUpdater");
const auto sub_group_sizes =
qu_.get_device().get_info<::sycl::info::device::sub_group_sizes>();
sub_group_size_ = sub_group_sizes.back();
}

protected:
void InitSampling(const USMVector<GradientPair, MemoryType::on_device> &gpair,
USMVector<size_t, MemoryType::on_device>* row_indices);

size_t sub_group_size_;
const xgboost::tree::TrainParam& param_;
TreeEvaluator<GradientSumT> tree_evaluator_;
std::unique_ptr<TreeUpdater> pruner_;
FeatureInteractionConstraintHost interaction_constraints_;

// back pointers to tree and data matrix
const RegTree* p_last_tree_;
DMatrix const* const p_last_fmat_;

xgboost::common::Monitor builder_monitor_;
xgboost::common::Monitor kernel_monitor_;

uint64_t seed_ = 0;

::sycl::queue qu_;
};

} // namespace tree
} // namespace sycl
} // namespace xgboost

#endif // PLUGIN_SYCL_TREE_HIST_UPDATER_H_
1 change: 1 addition & 0 deletions tests/ci_build/conda_env/linux_sycl_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ dependencies:
- pytest-timeout
- pytest-cov
- dpcpp_linux-64
- onedpl-devel
104 changes: 104 additions & 0 deletions tests/cpp/plugin/test_sycl_hist_updater.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/**
* Copyright 2020-2024 by XGBoost contributors
*/
#include <gtest/gtest.h>

#include <oneapi/dpl/random>

#include "../../../plugin/sycl/tree/hist_updater.h"
#include "../../../plugin/sycl/device_manager.h"

#include "../helpers.h"

namespace xgboost::sycl::tree {

template <typename GradientSumT>
class TestHistUpdater : public HistUpdater<GradientSumT> {
public:
TestHistUpdater(::sycl::queue qu,
const xgboost::tree::TrainParam& param,
std::unique_ptr<TreeUpdater> pruner,
FeatureInteractionConstraintHost int_constraints_,
DMatrix const* fmat) : HistUpdater<GradientSumT>(qu, param, std::move(pruner),
int_constraints_, fmat) {}

void TestInitSampling(const USMVector<GradientPair, MemoryType::on_device> &gpair,
USMVector<size_t, MemoryType::on_device>* row_indices) {
HistUpdater<GradientSumT>::InitSampling(gpair, row_indices);
}
};

template <typename GradientSumT>
void TestHistUpdaterSampling(const xgboost::tree::TrainParam& param) {
const size_t num_rows = 1u << 12;
const size_t num_columns = 1;

Context ctx;
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});

DeviceManager device_manager;
auto qu = device_manager.GetQueue(ctx.Device());
ObjInfo task{ObjInfo::kRegression};

auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0}.GenerateDMatrix();

FeatureInteractionConstraintHost int_constraints;
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)};

TestHistUpdater<GradientSumT> updater(qu, param, std::move(pruner), int_constraints, p_fmat.get());

USMVector<size_t, MemoryType::on_device> row_indices_0(&qu, num_rows);
USMVector<size_t, MemoryType::on_device> row_indices_1(&qu, num_rows);
USMVector<GradientPair, MemoryType::on_device> gpair(&qu, num_rows);
auto* gpair_ptr = gpair.Data();
qu.submit([&](::sycl::handler& cgh) {
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
[=](::sycl::item<1> pid) {
uint64_t i = pid.get_linear_id();

constexpr uint32_t seed = 777;
oneapi::dpl::minstd_rand engine(seed, i);
oneapi::dpl::uniform_real_distribution<GradientPair::ValueT> distr(-1., 1.);
gpair_ptr[i] = {distr(engine), distr(engine)};
});
}).wait();

updater.TestInitSampling(gpair, &row_indices_0);

size_t n_samples = row_indices_0.Size();
// Half of gpairs have neg hess
ASSERT_LT(n_samples, num_rows * 0.5 * param.subsample * 1.2);
ASSERT_GT(n_samples, num_rows * 0.5 * param.subsample / 1.2);

// Check if two lanunches generate different realisations:
updater.TestInitSampling(gpair, &row_indices_1);
if (row_indices_1.Size() == n_samples) {
std::vector<size_t> row_indices_0_host(n_samples);
std::vector<size_t> row_indices_1_host(n_samples);
qu.memcpy(row_indices_0_host.data(), row_indices_0.Data(), n_samples * sizeof(size_t)).wait();
qu.memcpy(row_indices_1_host.data(), row_indices_1.Data(), n_samples * sizeof(size_t)).wait();

// The order in row_indices_0 and row_indices_1 can be different
std::set<size_t> rows;
for (auto row : row_indices_0_host) {
rows.insert(row);
}

size_t num_diffs = 0;
for (auto row : row_indices_1_host) {
if (rows.count(row) == 0) num_diffs++;
}

ASSERT_NE(num_diffs, 0);
}

}

TEST(SyclHistUpdater, Sampling) {
xgboost::tree::TrainParam param;
param.UpdateAllowUnknown(Args{{"subsample", "0.7"}});

TestHistUpdaterSampling<float>(param);
TestHistUpdaterSampling<double>(param);
}
} // namespace xgboost::sycl::tree
Loading