Skip to content

Commit

Permalink
fix pairwise dataset bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyu1994 committed Feb 29, 2024
1 parent 0aaf090 commit 8714bfb
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 34 deletions.
25 changes: 24 additions & 1 deletion include/LightGBM/pairwise_ranking_feature_group.h
Expand Up @@ -72,14 +72,37 @@ class PairwiseRankingFeatureGroup: public FeatureGroup {
}

inline void FinishLoad() {
// TODO(shiyu1994)
CHECK(!is_multi_val_);
bin_data_->FinishLoad();
}

inline BinIterator* FeatureGroupIterator() {
// TODO(shiyu1994)
return nullptr;
}

/*!
* \brief Push one record, will auto convert to bin and push to bin data
* \param tid Thread id
* \param sub_feature_idx Index of the subfeature
* \param line_idx Index of record
* \param bin feature bin value of record
*/
inline void PushBinData(int tid, int sub_feature_idx, data_size_t line_idx, uint32_t bin) {
if (bin == bin_mappers_[sub_feature_idx]->GetMostFreqBin()) {
return;
}
if (bin_mappers_[sub_feature_idx]->GetMostFreqBin() == 0) {
bin -= 1;
}
if (is_multi_val_) {
multi_bin_data_[sub_feature_idx]->Push(tid, line_idx, bin + 1);
} else {
bin += bin_offsets_[sub_feature_idx];
bin_data_->Push(tid, line_idx, bin);
}
}

private:
void CreateBinData(int num_data, bool is_multi_val, bool force_dense, bool force_sparse) override;

Expand Down
4 changes: 2 additions & 2 deletions src/io/dataset.cpp
Expand Up @@ -858,7 +858,7 @@ void Dataset::CreatePairWiseRankingData(const Dataset* dataset) {
for (int i = 0; i < num_groups_; ++i) {
int original_group_index = i % dataset->num_groups_;
int original_group_feature_start = dataset->group_feature_start_[original_group_index];
const int is_first_or_second_in_pairing = original_group_index / dataset->num_groups_; // 0 for first, 1 for second
const int is_first_or_second_in_pairing = i / dataset->num_groups_; // 0 for first, 1 for second
group_feature_start_[i] = cur_feature_index;
for (int feature_index_in_group = 0; feature_index_in_group < dataset->group_feature_cnt_[original_group_index]; ++feature_index_in_group) {
const BinMapper* feature_bin_mapper = dataset->FeatureBinMapper(original_group_feature_start + feature_index_in_group);
Expand All @@ -869,7 +869,7 @@ void Dataset::CreatePairWiseRankingData(const Dataset* dataset) {
feature2subfeature_.push_back(dataset->feature2subfeature_[original_group_feature_start + feature_index_in_group]);
cur_feature_index += 1;
}
feature_groups_.emplace_back(new PairwiseRankingFeatureGroup(*dataset->feature_groups_[original_group_index].get(), num_data_, is_first_or_second_in_pairing, metadata_.paired_ranking_item_index_map_size(), metadata_.paired_ranking_item_index_map()));
feature_groups_.emplace_back(new PairwiseRankingFeatureGroup(*dataset->feature_groups_[original_group_index].get(), dataset->num_data(), is_first_or_second_in_pairing, metadata_.paired_ranking_item_index_map_size(), metadata_.paired_ranking_item_index_map()));
num_total_bin += dataset->FeatureGroupNumBin(original_group_index);
group_bin_boundaries_.push_back(num_total_bin);
group_feature_cnt_[i] = dataset->group_feature_cnt_[original_group_index];
Expand Down
9 changes: 0 additions & 9 deletions src/io/dataset_loader.cpp
Expand Up @@ -293,12 +293,6 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
// need to check training data
CheckDataset(dataset.get(), is_load_from_binary);

if (config_.objective == std::string("pairwise_lambdarank")) {
std::unique_ptr<Dataset> original_dataset(dataset.release());
dataset.reset(new Dataset());
dataset->CreatePairWiseRankingData(original_dataset.get());
}

return dataset.release();
}

Expand Down Expand Up @@ -357,9 +351,6 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
// check meta data
dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);

// TODO(shiyu1994)
Log::Warning("Pairwise ranking with validation set is not supported yet.");

return dataset.release();
}

Expand Down
6 changes: 3 additions & 3 deletions src/io/metadata.cpp
Expand Up @@ -911,11 +911,11 @@ data_size_t Metadata::BuildPairwiseFeatureRanking(const Metadata& metadata) {
}
}

data_size_t num_pairs_in_query = 0;
query_boundaries_.clear();
query_boundaries_.push_back(0);
num_queries_ = 0;
for (data_size_t query_index = 0; query_index < num_queries_; ++query_index) {
for (data_size_t query_index = 0; query_index < metadata.num_queries(); ++query_index) {
data_size_t num_pairs_in_query = 0;
const data_size_t query_start = query_boundaries[query_index];
const data_size_t query_end = query_boundaries[query_index + 1];
for (data_size_t item_index_i = query_start; item_index_i < query_end; ++item_index_i) {
Expand All @@ -926,7 +926,7 @@ data_size_t Metadata::BuildPairwiseFeatureRanking(const Metadata& metadata) {
}
const label_t label_j = label_[item_index_j];
if (label_i != label_j) {
paired_ranking_item_index_map_.push_back(std::pair<data_size_t, data_size_t>{item_index_i - query_start, item_index_j - query_start});
paired_ranking_item_index_map_.push_back(std::pair<data_size_t, data_size_t>{item_index_i, item_index_j});
++num_pairs_in_query;
++num_data_;
}
Expand Down
12 changes: 8 additions & 4 deletions src/io/pairwise_lambdarank_bin.hpp
Expand Up @@ -32,7 +32,7 @@ class PairwiseRankingFirstIterator: public BinIterator {
unpaired_bin_iterator_.reset(unpaired_bin_->GetIterator(min_bin, max_bin, most_freq_bin));
unpaired_bin_iterator_->Reset(0);
paired_ranking_item_index_map_ = paired_ranking_item_index_map;
prev_index_ = 0;
prev_index_ = -1;
prev_val_ = 0;
}

Expand All @@ -41,7 +41,7 @@ class PairwiseRankingFirstIterator: public BinIterator {
uint32_t Get(data_size_t idx) {
const data_size_t data_index = paired_ranking_item_index_map_[idx].first;
if (data_index != prev_index_) {
CHECK_GT(data_index, prev_index_);
CHECK_GE(data_index, prev_index_);
prev_val_ = unpaired_bin_iterator_->Get(data_index);
}
prev_index_ = data_index;
Expand All @@ -51,7 +51,7 @@ class PairwiseRankingFirstIterator: public BinIterator {
uint32_t RawGet(data_size_t idx) {
const data_size_t data_index = paired_ranking_item_index_map_[idx].first;
if (data_index != prev_index_) {
CHECK_GT(data_index, prev_index_);
CHECK_GE(data_index, prev_index_);
prev_val_ = unpaired_bin_iterator_->RawGet(data_index);
}
prev_index_ = data_index;
Expand All @@ -60,7 +60,7 @@ class PairwiseRankingFirstIterator: public BinIterator {

void Reset(data_size_t idx) {
unpaired_bin_iterator_->Reset(idx);
prev_index_ = 0;
prev_index_ = -1;
prev_val_ = 0;
}

Expand Down Expand Up @@ -134,6 +134,10 @@ class PairwiseRankingBin: public BIN_TYPE {

void Push(int tid, data_size_t idx, uint32_t value) override;

void FinishLoad() override {
unpaired_bin_->FinishLoad();
}

void CopySubrow(const Bin* full_bin, const data_size_t* used_indices, data_size_t num_used_indices) override;

void SaveBinaryToFile(BinaryWriter* writer) const override;
Expand Down
18 changes: 4 additions & 14 deletions src/io/pairwise_ranking_feature_group.cpp
Expand Up @@ -16,22 +16,12 @@ PairwiseRankingFeatureGroup::PairwiseRankingFeatureGroup(const FeatureGroup& oth

CreateBinData(num_original_data, is_multi_val_, !is_sparse_, is_sparse_);

// copy from original bin data
const int num_threads = OMP_NUM_THREADS();
std::vector<std::vector<std::unique_ptr<BinIterator>>> bin_iterators(num_threads);
for (int i = 0; i < num_threads; ++i) {
for (int j = 0; j < num_feature_; ++j) {
bin_iterators[i].emplace_back(other.SubFeatureIterator(j));
bin_iterators[i].back()->Reset(0);
}
}

Threading::For<data_size_t>(0, num_original_data, 512, [this, &other] (int block_index, data_size_t block_start, data_size_t block_end) {
for (int feature_index = 0; feature_index < num_feature_; ++feature_index) {
std::unique_ptr<BinIterator> bin_iterator(other.SubFeatureIterator(feature_index));
bin_iterator->Reset(block_start);
for (data_size_t index = block_start; index < block_end; ++index) {
PushData(block_index, feature_index, index, bin_iterator->RawGet(index));
PushBinData(block_index, feature_index, index, bin_iterator->Get(index));
}
}
});
Expand All @@ -50,7 +40,7 @@ void PairwiseRankingFeatureGroup::CreateBinData(int num_data, bool is_multi_val,
multi_bin_data_.emplace_back(Bin::CreateSparsePairwiseRankingFirstBin(
num_data, bin_mappers_[i]->num_bin() + addi, num_data_, paired_ranking_item_index_map_));
} else {
multi_bin_data_.emplace_back(Bin::CreateSparsePairwiseRankingSecondBin(
multi_bin_data_.emplace_back(Bin::CreateSparsePairwiseRankingSecondBin(
num_data, bin_mappers_[i]->num_bin() + addi, num_data_, paired_ranking_item_index_map_));
}
} else {
Expand All @@ -69,14 +59,14 @@ void PairwiseRankingFeatureGroup::CreateBinData(int num_data, bool is_multi_val,
(!force_dense && num_feature_ == 1 &&
bin_mappers_[0]->sparse_rate() >= kSparseThreshold)) {
is_sparse_ = true;
if (is_first_or_second_in_pairing_) {
if (is_first_or_second_in_pairing_ == 0) {
bin_data_.reset(Bin::CreateSparsePairwiseRankingFirstBin(num_data, num_total_bin_, num_data_, paired_ranking_item_index_map_));
} else {
bin_data_.reset(Bin::CreateSparsePairwiseRankingSecondBin(num_data, num_total_bin_, num_data_, paired_ranking_item_index_map_));
}
} else {
is_sparse_ = false;
if (is_first_or_second_in_pairing_) {
if (is_first_or_second_in_pairing_ == 0) {
bin_data_.reset(Bin::CreateDensePairwiseRankingFirstBin(num_data, num_total_bin_, num_data_, paired_ranking_item_index_map_));
} else {
bin_data_.reset(Bin::CreateDensePairwiseRankingSecondBin(num_data, num_total_bin_, num_data_, paired_ranking_item_index_map_));
Expand Down
3 changes: 3 additions & 0 deletions src/io/sparse_bin.hpp
Expand Up @@ -844,6 +844,9 @@ inline VAL_T SparseBinIterator<VAL_T>::InnerRawGet(data_size_t idx) {
template <typename VAL_T>
inline void SparseBinIterator<VAL_T>::Reset(data_size_t start_idx) {
bin_data_->InitIndex(start_idx, &i_delta_, &cur_pos_);
if (start_idx == 1920 || start_idx == 4320) {
Log::Warning("i_delta_ = %d, cur_pos_ = %d, start_idx = %d", i_delta_, cur_pos_, start_idx);
}
}

template <typename VAL_T>
Expand Down
2 changes: 1 addition & 1 deletion src/metric/dcg_calculator.cpp
Expand Up @@ -149,7 +149,7 @@ void DCGCalculator::CheckLabel(const label_t* label, data_size_t num_data) {
label_t delta = std::fabs(label[i] - static_cast<int>(label[i]));
if (delta > kEpsilon) {
Log::Fatal("label should be int type (met %f) for ranking task,\n"
"for the gain of label, please set the label_gain parameter", label[i]);
"for the gain of label, please set the label_gain parameter, i = %d", label[i], i);
}

if (label[i] < 0) {
Expand Down

0 comments on commit 8714bfb

Please sign in to comment.