Skip to content

Commit

Permalink
Ridge Regression support in oneapi (#2743)
Browse files Browse the repository at this point in the history
  • Loading branch information
DDJHB committed May 21, 2024
1 parent 908955d commit d6f4dc3
Show file tree
Hide file tree
Showing 18 changed files with 357 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ template class BatchContainer<DAAL_FPTYPE, normEqDense, DAAL_CPU>;
}
namespace internal
{
template class BatchKernel<DAAL_FPTYPE, normEqDense, DAAL_CPU>;
template class DAAL_EXPORT BatchKernel<DAAL_FPTYPE, normEqDense, DAAL_CPU>;
}
} // namespace training
} // namespace linear_regression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ template class BatchContainer<DAAL_FPTYPE, normEqDense, DAAL_CPU>;

namespace internal
{
template class BatchKernel<DAAL_FPTYPE, normEqDense, DAAL_CPU>;
template class DAAL_EXPORT BatchKernel<DAAL_FPTYPE, normEqDense, DAAL_CPU>;

} // namespace internal
} // namespace training
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ template class OnlineContainer<DAAL_FPTYPE, normEqDense, DAAL_CPU>;

namespace internal
{
template class OnlineKernel<DAAL_FPTYPE, normEqDense, DAAL_CPU>;
template class DAAL_EXPORT OnlineKernel<DAAL_FPTYPE, normEqDense, DAAL_CPU>;

} // namespace internal
} // namespace training
Expand Down
1 change: 1 addition & 0 deletions cpp/oneapi/dal/algo/linear_regression/backend/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ dal_module(
"@onedal//cpp/daal:core",
"@onedal//cpp/daal/src/algorithms/linear_model:kernel",
"@onedal//cpp/daal/src/algorithms/linear_regression:kernel",
"@onedal//cpp/daal/src/algorithms/ridge_regression:kernel"
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <daal/src/algorithms/linear_regression/linear_regression_train_kernel.h>
#include <daal/src/algorithms/linear_regression/linear_regression_hyperparameter_impl.h>
#include <daal/src/algorithms/ridge_regression/ridge_regression_train_kernel.h>

#include "oneapi/dal/backend/interop/common.hpp"
#include "oneapi/dal/backend/interop/error_converter.hpp"
Expand All @@ -37,21 +38,26 @@ namespace be = dal::backend;
namespace pr = be::primitives;
namespace interop = dal::backend::interop;
namespace daal_lr = daal::algorithms::linear_regression;
namespace daal_rr = daal::algorithms::ridge_regression;

using daal_hyperparameters_t = daal_lr::internal::Hyperparameter;
using daal_lr_hyperparameters_t = daal_lr::internal::Hyperparameter;

constexpr auto daal_method = daal_lr::training::normEqDense;
constexpr auto daal_lr_method = daal_lr::training::normEqDense;
constexpr auto daal_rr_method = daal_rr::training::normEqDense;

template <typename Float, daal::CpuType Cpu>
using online_kernel_t = daal_lr::training::internal::OnlineKernel<Float, daal_method, Cpu>;
using online_lr_kernel_t = daal_lr::training::internal::OnlineKernel<Float, daal_lr_method, Cpu>;

template <typename Float, daal::CpuType Cpu>
using online_rr_kernel_t = daal_rr::training::internal::OnlineKernel<Float, daal_rr_method, Cpu>;

template <typename Float, typename Task>
static daal_hyperparameters_t convert_parameters(const detail::train_parameters<Task>& params) {
static daal_lr_hyperparameters_t convert_parameters(const detail::train_parameters<Task>& params) {
using daal_lr::internal::HyperparameterId;

const std::int64_t block = params.get_cpu_macro_block();

daal_hyperparameters_t daal_hyperparameter;
daal_lr_hyperparameters_t daal_hyperparameter;
auto status = daal_hyperparameter.set(HyperparameterId::denseUpdateStepBlockSize, block);
interop::status_to_exception(status);

Expand All @@ -68,36 +74,58 @@ static train_result<Task> call_daal_kernel(const context_cpu& ctx,
using model_t = model<Task>;
using model_impl_t = detail::model_impl<Task>;

const bool beta = desc.get_compute_intercept();
const bool compute_intercept = desc.get_compute_intercept();

const auto response_count = input.get_partial_xty().get_row_count();
const auto ext_feature_count = input.get_partial_xty().get_column_count();

const auto feature_count = ext_feature_count - beta;
const auto feature_count = ext_feature_count - compute_intercept;

const auto betas_size = check_mul_overflow(response_count, feature_count + 1);
auto betas_arr = array<Float>::zeros(betas_size);

const daal_hyperparameters_t& hp = convert_parameters<Float>(params);

auto xtx_daal_table = interop::convert_to_daal_table<Float>(input.get_partial_xtx());
auto xty_daal_table = interop::convert_to_daal_table<Float>(input.get_partial_xty());
auto betas_daal_table =
interop::convert_to_daal_homogen_table(betas_arr, response_count, feature_count + 1);

{
const auto status = dal::backend::dispatch_by_cpu(ctx, [&](auto cpu) {
constexpr auto cpu_type = interop::to_daal_cpu_type<decltype(cpu)>::value;
return online_kernel_t<Float, cpu_type>().finalizeCompute(*xtx_daal_table,
*xty_daal_table,
*xtx_daal_table,
*xty_daal_table,
*betas_daal_table,
beta,
&hp);
});

interop::status_to_exception(status);
double alpha = desc.get_alpha();
if (alpha != 0.0) {
auto ridge_matrix_array = array<Float>::full(1, static_cast<Float>(alpha));
auto ridge_matrix = interop::convert_to_daal_homogen_table<Float>(ridge_matrix_array, 1, 1);

{
const auto status = dal::backend::dispatch_by_cpu(ctx, [&](auto cpu) {
constexpr auto cpu_type = interop::to_daal_cpu_type<decltype(cpu)>::value;
return online_rr_kernel_t<Float, cpu_type>().finalizeCompute(*xtx_daal_table,
*xty_daal_table,
*xtx_daal_table,
*xty_daal_table,
*betas_daal_table,
compute_intercept,
*ridge_matrix);
});

interop::status_to_exception(status);
}
}
else {
const daal_lr_hyperparameters_t& hp = convert_parameters<Float>(params);

{
const auto status = dal::backend::dispatch_by_cpu(ctx, [&](auto cpu) {
constexpr auto cpu_type = interop::to_daal_cpu_type<decltype(cpu)>::value;
return online_lr_kernel_t<Float, cpu_type>().finalizeCompute(*xtx_daal_table,
*xty_daal_table,
*xtx_daal_table,
*xty_daal_table,
*betas_daal_table,
compute_intercept,
&hp);
});

interop::status_to_exception(status);
}
}

auto betas_table = homogen_table::wrap(betas_arr, response_count, feature_count + 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ static partial_train_result<Task> call_daal_kernel(const context_cpu& ctx,
const partial_train_input<Task>& input) {
using dal::detail::check_mul_overflow;

const bool beta = desc.get_compute_intercept();
const bool compute_intercept = desc.get_compute_intercept();

const auto feature_count = input.get_data().get_column_count();
const auto response_count = input.get_responses().get_column_count();

const daal_hyperparameters_t& hp = convert_parameters<Float>(params);

const auto ext_feature_count = feature_count + beta;
const auto ext_feature_count = feature_count + compute_intercept;

const bool has_xtx_data = input.get_prev().get_partial_xtx().has_data();
if (has_xtx_data) {
Expand All @@ -85,7 +85,7 @@ static partial_train_result<Task> call_daal_kernel(const context_cpu& ctx,
*y_daal_table,
*daal_xtx,
*daal_xty,
beta,
compute_intercept,
&hp);

interop::status_to_exception(status);
Expand Down Expand Up @@ -117,7 +117,7 @@ static partial_train_result<Task> call_daal_kernel(const context_cpu& ctx,
*y_daal_table,
*xtx_daal_table,
*xty_daal_table,
beta,
compute_intercept,
&hp);

interop::status_to_exception(status);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <daal/src/algorithms/linear_regression/linear_regression_train_kernel.h>
#include <daal/src/algorithms/linear_regression/linear_regression_hyperparameter_impl.h>
#include <daal/src/algorithms/ridge_regression/ridge_regression_train_kernel.h>

#include "oneapi/dal/backend/interop/common.hpp"
#include "oneapi/dal/backend/interop/error_converter.hpp"
Expand All @@ -39,21 +40,26 @@ namespace be = dal::backend;
namespace pr = be::primitives;
namespace interop = dal::backend::interop;
namespace daal_lr = daal::algorithms::linear_regression;
namespace daal_rr = daal::algorithms::ridge_regression;

using daal_hyperparameters_t = daal_lr::internal::Hyperparameter;
using daal_lr_hyperparameters_t = daal_lr::internal::Hyperparameter;

constexpr auto daal_method = daal_lr::training::normEqDense;
constexpr auto daal_lr_method = daal_lr::training::normEqDense;
constexpr auto daal_rr_method = daal_rr::training::normEqDense;

template <typename Float, daal::CpuType Cpu>
using online_kernel_t = daal_lr::training::internal::OnlineKernel<Float, daal_method, Cpu>;
using batch_lr_kernel_t = daal_lr::training::internal::BatchKernel<Float, daal_lr_method, Cpu>;

template <typename Float, daal::CpuType Cpu>
using batch_rr_kernel_t = daal_rr::training::internal::BatchKernel<Float, daal_rr_method, Cpu>;

template <typename Float, typename Task>
static daal_hyperparameters_t convert_parameters(const detail::train_parameters<Task>& params) {
static daal_lr_hyperparameters_t convert_parameters(const detail::train_parameters<Task>& params) {
using daal_lr::internal::HyperparameterId;

const std::int64_t block = params.get_cpu_macro_block();

daal_hyperparameters_t daal_hyperparameter;
daal_lr_hyperparameters_t daal_hyperparameter;
auto status = daal_hyperparameter.set(HyperparameterId::denseUpdateStepBlockSize, block);
interop::status_to_exception(status);

Expand Down Expand Up @@ -97,33 +103,41 @@ static train_result<Task> call_daal_kernel(const context_cpu& ctx,
auto x_daal_table = interop::convert_to_daal_table<Float>(data);
auto y_daal_table = interop::convert_to_daal_table<Float>(resp);

const daal_hyperparameters_t& hp = convert_parameters<Float>(params);

{
const auto status = interop::call_daal_kernel<Float, online_kernel_t>(ctx,
*x_daal_table,
*y_daal_table,
*xtx_daal_table,
*xty_daal_table,
intp,
&hp);

interop::status_to_exception(status);
double alpha = desc.get_alpha();
if (alpha != 0.0) {
auto ridge_matrix_array = array<Float>::full(1, static_cast<Float>(alpha));
auto ridge_matrix = interop::convert_to_daal_homogen_table<Float>(ridge_matrix_array, 1, 1);

{
const auto status =
interop::call_daal_kernel<Float, batch_rr_kernel_t>(ctx,
*x_daal_table,
*y_daal_table,
*xtx_daal_table,
*xty_daal_table,
*betas_daal_table,
intp,
*ridge_matrix);

interop::status_to_exception(status);
}
}

{
const auto status = dal::backend::dispatch_by_cpu(ctx, [&](auto cpu) {
constexpr auto cpu_type = interop::to_daal_cpu_type<decltype(cpu)>::value;
return online_kernel_t<Float, cpu_type>().finalizeCompute(*xtx_daal_table,
*xty_daal_table,
*xtx_daal_table,
*xty_daal_table,
*betas_daal_table,
intp,
&hp);
});

interop::status_to_exception(status);
else {
const daal_lr_hyperparameters_t& hp = convert_parameters<Float>(params);

{
const auto status =
interop::call_daal_kernel<Float, batch_lr_kernel_t>(ctx,
*x_daal_table,
*y_daal_table,
*xtx_daal_table,
*xty_daal_table,
*betas_daal_table,
intp,
&hp);

interop::status_to_exception(status);
}
}

auto betas_table = homogen_table::wrap(betas_arr, response_count, feature_count + 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "oneapi/dal/algo/linear_regression/backend/model_impl.hpp"
#include "oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel.hpp"
#include "oneapi/dal/algo/linear_regression/backend/gpu/update_kernel.hpp"
#include "oneapi/dal/algo/linear_regression/backend/gpu/misc.hpp"

namespace oneapi::dal::linear_regression::backend {

Expand All @@ -47,14 +48,14 @@ static train_result<Task> call_dal_kernel(const context_gpu& ctx,

auto& queue = ctx.get_queue();

const bool beta = desc.get_compute_intercept();
const bool compute_intercept = desc.get_compute_intercept();

constexpr auto uplo = pr::mkl::uplo::upper;
constexpr auto alloc = sycl::usm::alloc::device;

const auto response_count = input.get_partial_xty().get_row_count();
const auto ext_feature_count = input.get_partial_xty().get_column_count();
const auto feature_count = ext_feature_count - beta;
const auto feature_count = ext_feature_count - compute_intercept;

const pr::ndshape<2> xtx_shape{ ext_feature_count, ext_feature_count };

Expand All @@ -69,9 +70,21 @@ static train_result<Task> call_dal_kernel(const context_gpu& ctx,
const auto betas_size = check_mul_overflow(response_count, feature_count + 1);
auto betas_arr = array<Float>::zeros(queue, betas_size, alloc);

double alpha = desc.get_alpha();
sycl::event ridge_event;
if (alpha != 0.0) {
ridge_event = add_ridge_penalty<Float>(queue, xtx_nd, compute_intercept, alpha);
}

auto nxtx = pr::ndarray<Float, 2>::empty(queue, xtx_shape, alloc);
auto nxty = pr::ndview<Float, 2>::wrap_mutable(betas_arr, betas_shape);
auto solve_event = pr::solve_system<uplo>(queue, beta, xtx_nd, xty_nd, nxtx, nxty, {});
auto solve_event = pr::solve_system<uplo>(queue,
compute_intercept,
xtx_nd,
xty_nd,
nxtx,
nxty,
{ ridge_event });
sycl::event::wait_and_throw({ solve_event });

auto betas = homogen_table::wrap(betas_arr, response_count, feature_count + 1);
Expand Down
Loading

0 comments on commit d6f4dc3

Please sign in to comment.