From e161ed963a0eba566fcfea45e40b308cf1186803 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 8 Oct 2021 15:34:40 +0800 Subject: [PATCH] Re-implement PR-AUC. * Support binary/multi-class classification, ranking. * Add documents. * Handle missing data. --- doc/parameter.rst | 8 +- src/common/common.h | 14 + src/metric/auc.cc | 337 ++++++++---- src/metric/auc.cu | 636 +++++++++++++++++----- src/metric/auc.h | 95 +++- src/metric/rank_metric.cc | 156 ------ src/metric/rank_metric.cu | 190 ------- tests/cpp/metric/test_auc.cc | 119 +++- tests/cpp/metric/test_rank_metric.cc | 60 -- tests/python-gpu/test_gpu_eval_metrics.py | 9 + tests/python/test_eval_metrics.py | 65 +++ tests/python/test_with_dask.py | 4 +- 12 files changed, 1038 insertions(+), 655 deletions(-) diff --git a/doc/parameter.rst b/doc/parameter.rst index aff9a5d18687..52f208cdf582 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -397,9 +397,13 @@ Specify the learning task and the corresponding learning objective. The objectiv - When used with multi-class classification, objective should be ``multi:softprob`` instead of ``multi:softmax``, as the latter doesn't output probability. Also the AUC is calculated by 1-vs-rest with reference class weighted by class prevalence. - When used with LTR task, the AUC is computed by comparing pairs of documents to count correctly sorted pairs. This corresponds to pairwise learning to rank. The implementation has some issues with average AUC around groups and distributed workers not being well-defined. - On a single machine the AUC calculation is exact. In a distributed environment the AUC is a weighted average over the AUC of training rows on each node - therefore, distributed AUC is an approximation sensitive to the distribution of data across workers. Use another metric in distributed environments if precision and reproducibility are important. - - If input dataset contains only negative or positive samples the output is `NaN`. + - When input dataset contains only negative or positive samples, the output is `NaN`. The behavior is implementation defined, for instance, ``scikit-learn`` returns :math:`0.5` instead. + + - ``aucpr``: `Area under the PR curve `_. + Available for classification and learning-to-rank tasks. + + After XGBoost 1.6, both of the requirements and restrictions for using ``aucpr`` in classification problem are similar to ``auc``. For ranking task, only binary relevance label :math:`y \in [0, 1]` is supported. - - ``aucpr``: `Area under the PR curve `_. Available for binary classification and learning-to-rank tasks. - ``ndcg``: `Normalized Discounted Cumulative Gain `_ - ``map``: `Mean Average Precision `_ - ``ndcg@n``, ``map@n``: 'n' can be assigned as an integer to cut off the top positions in the lists for evaluation. diff --git a/src/common/common.h b/src/common/common.h index d8d42f75ca6a..8230e532ff69 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -19,6 +19,7 @@ #include #include #include +#include #if defined(__CUDACC__) #include @@ -86,6 +87,19 @@ XGBOOST_DEVICE T1 DivRoundUp(const T1 a, const T2 b) { return static_cast(std::ceil(static_cast(a) / b)); } +namespace detail { +template +constexpr auto UnpackArr(std::array &&arr, std::index_sequence) { + return std::make_tuple(std::forward>(arr)[Idx]...); +} +} // namespace detail + +template +constexpr auto UnpackArr(std::array &&arr) { + return detail::UnpackArr(std::forward>(arr), + std::make_index_sequence{}); +} + /* * Range iterator */ diff --git a/src/metric/auc.cc b/src/metric/auc.cc index 7eec9741872f..63315150ff4f 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -14,62 +14,50 @@ #include "rabit/rabit.h" #include "xgboost/host_device_vector.h" #include "xgboost/metric.h" + #include "auc.h" + #include "../common/common.h" #include "../common/math.h" +#include "../common/threading_utils.h" namespace xgboost { namespace metric { - -namespace detail { -template -constexpr auto UnpackArr(std::array &&arr, std::index_sequence) { - return std::make_tuple(std::forward>(arr)[Idx]...); -} -} // namespace detail - -template -constexpr auto UnpackArr(std::array &&arr) { - return detail::UnpackArr(std::forward>(arr), - std::make_index_sequence{}); -} - /** * Calculate AUC for binary classification problem. This function does not normalize the * AUC by 1 / (num_positive * num_negative), instead it returns a tuple for caller to * handle the normalization. */ -std::tuple BinaryAUC(std::vector const &predts, - std::vector const &labels, - std::vector const &weights) { +template +std::tuple +BinaryAUC(common::Span predts, common::Span labels, + OptionalWeights weights, + std::vector const &sorted_idx, Fn &&area_fn) { CHECK(!labels.empty()); CHECK_EQ(labels.size(), predts.size()); + auto p_predts = predts.data(); + auto p_labels = labels.data(); - float auc {0}; - auto const sorted_idx = common::ArgSort( - common::Span(predts), std::greater<>{}); + float auc{0}; - auto get_weight = [&](size_t i) { - return weights.empty() ? 1.0f : weights[sorted_idx[i]]; - }; - float label = labels[sorted_idx.front()]; - float w = get_weight(0); + float label = p_labels[sorted_idx.front()]; + float w = weights[sorted_idx[0]]; float fp = (1.0 - label) * w, tp = label * w; float tp_prev = 0, fp_prev = 0; // TODO(jiaming): We can parallize this if we have a parallel scan for CPU. for (size_t i = 1; i < sorted_idx.size(); ++i) { - if (predts[sorted_idx[i]] != predts[sorted_idx[i-1]]) { - auc += TrapesoidArea(fp_prev, fp, tp_prev, tp); + if (p_predts[sorted_idx[i]] != p_predts[sorted_idx[i - 1]]) { + auc += area_fn(fp_prev, fp, tp_prev, tp); tp_prev = tp; fp_prev = fp; } - label = labels[sorted_idx[i]]; - float w = get_weight(i); + label = p_labels[sorted_idx[i]]; + float w = weights[sorted_idx[i]]; fp += (1.0f - label) * w; tp += label * w; } - auc += TrapesoidArea(fp_prev, fp, tp_prev, tp); + auc += area_fn(fp_prev, fp, tp_prev, tp); if (fp <= 0.0f || tp <= 0.0f) { auc = 0; fp = 0; @@ -87,46 +75,44 @@ std::tuple BinaryAUC(std::vector const &predts, * - Kleiman, Ross and Page, David. $AUC_{\mu}$: A Performance Metric for Multi-Class * Machine Learning Models */ -float MultiClassOVR(std::vector const& predts, MetaInfo const& info, size_t n_classes) { +template +float MultiClassOVR(common::Span predts, MetaInfo const &info, + size_t n_classes, int32_t n_threads, + BinaryAUC &&binary_auc) { CHECK_NE(n_classes, 0); - auto const& labels = info.labels_.ConstHostVector(); + auto const &labels = info.labels_.ConstHostVector(); std::vector results(n_classes * 3, 0); auto s_results = common::Span(results); auto local_area = s_results.subspan(0, n_classes); auto tp = s_results.subspan(n_classes, n_classes); auto auc = s_results.subspan(2 * n_classes, n_classes); + auto weights = OptionalWeights{info.weights_.ConstHostSpan()}; if (!info.labels_.Empty()) { - dmlc::OMPException omp_handler; -#pragma omp parallel for - for (omp_ulong c = 0; c < n_classes; ++c) { - omp_handler.Run([&]() { - std::vector proba(info.labels_.Size()); - std::vector response(info.labels_.Size()); - for (size_t i = 0; i < proba.size(); ++i) { - proba[i] = predts[i * n_classes + c]; - response[i] = labels[i] == c ? 1.0f : 0.0; - } - float fp; - std::tie(fp, tp[c], auc[c]) = - BinaryAUC(proba, response, info.weights_.ConstHostVector()); - local_area[c] = fp * tp[c]; - }); - } - omp_handler.Rethrow(); + common::ParallelFor(n_classes, n_threads, [&](auto c) { + std::vector proba(info.labels_.Size()); + std::vector response(info.labels_.Size()); + for (size_t i = 0; i < proba.size(); ++i) { + proba[i] = predts[i * n_classes + c]; + response[i] = labels[i] == c ? 1.0f : 0.0; + } + float fp; + std::tie(fp, tp[c], auc[c]) = binary_auc(proba, response, weights); + local_area[c] = fp * tp[c]; + }); } - // we have 2 averages going in here, first is among workers, second is among classes. - // allreduce sums up fp/tp auc for each class. + // we have 2 averages going in here, first is among workers, second is among + // classes. allreduce sums up fp/tp auc for each class. rabit::Allreduce(results.data(), results.size()); float auc_sum{0}; float tp_sum{0}; for (size_t c = 0; c < n_classes; ++c) { if (local_area[c] != 0) { - // normalize and weight it by prevalence. After allreduce, `local_area` means the - // total covered area (not area under curve, rather it's the accessible area for - // each worker) for each class. + // normalize and weight it by prevalence. After allreduce, `local_area` + // means the total covered area (not area under curve, rather it's the + // accessible area for each worker) for each class. auc_sum += auc[c] / local_area[c] * tp[c]; tp_sum += tp[c]; } else { @@ -142,10 +128,17 @@ float MultiClassOVR(std::vector const& predts, MetaInfo const& info, size return auc_sum; } +std::tuple BinaryROCAUC(common::Span predts, + common::Span labels, + OptionalWeights weights) { + auto const sorted_idx = common::ArgSort(predts, std::greater<>{}); + return BinaryAUC(predts, labels, weights, sorted_idx, TrapezoidArea); +} + /** * Calculate AUC for 1 ranking group; */ -float GroupRankingAUC(common::Span predts, +float GroupRankingROC(common::Span predts, common::Span labels, float w) { // on ranking, we just count all pairs. float auc{0}; @@ -174,11 +167,40 @@ float GroupRankingAUC(common::Span predts, return auc; } +/** + * \brief PR-AUC for binary classification. + * + * https://doi.org/10.1371/journal.pone.0092209 + */ +std::tuple BinaryPRAUC(common::Span predts, + common::Span labels, + OptionalWeights weights) { + auto const sorted_idx = common::ArgSort(predts, std::greater<>{}); + float total_pos{0}, total_neg{0}; + for (size_t i = 0; i < labels.size(); ++i) { + auto w = weights[i]; + total_pos += w * labels[i]; + total_neg += w * (1.0f - labels[i]); + } + if (total_pos <= 0 || total_neg <= 0) { + return {1.0f, 1.0f, std::numeric_limits::quiet_NaN()}; + } + auto fn = [total_pos](float fp_prev, float fp, float tp_prev, float tp) { + return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, total_pos); + }; + + float tp{0}, fp{0}, auc{0}; + std::tie(fp, tp, auc) = BinaryAUC(predts, labels, weights, sorted_idx, fn); + return std::make_tuple(1.0, 1.0, auc); +} + + /** * Cast LTR problem to binary classification problem by comparing pairs. */ +template std::pair RankingAUC(std::vector const &predts, - MetaInfo const &info) { + MetaInfo const &info, int32_t n_threads) { CHECK_GE(info.group_ptr_.size(), 2); uint32_t n_groups = info.group_ptr_.size() - 1; float sum_auc = 0; @@ -189,7 +211,7 @@ std::pair RankingAUC(std::vector const &predts, std::atomic invalid_groups{0}; dmlc::OMPException omp_handler; -#pragma omp parallel for reduction(+:sum_auc) +#pragma omp parallel for reduction(+:sum_auc) num_threads(n_threads) for (omp_ulong g = 1; g < info.group_ptr_.size(); ++g) { omp_handler.Run([&]() { size_t cnt = info.group_ptr_[g] - info.group_ptr_[g - 1]; @@ -197,30 +219,32 @@ std::pair RankingAUC(std::vector const &predts, auto g_predts = s_predts.subspan(info.group_ptr_[g - 1], cnt); auto g_labels = s_labels.subspan(info.group_ptr_[g - 1], cnt); float auc; - if (g_labels.size() < 3) { + if (is_roc && g_labels.size() < 3) { // With 2 documents, there's only 1 comparison can be made. So either // TP or FP will be zero. invalid_groups++; auc = 0; } else { - auc = GroupRankingAUC(g_predts, g_labels, w); + if (is_roc) { + auc = GroupRankingROC(g_predts, g_labels, w); + } else { + auc = std::get<2>(BinaryPRAUC(g_predts, g_labels, OptionalWeights{w})); + } + if (std::isnan(auc)) { + invalid_groups++; + auc = 0; + } } sum_auc += auc; }); } omp_handler.Rethrow(); - if (invalid_groups != 0) { - InvalidGroupAUC(); - } - return std::make_pair(sum_auc, n_groups - invalid_groups); } +template class EvalAUC : public Metric { - std::shared_ptr d_cache_; - - public: float Eval(const HostDeviceVector &preds, const MetaInfo &info, bool distributed) override { float auc {0}; @@ -232,8 +256,10 @@ class EvalAUC : public Metric { // We use the global size to handle empty dataset. std::array meta{info.labels_.Size(), preds.Size()}; rabit::Allreduce(meta.data(), meta.size()); - - if (!info.group_ptr_.empty()) { + if (meta[0] == 0) { + // Empty across all workers, which is not supported. + auc = std::numeric_limits::quiet_NaN(); + } else if (!info.group_ptr_.empty()) { /** * learning to rank */ @@ -243,13 +269,11 @@ class EvalAUC : public Metric { uint32_t valid_groups = 0; if (!info.labels_.Empty()) { CHECK_EQ(info.group_ptr_.back(), info.labels_.Size()); - if (tparam_->gpu_id == GenericParameter::kCpuId) { - std::tie(auc, valid_groups) = - RankingAUC(preds.ConstHostVector(), info); - } else { - std::tie(auc, valid_groups) = GPURankingAUC( - preds.ConstDeviceSpan(), info, tparam_->gpu_id, &this->d_cache_); - } + std::tie(auc, valid_groups) = + static_cast(this)->EvalRanking(preds, info); + } + if (valid_groups != info.group_ptr_.size() - 1) { + InvalidGroupAUC(); } std::array results{auc, static_cast(valid_groups)}; @@ -270,45 +294,85 @@ class EvalAUC : public Metric { */ size_t n_classes = meta[1] / meta[0]; CHECK_NE(n_classes, 0); - if (tparam_->gpu_id == GenericParameter::kCpuId) { - auc = MultiClassOVR(preds.ConstHostVector(), info, n_classes); - } else { - auc = GPUMultiClassAUCOVR(preds.ConstDeviceSpan(), info, tparam_->gpu_id, - &this->d_cache_, n_classes); - } + auc = static_cast(this)->EvalMultiClass(preds, info, n_classes); } else { /** * binary classification */ float fp{0}, tp{0}; if (!(preds.Empty() || info.labels_.Empty())) { - if (tparam_->gpu_id == GenericParameter::kCpuId) { - std::tie(fp, tp, auc) = - BinaryAUC(preds.ConstHostVector(), info.labels_.ConstHostVector(), - info.weights_.ConstHostVector()); - } else { - std::tie(fp, tp, auc) = GPUBinaryAUC( - preds.ConstDeviceSpan(), info, tparam_->gpu_id, &this->d_cache_); - } + std::tie(fp, tp, auc) = + static_cast(this)->EvalBinary(preds, info); } float local_area = fp * tp; std::array result{auc, local_area}; rabit::Allreduce(result.data(), result.size()); - std::tie(auc, local_area) = UnpackArr(std::move(result)); + std::tie(auc, local_area) = common::UnpackArr(std::move(result)); if (local_area <= 0) { // the dataset across all workers have only positive or negative sample auc = std::numeric_limits::quiet_NaN(); } else { + CHECK_LE(auc, local_area); // normalization auc = auc / local_area; } } if (std::isnan(auc)) { - LOG(WARNING) << "Dataset contains only positive or negative samples."; + LOG(WARNING) << "Dataset is empty, or contains only positive or negative samples."; } return auc; } +}; +class EvalROCAUC : public EvalAUC { + std::shared_ptr d_cache_; + + public: + std::pair EvalRanking(HostDeviceVector const &predts, + MetaInfo const &info) { + float auc{0}; + uint32_t valid_groups = 0; + auto n_threads = tparam_->Threads(); + if (tparam_->gpu_id == GenericParameter::kCpuId) { + std::tie(auc, valid_groups) = + RankingAUC(predts.ConstHostVector(), info, n_threads); + } else { + std::tie(auc, valid_groups) = GPURankingAUC( + predts.ConstDeviceSpan(), info, tparam_->gpu_id, &this->d_cache_); + } + return std::make_pair(auc, valid_groups); + } + + float EvalMultiClass(HostDeviceVector const &predts, + MetaInfo const &info, size_t n_classes) { + float auc{0}; + auto n_threads = tparam_->Threads(); + CHECK_NE(n_classes, 0); + if (tparam_->gpu_id == GenericParameter::kCpuId) { + auc = MultiClassOVR(predts.ConstHostVector(), info, n_classes, n_threads, + BinaryROCAUC); + } else { + auc = GPUMultiClassROCAUC(predts.ConstDeviceSpan(), info, tparam_->gpu_id, + &this->d_cache_, n_classes); + } + return auc; + } + + std::tuple + EvalBinary(HostDeviceVector const &predts, MetaInfo const &info) { + float fp, tp, auc; + if (tparam_->gpu_id == GenericParameter::kCpuId) { + std::tie(fp, tp, auc) = + BinaryROCAUC(predts.ConstHostVector(), info.labels_.ConstHostVector(), + OptionalWeights{info.weights_.ConstHostSpan()}); + } else { + std::tie(fp, tp, auc) = GPUBinaryROCAUC(predts.ConstDeviceSpan(), info, + tparam_->gpu_id, &this->d_cache_); + } + return std::make_tuple(fp, tp, auc); + } + + public: char const* Name() const override { return "auc"; } @@ -316,18 +380,19 @@ class EvalAUC : public Metric { XGBOOST_REGISTER_METRIC(EvalAUC, "auc") .describe("Receiver Operating Characteristic Area Under the Curve.") -.set_body([](const char*) { return new EvalAUC(); }); +.set_body([](const char*) { return new EvalROCAUC(); }); #if !defined(XGBOOST_USE_CUDA) std::tuple -GPUBinaryAUC(common::Span predts, MetaInfo const &info, - int32_t device, std::shared_ptr *p_cache) { +GPUBinaryROCAUC(common::Span predts, MetaInfo const &info, + int32_t device, std::shared_ptr *p_cache) { common::AssertGPUSupport(); return std::make_tuple(0.0f, 0.0f, 0.0f); } -float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info, - int32_t device, std::shared_ptr* cache, +float GPUMultiClassROCAUC(common::Span predts, + MetaInfo const &info, int32_t device, + std::shared_ptr *cache, size_t n_classes) { common::AssertGPUSupport(); return 0; @@ -341,5 +406,85 @@ GPURankingAUC(common::Span predts, MetaInfo const &info, } struct DeviceAUCCache {}; #endif // !defined(XGBOOST_USE_CUDA) + +class EvalAUCPR : public EvalAUC { + std::shared_ptr d_cache_; + + public: + std::tuple + EvalBinary(HostDeviceVector const &predts, MetaInfo const &info) { + float pr, re, auc; + if (tparam_->gpu_id == GenericParameter::kCpuId) { + std::tie(pr, re, auc) = + BinaryPRAUC(predts.ConstHostSpan(), info.labels_.ConstHostSpan(), + OptionalWeights{info.weights_.ConstHostSpan()}); + } else { + std::tie(pr, re, auc) = GPUBinaryPRAUC(predts.ConstDeviceSpan(), info, + tparam_->gpu_id, &this->d_cache_); + } + return std::make_tuple(pr, re, auc); + } + + float EvalMultiClass(HostDeviceVector const &predts, + MetaInfo const &info, size_t n_classes) { + if (tparam_->gpu_id == GenericParameter::kCpuId) { + auto n_threads = this->tparam_->Threads(); + return MultiClassOVR(predts.ConstHostSpan(), info, n_classes, n_threads, + BinaryPRAUC); + } else { + return GPUMultiClassPRAUC(predts.ConstDeviceSpan(), info, tparam_->gpu_id, + &d_cache_, n_classes); + } + } + + std::pair EvalRanking(HostDeviceVector const &predts, + MetaInfo const &info) { + float auc{0}; + uint32_t valid_groups = 0; + auto n_threads = tparam_->Threads(); + if (tparam_->gpu_id == GenericParameter::kCpuId) { + auto labels = info.labels_.ConstHostSpan(); + if (std::any_of(labels.cbegin(), labels.cend(), PRAUCLabelInvalid{})) { + InvalidLabels(); + } + std::tie(auc, valid_groups) = + RankingAUC(predts.ConstHostVector(), info, n_threads); + } else { + std::tie(auc, valid_groups) = GPURankingPRAUC( + predts.ConstDeviceSpan(), info, tparam_->gpu_id, &d_cache_); + } + return std::make_pair(auc, valid_groups); + } + + public: + const char *Name() const override { return "aucpr"; } +}; + +XGBOOST_REGISTER_METRIC(AUCPR, "aucpr") + .describe("Area under PR curve for both classification and rank.") + .set_body([](char const *) { return new EvalAUCPR{}; }); + +#if !defined(XGBOOST_USE_CUDA) +std::tuple +GPUBinaryPRAUC(common::Span predts, MetaInfo const &info, + int32_t device, std::shared_ptr *p_cache) { + common::AssertGPUSupport(); + return {}; +} + +float GPUMultiClassPRAUC(common::Span predts, MetaInfo const &info, + int32_t device, std::shared_ptr *cache, + size_t n_classes) { + common::AssertGPUSupport(); + return {}; +} + +std::pair +GPURankingPRAUC(common::Span predts, MetaInfo const &info, + int32_t device, std::shared_ptr *cache) { + common::AssertGPUSupport(); + return {}; +} +#endif } // namespace metric } // namespace xgboost diff --git a/src/metric/auc.cu b/src/metric/auc.cu index 7f57d1a87d5b..fa849b0a1f2d 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -3,6 +3,8 @@ */ #include #include + +#include #include #include #include @@ -19,12 +21,13 @@ namespace xgboost { namespace metric { namespace { -struct GetWeightOp { - common::Span weights; - common::Span sorted_idx; +// Pair of FP/TP +using Pair = thrust::pair; - __device__ float operator()(size_t i) const { - return weights.empty() ? 1.0f : weights[sorted_idx[i]]; +template > +struct PairPlus : public thrust::binary_function { + XGBOOST_DEVICE P operator()(P const& l, P const& r) const { + return thrust::make_pair(l.first + r.first, l.second + r.second); } }; } // namespace @@ -33,8 +36,6 @@ struct GetWeightOp { * A cache to GPU data to avoid reallocating memory. */ struct DeviceAUCCache { - // Pair of FP/TP - using Pair = thrust::pair; // index sorted by prediction value dh::device_vector sorted_idx; // track FP/TP for computation on trapesoid area @@ -64,6 +65,16 @@ struct DeviceAUCCache { } }; +template +void InitCacheOnce(common::Span predts, int32_t device, + std::shared_ptr* p_cache) { + auto& cache = *p_cache; + if (!cache) { + cache.reset(new DeviceAUCCache); + } + cache->Init(predts, is_multi, device); +} + /** * The GPU implementation uses same calculation as CPU with a few more steps to distribute * work across threads: @@ -73,15 +84,11 @@ struct DeviceAUCCache { * which are left coordinates of trapesoids. * - Reduce the scan array into 1 AUC value. */ +template std::tuple GPUBinaryAUC(common::Span predts, MetaInfo const &info, - int32_t device, std::shared_ptr *p_cache) { - auto& cache = *p_cache; - if (!cache) { - cache.reset(new DeviceAUCCache); - } - cache->Init(predts, false, device); - + int32_t device, common::Span d_sorted_idx, + Fn area_fn, std::shared_ptr cache) { auto labels = info.labels_.ConstDeviceSpan(); auto weights = info.weights_.ConstDeviceSpan(); dh::safe_cuda(cudaSetDevice(device)); @@ -89,22 +96,15 @@ GPUBinaryAUC(common::Span predts, MetaInfo const &info, CHECK(!labels.empty()); CHECK_EQ(labels.size(), predts.size()); - /** - * Create sorted index for each class - */ - auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); - dh::ArgSort(predts, d_sorted_idx); - /** * Linear scan */ - auto get_weight = GetWeightOp{weights, d_sorted_idx}; - using Pair = thrust::pair; - auto get_fp_tp = [=]__device__(size_t i) { + auto get_weight = OptionalWeights{weights}; + auto get_fp_tp = [=]XGBOOST_DEVICE(size_t i) { size_t idx = d_sorted_idx[i]; float label = labels[idx]; - float w = get_weight(i); + float w = get_weight[d_sorted_idx[i]]; float fp = (1.0 - label) * w; float tp = label * w; @@ -113,7 +113,7 @@ GPUBinaryAUC(common::Span predts, MetaInfo const &info, }; // NOLINT auto d_fptp = dh::ToSpan(cache->fptp); dh::LaunchN(d_sorted_idx.size(), - [=] __device__(size_t i) { d_fptp[i] = get_fp_tp(i); }); + [=] XGBOOST_DEVICE(size_t i) { d_fptp[i] = get_fp_tp(i); }); dh::XGBDeviceAllocator alloc; auto d_unique_idx = dh::ToSpan(cache->unique_idx); @@ -121,24 +121,20 @@ GPUBinaryAUC(common::Span predts, MetaInfo const &info, auto uni_key = dh::MakeTransformIterator( thrust::make_counting_iterator(0), - [=] __device__(size_t i) { return predts[d_sorted_idx[i]]; }); + [=] XGBOOST_DEVICE(size_t i) { return predts[d_sorted_idx[i]]; }); auto end_unique = thrust::unique_by_key_copy( thrust::cuda::par(alloc), uni_key, uni_key + d_sorted_idx.size(), dh::tbegin(d_unique_idx), thrust::make_discard_iterator(), dh::tbegin(d_unique_idx)); d_unique_idx = d_unique_idx.subspan(0, end_unique.second - dh::tbegin(d_unique_idx)); - dh::InclusiveScan( - dh::tbegin(d_fptp), dh::tbegin(d_fptp), - [=] __device__(Pair const &l, Pair const &r) { - return thrust::make_pair(l.first + r.first, l.second + r.second); - }, - d_fptp.size()); + dh::InclusiveScan(dh::tbegin(d_fptp), dh::tbegin(d_fptp), + PairPlus{}, d_fptp.size()); auto d_neg_pos = dh::ToSpan(cache->neg_pos); // scatter unique negaive/positive values // shift to right by 1 with initial value being 0 - dh::LaunchN(d_unique_idx.size(), [=] __device__(size_t i) { + dh::LaunchN(d_unique_idx.size(), [=] XGBOOST_DEVICE(size_t i) { if (d_unique_idx[i] == 0) { // first unique index is 0 assert(i == 0); d_neg_pos[0] = {0, 0}; @@ -154,7 +150,7 @@ GPUBinaryAUC(common::Span predts, MetaInfo const &info, }); auto in = dh::MakeTransformIterator( - thrust::make_counting_iterator(0), [=] __device__(size_t i) { + thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { float fp, tp; float fp_prev, tp_prev; if (i == 0) { @@ -165,7 +161,7 @@ GPUBinaryAUC(common::Span predts, MetaInfo const &info, thrust::tie(fp, tp) = d_fptp[d_unique_idx[i] - 1]; thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx[i - 1]]; } - return TrapesoidArea(fp_prev, fp, tp_prev, tp); + return area_fn(fp_prev, fp, tp_prev, tp); }); Pair last = cache->fptp.back(); @@ -173,11 +169,31 @@ GPUBinaryAUC(common::Span predts, MetaInfo const &info, return std::make_tuple(last.first, last.second, auc); } +std::tuple +GPUBinaryROCAUC(common::Span predts, MetaInfo const &info, + int32_t device, std::shared_ptr *p_cache) { + auto &cache = *p_cache; + InitCacheOnce(predts, device, p_cache); + + /** + * Create sorted index for each class + */ + auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); + dh::ArgSort(predts, d_sorted_idx); + // Create lambda to avoid pass function pointer. + return GPUBinaryAUC( + predts, info, device, d_sorted_idx, + [] XGBOOST_DEVICE(float x0, float x1, float y0, float y1) { + return TrapezoidArea(x0, x1, y0, y1); + }, + cache); +} + void Transpose(common::Span in, common::Span out, size_t m, - size_t n, int32_t device) { + size_t n) { CHECK_EQ(in.size(), out.size()); CHECK_EQ(in.size(), m * n); - dh::LaunchN(in.size(), [=] __device__(size_t i) { + dh::LaunchN(in.size(), [=] XGBOOST_DEVICE(size_t i) { size_t col = i / m; size_t row = i % m; size_t idx = row * n + col; @@ -204,7 +220,7 @@ float ScaleClasses(common::Span results, common::Span local_area, cache->reducer->AllReduceSum(results.data(), results.data(), results.size()); } auto reduce_in = dh::MakeTransformIterator>( - thrust::make_counting_iterator(0), [=] __device__(size_t i) { + thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { if (local_area[i] > 0) { return thrust::make_pair(auc[i] / local_area[i] * tp[i], tp[i]); } @@ -213,12 +229,9 @@ float ScaleClasses(common::Span results, common::Span local_area, float tp_sum; float auc_sum; - thrust::tie(auc_sum, tp_sum) = thrust::reduce( - thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes, - thrust::make_pair(0.0f, 0.0f), - [=] __device__(auto const &l, auto const &r) { - return thrust::make_pair(l.first + r.first, l.second + r.second); - }); + thrust::tie(auc_sum, tp_sum) = + thrust::reduce(thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes, + Pair{0.0f, 0.0f}, PairPlus{}); if (tp_sum != 0 && !std::isnan(auc_sum)) { auc_sum /= tp_sum; } else { @@ -227,19 +240,98 @@ float ScaleClasses(common::Span results, common::Span local_area, return auc_sum; } +/** + * Calculate FP/TP for multi-class and PR-AUC ranking. `segment_id` is a function for + * getting class id or group id given scan index. + */ +template +void SegmentedFPTP(common::Span d_fptp, Fn segment_id) { + using Triple = thrust::tuple; + // expand to tuple to include idx + auto fptp_it_in = dh::MakeTransformIterator( + thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { + return thrust::make_tuple(i, d_fptp[i].first, d_fptp[i].second); + }); + // shrink down to pair + auto fptp_it_out = thrust::make_transform_output_iterator( + dh::TypedDiscard{}, [d_fptp] XGBOOST_DEVICE(Triple const &t) { + d_fptp[thrust::get<0>(t)] = + thrust::make_pair(thrust::get<1>(t), thrust::get<2>(t)); + return t; + }); + dh::InclusiveScan( + fptp_it_in, fptp_it_out, + [=] XGBOOST_DEVICE(Triple const &l, Triple const &r) { + uint32_t l_gid = segment_id(thrust::get<0>(l)); + uint32_t r_gid = segment_id(thrust::get<0>(r)); + if (l_gid != r_gid) { + return r; + } + + return Triple(thrust::get<0>(r), + thrust::get<1>(l) + thrust::get<1>(r), // fp + thrust::get<2>(l) + thrust::get<2>(r)); // tp + }, + d_fptp.size()); +} + +/** + * Reduce the values of AUC for each group/class. + */ +template +void SegmentedReduceAUC(common::Span d_unique_idx, + common::Span d_class_ptr, + common::Span d_unique_class_ptr, + std::shared_ptr cache, + Area area_fn, + Seg segment_id, + common::Span d_auc) { + auto d_fptp = dh::ToSpan(cache->fptp); + auto d_neg_pos = dh::ToSpan(cache->neg_pos); + dh::XGBDeviceAllocator alloc; + auto key_in = dh::MakeTransformIterator( + thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { + size_t class_id = segment_id(d_unique_idx[i]); + return class_id; + }); + auto val_in = dh::MakeTransformIterator( + thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { + size_t class_id = segment_id(d_unique_idx[i]); + + float fp, tp, fp_prev, tp_prev; + if (i == d_unique_class_ptr[class_id]) { + // first item is ignored, we use this thread to calculate the last item + thrust::tie(fp, tp) = d_fptp[LastOf(class_id, d_class_ptr)]; + thrust::tie(fp_prev, tp_prev) = + d_neg_pos[d_unique_idx[LastOf(class_id, d_unique_class_ptr)]]; + } else { + thrust::tie(fp, tp) = d_fptp[d_unique_idx[i] - 1]; + thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx[i - 1]]; + } + float auc = area_fn(fp_prev, fp, tp_prev, tp, class_id); + return auc; + }); + thrust::reduce_by_key(thrust::cuda::par(alloc), key_in, + key_in + d_unique_idx.size(), val_in, + thrust::make_discard_iterator(), dh::tbegin(d_auc)); +} + /** * MultiClass implementation is similar to binary classification, except we need to split * up each class in all kernels. */ -float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info, - int32_t device, std::shared_ptr* p_cache, - size_t n_classes) { +template +float GPUMultiClassAUCOVR(common::Span predts, + MetaInfo const &info, int32_t device, + common::Span d_class_ptr, size_t n_classes, + std::shared_ptr cache, Fn area_fn) { dh::safe_cuda(cudaSetDevice(device)); - auto& cache = *p_cache; - if (!cache) { - cache.reset(new DeviceAUCCache); - } - cache->Init(predts, true, device); + /** + * Sorted idx + */ + auto d_predts_t = dh::ToSpan(cache->predts_t); + // Index is sorted within class. + auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); auto labels = info.labels_.ConstDeviceSpan(); auto weights = info.weights_.ConstDeviceSpan(); @@ -250,7 +342,7 @@ float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info dh::TemporaryArray resutls(n_classes * 4, 0.0f); auto d_results = dh::ToSpan(resutls); dh::LaunchN(n_classes * 4, - [=] __device__(size_t i) { d_results[i] = 0.0f; }); + [=] XGBOOST_DEVICE(size_t i) { d_results[i] = 0.0f; }); auto local_area = d_results.subspan(0, n_classes); auto fp = d_results.subspan(n_classes, n_classes); auto tp = d_results.subspan(2 * n_classes, n_classes); @@ -258,43 +350,26 @@ float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info return ScaleClasses(d_results, local_area, fp, tp, auc, cache, n_classes); } - /** - * Create sorted index for each class - */ - auto d_predts_t = dh::ToSpan(cache->predts_t); - Transpose(predts, d_predts_t, n_samples, n_classes, device); - - dh::TemporaryArray class_ptr(n_classes + 1, 0); - auto d_class_ptr = dh::ToSpan(class_ptr); - dh::LaunchN(n_classes + 1, - [=] __device__(size_t i) { d_class_ptr[i] = i * n_samples; }); - // no out-of-place sort for thrust, cub sort doesn't accept general iterator. So can't - // use transform iterator in sorting. - auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); - dh::SegmentedArgSort(d_predts_t, d_class_ptr, d_sorted_idx); - /** * Linear scan */ dh::caching_device_vector d_auc(n_classes, 0); - auto s_d_auc = dh::ToSpan(d_auc); - auto get_weight = GetWeightOp{weights, d_sorted_idx}; - using Pair = thrust::pair; + auto get_weight = OptionalWeights{weights}; auto d_fptp = dh::ToSpan(cache->fptp); - auto get_fp_tp = [=]__device__(size_t i) { + auto get_fp_tp = [=]XGBOOST_DEVICE(size_t i) { size_t idx = d_sorted_idx[i]; size_t class_id = i / n_samples; // labels is a vector of size n_samples. float label = labels[idx % n_samples] == class_id; - float w = weights.empty() ? 1.0f : weights[d_sorted_idx[i] % n_samples]; + float w = get_weight[d_sorted_idx[i] % n_samples]; float fp = (1.0 - label) * w; float tp = label * w; return thrust::make_pair(fp, tp); }; // NOLINT dh::LaunchN(d_sorted_idx.size(), - [=] __device__(size_t i) { d_fptp[i] = get_fp_tp(i); }); + [=] XGBOOST_DEVICE(size_t i) { d_fptp[i] = get_fp_tp(i); }); /** * Handle duplicated predictions @@ -303,14 +378,14 @@ float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info auto d_unique_idx = dh::ToSpan(cache->unique_idx); dh::Iota(d_unique_idx); auto uni_key = dh::MakeTransformIterator>( - thrust::make_counting_iterator(0), [=] __device__(size_t i) { + thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { uint32_t class_id = i / n_samples; float predt = d_predts_t[d_sorted_idx[i]]; return thrust::make_pair(class_id, predt); }); // unique values are sparse, so we need a CSR style indptr - dh::TemporaryArray unique_class_ptr(class_ptr.size()); + dh::TemporaryArray unique_class_ptr(d_class_ptr.size()); auto d_unique_class_ptr = dh::ToSpan(unique_class_ptr); auto n_uniques = dh::SegmentedUniqueByKey( thrust::cuda::par(alloc), @@ -324,39 +399,14 @@ float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info thrust::equal_to>{}); d_unique_idx = d_unique_idx.subspan(0, n_uniques); - using Triple = thrust::tuple; - // expand to tuple to include class id - auto fptp_it_in = dh::MakeTransformIterator( - thrust::make_counting_iterator(0), [=] __device__(size_t i) { - return thrust::make_tuple(i, d_fptp[i].first, d_fptp[i].second); - }); - // shrink down to pair - auto fptp_it_out = thrust::make_transform_output_iterator( - dh::TypedDiscard{}, [d_fptp] __device__(Triple const &t) { - d_fptp[thrust::get<0>(t)] = - thrust::make_pair(thrust::get<1>(t), thrust::get<2>(t)); - return t; - }); - dh::InclusiveScan( - fptp_it_in, fptp_it_out, - [=] __device__(Triple const &l, Triple const &r) { - uint32_t l_cid = thrust::get<0>(l) / n_samples; - uint32_t r_cid = thrust::get<0>(r) / n_samples; - if (l_cid != r_cid) { - return r; - } - - return Triple(thrust::get<0>(r), - thrust::get<1>(l) + thrust::get<1>(r), // fp - thrust::get<2>(l) + thrust::get<2>(r)); // tp - }, - d_fptp.size()); + auto get_class_id = [=] XGBOOST_DEVICE(size_t idx) { return idx / n_samples; }; + SegmentedFPTP(d_fptp, get_class_id); // scatter unique FP_PREV/TP_PREV values auto d_neg_pos = dh::ToSpan(cache->neg_pos); // When dataset is not empty, each class must have at least 1 (unique) sample // prediction, so no need to handle special case. - dh::LaunchN(d_unique_idx.size(), [=] __device__(size_t i) { + dh::LaunchN(d_unique_idx.size(), [=] XGBOOST_DEVICE(size_t i) { if (d_unique_idx[i] % n_samples == 0) { // first unique index is 0 assert(d_unique_idx[i] % n_samples == 0); d_neg_pos[d_unique_idx[i]] = {0, 0}; // class_id * n_samples = i @@ -375,32 +425,9 @@ float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info /** * Reduce the result for each class */ - auto key_in = dh::MakeTransformIterator( - thrust::make_counting_iterator(0), [=] __device__(size_t i) { - size_t class_id = d_unique_idx[i] / n_samples; - return class_id; - }); - auto val_in = dh::MakeTransformIterator( - thrust::make_counting_iterator(0), [=] __device__(size_t i) { - size_t class_id = d_unique_idx[i] / n_samples; - float fp, tp; - float fp_prev, tp_prev; - if (i == d_unique_class_ptr[class_id]) { - // first item is ignored, we use this thread to calculate the last item - thrust::tie(fp, tp) = d_fptp[class_id * n_samples + (n_samples - 1)]; - thrust::tie(fp_prev, tp_prev) = - d_neg_pos[d_unique_idx[LastOf(class_id, d_unique_class_ptr)]]; - } else { - thrust::tie(fp, tp) = d_fptp[d_unique_idx[i] - 1]; - thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx[i - 1]]; - } - float auc = TrapesoidArea(fp_prev, fp, tp_prev, tp); - return auc; - }); - - thrust::reduce_by_key(thrust::cuda::par(alloc), key_in, - key_in + d_unique_idx.size(), val_in, - thrust::make_discard_iterator(), d_auc.begin()); + auto s_d_auc = dh::ToSpan(d_auc); + SegmentedReduceAUC(d_unique_idx, d_class_ptr, d_unique_class_ptr, cache, + area_fn, get_class_id, s_d_auc); /** * Scale the classes with number of samples for each class. @@ -412,16 +439,58 @@ float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info auto tp = d_results.subspan(2 * n_classes, n_classes); auto auc = d_results.subspan(3 * n_classes, n_classes); - dh::LaunchN(n_classes, [=] __device__(size_t c) { + dh::LaunchN(n_classes, [=] XGBOOST_DEVICE(size_t c) { auc[c] = s_d_auc[c]; auto last = d_fptp[n_samples * c + (n_samples - 1)]; fp[c] = last.first; - tp[c] = last.second; - local_area[c] = last.first * last.second; + if (scale) { + local_area[c] = last.first * last.second; + tp[c] = last.second; + } else { + local_area[c] = 1.0f; + tp[c] = 1.0f; + } }); return ScaleClasses(d_results, local_area, fp, tp, auc, cache, n_classes); } +void MultiClassSortedIdx(common::Span predts, + common::Span d_class_ptr, + std::shared_ptr cache) { + size_t n_classes = d_class_ptr.size() - 1; + auto d_predts_t = dh::ToSpan(cache->predts_t); + auto n_samples = d_predts_t.size() / n_classes; + if (n_samples == 0) { + return; + } + Transpose(predts, d_predts_t, n_samples, n_classes); + dh::LaunchN(n_classes + 1, + [=] XGBOOST_DEVICE(size_t i) { d_class_ptr[i] = i * n_samples; }); + auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); + dh::SegmentedArgSort(d_predts_t, d_class_ptr, d_sorted_idx); +} + +float GPUMultiClassROCAUC(common::Span predts, + MetaInfo const &info, int32_t device, + std::shared_ptr *p_cache, + size_t n_classes) { + auto& cache = *p_cache; + InitCacheOnce(predts, device, p_cache); + + /** + * Create sorted index for each class + */ + dh::TemporaryArray class_ptr(n_classes + 1, 0); + MultiClassSortedIdx(predts, dh::ToSpan(class_ptr), cache); + + auto fn = [] XGBOOST_DEVICE(float fp_prev, float fp, float tp_prev, float tp, + size_t /*class_id*/) { + return TrapezoidArea(fp_prev, fp, tp_prev, tp); + }; + return GPUMultiClassAUCOVR(predts, info, device, dh::ToSpan(class_ptr), + n_classes, cache, fn); +} + namespace { struct RankScanItem { size_t idx; @@ -435,10 +504,7 @@ std::pair GPURankingAUC(common::Span predts, MetaInfo const &info, int32_t device, std::shared_ptr *p_cache) { auto& cache = *p_cache; - if (!cache) { - cache.reset(new DeviceAUCCache); - } - cache->Init(predts, false, device); + InitCacheOnce(predts, device, p_cache); dh::caching_device_vector group_ptr(info.group_ptr_); dh::XGBCachingDeviceAllocator alloc; @@ -449,10 +515,10 @@ GPURankingAUC(common::Span predts, MetaInfo const &info, */ auto check_it = dh::MakeTransformIterator( thrust::make_counting_iterator(0), - [=] __device__(size_t i) { return d_group_ptr[i + 1] - d_group_ptr[i]; }); + [=] XGBOOST_DEVICE(size_t i) { return d_group_ptr[i + 1] - d_group_ptr[i]; }); size_t n_valid = thrust::count_if( thrust::cuda::par(alloc), check_it, check_it + group_ptr.size() - 1, - [=] __device__(size_t len) { return len >= 3; }); + [=] XGBOOST_DEVICE(size_t len) { return len >= 3; }); if (n_valid < info.group_ptr_.size() - 1) { InvalidGroupAUC(); } @@ -476,7 +542,7 @@ GPURankingAUC(common::Span predts, MetaInfo const &info, auto n_threads = common::SegmentedTrapezoidThreads( d_group_ptr, d_threads_group_ptr, std::numeric_limits::max()); // get the coordinate in nested summation - auto get_i_j = [=]__device__(size_t idx, size_t query_group_idx) { + auto get_i_j = [=]XGBOOST_DEVICE(size_t idx, size_t query_group_idx) { auto data_group_begin = d_group_ptr[query_group_idx]; size_t n_samples = d_group_ptr[query_group_idx + 1] - data_group_begin; auto thread_group_begin = d_threads_group_ptr[query_group_idx]; @@ -491,7 +557,7 @@ GPURankingAUC(common::Span predts, MetaInfo const &info, return thrust::make_pair(i, j); }; // NOLINT auto in = dh::MakeTransformIterator( - thrust::make_counting_iterator(0), [=] __device__(size_t idx) { + thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t idx) { bst_group_t query_group_idx = dh::SegmentId(d_threads_group_ptr, idx); auto data_group_begin = d_group_ptr[query_group_idx]; size_t n_samples = d_group_ptr[query_group_idx + 1] - data_group_begin; @@ -519,7 +585,8 @@ GPURankingAUC(common::Span predts, MetaInfo const &info, dh::TemporaryArray d_auc(group_ptr.size() - 1); auto s_d_auc = dh::ToSpan(d_auc); auto out = thrust::make_transform_output_iterator( - dh::TypedDiscard{}, [=] __device__(RankScanItem const &item) -> RankScanItem { + dh::TypedDiscard{}, + [=] XGBOOST_DEVICE(RankScanItem const &item) -> RankScanItem { auto group_id = item.group_id; assert(group_id < d_group_ptr.size()); auto data_group_begin = d_group_ptr[group_id]; @@ -536,7 +603,7 @@ GPURankingAUC(common::Span predts, MetaInfo const &info, }); dh::InclusiveScan( in, out, - [] __device__(RankScanItem const &l, RankScanItem const &r) { + [] XGBOOST_DEVICE(RankScanItem const &l, RankScanItem const &r) { if (l.group_id != r.group_id) { return r; } @@ -551,5 +618,288 @@ GPURankingAUC(common::Span predts, MetaInfo const &info, dh::tend(s_d_auc), 0.0f); return std::make_pair(auc, n_valid); } + +std::tuple +GPUBinaryPRAUC(common::Span predts, MetaInfo const &info, + int32_t device, std::shared_ptr *p_cache) { + auto& cache = *p_cache; + InitCacheOnce(predts, device, p_cache); + + /** + * Create sorted index for each class + */ + auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); + dh::ArgSort(predts, d_sorted_idx); + + auto labels = info.labels_.ConstDeviceSpan(); + auto d_weights = info.weights_.ConstDeviceSpan(); + auto get_weight = OptionalWeights{d_weights}; + auto it = dh::MakeTransformIterator>( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { + auto w = get_weight[d_sorted_idx[i]]; + return thrust::make_pair(labels[d_sorted_idx[i]] * w, + (1.0f - labels[d_sorted_idx[i]]) * w); + }); + dh::XGBCachingDeviceAllocator alloc; + float total_pos, total_neg; + thrust::tie(total_pos, total_neg) = + thrust::reduce(thrust::cuda::par(alloc), it, it + labels.size(), + Pair{0.0f, 0.0f}, PairPlus{}); + + if (total_pos <= 0.0 || total_neg <= 0.0) { + return {0.0f, 0.0f, 0.0f}; + } + + auto fn = [total_pos] XGBOOST_DEVICE(float fp_prev, float fp, float tp_prev, + float tp) { + return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, total_pos); + }; + float fp, tp, auc; + std::tie(fp, tp, auc) = GPUBinaryAUC(predts, info, device, d_sorted_idx, fn, cache); + return std::make_tuple(1.0, 1.0, auc); +} + +float GPUMultiClassPRAUC(common::Span predts, + MetaInfo const &info, int32_t device, + std::shared_ptr *p_cache, + size_t n_classes) { + auto& cache = *p_cache; + InitCacheOnce(predts, device, p_cache); + + /** + * Create sorted index for each class + */ + dh::TemporaryArray class_ptr(n_classes + 1, 0); + auto d_class_ptr = dh::ToSpan(class_ptr); + MultiClassSortedIdx(predts, d_class_ptr, cache); + auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); + + auto d_weights = info.weights_.ConstDeviceSpan(); + + /** + * Get total positive/negative + */ + auto labels = info.labels_.ConstDeviceSpan(); + auto n_samples = info.num_row_; + dh::caching_device_vector> totals(n_classes); + auto key_it = + dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), + [n_samples] XGBOOST_DEVICE(size_t i) { + return i / n_samples; // class id + }); + auto get_weight = OptionalWeights{d_weights}; + auto val_it = dh::MakeTransformIterator>( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { + auto idx = d_sorted_idx[i] % n_samples; + auto w = get_weight[idx]; + auto class_id = i / n_samples; + auto y = labels[idx] == class_id; + return thrust::make_pair(y * w, (1.0f - y) * w); + }); + dh::XGBCachingDeviceAllocator alloc; + thrust::reduce_by_key(thrust::cuda::par(alloc), key_it, + key_it + predts.size(), val_it, + thrust::make_discard_iterator(), totals.begin(), + thrust::equal_to{}, PairPlus{}); + + /** + * Calculate AUC + */ + auto d_totals = dh::ToSpan(totals); + auto fn = [d_totals] XGBOOST_DEVICE(float fp_prev, float fp, float tp_prev, + float tp, size_t class_id) { + auto total_pos = d_totals[class_id].first; + return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, + d_totals[class_id].first); + }; + return GPUMultiClassAUCOVR(predts, info, device, d_class_ptr, + n_classes, cache, fn); +} + +template +std::pair +GPURankingPRAUCImpl(common::Span predts, MetaInfo const &info, + common::Span d_group_ptr, int32_t device, + std::shared_ptr cache, Fn area_fn) { + /** + * Sorted idx + */ + auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); + + auto labels = info.labels_.ConstDeviceSpan(); + auto weights = info.weights_.ConstDeviceSpan(); + + uint32_t n_groups = static_cast(info.group_ptr_.size() - 1); + + /** + * Linear scan + */ + size_t n_samples = labels.size(); + dh::caching_device_vector d_auc(n_groups, 0); + auto get_weight = OptionalWeights{weights}; + auto d_fptp = dh::ToSpan(cache->fptp); + auto get_fp_tp = [=] XGBOOST_DEVICE(size_t i) { + size_t idx = d_sorted_idx[i]; + + size_t group_id = dh::SegmentId(d_group_ptr, idx); + float label = labels[idx]; + + float w = get_weight[group_id]; + float fp = (1.0 - label) * w; + float tp = label * w; + return thrust::make_pair(fp, tp); + }; // NOLINT + dh::LaunchN(d_sorted_idx.size(), + [=] XGBOOST_DEVICE(size_t i) { d_fptp[i] = get_fp_tp(i); }); + + /** + * Handle duplicated predictions + */ + dh::XGBDeviceAllocator alloc; + auto d_unique_idx = dh::ToSpan(cache->unique_idx); + dh::Iota(d_unique_idx); + auto uni_key = dh::MakeTransformIterator>( + thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { + auto idx = d_sorted_idx[i]; + bst_group_t group_id = dh::SegmentId(d_group_ptr, idx); + float predt = predts[idx]; + return thrust::make_pair(group_id, predt); + }); + + // unique values are sparse, so we need a CSR style indptr + dh::TemporaryArray unique_class_ptr(d_group_ptr.size()); + auto d_unique_class_ptr = dh::ToSpan(unique_class_ptr); + auto n_uniques = dh::SegmentedUniqueByKey( + thrust::cuda::par(alloc), + dh::tbegin(d_group_ptr), + dh::tend(d_group_ptr), + uni_key, + uni_key + d_sorted_idx.size(), + dh::tbegin(d_unique_idx), + d_unique_class_ptr.data(), + dh::tbegin(d_unique_idx), + thrust::equal_to>{}); + d_unique_idx = d_unique_idx.subspan(0, n_uniques); + + auto get_group_id = [=] XGBOOST_DEVICE(size_t idx) { + return dh::SegmentId(d_group_ptr, idx); + }; + SegmentedFPTP(d_fptp, get_group_id); + + // scatter unique FP_PREV/TP_PREV values + auto d_neg_pos = dh::ToSpan(cache->neg_pos); + dh::LaunchN(d_unique_idx.size(), [=] XGBOOST_DEVICE(size_t i) { + if (thrust::binary_search(thrust::seq, d_unique_class_ptr.cbegin(), + d_unique_class_ptr.cend(), + i)) { // first unique index is 0 + d_neg_pos[d_unique_idx[i]] = {0, 0}; + return; + } + auto group_idx = dh::SegmentId(d_group_ptr, d_unique_idx[i]); + d_neg_pos[d_unique_idx[i]] = d_fptp[d_unique_idx[i] - 1]; + if (i == LastOf(group_idx, d_unique_class_ptr)) { + // last one needs to be included. + size_t last = d_unique_idx[LastOf(group_idx, d_unique_class_ptr)]; + d_neg_pos[LastOf(group_idx, d_group_ptr)] = d_fptp[last - 1]; + return; + } + }); + + /** + * Reduce the result for each group + */ + auto s_d_auc = dh::ToSpan(d_auc); + SegmentedReduceAUC(d_unique_idx, d_group_ptr, d_unique_class_ptr, cache, + area_fn, get_group_id, s_d_auc); + + /** + * Scale the groups with number of samples for each group. + */ + float auc; + uint32_t invalid_groups; + { + auto it = dh::MakeTransformIterator>( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t g) { + float fp, tp; + thrust::tie(fp, tp) = d_fptp[LastOf(g, d_group_ptr)]; + float area = fp * tp; + auto n_documents = d_group_ptr[g + 1] - d_group_ptr[g]; + if (area > 0 && n_documents >= 2) { + return thrust::make_pair(s_d_auc[g], static_cast(0)); + } + return thrust::make_pair(0.0f, static_cast(1)); + }); + thrust::tie(auc, invalid_groups) = thrust::reduce( + thrust::cuda::par(alloc), it, it + n_groups, + thrust::pair(0.0f, 0), PairPlus{}); + } + return std::make_pair(auc, n_groups - invalid_groups); +} + +std::pair +GPURankingPRAUC(common::Span predts, MetaInfo const &info, + int32_t device, std::shared_ptr *p_cache) { + dh::safe_cuda(cudaSetDevice(device)); + if (predts.empty()) { + return std::make_pair(0.0f, static_cast(0)); + } + + auto &cache = *p_cache; + InitCacheOnce(predts, device, p_cache); + + dh::device_vector group_ptr(info.group_ptr_.size()); + thrust::copy(info.group_ptr_.begin(), info.group_ptr_.end(), group_ptr.begin()); + auto d_group_ptr = dh::ToSpan(group_ptr); + CHECK_GE(info.group_ptr_.size(), 1) << "Must have at least 1 query group for LTR."; + size_t n_groups = info.group_ptr_.size() - 1; + + /** + * Create sorted index for each group + */ + auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); + dh::SegmentedArgSort(predts, d_group_ptr, d_sorted_idx); + + dh::XGBDeviceAllocator alloc; + auto labels = info.labels_.ConstDeviceSpan(); + if (thrust::any_of(thrust::cuda::par(alloc), dh::tbegin(labels), + dh::tend(labels), PRAUCLabelInvalid{})) { + InvalidLabels(); + } + /** + * Get total positive/negative for each group. + */ + auto d_weights = info.weights_.ConstDeviceSpan(); + dh::caching_device_vector> totals(n_groups); + auto key_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), + [=] XGBOOST_DEVICE(size_t i) { return dh::SegmentId(d_group_ptr, i); }); + auto val_it = dh::MakeTransformIterator>( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { + float w = 1.0f; + if (!d_weights.empty()) { + // Avoid a binary search if the groups are not weighted. + auto g = dh::SegmentId(d_group_ptr, i); + w = d_weights[g]; + } + auto y = labels[i]; + return thrust::make_pair(y * w, (1.0f - y) * w); + }); + thrust::reduce_by_key(thrust::cuda::par(alloc), key_it, + key_it + predts.size(), val_it, + thrust::make_discard_iterator(), totals.begin(), + thrust::equal_to{}, PairPlus{}); + + /** + * Calculate AUC + */ + auto d_totals = dh::ToSpan(totals); + auto fn = [d_totals] XGBOOST_DEVICE(float fp_prev, float fp, float tp_prev, + float tp, size_t group_id) { + auto total_pos = d_totals[group_id].first; + return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, + d_totals[group_id].first); + }; + return GPURankingPRAUCImpl(predts, info, d_group_ptr, n_groups, cache, fn); +} } // namespace metric } // namespace xgboost diff --git a/src/metric/auc.h b/src/metric/auc.h index d549ac426fbc..db399a060a09 100644 --- a/src/metric/auc.h +++ b/src/metric/auc.h @@ -3,7 +3,9 @@ */ #ifndef XGBOOST_METRIC_AUC_H_ #define XGBOOST_METRIC_AUC_H_ +#include #include +#include #include #include #include @@ -12,32 +14,115 @@ #include "xgboost/base.h" #include "xgboost/span.h" #include "xgboost/data.h" +#include "xgboost/metric.h" +#include "../common/common.h" +#include "../common/threading_utils.h" namespace xgboost { namespace metric { -XGBOOST_DEVICE inline float TrapesoidArea(float x0, float x1, float y0, float y1) { +/*********** + * ROC AUC * + ***********/ +XGBOOST_DEVICE inline float TrapezoidArea(float x0, float x1, float y0, float y1) { return std::abs(x0 - x1) * (y0 + y1) * 0.5f; } struct DeviceAUCCache; std::tuple -GPUBinaryAUC(common::Span predts, MetaInfo const &info, - int32_t device, std::shared_ptr *p_cache); +GPUBinaryROCAUC(common::Span predts, MetaInfo const &info, + int32_t device, std::shared_ptr *p_cache); -float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info, - int32_t device, std::shared_ptr* cache, +float GPUMultiClassROCAUC(common::Span predts, + MetaInfo const &info, int32_t device, + std::shared_ptr *cache, size_t n_classes); std::pair GPURankingAUC(common::Span predts, MetaInfo const &info, int32_t device, std::shared_ptr *cache); +/********** + * PR AUC * + **********/ +std::tuple +GPUBinaryPRAUC(common::Span predts, MetaInfo const &info, + int32_t device, std::shared_ptr *p_cache); + +float GPUMultiClassPRAUC(common::Span predts, MetaInfo const &info, + int32_t device, std::shared_ptr *cache, + size_t n_classes); + +std::pair +GPURankingPRAUC(common::Span predts, MetaInfo const &info, + int32_t device, std::shared_ptr *cache); + +namespace detail { +XGBOOST_DEVICE inline float CalcH(float fp_a, float fp_b, float tp_a, + float tp_b) { + return (fp_b - fp_a) / (tp_b - tp_a); +} + +XGBOOST_DEVICE inline float CalcB(float fp_a, float h, float tp_a, float total_pos) { + return (fp_a - h * tp_a) / total_pos; +} + +XGBOOST_DEVICE inline float CalcA(float h) { return h + 1; } + +XGBOOST_DEVICE inline float CalcDeltaPRAUC(float fp_prev, float fp, + float tp_prev, float tp, + float total_pos) { + float pr_prev = tp_prev / total_pos; + float pr = tp / total_pos; + + float h{0}, a{0}, b{0}; + + if (tp == tp_prev) { + a = 1.0; + b = 0.0; + } else { + h = detail::CalcH(fp_prev, fp, tp_prev, tp); + a = detail::CalcA(h); + b = detail::CalcB(fp_prev, h, tp_prev, total_pos); + } + + float area = 0; + if (b != 0.0) { + area = (pr - pr_prev - + b / a * (std::log(a * pr + b) - std::log(a * pr_prev + b))) / + a; + } else { + area = (pr - pr_prev) / a; + } + return area; +} +} // namespace detail + inline void InvalidGroupAUC() { LOG(INFO) << "Invalid group with less than 3 samples is found on worker " << rabit::GetRank() << ". Calculating AUC value requires at " << "least 2 pairs of samples."; } + +struct PRAUCLabelInvalid { + XGBOOST_DEVICE bool operator()(float y) { return y < 0.0f || y > 1.0f; } +}; + +inline void InvalidLabels() { + LOG(FATAL) << "PR-AUC supports only binary relevance for learning to rank."; +} + +struct OptionalWeights { + common::Span weights; + float dft { 1.0f }; + + explicit OptionalWeights(common::Span w) : weights{w} {} + explicit OptionalWeights(float w) : dft{w} {} + + XGBOOST_DEVICE float operator[](size_t i) const { + return weights.empty() ? dft : weights[i]; + } +}; } // namespace metric } // namespace xgboost #endif // XGBOOST_METRIC_AUC_H_ diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index 193938c0f8e6..6d8b5d3da124 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -388,166 +388,10 @@ struct EvalCox : public Metric { } }; -/*! \brief Area Under PR Curve, for both classification and rank computed on CPU */ -struct EvalAucPR : public Metric { - // implementation of AUC-PR for weighted data - // translated from PRROC R Package - // see https://doi.org/10.1371/journal.pone.0092209 - private: - // This is used to compute the AUCPR metrics on the GPU - for ranking tasks and - // for training jobs that run on the GPU. - std::unique_ptr aucpr_gpu_; - - template - bst_float Eval(const HostDeviceVector &preds, - const MetaInfo &info, - bool distributed, - const std::vector &gptr) { - const auto ngroups = static_cast(gptr.size() - 1); - - // sum of all AUC's across all query groups - double sum_auc = 0.0; - int auc_error = 0; - - const auto &h_labels = info.labels_.ConstHostVector(); - const auto &h_preds = preds.ConstHostVector(); - - dmlc::OMPException exc; - #pragma omp parallel reduction(+:sum_auc, auc_error) if (ngroups > 1) - { - exc.Run([&]() { - // Each thread works on a distinct group and sorts the predictions in that group - PredIndPairContainer rec; - #pragma omp for schedule(static) - for (bst_omp_uint group_id = 0; group_id < ngroups; ++group_id) { - exc.Run([&]() { - double total_pos = 0.0; - double total_neg = 0.0; - // Same thread can work on multiple groups one after another; hence, resize - // the predictions array based on the current group - rec.resize(gptr[group_id + 1] - gptr[group_id]); - #pragma omp parallel for schedule(static) reduction(+:total_pos, total_neg) \ - if (!omp_in_parallel()) // NOLINT - for (bst_omp_uint j = gptr[group_id]; j < gptr[group_id + 1]; ++j) { - exc.Run([&]() { - const bst_float wt = WeightPolicy::GetWeightOfInstance(info, j, group_id); - total_pos += wt * h_labels[j]; - total_neg += wt * (1.0f - h_labels[j]); - rec[j - gptr[group_id]] = {h_preds[j], j}; - }); - } - - // we need pos > 0 && neg > 0 - if (total_pos <= 0.0 || total_neg <= 0.0) { - auc_error += 1; - return; - } - - XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst); - - // calculate AUC - double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0; - for (size_t j = 0; j < rec.size(); ++j) { - const bst_float wt = WeightPolicy::GetWeightOfSortedRecord(info, rec, j, group_id); - tp += wt * h_labels[rec[j].second]; - fp += wt * (1.0f - h_labels[rec[j].second]); - if ((j < rec.size() - 1 && rec[j].first != rec[j + 1].first) || - j == rec.size() - 1) { - if (tp == prevtp) { - a = 1.0; - b = 0.0; - } else { - h = (fp - prevfp) / (tp - prevtp); - a = 1.0 + h; - b = (prevfp - h * prevtp) / total_pos; - } - if (0.0 != b) { - sum_auc += (tp / total_pos - prevtp / total_pos - - b / a * (std::log(a * tp / total_pos + b) - - std::log(a * prevtp / total_pos + b))) / a; - } else { - sum_auc += (tp / total_pos - prevtp / total_pos) / a; - } - prevtp = tp; - prevfp = fp; - } - } - // sanity check - if (tp < 0 || prevtp < 0 || fp < 0 || prevfp < 0) { - CHECK(!auc_error) << "AUC-PR: error in calculation"; - } - }); - } - }); - } - exc.Rethrow(); - - // Report average AUC-PR across all groups - // In distributed mode, workers which only contains pos or neg samples - // will be ignored when aggregate AUC-PR. - bst_float dat[2] = {0.0f, 0.0f}; - if (auc_error < static_cast(ngroups)) { - dat[0] = static_cast(sum_auc); - dat[1] = static_cast(static_cast(ngroups) - auc_error); - } - if (distributed) { - rabit::Allreduce(dat, 2); - } - CHECK_GT(dat[1], 0.0f) - << "AUC-PR: the dataset only contains pos or neg samples"; - CHECK_LE(dat[0], dat[1]) << "AUC-PR: AUC > 1.0"; - return dat[0] / dat[1]; - } - - public: - bst_float Eval(const HostDeviceVector &preds, - const MetaInfo &info, - bool distributed) override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.Size(), info.labels_.Size()) - << "label size predict size not match"; - std::vector tgptr(2, 0); - tgptr[1] = static_cast(info.labels_.Size()); - - const auto &gptr = info.group_ptr_.empty() ? tgptr : info.group_ptr_; - CHECK_EQ(gptr.back(), info.labels_.Size()) - << "EvalAucPR: group structure must match number of prediction"; - - // For ranking task, weights are per-group - // For binary classification task, weights are per-instance - const bool is_ranking_task = - !info.group_ptr_.empty() && info.weights_.Size() != info.num_row_; - - // Check if we have a GPU assignment; else, revert back to CPU - if (tparam_->gpu_id >= 0 && is_ranking_task) { - if (!aucpr_gpu_) { - // Check and see if we have the GPU metric registered in the internal registry - aucpr_gpu_.reset(GPUMetric::CreateGPUMetric(this->Name(), tparam_)); - } - - if (aucpr_gpu_) { - return aucpr_gpu_->Eval(preds, info, distributed); - } - } - - if (is_ranking_task) { - return Eval(preds, info, distributed, gptr); - } else { - return Eval(preds, info, distributed, gptr); - } - } - - const char *Name() const override { return "aucpr"; } -}; - XGBOOST_REGISTER_METRIC(AMS, "ams") .describe("AMS metric for higgs.") .set_body([](const char* param) { return new EvalAMS(param); }); -XGBOOST_REGISTER_METRIC(AucPR, "aucpr") -.describe("Area under PR curve for both classification and rank.") -.set_body([](const char*) { return new EvalAucPR(); }); - XGBOOST_REGISTER_METRIC(Precision, "pre") .describe("precision@k for rank.") .set_body([](const char* param) { return new EvalPrecision("pre", param); }); diff --git a/src/metric/rank_metric.cu b/src/metric/rank_metric.cu index 834ca3078e0b..57757b22badd 100644 --- a/src/metric/rank_metric.cu +++ b/src/metric/rank_metric.cu @@ -274,196 +274,6 @@ struct EvalMAPGpu { } }; -/*! \brief Area Under PR Curve metric computation for ranking datasets */ -struct EvalAucPRGpu : public Metric { - public: - // This function object computes the item's positive/negative precision value - class ComputeItemPrecision : public thrust::unary_function { - public: - // The precision type to be computed - enum class PrecisionType { - kPositive, - kNegative - }; - - XGBOOST_DEVICE ComputeItemPrecision(PrecisionType ptype, - uint32_t ngroups, - const float *dweights, - const xgboost::common::Span &dgidxs, - const float *dlabels) - : ptype_(ptype), ngroups_(ngroups), dweights_(dweights), dgidxs_(dgidxs), dlabels_(dlabels) {} - - // Compute precision value for the prediction that was originally at 'idx' - __device__ __forceinline__ float operator()(uint32_t idx) const { - // For ranking task, weights are per-group - // For binary classification task, weights are per-instance - const auto wt = dweights_ == nullptr ? 1.0f : dweights_[ngroups_ == 1 ? idx : dgidxs_[idx]]; - return wt * (ptype_ == PrecisionType::kPositive ? dlabels_[idx] : (1.0f - dlabels_[idx])); - } - - private: - PrecisionType ptype_; // Precision type to be computed - uint32_t ngroups_; // Number of groups in the dataset - const float *dweights_; // Instance/group weights - const xgboost::common::Span dgidxs_; // The group a given instance belongs to - const float *dlabels_; // Unsorted labels in the dataset - }; - - bst_float Eval(const HostDeviceVector &preds, - const MetaInfo &info, - bool distributed) override { - // Sanity check is done by the caller - std::vector tgptr(2, 0); - tgptr[1] = static_cast(info.labels_.Size()); - const std::vector &gptr = info.group_ptr_.empty() ? tgptr : info.group_ptr_; - - auto device = tparam_->gpu_id; - dh::safe_cuda(cudaSetDevice(device)); - - info.labels_.SetDevice(device); - preds.SetDevice(device); - info.weights_.SetDevice(device); - - auto dpreds = preds.ConstDevicePointer(); - auto dlabels = info.labels_.ConstDevicePointer(); - auto dweights = info.weights_.ConstDevicePointer(); - - // Sort all the predictions - dh::SegmentSorter segment_pred_sorter; - segment_pred_sorter.SortItems(dpreds, preds.Size(), gptr); - - const auto &dsorted_preds = segment_pred_sorter.GetItemsSpan(); - // Original positions of the predictions after they have been sorted - const auto &dpreds_orig_pos = segment_pred_sorter.GetOriginalPositionsSpan(); - - // Group info on device - const auto &dgroups = segment_pred_sorter.GetGroupsSpan(); - uint32_t ngroups = segment_pred_sorter.GetNumGroups(); - const auto &dgroup_idx = segment_pred_sorter.GetGroupSegmentsSpan(); - - // First, aggregate the positive and negative precision for each group - dh::caching_device_vector total_pos(ngroups, 0); - dh::caching_device_vector total_neg(ngroups, 0); - - // Allocator to be used for managing space overhead while performing transformed reductions - dh::XGBCachingDeviceAllocator alloc; - - // Compute each elements positive precision value and reduce them across groups concurrently. - ComputeItemPrecision pos_prec_functor(ComputeItemPrecision::PrecisionType::kPositive, - ngroups, dweights, dgroup_idx, dlabels); - auto end_range = - thrust::reduce_by_key(thrust::cuda::par(alloc), - dh::tcbegin(dgroup_idx), dh::tcend(dgroup_idx), - thrust::make_transform_iterator( - // The indices need not be sequential within a group, as we care only - // about the sum of positive precision values within a group - dh::tcbegin(segment_pred_sorter.GetOriginalPositionsSpan()), - pos_prec_functor), - thrust::make_discard_iterator(), // We don't care for the group indices - total_pos.begin()); // Sum of positive precision values in the group - CHECK(end_range.second - total_pos.begin() == total_pos.size()); - - // Compute each elements negative precision value and reduce them across groups concurrently. - ComputeItemPrecision neg_prec_functor(ComputeItemPrecision::PrecisionType::kNegative, - ngroups, dweights, dgroup_idx, dlabels); - end_range = - thrust::reduce_by_key(thrust::cuda::par(alloc), - dh::tcbegin(dgroup_idx), dh::tcend(dgroup_idx), - thrust::make_transform_iterator( - // The indices need not be sequential within a group, as we care only - // about the sum of negative precision values within a group - dh::tcbegin(segment_pred_sorter.GetOriginalPositionsSpan()), - neg_prec_functor), - thrust::make_discard_iterator(), // We don't care for the group indices - total_neg.begin()); // Sum of negative precision values in the group - CHECK(end_range.second - total_neg.begin() == total_neg.size()); - - const auto *dtotal_pos = total_pos.data().get(); - const auto *dtotal_neg = total_neg.data().get(); - - // AUC sum for each group - dh::caching_device_vector sum_auc(ngroups, 0); - // AUC error across all groups - dh::caching_device_vector auc_error(1, 0); - auto *dsum_auc = sum_auc.data().get(); - auto *dauc_error = auc_error.data().get(); - - int device_id = -1; - dh::safe_cuda(cudaGetDevice(&device_id)); - // For each group item compute the aggregated precision - dh::LaunchN<1, 32>(ngroups, nullptr, [=] __device__(uint32_t gidx) { - // We need pos > 0 && neg > 0 - if (dtotal_pos[gidx] <= 0.0 || dtotal_neg[gidx] <= 0.0) { - atomicAdd(dauc_error, 1); - } else { - auto gbegin = dgroups[gidx]; - auto gend = dgroups[gidx + 1]; - // Calculate AUC - double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0; - for (auto i = gbegin; i < gend; ++i) { - const auto wt = dweights == nullptr ? 1.0f - : dweights[ngroups == 1 ? dpreds_orig_pos[i] : gidx]; - tp += wt * dlabels[dpreds_orig_pos[i]]; - fp += wt * (1.0f - dlabels[dpreds_orig_pos[i]]); - if ((i < gend - 1 && dsorted_preds[i] != dsorted_preds[i + 1]) || (i == gend - 1)) { - if (tp == prevtp) { - a = 1.0; - b = 0.0; - } else { - h = (fp - prevfp) / (tp - prevtp); - a = 1.0 + h; - b = (prevfp - h * prevtp) / dtotal_pos[gidx]; - } - if (0.0 != b) { - dsum_auc[gidx] += (tp / dtotal_pos[gidx] - prevtp / dtotal_pos[gidx] - - b / a * (std::log(a * tp / dtotal_pos[gidx] + b) - - std::log(a * prevtp / dtotal_pos[gidx] + b))) / a; - } else { - dsum_auc[gidx] += (tp / dtotal_pos[gidx] - prevtp / dtotal_pos[gidx]) / a; - } - prevtp = tp; - prevfp = fp; - } - } - - // Sanity check - if (tp < 0 || prevtp < 0 || fp < 0 || prevfp < 0) { - // Check if we have any metric error thus far - auto current_auc_error = atomicAdd(dauc_error, 0); - KERNEL_CHECK(!current_auc_error); - } - } - }); - - const auto hsum_auc = thrust::reduce(thrust::cuda::par(alloc), sum_auc.begin(), sum_auc.end()); - const auto hauc_error = auc_error.back(); // Copy it back to host - - // Report average AUC-PR across all groups - // In distributed mode, workers which only contains pos or neg samples - // will be ignored when aggregate AUC-PR. - bst_float dat[2] = {0.0f, 0.0f}; - if (hauc_error < static_cast(ngroups)) { - dat[0] = static_cast(hsum_auc); - dat[1] = static_cast(static_cast(ngroups) - hauc_error); - } - if (distributed) { - rabit::Allreduce(dat, 2); - } - CHECK_GT(dat[1], 0.0f) - << "AUC-PR: the dataset only contains pos or neg samples"; - CHECK_LE(dat[0], dat[1]) << "AUC-PR: AUC > 1.0"; - return dat[0] / dat[1]; - } - - const char* Name() const override { - return "aucpr"; - } -}; - -XGBOOST_REGISTER_GPU_METRIC(AucPRGpu, "aucpr") -.describe("Area under PR curve for rank computed on GPU.") -.set_body([](const char* param) { return new EvalAucPRGpu(); }); - XGBOOST_REGISTER_GPU_METRIC(PrecisionGpu, "pre") .describe("precision@k for rank computed on GPU.") .set_body([](const char* param) { return new EvalRankGpu("pre", param); }); diff --git a/tests/cpp/metric/test_auc.cc b/tests/cpp/metric/test_auc.cc index 76535ad94f3b..a8ae7eeafa8d 100644 --- a/tests/cpp/metric/test_auc.cc +++ b/tests/cpp/metric/test_auc.cc @@ -48,7 +48,7 @@ TEST(Metric, DeclareUnifiedTest(BinaryAUC)) { 0.5, 1e-10); } -TEST(Metric, DeclareUnifiedTest(MultiAUC)) { +TEST(Metric, DeclareUnifiedTest(MultiClassAUC)) { auto tparam = CreateEmptyGenericParam(GPUIDX); std::unique_ptr uni_ptr{ Metric::Create("auc", &tparam)}; @@ -64,6 +64,17 @@ TEST(Metric, DeclareUnifiedTest(MultiAUC)) { }, {0, 1, 2}), 1.0f, 1e-10); + + EXPECT_NEAR(GetMetricEval(metric, + { + 1.0f, 0.0f, 0.0f, // p_0 + 0.0f, 1.0f, 0.0f, // p_1 + 0.0f, 0.0f, 1.0f // p_2 + }, + {0, 1, 2}, + {1.0f, 1.0f, 1.0f}), + 1.0f, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, { 1.0f, 0.0f, 0.0f, // p_0 @@ -72,6 +83,7 @@ TEST(Metric, DeclareUnifiedTest(MultiAUC)) { }, {2, 1, 0}), 0.5f, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, { 1.0f, 0.0f, 0.0f, // p_0 @@ -139,5 +151,110 @@ TEST(Metric, DeclareUnifiedTest(RankingAUC)) { /*weights=*/{}, groups), 0.769841f, 1e-6); } + +TEST(Metric, DeclareUnifiedTest(PRAUC)) { + auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX); + + xgboost::Metric *metric = xgboost::Metric::Create("aucpr", &tparam); + ASSERT_STREQ(metric->Name(), "aucpr"); + EXPECT_NEAR(GetMetricEval(metric, {0, 0, 1, 1}, {0, 0, 1, 1}), 1, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, {0, 0, 1, 1}), + 0.5f, 0.001f); + EXPECT_NEAR(GetMetricEval( + metric, + {0.4f, 0.2f, 0.9f, 0.1f, 0.2f, 0.4f, 0.1f, 0.1f, 0.2f, 0.1f}, + {0, 0, 0, 0, 0, 1, 0, 0, 1, 1}), + 0.2908445f, 0.001f); + EXPECT_NEAR(GetMetricEval( + metric, {0.87f, 0.31f, 0.40f, 0.42f, 0.25f, 0.66f, 0.95f, + 0.09f, 0.10f, 0.97f, 0.76f, 0.69f, 0.15f, 0.20f, + 0.30f, 0.14f, 0.07f, 0.58f, 0.61f, 0.08f}, + {0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1}), + 0.2769199f, 0.001f); + auto auc = GetMetricEval(metric, {0, 1}, {}); + ASSERT_TRUE(std::isnan(auc)); + + // AUCPR with instance weights + EXPECT_NEAR(GetMetricEval(metric, + {0.29f, 0.52f, 0.11f, 0.21f, 0.219f, 0.93f, 0.493f, + 0.17f, 0.47f, 0.13f, 0.43f, 0.59f, 0.87f, 0.007f}, + {0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0}, + {1, 2, 7, 4, 5, 2.2f, 3.2f, 5, 6, 1, 2, 1.1f, 3.2f, + 4.5f}), // weights + 0.694435f, 0.001f); + + // Both groups contain only pos or neg samples. + auc = GetMetricEval(metric, + {0, 0.1f, 0.3f, 0.5f, 0.7f}, + {1, 1, 0, 0, 0}, + {}, + {0, 2, 5}); + ASSERT_TRUE(std::isnan(auc)); + delete metric; +} + +TEST(Metric, DeclareUnifiedTest(MultiClassPRAUC)) { + auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX); + + std::unique_ptr metric{Metric::Create("aucpr", &tparam)}; + + float auc = 0; + std::vector labels {1.0f, 0.0f, 2.0f}; + HostDeviceVector predts{ + 0.0f, 1.0f, 0.0f, + 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, + }; + auc = GetMetricEval(metric.get(), predts, labels, {}); + EXPECT_EQ(auc, 1.0f); + + auc = GetMetricEval(metric.get(), predts, labels, {1.0f, 1.0f, 1.0f}); + EXPECT_EQ(auc, 1.0f); + + predts.HostVector() = { + 0.0f, 1.0f, 0.0f, + 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, + 0.0f, 0.0f, 1.0f, + }; + labels = {1.0f, 0.0f, 2.0f, 1.0f}; + auc = GetMetricEval(metric.get(), predts, labels, {1.0f, 2.0f, 3.0f, 4.0f}); + ASSERT_GT(auc, 0.699); +} + +TEST(Metric, DeclareUnifiedTest(RankingPRAUC)) { + auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX); + + std::unique_ptr metric{Metric::Create("aucpr", &tparam)}; + + std::vector labels {1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f}; + std::vector groups {0, 2, 6}; + + float auc = 0; + auc = GetMetricEval(metric.get(), {1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f}, labels, {}, groups); + EXPECT_EQ(auc, 1.0f); + + auc = GetMetricEval(metric.get(), {1.0f, 0.5f, 0.8f, 0.3f, 0.2f, 1.0f}, labels, {}, groups); + EXPECT_EQ(auc, 1.0f); + + auc = GetMetricEval(metric.get(), {1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f}, + {1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f}, {}, groups); + ASSERT_TRUE(std::isnan(auc)); + + // Incorrect label + ASSERT_THROW(GetMetricEval(metric.get(), {1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f}, + {1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 3.0f}, {}, groups), + dmlc::Error); + + // AUCPR with groups and no weights + EXPECT_NEAR(GetMetricEval( + metric.get(), {0.87f, 0.31f, 0.40f, 0.42f, 0.25f, 0.66f, 0.95f, + 0.09f, 0.10f, 0.97f, 0.76f, 0.69f, 0.15f, 0.20f, + 0.30f, 0.14f, 0.07f, 0.58f, 0.61f, 0.08f}, + {0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1}, + {}, // weights + {0, 2, 5, 9, 14, 20}), // group info + 0.556021f, 0.001f); +} } // namespace metric } // namespace xgboost diff --git a/tests/cpp/metric/test_rank_metric.cc b/tests/cpp/metric/test_rank_metric.cc index c8c97bef1930..e7eef166ddab 100644 --- a/tests/cpp/metric/test_rank_metric.cc +++ b/tests/cpp/metric/test_rank_metric.cc @@ -24,66 +24,6 @@ TEST(Metric, AMS) { } #endif -TEST(Metric, DeclareUnifiedTest(AUCPR)) { - auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX); - xgboost::Metric *metric = xgboost::Metric::Create("aucpr", &tparam); - ASSERT_STREQ(metric->Name(), "aucpr"); - EXPECT_NEAR(GetMetricEval(metric, {0, 0, 1, 1}, {0, 0, 1, 1}), 1, 1e-10); - EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, {0, 0, 1, 1}), - 0.5f, 0.001f); - EXPECT_NEAR( - GetMetricEval(metric, - {0.4f, 0.2f, 0.9f, 0.1f, 0.2f, 0.4f, 0.1f, 0.1f, 0.2f, 0.1f}, - {0, 0, 0, 0, 0, 1, 0, 0, 1, 1}), - 0.2908445f, 0.001f); - EXPECT_NEAR(GetMetricEval( - metric, {0.87f, 0.31f, 0.40f, 0.42f, 0.25f, 0.66f, 0.95f, - 0.09f, 0.10f, 0.97f, 0.76f, 0.69f, 0.15f, 0.20f, - 0.30f, 0.14f, 0.07f, 0.58f, 0.61f, 0.08f}, - {0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1}), - 0.2769199f, 0.001f); - EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {})); - EXPECT_ANY_THROW(GetMetricEval(metric, {0, 0}, {0, 0})); - EXPECT_ANY_THROW(GetMetricEval(metric, {0, 0}, {1, 1})); - - // AUCPR with instance weights - EXPECT_NEAR(GetMetricEval( - metric, {0.29f, 0.52f, 0.11f, 0.21f, 0.219f, 0.93f, 0.493f, - 0.17f, 0.47f, 0.13f, 0.43f, 0.59f, 0.87f, 0.007f}, - {0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0}, - {1, 2, 7, 4, 5, 2.2f, 3.2f, 5, 6, 1, 2, 1.1f, 3.2f, 4.5f}), // weights - 0.694435f, 0.001f); - - // AUCPR with groups and no weights - EXPECT_NEAR(GetMetricEval( - metric, {0.87f, 0.31f, 0.40f, 0.42f, 0.25f, 0.66f, 0.95f, - 0.09f, 0.10f, 0.97f, 0.76f, 0.69f, 0.15f, 0.20f, - 0.30f, 0.14f, 0.07f, 0.58f, 0.61f, 0.08f}, - {0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1}, - {}, // weights - {0, 2, 5, 9, 14, 20}), // group info - 0.556021f, 0.001f); - - // AUCPR with groups and weights - EXPECT_NEAR(GetMetricEval( - metric, {0.29f, 0.52f, 0.11f, 0.21f, 0.219f, 0.93f, 0.493f, - 0.17f, 0.47f, 0.13f, 0.43f, 0.59f, 0.87f, 0.007f}, // predictions - {0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0}, - {1, 2, 7, 4, 5, 2.2f, 3.2f, 5, 6, 1, 2, 1.1f, 3.2f, 4.5f}, // weights - {0, 2, 5, 9, 14}), // group info - 0.8150615f, 0.001f); - - // Exception scenarios for grouped datasets - EXPECT_ANY_THROW(GetMetricEval(metric, - {0, 0.1f, 0.3f, 0.5f, 0.7f}, - {1, 1, 0, 0, 0}, - {}, - {0, 2, 5})); - - delete metric; -} - - TEST(Metric, DeclareUnifiedTest(Precision)) { // When the limit for precision is not given, it takes the limit at // std::numeric_limits::max(); hence all values are very small diff --git a/tests/python-gpu/test_gpu_eval_metrics.py b/tests/python-gpu/test_gpu_eval_metrics.py index f2b605c8b302..1282e115a775 100644 --- a/tests/python-gpu/test_gpu_eval_metrics.py +++ b/tests/python-gpu/test_gpu_eval_metrics.py @@ -47,3 +47,12 @@ def test_roc_auc_ltr(self, n_samples): gpu_auc = float(gpu.eval(Xy).split(":")[1]) np.testing.assert_allclose(cpu_auc, gpu_auc) + + def test_pr_auc_binary(self): + self.cpu_test.run_pr_auc_binary("gpu_hist") + + def test_pr_auc_multi(self): + self.cpu_test.run_pr_auc_multi("gpu_hist") + + def test_pr_auc_ltr(self): + self.cpu_test.run_pr_auc_ltr("gpu_hist") diff --git a/tests/python/test_eval_metrics.py b/tests/python/test_eval_metrics.py index 877f3ef33447..1c98fa846d51 100644 --- a/tests/python/test_eval_metrics.py +++ b/tests/python/test_eval_metrics.py @@ -239,6 +239,7 @@ def run_roc_auc_multi(self, tree_method, n_samples, weighted): np.testing.assert_allclose(skl_auc, auc, rtol=1e-6) X = rng.randn(*X.shape) + score = booster.predict(xgb.DMatrix(X, weight=weights)) skl_auc = roc_auc_score( y, score, average="weighted", sample_weight=weights, multi_class="ovr" @@ -251,3 +252,67 @@ def run_roc_auc_multi(self, tree_method, n_samples, weighted): ) def test_roc_auc_multi(self, n_samples, weighted): self.run_roc_auc_multi("hist", n_samples, weighted) + + @pytest.mark.parametrize("n_samples", [4, 100, 1000]) + def test_roc_auc_multi(self, n_samples): + self.run_roc_auc_multi("hist", n_samples) + + def run_pr_auc_binary(self, tree_method): + from sklearn.metrics import precision_recall_curve, auc + from sklearn.datasets import make_classification + X, y = make_classification(128, 4, n_classes=2, random_state=1994) + clf = xgb.XGBClassifier(tree_method=tree_method, n_estimators=1) + clf.fit(X, y, eval_metric="aucpr", eval_set=[(X, y)]) + evals_result = clf.evals_result()["validation_0"]["aucpr"][-1] + + y_score = clf.predict_proba(X)[:, 1] # get the positive column + precision, recall, _ = precision_recall_curve(y, y_score) + prauc = auc(recall, precision) + # Interpolation results are slightly different from sklearn, but overall should be + # similar. + np.testing.assert_allclose(prauc, evals_result, rtol=1e-2) + + clf = xgb.XGBClassifier(tree_method=tree_method, n_estimators=10) + clf.fit(X, y, eval_metric="aucpr", eval_set=[(X, y)]) + evals_result = clf.evals_result()["validation_0"]["aucpr"][-1] + np.testing.assert_allclose(0.99, evals_result, rtol=1e-2) + + def test_pr_auc_binary(self): + self.run_pr_auc_binary("hist") + + def run_pr_auc_multi(self, tree_method): + from sklearn.datasets import make_classification + X, y = make_classification( + 64, 16, n_informative=8, n_classes=3, random_state=1994 + ) + clf = xgb.XGBClassifier(tree_method=tree_method, n_estimators=1) + clf.fit(X, y, eval_metric="aucpr", eval_set=[(X, y)]) + evals_result = clf.evals_result()["validation_0"]["aucpr"][-1] + # No available implementation for comparison, just check that XGBoost converges to + # 1.0 + clf = xgb.XGBClassifier(tree_method=tree_method, n_estimators=10) + clf.fit(X, y, eval_metric="aucpr", eval_set=[(X, y)]) + evals_result = clf.evals_result()["validation_0"]["aucpr"][-1] + np.testing.assert_allclose(1.0, evals_result, rtol=1e-2) + + def test_pr_auc_multi(self): + self.run_pr_auc_multi("hist") + + def run_pr_auc_ltr(self, tree_method): + from sklearn.datasets import make_classification + X, y = make_classification(128, 4, n_classes=2, random_state=1994) + ltr = xgb.XGBRanker(tree_method=tree_method, n_estimators=16) + groups = np.array([32, 32, 64]) + ltr.fit( + X, + y, + group=groups, + eval_set=[(X, y)], + eval_group=[groups], + eval_metric="aucpr" + ) + results = ltr.evals_result()["validation_0"]["aucpr"] + assert results[-1] >= 0.99 + + def test_pr_auc_ltr(self): + self.run_pr_auc_ltr("hist") diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 343eff97b984..ddb7732d801c 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -587,7 +587,7 @@ def run_empty_dmatrix_auc(client: "Client", tree_method: str, n_workers: int) -> cls = xgb.dask.DaskXGBClassifier( tree_method=tree_method, n_estimators=2, use_label_encoder=False ) - cls.fit(X, y, eval_metric="auc", eval_set=[(valid_X, valid_y)]) + cls.fit(X, y, eval_metric=["auc", "aucpr"], eval_set=[(valid_X, valid_y)]) # multiclass X_, y_ = make_classification( @@ -618,7 +618,7 @@ def run_empty_dmatrix_auc(client: "Client", tree_method: str, n_workers: int) -> cls = xgb.dask.DaskXGBClassifier( tree_method=tree_method, n_estimators=2, use_label_encoder=False ) - cls.fit(X, y, eval_metric="auc", eval_set=[(valid_X, valid_y)]) + cls.fit(X, y, eval_metric=["auc", "aucpr"], eval_set=[(valid_X, valid_y)]) def test_empty_dmatrix_auc() -> None: