From 7e34d23c05599ce3a8a6f22cdba29e103f57d218 Mon Sep 17 00:00:00 2001 From: Pavel Metrikov <46672636+metpavel@users.noreply.github.com> Date: Mon, 4 Sep 2023 02:05:46 -0700 Subject: [PATCH] Treat position bias via GAM in LambdaMART (#5929) * Update dataset.h * Update metadata.cpp * Update rank_objective.hpp * Update metadata.cpp * Update rank_objective.hpp * Update metadata.cpp * Update dataset.h * Update rank_objective.hpp * Update metadata.cpp * Update test_engine.py * Update test_engine.py * Add files via upload * Update test_engine.py * Update test_engine.py * Update test_engine.py * Update test_engine.py * Update test_engine.py * Update _rank.train.position * Update test_engine.py * Update test_engine.py * Update test_engine.py * Update test_engine.py * Update _rank.train.position * Update _rank.train.position * Update test_engine.py * Update _rank.train.position * Update test_engine.py * Update test_engine.py * Update test_engine.py * Update test_engine.py * Update test_engine.py * Update the position of import statement * Update rank_objective.hpp * Update config.h * Update config_auto.cpp * Update rank_objective.hpp * Update rank_objective.hpp * update documentation * remove extra blank line * Update src/io/metadata.cpp Co-authored-by: James Lamb * Update src/io/metadata.cpp Co-authored-by: James Lamb * remove _rank.train.position * add position in python API * fix set_positions in basic.py * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update docs/Advanced-Topics.rst Co-authored-by: James Lamb * Update docs/Advanced-Topics.rst Co-authored-by: James Lamb * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * Update Advanced-Topics.rst * remove List from _LGBM_PositionType * move new position parameter to the last in Dataset constructor * add position_filename as a parameter * Update docs/Advanced-Topics.rst Co-authored-by: James Lamb * Update docs/Advanced-Topics.rst Co-authored-by: James Lamb * Update Advanced-Topics.rst * Update src/objective/rank_objective.hpp Co-authored-by: James Lamb * Update src/io/metadata.cpp Co-authored-by: James Lamb * Update metadata.cpp * Update python-package/lightgbm/basic.py Co-authored-by: James Lamb * Update python-package/lightgbm/basic.py Co-authored-by: James Lamb * Update python-package/lightgbm/basic.py Co-authored-by: James Lamb * Update python-package/lightgbm/basic.py Co-authored-by: James Lamb * Update src/io/metadata.cpp Co-authored-by: James Lamb * more infomrative fatal message address more comments * update documentation for more flexible position specification * fix SetPosition add tests for get_position and set_position * remove position_filename * remove useless changes * Update python-package/lightgbm/basic.py Co-authored-by: James Lamb * remove useless files * move position file when position set in Dataset * warn when positions are overwritten * skip ranking with position test in cuda * split test case * remove useless import * Update test_engine.py * Update test_engine.py * Update test_engine.py * Update docs/Advanced-Topics.rst Co-authored-by: James Lamb * Update Parameters.rst * Update rank_objective.hpp * Update config.h * update config_auto.cppp * Update docs/Advanced-Topics.rst Co-authored-by: James Lamb * fix randomness in test case for gpu --------- Co-authored-by: shiyu1994 Co-authored-by: James Lamb --- docs/Advanced-Topics.rst | 41 ++++++ docs/Parameters.rst | 4 + include/LightGBM/config.h | 4 + include/LightGBM/dataset.h | 43 ++++++ python-package/lightgbm/basic.py | 68 +++++++-- src/io/config_auto.cpp | 7 + src/io/dataset.cpp | 5 + src/io/metadata.cpp | 97 +++++++++++++ src/objective/rank_objective.hpp | 96 ++++++++++++- tests/python_package_test/test_engine.py | 168 ++++++++++++++++++++++- 10 files changed, 522 insertions(+), 11 deletions(-) diff --git a/docs/Advanced-Topics.rst b/docs/Advanced-Topics.rst index d1787b99847..345a1361bfa 100644 --- a/docs/Advanced-Topics.rst +++ b/docs/Advanced-Topics.rst @@ -77,3 +77,44 @@ Recommendations for gcc Users (MinGW, \*nix) -------------------------------------------- - Refer to `gcc Tips <./gcc-Tips.rst>`__. + +Support for Position Bias Treatment +------------------------------------ + +Often the relevance labels provided in Learning-to-Rank tasks might be derived from implicit user feedback (e.g., clicks) and therefore might be biased due to their position/location on the screen when having been presented to a user. +LightGBM can make use of positional data. + +For example, consider the case where you expect that the first 3 results from a search engine will be visible in users' browsers without scrolling, and all other results for a query would require scrolling. + +LightGBM could be told to account for the position bias from results being "above the fold" by providing a ``positions`` array encoded as follows: + +:: + + 0 + 0 + 0 + 1 + 1 + 0 + 0 + 0 + 1 + ... + +Where ``0 = "above the fold"`` and ``1 = "requires scrolling"``. +The specific values are not important, as long as they are consistent across all observations in the training data. +An encoding like ``100 = "above the fold"`` and ``17 = "requires scrolling"`` would result in exactly the same trained model. + +In that way, ``positions`` in LightGBM's API are similar to a categorical feature. +Just as with non-ordinal categorical features, an integer representation is just used for memory and computational efficiency... LightGBM does not care about the absolute or relative magnitude of the values. + +Unlike a categorical feature, however, ``positions`` are used to adjust the target to reduce the bias in predictions made by the trained model. + +The position file corresponds with training data file line by line, and has one position per line. And if the name of training data file is ``train.txt``, the position file should be named as ``train.txt.position`` and placed in the same folder as the data file. +In this case, LightGBM will load the position file automatically if it exists. The positions can also be specified through the ``Dataset`` constructor when using Python API. If the positions are specified in both approaches, the ``.position`` file will be ignored. + +Currently, implemented is an approach to model position bias by using an idea of Generalized Additive Models (`GAM `_) to linearly decompose the document score ``s`` into the sum of a relevance component ``f`` and a positional component ``g``: ``s(x, pos) = f(x) + g(pos)`` where the former component depends on the original query-document features and the latter depends on the position of an item. +During the training, the compound scoring function ``s(x, pos)`` is fit with a standard ranking algorithm (e.g., LambdaMART) which boils down to jointly learning the relevance component ``f(x)`` (it is later returned as an unbiased model) and the position factors ``g(pos)`` that help better explain the observed (biased) labels. +Similar score decomposition ideas have previously been applied for classification & pointwise ranking tasks with assumptions of binary labels and binary relevance (a.k.a. "two-tower" models, refer to the papers: `Towards Disentangling Relevance and Bias in Unbiased Learning to Rank `_, `PAL: a position-bias aware learning framework for CTR prediction in live recommender systems `_, `A General Framework for Debiasing in CTR Prediction `_). +In LightGBM, we adapt this idea to general pairwise Lerarning-to-Rank with arbitrary ordinal relevance labels. +Besides, GAMs have been used in the context of explainable ML (`Accurate Intelligible Models with Pairwise Interactions `_) to linearly decompose the contribution of each feature (and possibly their pairwise interactions) to the overall score, for subsequent analysis and interpretation of their effects in the trained models. diff --git a/docs/Parameters.rst b/docs/Parameters.rst index 5eecc27889b..7d825f9f135 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -1137,6 +1137,10 @@ Objective Parameters - separate by ``,`` +- ``lambdarank_position_bias_regularization`` :raw-html:`🔗︎`, default = ``0.0``, type = double, constraints: ``lambdarank_position_bias_regularization >= 0.0`` + + - used only in ``lambdarank`` application when positional information is provided and position bias is modeled. Larger values reduce the inferred position bias factors. + Metric Parameters ----------------- diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index e0157839625..343abf51e17 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -965,6 +965,10 @@ struct Config { // desc = separate by ``,`` std::vector label_gain; + // check = >=0.0 + // desc = used only in ``lambdarank`` application when positional information is provided and position bias is modeled. Larger values reduce the inferred position bias factors. + double lambdarank_position_bias_regularization = 0.0; + #ifndef __NVCC__ #pragma endregion diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index 825c5c6ebcf..e7baa42dc2e 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -114,6 +114,8 @@ class Metadata { void SetQuery(const data_size_t* query, data_size_t len); + void SetPosition(const data_size_t* position, data_size_t len); + /*! * \brief Set initial scores * \param init_score Initial scores, this class will manage memory for init_score. @@ -213,6 +215,38 @@ class Metadata { } } + /*! + * \brief Get positions, if does not exist then return nullptr + * \return Pointer of positions + */ + inline const data_size_t* positions() const { + if (!positions_.empty()) { + return positions_.data(); + } else { + return nullptr; + } + } + + /*! + * \brief Get position IDs, if does not exist then return nullptr + * \return Pointer of position IDs + */ + inline const std::string* position_ids() const { + if (!position_ids_.empty()) { + return position_ids_.data(); + } else { + return nullptr; + } + } + + /*! + * \brief Get Number of different position IDs + * \return number of different position IDs + */ + inline size_t num_position_ids() const { + return position_ids_.size(); + } + /*! * \brief Get data boundaries on queries, if not exists, will return nullptr * we assume data will order by query, @@ -289,6 +323,8 @@ class Metadata { private: /*! \brief Load wights from file */ void LoadWeights(); + /*! \brief Load positions from file */ + void LoadPositions(); /*! \brief Load query boundaries from file */ void LoadQueryBoundaries(); /*! \brief Calculate query weights from queries */ @@ -309,10 +345,16 @@ class Metadata { data_size_t num_data_; /*! \brief Number of weights, used to check correct weight file */ data_size_t num_weights_; + /*! \brief Number of positions, used to check correct position file */ + data_size_t num_positions_; /*! \brief Label data */ std::vector label_; /*! \brief Weights data */ std::vector weights_; + /*! \brief Positions data */ + std::vector positions_; + /*! \brief Position identifiers */ + std::vector position_ids_; /*! \brief Query boundaries */ std::vector query_boundaries_; /*! \brief Query weights */ @@ -328,6 +370,7 @@ class Metadata { /*! \brief mutex for threading safe call */ std::mutex mutex_; bool weight_load_from_file_; + bool position_load_from_file_; bool query_load_from_file_; bool init_score_load_from_file_; #ifdef USE_CUDA diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 6ff448dfeb3..2f061bdacf3 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -62,6 +62,10 @@ np.ndarray, pd_Series ] +_LGBM_PositionType = Union[ + np.ndarray, + pd_Series +] _LGBM_InitScoreType = Union[ List[float], List[List[float]], @@ -577,7 +581,8 @@ def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_va "label": _C_API_DTYPE_FLOAT32, "weight": _C_API_DTYPE_FLOAT32, "init_score": _C_API_DTYPE_FLOAT64, - "group": _C_API_DTYPE_INT32 + "group": _C_API_DTYPE_INT32, + "position": _C_API_DTYPE_INT32 } """String name to int feature importance type mapper""" @@ -1525,7 +1530,8 @@ def __init__( feature_name: _LGBM_FeatureNameConfiguration = 'auto', categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto', params: Optional[Dict[str, Any]] = None, - free_raw_data: bool = True + free_raw_data: bool = True, + position: Optional[_LGBM_PositionType] = None, ): """Initialize Dataset. @@ -1565,6 +1571,8 @@ def __init__( Other parameters for Dataset. free_raw_data : bool, optional (default=True) If True, raw data is freed after constructing inner Dataset. + position : numpy 1-D array, pandas Series or None, optional (default=None) + Position of items used in unbiased learning-to-rank task. """ self._handle: Optional[_DatasetHandle] = None self.data = data @@ -1572,6 +1580,7 @@ def __init__( self.reference = reference self.weight = weight self.group = group + self.position = position self.init_score = init_score self.feature_name: _LGBM_FeatureNameConfiguration = feature_name self.categorical_feature: _LGBM_CategoricalFeatureConfiguration = categorical_feature @@ -1836,7 +1845,8 @@ def _lazy_init( predictor: Optional[_InnerPredictor], feature_name: _LGBM_FeatureNameConfiguration, categorical_feature: _LGBM_CategoricalFeatureConfiguration, - params: Optional[Dict[str, Any]] + params: Optional[Dict[str, Any]], + position: Optional[_LGBM_PositionType] ) -> "Dataset": if data is None: self._handle = None @@ -1925,6 +1935,8 @@ def _lazy_init( self.set_weight(weight) if group is not None: self.set_group(group) + if position is not None: + self.set_position(position) if isinstance(predictor, _InnerPredictor): if self._predictor is None and init_score is not None: _log_warning("The init_score will be overridden by the prediction of init_model.") @@ -2219,7 +2231,7 @@ def construct(self) -> "Dataset": if self.used_indices is None: # create valid self._lazy_init(data=self.data, label=self.label, reference=self.reference, - weight=self.weight, group=self.group, + weight=self.weight, group=self.group, position=self.position, init_score=self.init_score, predictor=self._predictor, feature_name=self.feature_name, categorical_feature='auto', params=self.params) else: @@ -2242,6 +2254,8 @@ def construct(self) -> "Dataset": self.get_data() if self.group is not None: self.set_group(self.group) + if self.position is not None: + self.set_position(self.position) if self.get_label() is None: raise ValueError("Label should not be None.") if isinstance(self._predictor, _InnerPredictor) and self._predictor is not self.reference._predictor: @@ -2256,7 +2270,8 @@ def construct(self) -> "Dataset": self._lazy_init(data=self.data, label=self.label, reference=None, weight=self.weight, group=self.group, init_score=self.init_score, predictor=self._predictor, - feature_name=self.feature_name, categorical_feature=self.categorical_feature, params=self.params) + feature_name=self.feature_name, categorical_feature=self.categorical_feature, + params=self.params, position=self.position) if self.free_raw_data: self.data = None self.feature_name = self.get_feature_name() @@ -2269,7 +2284,8 @@ def create_valid( weight: Optional[_LGBM_WeightType] = None, group: Optional[_LGBM_GroupType] = None, init_score: Optional[_LGBM_InitScoreType] = None, - params: Optional[Dict[str, Any]] = None + params: Optional[Dict[str, Any]] = None, + position: Optional[_LGBM_PositionType] = None ) -> "Dataset": """Create validation data align with current Dataset. @@ -2292,6 +2308,8 @@ def create_valid( Init score for Dataset. params : dict or None, optional (default=None) Other parameters for validation Dataset. + position : numpy 1-D array, pandas Series or None, optional (default=None) + Position of items used in unbiased learning-to-rank task. Returns ------- @@ -2299,7 +2317,7 @@ def create_valid( Validation Dataset with reference to self. """ ret = Dataset(data, label=label, reference=self, - weight=weight, group=group, init_score=init_score, + weight=weight, group=group, position=position, init_score=init_score, params=params, free_raw_data=self.free_raw_data) ret._predictor = self._predictor ret.pandas_categorical = self.pandas_categorical @@ -2434,7 +2452,7 @@ def set_field( 'In multiclass classification init_score can also be a list of lists, numpy 2-D array or pandas DataFrame.' ) else: - dtype = np.int32 if field_name == 'group' else np.float32 + dtype = np.int32 if (field_name == 'group' or field_name == 'position') else np.float32 data = _list_to_1d_numpy(data, dtype=dtype, name=field_name) ptr_data: Union[_ctypes_float_ptr, _ctypes_int_ptr] @@ -2727,6 +2745,28 @@ def set_group( self.set_field('group', group) return self + def set_position( + self, + position: Optional[_LGBM_PositionType] + ) -> "Dataset": + """Set position of Dataset (used for ranking). + + Parameters + ---------- + position : numpy 1-D array, pandas Series or None, optional (default=None) + Position of items used in unbiased learning-to-rank task. + + Returns + ------- + self : Dataset + Dataset with set position. + """ + self.position = position + if self._handle is not None and position is not None: + position = _list_to_1d_numpy(position, dtype=np.int32, name='position') + self.set_field('position', position) + return self + def get_feature_name(self) -> List[str]: """Get the names of columns (features) in the Dataset. @@ -2853,6 +2893,18 @@ def get_group(self) -> Optional[np.ndarray]: self.group = np.diff(self.group) return self.group + def get_position(self) -> Optional[np.ndarray]: + """Get the position of the Dataset. + + Returns + ------- + position : numpy 1-D array or None + Position of items used in unbiased learning-to-rank task. + """ + if self.position is None: + self.position = self.get_field('position') + return self.position + def num_data(self) -> int: """Get the number of rows in the Dataset. diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp index 0906ba4b643..8182c9b52b9 100644 --- a/src/io/config_auto.cpp +++ b/src/io/config_auto.cpp @@ -304,6 +304,7 @@ const std::unordered_set& Config::parameter_set() { "lambdarank_truncation_level", "lambdarank_norm", "label_gain", + "lambdarank_position_bias_regularization", "metric", "metric_freq", "is_provide_training_metric", @@ -619,6 +620,9 @@ void Config::GetMembersFromString(const std::unordered_map(tmp_str, ','); } + GetDouble(params, "lambdarank_position_bias_regularization", &lambdarank_position_bias_regularization); + CHECK_GE(lambdarank_position_bias_regularization, 0.0); + GetInt(params, "metric_freq", &metric_freq); CHECK_GT(metric_freq, 0); @@ -754,6 +758,7 @@ std::string Config::SaveMembersToString() const { str_buf << "[lambdarank_truncation_level: " << lambdarank_truncation_level << "]\n"; str_buf << "[lambdarank_norm: " << lambdarank_norm << "]\n"; str_buf << "[label_gain: " << Common::Join(label_gain, ",") << "]\n"; + str_buf << "[lambdarank_position_bias_regularization: " << lambdarank_position_bias_regularization << "]\n"; str_buf << "[eval_at: " << Common::Join(eval_at, ",") << "]\n"; str_buf << "[multi_error_top_k: " << multi_error_top_k << "]\n"; str_buf << "[auc_mu_weights: " << Common::Join(auc_mu_weights, ",") << "]\n"; @@ -893,6 +898,7 @@ const std::unordered_map>& Config::paramet {"lambdarank_truncation_level", {}}, {"lambdarank_norm", {}}, {"label_gain", {}}, + {"lambdarank_position_bias_regularization", {}}, {"metric", {"metrics", "metric_types"}}, {"metric_freq", {"output_freq"}}, {"is_provide_training_metric", {"training_metric", "is_training_metric", "train_metric"}}, @@ -1035,6 +1041,7 @@ const std::unordered_map& Config::ParameterTypes() { {"lambdarank_truncation_level", "int"}, {"lambdarank_norm", "bool"}, {"label_gain", "vector"}, + {"lambdarank_position_bias_regularization", "double"}, {"metric", "vector"}, {"metric_freq", "int"}, {"is_provide_training_metric", "bool"}, diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index 9e590b79821..d5aa707adcc 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -937,6 +937,8 @@ bool Dataset::SetIntField(const char* field_name, const int* field_data, name = Common::Trim(name); if (name == std::string("query") || name == std::string("group")) { metadata_.SetQuery(field_data, num_element); + } else if (name == std::string("position")) { + metadata_.SetPosition(field_data, num_element); } else { return false; } @@ -987,6 +989,9 @@ bool Dataset::GetIntField(const char* field_name, data_size_t* out_len, if (name == std::string("query") || name == std::string("group")) { *out_ptr = metadata_.query_boundaries(); *out_len = metadata_.num_queries() + 1; + } else if (name == std::string("position")) { + *out_ptr = metadata_.positions(); + *out_len = num_data_; } else { return false; } diff --git a/src/io/metadata.cpp b/src/io/metadata.cpp index 2a589fa24ef..1fc47c46787 100644 --- a/src/io/metadata.cpp +++ b/src/io/metadata.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -15,7 +16,9 @@ Metadata::Metadata() { num_init_score_ = 0; num_data_ = 0; num_queries_ = 0; + num_positions_ = 0; weight_load_from_file_ = false; + position_load_from_file_ = false; query_load_from_file_ = false; init_score_load_from_file_ = false; #ifdef USE_CUDA @@ -28,6 +31,7 @@ void Metadata::Init(const char* data_filename) { // for lambdarank, it needs query data for partition data in distributed learning LoadQueryBoundaries(); LoadWeights(); + LoadPositions(); CalculateQueryWeights(); LoadInitialScore(data_filename_); } @@ -214,6 +218,13 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector 0 && num_positions_ != num_all_data) { + positions_.clear(); + num_positions_ = 0; + Log::Fatal("Positions size (%i) doesn't match data size (%i)", num_positions_, num_data_); + } + // get local positions + if (!positions_.empty()) { + auto old_positions = positions_; + num_positions_ = num_data_; + positions_ = std::vector(num_data_); + #pragma omp parallel for schedule(static, 512) + for (int i = 0; i < static_cast(used_data_indices.size()); ++i) { + positions_[i] = old_positions[used_data_indices[i]]; + } + old_positions.clear(); + } + } if (query_load_from_file_) { // check query boundries if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_all_data) { @@ -489,6 +519,47 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) { #endif // USE_CUDA } +void Metadata::SetPosition(const data_size_t* positions, data_size_t len) { + std::lock_guard lock(mutex_); + // save to nullptr + if (positions == nullptr || len == 0) { + positions_.clear(); + num_positions_ = 0; + return; + } + #ifdef USE_CUDA + Log::Fatal("Positions in learning to rank is not supported in CUDA version yet."); + #endif // USE_CUDA + if (num_data_ != len) { + Log::Fatal("Positions size (%i) doesn't match data size (%i)", len, num_data_); + } + if (positions_.empty()) { + positions_.resize(num_data_); + } else { + Log::Warning("Overwritting positions in dataset."); + } + num_positions_ = num_data_; + + position_load_from_file_ = false; + + position_ids_.clear(); + std::unordered_map map_id2pos; + for (data_size_t i = 0; i < num_positions_; ++i) { + if (map_id2pos.count(positions[i]) == 0) { + int pos = static_cast(map_id2pos.size()); + map_id2pos[positions[i]] = pos; + position_ids_.push_back(std::to_string(positions[i])); + } + } + + Log::Debug("number of unique positions found = %ld", position_ids_.size()); + + #pragma omp parallel for schedule(static, 512) if (num_positions_ >= 1024) + for (data_size_t i = 0; i < num_positions_; ++i) { + positions_[i] = map_id2pos.at(positions[i]); + } +} + void Metadata::InsertQueries(const data_size_t* queries, data_size_t start_index, data_size_t len) { if (!queries) { Log::Fatal("Passed null queries"); @@ -528,6 +599,32 @@ void Metadata::LoadWeights() { weight_load_from_file_ = true; } +void Metadata::LoadPositions() { + num_positions_ = 0; + std::string position_filename(data_filename_); + // default position file name + position_filename.append(".position"); + TextReader reader(position_filename.c_str(), false); + reader.ReadAllLines(); + if (reader.Lines().empty()) { + return; + } + Log::Info("Loading positions from %s ...", position_filename.c_str()); + num_positions_ = static_cast(reader.Lines().size()); + positions_ = std::vector(num_positions_); + position_ids_ = std::vector(); + std::unordered_map map_id2pos; + for (data_size_t i = 0; i < num_positions_; ++i) { + std::string& line = reader.Lines()[i]; + if (map_id2pos.count(line) == 0) { + map_id2pos[line] = static_cast(position_ids_.size()); + position_ids_.push_back(line); + } + positions_[i] = map_id2pos.at(line); + } + position_load_from_file_ = true; +} + void Metadata::LoadInitialScore(const std::string& data_filename) { num_init_score_ = 0; std::string init_score_filename(data_filename); diff --git a/src/objective/rank_objective.hpp b/src/objective/rank_objective.hpp index 653fc6e8609..6bd5324812f 100644 --- a/src/objective/rank_objective.hpp +++ b/src/objective/rank_objective.hpp @@ -25,7 +25,10 @@ namespace LightGBM { class RankingObjective : public ObjectiveFunction { public: explicit RankingObjective(const Config& config) - : seed_(config.objective_seed) {} + : seed_(config.objective_seed) { + learning_rate_ = config.learning_rate; + position_bias_regularization_ = config.lambdarank_position_bias_regularization; + } explicit RankingObjective(const std::vector&) : seed_(0) {} @@ -37,12 +40,20 @@ class RankingObjective : public ObjectiveFunction { label_ = metadata.label(); // get weights weights_ = metadata.weights(); + // get positions + positions_ = metadata.positions(); + // get position ids + position_ids_ = metadata.position_ids(); + // get number of different position ids + num_position_ids_ = static_cast(metadata.num_position_ids()); // get boundries query_boundaries_ = metadata.query_boundaries(); if (query_boundaries_ == nullptr) { Log::Fatal("Ranking tasks require query information"); } num_queries_ = metadata.num_queries(); + // initialize position bias vectors + pos_biases_.resize(num_position_ids_, 0.0); } void GetGradients(const double* score, score_t* gradients, @@ -51,7 +62,13 @@ class RankingObjective : public ObjectiveFunction { for (data_size_t i = 0; i < num_queries_; ++i) { const data_size_t start = query_boundaries_[i]; const data_size_t cnt = query_boundaries_[i + 1] - query_boundaries_[i]; - GetGradientsForOneQuery(i, cnt, label_ + start, score + start, + std::vector score_adjusted; + if (num_position_ids_ > 0) { + for (data_size_t j = 0; j < cnt; ++j) { + score_adjusted.push_back(score[start + j] + pos_biases_[positions_[start + j]]); + } + } + GetGradientsForOneQuery(i, cnt, label_ + start, num_position_ids_ > 0 ? score_adjusted.data() : score + start, gradients + start, hessians + start); if (weights_ != nullptr) { for (data_size_t j = 0; j < cnt; ++j) { @@ -62,6 +79,9 @@ class RankingObjective : public ObjectiveFunction { } } } + if (num_position_ids_ > 0) { + UpdatePositionBiasFactors(gradients, hessians); + } } virtual void GetGradientsForOneQuery(data_size_t query_id, data_size_t cnt, @@ -69,6 +89,8 @@ class RankingObjective : public ObjectiveFunction { const double* score, score_t* lambdas, score_t* hessians) const = 0; + virtual void UpdatePositionBiasFactors(const score_t* /*lambdas*/, const score_t* /*hessians*/) const {} + const char* GetName() const override = 0; std::string ToString() const override { @@ -88,8 +110,20 @@ class RankingObjective : public ObjectiveFunction { const label_t* label_; /*! \brief Pointer of weights */ const label_t* weights_; + /*! \brief Pointer of positions */ + const data_size_t* positions_; + /*! \brief Pointer of position IDs */ + const std::string* position_ids_; + /*! \brief Pointer of label */ + data_size_t num_position_ids_; /*! \brief Query boundaries */ const data_size_t* query_boundaries_; + /*! \brief Position bias factors */ + mutable std::vector pos_biases_; + /*! \brief Learning rate to update position bias factors */ + double learning_rate_; + /*! \brief Position bias regularization */ + double position_bias_regularization_; }; /*! @@ -253,9 +287,67 @@ class LambdarankNDCG : public RankingObjective { } } + void UpdatePositionBiasFactors(const score_t* lambdas, const score_t* hessians) const override { + /// get number of threads + int num_threads = 1; + #pragma omp parallel + #pragma omp master + { + num_threads = omp_get_num_threads(); + } + // create per-thread buffers for first and second derivatives of utility w.r.t. position bias factors + std::vector bias_first_derivatives(num_position_ids_ * num_threads, 0.0); + std::vector bias_second_derivatives(num_position_ids_ * num_threads, 0.0); + std::vector instance_counts(num_position_ids_ * num_threads, 0); + #pragma omp parallel for schedule(guided) + for (data_size_t i = 0; i < num_data_; i++) { + // get thread ID + const int tid = omp_get_thread_num(); + size_t offset = static_cast(positions_[i] + tid * num_position_ids_); + // accumulate first derivatives of utility w.r.t. position bias factors, for each position + bias_first_derivatives[offset] -= lambdas[i]; + // accumulate second derivatives of utility w.r.t. position bias factors, for each position + bias_second_derivatives[offset] -= hessians[i]; + instance_counts[offset]++; + } + #pragma omp parallel for schedule(guided) + for (data_size_t i = 0; i < num_position_ids_; i++) { + double bias_first_derivative = 0.0; + double bias_second_derivative = 0.0; + int instance_count = 0; + // aggregate derivatives from per-thread buffers + for (int tid = 0; tid < num_threads; tid++) { + size_t offset = static_cast(i + tid * num_position_ids_); + bias_first_derivative += bias_first_derivatives[offset]; + bias_second_derivative += bias_second_derivatives[offset]; + instance_count += instance_counts[offset]; + } + // L2 regularization on position bias factors + bias_first_derivative -= pos_biases_[i] * position_bias_regularization_ * instance_count; + bias_second_derivative -= position_bias_regularization_ * instance_count; + // do Newton-Raphson step to update position bias factors + pos_biases_[i] += learning_rate_ * bias_first_derivative / (std::abs(bias_second_derivative) + 0.001); + } + LogDebugPositionBiasFactors(); + } + const char* GetName() const override { return "lambdarank"; } protected: + void LogDebugPositionBiasFactors() const { + std::stringstream message_stream; + message_stream << std::setw(15) << "position" + << std::setw(15) << "bias_factor" + << std::endl; + Log::Debug(message_stream.str().c_str()); + message_stream.str(""); + for (int i = 0; i < num_position_ids_; ++i) { + message_stream << std::setw(15) << position_ids_[i] + << std::setw(15) << pos_biases_[i]; + Log::Debug(message_stream.str().c_str()); + message_stream.str(""); + } + } /*! \brief Sigmoid param */ double sigmoid_; /*! \brief Normalize the lambdas or not */ diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index e9e7179a9b6..25413d7ea07 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -9,6 +9,7 @@ import re from os import getenv from pathlib import Path +from shutil import copyfile import numpy as np import psutil @@ -19,7 +20,7 @@ from sklearn.model_selection import GroupKFold, TimeSeriesSplit, train_test_split import lightgbm as lgb -from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame +from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame, pd_Series from .utils import (SERIALIZERS, dummy_obj, load_breast_cancer, load_digits, load_iris, logistic_sigmoid, make_synthetic_regression, mse_obj, pickle_and_unpickle_object, sklearn_multiclass_custom_objective, @@ -747,6 +748,171 @@ def test_ranking_prediction_early_stopping(): np.testing.assert_allclose(ret_early, ret_early_more_strict) +# Simulates position bias for a given ranking dataset. +# The ouput dataset is identical to the input one with the exception for the relevance labels. +# The new labels are generated according to an instance of a cascade user model: +# for each query, the user is simulated to be traversing the list of documents ranked by a baseline ranker +# (in our example it is simply the ordering by some feature correlated with relevance, e.g., 34) +# and clicks on that document (new_label=1) with some probability 'pclick' depending on its true relevance; +# at each position the user may stop the traversal with some probability pstop. For the non-clicked documents, +# new_label=0. Thus the generated new labels are biased towards the baseline ranker. +# The positions of the documents in the ranked lists produced by the baseline, are returned. +def simulate_position_bias(file_dataset_in, file_query_in, file_dataset_out, baseline_feature): + # a mapping of a document's true relevance (defined on a 5-grade scale) into the probability of clicking it + def get_pclick(label): + if label == 0: + return 0.4 + elif label == 1: + return 0.6 + elif label == 2: + return 0.7 + elif label == 3: + return 0.8 + else: + return 0.9 + # an instantiation of a cascade model where the user stops with probability 0.2 after observing each document + pstop = 0.2 + + f_dataset_in = open(file_dataset_in, 'r') + f_dataset_out = open(file_dataset_out, 'w') + random.seed(10) + positions_all = [] + for line in open(file_query_in): + docs_num = int (line) + lines = [] + index_values = [] + positions = [0] * docs_num + for index in range(docs_num): + features = f_dataset_in.readline().split() + lines.append(features) + val = 0.0 + for feature_val in features: + feature_val_split = feature_val.split(":") + if int(feature_val_split[0]) == baseline_feature: + val = float(feature_val_split[1]) + index_values.append([index, val]) + index_values.sort(key=lambda x: -x[1]) + stop = False + for pos in range(docs_num): + index = index_values[pos][0] + new_label = 0 + if not stop: + label = int(lines[index][0]) + pclick = get_pclick(label) + if random.random() < pclick: + new_label = 1 + stop = random.random() < pstop + lines[index][0] = str(new_label) + positions[index] = pos + for features in lines: + f_dataset_out.write(' '.join(features) + '\n') + positions_all.extend(positions) + f_dataset_out.close() + return positions_all + + +@pytest.mark.skipif(getenv('TASK', '') == 'cuda', reason='Positions in learning to rank is not supported in CUDA version yet') +def test_ranking_with_position_information_with_file(tmp_path): + rank_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'lambdarank' + params = { + 'objective': 'lambdarank', + 'verbose': -1, + 'eval_at': [3], + 'metric': 'ndcg', + 'bagging_freq': 1, + 'bagging_fraction': 0.9, + 'min_data_in_leaf': 50, + 'min_sum_hessian_in_leaf': 5.0 + } + + # simulate position bias for the train dataset and put the train dataset with biased labels to temp directory + positions = simulate_position_bias(str(rank_example_dir / 'rank.train'), str(rank_example_dir / 'rank.train.query'), str(tmp_path / 'rank.train'), baseline_feature=34) + copyfile(str(rank_example_dir / 'rank.train.query'), str(tmp_path / 'rank.train.query')) + copyfile(str(rank_example_dir / 'rank.test'), str(tmp_path / 'rank.test')) + copyfile(str(rank_example_dir / 'rank.test.query'), str(tmp_path / 'rank.test.query')) + + lgb_train = lgb.Dataset(str(tmp_path / 'rank.train'), params=params) + lgb_valid = [lgb_train.create_valid(str(tmp_path / 'rank.test'))] + gbm_baseline = lgb.train(params, lgb_train, valid_sets = lgb_valid, num_boost_round=50) + + f_positions_out = open(str(tmp_path / 'rank.train.position'), 'w') + for pos in positions: + f_positions_out.write(str(pos) + '\n') + f_positions_out.close() + + lgb_train = lgb.Dataset(str(tmp_path / 'rank.train'), params=params) + lgb_valid = [lgb_train.create_valid(str(tmp_path / 'rank.test'))] + gbm_unbiased_with_file = lgb.train(params, lgb_train, valid_sets = lgb_valid, num_boost_round=50) + + # the performance of the unbiased LambdaMART should outperform the plain LambdaMART on the dataset with position bias + assert gbm_baseline.best_score['valid_0']['ndcg@3'] + 0.03 <= gbm_unbiased_with_file.best_score['valid_0']['ndcg@3'] + + # add extra row to position file + with open(str(tmp_path / 'rank.train.position'), 'a') as file: + file.write('pos_1000\n') + file.close() + lgb_train = lgb.Dataset(str(tmp_path / 'rank.train'), params=params) + lgb_valid = [lgb_train.create_valid(str(tmp_path / 'rank.test'))] + with pytest.raises(lgb.basic.LightGBMError, match="Positions size \(3006\) doesn't match data size"): + lgb.train(params, lgb_train, valid_sets = lgb_valid, num_boost_round=50) + + +@pytest.mark.skipif(getenv('TASK', '') == 'cuda', reason='Positions in learning to rank is not supported in CUDA version yet') +def test_ranking_with_position_information_with_dataset_constructor(tmp_path): + rank_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'lambdarank' + params = { + 'objective': 'lambdarank', + 'verbose': -1, + 'eval_at': [3], + 'metric': 'ndcg', + 'bagging_freq': 1, + 'bagging_fraction': 0.9, + 'min_data_in_leaf': 50, + 'min_sum_hessian_in_leaf': 5.0, + 'num_threads': 1, + 'deterministic': True, + 'seed': 0 + } + + # simulate position bias for the train dataset and put the train dataset with biased labels to temp directory + positions = simulate_position_bias(str(rank_example_dir / 'rank.train'), str(rank_example_dir / 'rank.train.query'), str(tmp_path / 'rank.train'), baseline_feature=34) + copyfile(str(rank_example_dir / 'rank.train.query'), str(tmp_path / 'rank.train.query')) + copyfile(str(rank_example_dir / 'rank.test'), str(tmp_path / 'rank.test')) + copyfile(str(rank_example_dir / 'rank.test.query'), str(tmp_path / 'rank.test.query')) + + lgb_train = lgb.Dataset(str(tmp_path / 'rank.train'), params=params) + lgb_valid = [lgb_train.create_valid(str(tmp_path / 'rank.test'))] + gbm_baseline = lgb.train(params, lgb_train, valid_sets = lgb_valid, num_boost_round=50) + + positions = np.array(positions) + + # test setting positions through Dataset constructor with numpy array + lgb_train = lgb.Dataset(str(tmp_path / 'rank.train'), params=params, position=positions) + lgb_valid = [lgb_train.create_valid(str(tmp_path / 'rank.test'))] + gbm_unbiased = lgb.train(params, lgb_train, valid_sets = lgb_valid, num_boost_round=50) + + # the performance of the unbiased LambdaMART should outperform the plain LambdaMART on the dataset with position bias + assert gbm_baseline.best_score['valid_0']['ndcg@3'] + 0.03 <= gbm_unbiased.best_score['valid_0']['ndcg@3'] + + if PANDAS_INSTALLED: + # test setting positions through Dataset constructor with pandas Series + lgb_train = lgb.Dataset(str(tmp_path / 'rank.train'), params=params, position=pd_Series(positions)) + lgb_valid = [lgb_train.create_valid(str(tmp_path / 'rank.test'))] + gbm_unbiased_pandas_series = lgb.train(params, lgb_train, valid_sets = lgb_valid, num_boost_round=50) + assert gbm_unbiased.best_score['valid_0']['ndcg@3'] == gbm_unbiased_pandas_series.best_score['valid_0']['ndcg@3'] + + # test setting positions through set_position + lgb_train = lgb.Dataset(str(tmp_path / 'rank.train'), params=params) + lgb_valid = [lgb_train.create_valid(str(tmp_path / 'rank.test'))] + lgb_train.set_position(positions) + gbm_unbiased_set_position = lgb.train(params, lgb_train, valid_sets = lgb_valid, num_boost_round=50) + assert gbm_unbiased.best_score['valid_0']['ndcg@3'] == gbm_unbiased_set_position.best_score['valid_0']['ndcg@3'] + + # test get_position works + positions_from_get = lgb_train.get_position() + np.testing.assert_array_equal(positions_from_get, positions) + + def test_early_stopping(): X, y = load_breast_cancer(return_X_y=True) params = {