From 202fe96d76d01a048ba8fdc2085aa8e68e56833a Mon Sep 17 00:00:00 2001 From: fis Date: Sun, 10 Oct 2021 00:43:27 +0800 Subject: [PATCH] Fix. --- src/metric/auc.cu | 5 +++-- tests/cpp/metric/test_auc.cc | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/metric/auc.cu b/src/metric/auc.cu index 834d3d3710fc..9cb66f32d442 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -735,7 +735,7 @@ GPURankingPRAUCImpl(common::Span predts, MetaInfo const &info, size_t group_id = dh::SegmentId(d_group_ptr, idx); float label = labels[idx]; - float w = get_weight(group_id); + float w = weights.empty() ? 1.0f : get_weight(group_id); float fp = (1.0 - label) * w; float tp = label * w; return thrust::make_pair(fp, tp); @@ -830,6 +830,7 @@ GPURankingPRAUCImpl(common::Span predts, MetaInfo const &info, auto val_in = dh::MakeTransformIterator( thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { auto group_id = dh::SegmentId(d_group_ptr, d_unique_idx[i]); + float fp, tp; float fp_prev, tp_prev; if (i == d_unique_class_ptr[group_id]) { @@ -860,7 +861,7 @@ GPURankingPRAUCImpl(common::Span predts, MetaInfo const &info, float fp, tp; thrust::tie(fp, tp) = d_fptp[LastOf(g, d_group_ptr)]; float area = fp * tp; - if (area > 0) { + if (area > 0 && d_group_ptr[g + 1] - d_group_ptr[g] >= 3) { return thrust::make_pair(s_d_auc[g], static_cast(0)); } return thrust::make_pair(0.0f, static_cast(1)); diff --git a/tests/cpp/metric/test_auc.cc b/tests/cpp/metric/test_auc.cc index f9cdde75c50d..e29040759255 100644 --- a/tests/cpp/metric/test_auc.cc +++ b/tests/cpp/metric/test_auc.cc @@ -187,7 +187,7 @@ TEST(Metric, DeclareUnifiedTest(RankingPRAUC)) { EXPECT_EQ(auc, 1.0f); auc = GetMetricEval(metric.get(), {0.0f, 1.0f, 0.0f, 1.0f, 1.0f}, labels, {}, groups); - std::cout << auc << std::endl; + EXPECT_NEAR(auc, 0.189, 1e-3); } } // namespace metric } // namespace xgboost