Skip to content

Commit

Permalink
[sycl] optimise hist building (#10311)
Browse files Browse the repository at this point in the history
Co-authored-by: Dmitry Razdoburdin <>
  • Loading branch information
razdoburdin committed May 27, 2024
1 parent 9def441 commit 0058301
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 38 deletions.
86 changes: 53 additions & 33 deletions plugin/sycl/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,30 @@ template ::sycl::event SubtractionHist(::sycl::queue qu,
const GHistRow<double, MemoryType::on_device>& src2,
size_t size, ::sycl::event event_priv);

inline auto GetBlocksParameters(const ::sycl::queue& qu, size_t size, size_t max_nblocks) {
struct _ {
size_t block_size, nblocks;
};

const size_t min_block_size = 32;
const size_t max_compute_units =
qu.get_device().get_info<::sycl::info::device::max_compute_units>();

size_t nblocks = max_compute_units;

size_t block_size = size / nblocks + !!(size % nblocks);
if (block_size > (1u << 12)) {
nblocks = max_nblocks;
block_size = size / nblocks + !!(size % nblocks);
}
if (block_size < min_block_size) {
block_size = min_block_size;
nblocks = size / block_size + !!(size % block_size);
}

return _{block_size, nblocks};
}

// Kernel with buffer using
template<typename FPType, typename BinIdxType, bool isDense>
::sycl::event BuildHistKernel(::sycl::queue qu,
Expand All @@ -73,27 +97,26 @@ ::sycl::event BuildHistKernel(::sycl::queue qu,
GHistRow<FPType, MemoryType::on_device>* hist,
GHistRow<FPType, MemoryType::on_device>* hist_buffer,
::sycl::event event_priv) {
using GradientPairT = xgboost::detail::GradientPairInternal<FPType>;
const size_t size = row_indices.Size();
const size_t* rid = row_indices.begin;
const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride;
const GradientPair::ValueT* pgh =
reinterpret_cast<const GradientPair::ValueT*>(gpair_device.DataConst());
const auto* pgh = gpair_device.DataConst();
const BinIdxType* gradient_index = gmat.index.data<BinIdxType>();
const uint32_t* offsets = gmat.index.Offset();
FPType* hist_data = reinterpret_cast<FPType*>(hist->Data());
const size_t nbins = gmat.nbins;

const size_t max_work_group_size =
qu.get_device().get_info<::sycl::info::device::max_work_group_size>();
const size_t work_group_size = n_columns < max_work_group_size ? n_columns : max_work_group_size;

const size_t max_nblocks = hist_buffer->Size() / (nbins * 2);
const size_t min_block_size = 128;
size_t nblocks = std::min(max_nblocks, size / min_block_size + !!(size % min_block_size));
const size_t block_size = size / nblocks + !!(size % nblocks);
FPType* hist_buffer_data = reinterpret_cast<FPType*>(hist_buffer->Data());
// Captured structured bindings are a C++20 extension
const auto block_params = GetBlocksParameters(qu, size, hist_buffer->Size() / (nbins * 2));
const size_t block_size = block_params.block_size;
const size_t nblocks = block_params.nblocks;

auto event_fill = qu.fill(hist_buffer_data, FPType(0), nblocks * nbins * 2, event_priv);
GradientPairT* hist_buffer_data = hist_buffer->Data();
auto event_fill = qu.fill(hist_buffer_data, GradientPairT(0, 0), nblocks * nbins * 2, event_priv);
auto event_main = qu.submit([&](::sycl::handler& cgh) {
cgh.depends_on(event_fill);
cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(nblocks, work_group_size),
Expand All @@ -102,13 +125,14 @@ ::sycl::event BuildHistKernel(::sycl::queue qu,
size_t block = pid.get_global_id(0);
size_t feat = pid.get_global_id(1);

FPType* hist_local = hist_buffer_data + block * nbins * 2;
GradientPairT* hist_local = hist_buffer_data + block * nbins;
for (size_t idx = 0; idx < block_size; ++idx) {
size_t i = block * block_size + idx;
if (i < size) {
const size_t icol_start = n_columns * rid[i];
const size_t idx_gh = rid[i];

const GradientPairT pgh_row = {pgh[idx_gh].GetGrad(), pgh[idx_gh].GetHess()};
pid.barrier(::sycl::access::fence_space::local_space);
const BinIdxType* gr_index_local = gradient_index + icol_start;

Expand All @@ -118,30 +142,27 @@ ::sycl::event BuildHistKernel(::sycl::queue qu,
idx_bin += offsets[j];
}
if (idx_bin < nbins) {
hist_local[2 * idx_bin] += pgh[2 * idx_gh];
hist_local[2 * idx_bin+1] += pgh[2 * idx_gh+1];
hist_local[idx_bin] += pgh_row;
}
}
}
}
});
});

GradientPairT* hist_data = hist->Data();
auto event_save = qu.submit([&](::sycl::handler& cgh) {
cgh.depends_on(event_main);
cgh.parallel_for<>(::sycl::range<1>(nbins), [=](::sycl::item<1> pid) {
size_t idx_bin = pid.get_id(0);

FPType gsum = 0.0f;
FPType hsum = 0.0f;
GradientPairT gpair = {0, 0};

for (size_t j = 0; j < nblocks; ++j) {
gsum += hist_buffer_data[j * nbins * 2 + 2 * idx_bin];
hsum += hist_buffer_data[j * nbins * 2 + 2 * idx_bin + 1];
gpair += hist_buffer_data[j * nbins + idx_bin];
}

hist_data[2 * idx_bin] = gsum;
hist_data[2 * idx_bin + 1] = hsum;
hist_data[idx_bin] = gpair;
});
});
return event_save;
Expand All @@ -165,33 +186,36 @@ ::sycl::event BuildHistKernel(::sycl::queue qu,
FPType* hist_data = reinterpret_cast<FPType*>(hist->Data());
const size_t nbins = gmat.nbins;

const size_t max_work_group_size =
qu.get_device().get_info<::sycl::info::device::max_work_group_size>();
const size_t feat_local = n_columns < max_work_group_size ? n_columns : max_work_group_size;
constexpr size_t work_group_size = 32;
const size_t n_work_groups = n_columns / work_group_size + (n_columns % work_group_size > 0);

auto event_fill = qu.fill(hist_data, FPType(0), nbins * 2, event_priv);
auto event_main = qu.submit([&](::sycl::handler& cgh) {
cgh.depends_on(event_fill);
cgh.parallel_for<>(::sycl::range<2>(size, feat_local),
[=](::sycl::item<2> pid) {
size_t i = pid.get_id(0);
size_t feat = pid.get_id(1);
cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(size, n_work_groups * work_group_size),
::sycl::range<2>(1, work_group_size)),
[=](::sycl::nd_item<2> pid) {
const int i = pid.get_global_id(0);
auto group = pid.get_group();

const size_t icol_start = n_columns * rid[i];
const size_t idx_gh = rid[i];

const FPType pgh_row[2] = {pgh[2 * idx_gh], pgh[2 * idx_gh + 1]};
const BinIdxType* gr_index_local = gradient_index + icol_start;

for (size_t j = feat; j < n_columns; j += feat_local) {
const size_t group_id = group.get_group_id()[1];
const size_t local_id = group.get_local_id()[1];
const size_t j = group_id * work_group_size + local_id;
if (j < n_columns) {
uint32_t idx_bin = static_cast<uint32_t>(gr_index_local[j]);
if constexpr (isDense) {
idx_bin += offsets[j];
}
if (idx_bin < nbins) {
AtomicRef<FPType> gsum(hist_data[2 * idx_bin]);
AtomicRef<FPType> hsum(hist_data[2 * idx_bin + 1]);
gsum.fetch_add(pgh[2 * idx_gh]);
hsum.fetch_add(pgh[2 * idx_gh + 1]);
gsum += pgh_row[0];
hsum += pgh_row[1];
}
}
});
Expand All @@ -214,10 +238,6 @@ ::sycl::event BuildHistDispatchKernel(
const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride;
const size_t nbins = gmat.nbins;

// max cycle size, while atomics are still effective
const size_t max_cycle_size_atomics = nbins;
const size_t cycle_size = size;

// TODO(razdoburdin): replace the add-hock dispatching criteria by more sutable one
bool use_atomic = (size < nbins) || (gmat.max_num_bins == gmat.nbins / n_columns);

Expand Down
5 changes: 1 addition & 4 deletions plugin/sycl/tree/hist_updater.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,7 @@ void HistUpdater<GradientSumT>::InitData(

hist_buffer_.Init(qu_, nbins);
size_t buffer_size = kBufferSize;
if (buffer_size > info.num_row_ / kMinBlockSize + 1) {
buffer_size = info.num_row_ / kMinBlockSize + 1;
}
hist_buffer_.Reset(buffer_size);
hist_buffer_.Reset(kBufferSize);

// initialize histogram builder
hist_builder_ = common::GHistBuilder<GradientSumT>(qu_, nbins);
Expand Down
1 change: 0 additions & 1 deletion plugin/sycl/tree/hist_updater.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ class HistUpdater {
DataLayout data_layout_;

constexpr static size_t kBufferSize = 2048;
constexpr static size_t kMinBlockSize = 128;
common::GHistBuilder<GradientSumT> hist_builder_;
common::ParallelGHistBuilder<GradientSumT> hist_buffer_;
/*! \brief culmulative histogram of gradients. */
Expand Down

0 comments on commit 0058301

Please sign in to comment.