diff --git a/src/metric/metric_common.h b/src/metric/metric_common.h index 50399cfe9aea..303118f9eb2b 100644 --- a/src/metric/metric_common.h +++ b/src/metric/metric_common.h @@ -5,11 +5,26 @@ #ifndef XGBOOST_METRIC_METRIC_COMMON_H_ #define XGBOOST_METRIC_METRIC_COMMON_H_ +#include +#include +#include +#include + #include "../common/common.h" namespace xgboost { namespace metric { +// Ranking config to be used on device and host +struct EvalRankConfig { + public: + // Parsed from metric name, the top-n number of instances within a group after + // ranking to use for evaluation. + unsigned topn{std::numeric_limits::max()}; + std::string name; + bool minus{false}; +}; + class PackedReduceResult { double residue_sum_; double weights_sum_; diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index b0cf5ca3984a..5ffdc003fb49 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -13,9 +13,13 @@ #include "xgboost/host_device_vector.h" #include "../common/math.h" +#include "metric_common.h" namespace { +using PredIndPair = std::pair; +using PredIndPairContainer = std::vector; + /* * Adapter to access instance weights. * @@ -31,9 +35,6 @@ namespace { * of type PredIndPairContainer */ -using PredIndPairContainer - = std::vector>; - class PerInstanceWeightPolicy { public: inline static xgboost::bst_float @@ -91,20 +92,20 @@ struct EvalAMS : public Metric { using namespace std; // NOLINT(*) const auto ndata = static_cast(info.labels_.Size()); - std::vector > rec(ndata); + PredIndPairContainer rec(ndata); - const std::vector& h_preds = preds.HostVector(); -#pragma omp parallel for schedule(static) + const auto &h_preds = preds.ConstHostVector(); + #pragma omp parallel for schedule(static) for (bst_omp_uint i = 0; i < ndata; ++i) { rec[i] = std::make_pair(h_preds[i], i); } - std::stable_sort(rec.begin(), rec.end(), common::CmpFirst); + XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst); auto ntop = static_cast(ratio_ * ndata); if (ntop == 0) ntop = ndata; const double br = 10.0; unsigned thresindex = 0; double s_tp = 0.0, b_fp = 0.0, tams = 0.0; - const auto& labels = info.labels_.HostVector(); + const auto& labels = info.labels_.ConstHostVector(); for (unsigned i = 0; i < static_cast(ndata-1) && i < ntop; ++i) { const unsigned ridx = rec[i].second; const bst_float wt = info.GetWeight(ridx); @@ -139,72 +140,77 @@ struct EvalAMS : public Metric { float ratio_; }; -/*! \brief Area Under Curve, for both classification and rank */ +/*! \brief Area Under Curve, for both classification and rank computed on CPU */ struct EvalAuc : public Metric { private: template bst_float Eval(const HostDeviceVector &preds, const MetaInfo &info, - bool distributed) { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.Size(), info.labels_.Size()) - << "label size predict size not match"; - std::vector tgptr(2, 0); - tgptr[1] = static_cast(info.labels_.Size()); - - const std::vector &gptr = info.group_ptr_.empty() ? tgptr : info.group_ptr_; - CHECK_EQ(gptr.back(), info.labels_.Size()) - << "EvalAuc: group structure must match number of prediction"; - const auto ngroup = static_cast(gptr.size() - 1); + bool distributed, + const std::vector &gptr) { + const auto ngroups = static_cast(gptr.size() - 1); // sum of all AUC's across all query groups double sum_auc = 0.0; int auc_error = 0; - // each thread takes a local rec - std::vector> rec; - const auto& labels = info.labels_.HostVector(); - const std::vector& h_preds = preds.HostVector(); - for (bst_omp_uint group_id = 0; group_id < ngroup; ++group_id) { - rec.clear(); - for (unsigned j = gptr[group_id]; j < gptr[group_id + 1]; ++j) { - rec.emplace_back(h_preds[j], j); - } - XGBOOST_PARALLEL_STABLE_SORT(rec.begin(), rec.end(), common::CmpFirst); - // calculate AUC - double sum_pospair = 0.0; - double sum_npos = 0.0, sum_nneg = 0.0, buf_pos = 0.0, buf_neg = 0.0; - for (size_t j = 0; j < rec.size(); ++j) { - const bst_float wt - = WeightPolicy::GetWeightOfSortedRecord(info, rec, j, group_id); - const bst_float ctr = labels[rec[j].second]; - // keep bucketing predictions in same bucket - if (j != 0 && rec[j].first != rec[j - 1].first) { - sum_pospair += buf_neg * (sum_npos + buf_pos *0.5); - sum_npos += buf_pos; - sum_nneg += buf_neg; - buf_neg = buf_pos = 0.0f; + const auto& labels = info.labels_.ConstHostVector(); + const auto &h_preds = preds.ConstHostVector(); + + #pragma omp parallel reduction(+:sum_auc, auc_error) if (ngroups > 1) + { + // Each thread works on a distinct group and sorts the predictions in that group + PredIndPairContainer rec; + #pragma omp for schedule(static) + for (bst_omp_uint group_id = 0; group_id < ngroups; ++group_id) { + // Same thread can work on multiple groups one after another; hence, resize + // the predictions array based on the current group + rec.resize(gptr[group_id + 1] - gptr[group_id]); + #pragma omp parallel for schedule(static) if (!omp_in_parallel()) + for (bst_omp_uint j = gptr[group_id]; j < gptr[group_id + 1]; ++j) { + rec[j - gptr[group_id]] = {h_preds[j], j}; + } + + if (omp_in_parallel()) { + std::stable_sort(rec.begin(), rec.end(), common::CmpFirst); + } else { + XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst); + } + + // calculate AUC + double sum_pospair = 0.0; + double sum_npos = 0.0, sum_nneg = 0.0, buf_pos = 0.0, buf_neg = 0.0; + for (size_t j = 0; j < rec.size(); ++j) { + const bst_float wt = WeightPolicy::GetWeightOfSortedRecord(info, rec, j, group_id); + const bst_float ctr = labels[rec[j].second]; + // keep bucketing predictions in same bucket + if (j != 0 && rec[j].first != rec[j - 1].first) { + sum_pospair += buf_neg * (sum_npos + buf_pos * 0.5); + sum_npos += buf_pos; + sum_nneg += buf_neg; + buf_neg = buf_pos = 0.0f; + } + buf_pos += ctr * wt; + buf_neg += (1.0f - ctr) * wt; + } + sum_pospair += buf_neg * (sum_npos + buf_pos * 0.5); + sum_npos += buf_pos; + sum_nneg += buf_neg; + // check weird conditions + if (sum_npos <= 0.0 || sum_nneg <= 0.0) { + auc_error += 1; + } else { + // this is the AUC + sum_auc += sum_pospair / (sum_npos * sum_nneg); } - buf_pos += ctr * wt; - buf_neg += (1.0f - ctr) * wt; - } - sum_pospair += buf_neg * (sum_npos + buf_pos * 0.5); - sum_npos += buf_pos; - sum_nneg += buf_neg; - // check weird conditions - if (sum_npos <= 0.0 || sum_nneg <= 0.0) { - auc_error += 1; - continue; } - // this is the AUC - sum_auc += sum_pospair / (sum_npos * sum_nneg); } // Report average AUC across all groups // In distributed mode, workers which only contains pos or neg samples // will be ignored when aggregate AUC. bst_float dat[2] = {0.0f, 0.0f}; - if (auc_error < static_cast(ngroup)) { + if (auc_error < static_cast(ngroups)) { dat[0] = static_cast(sum_auc); - dat[1] = static_cast(static_cast(ngroup) - auc_error); + dat[1] = static_cast(static_cast(ngroups) - auc_error); } if (distributed) { rabit::Allreduce(dat, 2); @@ -218,127 +224,133 @@ struct EvalAuc : public Metric { bst_float Eval(const HostDeviceVector &preds, const MetaInfo &info, bool distributed) override { + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels_.Size()) + << "label size predict size not match"; + std::vector tgptr(2, 0); + tgptr[1] = static_cast(info.labels_.Size()); + + const auto &gptr = info.group_ptr_.empty() ? tgptr : info.group_ptr_; + CHECK_EQ(gptr.back(), info.labels_.Size()) + << "EvalAuc: group structure must match number of prediction"; + // For ranking task, weights are per-group // For binary classification task, weights are per-instance const bool is_ranking_task = !info.group_ptr_.empty() && info.weights_.Size() != info.num_row_; + if (is_ranking_task) { - return Eval(preds, info, distributed); + return Eval(preds, info, distributed, gptr); } else { - return Eval(preds, info, distributed); + return Eval(preds, info, distributed, gptr); } } - const char* Name() const override { - return "auc"; - } + + const char *Name() const override { return "auc"; } }; /*! \brief Evaluate rank list */ -struct EvalRankList : public Metric { +struct EvalRank : public Metric, public EvalRankConfig { public: bst_float Eval(const HostDeviceVector &preds, const MetaInfo &info, bool distributed) override { CHECK_EQ(preds.Size(), info.labels_.Size()) << "label size predict size not match"; + // quick consistency when group is not available std::vector tgptr(2, 0); tgptr[1] = static_cast(preds.Size()); - const std::vector &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_; + const auto &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_; + CHECK_NE(gptr.size(), 0U) << "must specify group when constructing rank file"; CHECK_EQ(gptr.back(), preds.Size()) - << "EvalRanklist: group structure must match number of prediction"; - const auto ngroup = static_cast(gptr.size() - 1); + << "EvalRank: group structure must match number of prediction"; + + const auto ngroups = static_cast(gptr.size() - 1); // sum statistics double sum_metric = 0.0f; - const auto& labels = info.labels_.HostVector(); - const std::vector& h_preds = preds.HostVector(); -#pragma omp parallel reduction(+:sum_metric) + const auto &labels = info.labels_.ConstHostVector(); + const auto &h_preds = preds.ConstHostVector(); + + #pragma omp parallel reduction(+:sum_metric) { // each thread takes a local rec - std::vector< std::pair > rec; + PredIndPairContainer rec; #pragma omp for schedule(static) - for (bst_omp_uint k = 0; k < ngroup; ++k) { + for (bst_omp_uint k = 0; k < ngroups; ++k) { rec.clear(); for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) { rec.emplace_back(h_preds[j], static_cast(labels[j])); } - sum_metric += this->EvalMetric(rec); + sum_metric += this->EvalGroup(&rec); } } + if (distributed) { bst_float dat[2]; dat[0] = static_cast(sum_metric); - dat[1] = static_cast(ngroup); + dat[1] = static_cast(ngroups); // approximately estimate the metric using mean rabit::Allreduce(dat, 2); return dat[0] / dat[1]; } else { - return static_cast(sum_metric) / ngroup; + return static_cast(sum_metric) / ngroups; } } + const char* Name() const override { - return name_.c_str(); + return name.c_str(); } protected: - explicit EvalRankList(const char* name, const char* param) { + explicit EvalRank(const char* name, const char* param) { using namespace std; // NOLINT(*) - minus_ = false; + if (param != nullptr) { std::ostringstream os; - if (sscanf(param, "%u[-]?", &topn_) == 1) { + if (sscanf(param, "%u[-]?", &topn) == 1) { os << name << '@' << param; - name_ = os.str(); + this->name = os.str(); } else { - topn_ = std::numeric_limits::max(); os << name << param; - name_ = os.str(); + this->name = os.str(); } if (param[strlen(param) - 1] == '-') { - minus_ = true; + minus = true; } } else { - name_ = name; - topn_ = std::numeric_limits::max(); + this->name = name; } } - /*! \return evaluation metric, given the pair_sort record, (pred,label) */ - virtual bst_float EvalMetric(std::vector > &pair_sort) const = 0; // NOLINT(*) - protected: - unsigned topn_; - std::string name_; - bool minus_; + virtual double EvalGroup(PredIndPairContainer *recptr) const = 0; }; /*! \brief Precision at N, for both classification and rank */ -struct EvalPrecision : public EvalRankList{ +struct EvalPrecision : public EvalRank { public: - explicit EvalPrecision(const char *name) : EvalRankList("pre", name) {} + explicit EvalPrecision(const char* name, const char* param) : EvalRank(name, param) {} - protected: - bst_float EvalMetric(std::vector< std::pair > &rec) const override { + double EvalGroup(PredIndPairContainer *recptr) const override { + PredIndPairContainer &rec(*recptr); // calculate Precision std::stable_sort(rec.begin(), rec.end(), common::CmpFirst); unsigned nhit = 0; - for (size_t j = 0; j < rec.size() && j < this->topn_; ++j) { + for (size_t j = 0; j < rec.size() && j < this->topn; ++j) { nhit += (rec[j].second != 0); } - return static_cast(nhit) / topn_; + return static_cast(nhit) / this->topn; } }; /*! \brief NDCG: Normalized Discounted Cumulative Gain at N */ -struct EvalNDCG : public EvalRankList{ - public: - explicit EvalNDCG(const char *name) : EvalRankList("ndcg", name) {} - - protected: - inline bst_float CalcDCG(const std::vector > &rec) const { +struct EvalNDCG : public EvalRank { + private: + double CalcDCG(const PredIndPairContainer &rec) const { double sumdcg = 0.0; - for (size_t i = 0; i < rec.size() && i < this->topn_; ++i) { + for (size_t i = 0; i < rec.size() && i < this->topn; ++i) { const unsigned rel = rec[i].second; if (rel != 0) { sumdcg += ((1 << rel) - 1) / std::log2(i + 2.0); @@ -346,13 +358,18 @@ struct EvalNDCG : public EvalRankList{ } return sumdcg; } - virtual bst_float EvalMetric(std::vector > &rec) const { // NOLINT(*) - XGBOOST_PARALLEL_STABLE_SORT(rec.begin(), rec.end(), common::CmpFirst); - bst_float dcg = this->CalcDCG(rec); - XGBOOST_PARALLEL_STABLE_SORT(rec.begin(), rec.end(), common::CmpSecond); - bst_float idcg = this->CalcDCG(rec); + + public: + explicit EvalNDCG(const char* name, const char* param) : EvalRank(name, param) {} + + double EvalGroup(PredIndPairContainer *recptr) const override { + PredIndPairContainer &rec(*recptr); + std::stable_sort(rec.begin(), rec.end(), common::CmpFirst); + double dcg = CalcDCG(rec); + std::stable_sort(rec.begin(), rec.end(), common::CmpSecond); + double idcg = CalcDCG(rec); if (idcg == 0.0f) { - if (minus_) { + if (this->minus) { return 0.0f; } else { return 1.0f; @@ -363,28 +380,28 @@ struct EvalNDCG : public EvalRankList{ }; /*! \brief Mean Average Precision at N, for both classification and rank */ -struct EvalMAP : public EvalRankList { +struct EvalMAP : public EvalRank { public: - explicit EvalMAP(const char *name) : EvalRankList("map", name) {} + explicit EvalMAP(const char* name, const char* param) : EvalRank(name, param) {} - protected: - bst_float EvalMetric(std::vector< std::pair > &rec) const override { + double EvalGroup(PredIndPairContainer *recptr) const override { + PredIndPairContainer &rec(*recptr); std::stable_sort(rec.begin(), rec.end(), common::CmpFirst); unsigned nhits = 0; double sumap = 0.0; for (size_t i = 0; i < rec.size(); ++i) { if (rec[i].second != 0) { nhits += 1; - if (i < this->topn_) { - sumap += static_cast(nhits) / (i + 1); + if (i < this->topn) { + sumap += static_cast(nhits) / (i + 1); } } } if (nhits != 0) { sumap /= nhits; - return static_cast(sumap); + return sumap; } else { - if (minus_) { + if (this->minus) { return 0.0f; } else { return 1.0f; @@ -404,12 +421,12 @@ struct EvalCox : public Metric { using namespace std; // NOLINT(*) const auto ndata = static_cast(info.labels_.Size()); - const std::vector &label_order = info.LabelAbsSort(); + const auto &label_order = info.LabelAbsSort(); // pre-compute a sum for the denominator double exp_p_sum = 0; // we use double because we might need the precision with large datasets - const std::vector& h_preds = preds.HostVector(); + const auto &h_preds = preds.ConstHostVector(); for (omp_ulong i = 0; i < ndata; ++i) { exp_p_sum += h_preds[i]; } @@ -417,7 +434,7 @@ struct EvalCox : public Metric { double out = 0; double accumulated_sum = 0; bst_omp_uint num_events = 0; - const auto& labels = info.labels_.HostVector(); + const auto& labels = info.labels_.ConstHostVector(); for (bst_omp_uint i = 0; i < ndata; ++i) { const size_t ind = label_order[i]; const auto label = labels[ind]; @@ -442,7 +459,7 @@ struct EvalCox : public Metric { } }; -/*! \brief Area Under PR Curve, for both classification and rank */ +/*! \brief Area Under PR Curve, for both classification and rank computed on CPU */ struct EvalAucPR : public Metric { // implementation of AUC-PR for weighted data // translated from PRROC R Package @@ -451,72 +468,79 @@ struct EvalAucPR : public Metric { template bst_float Eval(const HostDeviceVector &preds, const MetaInfo &info, - bool distributed) { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.Size(), info.labels_.Size()) - << "label size predict size not match"; - std::vector tgptr(2, 0); - tgptr[1] = static_cast(info.labels_.Size()); - const std::vector &gptr = - info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_; - CHECK_EQ(gptr.back(), info.labels_.Size()) - << "EvalAucPR: group structure must match number of prediction"; - const auto ngroup = static_cast(gptr.size() - 1); + bool distributed, + const std::vector &gptr) { + const auto ngroups = static_cast(gptr.size() - 1); + // sum of all AUC's across all query groups double sum_auc = 0.0; int auc_error = 0; - // each thread takes a local rec - std::vector> rec; - const auto& h_labels = info.labels_.HostVector(); - const std::vector& h_preds = preds.HostVector(); - - for (bst_omp_uint group_id = 0; group_id < ngroup; ++group_id) { - double total_pos = 0.0; - double total_neg = 0.0; - rec.clear(); - for (unsigned j = gptr[group_id]; j < gptr[group_id + 1]; ++j) { - const bst_float wt - = WeightPolicy::GetWeightOfInstance(info, j, group_id); - total_pos += wt * h_labels[j]; - total_neg += wt * (1.0f - h_labels[j]); - rec.emplace_back(h_preds[j], j); - } - XGBOOST_PARALLEL_STABLE_SORT(rec.begin(), rec.end(), common::CmpFirst); - // we need pos > 0 && neg > 0 - if (0.0 == total_pos || 0.0 == total_neg) { - auc_error += 1; - continue; - } - // calculate AUC - double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0; - for (size_t j = 0; j < rec.size(); ++j) { - const bst_float wt - = WeightPolicy::GetWeightOfSortedRecord(info, rec, j, group_id); - tp += wt * h_labels[rec[j].second]; - fp += wt * (1.0f - h_labels[rec[j].second]); - if ((j < rec.size() - 1 && rec[j].first != rec[j + 1].first) || j == rec.size() - 1) { - if (tp == prevtp) { - a = 1.0; - b = 0.0; - } else { - h = (fp - prevfp) / (tp - prevtp); - a = 1.0 + h; - b = (prevfp - h * prevtp) / total_pos; - } - if (0.0 != b) { - sum_auc += (tp / total_pos - prevtp / total_pos - - b / a * (std::log(a * tp / total_pos + b) - - std::log(a * prevtp / total_pos + b))) / a; - } else { - sum_auc += (tp / total_pos - prevtp / total_pos) / a; + + const auto &h_labels = info.labels_.ConstHostVector(); + const auto &h_preds = preds.ConstHostVector(); + + #pragma omp parallel reduction(+:sum_auc, auc_error) if (ngroups > 1) + { + // Each thread works on a distinct group and sorts the predictions in that group + PredIndPairContainer rec; + #pragma omp for schedule(static) + for (bst_omp_uint group_id = 0; group_id < ngroups; ++group_id) { + double total_pos = 0.0; + double total_neg = 0.0; + // Same thread can work on multiple groups one after another; hence, resize + // the predictions array based on the current group + rec.resize(gptr[group_id + 1] - gptr[group_id]); + #pragma omp parallel for schedule(static) reduction(+:total_pos, total_neg) \ + if (!omp_in_parallel()) // NOLINT + for (bst_omp_uint j = gptr[group_id]; j < gptr[group_id + 1]; ++j) { + const bst_float wt = WeightPolicy::GetWeightOfInstance(info, j, group_id); + total_pos += wt * h_labels[j]; + total_neg += wt * (1.0f - h_labels[j]); + rec[j - gptr[group_id]] = {h_preds[j], j}; + } + + // we need pos > 0 && neg > 0 + if (total_pos <= 0.0 || total_neg <= 0.0) { + auc_error += 1; + continue; + } + + if (omp_in_parallel()) { + std::stable_sort(rec.begin(), rec.end(), common::CmpFirst); + } else { + XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst); + } + + // calculate AUC + double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0; + for (size_t j = 0; j < rec.size(); ++j) { + const bst_float wt = WeightPolicy::GetWeightOfSortedRecord(info, rec, j, group_id); + tp += wt * h_labels[rec[j].second]; + fp += wt * (1.0f - h_labels[rec[j].second]); + if ((j < rec.size() - 1 && rec[j].first != rec[j + 1].first) || j == rec.size() - 1) { + if (tp == prevtp) { + a = 1.0; + b = 0.0; + } else { + h = (fp - prevfp) / (tp - prevtp); + a = 1.0 + h; + b = (prevfp - h * prevtp) / total_pos; + } + if (0.0 != b) { + sum_auc += (tp / total_pos - prevtp / total_pos - + b / a * (std::log(a * tp / total_pos + b) - + std::log(a * prevtp / total_pos + b))) / a; + } else { + sum_auc += (tp / total_pos - prevtp / total_pos) / a; + } + prevtp = tp; + prevfp = fp; } - prevtp = tp; - prevfp = fp; } - } - // sanity check - if (tp < 0 || prevtp < 0 || fp < 0 || prevfp < 0) { - CHECK(!auc_error) << "AUC-PR: error in calculation"; + // sanity check + if (tp < 0 || prevtp < 0 || fp < 0 || prevfp < 0) { + CHECK(!auc_error) << "AUC-PR: error in calculation"; + } } } @@ -524,9 +548,9 @@ struct EvalAucPR : public Metric { // In distributed mode, workers which only contains pos or neg samples // will be ignored when aggregate AUC-PR. bst_float dat[2] = {0.0f, 0.0f}; - if (auc_error < static_cast(ngroup)) { + if (auc_error < static_cast(ngroups)) { dat[0] = static_cast(sum_auc); - dat[1] = static_cast(static_cast(ngroup) - auc_error); + dat[1] = static_cast(static_cast(ngroups) - auc_error); } if (distributed) { rabit::Allreduce(dat, 2); @@ -541,20 +565,31 @@ struct EvalAucPR : public Metric { bst_float Eval(const HostDeviceVector &preds, const MetaInfo &info, bool distributed) override { + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels_.Size()) + << "label size predict size not match"; + std::vector tgptr(2, 0); + tgptr[1] = static_cast(info.labels_.Size()); + + const auto &gptr = info.group_ptr_.empty() ? tgptr : info.group_ptr_; + CHECK_EQ(gptr.back(), info.labels_.Size()) + << "EvalAucPR: group structure must match number of prediction"; + // For ranking task, weights are per-group // For binary classification task, weights are per-instance const bool is_ranking_task = !info.group_ptr_.empty() && info.weights_.Size() != info.num_row_; + if (is_ranking_task) { - return Eval(preds, info, distributed); + return Eval(preds, info, distributed, gptr); } else { - return Eval(preds, info, distributed); + return Eval(preds, info, distributed, gptr); } } + const char *Name() const override { return "aucpr"; } }; - XGBOOST_REGISTER_METRIC(AMS, "ams") .describe("AMS metric for higgs.") .set_body([](const char* param) { return new EvalAMS(param); }); @@ -569,15 +604,15 @@ XGBOOST_REGISTER_METRIC(AucPR, "aucpr") XGBOOST_REGISTER_METRIC(Precision, "pre") .describe("precision@k for rank.") -.set_body([](const char* param) { return new EvalPrecision(param); }); +.set_body([](const char* param) { return new EvalPrecision("pre", param); }); XGBOOST_REGISTER_METRIC(NDCG, "ndcg") .describe("ndcg@k for rank.") -.set_body([](const char* param) { return new EvalNDCG(param); }); +.set_body([](const char* param) { return new EvalNDCG("ndcg", param); }); XGBOOST_REGISTER_METRIC(MAP, "map") .describe("map@k for rank.") -.set_body([](const char* param) { return new EvalMAP(param); }); +.set_body([](const char* param) { return new EvalMAP("map", param); }); XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik") .describe("Negative log partial likelihood of Cox proportioanl hazards model.") diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index bca61c7b3ce1..8b62d265fa06 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -121,11 +121,13 @@ void CheckRankingObjFunction(std::unique_ptr const& obj, xgboost::bst_float GetMetricEval(xgboost::Metric * metric, xgboost::HostDeviceVector preds, std::vector labels, - std::vector weights) { + std::vector weights, + std::vector groups) { xgboost::MetaInfo info; info.num_row_ = labels.size(); info.labels_.HostVector() = labels; info.weights_.HostVector() = weights; + info.group_ptr_ = groups; return metric->Eval(preds, info, false); } diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 8001076bae52..809b6fa7a938 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -81,7 +81,8 @@ xgboost::bst_float GetMetricEval( xgboost::Metric * metric, xgboost::HostDeviceVector preds, std::vector labels, - std::vector weights = std::vector ()); + std::vector weights = std::vector(), + std::vector groups = std::vector()); namespace xgboost { bool IsNear(std::vector::const_iterator _beg1, diff --git a/tests/cpp/metric/test_rank_metric.cc b/tests/cpp/metric/test_rank_metric.cc index 792deb5a62d7..1861eb398fed 100644 --- a/tests/cpp/metric/test_rank_metric.cc +++ b/tests/cpp/metric/test_rank_metric.cc @@ -34,6 +34,29 @@ TEST(Metric, AUC) { EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {})); EXPECT_ANY_THROW(GetMetricEval(metric, {0, 0}, {0, 0})); + // AUC with instance weights + EXPECT_NEAR(GetMetricEval(metric, + {0.9f, 0.1f, 0.4f, 0.3f}, + {0, 0, 1, 1}, + {1.0f, 3.0f, 2.0f, 4.0f}), + 0.75f, 0.001f); + + // AUC for a ranking task without weights + EXPECT_NEAR(GetMetricEval(metric, + {0.9f, 0.1f, 0.4f, 0.3f, 0.7f}, + {0.1f, 0.2f, 0.3f, 0.4f, 0.5f}, + {}, + {0, 2, 5}), + 0.4741f, 0.001f); + + // AUC for a ranking task with weights/group + EXPECT_NEAR(GetMetricEval(metric, + {0.9f, 0.1f, 0.4f, 0.3f, 0.7f}, + {0.1f, 0.2f, 0.3f, 0.4f, 0.5f}, + {1, 2}, + {0, 2, 5}), + 0.4741f, 0.001f); + delete metric; } @@ -58,9 +81,37 @@ TEST(Metric, AUCPR) { EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {})); EXPECT_ANY_THROW(GetMetricEval(metric, {0, 0}, {0, 0})); + // AUCPR with instance weights + EXPECT_NEAR(GetMetricEval( + metric, {0.29f, 0.52f, 0.11f, 0.21f, 0.219f, 0.93f, 0.493f, + 0.17f, 0.47f, 0.13f, 0.43f, 0.59f, 0.87f, 0.007f}, + {0, 0.1f, 0.2f, 0.3f, 0.4f, 0.3f, 0.1f, 0.2f, 0.4f, 0, 0.2f, 0.3f, 1, 0}, + {1, 2, 7, 4, 5, 2.2f, 3.2f, 5, 6, 1, 2, 1.1f, 3.2f, 4.5f}), // weights + 0.425919f, 0.001f); + + // AUCPR with groups and no weights + EXPECT_NEAR(GetMetricEval( + metric, {0.87f, 0.31f, 0.40f, 0.42f, 0.25f, 0.66f, 0.95f, + 0.09f, 0.10f, 0.97f, 0.76f, 0.69f, 0.15f, 0.20f, + 0.30f, 0.14f, 0.07f, 0.58f, 0.61f, 0.08f}, + {0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1}, + {}, // weights + {0, 2, 5, 9, 14, 20}), // group info + 0.556021f, 0.001f); + + // AUCPR with groups and weights + EXPECT_NEAR(GetMetricEval( + metric, {0.29f, 0.52f, 0.11f, 0.21f, 0.219f, 0.93f, 0.493f, + 0.17f, 0.47f, 0.13f, 0.43f, 0.59f, 0.87f, 0.007f}, // predictions + {0, 0.1f, 0.2f, 0.3f, 0.4f, 0.3f, 0.1f, 0.2f, 0.4f, 0, 0.2f, 0.3f, 1, 0}, + {1, 2, 7, 4, 5, 2.2f, 3.2f, 5, 6, 1, 2, 1.1f, 3.2f, 4.5f}, // weights + {0, 2, 5, 9, 14}), // group info + 0.423391f, 0.001f); + delete metric; } + TEST(Metric, Precision) { // When the limit for precision is not given, it takes the limit at // std::numeric_limits::max(); hence all values are very small @@ -159,6 +210,14 @@ TEST(Metric, MAP) { xgboost::HostDeviceVector{}, std::vector{}), 1, 1e-10); + // Rank metric with group info + EXPECT_NEAR(GetMetricEval(metric, + {0.1f, 0.9f, 0.2f, 0.8f, 0.4f, 1.7f}, + {2, 7, 1, 0, 5, 0}, // Labels + {}, // Weights + {0, 2, 5, 6}), // Group info + 0.8611f, 0.001f); + delete metric; metric = xgboost::Metric::Create("map@-", &tparam); ASSERT_STREQ(metric->Name(), "map-");