Skip to content

Commit

Permalink
Add sample_weight to Coordinate Descent solver (Lasso and ElasticNe…
Browse files Browse the repository at this point in the history
…t) (rapidsai#4867)

Linking rapidsai#669.
This PR adds `sample_weight` parameter to the C++ Coordinate Descent solver, which is used by Lasso and ElasticNet.
With some tests on C++ and Python level.
I am also removing some cudaStream parameters when the raft handle can be used.

Authors:
  - Micka (https://github.com/lowener)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#4867
  • Loading branch information
lowener committed Aug 31, 2022
1 parent d81465d commit 1cbcb63
Show file tree
Hide file tree
Showing 7 changed files with 284 additions and 90 deletions.
56 changes: 53 additions & 3 deletions cpp/include/cuml/solvers/solver.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -101,6 +101,54 @@ void sgdPredictBinaryClass(raft::handle_t& handle,
double* preds,
int loss);

/**
* Fits a linear, lasso, and elastic-net regression model using Coordinate Descent solver.
*
* i.e. finds coefficients that minimize the following loss function:
*
* f(coef) = 1/2 * || labels - input * coef ||^2
* + 1/2 * alpha * (1 - l1_ratio) * ||coef||^2
* + alpha * l1_ratio * ||coef||_1
*
*
* @param handle
* Reference of raft::handle_t
* @param input
* pointer to an array in column-major format (size of n_rows, n_cols)
* @param n_rows
* n_samples or rows in input
* @param n_cols
* n_features or columns in X
* @param labels
* pointer to an array for labels (size of n_rows)
* @param coef
* pointer to an array for coefficients (size of n_cols). This will be filled with
* coefficients once the function is executed.
* @param intercept
* pointer to a scalar for intercept. This will be filled
* once the function is executed
* @param fit_intercept
* boolean parameter to control if the intercept will be fitted or not
* @param normalize
* boolean parameter to control if the data will be normalized or not;
* NB: the input is scaled by the column-wise biased sample standard deviation estimator.
* @param epochs
* Maximum number of iterations that solver will run
* @param loss
* enum to use different loss functions. Only linear regression loss functions is supported
* right now
* @param alpha
* L1 parameter
* @param l1_ratio
* ratio of alpha will be used for L1. (1 - l1_ratio) * alpha will be used for L2
* @param shuffle
* boolean parameter to control whether coordinates will be picked randomly or not
* @param tol
* tolerance to stop the solver
* @param sample_weight
* device pointer to sample weight vector of length n_rows (nullptr or uniform weights)
* This vector is modified during the computation
*/
void cdFit(raft::handle_t& handle,
float* input,
int n_rows,
Expand All @@ -115,7 +163,8 @@ void cdFit(raft::handle_t& handle,
float alpha,
float l1_ratio,
bool shuffle,
float tol);
float tol,
float* sample_weight = nullptr);

void cdFit(raft::handle_t& handle,
double* input,
Expand All @@ -131,7 +180,8 @@ void cdFit(raft::handle_t& handle,
double alpha,
double l1_ratio,
bool shuffle,
double tol);
double tol,
double* sample_weight = nullptr);

void cdPredict(raft::handle_t& handle,
const float* input,
Expand Down
77 changes: 58 additions & 19 deletions cpp/src/solver/cd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,20 @@
#include <raft/common/nvtx.hpp>
#include <raft/cuda_utils.cuh>
#include <raft/cudart_utils.h>
#include <raft/linalg/add.hpp>
#include <raft/linalg/axpy.hpp>
#include <raft/linalg/eltwise.hpp>
#include <raft/linalg/gemm.hpp>
#include <raft/linalg/gemv.hpp>
#include <raft/linalg/multiply.hpp>
#include <raft/linalg/subtract.hpp>
#include <raft/linalg/unary_op.hpp>
#include <raft/matrix/math.hpp>
#include <raft/matrix/matrix.hpp>
#include <raft/linalg/add.cuh>
#include <raft/linalg/axpy.cuh>
#include <raft/linalg/eltwise.cuh>
#include <raft/linalg/gemm.cuh>
#include <raft/linalg/gemv.cuh>
#include <raft/linalg/map.cuh>
#include <raft/linalg/multiply.cuh>
#include <raft/linalg/power.cuh>
#include <raft/linalg/sqrt.cuh>
#include <raft/linalg/subtract.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/matrix/math.cuh>
#include <raft/matrix/matrix.cuh>
#include <raft/stats/sum.cuh>

namespace ML {
namespace Solver {
Expand Down Expand Up @@ -123,8 +127,9 @@ __global__ void __launch_bounds__(1, 1) cdUpdateCoefKernel(math_t* coefLoc,
* boolean parameter to control whether coordinates will be picked randomly or not
* @param tol
* tolerance to stop the solver
* @param stream
* cuda stream
* @param sample_weight
* device pointer to sample weight vector of length n_rows (nullptr or uniform weights)
* This vector is modified during the computation
*/
template <typename math_t>
void cdFit(const raft::handle_t& handle,
Expand All @@ -142,20 +147,30 @@ void cdFit(const raft::handle_t& handle,
math_t l1_ratio,
bool shuffle,
math_t tol,
cudaStream_t stream)
math_t* sample_weight = nullptr)
{
raft::common::nvtx::range fun_scope("ML::Solver::cdFit-%d-%d", n_rows, n_cols);
ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one");
ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two");
ASSERT(loss == ML::loss_funct::SQRD_LOSS,
"Parameter loss: Only SQRT_LOSS function is supported for now");

cudaStream_t stream = handle.get_stream();
rmm::device_uvector<math_t> residual(n_rows, stream);
rmm::device_uvector<math_t> squared(n_cols, stream);
rmm::device_uvector<math_t> mu_input(0, stream);
rmm::device_uvector<math_t> mu_labels(0, stream);
rmm::device_uvector<math_t> norm2_input(0, stream);
math_t h_sum_sw = 0;

if (sample_weight != nullptr) {
rmm::device_scalar<math_t> sum_sw(stream);
raft::stats::sum(sum_sw.data(), sample_weight, 1, n_rows, true, stream);
raft::update_host(&h_sum_sw, sum_sw.data(), 1, stream);

raft::linalg::multiplyScalar(
sample_weight, sample_weight, (math_t)n_rows / h_sum_sw, n_rows, stream);
}
if (fit_intercept) {
mu_input.resize(n_cols, stream);
mu_labels.resize(1, stream);
Expand All @@ -171,7 +186,20 @@ void cdFit(const raft::handle_t& handle,
mu_labels.data(),
norm2_input.data(),
fit_intercept,
normalize);
normalize,
sample_weight);
}
if (sample_weight != nullptr) {
raft::linalg::sqrt(sample_weight, sample_weight, n_rows, stream);
raft::matrix::matrixVectorBinaryMult(
input, sample_weight, n_rows, n_cols, false, false, stream);
raft::linalg::map(
labels,
n_rows,
[] __device__(math_t a, math_t b) { return a * b; },
stream,
labels,
sample_weight);
}

std::vector<int> ri(n_cols);
Expand Down Expand Up @@ -254,6 +282,20 @@ void cdFit(const raft::handle_t& handle,
if (h_convState.coefMax < tol || (h_convState.diffMax / h_convState.coefMax) < tol) break;
}

if (sample_weight != nullptr) {
raft::matrix::matrixVectorBinaryDivSkipZero(
input, sample_weight, n_rows, n_cols, false, false, stream);
raft::linalg::map(
labels,
n_rows,
[] __device__(math_t a, math_t b) { return a / b; },
stream,
labels,
sample_weight);
raft::linalg::powerScalar(sample_weight, sample_weight, (math_t)2, n_rows, stream);
raft::linalg::multiplyScalar(sample_weight, sample_weight, h_sum_sw / n_rows, n_rows, stream);
}

if (fit_intercept) {
GLM::postProcessData(handle,
input,
Expand Down Expand Up @@ -293,8 +335,6 @@ void cdFit(const raft::handle_t& handle,
* @param loss
* enum to use different loss functions. Only linear regression loss functions is supported
* right now.
* @param stream
* cuda stream
*/
template <typename math_t>
void cdPredict(const raft::handle_t& handle,
Expand All @@ -304,15 +344,14 @@ void cdPredict(const raft::handle_t& handle,
const math_t* coef,
math_t intercept,
math_t* preds,
ML::loss_funct loss,
cudaStream_t stream)
ML::loss_funct loss)
{
ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one");
ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two");
ASSERT(loss == ML::loss_funct::SQRD_LOSS,
"Parameter loss: Only SQRT_LOSS function is supported for now");

Functions::linearRegH(handle, input, n_rows, n_cols, coef, preds, intercept, stream);
Functions::linearRegH(handle, input, n_rows, n_cols, coef, preds, intercept, handle.get_stream());
}

}; // namespace Solver
Expand Down
74 changes: 38 additions & 36 deletions cpp/src/solver/solver.cu
Original file line number Diff line number Diff line change
Expand Up @@ -297,28 +297,29 @@ void cdFit(raft::handle_t& handle,
float alpha,
float l1_ratio,
bool shuffle,
float tol)
float tol,
float* sample_weight)
{
ASSERT(loss == 0, "Parameter loss: Only SQRT_LOSS function is supported for now");

ML::loss_funct loss_funct = ML::loss_funct::SQRD_LOSS;

cdFit(handle,
input,
n_rows,
n_cols,
labels,
coef,
intercept,
fit_intercept,
normalize,
epochs,
loss_funct,
alpha,
l1_ratio,
shuffle,
tol,
handle.get_stream());
cdFit<float>(handle,
input,
n_rows,
n_cols,
labels,
coef,
intercept,
fit_intercept,
normalize,
epochs,
loss_funct,
alpha,
l1_ratio,
shuffle,
tol,
sample_weight);
}

void cdFit(raft::handle_t& handle,
Expand All @@ -335,28 +336,29 @@ void cdFit(raft::handle_t& handle,
double alpha,
double l1_ratio,
bool shuffle,
double tol)
double tol,
double* sample_weight)
{
ASSERT(loss == 0, "Parameter loss: Only SQRT_LOSS function is supported for now");

ML::loss_funct loss_funct = ML::loss_funct::SQRD_LOSS;

cdFit(handle,
input,
n_rows,
n_cols,
labels,
coef,
intercept,
fit_intercept,
normalize,
epochs,
loss_funct,
alpha,
l1_ratio,
shuffle,
tol,
handle.get_stream());
cdFit<double>(handle,
input,
n_rows,
n_cols,
labels,
coef,
intercept,
fit_intercept,
normalize,
epochs,
loss_funct,
alpha,
l1_ratio,
shuffle,
tol,
sample_weight);
}

void cdPredict(raft::handle_t& handle,
Expand All @@ -375,7 +377,7 @@ void cdPredict(raft::handle_t& handle,
ASSERT(false, "glm.cu: other functions are not supported yet.");
}

cdPredict(handle, input, n_rows, n_cols, coef, intercept, preds, loss_funct, handle.get_stream());
cdPredict<float>(handle, input, n_rows, n_cols, coef, intercept, preds, loss_funct);
}

void cdPredict(raft::handle_t& handle,
Expand All @@ -394,7 +396,7 @@ void cdPredict(raft::handle_t& handle,
ASSERT(false, "glm.cu: other functions are not supported yet.");
}

cdPredict(handle, input, n_rows, n_cols, coef, intercept, preds, loss_funct, handle.get_stream());
cdPredict<double>(handle, input, n_rows, n_cols, coef, intercept, preds, loss_funct);
}

} // namespace Solver
Expand Down

0 comments on commit 1cbcb63

Please sign in to comment.