Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SYCL. Optimize gradients calculations. #10325

Merged
merged 6 commits into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 81 additions & 36 deletions plugin/sycl/objective/multiclass_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "../../../src/objective/multiclass_param.h"

#include "../device_manager.h"
#include "../data.h"
#include <CL/sycl.hpp>

namespace xgboost {
Expand All @@ -32,6 +33,20 @@ namespace obj {
DMLC_REGISTRY_FILE_TAG(multiclass_obj_sycl);

class SoftmaxMultiClassObj : public ObjFunction {
static constexpr size_t kBatchSize = 1u << 22;
mutable bool are_buffs_init = false;

void InitBuffers() const {
if (!are_buffs_init) {
events_.resize(5);
preds_.Resize(&qu_, kBatchSize);
labels_.Resize(&qu_, kBatchSize);
weights_.Resize(&qu_, kBatchSize);
out_gpair_.Resize(&qu_, kBatchSize);
are_buffs_init = true;
}
}

public:
explicit SoftmaxMultiClassObj(bool output_prob)
: output_prob_(output_prob) {}
Expand All @@ -47,6 +62,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
linalg::Matrix<GradientPair>* out_gpair) override {
if (preds.Size() == 0) return;
if (info.labels.Size() == 0) return;
InitBuffers();

CHECK(preds.Size() == (static_cast<size_t>(param_.num_class) * info.labels.Size()))
<< "SoftmaxMultiClassObj: label size and pred size does not match.\n"
Expand All @@ -66,47 +82,70 @@ class SoftmaxMultiClassObj : public ObjFunction {
<< "Number of weights should be equal to number of data points.";
}

::sycl::buffer<bst_float, 1> preds_buf(preds.HostPointer(), preds.Size());
::sycl::buffer<bst_float, 1> labels_buf(info.labels.Data()->HostPointer(), info.labels.Size());
::sycl::buffer<GradientPair, 1> out_gpair_buf(out_gpair->Data()->HostPointer(),
out_gpair->Size());
::sycl::buffer<bst_float, 1> weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(),
is_null_weight ? 1 : info.weights_.Size());
bst_float* preds_ptr = preds_.Data();
bst_float* labels_ptr = labels_.Data();
bst_float* weights_ptr = weights_.Data();
GradientPair* out_gpair_ptr = out_gpair_.Data();

int flag = 1;
int wg_size = 32;
const size_t nBatch = ndata / kBatchSize + (ndata % kBatchSize > 0);
Copy link
Member

@trivialfis trivialfis May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that we can have a separate function or a set of helpers for calculating batches that are independent of the objective functions? It would be great if we could start the objective function with a simpler structure and avoid things that are not strictly related to the math itself (like batch, memcpy, etc). I know some of the existing objective implementations do the same, but I hope we can improve them with more and more objectives defined in xgboost.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I have introduced the BatchProcessingHelper class. It conceals most of the details regarding batch processing implementation.

{
::sycl::buffer<int, 1> flag_buf(&flag, 1);
qu_.submit([&](::sycl::handler& cgh) {
auto preds_acc = preds_buf.get_access<::sycl::access::mode::read>(cgh);
auto labels_acc = labels_buf.get_access<::sycl::access::mode::read>(cgh);
auto weights_acc = weights_buf.get_access<::sycl::access::mode::read>(cgh);
auto out_gpair_acc = out_gpair_buf.get_access<::sycl::access::mode::write>(cgh);
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
int idx = pid[0];

bst_float const * point = &preds_acc[idx * nclass];

// Part of Softmax function
bst_float wmax = std::numeric_limits<bst_float>::min();
for (int k = 0; k < nclass; k++) { wmax = ::sycl::max(point[k], wmax); }
float wsum = 0.0f;
for (int k = 0; k < nclass; k++) { wsum += ::sycl::exp(point[k] - wmax); }
auto label = labels_acc[idx];
if (label < 0 || label >= nclass) {
flag_buf_acc[0] = 0;
label = 0;
}
bst_float wt = is_null_weight ? 1.0f : weights_acc[idx];
for (int k = 0; k < nclass; ++k) {
bst_float p = expf(point[k] - wmax) / static_cast<float>(wsum);
const float eps = 1e-16f;
const bst_float h = ::sycl::max(2.0f * p * (1.0f - p) * wt, eps);
p = label == k ? p - 1.0f : p;
out_gpair_acc[idx * nclass + k] = GradientPair(p * wt, h);
}
for (size_t batch = 0; batch < nBatch; ++batch) {
const size_t begin = batch * kBatchSize;
const size_t end = (batch == nBatch - 1) ? ndata : begin + kBatchSize;
const size_t batch_size = end - begin;
int nwgs = (batch_size / wg_size + (batch_size % wg_size > 0));

events_[0] = qu_.memcpy(preds_ptr, preds.HostPointer() + begin * nclass,
batch_size * nclass * sizeof(bst_float), events_[3]);
events_[1] = qu_.memcpy(labels_ptr, info.labels.Data()->HostPointer() + begin,
batch_size * sizeof(bst_float), events_[3]);
if (!is_null_weight) {
events_[2] = qu_.memcpy(weights_ptr, info.weights_.HostPointer() + begin,
info.weights_.Size() * sizeof(bst_float), events_[3]);
}


events_[3] = qu_.submit([&](::sycl::handler& cgh) {
cgh.depends_on(events_);
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::write>(cgh);
cgh.parallel_for_work_group<>(::sycl::range<1>(nwgs), ::sycl::range<1>(wg_size),
[=](::sycl::group<1> group) {
group.parallel_for_work_item([&](::sycl::h_item<1> item) {
const size_t idx = item.get_global_id()[0];

const bst_float* pred = preds_ptr + idx * nclass;

// Part of Softmax function
bst_float wmax = std::numeric_limits<bst_float>::min();
for (int k = 0; k < nclass; k++) { wmax = ::sycl::max(pred[k], wmax); }
bst_float wsum = 0.0f;
for (int k = 0; k < nclass; k++) { wsum += ::sycl::exp(pred[k] - wmax); }
bst_float label = labels_ptr[idx];

if (label < 0 || label >= nclass) {
AtomicRef<int> flag_ref(flag_buf_acc[0]);
flag_ref = 0;
label = 0;
}

bst_float wt = is_null_weight ? 1.0f : weights_ptr[idx];
for (int k = 0; k < nclass; ++k) {
bst_float p = expf(pred[k] - wmax) / static_cast<float>(wsum);
const float eps = 1e-16f;
const bst_float h = ::sycl::max(2.0f * p * (1.0f - p) * wt, eps);
p = label == k ? p - 1.0f : p;
out_gpair_ptr[idx * nclass + k] = GradientPair(p * wt, h);
}
});
});
});
}).wait();
events_[4] = qu_.memcpy(out_gpair->Data()->HostPointer() + begin * nclass, out_gpair_ptr,
batch_size * nclass * sizeof(GradientPair), events_[3]);
}
qu_.wait_and_throw();
}
// flag_buf is destroyed, content is copyed to the "flag"

Expand Down Expand Up @@ -190,6 +229,12 @@ class SoftmaxMultiClassObj : public ObjFunction {
sycl::DeviceManager device_manager;

mutable ::sycl::queue qu_;
mutable std::vector<::sycl::event> events_;
// Buffers
mutable USMVector<bst_float, MemoryType::on_device> preds_;
mutable USMVector<bst_float, MemoryType::on_device> labels_;
mutable USMVector<bst_float, MemoryType::on_device> weights_;
mutable USMVector<GradientPair, MemoryType::on_device> out_gpair_;
};

XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax_sycl")
Expand Down
177 changes: 116 additions & 61 deletions plugin/sycl/objective/regression_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "../../../src/objective/regression_param.h"

#include "../device_manager.h"
#include "../data.h"

#include <CL/sycl.hpp>

Expand All @@ -41,6 +42,19 @@ template<typename Loss>
class RegLossObj : public ObjFunction {
protected:
HostDeviceVector<int> label_correct_;
static constexpr size_t kBatchSize = 1u << 22;
mutable bool are_buffs_init = false;

void InitBuffers() const {
if (!are_buffs_init) {
events_.resize(5);
preds_.Resize(&qu_, kBatchSize);
labels_.Resize(&qu_, kBatchSize);
weights_.Resize(&qu_, kBatchSize);
out_gpair_.Resize(&qu_, kBatchSize);
are_buffs_init = true;
}
}

public:
RegLossObj() = default;
Expand All @@ -54,62 +68,83 @@ class RegLossObj : public ObjFunction {
const MetaInfo &info,
int iter,
linalg::Matrix<GradientPair>* out_gpair) override {
if (info.labels.Size() == 0) return;
CHECK_EQ(preds.Size(), info.labels.Size())
<< " " << "labels are not correctly provided"
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels.Size() << ", "
<< "Loss: " << Loss::Name();

size_t const ndata = preds.Size();
auto const n_targets = this->Targets(info);
out_gpair->Reshape(info.num_row_, n_targets);

// TODO(razdoburdin): add label_correct check
label_correct_.Resize(1);
label_correct_.Fill(1);

bool is_null_weight = info.weights_.Size() == 0;

::sycl::buffer<bst_float, 1> preds_buf(preds.HostPointer(), preds.Size());
::sycl::buffer<bst_float, 1> labels_buf(info.labels.Data()->HostPointer(), info.labels.Size());
::sycl::buffer<GradientPair, 1> out_gpair_buf(out_gpair->Data()->HostPointer(),
out_gpair->Size());
::sycl::buffer<bst_float, 1> weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(),
is_null_weight ? 1 : info.weights_.Size());

auto scale_pos_weight = param_.scale_pos_weight;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), info.labels.Shape(0))
<< "Number of weights should be equal to number of data points.";
}
if (info.labels.Size() == 0) return;
CHECK_EQ(preds.Size(), info.labels.Size())
<< " " << "labels are not correctly provided"
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels.Size() << ", "
<< "Loss: " << Loss::Name();

InitBuffers();
size_t const ndata = preds.Size();
auto const n_targets = this->Targets(info);
out_gpair->Reshape(info.num_row_, n_targets);

// TODO(razdoburdin): add label_correct check
label_correct_.Resize(1);
label_correct_.Fill(1);

bool is_null_weight = info.weights_.Size() == 0;

bst_float* preds_ptr = preds_.Data();
bst_float* labels_ptr = labels_.Data();
bst_float* weights_ptr = weights_.Data();
GradientPair* out_gpair_ptr = out_gpair_.Data();

auto scale_pos_weight = param_.scale_pos_weight;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), info.labels.Shape(0))
<< "Number of weights should be equal to number of data points.";
}

int flag = 1;
{
::sycl::buffer<int, 1> flag_buf(&flag, 1);
qu_.submit([&](::sycl::handler& cgh) {
auto preds_acc = preds_buf.get_access<::sycl::access::mode::read>(cgh);
auto labels_acc = labels_buf.get_access<::sycl::access::mode::read>(cgh);
auto weights_acc = weights_buf.get_access<::sycl::access::mode::read>(cgh);
auto out_gpair_acc = out_gpair_buf.get_access<::sycl::access::mode::write>(cgh);
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
int idx = pid[0];
bst_float p = Loss::PredTransform(preds_acc[idx]);
bst_float w = is_null_weight ? 1.0f : weights_acc[idx/n_targets];
bst_float label = labels_acc[idx];
if (label == 1.0f) {
w *= scale_pos_weight;
}
if (!Loss::CheckLabel(label)) {
// If there is an incorrect label, the host code will know.
flag_buf_acc[0] = 0;
}
out_gpair_acc[idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w,
Loss::SecondOrderGradient(p, label) * w);
int flag = 1;
const int wg_size = 32;
const size_t nBatch = ndata / kBatchSize + (ndata % kBatchSize > 0);
{
::sycl::buffer<int, 1> flag_buf(&flag, 1);
for (size_t batch = 0; batch < nBatch; ++batch) {
const size_t begin = batch * kBatchSize;
const size_t end = (batch == nBatch - 1) ? ndata : begin + kBatchSize;
const size_t batch_size = end - begin;
int nwgs = (batch_size / wg_size + (batch_size % wg_size > 0));

events_[0] = qu_.memcpy(preds_ptr, preds.HostPointer() + begin,
batch_size * sizeof(bst_float), events_[3]);
events_[1] = qu_.memcpy(labels_ptr, info.labels.Data()->HostPointer() + begin,
batch_size * sizeof(bst_float), events_[3]);
if (!is_null_weight) {
events_[2] = qu_.memcpy(weights_ptr, info.weights_.HostPointer() + begin,
info.weights_.Size() * sizeof(bst_float), events_[3]);
}

events_[3] = qu_.submit([&](::sycl::handler& cgh) {
cgh.depends_on(events_);
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::write>(cgh);
cgh.parallel_for_work_group<>(::sycl::range<1>(nwgs), ::sycl::range<1>(wg_size),
[=](::sycl::group<1> group) {
group.parallel_for_work_item([&](::sycl::h_item<1> item) {
const size_t idx = item.get_global_id()[0];

const bst_float pred = Loss::PredTransform(preds_ptr[idx]);
bst_float weight = is_null_weight ? 1.0f : weights_ptr[idx/n_targets];
const bst_float label = labels_ptr[idx];
if (label == 1.0f) {
weight *= scale_pos_weight;
}
if (!Loss::CheckLabel(label)) {
AtomicRef<int> flag_ref(flag_buf_acc[0]);
flag_ref = 0;
}
out_gpair_ptr[idx] = GradientPair(Loss::FirstOrderGradient(pred, label) * weight,
Loss::SecondOrderGradient(pred, label) * weight);
});
});
});
}).wait();
}
// flag_buf is destroyed, content is copyed to the "flag"
events_[4] = qu_.memcpy(out_gpair->Data()->HostPointer() + begin, out_gpair_ptr,
batch_size * sizeof(GradientPair), events_[3]);
}
qu_.wait_and_throw();
}
// flag_buf is destroyed, content is copyed to the "flag"

if (flag == 0) {
LOG(FATAL) << Loss::LabelErrorMsg();
Expand All @@ -124,15 +159,29 @@ class RegLossObj : public ObjFunction {
void PredTransform(HostDeviceVector<float> *io_preds) const override {
size_t const ndata = io_preds->Size();
if (ndata == 0) return;
::sycl::buffer<bst_float, 1> io_preds_buf(io_preds->HostPointer(), io_preds->Size());
InitBuffers();

::sycl::event event;
bst_float* preds_ptr = preds_.Data();
const size_t nBatch = ndata / kBatchSize + (ndata % kBatchSize > 0);
for (size_t batch = 0; batch < nBatch; ++batch) {
const size_t begin = batch * kBatchSize;
const size_t end = (batch == nBatch - 1) ? ndata : begin + kBatchSize;
const size_t batch_size = end - begin;

qu_.submit([&](::sycl::handler& cgh) {
auto io_preds_acc = io_preds_buf.get_access<::sycl::access::mode::read_write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
int idx = pid[0];
io_preds_acc[idx] = Loss::PredTransform(io_preds_acc[idx]);
event = qu_.memcpy(preds_ptr, io_preds->HostPointer() + begin,
batch_size * sizeof(bst_float), event);

event = qu_.submit([&](::sycl::handler& cgh) {
cgh.depends_on(event);
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
int idx = pid[0];
preds_ptr[idx] = Loss::PredTransform(preds_ptr[idx]);
});
});
}).wait();
event = qu_.memcpy(io_preds->HostPointer(), preds_ptr, batch_size*sizeof(bst_float), event);
}
qu_.wait_and_throw();
}

float ProbToMargin(float base_score) const override {
Expand Down Expand Up @@ -163,6 +212,12 @@ class RegLossObj : public ObjFunction {
sycl::DeviceManager device_manager;

mutable ::sycl::queue qu_;
mutable std::vector<::sycl::event> events_;
// Buffers
mutable USMVector<bst_float, MemoryType::on_device> preds_;
mutable USMVector<bst_float, MemoryType::on_device> labels_;
mutable USMVector<bst_float, MemoryType::on_device> weights_;
mutable USMVector<GradientPair, MemoryType::on_device> out_gpair_;
};

XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegression,
Expand Down
Loading