Skip to content

Commit

Permalink
adding initial support for pairwise gradients and NDCG eval with pair…
Browse files Browse the repository at this point in the history
…wise scores
  • Loading branch information
metpavel committed Feb 9, 2024
1 parent 1699c06 commit d5b6f0a
Show file tree
Hide file tree
Showing 3 changed files with 401 additions and 5 deletions.
3 changes: 3 additions & 0 deletions include/LightGBM/objective_function.h
Expand Up @@ -110,6 +110,9 @@ class ObjectiveFunction {
#endif // USE_CUDA
};

void UpdatePointwiseScoresForOneQuery(data_size_t query_id, double* score_pointwise, const double* score, data_size_t cnt_pointwise,
data_size_t selected_pairs_cnt, const data_size_t* selected_pairs, const std::pair<data_size_t, data_size_t>* paired_index_map, int truncation_level, double sigma);

} // namespace LightGBM

#endif // LightGBM_OBJECTIVE_FUNCTION_H_
47 changes: 43 additions & 4 deletions src/metric/rank_metric.hpp
Expand Up @@ -9,10 +9,12 @@
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/objective_function.h>

#include <string>
#include <sstream>
#include <vector>
#include <numeric>

namespace LightGBM {

Expand All @@ -26,6 +28,9 @@ class NDCGMetric:public Metric {
DCGCalculator::DefaultLabelGain(&label_gain);
// initialize DCG calculator
DCGCalculator::Init(label_gain);
pairwise_scores_ = config.objective == std::string("pairwise_lambdarank");
sigmoid_ = config.sigmoid;
truncation_level_ = config.lambdarank_truncation_level;
}

~NDCGMetric() {
Expand All @@ -34,14 +39,14 @@ class NDCGMetric:public Metric {
for (auto k : eval_at_) {
name_.emplace_back(std::string("ndcg@") + std::to_string(k));
}
num_data_ = num_data;
num_data_ = pairwise_scores_? metadata.pointwise_query_boundaries()[metadata.num_queries()] : num_data;
// get label
label_ = metadata.label();
num_queries_ = metadata.num_queries();
DCGCalculator::CheckMetadata(metadata, num_queries_);
DCGCalculator::CheckLabel(label_, num_data_);
// get query boundaries
query_boundaries_ = metadata.query_boundaries();
query_boundaries_ = pairwise_scores_? metadata.pointwise_query_boundaries() : metadata.query_boundaries();
if (query_boundaries_ == nullptr) {
Log::Fatal("The NDCG metric requires query information");
}
Expand Down Expand Up @@ -73,6 +78,12 @@ class NDCGMetric:public Metric {
}
}
}
if (pairwise_scores_) {
paired_index_map_ = metadata.paired_ranking_item_index_map();
scores_pointwise_.resize(num_data_, 0.0);
num_data_pairwise_ = num_data;
query_boundaries_pairwise_ = metadata.query_boundaries();
}
}

const std::vector<std::string>& GetName() const override {
Expand Down Expand Up @@ -101,9 +112,19 @@ class NDCGMetric:public Metric {
result_buffer_[tid][j] += 1.0f;
}
} else {
if (pairwise_scores_) {
const data_size_t start_pointwise = query_boundaries_[i];
const data_size_t cnt_pointwise = query_boundaries_[i + 1] - query_boundaries_[i];
const data_size_t start_pairwise = query_boundaries_pairwise_[i];
const data_size_t cnt_pairwise = query_boundaries_[i + 1] - query_boundaries_[i];
std::vector<data_size_t> all_pairs(cnt_pairwise);
std::iota(all_pairs.begin(), all_pairs.end(), 0);
UpdatePointwiseScoresForOneQuery(i, scores_pointwise_.data() + start_pointwise, score + start_pairwise, cnt_pointwise, cnt_pairwise, all_pairs.data(), paired_index_map_ + start_pairwise, truncation_level_, sigmoid_);
}

// calculate DCG
DCGCalculator::CalDCG(eval_at_, label_ + query_boundaries_[i],
score + query_boundaries_[i],
pairwise_scores_? scores_pointwise_.data(): score + query_boundaries_[i],
query_boundaries_[i + 1] - query_boundaries_[i], &tmp_dcg);
// calculate NDCG
for (size_t j = 0; j < eval_at_.size(); ++j) {
Expand All @@ -121,9 +142,18 @@ class NDCGMetric:public Metric {
result_buffer_[tid][j] += 1.0f;
}
} else {
if (pairwise_scores_) {
const data_size_t start_pointwise = query_boundaries_[i];
const data_size_t cnt_pointwise = query_boundaries_[i + 1] - query_boundaries_[i];
const data_size_t start_pairwise = query_boundaries_pairwise_[i];
const data_size_t cnt_pairwise = query_boundaries_[i + 1] - query_boundaries_[i];
std::vector<data_size_t> all_pairs(cnt_pairwise);
std::iota(all_pairs.begin(), all_pairs.end(), 0);
UpdatePointwiseScoresForOneQuery(i, scores_pointwise_.data() + start_pointwise, score + start_pairwise, cnt_pointwise, cnt_pairwise, all_pairs.data(), paired_index_map_ + start_pairwise, truncation_level_, sigmoid_);
}
// calculate DCG
DCGCalculator::CalDCG(eval_at_, label_ + query_boundaries_[i],
score + query_boundaries_[i],
pairwise_scores_ ? scores_pointwise_.data() : score + query_boundaries_[i],
query_boundaries_[i + 1] - query_boundaries_[i], &tmp_dcg);
// calculate NDCG
for (size_t j = 0; j < eval_at_.size(); ++j) {
Expand Down Expand Up @@ -162,6 +192,15 @@ class NDCGMetric:public Metric {
std::vector<data_size_t> eval_at_;
/*! \brief Cache the inverse max dcg for all queries */
std::vector<std::vector<double>> inverse_max_dcgs_;
bool pairwise_scores_;
double sigmoid_;
/*! \brief Truncation position for max DCG */
int truncation_level_;
mutable std::vector<double> scores_pointwise_;
const std::pair<data_size_t, data_size_t>* paired_index_map_;
/*! \brief Number of data */
data_size_t num_data_pairwise_;
const data_size_t* query_boundaries_pairwise_;
};

} // namespace LightGBM
Expand Down

0 comments on commit d5b6f0a

Please sign in to comment.