Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SYCL. Optimize gradients calculations. #10325

Merged
merged 6 commits into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 240 additions & 0 deletions plugin/sycl/common/linalg_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
/**
* Copyright 2021-2024, XGBoost Contributors
* \file linalg_op.h
*/
#ifndef PLUGIN_SYCL_COMMON_LINALG_OP_H_
#define PLUGIN_SYCL_COMMON_LINALG_OP_H_

#include <vector>
#include <utility>

#include "../data.h"

#include <CL/sycl.hpp>

namespace xgboost {
namespace sycl {
namespace linalg {

struct WorkGroupsParams {
size_t n_workgroups;
size_t workgroup_size;
};

template <typename Fn>
::sycl::event GroupWiseKernel(::sycl::queue* qu, int* flag_ptr,
const std::vector<::sycl::event>& events,
const WorkGroupsParams& wg, Fn &&fn) {
::sycl::buffer<int, 1> flag_buf(flag_ptr, 1);
auto event = qu->submit([&](::sycl::handler& cgh) {
cgh.depends_on(events);
auto flag = flag_buf.get_access<::sycl::access::mode::write>(cgh);
cgh.parallel_for_work_group<>(::sycl::range<1>(wg.n_workgroups),
::sycl::range<1>(wg.workgroup_size),
[=](::sycl::group<1> group) {
group.parallel_for_work_item([&](::sycl::h_item<1> item) {
const size_t idx = item.get_global_id()[0];
fn(idx, flag);
});
});
});
return event;
}

struct Argument {
template <typename T>
operator T&&() const;
};

template <typename Fn, typename Is, typename = void>
struct ArgumentsPassedImpl
: std::false_type {};

template <typename Fn, size_t ...Is>
struct ArgumentsPassedImpl<Fn, std::index_sequence<Is...>,
decltype(std::declval<Fn>()(((void)Is, Argument{})...), void())>
: std::true_type {};

template <typename Fn, size_t N>
struct ArgumentsPassed : ArgumentsPassedImpl<Fn, std::make_index_sequence<N>> {};

template <typename OutputDType, typename InputDType,
size_t BatchSize, size_t MaxNumInputs>
class BatchProcessingHelper {
public:
static constexpr size_t kBatchSize = BatchSize;
using InputType = HostDeviceVector<InputDType>;
using OutputType = HostDeviceVector<OutputDType>;

private:
template <size_t NumInput = 0>
void Host2Buffers(InputDType* in_buffer_ptr, const InputType& input) {
/*
* Some inputs may have less than 1 sample per output symbol.
*/
const size_t sub_sample_rate = ndata_ * sample_rates_[NumInput+1] / input.Size();
const size_t n_samples = batch_size_ * sample_rates_[NumInput+1] / sub_sample_rate;

const InputDType* in_host_ptr = input.HostPointer() +
batch_begin_ * sample_rates_[NumInput+1] / sub_sample_rate;

events_[NumInput] =
qu_->memcpy(in_buffer_ptr, in_host_ptr, n_samples * sizeof(InputDType),
events_[MaxNumInputs - 2]);
}

template <size_t NumInput = 0, class... InputTypes>
void Host2Buffers(InputDType* in_buffer_ptr, const InputType& input,
const InputTypes&... other_inputs) {
// Make copy for the first input in the list
Host2Buffers<NumInput>(in_buffer_ptr, input);
// Recurent call for next inputs
InputDType* next_input = in_buffer_.Data() + in_buff_offsets_[NumInput + 1];
Host2Buffers<NumInput+1>(next_input, other_inputs...);
}

void Buffers2Host(OutputType* output) {
const size_t n_samples = batch_size_ * sample_rates_[0];
OutputDType* out_host_ptr = output->HostPointer() + batch_begin_* sample_rates_[0];
events_[MaxNumInputs - 1] =
qu_->memcpy(out_host_ptr, out_buffer_.DataConst(), n_samples * sizeof(OutputDType),
events_[MaxNumInputs - 2]);
}

void Buffers2Host(InputType* output) {
const size_t n_samples = batch_size_ * sample_rates_[1];
InputDType* out_host_ptr = output->HostPointer() + batch_begin_* sample_rates_[1];
events_[MaxNumInputs - 1] =
qu_->memcpy(out_host_ptr, in_buffer_.DataConst(), n_samples * sizeof(InputDType),
events_[MaxNumInputs - 2]);
}

template <size_t NumInputs = 1, typename Fn, class... InputTypes>
void Call(Fn &&fn, const InputDType* input, const InputTypes*... other_inputs) {
static_assert(NumInputs <= MaxNumInputs,
"To many arguments in the passed function");
/* Passed lambda may have less inputs than MaxNumInputs,
* need to pass only requared number of arguments
*/
// 1 for events, 1 for batch_size, 1 for output
if constexpr (ArgumentsPassed<Fn, NumInputs + 1 + 1 + 1>::value) {
events_[MaxNumInputs - 2] = fn(events_, batch_size_,
out_buffer_.Data(), input, other_inputs...);
} else {
const InputDType* next_input = in_buffer_.DataConst() +
in_buff_offsets_[MaxNumInputs - 1 - NumInputs];
Call<NumInputs+1>(std::forward<Fn>(fn), next_input, input, other_inputs...);
}
}

template <size_t NumInputs = 1, typename Fn, class... InputTypes>
void Call(Fn &&fn, InputDType* io, const InputDType* input, const InputTypes*... other_inputs) {
static_assert(NumInputs <= MaxNumInputs,
"To many arguments in the passed function");
if constexpr (ArgumentsPassed<Fn, NumInputs + 1 + 1>::value) {
events_[MaxNumInputs - 2] = fn(events_, batch_size_,
io, input, other_inputs...);
} else {
const InputDType* next_input = in_buffer_.DataConst() +
in_buff_offsets_[MaxNumInputs - NumInputs];
Call<NumInputs+1>(std::forward<Fn>(fn), io, next_input, input, other_inputs...);
}
}

template <size_t NumInputs = 1, typename Fn>
void Call(Fn &&fn, InputDType* io) {
static_assert(NumInputs <= MaxNumInputs,
"To many arguments in the passed function");
if constexpr (ArgumentsPassed<Fn, NumInputs + 1 + 1>::value) {
events_[MaxNumInputs - 2] = fn(events_, batch_size_, io);
} else {
const InputDType* next_input = in_buffer_.DataConst() +
in_buff_offsets_[MaxNumInputs - 1];
Call<NumInputs+1>(std::forward<Fn>(fn), io, next_input);
}
}

public:
BatchProcessingHelper() = default;

// The first element of sample_rate always corresonds to output sample rate
void InitBuffers(::sycl::queue* qu, const std::vector<int>& sample_rate) {
assert(sample_rate.size() == MaxNumInputs + 1);
sample_rates_ = sample_rate;
qu_ = qu;
events_.resize(MaxNumInputs + 2);
out_buffer_.Resize(qu, kBatchSize * sample_rate.front());

in_buff_offsets_[0] = 0;
for (size_t i = 1; i < MaxNumInputs; ++i) {
in_buff_offsets_[i] = in_buff_offsets_[i - 1] + kBatchSize * sample_rate[i];
}
const size_t in_buff_size = in_buff_offsets_.back() + kBatchSize * sample_rate.back();
in_buffer_.Resize(qu, in_buff_size);
}

/*
* Batch-wise proces on sycl device
* output = fn(inputs)
*/
template <typename Fn, class... InputTypes>
void Calculate(Fn &&fn, OutputType* output, const InputTypes&... inputs) {
ndata_ = output->Size() / sample_rates_.front();
const size_t nBatch = ndata_ / kBatchSize + (ndata_ % kBatchSize > 0);
for (size_t batch = 0; batch < nBatch; ++batch) {
batch_begin_ = batch * kBatchSize;
batch_end_ = (batch == nBatch - 1) ? ndata_ : batch_begin_ + kBatchSize;
batch_size_ = batch_end_ - batch_begin_;

// Iteratively copy all inputs to device buffers
Host2Buffers(in_buffer_.Data(), inputs...);
// Pack buffers and call function
// We shift input pointer to keep the same order of inputs after packing
Call(std::forward<Fn>(fn), in_buffer_.DataConst() + in_buff_offsets_.back());
// Copy results to host
Buffers2Host(output);
}
}

/*
* Batch-wise proces on sycl device
* input = fn(input, other_inputs)
*/
template <typename Fn, class... InputTypes>
void Calculate(Fn &&fn, InputType* input, const InputTypes&... other_inputs) {
ndata_ = input->Size();
const size_t nBatch = ndata_ / kBatchSize + (ndata_ % kBatchSize > 0);
for (size_t batch = 0; batch < nBatch; ++batch) {
batch_begin_ = batch * kBatchSize;
batch_end_ = (batch == nBatch - 1) ? ndata_ : batch_begin_ + kBatchSize;
batch_size_ = batch_end_ - batch_begin_;

// Iteratively copy all inputs to device buffers.
// inputs are pased by const reference
Host2Buffers(in_buffer_.Data(), *(input), other_inputs...);
// Pack buffers and call function
// We shift input pointer to keep the same order of inputs after packing
Call(std::forward<Fn>(fn), in_buffer_.Data());
// Copy results to host
Buffers2Host(input);
}
}

private:
std::array<int, MaxNumInputs> in_buff_offsets_;
std::vector<int> sample_rates_;
size_t ndata_;
size_t batch_begin_;
size_t batch_end_;
// is not equal to kBatchSize for the last batch
size_t batch_size_;
::sycl::queue* qu_;
std::vector<::sycl::event> events_;
USMVector<InputDType, MemoryType::on_device> in_buffer_;
USMVector<OutputDType, MemoryType::on_device> out_buffer_;
};

} // namespace linalg
} // namespace sycl
} // namespace xgboost
#endif // PLUGIN_SYCL_COMMON_LINALG_OP_H_
90 changes: 59 additions & 31 deletions plugin/sycl/objective/multiclass_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@

#include "../../../src/objective/multiclass_param.h"

#include "../common/linalg_op.h"

#include "../device_manager.h"
#include "../data.h"
#include <CL/sycl.hpp>

namespace xgboost {
Expand All @@ -32,6 +35,15 @@ namespace obj {
DMLC_REGISTRY_FILE_TAG(multiclass_obj_sycl);

class SoftmaxMultiClassObj : public ObjFunction {
mutable bool are_buffs_init = false;

void InitBuffers(const std::vector<int>& sample_rate) const {
if (!are_buffs_init) {
batch_processor_.InitBuffers(&qu_, sample_rate);
are_buffs_init = true;
}
}

public:
explicit SoftmaxMultiClassObj(bool output_prob)
: output_prob_(output_prob) {}
Expand All @@ -44,7 +56,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo& info,
int iter,
linalg::Matrix<GradientPair>* out_gpair) override {
xgboost::linalg::Matrix<GradientPair>* out_gpair) override {
if (preds.Size() == 0) return;
if (info.labels.Size() == 0) return;

Expand All @@ -66,54 +78,68 @@ class SoftmaxMultiClassObj : public ObjFunction {
<< "Number of weights should be equal to number of data points.";
}

::sycl::buffer<bst_float, 1> preds_buf(preds.HostPointer(), preds.Size());
::sycl::buffer<bst_float, 1> labels_buf(info.labels.Data()->HostPointer(), info.labels.Size());
::sycl::buffer<GradientPair, 1> out_gpair_buf(out_gpair->Data()->HostPointer(),
out_gpair->Size());
::sycl::buffer<bst_float, 1> weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(),
is_null_weight ? 1 : info.weights_.Size());

int flag = 1;
{
::sycl::buffer<int, 1> flag_buf(&flag, 1);
qu_.submit([&](::sycl::handler& cgh) {
auto preds_acc = preds_buf.get_access<::sycl::access::mode::read>(cgh);
auto labels_acc = labels_buf.get_access<::sycl::access::mode::read>(cgh);
auto weights_acc = weights_buf.get_access<::sycl::access::mode::read>(cgh);
auto out_gpair_acc = out_gpair_buf.get_access<::sycl::access::mode::write>(cgh);
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
int idx = pid[0];

bst_float const * point = &preds_acc[idx * nclass];
auto objective_fn = [=, &flag]
(const std::vector<::sycl::event>& events,
size_t ndata,
GradientPair* out_gpair,
const bst_float* preds,
const bst_float* labels,
const bst_float* weights) {
const size_t wg_size = 32;
const size_t nwgs = ndata / wg_size + (ndata % wg_size > 0);
return linalg::GroupWiseKernel(&qu_, &flag, events, {nwgs, wg_size},
[=] (size_t idx, auto flag) {
const bst_float* pred = preds + idx * nclass;

// Part of Softmax function
bst_float wmax = std::numeric_limits<bst_float>::min();
for (int k = 0; k < nclass; k++) { wmax = ::sycl::max(point[k], wmax); }
float wsum = 0.0f;
for (int k = 0; k < nclass; k++) { wsum += ::sycl::exp(point[k] - wmax); }
auto label = labels_acc[idx];
for (int k = 0; k < nclass; k++) { wmax = ::sycl::max(pred[k], wmax); }
bst_float wsum = 0.0f;
for (int k = 0; k < nclass; k++) { wsum += ::sycl::exp(pred[k] - wmax); }
bst_float label = labels[idx];

if (label < 0 || label >= nclass) {
flag_buf_acc[0] = 0;
AtomicRef<int> flag_ref(flag[0]);
flag_ref = 0;
label = 0;
}
bst_float wt = is_null_weight ? 1.0f : weights_acc[idx];

bst_float wt = is_null_weight ? 1.0f : weights[idx];
for (int k = 0; k < nclass; ++k) {
bst_float p = expf(point[k] - wmax) / static_cast<float>(wsum);
bst_float p = expf(pred[k] - wmax) / static_cast<float>(wsum);
const float eps = 1e-16f;
const bst_float h = ::sycl::max(2.0f * p * (1.0f - p) * wt, eps);
p = label == k ? p - 1.0f : p;
out_gpair_acc[idx * nclass + k] = GradientPair(p * wt, h);
out_gpair[idx * nclass + k] = GradientPair(p * wt, h);
}
});
}).wait();
});
};

// out_gpair and preds have nclass points per sample
// labels and weights have 1 points per sample
InitBuffers({nclass, nclass, 1, 1});
if (is_null_weight) {
// Output is passed by pointer
// Inputs are passed by const reference
batch_processor_.Calculate(std::move(objective_fn),
out_gpair->Data(),
preds,
*(info.labels.Data()));
} else {
batch_processor_.Calculate(std::move(objective_fn),
out_gpair->Data(),
preds,
*(info.labels.Data()),
info.weights_);
}
// flag_buf is destroyed, content is copyed to the "flag"
qu_.wait_and_throw();

if (flag == 0) {
LOG(FATAL) << "SYCL::SoftmaxMultiClassObj: label must be in [0, num_class).";
}
}

void PredTransform(HostDeviceVector<bst_float>* io_preds) const override {
this->Transform(io_preds, output_prob_);
}
Expand Down Expand Up @@ -190,6 +216,8 @@ class SoftmaxMultiClassObj : public ObjFunction {
sycl::DeviceManager device_manager;

mutable ::sycl::queue qu_;
static constexpr size_t kBatchSize = 1u << 22;
mutable linalg::BatchProcessingHelper<GradientPair, bst_float, kBatchSize, 3> batch_processor_;
};

XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax_sycl")
Expand Down
Loading
Loading