From bd1d7b260a4dcc0ce129a1b737bd1ea5f7c278b9 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 18 Oct 2021 13:43:25 +0530 Subject: [PATCH 01/20] [WIP] Implement fair XGBoost. This implements only the objective and meta info. --- doc/parameter.rst | 9 +- include/xgboost/data.h | 4 +- include/xgboost/linalg.h | 3 +- include/xgboost/metric.h | 5 +- python-package/xgboost/core.py | 30 +++++- python-package/xgboost/dask.py | 10 +- src/common/common.h | 10 ++ src/common/fair_param.h | 14 +++ src/common/linalg_op.cuh | 15 +++ src/common/linalg_op.h | 21 +++- src/common/math.h | 6 +- src/data/data.cc | 23 ++++- src/data/data.cu | 2 +- src/learner.cc | 49 +++++---- src/metric/auc.cc | 14 +-- src/metric/auc.cu | 10 +- src/metric/auc.h | 12 --- src/metric/elementwise_metric.cu | 104 ++++++++++++++++--- src/objective/regression_obj.cu | 115 ++++++++++++++++++++- tests/cpp/common/test_linalg.cc | 8 +- tests/cpp/common/test_linalg.cu | 11 +- tests/cpp/data/test_metainfo.h | 4 +- tests/cpp/objective/test_regression_obj.cc | 60 ++++++++++- tests/cpp/test_learner.cc | 4 +- 24 files changed, 446 insertions(+), 97 deletions(-) create mode 100644 src/common/fair_param.h diff --git a/doc/parameter.rst b/doc/parameter.rst index 0ec58026a15c..37b55cb19f1c 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -352,6 +352,11 @@ Parameters for Tweedie Regression (``objective=reg:tweedie``) - Set closer to 2 to shift towards a gamma distribution - Set closer to 1 to shift towards a Poisson distribution. +Parameter for Fair Classification (``objective=binary:regularized``) +==================================================================== + +* ``fairness``: The strength of regularization, must be greater than 0. + ************************ Learning Task Parameters ************************ @@ -361,9 +366,10 @@ Specify the learning task and the corresponding learning objective. The objectiv - ``reg:squarederror``: regression with squared loss. - ``reg:squaredlogerror``: regression with squared log loss :math:`\frac{1}{2}[log(pred + 1) - log(label + 1)]^2`. All input labels are required to be greater than -1. Also, see metric ``rmsle`` for possible issue with this objective. - - ``reg:logistic``: logistic regression + - ``reg:logistic``: logistic regression. - ``reg:pseudohubererror``: regression with Pseudo Huber loss, a twice differentiable alternative to absolute loss. - ``binary:logistic``: logistic regression for binary classification, output probability + - ``binary:regularized`` regularized logistic binary classification, outputs probability. - ``binary:logitraw``: logistic regression for binary classification, output score before logistic transformation - ``binary:hinge``: hinge loss for binary classification. This makes predictions of 0 or 1, rather than producing probabilities. - ``count:poisson``: Poisson regression for count data, output mean of Poisson distribution. @@ -400,6 +406,7 @@ Specify the learning task and the corresponding learning objective. The objectiv - ``mape``: `mean absolute percentage error `_ - ``mphe``: `mean Pseudo Huber error `_. Default metric of ``reg:pseudohubererror`` objective. - ``logloss``: `negative log-likelihood `_ + - ``regularized-logloss``: Default metric for ``binary:regularized``. - ``error``: Binary classification error rate. It is calculated as ``#(wrong cases)/#(all cases)``. For the predictions, the evaluation will regard the instances with prediction value larger than 0.5 as positive instances, and the others as negative instances. - ``error@t``: a different than 0.5 binary classification threshold value could be specified by providing a numerical value through 't'. - ``merror``: Multiclass classification error rate. It is calculated as ``#(wrong cases)/#(all cases)``. diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 7399b8265377..0c931c1cc9a7 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -46,7 +46,7 @@ enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 }; class MetaInfo { public: /*! \brief number of data fields in MetaInfo */ - static constexpr uint64_t kNumField = 12; + static constexpr uint64_t kNumField = 13; /*! \brief number of rows in the data */ uint64_t num_row_{0}; // NOLINT @@ -63,6 +63,8 @@ class MetaInfo { std::vector group_ptr_; // NOLINT /*! \brief weights of each instance, optional */ HostDeviceVector weights_; // NOLINT + /*! \brief sensitive feature of each instance, optional */ + linalg::Tensor sensitive_features; // NOLINT /*! * \brief initialized margins, * if specified, xgboost will start from this init margin diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index 897a7330189d..5caba4ae43a6 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -545,8 +545,7 @@ using VectorView = TensorView; */ template auto MakeVec(T *ptr, size_t s, int32_t device = -1) { - using U = std::remove_const_t> const; - return linalg::TensorView{{ptr, s}, {s}, device}; + return linalg::TensorView{{ptr, s}, {s}, device}; } /** diff --git a/include/xgboost/metric.h b/include/xgboost/metric.h index 42d517819b14..0ce0d11ce807 100644 --- a/include/xgboost/metric.h +++ b/include/xgboost/metric.h @@ -48,7 +48,10 @@ class Metric : public Configurable { * override this function to maintain internal configuration * \param out pointer to output JSON object */ - void SaveConfig(Json*) const override {} + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["name"] = String(this->Name()); + } /*! * \brief evaluate a specific metric diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 22564db80267..56c81fe4b73b 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -526,6 +526,7 @@ def __init__( qid=None, label_lower_bound=None, label_upper_bound=None, + sensitive_feature=None, feature_weights=None, enable_categorical: bool = False, ) -> None: @@ -575,6 +576,8 @@ def __init__( Lower bound for survival training. label_upper_bound : array_like Upper bound for survival training. + sensitive_feature: array_like + Sensitive feature for each training sample. feature_weights : array_like, optional Set feature weights for column sampling. enable_categorical: boolean, optional @@ -625,6 +628,7 @@ def __init__( qid=qid, label_lower_bound=label_lower_bound, label_upper_bound=label_upper_bound, + sensitive_feature=sensitive_feature, feature_weights=feature_weights, ) @@ -676,6 +680,7 @@ def set_info( qid=None, label_lower_bound=None, label_upper_bound=None, + sensitive_feature=None, feature_names: FeatNamesT = None, feature_types: Optional[List[str]] = None, feature_weights=None @@ -687,6 +692,8 @@ def set_info( self.set_label(label) if weight is not None: self.set_weight(weight) + if sensitive_feature is not None: + self.set_sensitive_feature(sensitive_feature) if base_margin is not None: self.set_base_margin(base_margin) if group is not None: @@ -836,6 +843,17 @@ def set_weight(self, weight) -> None: from .data import dispatch_meta_backend dispatch_meta_backend(self, weight, 'weight', 'float') + def set_sensitive_feature(self, sensitive_feature) -> None: + """Set sensitive_feature of each instance. + + Parameters + ---------- + sensitive_feature : array like + Sensitive feature for each data point + """ + from .data import dispatch_meta_backend + dispatch_meta_backend(self, sensitive_feature, 'sensitive_feature', 'float') + def set_base_margin(self, margin) -> None: """Set base margin of booster to start from. @@ -882,7 +900,15 @@ def get_weight(self) -> np.ndarray: """ return self.get_float_info('weight') - def get_base_margin(self) -> np.ndarray: + def get_sensitive_feature(self): + """Get the sensitive feature of the DMatrix. + Returns + ------- + sensitive_feature : array + """ + return self.get_float_info('sensitive_feature') + + def get_base_margin(self): """Get the base margin of the DMatrix. Returns @@ -1174,6 +1200,7 @@ def __init__( # pylint: disable=super-init-not-called qid=None, label_lower_bound=None, label_upper_bound=None, + sensitive_feature=None, feature_weights=None, enable_categorical: bool = False, ): @@ -1201,6 +1228,7 @@ def __init__( # pylint: disable=super-init-not-called qid=qid, label_lower_bound=label_lower_bound, label_upper_bound=label_upper_bound, + sensitive_feature=sensitive_feature, feature_weights=feature_weights, feature_names=feature_names, feature_types=feature_types, diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index ee08950559d5..b4e6e9165d81 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -316,6 +316,7 @@ def __init__( qid: Optional[_DaskCollection] = None, label_lower_bound: Optional[_DaskCollection] = None, label_upper_bound: Optional[_DaskCollection] = None, + sensitive_feature: Optional[_DaskCollection] = None, feature_weights: Optional[_DaskCollection] = None, enable_categorical: bool = False, ) -> None: @@ -358,6 +359,7 @@ def __init__( feature_weights=feature_weights, label_lower_bound=label_lower_bound, label_upper_bound=label_upper_bound, + sensitive_feature=sensitive_feature, ) def __await__(self) -> Generator: @@ -374,6 +376,7 @@ async def _map_local_data( feature_weights: Optional[_DaskCollection] = None, label_lower_bound: Optional[_DaskCollection] = None, label_upper_bound: Optional[_DaskCollection] = None, + sensitive_feature: Optional[_DaskCollection] = None, ) -> "DaskDMatrix": """Obtain references to local data.""" @@ -427,6 +430,7 @@ def flatten_meta(meta: OpDelayed) -> OpDelayed: qid_parts = flatten_meta(qid) ll_parts = flatten_meta(label_lower_bound) lu_parts = flatten_meta(label_upper_bound) + sf_parts = flatten_meta(sensitive_feature) parts: Dict[str, List[ddelayed.Delayed]] = {"data": X_parts} @@ -443,6 +447,7 @@ def append_meta(m_parts: Optional[List[ddelayed.Delayed]], name: str) -> None: append_meta(qid_parts, "qid") append_meta(ll_parts, "label_lower_bound") append_meta(lu_parts, "label_upper_bound") + append_meta(sf_parts, 'sensitive_feature') # At this point, `parts` looks like: # [(x0, x1, ..), (y0, y1, ..), ..] in delayed form @@ -570,7 +575,7 @@ def append(i: int, name: str) -> None: append(i, "qid") append(i, "label_lower_bound") append(i, "label_upper_bound") - + append(i, "sensitive_feature") return result @@ -586,6 +591,7 @@ def __init__( qid: Optional[List[Any]] = None, label_lower_bound: Optional[List[Any]] = None, label_upper_bound: Optional[List[Any]] = None, + sensitive_feature: Optional[List[Any]] = None, feature_names: FeatNamesT = None, feature_types: Optional[Union[Any, List[Any]]] = None, ) -> None: @@ -596,6 +602,7 @@ def __init__( self._qid = qid self._label_lower_bound = label_lower_bound self._label_upper_bound = label_upper_bound + self._sensitive_feature = sensitive_feature self._feature_names = feature_names self._feature_types = feature_types @@ -646,6 +653,7 @@ def next(self, input_data: Callable) -> int: base_margin=self._get("_base_margin"), label_lower_bound=self._get("_label_lower_bound"), label_upper_bound=self._get("_label_upper_bound"), + sensitive_feature=self._get("sensitive_feature"), feature_names=feature_names, feature_types=self._feature_types, ) diff --git a/src/common/common.h b/src/common/common.h index 8230e532ff69..fb7e7fee55da 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -188,6 +188,16 @@ std::vector ArgSort(Container const &array, Comp comp = std::less{}) { XGBOOST_PARALLEL_STABLE_SORT(result.begin(), result.end(), op); return result; } + +struct OptionalWeights { + Span weights; + float dft{1.0f}; + + explicit OptionalWeights(Span w) : weights{w} {} + explicit OptionalWeights(float w) : dft{w} {} + + XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; } +}; } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_COMMON_H_ diff --git a/src/common/fair_param.h b/src/common/fair_param.h new file mode 100644 index 000000000000..1df22b49a1b2 --- /dev/null +++ b/src/common/fair_param.h @@ -0,0 +1,14 @@ +/*! + * Copyright 2022 by XGBoost Contributors + */ +#include "xgboost/parameter.h" + +namespace xgboost { +struct BinaryRegularizationParam : public XGBoostParameter { + float fairness{0.0f}; + DMLC_DECLARE_PARAMETER(BinaryRegularizationParam) { + DMLC_DECLARE_FIELD(fairness).set_range(0.0f, 1.0f).describe( + "The strength of the regularizer for fairness XGBoost."); + } +}; +} // namespace xgboost diff --git a/src/common/linalg_op.cuh b/src/common/linalg_op.cuh index dfab58729b56..bdbee205b3ca 100644 --- a/src/common/linalg_op.cuh +++ b/src/common/linalg_op.cuh @@ -10,6 +10,21 @@ namespace xgboost { namespace linalg { template void ElementWiseKernelDevice(linalg::TensorView t, Fn&& fn, cudaStream_t s = nullptr) { + static_assert(std::is_void>::value, + "For function with return, use transform instead."); + if (t.Contiguous()) { + auto ptr = t.Values().data(); + dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { fn(i, ptr[i]); }); + } else { + dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable { + T& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape())); + fn(i, v); + }); + } +} + +template +void ElementWiseTransformDevice(linalg::TensorView t, Fn&& fn, cudaStream_t s = nullptr) { if (t.Contiguous()) { auto ptr = t.Values().data(); dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); }); diff --git a/src/common/linalg_op.h b/src/common/linalg_op.h index a74b119e7947..71c9991fbf00 100644 --- a/src/common/linalg_op.h +++ b/src/common/linalg_op.h @@ -1,15 +1,17 @@ /*! - * Copyright 2021 by XGBoost Contributors + * Copyright 2021-2022 by XGBoost Contributors */ #ifndef XGBOOST_COMMON_LINALG_OP_H_ #define XGBOOST_COMMON_LINALG_OP_H_ +#include + #include "threading_utils.h" #include "xgboost/linalg.h" namespace xgboost { namespace linalg { template -void ElementWiseKernelHost(linalg::TensorView t, int32_t n_threads, Fn&& fn) { +void ElementWiseTransformHost(linalg::TensorView t, int32_t n_threads, Fn&& fn) { if (t.Contiguous()) { auto ptr = t.Values().data(); common::ParallelFor(t.Size(), n_threads, [&](size_t i) { ptr[i] = fn(i, ptr[i]); }); @@ -20,6 +22,21 @@ void ElementWiseKernelHost(linalg::TensorView t, int32_t n_threads, Fn&& f }); } } + +template +void ElementWiseKernelHost(linalg::TensorView t, int32_t n_threads, Fn&& fn) { + static_assert(std::is_void>::value, + "For function with return, use transform instead."); + if (t.Contiguous()) { + auto ptr = t.Values().data(); + common::ParallelFor(t.Size(), n_threads, [&](size_t i) { fn(i, ptr[i]); }); + } else { + common::ParallelFor(t.Size(), n_threads, [&](size_t i) { + auto& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape())); + fn(i, v); + }); + } +} } // namespace linalg } // namespace xgboost #endif // XGBOOST_COMMON_LINALG_OP_H_ diff --git a/src/common/math.h b/src/common/math.h index 5a98ad329ce4..71a494544be1 100644 --- a/src/common/math.h +++ b/src/common/math.h @@ -23,7 +23,11 @@ namespace common { * \return the transformed value. */ XGBOOST_DEVICE inline float Sigmoid(float x) { - return 1.0f / (1.0f + expf(-x)); + float constexpr kEps = 1e-16; // avoid 0 div + x = std::min(-x, 88.7f); // avoid exp overflow + auto denom = expf(x) + 1.0f + kEps; + auto y = 1.0f / denom; + return y; } template diff --git a/src/data/data.cc b/src/data/data.cc index 3d1e3cc2862d..1a796b3d06a2 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -197,6 +197,7 @@ void MetaInfo::Clear() { * | base_margin | kFloat32 | False | ${Shape(0)} | ${Shape(1)} | ${base_margin_} | * | labels_lower_bound | kFloat32 | False | ${size} | 1 | ${labels_lower_bound_} | * | labels_upper_bound | kFloat32 | False | ${size} | 1 | ${labels_upper_bound_} | + * | sensitive_features | kFloat32 | False | ${Shape(0)} | 1 | ${sensitive_features_} | * | feature_names | kStr | False | ${size} | 1 | ${feature_names} | * | feature_types | kStr | False | ${size} | 1 | ${feature_types} | * | feature_weights | kFloat32 | False | ${size} | 1 | ${feature_weights} | @@ -224,7 +225,7 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const { {labels_lower_bound_.Size(), 1}, labels_lower_bound_); ++field_cnt; SaveVectorField(fo, u8"labels_upper_bound", DataType::kFloat32, {labels_upper_bound_.Size(), 1}, labels_upper_bound_); ++field_cnt; - + SaveTensorField(fo, u8"sensitive_features", DataType::kFloat32, sensitive_features), ++field_cnt; SaveVectorField(fo, u8"feature_names", DataType::kStr, {feature_names.size(), 1}, feature_names); ++field_cnt; SaveVectorField(fo, u8"feature_types", DataType::kStr, @@ -297,7 +298,7 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) { LoadTensorField(fi, u8"base_margin", DataType::kFloat32, &base_margin_); LoadVectorField(fi, u8"labels_lower_bound", DataType::kFloat32, &labels_lower_bound_); LoadVectorField(fi, u8"labels_upper_bound", DataType::kFloat32, &labels_upper_bound_); - + LoadTensorField(fi, u8"sensitive_features", DataType::kFloat32, &sensitive_features); LoadVectorField(fi, u8"feature_names", DataType::kStr, &feature_names); LoadVectorField(fi, u8"feature_types", DataType::kStr, &feature_type_names); LoadVectorField(fi, u8"feature_weights", DataType::kFloat32, &feature_weights); @@ -352,6 +353,12 @@ MetaInfo MetaInfo::Slice(common::Span ridxs) const { out.weights_.HostVector() = Gather(this->weights_.HostVector(), ridxs); } + // sensitive feature + auto t_sf = this->sensitive_features.View(this->sensitive_features.Data()->DeviceIdx()); + out.sensitive_features.Reshape(ridxs.size()); + out.sensitive_features.Data()->HostVector() = + Gather(this->sensitive_features.Data()->HostVector(), ridxs, t_sf.Stride(0)); + if (this->base_margin_.Size() != this->num_row_) { CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0) << "Incorrect size of base margin vector."; @@ -431,7 +438,7 @@ void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor* p_out) { auto t = p_out->View(GenericParameter::kCpuId); CHECK(t.CContiguous()); // FIXME(jiamingy): Remove the use of this default thread. - linalg::ElementWiseKernelHost(t, common::OmpGetNumThreads(0), [&](auto i, auto) { + linalg::ElementWiseTransformHost(t, common::OmpGetNumThreads(0), [&](auto i, auto) { return linalg::detail::Apply(TypedIndex{array}, linalg::UnravelIndex(i, t.Shape())); }); } @@ -487,7 +494,11 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) { auto valid = std::none_of(h_labels.cbegin(), h_labels.cend(), data::LabelsCheck{}); CHECK(valid) << "Label contains NaN, infinity or a value too large."; return; + } else if (key == "sensitive_feature") { + CopyTensorInfoImpl(arr, &this->sensitive_features); + return; } + // uint info if (key == "group") { linalg::Tensor t; @@ -593,6 +604,8 @@ void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype, vec = &this->labels.Data()->HostVector(); } else if (!std::strcmp(key, "weight")) { vec = &this->weights_.HostVector(); + } else if (!std::strcmp(key, "sensitive_feature")) { + vec = &this->sensitive_features.Data()->HostVector(); } else if (!std::strcmp(key, "base_margin")) { vec = &this->base_margin_.Data()->HostVector(); } else if (!std::strcmp(key, "label_lower_bound")) { @@ -676,6 +689,8 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col this->weights_.SetDevice(that.weights_.DeviceIdx()); this->weights_.Extend(that.weights_); + linalg::Stack(&this->sensitive_features, that.sensitive_features); + this->labels_lower_bound_.SetDevice(that.labels_lower_bound_.DeviceIdx()); this->labels_lower_bound_.Extend(that.labels_lower_bound_); @@ -877,7 +892,7 @@ DMatrix* DMatrix::Load(const std::string& uri, dmat = DMatrix::Create(&adapter, std::numeric_limits::quiet_NaN(), 1, cache_file); } else { - data::FileIterator iter{fname, uint32_t(partid), uint32_t(npart), + data::FileIterator iter{fname, static_cast(partid), static_cast(npart), file_format}; dmat = new data::SparsePageDMatrix{ &iter, diff --git a/src/data/data.cu b/src/data/data.cu index 475d70313310..55c1c80d0e23 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -49,7 +49,7 @@ void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor* p_out) { } p_out->Reshape(array.shape); auto t = p_out->View(ptr_device); - linalg::ElementWiseKernelDevice(t, [=] __device__(size_t i, T) { + linalg::ElementWiseTransformDevice(t, [=] __device__(size_t i, T) { return linalg::detail::Apply(TypedIndex{array}, linalg::UnravelIndex(i, array.shape)); }); } diff --git a/src/learner.cc b/src/learner.cc index 0ca52a9621b9..86cab0132b73 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -271,6 +271,21 @@ using LearnerAPIThreadLocalStore = using ThreadLocalPredictionCache = dmlc::ThreadLocalStore>; +namespace { +StringView ModelMsg() { + return StringView{ + R"doc( + If you are loading a serialized model (like pickle in Python, RDS in R) generated by + older XGBoost, please export the model by calling `Booster.save_model` from that version + first, then load it back in current version. See: + + https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html + + for more details about differences between saving model and serializing. +)doc"}; +} +} // anonymous namespace + class LearnerConfiguration : public Learner { private: std::mutex config_lock_; @@ -369,7 +384,6 @@ class LearnerConfiguration : public Learner { this->ConfigureGBM(old_tparam, args); generic_parameters_.ConfigureGpuId(this->gbm_->UseGPU()); - this->ConfigureMetrics(args); this->need_configuration_ = false; @@ -412,9 +426,17 @@ class LearnerConfiguration : public Learner { metric_names_.resize(n_metrics); metrics_.resize(n_metrics); for (size_t i = 0; i < n_metrics; ++i) { - metric_names_[i]= get(j_metrics[i]); - metrics_[i] = std::unique_ptr( - Metric::Create(metric_names_[i], &generic_parameters_)); + auto old_serialization = IsA(j_metrics[i]); + if (old_serialization) { + LOG(WARNING) << ModelMsg(); + metric_names_[i] = get(j_metrics[i]); + } else { + metric_names_[i] = get(j_metrics[i]["name"]); + } + metrics_[i] = std::unique_ptr(Metric::Create(metric_names_[i], &generic_parameters_)); + if (!old_serialization) { + metrics_[i]->LoadConfig(j_metrics[i]); + } } FromJson(learner_parameters.at("generic_param"), &generic_parameters_); @@ -442,9 +464,9 @@ class LearnerConfiguration : public Learner { auto& objective_fn = learner_parameters["objective"]; obj_->SaveConfig(&objective_fn); - std::vector metrics(metrics_.size()); + std::vector metrics(metrics_.size(), Json{Object{}}); for (size_t i = 0; i < metrics_.size(); ++i) { - metrics[i] = String(metrics_[i]->Name()); + metrics_[i]->SaveConfig(&metrics[i]); } learner_parameters["metrics"] = Array(std::move(metrics)); @@ -703,21 +725,6 @@ class LearnerConfiguration : public Learner { std::string const LearnerConfiguration::kEvalMetric {"eval_metric"}; // NOLINT -namespace { -StringView ModelMsg() { - return StringView{ - R"doc( - If you are loading a serialized model (like pickle in Python, RDS in R) generated by - older XGBoost, please export the model by calling `Booster.save_model` from that version - first, then load it back in current version. See: - - https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html - - for more details about differences between saving model and serializing. -)doc"}; -} -} // anonymous namespace - class LearnerIO : public LearnerConfiguration { private: std::set saved_configs_ = {"num_round"}; diff --git a/src/metric/auc.cc b/src/metric/auc.cc index 1957bcc9a083..0829db8f1fee 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -33,7 +33,7 @@ namespace metric { template std::tuple BinaryAUC(common::Span predts, linalg::VectorView labels, - OptionalWeights weights, + common::OptionalWeights weights, std::vector const &sorted_idx, Fn &&area_fn) { CHECK_NE(labels.Size(), 0); CHECK_EQ(labels.Size(), predts.size()); @@ -93,7 +93,7 @@ double MultiClassOVR(common::Span predts, MetaInfo const &info, auto tp = results.Slice(linalg::All(), 1); auto auc = results.Slice(linalg::All(), 2); - auto weights = OptionalWeights{info.weights_.ConstHostSpan()}; + auto weights = common::OptionalWeights{info.weights_.ConstHostSpan()}; auto predts_t = linalg::TensorView( predts, {static_cast(info.num_row_), n_classes}, GenericParameter::kCpuId); @@ -140,7 +140,7 @@ double MultiClassOVR(common::Span predts, MetaInfo const &info, std::tuple BinaryROCAUC(common::Span predts, linalg::VectorView labels, - OptionalWeights weights) { + common::OptionalWeights weights) { auto const sorted_idx = common::ArgSort(predts, std::greater<>{}); return BinaryAUC(predts, labels, weights, sorted_idx, TrapezoidArea); } @@ -186,7 +186,7 @@ double GroupRankingROC(common::Span predts, */ std::tuple BinaryPRAUC(common::Span predts, linalg::VectorView labels, - OptionalWeights weights) { + common::OptionalWeights weights) { auto const sorted_idx = common::ArgSort(predts, std::greater<>{}); double total_pos{0}, total_neg{0}; for (size_t i = 0; i < labels.Size(); ++i) { @@ -238,7 +238,7 @@ std::pair RankingAUC(std::vector const &predts, if (is_roc) { auc = GroupRankingROC(g_predts, g_labels, w); } else { - auc = std::get<2>(BinaryPRAUC(g_predts, g_labels, OptionalWeights{w})); + auc = std::get<2>(BinaryPRAUC(g_predts, g_labels, common::OptionalWeights{w})); } if (std::isnan(auc)) { invalid_groups++; @@ -373,7 +373,7 @@ class EvalROCAUC : public EvalAUC { if (tparam_->gpu_id == GenericParameter::kCpuId) { std::tie(fp, tp, auc) = BinaryROCAUC(predts.ConstHostVector(), info.labels.HostView().Slice(linalg::All(), 0), - OptionalWeights{info.weights_.ConstHostSpan()}); + common::OptionalWeights{info.weights_.ConstHostSpan()}); } else { std::tie(fp, tp, auc) = GPUBinaryROCAUC(predts.ConstDeviceSpan(), info, tparam_->gpu_id, &this->d_cache_); @@ -426,7 +426,7 @@ class EvalPRAUC : public EvalAUC { if (tparam_->gpu_id == GenericParameter::kCpuId) { std::tie(pr, re, auc) = BinaryPRAUC(predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0), - OptionalWeights{info.weights_.ConstHostSpan()}); + common::OptionalWeights{info.weights_.ConstHostSpan()}); } else { std::tie(pr, re, auc) = GPUBinaryPRAUC(predts.ConstDeviceSpan(), info, tparam_->gpu_id, &this->d_cache_); diff --git a/src/metric/auc.cu b/src/metric/auc.cu index 317ce7db2c84..be89c015c93d 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -99,7 +99,7 @@ GPUBinaryAUC(common::Span predts, MetaInfo const &info, /** * Linear scan */ - auto get_weight = OptionalWeights{weights}; + auto get_weight = common::OptionalWeights{weights}; auto get_fp_tp = [=]XGBOOST_DEVICE(size_t i) { size_t idx = d_sorted_idx[i]; @@ -353,7 +353,7 @@ double GPUMultiClassAUCOVR(common::Span predts, * Linear scan */ dh::caching_device_vector d_auc(n_classes, 0); - auto get_weight = OptionalWeights{weights}; + auto get_weight = common::OptionalWeights{weights}; auto d_fptp = dh::ToSpan(cache->fptp); auto get_fp_tp = [=]XGBOOST_DEVICE(size_t i) { size_t idx = d_sorted_idx[i]; @@ -633,7 +633,7 @@ GPUBinaryPRAUC(common::Span predts, MetaInfo const &info, auto labels = info.labels.View(device); auto d_weights = info.weights_.ConstDeviceSpan(); - auto get_weight = OptionalWeights{d_weights}; + auto get_weight = common::OptionalWeights{d_weights}; auto it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { auto w = get_weight[d_sorted_idx[i]]; @@ -687,7 +687,7 @@ double GPUMultiClassPRAUC(common::Span predts, [n_samples] XGBOOST_DEVICE(size_t i) { return i / n_samples; // class id }); - auto get_weight = OptionalWeights{d_weights}; + auto get_weight = common::OptionalWeights{d_weights}; auto val_it = dh::MakeTransformIterator>( thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { auto idx = d_sorted_idx[i] % n_samples; @@ -736,7 +736,7 @@ GPURankingPRAUCImpl(common::Span predts, MetaInfo const &info, */ size_t n_samples = labels.Shape(0); dh::caching_device_vector d_auc(n_groups, 0); - auto get_weight = OptionalWeights{weights}; + auto get_weight = common::OptionalWeights{weights}; auto d_fptp = dh::ToSpan(cache->fptp); auto get_fp_tp = [=] XGBOOST_DEVICE(size_t i) { size_t idx = d_sorted_idx[i]; diff --git a/src/metric/auc.h b/src/metric/auc.h index cde8febf2b9e..c42df6890a39 100644 --- a/src/metric/auc.h +++ b/src/metric/auc.h @@ -112,18 +112,6 @@ struct PRAUCLabelInvalid { inline void InvalidLabels() { LOG(FATAL) << "PR-AUC supports only binary relevance for learning to rank."; } - -struct OptionalWeights { - common::Span weights; - float dft { 1.0f }; - - explicit OptionalWeights(common::Span w) : weights{w} {} - explicit OptionalWeights(float w) : dft{w} {} - - XGBOOST_DEVICE float operator[](size_t i) const { - return weights.empty() ? dft : weights[i]; - } -}; } // namespace metric } // namespace xgboost #endif // XGBOOST_METRIC_AUC_H_ diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index abf888e0b878..2992a89d4404 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -6,15 +6,17 @@ * * The expressions like wsum == 0 ? esum : esum / wsum is used to handle empty dataset. */ +#include #include #include -#include + #include -#include "metric_common.h" -#include "../common/math.h" #include "../common/common.h" +#include "../common/fair_param.h" +#include "../common/math.h" #include "../common/threading_utils.h" +#include "metric_common.h" #if defined(XGBOOST_USE_CUDA) #include // thrust::cuda::par @@ -187,25 +189,94 @@ struct EvalRowMAPE { } }; +namespace { +XGBOOST_DEVICE inline float LogLoss(float y, float py) { + auto xlogy = [](float x, float y) { + float eps = 1e-16; + return (x - 0.0 == 0.0) ? 0 : (x * std::log(std::max(y, eps))); + }; + const bst_float pneg = 1.0f - py; + return xlogy(-y, py) + xlogy(-(1.0f - y), pneg); +} +} // anonymous namespace + struct EvalRowLogLoss { const char *Name() const { return "logloss"; } - XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const { - const bst_float eps = 1e-16f; - const bst_float pneg = 1.0f - py; - if (py < eps) { - return -y * std::log(eps) - (1.0f - y) * std::log(1.0f - eps); - } else if (pneg < eps) { - return -y * std::log(1.0f - eps) - (1.0f - y) * std::log(eps); + XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const { return LogLoss(y, py); } + static double GetFinal(double esum, double wsum) { + return wsum == 0 ? esum : esum / wsum; + } +}; + +class RegularizedLogLoss : public Metric { + BinaryRegularizationParam param_; + + public: + const char* Name() const override { return "regularized-logloss"; } + void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } + void LoadConfig(Json const& in) override { FromJson(in["binary_regularized_param"], ¶m_); } + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["name"] = String(this->Name()); + out["binary_regularized_param"] = ToJson(param_); + } + + double Eval(const HostDeviceVector& preds, const MetaInfo& info, + bool distributed) override { + auto sensitive_features = info.sensitive_features.View(tparam_->gpu_id); + auto labels = info.labels.View(tparam_->gpu_id); + auto predts = tparam_->gpu_id == GenericParameter::kCpuId ? preds.ConstHostSpan() + : preds.ConstDeviceSpan(); + auto n_targets = std::max(info.labels.Shape(1), static_cast(1)); + common::OptionalWeights weights(tparam_->gpu_id == GenericParameter::kCpuId + ? info.weights_.ConstHostSpan() + : info.weights_.ConstDeviceSpan()); + float fairness = this->param_.fairness; + auto loss = [=] XGBOOST_DEVICE(size_t i, float wt) { + auto v = (LogLoss(labels(i), predts[i]) * wt) - + (fairness * LogLoss(sensitive_features(i), predts[i]) * wt); + return v; + }; + PackedReduceResult result; + if (tparam_->gpu_id != GenericParameter::kCpuId) { +#if defined(XGBOOST_USE_CUDA) + dh::XGBCachingDeviceAllocator alloc; + thrust::counting_iterator begin(0); + thrust::counting_iterator end = begin + info.num_row_; + result = thrust::transform_reduce( + thrust::cuda::par(alloc), begin, end, + [=] XGBOOST_DEVICE(size_t idx) { + float weight = weights[idx / n_targets]; + float l = loss(idx, weight); + return PackedReduceResult{l, weight}; + }, + PackedReduceResult{}, thrust::plus()); +#else + common::AssertGPUSupport(); +#endif // defined(XGBOOST_USE_CUDA) } else { - return -y * std::log(py) - (1.0f - y) * std::log(pneg); + auto n_threads = tparam_->Threads(); + std::vector score_tloc(n_threads, 0.0); + std::vector weight_tloc(n_threads, 0.0); + common::ParallelFor(info.num_row_, tparam_->Threads(), [&](size_t i) { + float wt = weights[i / n_targets]; + auto t_idx = omp_get_thread_num(); + score_tloc[t_idx] += loss(i, weights[i]); + weight_tloc[t_idx] += wt; + }); + double residue_sum = std::accumulate(score_tloc.cbegin(), score_tloc.cend(), 0.0); + double weights_sum = std::accumulate(weight_tloc.cbegin(), weight_tloc.cend(), 0.0); + result = PackedReduceResult{residue_sum, weights_sum}; } - } - static double GetFinal(double esum, double wsum) { - return wsum == 0 ? esum : esum / wsum; + double dat[2]{result.Residue(), result.Weights()}; + if (distributed) { + rabit::Allreduce(dat, 2); + } + return EvalRowLogLoss::GetFinal(dat[0], dat[1]); } }; @@ -409,6 +480,10 @@ XGBOOST_REGISTER_METRIC(LogLoss, "logloss") .describe("Negative loglikelihood for logistic regression.") .set_body([](const char* param) { return new EvalEWiseBase(); }); +XGBOOST_REGISTER_METRIC(RegularizedLogLoss, "regularized-logloss") +.describe("Negative loglikelihood for regularized binary classification.") +.set_body([](const char* param) { return new RegularizedLogLoss(); }); + XGBOOST_REGISTER_METRIC(PossionNegLoglik, "poisson-nloglik") .describe("Negative loglikelihood for poisson regression.") .set_body([](const char* param) { return new EvalEWiseBase(); }); @@ -430,6 +505,5 @@ XGBOOST_REGISTER_METRIC(TweedieNLogLik, "tweedie-nloglik") .set_body([](const char* param) { return new EvalEWiseBase(param); }); - } // namespace metric } // namespace xgboost diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index a07de8e446a8..41f8cb511ae7 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -8,20 +8,25 @@ #include #include #include + #include #include #include +#include "../common/common.h" +#include "../common/fair_param.h" +#include "../common/linalg_op.h" +#include "../common/threading_utils.h" +#include "../common/transform.h" +#include "./regression_loss.h" #include "xgboost/host_device_vector.h" #include "xgboost/json.h" #include "xgboost/parameter.h" #include "xgboost/span.h" -#include "../common/transform.h" -#include "../common/common.h" -#include "../common/threading_utils.h" -#include "./regression_loss.h" - +#if defined(XGBOOST_USE_CUDA) +#include "../common/linalg_op.cuh" +#endif // defined(XGBOOST_USE_CUDA) namespace xgboost { namespace obj { @@ -200,6 +205,104 @@ XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear") return new RegLossObj(); }); // End deprecated +/** + * \brief Implementation of https://arxiv.org/abs/2009.01442 + */ +class RegularizedClassification : public ObjFunction { + BinaryRegularizationParam param_; + + public: + void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } + struct ObjInfo Task() const override { + return {ObjInfo::kBinary, false}; + } + + uint32_t Targets(MetaInfo const& info) const override { + return std::max(static_cast(1), info.labels.Shape(1)); + } + + void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, int, + HostDeviceVector* out_gpair) override { + CHECK_EQ(info.sensitive_features.Size(), info.num_row_) + << "Incorrect shape of sensitive features, Expecting: (" << info.num_row_ << "), got: (" + << info.sensitive_features.Size() << ")"; + CHECK_EQ(info.labels.Shape(0), info.num_row_); + CHECK_EQ(info.labels.Shape(1), 1); + + auto fairness = param_.fairness; + if (!info.weights_.Empty()) { + CHECK_EQ(info.weights_.Size(), info.num_row_) + << "Number of weights should be equal to number of data points."; + } + + auto fn = [fairness] XGBOOST_DEVICE( + size_t i, float y, linalg::TensorView predt_t, + common::OptionalWeights weight, linalg::TensorView sensitive, + linalg::TensorView gpair) { + auto predt = common::Sigmoid(predt_t(i)); + auto sf = sensitive(i); + auto grad = (predt - y) + (fairness * (sf - predt)); + auto hess = (1.0f - fairness) * predt * (1.0f - predt); + auto w = weight[i]; + gpair(i) = {grad * w, hess * w}; + }; + + auto sensitive = info.sensitive_features.View(ctx_->gpu_id); + out_gpair->SetDevice(ctx_->gpu_id); + out_gpair->Resize(info.num_row_); + preds.SetDevice(ctx_->gpu_id); + info.weights_.SetDevice(ctx_->gpu_id); + if (ctx_->gpu_id != GenericParameter::kCpuId) { +#if defined(XGBOOST_USE_CUDA) + auto gpair = linalg::MakeVec(out_gpair->DevicePointer(), info.num_row_, ctx_->gpu_id); + auto predt = linalg::MakeVec(preds.ConstDevicePointer(), preds.Size(), ctx_->gpu_id); + common::OptionalWeights weight{info.weights_.ConstDeviceSpan()}; + linalg::ElementWiseKernelDevice( + info.labels.View(ctx_->gpu_id), + [=] XGBOOST_DEVICE(size_t i, float y) { fn(i, y, predt, weight, sensitive, gpair); }); +#else + common::AssertGPUSupport(); +#endif // defined(XGBOOST_USE_CUDA) + } else { + auto gpair = linalg::MakeVec(out_gpair->HostPointer(), info.num_row_); + auto predt = linalg::MakeVec(preds.ConstHostPointer(), preds.Size(), ctx_->gpu_id); + common::OptionalWeights weight{info.weights_.ConstHostSpan()}; + linalg::ElementWiseKernelHost( + info.labels.HostView(), ctx_->Threads(), + [=] XGBOOST_DEVICE(size_t i, float y) { fn(i, y, predt, weight, sensitive, gpair); }); + } + } + + void PredTransform(HostDeviceVector* io_preds) const override { + auto eps = kRtEps; // undefined in device code. + common::Transform<>::Init( + [eps] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { + _preds[_idx] = common::Sigmoid(_preds[_idx]); + }, + common::Range{0, static_cast(io_preds->Size())}, this->ctx_->Threads(), + io_preds->DeviceIdx()) + .Eval(io_preds); + } + + float ProbToMargin(float base_score) const override { + return LogisticClassification::ProbToMargin(base_score); + } + + const char* DefaultEvalMetric() const override { return "regularized-logloss"; } + + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["name"] = String("binary:regularized"); + out["binary_regularized_param"] = ToJson(param_); + } + + void LoadConfig(Json const& in) override { FromJson(in["binary_regularized_param"], ¶m_); } +}; + +XGBOOST_REGISTER_OBJECTIVE(FairClassification, "binary:regularized") + .describe("binary classification with fairness.") + .set_body([]() { return new RegularizedClassification(); }); + // declare parameter struct PoissonRegressionParam : public XGBoostParameter { float max_delta_step; @@ -608,4 +711,6 @@ XGBOOST_REGISTER_OBJECTIVE(TweedieRegression, "reg:tweedie") .set_body([]() { return new TweedieRegression(); }); } // namespace obj + +DMLC_REGISTER_PARAMETER(BinaryRegularizationParam); } // namespace xgboost diff --git a/tests/cpp/common/test_linalg.cc b/tests/cpp/common/test_linalg.cc index a4f3e6ab41fb..110f18fcb289 100644 --- a/tests/cpp/common/test_linalg.cc +++ b/tests/cpp/common/test_linalg.cc @@ -314,11 +314,11 @@ TEST(Linalg, Popc) { TEST(Linalg, Stack) { Tensor l{{2, 3, 4}, kCpuId}; - ElementWiseKernelHost(l.View(kCpuId), omp_get_max_threads(), - [=](size_t i, float v) { return i; }); + ElementWiseTransformHost(l.View(kCpuId), omp_get_max_threads(), + [=](size_t i, float v) { return i; }); Tensor r_0{{2, 3, 4}, kCpuId}; - ElementWiseKernelHost(r_0.View(kCpuId), omp_get_max_threads(), - [=](size_t i, float v) { return i; }); + ElementWiseTransformHost(r_0.View(kCpuId), omp_get_max_threads(), + [=](size_t i, float v) { return i; }); Stack(&l, r_0); diff --git a/tests/cpp/common/test_linalg.cu b/tests/cpp/common/test_linalg.cu index 9ea6b22dd012..78f6a8c25e4f 100644 --- a/tests/cpp/common/test_linalg.cu +++ b/tests/cpp/common/test_linalg.cu @@ -19,7 +19,7 @@ void TestElementWiseKernel() { // GPU view auto t = l.View(0).Slice(linalg::All(), 1, linalg::All()); ASSERT_FALSE(t.CContiguous()); - ElementWiseKernelDevice(t, [] __device__(size_t i, float) { return i; }); + ElementWiseTransformDevice(t, [] __device__(size_t i, float) { return i; }); // CPU view t = l.View(GenericParameter::kCpuId).Slice(linalg::All(), 1, linalg::All()); size_t k = 0; @@ -30,10 +30,7 @@ void TestElementWiseKernel() { } t = l.View(0).Slice(linalg::All(), 1, linalg::All()); - ElementWiseKernelDevice(t, [] __device__(size_t i, float v) { - SPAN_CHECK(v == i); - return v; - }); + ElementWiseKernelDevice(t, [] XGBOOST_DEVICE(size_t i, float v) { SPAN_CHECK(v == i); }); } { @@ -41,8 +38,10 @@ void TestElementWiseKernel() { * Contiguous */ auto t = l.View(0); - ElementWiseKernelDevice(t, [] __device__(size_t i, float) { return i; }); + ElementWiseTransformDevice(t, [] XGBOOST_DEVICE(size_t i, float) { return i; }); ASSERT_TRUE(t.CContiguous()); + ; + ; // CPU view t = l.View(GenericParameter::kCpuId); diff --git a/tests/cpp/data/test_metainfo.h b/tests/cpp/data/test_metainfo.h index f070e6f81870..bb86e16eaefa 100644 --- a/tests/cpp/data/test_metainfo.h +++ b/tests/cpp/data/test_metainfo.h @@ -29,14 +29,13 @@ inline void TestMetaInfoStridedData(int32_t device) { auto const& h_result = info.labels.View(-1); ASSERT_EQ(h_result.Shape().size(), 2); auto in_labels = labels.View(-1); - linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float v_0) { + linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float& v_0) { auto tup = linalg::UnravelIndex(i, h_result.Shape()); auto i0 = std::get<0>(tup); auto i1 = std::get<1>(tup); // Sliced at second dimension. auto v_1 = in_labels(i0, 0, i1); CHECK_EQ(v_0, v_1); - return v_0; }); } { @@ -71,7 +70,6 @@ inline void TestMetaInfoStridedData(int32_t device) { // Sliced at second dimension. auto v_1 = in_margin(i0, 0, i1); CHECK_EQ(v_0, v_1); - return v_0; }); } } diff --git a/tests/cpp/objective/test_regression_obj.cc b/tests/cpp/objective/test_regression_obj.cc index 6f396ea76e62..115627ff133d 100644 --- a/tests/cpp/objective/test_regression_obj.cc +++ b/tests/cpp/objective/test_regression_obj.cc @@ -131,7 +131,6 @@ TEST(Objective, DeclareUnifiedTest(LogisticRawGPair)) { std::unique_ptr obj { ObjFunction::Create("binary:logitraw", &lparam) }; - obj->Configure(args); CheckObjFunction(obj, @@ -355,6 +354,64 @@ TEST(Objective, DeclareUnifiedTest(TweedieRegressionBasic)) { } } +TEST(Objective, DeclareUnifiedTest(RegularizedClassification)) { + GenericParameter lparam = CreateEmptyGenericParam(GPUIDX); + Args args{{"fairness", "0.0"}}; + std::unique_ptr obj{ObjFunction::Create("binary:regularized", &lparam)}; + + obj->Configure(args); + CheckConfigReload(obj, "binary:regularized"); + + MetaInfo info; + info.num_row_ = 16; + info.sensitive_features = linalg::Tensor{{info.num_row_}, GPUIDX}; + auto& h_sf = info.sensitive_features.Data()->HostVector(); + for (size_t i = 0; i < h_sf.size(); ++i) { + h_sf[i] = i % 2 == 0; + } + + info.labels = linalg::Tensor{{info.num_row_, static_cast(1)}, GPUIDX}; + auto& h_y = info.labels.Data()->HostVector(); + for (size_t i = 0; i < h_y.size(); ++i) { + h_y[i] = i % 2 != 0; + } + + HostDeviceVector predts; + predts.SetDevice(GPUIDX); + predts.Resize(info.num_row_); + auto& h_predts = predts.HostVector(); + for (size_t i = 0; i < h_y.size(); ++i) { + h_predts[i] = i % 2 != 0; + } + + HostDeviceVector reg_gpair; + obj->GetGradient(predts, info, 0, ®_gpair); + auto const& h_reg = reg_gpair.ConstHostVector(); + + // fairness == 0 means unbiased + std::unique_ptr logistic{ObjFunction::Create("binary:logistic", &lparam)}; + logistic->Configure({}); + HostDeviceVector logistic_gpair; + obj->GetGradient(predts, info, 0, &logistic_gpair); + auto const& h_logistic = logistic_gpair.ConstHostVector(); + for (size_t i = 0; i < h_reg.size(); ++i) { + ASSERT_EQ(h_logistic[i], h_reg[i]); + } + + auto test_regularized = [&]() { + obj->Configure({{"fairness", "1.0"}}); + obj->GetGradient(predts, info, 0, ®_gpair); + auto const& h_reg = reg_gpair.ConstHostVector(); + for (size_t i = 0; i < h_reg.size(); ++i) { + ASSERT_EQ(h_reg[i].GetHess(), 0.0f); + ASSERT_EQ(h_reg[i].GetGrad(), i % 2 == 0 ? 1.0 : -1.0); + } + }; + test_regularized(); + info.weights_.Resize(info.num_row_, 1.0); + test_regularized(); +} + // CoxRegression not implemented in GPU code, no need for testing. #if !defined(__CUDACC__) TEST(Objective, CoxRegressionGPair) { @@ -373,5 +430,4 @@ TEST(Objective, CoxRegressionGPair) { { 0, 0, 0, 0.160f, 0.186f, 0.348f, 0.610f, 0.639f}); } #endif - } // namespace xgboost diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index f7e2215405c2..eaba41b6aa48 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -430,8 +430,8 @@ TEST(Learner, MultiTarget) { size_t constexpr kRows{128}, kCols{10}, kTargets{3}; auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); m->Info().labels.Reshape(kRows, kTargets); - linalg::ElementWiseKernelHost(m->Info().labels.HostView(), omp_get_max_threads(), - [](auto i, auto) { return i; }); + linalg::ElementWiseTransformHost(m->Info().labels.HostView(), omp_get_max_threads(), + [](auto i, auto) { return i; }); { std::unique_ptr learner{Learner::Create({m})}; From 061ba22716310b526cfe1afe4dc56ef01b5a19b7 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 9 Feb 2022 23:18:58 +0800 Subject: [PATCH 02/20] Windows. --- tests/cpp/objective/test_regression_obj.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/objective/test_regression_obj.cc b/tests/cpp/objective/test_regression_obj.cc index 115627ff133d..8639d24a394c 100644 --- a/tests/cpp/objective/test_regression_obj.cc +++ b/tests/cpp/objective/test_regression_obj.cc @@ -370,7 +370,7 @@ TEST(Objective, DeclareUnifiedTest(RegularizedClassification)) { h_sf[i] = i % 2 == 0; } - info.labels = linalg::Tensor{{info.num_row_, static_cast(1)}, GPUIDX}; + info.labels = linalg::Tensor{{info.num_row_, static_cast(1)}, GPUIDX}; auto& h_y = info.labels.Data()->HostVector(); for (size_t i = 0; i < h_y.size(); ++i) { h_y[i] = i % 2 != 0; From 24cf56232c4d81d998393d95fee7be5f8c7e4a42 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 9 Feb 2022 23:25:12 +0800 Subject: [PATCH 03/20] lint. --- python-package/xgboost/core.py | 2 +- python-package/xgboost/dask.py | 17 ++++++++++------- src/common/fair_param.h | 8 ++++++-- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 56c81fe4b73b..529e53a11e3c 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -501,7 +501,7 @@ def inner_f(*args, **kwargs): return inner_f -class DMatrix: # pylint: disable=too-many-instance-attributes +class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods """Data Matrix used in XGBoost. DMatrix is an internal data structure that is used by XGBoost, diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index b4e6e9165d81..3a1d12accf55 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -441,13 +441,16 @@ def append_meta(m_parts: Optional[List[ddelayed.Delayed]], name: str) -> None: ) parts[name] = m_parts - append_meta(y_parts, "label") - append_meta(w_parts, "weight") - append_meta(margin_parts, "base_margin") - append_meta(qid_parts, "qid") - append_meta(ll_parts, "label_lower_bound") - append_meta(lu_parts, "label_upper_bound") - append_meta(sf_parts, 'sensitive_feature') + for p, n in [ + (y_parts, "label"), + (w_parts, "weight"), + (margin_parts, "base_margin"), + (qid_parts, "qid"), + (ll_parts, "label_lower_bound"), + (lu_parts, "label_upper_bound"), + (sf_parts, "sensitive_feature"), + ]: + append_meta(p, n) # At this point, `parts` looks like: # [(x0, x1, ..), (y0, y1, ..), ..] in delayed form diff --git a/src/common/fair_param.h b/src/common/fair_param.h index 1df22b49a1b2..05301bcecadc 100644 --- a/src/common/fair_param.h +++ b/src/common/fair_param.h @@ -1,3 +1,5 @@ +#ifndef XGBOOST_COMMON_FAIR_PARAM_H_ +#define XGBOOST_COMMON_FAIR_PARAM_H_ /*! * Copyright 2022 by XGBoost Contributors */ @@ -7,8 +9,10 @@ namespace xgboost { struct BinaryRegularizationParam : public XGBoostParameter { float fairness{0.0f}; DMLC_DECLARE_PARAMETER(BinaryRegularizationParam) { - DMLC_DECLARE_FIELD(fairness).set_range(0.0f, 1.0f).describe( - "The strength of the regularizer for fairness XGBoost."); + DMLC_DECLARE_FIELD(fairness) + .set_range(0.0f, 1.0f) + .describe("The strength of the regularizer for fairness XGBoost."); } }; } // namespace xgboost +#endif // XGBOOST_COMMON_FAIR_PARAM_H_ From 40423471d84aff13f8c394b39509cbdefde8bd85 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 9 Feb 2022 23:27:28 +0800 Subject: [PATCH 04/20] fix test. --- tests/cpp/objective/test_objective.cc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/cpp/objective/test_objective.cc b/tests/cpp/objective/test_objective.cc index fd110deb1f2e..68adfbf7011f 100644 --- a/tests/cpp/objective/test_objective.cc +++ b/tests/cpp/objective/test_objective.cc @@ -25,11 +25,13 @@ TEST(Objective, PredTransform) { tparam.UpdateAllowUnknown(Args{{"gpu_id", "0"}}); size_t n = 100; - for (const auto &entry : - ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) { - std::unique_ptr obj{ - xgboost::ObjFunction::Create(entry->name, &tparam)}; - obj->Configure(Args{{"num_class", "2"}}); + for (const auto& entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) { + std::unique_ptr obj{xgboost::ObjFunction::Create(entry->name, &tparam)}; + if (entry->name == "binary:regularized") { + obj->Configure(Args{{"num_class", "2"}, {"fairness", "0.5"}}); + } else { + obj->Configure(Args{{"num_class", "2"}}); + } HostDeviceVector predts; predts.Resize(n, 3.14f); // prediction is performed on host. ASSERT_FALSE(predts.DeviceCanRead()); From 7f66320ff574f3a5eb7aead45bd4ae1862d25560 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 10 Feb 2022 00:12:05 +0800 Subject: [PATCH 05/20] Fixes. --- doc/model.schema | 19 +++++++++++++++++++ python-package/xgboost/dask.py | 2 +- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/doc/model.schema b/doc/model.schema index 86acea967f1e..82e92c1dfbbe 100644 --- a/doc/model.schema +++ b/doc/model.schema @@ -204,6 +204,14 @@ } } }, + "binary_regularized_param": { + "type": "object", + "properties": { + "fairness": { + "type": "number" + } + } + }, "aft_loss_param": { "type": "object", "properties": { @@ -378,6 +386,17 @@ "reg_loss_param" ] }, + { + "type": "object", + "properties": { + "name": { "const": "binary:regularized" }, + "binary_regularized_param": { "$ref": "#/definitions/binary_regularized_param"} + }, + "required": [ + "name", + "binary_regularized_param" + ] + }, { "type": "object", "properties": { diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 3a1d12accf55..6fc10eb49344 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -656,7 +656,7 @@ def next(self, input_data: Callable) -> int: base_margin=self._get("_base_margin"), label_lower_bound=self._get("_label_lower_bound"), label_upper_bound=self._get("_label_upper_bound"), - sensitive_feature=self._get("sensitive_feature"), + sensitive_feature=self._get("_sensitive_feature"), feature_names=feature_names, feature_types=self._feature_types, ) From 27506316101e3b7f241b3e0b9e1de847751e8591 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 10 Feb 2022 01:53:26 +0800 Subject: [PATCH 06/20] Missing --- python-package/xgboost/dask.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 6fc10eb49344..e85a16f0047c 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -698,6 +698,7 @@ def __init__( qid: Optional[_DaskCollection] = None, label_lower_bound: Optional[_DaskCollection] = None, label_upper_bound: Optional[_DaskCollection] = None, + sensitive_feature: Optional[_DaskCollection] = None, feature_weights: Optional[_DaskCollection] = None, enable_categorical: bool = False, ) -> None: @@ -711,6 +712,7 @@ def __init__( qid=qid, label_lower_bound=label_lower_bound, label_upper_bound=label_upper_bound, + sensitive_feature=sensitive_feature, missing=missing, silent=silent, feature_weights=feature_weights, From f3cf322b8593eafdc02b2063dc808c581c3b51ca Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 10 Feb 2022 05:13:21 +0800 Subject: [PATCH 07/20] Validation. --- include/xgboost/generic_parameters.h | 2 + include/xgboost/linalg.h | 12 +++++ src/metric/elementwise_metric.cu | 40 +++++++------- src/objective/regression_obj.cu | 80 ++++++++++++++++++++-------- 4 files changed, 92 insertions(+), 42 deletions(-) diff --git a/include/xgboost/generic_parameters.h b/include/xgboost/generic_parameters.h index dc76b7c3a360..4a0188b23fb4 100644 --- a/include/xgboost/generic_parameters.h +++ b/include/xgboost/generic_parameters.h @@ -42,6 +42,8 @@ struct GenericParameter : public XGBoostParameter { */ int32_t Threads() const; + bool IsCPU() const { return gpu_id == kCpuId; } + // declare parameters DMLC_DECLARE_PARAMETER(GenericParameter) { DMLC_DECLARE_FIELD(seed).set_default(kDefaultSeed).describe( diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index 5caba4ae43a6..32d0f9fb9f9c 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -548,6 +548,18 @@ auto MakeVec(T *ptr, size_t s, int32_t device = -1) { return linalg::TensorView{{ptr, s}, {s}, device}; } +template +auto MakeVec(HostDeviceVector *data) { + return MakeVec(data->DeviceIdx() == -1 ? data->HostPointer() : data->DevicePointer(), + data->Size(), data->DeviceIdx()); +} + +template +auto MakeVec(HostDeviceVector const *data) { + return MakeVec(data->DeviceIdx() == -1 ? data->ConstHostPointer() : data->ConstDevicePointer(), + data->Size(), data->DeviceIdx()); +} + /** * \brief A view over a matrix, specialization of Tensor. * diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index 2992a89d4404..9db4a0aa322b 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -228,20 +228,32 @@ class RegularizedLogLoss : public Metric { bool distributed) override { auto sensitive_features = info.sensitive_features.View(tparam_->gpu_id); auto labels = info.labels.View(tparam_->gpu_id); - auto predts = tparam_->gpu_id == GenericParameter::kCpuId ? preds.ConstHostSpan() - : preds.ConstDeviceSpan(); + auto predts = tparam_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan(); auto n_targets = std::max(info.labels.Shape(1), static_cast(1)); - common::OptionalWeights weights(tparam_->gpu_id == GenericParameter::kCpuId - ? info.weights_.ConstHostSpan() - : info.weights_.ConstDeviceSpan()); + common::OptionalWeights weights(tparam_->IsCPU() ? info.weights_.ConstHostSpan() + : info.weights_.ConstDeviceSpan()); float fairness = this->param_.fairness; + auto n_samples = info.num_row_; auto loss = [=] XGBOOST_DEVICE(size_t i, float wt) { auto v = (LogLoss(labels(i), predts[i]) * wt) - (fairness * LogLoss(sensitive_features(i), predts[i]) * wt); return v; }; PackedReduceResult result; - if (tparam_->gpu_id != GenericParameter::kCpuId) { + if (tparam_->IsCPU()) { + auto n_threads = tparam_->Threads(); + std::vector score_tloc(n_threads, 0.0); + std::vector weight_tloc(n_threads, 0.0); + common::ParallelFor(info.num_row_, tparam_->Threads(), [&](size_t i) { + float wt = weights[i / n_targets]; + auto t_idx = omp_get_thread_num(); + score_tloc[t_idx] += loss(i, weights[i]); + weight_tloc[t_idx] += wt; + }); + double residue_sum = std::accumulate(score_tloc.cbegin(), score_tloc.cend(), 0.0); + double weights_sum = std::accumulate(weight_tloc.cbegin(), weight_tloc.cend(), 0.0); + result = PackedReduceResult{residue_sum, weights_sum}; + } else { #if defined(XGBOOST_USE_CUDA) dh::XGBCachingDeviceAllocator alloc; thrust::counting_iterator begin(0); @@ -249,7 +261,8 @@ class RegularizedLogLoss : public Metric { result = thrust::transform_reduce( thrust::cuda::par(alloc), begin, end, [=] XGBOOST_DEVICE(size_t idx) { - float weight = weights[idx / n_targets]; + auto sample_id = std::get<0>(linalg::UnravelIndex(idx, labels.Shape())); + float weight = weights[sample_id]; float l = loss(idx, weight); return PackedReduceResult{l, weight}; }, @@ -257,19 +270,6 @@ class RegularizedLogLoss : public Metric { #else common::AssertGPUSupport(); #endif // defined(XGBOOST_USE_CUDA) - } else { - auto n_threads = tparam_->Threads(); - std::vector score_tloc(n_threads, 0.0); - std::vector weight_tloc(n_threads, 0.0); - common::ParallelFor(info.num_row_, tparam_->Threads(), [&](size_t i) { - float wt = weights[i / n_targets]; - auto t_idx = omp_get_thread_num(); - score_tloc[t_idx] += loss(i, weights[i]); - weight_tloc[t_idx] += wt; - }); - double residue_sum = std::accumulate(score_tloc.cbegin(), score_tloc.cend(), 0.0); - double weights_sum = std::accumulate(weight_tloc.cbegin(), weight_tloc.cend(), 0.0); - result = PackedReduceResult{residue_sum, weights_sum}; } double dat[2]{result.Residue(), result.Weights()}; diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index 41f8cb511ae7..7f1ee4e68699 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -30,6 +30,17 @@ namespace xgboost { namespace obj { +namespace { +void CheckRegInputs(MetaInfo const& info, HostDeviceVector const& preds) { + CHECK_EQ(info.labels.Shape(0), info.num_row_) << "Invalid shape of labels."; + CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels."; + CHECK_EQ(info.labels.Shape(1), 1); + if (!info.weights_.Empty()) { + CHECK_EQ(info.weights_.Size(), info.num_row_) + << "Number of weights should be equal to number of data points."; + } +} +} // anonymous namespace #if defined(XGBOOST_USE_CUDA) DMLC_REGISTRY_FILE_TAG(regression_obj_gpu); @@ -211,8 +222,33 @@ XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear") class RegularizedClassification : public ObjFunction { BinaryRegularizationParam param_; + void ValidateInfo(MetaInfo const& info) const { + HostDeviceVector flag(1, 0); + flag.SetDevice(ctx_->gpu_id); + auto vflag = ctx_->IsCPU() ? flag.HostSpan() : flag.DeviceSpan(); + auto sensitive_feat = info.sensitive_features.View(ctx_->gpu_id); + auto check = [=](size_t i, float y) { + if (!LogisticClassification::CheckLabel(y)) { + vflag[0] = 1; + } + if (!LogisticClassification::CheckLabel(sensitive_feat(i))) { + vflag[0] = 1; + } + }; + if (ctx_->IsCPU()) { + linalg::ElementWiseKernelHost(info.labels.HostView(), ctx_->Threads(), check); + } else { + linalg::ElementWiseKernelDevice(info.labels.HostView(), check); + } + if (flag.HostVector()[0] == 1) { + LOG(FATAL) << LogisticClassification::LabelErrorMsg() + << " and sensitive feature must be in [0, 1] for regularized logistic."; + } + } + public: void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } + struct ObjInfo Task() const override { return {ObjInfo::kBinary, false}; } @@ -221,21 +257,21 @@ class RegularizedClassification : public ObjFunction { return std::max(static_cast(1), info.labels.Shape(1)); } - void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, int, + void GetGradient(HostDeviceVector const& preds, const MetaInfo& info, int iter, HostDeviceVector* out_gpair) override { CHECK_EQ(info.sensitive_features.Size(), info.num_row_) << "Incorrect shape of sensitive features, Expecting: (" << info.num_row_ << "), got: (" << info.sensitive_features.Size() << ")"; - CHECK_EQ(info.labels.Shape(0), info.num_row_); - CHECK_EQ(info.labels.Shape(1), 1); + CheckRegInputs(info, preds); + if (iter == 0) { + // Unlike other objectives, no validation during gradient calculation. + this->ValidateInfo(info); + } auto fairness = param_.fairness; - if (!info.weights_.Empty()) { - CHECK_EQ(info.weights_.Size(), info.num_row_) - << "Number of weights should be equal to number of data points."; - } + auto labels = info.labels.View(ctx_->gpu_id); - auto fn = [fairness] XGBOOST_DEVICE( + auto fn = [fairness, labels] XGBOOST_DEVICE( size_t i, float y, linalg::TensorView predt_t, common::OptionalWeights weight, linalg::TensorView sensitive, linalg::TensorView gpair) { @@ -243,33 +279,33 @@ class RegularizedClassification : public ObjFunction { auto sf = sensitive(i); auto grad = (predt - y) + (fairness * (sf - predt)); auto hess = (1.0f - fairness) * predt * (1.0f - predt); - auto w = weight[i]; + auto w = weight[std::get<1>(linalg::UnravelIndex(i, labels.Shape()))]; gpair(i) = {grad * w, hess * w}; }; auto sensitive = info.sensitive_features.View(ctx_->gpu_id); out_gpair->SetDevice(ctx_->gpu_id); out_gpair->Resize(info.num_row_); + auto gpair = linalg::MakeVec(out_gpair); + preds.SetDevice(ctx_->gpu_id); + auto predt = linalg::MakeVec(&preds); info.weights_.SetDevice(ctx_->gpu_id); - if (ctx_->gpu_id != GenericParameter::kCpuId) { + + if (ctx_->IsCPU()) { + common::OptionalWeights weight{info.weights_.ConstHostSpan()}; + linalg::ElementWiseKernelHost(labels, ctx_->Threads(), [=] XGBOOST_DEVICE(size_t i, float y) { + fn(i, y, predt, weight, sensitive, gpair); + }); + } else { #if defined(XGBOOST_USE_CUDA) - auto gpair = linalg::MakeVec(out_gpair->DevicePointer(), info.num_row_, ctx_->gpu_id); - auto predt = linalg::MakeVec(preds.ConstDevicePointer(), preds.Size(), ctx_->gpu_id); common::OptionalWeights weight{info.weights_.ConstDeviceSpan()}; - linalg::ElementWiseKernelDevice( - info.labels.View(ctx_->gpu_id), - [=] XGBOOST_DEVICE(size_t i, float y) { fn(i, y, predt, weight, sensitive, gpair); }); + linalg::ElementWiseKernelDevice(labels, [=] XGBOOST_DEVICE(size_t i, float y) { + fn(i, y, predt, weight, sensitive, gpair); + }); #else common::AssertGPUSupport(); #endif // defined(XGBOOST_USE_CUDA) - } else { - auto gpair = linalg::MakeVec(out_gpair->HostPointer(), info.num_row_); - auto predt = linalg::MakeVec(preds.ConstHostPointer(), preds.Size(), ctx_->gpu_id); - common::OptionalWeights weight{info.weights_.ConstHostSpan()}; - linalg::ElementWiseKernelHost( - info.labels.HostView(), ctx_->Threads(), - [=] XGBOOST_DEVICE(size_t i, float y) { fn(i, y, predt, weight, sensitive, gpair); }); } } From 7cf0d0c021e86654c2ca7ee2138105935ce90775 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 10 Feb 2022 05:28:45 +0800 Subject: [PATCH 08/20] Abstract away the dispatch. --- src/common/linalg_op.cuh | 10 ++++++++- src/common/linalg_op.h | 22 ++++++++++++++++++++ src/objective/regression_obj.cu | 37 +++++++++------------------------ 3 files changed, 41 insertions(+), 28 deletions(-) diff --git a/src/common/linalg_op.cuh b/src/common/linalg_op.cuh index bdbee205b3ca..c4557a0b57ad 100644 --- a/src/common/linalg_op.cuh +++ b/src/common/linalg_op.cuh @@ -1,9 +1,12 @@ /*! - * Copyright 2021 by XGBoost Contributors + * Copyright 2021-2022 by XGBoost Contributors */ #ifndef XGBOOST_COMMON_LINALG_OP_CUH_ #define XGBOOST_COMMON_LINALG_OP_CUH_ + +#include "xgboost/generic_parameters.h" #include "device_helpers.cuh" +#include "linalg_op.h" #include "xgboost/linalg.h" namespace xgboost { @@ -35,6 +38,11 @@ void ElementWiseTransformDevice(linalg::TensorView t, Fn&& fn, cudaStream_ }); } } + +template +void ElementWiseKernel(GenericParameter const* ctx, linalg::TensorView t, Fn&& fn) { + ctx->IsCPU() ? ElementWiseKernelHost(t, ctx->Threads(), fn) : ElementWiseKernelDevice(t, fn); +} } // namespace linalg } // namespace xgboost #endif // XGBOOST_COMMON_LINALG_OP_CUH_ diff --git a/src/common/linalg_op.h b/src/common/linalg_op.h index 71c9991fbf00..05f050772ccc 100644 --- a/src/common/linalg_op.h +++ b/src/common/linalg_op.h @@ -5,7 +5,9 @@ #define XGBOOST_COMMON_LINALG_OP_H_ #include +#include "common.h" #include "threading_utils.h" +#include "xgboost/generic_parameters.h" #include "xgboost/linalg.h" namespace xgboost { @@ -37,6 +39,26 @@ void ElementWiseKernelHost(linalg::TensorView t, int32_t n_threads, Fn&& f }); } } + +#if !defined(XGBOOST_USE_CUDA) +template +void ElementWiseKernelDevice(linalg::TensorView t, Fn&& fn, void* s = nullptr) { + common::AssertGPUSupport(); +} + +template +void ElementWiseTransformDevice(linalg::TensorView t, Fn&& fn, void* s = nullptr) { + common::AssertGPUSupport(); +} + +template +void ElementWiseKernel(GenericParameter const* ctx, linalg::TensorView t, Fn&& fn) { + if (!ctx->IsCPU()) { + common::AssertGPUSupport(); + } + ElementWiseKernelHost(t, ctx->Threads(), fn); +} +#endif // !defined(XGBOOST_USE_CUDA) } // namespace linalg } // namespace xgboost #endif // XGBOOST_COMMON_LINALG_OP_H_ diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index 7f1ee4e68699..6baf950348b9 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -271,18 +271,6 @@ class RegularizedClassification : public ObjFunction { auto fairness = param_.fairness; auto labels = info.labels.View(ctx_->gpu_id); - auto fn = [fairness, labels] XGBOOST_DEVICE( - size_t i, float y, linalg::TensorView predt_t, - common::OptionalWeights weight, linalg::TensorView sensitive, - linalg::TensorView gpair) { - auto predt = common::Sigmoid(predt_t(i)); - auto sf = sensitive(i); - auto grad = (predt - y) + (fairness * (sf - predt)); - auto hess = (1.0f - fairness) * predt * (1.0f - predt); - auto w = weight[std::get<1>(linalg::UnravelIndex(i, labels.Shape()))]; - gpair(i) = {grad * w, hess * w}; - }; - auto sensitive = info.sensitive_features.View(ctx_->gpu_id); out_gpair->SetDevice(ctx_->gpu_id); out_gpair->Resize(info.num_row_); @@ -291,22 +279,17 @@ class RegularizedClassification : public ObjFunction { preds.SetDevice(ctx_->gpu_id); auto predt = linalg::MakeVec(&preds); info.weights_.SetDevice(ctx_->gpu_id); + common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan() + : info.weights_.ConstDeviceSpan()}; - if (ctx_->IsCPU()) { - common::OptionalWeights weight{info.weights_.ConstHostSpan()}; - linalg::ElementWiseKernelHost(labels, ctx_->Threads(), [=] XGBOOST_DEVICE(size_t i, float y) { - fn(i, y, predt, weight, sensitive, gpair); - }); - } else { -#if defined(XGBOOST_USE_CUDA) - common::OptionalWeights weight{info.weights_.ConstDeviceSpan()}; - linalg::ElementWiseKernelDevice(labels, [=] XGBOOST_DEVICE(size_t i, float y) { - fn(i, y, predt, weight, sensitive, gpair); - }); -#else - common::AssertGPUSupport(); -#endif // defined(XGBOOST_USE_CUDA) - } + linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(size_t i, float y) mutable { + auto p = common::Sigmoid(predt(i)); + auto sf = sensitive(i); + auto grad = (p - y) + (fairness * (sf - p)); + auto hess = (1.0f - fairness) * p * (1.0f - p); + auto w = weight[std::get<1>(linalg::UnravelIndex(i, labels.Shape()))]; + gpair(i) = {grad * w, hess * w}; + }); } void PredTransform(HostDeviceVector* io_preds) const override { From c72c49988fd528ef6b6e0381cc350bd965ccf036 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 10 Feb 2022 05:30:50 +0800 Subject: [PATCH 09/20] fix. --- src/metric/elementwise_metric.cu | 3 ++- src/objective/regression_obj.cu | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index 9db4a0aa322b..a57aafd50b64 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -245,7 +245,8 @@ class RegularizedLogLoss : public Metric { std::vector score_tloc(n_threads, 0.0); std::vector weight_tloc(n_threads, 0.0); common::ParallelFor(info.num_row_, tparam_->Threads(), [&](size_t i) { - float wt = weights[i / n_targets]; + auto sample_id = std::get<0>(linalg::UnravelIndex(i, labels.Shape())); + float wt = weights[sample_id]; auto t_idx = omp_get_thread_num(); score_tloc[t_idx] += loss(i, weights[i]); weight_tloc[t_idx] += wt; diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index 6baf950348b9..3b1821ded76d 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -287,7 +287,7 @@ class RegularizedClassification : public ObjFunction { auto sf = sensitive(i); auto grad = (p - y) + (fairness * (sf - p)); auto hess = (1.0f - fairness) * p * (1.0f - p); - auto w = weight[std::get<1>(linalg::UnravelIndex(i, labels.Shape()))]; + auto w = weight[std::get<0>(linalg::UnravelIndex(i, labels.Shape()))]; gpair(i) = {grad * w, hess * w}; }); } From c66dbb07003d3176a16dc8394164953963301b2e Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 10 Feb 2022 05:38:53 +0800 Subject: [PATCH 10/20] cleanups. --- src/common/linalg_op.cuh | 2 +- src/objective/regression_obj.cu | 17 +++++++---------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/common/linalg_op.cuh b/src/common/linalg_op.cuh index c4557a0b57ad..f0f89df8ab26 100644 --- a/src/common/linalg_op.cuh +++ b/src/common/linalg_op.cuh @@ -17,7 +17,7 @@ void ElementWiseKernelDevice(linalg::TensorView t, Fn&& fn, cudaStream_t s "For function with return, use transform instead."); if (t.Contiguous()) { auto ptr = t.Values().data(); - dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { fn(i, ptr[i]); }); + dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable { fn(i, ptr[i]); }); } else { dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable { T& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape())); diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index 3b1821ded76d..22acc7c8fccf 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -222,24 +222,21 @@ XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear") class RegularizedClassification : public ObjFunction { BinaryRegularizationParam param_; + public: void ValidateInfo(MetaInfo const& info) const { HostDeviceVector flag(1, 0); flag.SetDevice(ctx_->gpu_id); auto vflag = ctx_->IsCPU() ? flag.HostSpan() : flag.DeviceSpan(); - auto sensitive_feat = info.sensitive_features.View(ctx_->gpu_id); - auto check = [=](size_t i, float y) { + auto sensitive = info.sensitive_features.View(ctx_->gpu_id); + auto labels = info.labels.View(ctx_->gpu_id); + linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(size_t i, float y) { if (!LogisticClassification::CheckLabel(y)) { vflag[0] = 1; } - if (!LogisticClassification::CheckLabel(sensitive_feat(i))) { + if (!LogisticClassification::CheckLabel(sensitive(i))) { vflag[0] = 1; } - }; - if (ctx_->IsCPU()) { - linalg::ElementWiseKernelHost(info.labels.HostView(), ctx_->Threads(), check); - } else { - linalg::ElementWiseKernelDevice(info.labels.HostView(), check); - } + }); if (flag.HostVector()[0] == 1) { LOG(FATAL) << LogisticClassification::LabelErrorMsg() << " and sensitive feature must be in [0, 1] for regularized logistic."; @@ -282,7 +279,7 @@ class RegularizedClassification : public ObjFunction { common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan() : info.weights_.ConstDeviceSpan()}; - linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(size_t i, float y) mutable { + linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(size_t i, float const y) mutable { auto p = common::Sigmoid(predt(i)); auto sf = sensitive(i); auto grad = (p - y) + (fairness * (sf - p)); From 8c5b5f795abe598a0d97432b0e7ea2817666ce43 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 10 Feb 2022 05:48:53 +0800 Subject: [PATCH 11/20] cleanup. --- src/metric/elementwise_metric.cu | 2 +- src/objective/regression_obj.cu | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index a57aafd50b64..ba3c1f16536f 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -193,7 +193,7 @@ namespace { XGBOOST_DEVICE inline float LogLoss(float y, float py) { auto xlogy = [](float x, float y) { float eps = 1e-16; - return (x - 0.0 == 0.0) ? 0 : (x * std::log(std::max(y, eps))); + return (x - 0.0f == 0.0f) ? 0.0f : (x * std::log(std::max(y, eps))); }; const bst_float pneg = 1.0f - py; return xlogy(-y, py) + xlogy(-(1.0f - y), pneg); diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index 22acc7c8fccf..7e2e5663361c 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -290,9 +290,8 @@ class RegularizedClassification : public ObjFunction { } void PredTransform(HostDeviceVector* io_preds) const override { - auto eps = kRtEps; // undefined in device code. common::Transform<>::Init( - [eps] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { + [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { _preds[_idx] = common::Sigmoid(_preds[_idx]); }, common::Range{0, static_cast(io_preds->Size())}, this->ctx_->Threads(), From b9af6cd05842aedf7c97a428d2120165207af46d Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 10 Feb 2022 06:26:10 +0800 Subject: [PATCH 12/20] tidy. --- src/metric/elementwise_metric.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index ba3c1f16536f..4770984535fb 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -229,11 +229,9 @@ class RegularizedLogLoss : public Metric { auto sensitive_features = info.sensitive_features.View(tparam_->gpu_id); auto labels = info.labels.View(tparam_->gpu_id); auto predts = tparam_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan(); - auto n_targets = std::max(info.labels.Shape(1), static_cast(1)); common::OptionalWeights weights(tparam_->IsCPU() ? info.weights_.ConstHostSpan() : info.weights_.ConstDeviceSpan()); float fairness = this->param_.fairness; - auto n_samples = info.num_row_; auto loss = [=] XGBOOST_DEVICE(size_t i, float wt) { auto v = (LogLoss(labels(i), predts[i]) * wt) - (fairness * LogLoss(sensitive_features(i), predts[i]) * wt); From 0b5f3aafb9799574c566d2aa81b0f87fbdaa65ee Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 10 Feb 2022 18:49:43 +0800 Subject: [PATCH 13/20] tests. --- src/metric/elementwise_metric.cu | 44 ++++++++++------- src/objective/regression_obj.cu | 2 +- tests/cpp/metric/test_elementwise_metric.cc | 53 +++++++++++++++++++-- 3 files changed, 77 insertions(+), 22 deletions(-) diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index 4770984535fb..3da7622c6cc4 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2015-2019 by Contributors + * Copyright 2015-2022 by XGBoost Contributors * \file elementwise_metric.cc * \brief evaluation metrics for elementwise binary or regression. * \author Kailong Chen, Tianqi Chen @@ -227,28 +227,39 @@ class RegularizedLogLoss : public Metric { double Eval(const HostDeviceVector& preds, const MetaInfo& info, bool distributed) override { auto sensitive_features = info.sensitive_features.View(tparam_->gpu_id); + CHECK_EQ(info.labels.Shape(0), info.num_row_); auto labels = info.labels.View(tparam_->gpu_id); + preds.SetDevice(tparam_->gpu_id); auto predts = tparam_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan(); + info.weights_.SetDevice(tparam_->gpu_id); common::OptionalWeights weights(tparam_->IsCPU() ? info.weights_.ConstHostSpan() : info.weights_.ConstDeviceSpan()); float fairness = this->param_.fairness; - auto loss = [=] XGBOOST_DEVICE(size_t i, float wt) { - auto v = (LogLoss(labels(i), predts[i]) * wt) - - (fairness * LogLoss(sensitive_features(i), predts[i]) * wt); - return v; + auto loss = [=] XGBOOST_DEVICE(size_t i) { + size_t sample_id; + size_t target_id; + std::tie(sample_id, target_id) = linalg::UnravelIndex(i, labels.Shape()); + float wt = weights[sample_id]; + auto sf = sensitive_features(sample_id); + + auto logloss = (LogLoss(labels(sample_id, target_id), predts[i]) * wt); + auto reg = (fairness * LogLoss(sf, predts[i]) * wt); + auto v = logloss - reg; + return std::make_tuple(v, wt); }; PackedReduceResult result; if (tparam_->IsCPU()) { auto n_threads = tparam_->Threads(); std::vector score_tloc(n_threads, 0.0); std::vector weight_tloc(n_threads, 0.0); - common::ParallelFor(info.num_row_, tparam_->Threads(), [&](size_t i) { - auto sample_id = std::get<0>(linalg::UnravelIndex(i, labels.Shape())); - float wt = weights[sample_id]; - auto t_idx = omp_get_thread_num(); - score_tloc[t_idx] += loss(i, weights[i]); - weight_tloc[t_idx] += wt; - }); + common::ParallelFor(info.labels.Shape(0) * info.labels.Shape(1), tparam_->Threads(), + [&](size_t i) { + auto t_idx = omp_get_thread_num(); + float v, wt; + std::tie(v, wt) = loss(i); + score_tloc[t_idx] += v; + weight_tloc[t_idx] += wt; + }); double residue_sum = std::accumulate(score_tloc.cbegin(), score_tloc.cend(), 0.0); double weights_sum = std::accumulate(weight_tloc.cbegin(), weight_tloc.cend(), 0.0); result = PackedReduceResult{residue_sum, weights_sum}; @@ -259,11 +270,10 @@ class RegularizedLogLoss : public Metric { thrust::counting_iterator end = begin + info.num_row_; result = thrust::transform_reduce( thrust::cuda::par(alloc), begin, end, - [=] XGBOOST_DEVICE(size_t idx) { - auto sample_id = std::get<0>(linalg::UnravelIndex(idx, labels.Shape())); - float weight = weights[sample_id]; - float l = loss(idx, weight); - return PackedReduceResult{l, weight}; + [=] XGBOOST_DEVICE(size_t i) { + float v, wt; + std::tie(v, wt) = loss(i); + return PackedReduceResult{v, wt}; }, PackedReduceResult{}, thrust::plus()); #else diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index 7e2e5663361c..7ea6635e3bce 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -257,7 +257,7 @@ class RegularizedClassification : public ObjFunction { void GetGradient(HostDeviceVector const& preds, const MetaInfo& info, int iter, HostDeviceVector* out_gpair) override { CHECK_EQ(info.sensitive_features.Size(), info.num_row_) - << "Incorrect shape of sensitive features, Expecting: (" << info.num_row_ << "), got: (" + << "Invalid shape of sensitive features, Expecting: (" << info.num_row_ << "), got: (" << info.sensitive_features.Size() << ")"; CheckRegInputs(info, preds); if (iter == 0) { diff --git a/tests/cpp/metric/test_elementwise_metric.cc b/tests/cpp/metric/test_elementwise_metric.cc index 514b8753ccad..f843c28b59d3 100644 --- a/tests/cpp/metric/test_elementwise_metric.cc +++ b/tests/cpp/metric/test_elementwise_metric.cc @@ -1,12 +1,13 @@ /*! - * Copyright 2018-2019 XGBoost contributors + * Copyright 2018-2022 by XGBoost contributors */ -#include #include +#include #include #include +#include "../../../src/common/linalg_op.h" #include "../helpers.h" namespace xgboost { @@ -288,8 +289,8 @@ TEST(Metric, DeclareUnifiedTest(MultiRMSE)) { HostDeviceVector predt(n_samples * n_targets, 0); - auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX); - std::unique_ptr metric{Metric::Create("rmse", &lparam)}; + auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX); + std::unique_ptr metric{Metric::Create("rmse", &ctx)}; metric->Configure({}); auto loss = GetMultiMetricEval(metric.get(), predt, y); @@ -301,5 +302,49 @@ TEST(Metric, DeclareUnifiedTest(MultiRMSE)) { ASSERT_FLOAT_EQ(ret, loss); ASSERT_FLOAT_EQ(ret, loss_w); } + +TEST(Metric, DeclareUnifiedTest(RegularizedLogLoss)) { + auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX); + + size_t n_samples = 8, n_targets = 2; + MetaInfo info; + info.num_row_ = n_samples; + float fairness = 0.0; + + std::unique_ptr metric{Metric::Create("regularized-logloss", &ctx)}; + metric->Configure(Args{{"fairness", std::to_string(fairness)}}); + + info.labels = linalg::Tensor{{n_samples, n_targets}, GPUIDX}; + auto h_y = info.labels.HostView(); + info.sensitive_features = linalg::Tensor{{n_samples}, GPUIDX}; + auto h_s = info.sensitive_features.HostView(); + + HostDeviceVector predt(n_samples * n_targets, 0); + linalg::ElementWiseTransformHost( + h_y, ctx.Threads(), [](size_t i, float) { return static_cast(i % 2 == 0 ? 0 : 1); }); + linalg::ElementWiseTransformHost( + h_s, ctx.Threads(), [](size_t i, float) { return static_cast(i % 2 == 1 ? 0 : 1); }); + + double regularized_res{metric->Eval(predt, info, false)}; + auto check = [&]() { + std::unique_ptr logloss{Metric::Create("logloss", &ctx)}; + logloss->Configure({}); + auto res = logloss->Eval(predt, info, false); + ASSERT_EQ(res, regularized_res); + }; + check(); + + info.weights_.Resize(n_samples); + auto &h_w = info.weights_.HostVector(); + std::iota(h_w.begin(), h_w.end(), 0.0f); + + regularized_res = metric->Eval(predt, info, false); + check(); + + fairness = 0.5; + metric->Configure(Args{{"fairness", std::to_string(fairness)}}); + regularized_res = metric->Eval(predt, info, false); + ASSERT_FLOAT_EQ(regularized_res, 10.5261f); +} } // namespace metric } // namespace xgboost From 03ace45866c49138390ce1b516196efc32cebd9b Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 10 Feb 2022 19:19:04 +0800 Subject: [PATCH 14/20] Refactor the elementwise metric. --- src/metric/elementwise_metric.cu | 214 ++++++++++--------------------- 1 file changed, 71 insertions(+), 143 deletions(-) diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index 3da7622c6cc4..044fa4949fb9 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -32,109 +32,63 @@ namespace metric { // tag the this file, used by force static link later. DMLC_REGISTRY_FILE_TAG(elementwise_metric); -template -class ElementWiseMetricsReduction { - public: - explicit ElementWiseMetricsReduction(EvalRow policy) : policy_(std::move(policy)) {} - - PackedReduceResult - CpuReduceMetrics(const HostDeviceVector &weights, - linalg::TensorView labels, - const HostDeviceVector &preds, - int32_t n_threads) const { - size_t ndata = labels.Size(); - auto n_targets = std::max(labels.Shape(1), static_cast(1)); - auto h_labels = labels.Values(); - - const auto& h_weights = weights.HostVector(); - const auto& h_preds = preds.HostVector(); - +namespace { +/** + * \brief Reduce function for element wise metrics. + * + * The loss function should handle all the computation for each sample, including + * applying the weights. A tuple of {error_i, weight_i} is expected as return. + */ +template +PackedReduceResult Reduce(GenericParameter const* ctx, MetaInfo const& info, Fn&& loss) { + PackedReduceResult result; + auto labels = info.labels.View(ctx->gpu_id); + if (ctx->IsCPU()) { + auto n_threads = ctx->Threads(); std::vector score_tloc(n_threads, 0.0); std::vector weight_tloc(n_threads, 0.0); - // We sum over losses over all samples and targets instead of performing this for each // target since the first one approach more accurate while the second approach is used // for approximation in distributed setting. For rmse: // - sqrt(1/w(sum_t0 + sum_t1 + ... + sum_tm)) // multi-target // - sqrt(avg_t0) + sqrt(avg_t1) + ... sqrt(avg_tm) // distributed - common::ParallelFor(ndata, n_threads, [&](size_t i) { - float wt = h_weights.size() > 0 ? h_weights[i / n_targets] : 1.0f; + common::ParallelFor(info.labels.Size(), ctx->Threads(), [&](size_t i) { auto t_idx = omp_get_thread_num(); - score_tloc[t_idx] += policy_.EvalRow(h_labels[i], h_preds[i]) * wt; + size_t sample_id; + size_t target_id; + std::tie(sample_id, target_id) = linalg::UnravelIndex(i, labels.Shape()); + + float v, wt; + std::tie(v, wt) = loss(i, sample_id, target_id); + score_tloc[t_idx] += v; weight_tloc[t_idx] += wt; }); double residue_sum = std::accumulate(score_tloc.cbegin(), score_tloc.cend(), 0.0); double weights_sum = std::accumulate(weight_tloc.cbegin(), weight_tloc.cend(), 0.0); - - PackedReduceResult res { residue_sum, weights_sum }; - return res; - } - + result = PackedReduceResult{residue_sum, weights_sum}; + } else { #if defined(XGBOOST_USE_CUDA) - - PackedReduceResult DeviceReduceMetrics( - const HostDeviceVector& weights, - linalg::TensorView labels, - const HostDeviceVector& preds) { - size_t n_data = preds.Size(); - auto n_targets = std::max(labels.Shape(1), static_cast(1)); - - thrust::counting_iterator begin(0); - thrust::counting_iterator end = begin + n_data; - - auto s_label = labels.Values(); - auto s_preds = preds.DeviceSpan(); - auto s_weights = weights.DeviceSpan(); - - bool const is_null_weight = weights.Size() == 0; - - auto d_policy = policy_; - dh::XGBCachingDeviceAllocator alloc; - PackedReduceResult result = thrust::transform_reduce( - thrust::cuda::par(alloc), - begin, end, - [=] XGBOOST_DEVICE(size_t idx) { - float weight = is_null_weight ? 1.0f : s_weights[idx / n_targets]; - - float residue = d_policy.EvalRow(s_label[idx], s_preds[idx]); - residue *= weight; - return PackedReduceResult{ residue, weight }; + thrust::counting_iterator begin(0); + thrust::counting_iterator end = begin + labels.Size(); + result = thrust::transform_reduce( + thrust::cuda::par(alloc), begin, end, + [=] XGBOOST_DEVICE(size_t i) { + auto idx = linalg::UnravelIndex(i, labels.Shape()); + auto sample_id = std::get<0>(idx); + auto target_id = std::get<1>(idx); + auto res = loss(i, sample_id, target_id); + float v{std::get<0>(res)}, wt{std::get<1>(res)}; + return PackedReduceResult{v, wt}; }, - PackedReduceResult(), - thrust::plus()); - - return result; - } - -#endif // XGBOOST_USE_CUDA - - PackedReduceResult Reduce(const GenericParameter& ctx, const HostDeviceVector& weights, - linalg::Tensor const& labels, - const HostDeviceVector& preds) { - PackedReduceResult result; - - if (ctx.gpu_id < 0) { - auto n_threads = ctx.Threads(); - result = CpuReduceMetrics(weights, labels.HostView(), preds, n_threads); - } -#if defined(XGBOOST_USE_CUDA) - else { // NOLINT - preds.SetDevice(ctx.gpu_id); - weights.SetDevice(ctx.gpu_id); - - dh::safe_cuda(cudaSetDevice(ctx.gpu_id)); - result = DeviceReduceMetrics(weights, labels.View(ctx.gpu_id), preds); - } -#endif // defined(XGBOOST_USE_CUDA) - return result; + PackedReduceResult{}, thrust::plus()); +#else + common::AssertGPUSupport(); +#endif // defined(XGBOOST_USE_CUDA) } - - private: - EvalRow policy_; -#if defined(XGBOOST_USE_CUDA) -#endif // defined(XGBOOST_USE_CUDA) -}; + return result; +} +} // anonymous namespace struct EvalRowRMSE { char const *Name() const { @@ -235,52 +189,16 @@ class RegularizedLogLoss : public Metric { common::OptionalWeights weights(tparam_->IsCPU() ? info.weights_.ConstHostSpan() : info.weights_.ConstDeviceSpan()); float fairness = this->param_.fairness; - auto loss = [=] XGBOOST_DEVICE(size_t i) { - size_t sample_id; - size_t target_id; - std::tie(sample_id, target_id) = linalg::UnravelIndex(i, labels.Shape()); - float wt = weights[sample_id]; - auto sf = sensitive_features(sample_id); - - auto logloss = (LogLoss(labels(sample_id, target_id), predts[i]) * wt); - auto reg = (fairness * LogLoss(sf, predts[i]) * wt); - auto v = logloss - reg; - return std::make_tuple(v, wt); - }; - PackedReduceResult result; - if (tparam_->IsCPU()) { - auto n_threads = tparam_->Threads(); - std::vector score_tloc(n_threads, 0.0); - std::vector weight_tloc(n_threads, 0.0); - common::ParallelFor(info.labels.Shape(0) * info.labels.Shape(1), tparam_->Threads(), - [&](size_t i) { - auto t_idx = omp_get_thread_num(); - float v, wt; - std::tie(v, wt) = loss(i); - score_tloc[t_idx] += v; - weight_tloc[t_idx] += wt; - }); - double residue_sum = std::accumulate(score_tloc.cbegin(), score_tloc.cend(), 0.0); - double weights_sum = std::accumulate(weight_tloc.cbegin(), weight_tloc.cend(), 0.0); - result = PackedReduceResult{residue_sum, weights_sum}; - } else { -#if defined(XGBOOST_USE_CUDA) - dh::XGBCachingDeviceAllocator alloc; - thrust::counting_iterator begin(0); - thrust::counting_iterator end = begin + info.num_row_; - result = thrust::transform_reduce( - thrust::cuda::par(alloc), begin, end, - [=] XGBOOST_DEVICE(size_t i) { - float v, wt; - std::tie(v, wt) = loss(i); - return PackedReduceResult{v, wt}; - }, - PackedReduceResult{}, thrust::plus()); -#else - common::AssertGPUSupport(); -#endif // defined(XGBOOST_USE_CUDA) - } - + PackedReduceResult result = + Reduce(tparam_, info, [=] XGBOOST_DEVICE(size_t i, size_t sample_id, size_t target_id) { + float wt = weights[sample_id]; + auto sf = sensitive_features(sample_id); + + auto logloss = (LogLoss(labels(sample_id, target_id), predts[i]) * wt); + auto reg = (fairness * LogLoss(sf, predts[i]) * wt); + auto v = logloss - reg; + return std::make_tuple(v, wt); + }); double dat[2]{result.Residue(), result.Weights()}; if (distributed) { rabit::Allreduce(dat, 2); @@ -435,20 +353,33 @@ struct EvalTweedieNLogLik { * \brief base class of element-wise evaluation * \tparam Derived the name of subclass */ -template +template struct EvalEWiseBase : public Metric { EvalEWiseBase() = default; - explicit EvalEWiseBase(char const* policy_param) : - policy_{policy_param}, reducer_{policy_} {} + explicit EvalEWiseBase(char const* policy_param) : policy_{policy_param} {} - double Eval(const HostDeviceVector &preds, const MetaInfo &info, + double Eval(HostDeviceVector const& preds, const MetaInfo& info, bool distributed) override { CHECK_EQ(preds.Size(), info.labels.Size()) << "label and prediction size not match, " << "hint: use merror or mlogloss for multi-class classification"; - auto result = reducer_.Reduce(*tparam_, info.weights_, info.labels, preds); + auto labels = info.labels.View(tparam_->gpu_id); + info.weights_.SetDevice(tparam_->gpu_id); + common::OptionalWeights weights(tparam_->IsCPU() ? info.weights_.ConstHostSpan() + : info.weights_.ConstDeviceSpan()); + preds.SetDevice(tparam_->gpu_id); + auto predts = tparam_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan(); - double dat[2] { result.Residue(), result.Weights() }; + auto d_policy = policy_; + auto result = + Reduce(tparam_, info, [=] XGBOOST_DEVICE(size_t i, size_t sample_id, size_t target_id) { + float wt = weights[sample_id]; + float residue = d_policy.EvalRow(labels(sample_id, target_id), predts[i]); + residue *= wt; + return std::make_tuple(residue, wt); + }); + + double dat[2]{result.Residue(), result.Weights()}; if (distributed) { rabit::Allreduce(dat, 2); @@ -456,13 +387,10 @@ struct EvalEWiseBase : public Metric { return Policy::GetFinal(dat[0], dat[1]); } - const char* Name() const override { - return policy_.Name(); - } + const char* Name() const override { return policy_.Name(); } private: Policy policy_; - ElementWiseMetricsReduction reducer_{policy_}; }; XGBOOST_REGISTER_METRIC(RMSE, "rmse") From a0fcb46cdfd650b0a8ae1763397b308387c38baf Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 10 Feb 2022 19:55:12 +0800 Subject: [PATCH 15/20] Check shape. --- src/metric/elementwise_metric.cu | 1 + tests/cpp/helpers.cc | 4 ++-- tests/cpp/metric/test_elementwise_metric.cc | 4 +++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index 044fa4949fb9..2132ee312634 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -363,6 +363,7 @@ struct EvalEWiseBase : public Metric { CHECK_EQ(preds.Size(), info.labels.Size()) << "label and prediction size not match, " << "hint: use merror or mlogloss for multi-class classification"; + CHECK_NE(info.labels.Shape(1), 0); auto labels = info.labels.View(tparam_->gpu_id); info.weights_.SetDevice(tparam_->gpu_id); common::OptionalWeights weights(tparam_->IsCPU() ? info.weights_.ConstHostSpan() diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index fe32a0593792..66672047dd11 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -136,8 +136,8 @@ void CheckRankingObjFunction(std::unique_ptr const& obj, std::vector out_hess) { xgboost::MetaInfo info; info.num_row_ = labels.size(); - info.labels = - xgboost::linalg::Tensor{labels.cbegin(), labels.cend(), {labels.size()}, -1}; + info.labels = xgboost::linalg::Tensor{ + labels.cbegin(), labels.cend(), {labels.size(), static_cast(1)}, -1}; info.weights_.HostVector() = weights; info.group_ptr_ = groups; diff --git a/tests/cpp/metric/test_elementwise_metric.cc b/tests/cpp/metric/test_elementwise_metric.cc index f843c28b59d3..1ac6fc9eb01f 100644 --- a/tests/cpp/metric/test_elementwise_metric.cc +++ b/tests/cpp/metric/test_elementwise_metric.cc @@ -17,14 +17,16 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) std::unique_ptr metric{Metric::Create(name.c_str(), &lparam)}; HostDeviceVector predts; + size_t n_samples = 2048; + MetaInfo info; + info.labels.Reshape(n_samples, 1); auto &h_labels = info.labels.Data()->HostVector(); auto &h_predts = predts.HostVector(); SimpleLCG lcg; SimpleRealUniformDistribution dist{0.0f, 1.0f}; - size_t n_samples = 2048; h_labels.resize(n_samples); h_predts.resize(n_samples); From 3bf3cb5531ed87c0926b21eb75719b6996bab48a Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 10 Feb 2022 20:06:54 +0800 Subject: [PATCH 16/20] cleanup. --- tests/cpp/common/test_linalg.cu | 4 +--- tests/cpp/helpers.cc | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/cpp/common/test_linalg.cu b/tests/cpp/common/test_linalg.cu index 78f6a8c25e4f..ae0eb28a70cd 100644 --- a/tests/cpp/common/test_linalg.cu +++ b/tests/cpp/common/test_linalg.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2021 by XGBoost Contributors + * Copyright 2021-2022 by XGBoost Contributors */ #include @@ -40,8 +40,6 @@ void TestElementWiseKernel() { auto t = l.View(0); ElementWiseTransformDevice(t, [] XGBOOST_DEVICE(size_t i, float) { return i; }); ASSERT_TRUE(t.CContiguous()); - ; - ; // CPU view t = l.View(GenericParameter::kCpuId); diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 66672047dd11..05c138781e0d 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2016-2020 XGBoost contributors + * Copyright 2016-2022 by XGBoost contributors */ #include #include From 07bd80bfe2845ae540acd7b8e57de63391a55273 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 10 Feb 2022 20:40:42 +0800 Subject: [PATCH 17/20] check. --- src/metric/elementwise_metric.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index 2132ee312634..b8f8ad3bdc9b 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -363,7 +363,9 @@ struct EvalEWiseBase : public Metric { CHECK_EQ(preds.Size(), info.labels.Size()) << "label and prediction size not match, " << "hint: use merror or mlogloss for multi-class classification"; - CHECK_NE(info.labels.Shape(1), 0); + if (info.labels.Size() != 0) { + CHECK_NE(info.labels.Shape(1), 0); + } auto labels = info.labels.View(tparam_->gpu_id); info.weights_.SetDevice(tparam_->gpu_id); common::OptionalWeights weights(tparam_->IsCPU() ? info.weights_.ConstHostSpan() From d7bb90d0c11f2f6b84d71c7d824f8c28441e3ad6 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 11 Feb 2022 17:44:11 +0800 Subject: [PATCH 18/20] multi target. --- src/objective/regression_obj.cu | 22 ++-- tests/cpp/objective/test_regression_obj.cc | 135 +++++++++++++++------ 2 files changed, 106 insertions(+), 51 deletions(-) diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index 7ea6635e3bce..c46637b818a9 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -34,7 +34,6 @@ namespace { void CheckRegInputs(MetaInfo const& info, HostDeviceVector const& preds) { CHECK_EQ(info.labels.Shape(0), info.num_row_) << "Invalid shape of labels."; CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels."; - CHECK_EQ(info.labels.Shape(1), 1); if (!info.weights_.Empty()) { CHECK_EQ(info.weights_.Size(), info.num_row_) << "Number of weights should be equal to number of data points."; @@ -80,20 +79,13 @@ class RegLossObj : public ObjFunction { void GetGradient(const HostDeviceVector& preds, const MetaInfo &info, int, HostDeviceVector* out_gpair) override { - CHECK_EQ(preds.Size(), info.labels.Size()) - << " " << "labels are not correctly provided" - << "preds.size=" << preds.Size() << ", label.size=" << info.labels.Size() << ", " - << "Loss: " << Loss::Name(); + CheckRegInputs(info, preds); size_t const ndata = preds.Size(); out_gpair->Resize(ndata); auto device = ctx_->gpu_id; additional_input_.HostVector().begin()[0] = 1; // Fill the label_correct flag bool is_null_weight = info.weights_.Size() == 0; - if (!is_null_weight) { - CHECK_EQ(info.weights_.Size(), info.labels.Shape(0)) - << "Number of weights should be equal to number of data points."; - } auto scale_pos_weight = param_.scale_pos_weight; additional_input_.HostVector().begin()[1] = scale_pos_weight; additional_input_.HostVector().begin()[2] = is_null_weight; @@ -233,7 +225,8 @@ class RegularizedClassification : public ObjFunction { if (!LogisticClassification::CheckLabel(y)) { vflag[0] = 1; } - if (!LogisticClassification::CheckLabel(sensitive(i))) { + auto sample_id = std::get<0>(linalg::UnravelIndex(i, labels.Shape())); + if (!LogisticClassification::CheckLabel(sensitive(sample_id))) { vflag[0] = 1; } }); @@ -269,22 +262,25 @@ class RegularizedClassification : public ObjFunction { auto labels = info.labels.View(ctx_->gpu_id); auto sensitive = info.sensitive_features.View(ctx_->gpu_id); + out_gpair->SetDevice(ctx_->gpu_id); - out_gpair->Resize(info.num_row_); + out_gpair->Resize(info.labels.Size()); auto gpair = linalg::MakeVec(out_gpair); preds.SetDevice(ctx_->gpu_id); auto predt = linalg::MakeVec(&preds); + info.weights_.SetDevice(ctx_->gpu_id); common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan() : info.weights_.ConstDeviceSpan()}; linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(size_t i, float const y) mutable { + auto sample_id = std::get<0>(linalg::UnravelIndex(i, labels.Shape())); auto p = common::Sigmoid(predt(i)); - auto sf = sensitive(i); + auto sf = sensitive(sample_id); auto grad = (p - y) + (fairness * (sf - p)); auto hess = (1.0f - fairness) * p * (1.0f - p); - auto w = weight[std::get<0>(linalg::UnravelIndex(i, labels.Shape()))]; + auto w = weight[sample_id]; gpair(i) = {grad * w, hess * w}; }); } diff --git a/tests/cpp/objective/test_regression_obj.cc b/tests/cpp/objective/test_regression_obj.cc index 8639d24a394c..de386015fbe5 100644 --- a/tests/cpp/objective/test_regression_obj.cc +++ b/tests/cpp/objective/test_regression_obj.cc @@ -355,12 +355,14 @@ TEST(Objective, DeclareUnifiedTest(TweedieRegressionBasic)) { } TEST(Objective, DeclareUnifiedTest(RegularizedClassification)) { - GenericParameter lparam = CreateEmptyGenericParam(GPUIDX); - Args args{{"fairness", "0.0"}}; - std::unique_ptr obj{ObjFunction::Create("binary:regularized", &lparam)}; + GenericParameter ctx = CreateEmptyGenericParam(GPUIDX); - obj->Configure(args); - CheckConfigReload(obj, "binary:regularized"); + { + Args args{{"fairness", "0.0"}}; + std::unique_ptr obj{ObjFunction::Create("binary:regularized", &ctx)}; + obj->Configure(args); + CheckConfigReload(obj, "binary:regularized"); + } MetaInfo info; info.num_row_ = 16; @@ -369,47 +371,104 @@ TEST(Objective, DeclareUnifiedTest(RegularizedClassification)) { for (size_t i = 0; i < h_sf.size(); ++i) { h_sf[i] = i % 2 == 0; } + HostDeviceVector reg_gpair; - info.labels = linalg::Tensor{{info.num_row_, static_cast(1)}, GPUIDX}; - auto& h_y = info.labels.Data()->HostVector(); - for (size_t i = 0; i < h_y.size(); ++i) { - h_y[i] = i % 2 != 0; - } + // fairness == 0 means unbiased + auto test_unbiased = [&](HostDeviceVector const& predts, + HostDeviceVector* reg_gpair) { + std::unique_ptr obj{ObjFunction::Create("binary:regularized", &ctx)}; + obj->Configure({{"fairness", "0.0"}}); + obj->GetGradient(predts, info, 0, reg_gpair); + auto const& h_reg = reg_gpair->ConstHostVector(); - HostDeviceVector predts; - predts.SetDevice(GPUIDX); - predts.Resize(info.num_row_); - auto& h_predts = predts.HostVector(); - for (size_t i = 0; i < h_y.size(); ++i) { - h_predts[i] = i % 2 != 0; - } + std::unique_ptr logistic{ObjFunction::Create("binary:logistic", &ctx)}; + logistic->Configure({}); - HostDeviceVector reg_gpair; - obj->GetGradient(predts, info, 0, ®_gpair); - auto const& h_reg = reg_gpair.ConstHostVector(); + HostDeviceVector logistic_gpair; + logistic->GetGradient(predts, info, 0, &logistic_gpair); + auto const& h_logistic = logistic_gpair.ConstHostVector(); - // fairness == 0 means unbiased - std::unique_ptr logistic{ObjFunction::Create("binary:logistic", &lparam)}; - logistic->Configure({}); - HostDeviceVector logistic_gpair; - obj->GetGradient(predts, info, 0, &logistic_gpair); - auto const& h_logistic = logistic_gpair.ConstHostVector(); - for (size_t i = 0; i < h_reg.size(); ++i) { - ASSERT_EQ(h_logistic[i], h_reg[i]); - } + for (size_t i = 0; i < h_reg.size(); ++i) { + ASSERT_EQ(h_logistic[i], h_reg[i]); + } + }; - auto test_regularized = [&]() { + auto test_regularized = [&](HostDeviceVector const& predts, + HostDeviceVector* reg_gpair) { + std::unique_ptr obj{ObjFunction::Create("binary:regularized", &ctx)}; obj->Configure({{"fairness", "1.0"}}); - obj->GetGradient(predts, info, 0, ®_gpair); - auto const& h_reg = reg_gpair.ConstHostVector(); - for (size_t i = 0; i < h_reg.size(); ++i) { - ASSERT_EQ(h_reg[i].GetHess(), 0.0f); - ASSERT_EQ(h_reg[i].GetGrad(), i % 2 == 0 ? 1.0 : -1.0); + obj->GetGradient(predts, info, 0, reg_gpair); + auto const& h_reg = reg_gpair->ConstHostVector(); + auto h_y = info.labels.HostView(); + size_t strides[] = {h_y.Stride(0), h_y.Stride(1)}; + + for (size_t i = 0; i < info.labels.Shape(1); ++i) { + for (size_t j = 0; j < info.labels.Shape(0); ++j) { + auto offset = linalg::detail::Offset<0ul>(strides, 0ul, j, i); + + ASSERT_EQ(h_reg[offset].GetHess(), 0.0f); + ASSERT_EQ(h_reg[offset].GetGrad(), j % 2 == 0 ? 1.0 : -1.0); + } } }; - test_regularized(); - info.weights_.Resize(info.num_row_, 1.0); - test_regularized(); + + { + info.labels = linalg::Tensor{{info.num_row_, static_cast(1)}, GPUIDX}; + auto& h_y = info.labels.Data()->HostVector(); + for (size_t i = 0; i < h_y.size(); ++i) { + h_y[i] = i % 2 != 0; + } + + HostDeviceVector predts; + predts.SetDevice(GPUIDX); + predts.Resize(info.num_row_); + auto& h_predts = predts.HostVector(); + for (size_t i = 0; i < h_y.size(); ++i) { + h_predts[i] = i % 2 != 0; + } + + info.weights_.Resize(0); + test_unbiased(predts, ®_gpair); + info.weights_.Resize(info.num_row_, 1.0); + test_unbiased(predts, ®_gpair); + + info.weights_.Resize(0); + test_regularized(predts, ®_gpair); + info.weights_.Resize(info.num_row_, 1.0); + test_regularized(predts, ®_gpair); + } + + { + /** + * multi-target, change the shape of labels and predictions. + */ + size_t n_targets = 4; + info.labels.Reshape(info.num_row_, n_targets); + auto h_y = info.labels.HostView(); + + HostDeviceVector predts; + predts.SetDevice(GPUIDX); + predts.Resize(info.labels.Size()); + auto& h_predts = predts.HostVector(); + for (size_t i = 0; i < n_targets; ++i) { + for (size_t j = 0; j < info.num_row_; ++j) { + h_y(j, i) = j % 2 != 0; + size_t strides[] = {h_y.Stride(0), h_y.Stride(1)}; + auto offset = linalg::detail::Offset<0ul>(strides, 0ul, j, i); + h_predts[offset] = j % 2 != 0; + } + } + + info.weights_.Resize(0); + test_unbiased(predts, ®_gpair); + info.weights_.Resize(info.num_row_, 1.0); + test_unbiased(predts, ®_gpair); + + info.weights_.Resize(0); + test_regularized(predts, ®_gpair); + info.weights_.Resize(info.num_row_, 1.0); + test_regularized(predts, ®_gpair); + } } // CoxRegression not implemented in GPU code, no need for testing. From e000d2591b5322ed140fa8c8783750f9f19fe988 Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 22 Feb 2022 12:39:42 +0800 Subject: [PATCH 19/20] Rename file. --- src/common/{fair_param.h => regularized.h} | 0 src/metric/elementwise_metric.cu | 6 +++--- src/objective/regression_obj.cu | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) rename src/common/{fair_param.h => regularized.h} (100%) diff --git a/src/common/fair_param.h b/src/common/regularized.h similarity index 100% rename from src/common/fair_param.h rename to src/common/regularized.h diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index b8f8ad3bdc9b..f3d231fb9f57 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -13,8 +13,8 @@ #include #include "../common/common.h" -#include "../common/fair_param.h" #include "../common/math.h" +#include "../common/regularized.h" #include "../common/threading_utils.h" #include "metric_common.h" @@ -421,8 +421,8 @@ XGBOOST_REGISTER_METRIC(LogLoss, "logloss") .set_body([](const char* param) { return new EvalEWiseBase(); }); XGBOOST_REGISTER_METRIC(RegularizedLogLoss, "regularized-logloss") -.describe("Negative loglikelihood for regularized binary classification.") -.set_body([](const char* param) { return new RegularizedLogLoss(); }); + .describe("Negative loglikelihood for regularized binary classification.") + .set_body([](const char* param) { return new RegularizedLogLoss(); }); XGBOOST_REGISTER_METRIC(PossionNegLoglik, "poisson-nloglik") .describe("Negative loglikelihood for poisson regression.") diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index c46637b818a9..7c7afd345fda 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -14,8 +14,8 @@ #include #include "../common/common.h" -#include "../common/fair_param.h" #include "../common/linalg_op.h" +#include "../common/regularized.h" #include "../common/threading_utils.h" #include "../common/transform.h" #include "./regression_loss.h" From 4362debd1674c7ae86833d48893af09d3a246492 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 23 Feb 2022 03:43:31 +0800 Subject: [PATCH 20/20] Lint. --- src/common/regularized.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/common/regularized.h b/src/common/regularized.h index 05301bcecadc..dab7baee03d0 100644 --- a/src/common/regularized.h +++ b/src/common/regularized.h @@ -1,5 +1,5 @@ -#ifndef XGBOOST_COMMON_FAIR_PARAM_H_ -#define XGBOOST_COMMON_FAIR_PARAM_H_ +#ifndef XGBOOST_COMMON_REGULARIZED_H_ +#define XGBOOST_COMMON_REGULARIZED_H_ /*! * Copyright 2022 by XGBoost Contributors */ @@ -15,4 +15,4 @@ struct BinaryRegularizationParam : public XGBoostParameter