Skip to content

Commit

Permalink
Re-implement ROC-AUC. (#6747)
Browse files Browse the repository at this point in the history
* Re-implement ROC-AUC.

* Binary
* MultiClass
* LTR
* Add documents.

This PR resolves a few issues:
  - Define a value when the dataset is invalid, which can happen if there's an
  empty dataset, or when the dataset contains only positive or negative values.
  - Define ROC-AUC for multi-class classification.
  - Define weighted average value for distributed setting.
  - A correct implementation for learning to rank task.  Previous
  implementation is just binary classification with averaging across groups,
  which doesn't measure ordered learning to rank.
  • Loading branch information
trivialfis committed Mar 20, 2021
1 parent 4ee8340 commit bcc0277
Show file tree
Hide file tree
Showing 27 changed files with 1,622 additions and 461 deletions.
1 change: 1 addition & 0 deletions amalgamation/xgboost-all0.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "../src/metric/elementwise_metric.cc"
#include "../src/metric/multiclass_metric.cc"
#include "../src/metric/rank_metric.cc"
#include "../src/metric/auc.cc"
#include "../src/metric/survival_metric.cc"

// objectives
Expand Down
10 changes: 9 additions & 1 deletion doc/parameter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,15 @@ Specify the learning task and the corresponding learning objective. The objectiv
- ``error@t``: a different than 0.5 binary classification threshold value could be specified by providing a numerical value through 't'.
- ``merror``: Multiclass classification error rate. It is calculated as ``#(wrong cases)/#(all cases)``.
- ``mlogloss``: `Multiclass logloss <http://scikit-learn.org/stable/modules/generated/sklearn.metrics.log_loss.html>`_.
- ``auc``: `Area under the curve <http://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_curve>`_. Available for binary classification and learning-to-rank tasks.
- ``auc``: `Receiver Operating Characteristic Area under the Curve <http://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_curve>`_.
Available for classification and learning-to-rank tasks.

- When used with binary classification, the objective should be ``binary:logistic`` or similar functions that work on probability.
- When used with multi-class classification, objective should be ``multi:softprob`` instead of ``multi:softmax``, as the latter doesn't output probability. Also the AUC is calculated by 1-vs-rest with reference class weighted by class prevalence.
- When used with LTR task, the AUC is computed by comparing pairs of documents to count correctly sorted pairs. This corresponds to pairwise learning to rank. The implementation has some issues with average AUC around groups and distributed workers not being well-defined.
- On a single machine the AUC calculation is exact. In a distributed environment the AUC is a weighted average over the AUC of training rows on each node - therefore, distributed AUC is an approximation sensitive to the distribution of data across workers. Use another metric in distributed environments if precision and reproducibility are important.
- If input dataset contains only negative or positive samples the output is `NaN`.

- ``aucpr``: `Area under the PR curve <https://en.wikipedia.org/wiki/Precision_and_recall>`_. Available for binary classification and learning-to-rank tasks.
- ``ndcg``: `Normalized Discounted Cumulative Gain <http://en.wikipedia.org/wiki/NDCG>`_
- ``map``: `Mean Average Precision <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_
Expand Down
12 changes: 7 additions & 5 deletions src/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <xgboost/base.h>
#include <xgboost/logging.h>
#include <xgboost/span.h>

#include <algorithm>
#include <exception>
Expand Down Expand Up @@ -163,13 +164,14 @@ inline void AssertOneAPISupport() {
#endif // XGBOOST_USE_ONEAPI
}

template <typename Idx, typename V, typename Comp = std::less<V>>
std::vector<Idx> ArgSort(std::vector<V> const &array, Comp comp = std::less<V>{}) {
template <typename Idx, typename Container,
typename V = typename Container::value_type,
typename Comp = std::less<V>>
std::vector<Idx> ArgSort(Container const &array, Comp comp = std::less<V>{}) {
std::vector<Idx> result(array.size());
std::iota(result.begin(), result.end(), 0);
std::stable_sort(
result.begin(), result.end(),
[&array, comp](Idx const &l, Idx const &r) { return comp(array[l], array[r]); });
auto op = [&array, comp](Idx const &l, Idx const &r) { return comp(array[l], array[r]); };
XGBOOST_PARALLEL_STABLE_SORT(result.begin(), result.end(), op);
return result;
}
} // namespace common
Expand Down
150 changes: 121 additions & 29 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,62 @@ size_t SegmentedUnique(Inputs &&...inputs) {
return SegmentedUnique(thrust::cuda::par(alloc), std::forward<Inputs&&>(inputs)...);
}

/**
* \brief Unique by key for many groups of data. Has same constraint as `SegmentedUnique`.
*
* \tparam exec thrust execution policy
* \tparam key_segments_first start iter to segment pointer
* \tparam key_segments_last end iter to segment pointer
* \tparam key_first start iter to key for comparison
* \tparam key_last end iter to key for comparison
* \tparam val_first start iter to values
* \tparam key_segments_out output iterator for new segment pointer
* \tparam val_out output iterator for values
* \tparam comp binary comparison operator
*/
template <typename DerivedPolicy, typename SegInIt, typename SegOutIt,
typename KeyInIt, typename ValInIt, typename ValOutIt, typename Comp>
size_t SegmentedUniqueByKey(
const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
SegInIt key_segments_first, SegInIt key_segments_last, KeyInIt key_first,
KeyInIt key_last, ValInIt val_first, SegOutIt key_segments_out,
ValOutIt val_out, Comp comp) {
using Key =
thrust::pair<size_t,
typename thrust::iterator_traits<KeyInIt>::value_type>;

auto unique_key_it = dh::MakeTransformIterator<Key>(
thrust::make_counting_iterator(static_cast<size_t>(0)),
[=] __device__(size_t i) {
size_t seg = dh::SegmentId(key_segments_first, key_segments_last, i);
return thrust::make_pair(seg, *(key_first + i));
});
size_t segments_len = key_segments_last - key_segments_first;
thrust::fill(thrust::device, key_segments_out,
key_segments_out + segments_len, 0);
size_t n_inputs = std::distance(key_first, key_last);
// Reduce the number of uniques elements per segment, avoid creating an
// intermediate array for `reduce_by_key`. It's limited by the types that
// atomicAdd supports. For example, size_t is not supported as of CUDA 10.2.
auto reduce_it = thrust::make_transform_output_iterator(
thrust::make_discard_iterator(),
detail::SegmentedUniqueReduceOp<Key, SegOutIt>{key_segments_out});
auto uniques_ret = thrust::unique_by_key_copy(
exec, unique_key_it, unique_key_it + n_inputs, val_first, reduce_it,
val_out, [=] __device__(Key const &l, Key const &r) {
if (l.first == r.first) {
// In the same segment.
return comp(thrust::get<1>(l), thrust::get<1>(r));
}
return false;
});
auto n_uniques = uniques_ret.second - val_out;
CHECK_LE(n_uniques, n_inputs);
thrust::exclusive_scan(exec, key_segments_out,
key_segments_out + segments_len, key_segments_out, 0);
return n_uniques;
}

template <typename Policy, typename InputIt, typename Init, typename Func>
auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce_op) {
size_t constexpr kLimit = std::numeric_limits<int32_t>::max() / 2;
Expand All @@ -1215,36 +1271,73 @@ auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce
return aggregate;
}

// wrapper to avoid integer `num_items`.
template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT,
typename OffsetT>
void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
OffsetT num_items) {
size_t bytes = 0;
safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
OffsetT>::Dispatch(nullptr, bytes, d_in, d_out, scan_op,
cub::NullType(), num_items, nullptr,
false)));
dh::TemporaryArray<char> storage(bytes);
safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
OffsetT>::Dispatch(storage.data().get(), bytes, d_in,
d_out, scan_op, cub::NullType(),
num_items, nullptr, false)));
}

template <typename InputIteratorT, typename OutputIteratorT, typename OffsetT>
void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items) {
InclusiveScan(d_in, d_out, cub::Sum(), num_items);
}

template <bool accending, typename IdxT, typename U>
void ArgSort(xgboost::common::Span<U> values, xgboost::common::Span<IdxT> sorted_idx) {
void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_idx) {
size_t bytes = 0;
Iota(sorted_idx);
CHECK_LT(sorted_idx.size(), 1 << 31);
TemporaryArray<U> out(values.size());

using KeyT = typename decltype(keys)::value_type;
using ValueT = std::remove_const_t<IdxT>;

TemporaryArray<KeyT> out(keys.size());
cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(keys.data()),
out.data().get());
cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(sorted_idx.data()),
sorted_idx.data());

if (accending) {
cub::DeviceRadixSort::SortPairs(nullptr, bytes, values.data(),
out.data().get(), sorted_idx.data(),
sorted_idx.data(), sorted_idx.size());
void *d_temp_storage = nullptr;
cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false);
dh::TemporaryArray<char> storage(bytes);
cub::DeviceRadixSort::SortPairs(storage.data().get(), bytes, values.data(),
out.data().get(), sorted_idx.data(),
sorted_idx.data(), sorted_idx.size());
d_temp_storage = storage.data().get();
cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false);
} else {
cub::DeviceRadixSort::SortPairsDescending(
nullptr, bytes, values.data(), out.data().get(), sorted_idx.data(),
sorted_idx.data(), sorted_idx.size());
void *d_temp_storage = nullptr;
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, size_t>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false)));
dh::TemporaryArray<char> storage(bytes);
cub::DeviceRadixSort::SortPairsDescending(
storage.data().get(), bytes, values.data(), out.data().get(),
sorted_idx.data(), sorted_idx.data(), sorted_idx.size());
d_temp_storage = storage.data().get();
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, size_t>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false)));
}
}

namespace detail {
// Wrapper around cub sort for easier `descending` sort
template <bool descending, typename KeyT, typename ValueT, typename OffsetIteratorT>
// Wrapper around cub sort for easier `descending` sort and `size_t num_items`.
template <bool descending, typename KeyT, typename ValueT,
typename OffsetIteratorT>
void DeviceSegmentedRadixSortPair(
void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, // NOLINT
void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, // NOLINT
KeyT *d_keys_out, const ValueT *d_values_in, ValueT *d_values_out,
size_t num_items, size_t num_segments, OffsetIteratorT d_begin_offsets,
OffsetIteratorT d_end_offsets, int begin_bit = 0,
Expand All @@ -1253,12 +1346,12 @@ void DeviceSegmentedRadixSortPair(
cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(d_values_in),
d_values_out);
using OffsetT = size_t;
dh::safe_cuda((cub::DispatchSegmentedRadixSort<
descending, KeyT, ValueT, OffsetIteratorT,
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys,
d_values, num_items, num_segments,
d_begin_offsets, d_end_offsets, begin_bit,
end_bit, false, nullptr, false)));
safe_cuda((cub::DispatchSegmentedRadixSort<
descending, KeyT, ValueT, OffsetIteratorT,
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys,
d_values, num_items, num_segments,
d_begin_offsets, d_end_offsets, begin_bit,
end_bit, false, nullptr, false)));
}
} // namespace detail

Expand All @@ -1270,12 +1363,11 @@ void SegmentedArgSort(xgboost::common::Span<U> values,
size_t n_groups = group_ptr.size() - 1;
size_t bytes = 0;
Iota(sorted_idx);
CHECK_LT(sorted_idx.size(), 1 << 31);
TemporaryArray<U> values_out(values.size());
TemporaryArray<std::remove_const_t<U>> values_out(values.size());
detail::DeviceSegmentedRadixSortPair<!accending>(
nullptr, bytes, values.data(), values_out.data().get(),
sorted_idx.data(), sorted_idx.data(), sorted_idx.size(), n_groups,
group_ptr.data(), group_ptr.data() + 1);
nullptr, bytes, values.data(), values_out.data().get(), sorted_idx.data(),
sorted_idx.data(), sorted_idx.size(), n_groups, group_ptr.data(),
group_ptr.data() + 1);
dh::TemporaryArray<xgboost::common::byte> temp_storage(bytes);
detail::DeviceSegmentedRadixSortPair<!accending>(
temp_storage.data().get(), bytes, values.data(), values_out.data().get(),
Expand Down
3 changes: 3 additions & 0 deletions src/common/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ XGBOOST_DEVICE inline float Sigmoid(float x) {
return 1.0f / (1.0f + expf(-x));
}

template <typename T>
XGBOOST_DEVICE inline static T Sqr(T a) { return a * a; }

/*!
* \brief Equality test for both integer and floating point.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/common/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ std::vector<T> WeightedSamplingWithoutReplacement(
auto k = std::log(u) / w;
keys[i] = k;
}
auto ind = ArgSort<size_t>(keys, std::greater<>{});
auto ind = ArgSort<size_t>(Span<float>{keys}, std::greater<>{});
ind.resize(n);

std::vector<T> results(ind.size());
Expand Down
84 changes: 84 additions & 0 deletions src/common/ranking_utils.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*!
* Copyright 2021 by XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_RANKING_UTILS_H_
#define XGBOOST_COMMON_RANKING_UTILS_H_

#include <cub/cub.cuh>
#include "xgboost/base.h"
#include "device_helpers.cuh"
#include "./math.h"

namespace xgboost {
namespace common {
/**
* \param n Number of items (length of the base)
* \param h hight
*/
XGBOOST_DEVICE inline size_t DiscreteTrapezoidArea(size_t n, size_t h) {
n -= 1; // without diagonal entries
h = std::min(n, h); // Specific for ranking.
size_t total = ((n - (h - 1)) + n) * h / 2;
return total;
}

/**
* Used for mapping many groups of trapezoid shaped computation onto CUDA blocks. The
* trapezoid must be on upper right corner.
*
* Equivalent to loops like:
*
* \code
* for (size i = 0; i < h; ++i) {
* for (size_t j = i + 1; j < n; ++j) {
* do_something();
* }
* }
* \endcode
*/
template <typename U>
inline size_t
SegmentedTrapezoidThreads(xgboost::common::Span<U> group_ptr,
xgboost::common::Span<size_t> out_group_threads_ptr,
size_t h) {
CHECK_GE(group_ptr.size(), 1);
CHECK_EQ(group_ptr.size(), out_group_threads_ptr.size());
dh::LaunchN(
dh::CurrentDevice(), group_ptr.size(), [=] XGBOOST_DEVICE(size_t idx) {
if (idx == 0) {
out_group_threads_ptr[0] = 0;
return;
}

size_t cnt = static_cast<size_t>(group_ptr[idx] - group_ptr[idx - 1]);
out_group_threads_ptr[idx] = DiscreteTrapezoidArea(cnt, h);
});
dh::InclusiveSum(out_group_threads_ptr.data(), out_group_threads_ptr.data(),
out_group_threads_ptr.size());
size_t total = 0;
dh::safe_cuda(cudaMemcpy(
&total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1,
sizeof(total), cudaMemcpyDeviceToHost));
return total;
}

/**
* Called inside kernel to obtain coordinate from trapezoid grid.
*/
XGBOOST_DEVICE inline void UnravelTrapeziodIdx(size_t i_idx, size_t n,
size_t *out_i, size_t *out_j) {
auto &i = *out_i;
auto &j = *out_j;
double idx = static_cast<double>(i_idx);
double N = static_cast<double>(n);

i = std::ceil(-(0.5 - N + std::sqrt(common::Sqr(N - 0.5) + 2.0 * (-idx - 1.0)))) - 1.0;

auto I = static_cast<double>(i);
size_t n_elems = -0.5 * common::Sqr(I) + (N - 0.5) * I;

j = idx - n_elems + i + 1;
}
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_RANKING_UTILS_H_
4 changes: 3 additions & 1 deletion src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,9 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
group_ptr_.push_back(i);
}
}
group_ptr_.push_back(query_ids.size());
if (group_ptr_.back() != query_ids.size()) {
group_ptr_.push_back(query_ids.size());
}
} else if (!std::strcmp(key, "label_lower_bound")) {
auto& labels = labels_lower_bound_.HostVector();
labels.resize(num);
Expand Down

0 comments on commit bcc0277

Please sign in to comment.