Skip to content

Commit

Permalink
Old CUDA.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Oct 9, 2021
1 parent 8b5a15a commit 1b5808c
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
3 changes: 1 addition & 2 deletions src/metric/auc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,7 @@ class EvalAUCPR : public EvalAUC<EvalAUCPR> {
auto n_threads = tparam_->Threads();
if (tparam_->gpu_id == GenericParameter::kCpuId) {
auto labels = info.labels_.ConstHostSpan();
if (std::any_of(labels.cbegin(), labels.cend(),
[](float y) { return y < 0.0f || y > 1.0f; })) {
if (std::any_of(labels.cbegin(), labels.cend(), PRAUCLabelInvalid{})) {
InvalidLabels();
}
std::tie(auc, valid_groups) =
Expand Down
5 changes: 2 additions & 3 deletions src/metric/auc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -897,9 +897,8 @@ GPURankingPRAUC(common::Span<float const> predts, MetaInfo const &info,

dh::XGBDeviceAllocator<char> alloc;
auto labels = info.labels_.ConstDeviceSpan();
if (thrust::any_of(
thrust::cuda::par(alloc), dh::tbegin(labels), dh::tend(labels),
[] XGBOOST_DEVICE(float y) { return y < 0.0f || y > 1.0f; })) {
if (thrust::any_of(thrust::cuda::par(alloc), dh::tbegin(labels),
dh::tend(labels), PRAUCLabelInvalid{})) {
InvalidLabels();
}
auto d_weights = info.weights_.ConstDeviceSpan();
Expand Down
4 changes: 4 additions & 0 deletions src/metric/auc.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ inline void InvalidGroupAUC() {
<< "least 2 pairs of samples.";
}

struct PRAUCLabelInvalid {
XGBOOST_DEVICE bool operator()(float y) { return y < 0.0f || y > 1.0f; }
};

inline void InvalidLabels() {
LOG(FATAL) << "PR-AUC supports only binary relevance for learning to rank.";
}
Expand Down

0 comments on commit 1b5808c

Please sign in to comment.