Skip to content

Commit

Permalink
Fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Oct 9, 2021
1 parent 1b5808c commit 202fe96
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions src/metric/auc.cu
Expand Up @@ -735,7 +735,7 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
size_t group_id = dh::SegmentId(d_group_ptr, idx);
float label = labels[idx];

float w = get_weight(group_id);
float w = weights.empty() ? 1.0f : get_weight(group_id);
float fp = (1.0 - label) * w;
float tp = label * w;
return thrust::make_pair(fp, tp);
Expand Down Expand Up @@ -830,6 +830,7 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
auto val_in = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
auto group_id = dh::SegmentId(d_group_ptr, d_unique_idx[i]);

float fp, tp;
float fp_prev, tp_prev;
if (i == d_unique_class_ptr[group_id]) {
Expand Down Expand Up @@ -860,7 +861,7 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
float fp, tp;
thrust::tie(fp, tp) = d_fptp[LastOf(g, d_group_ptr)];
float area = fp * tp;
if (area > 0) {
if (area > 0 && d_group_ptr[g + 1] - d_group_ptr[g] >= 3) {
return thrust::make_pair(s_d_auc[g], static_cast<uint32_t>(0));
}
return thrust::make_pair(0.0f, static_cast<uint32_t>(1));
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/metric/test_auc.cc
Expand Up @@ -187,7 +187,7 @@ TEST(Metric, DeclareUnifiedTest(RankingPRAUC)) {
EXPECT_EQ(auc, 1.0f);

auc = GetMetricEval(metric.get(), {0.0f, 1.0f, 0.0f, 1.0f, 1.0f}, labels, {}, groups);
std::cout << auc << std::endl;
EXPECT_NEAR(auc, 0.189, 1e-3);
}
} // namespace metric
} // namespace xgboost

0 comments on commit 202fe96

Please sign in to comment.