Skip to content

Commit

Permalink
Treat position bias via GAM in LambdaMART (#5929)
Browse files Browse the repository at this point in the history
* Update dataset.h

* Update metadata.cpp

* Update rank_objective.hpp

* Update metadata.cpp

* Update rank_objective.hpp

* Update metadata.cpp

* Update dataset.h

* Update rank_objective.hpp

* Update metadata.cpp

* Update test_engine.py

* Update test_engine.py

* Add files via upload

* Update test_engine.py

* Update test_engine.py

* Update test_engine.py

* Update test_engine.py

* Update test_engine.py

* Update _rank.train.position

* Update test_engine.py

* Update test_engine.py

* Update test_engine.py

* Update test_engine.py

* Update _rank.train.position

* Update _rank.train.position

* Update test_engine.py

* Update _rank.train.position

* Update test_engine.py

* Update test_engine.py

* Update test_engine.py

* Update test_engine.py

* Update test_engine.py

* Update the position of import statement

* Update rank_objective.hpp

* Update config.h

* Update config_auto.cpp

* Update rank_objective.hpp

* Update rank_objective.hpp

* update documentation

* remove extra blank line

* Update src/io/metadata.cpp

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update src/io/metadata.cpp

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* remove _rank.train.position

* add position in python API

* fix set_positions in basic.py

* Update Advanced-Topics.rst

* Update Advanced-Topics.rst

* Update Advanced-Topics.rst

* Update Advanced-Topics.rst

* Update Advanced-Topics.rst

* Update Advanced-Topics.rst

* Update Advanced-Topics.rst

* Update Advanced-Topics.rst

* Update Advanced-Topics.rst

* Update Advanced-Topics.rst

* Update Advanced-Topics.rst

* Update docs/Advanced-Topics.rst

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update docs/Advanced-Topics.rst

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update Advanced-Topics.rst

* Update Advanced-Topics.rst

* Update Advanced-Topics.rst

* Update Advanced-Topics.rst

* remove List from _LGBM_PositionType

* move new position parameter to the last in Dataset constructor

* add position_filename as a parameter

* Update docs/Advanced-Topics.rst

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update docs/Advanced-Topics.rst

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update Advanced-Topics.rst

* Update src/objective/rank_objective.hpp

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update src/io/metadata.cpp

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update metadata.cpp

* Update python-package/lightgbm/basic.py

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update python-package/lightgbm/basic.py

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update python-package/lightgbm/basic.py

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update python-package/lightgbm/basic.py

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update src/io/metadata.cpp

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* more infomrative fatal message

address more comments

* update documentation for more flexible position specification

* fix SetPosition

add tests for get_position and set_position

* remove position_filename

* remove useless changes

* Update python-package/lightgbm/basic.py

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* remove useless files

* move position file when position set in Dataset

* warn when positions are overwritten

* skip ranking with position test in cuda

* split test case

* remove useless import

* Update test_engine.py

* Update test_engine.py

* Update test_engine.py

* Update docs/Advanced-Topics.rst

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update Parameters.rst

* Update rank_objective.hpp

* Update config.h

* update config_auto.cppp

* Update docs/Advanced-Topics.rst

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* fix randomness in test case for gpu

---------

Co-authored-by: shiyu1994 <shiyu_k1994@qq.com>
Co-authored-by: James Lamb <jaylamb20@gmail.com>
  • Loading branch information
3 people committed Sep 4, 2023
1 parent 1881a50 commit 7e34d23
Show file tree
Hide file tree
Showing 10 changed files with 522 additions and 11 deletions.
41 changes: 41 additions & 0 deletions docs/Advanced-Topics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,44 @@ Recommendations for gcc Users (MinGW, \*nix)
--------------------------------------------

- Refer to `gcc Tips <./gcc-Tips.rst>`__.

Support for Position Bias Treatment
------------------------------------

Often the relevance labels provided in Learning-to-Rank tasks might be derived from implicit user feedback (e.g., clicks) and therefore might be biased due to their position/location on the screen when having been presented to a user.
LightGBM can make use of positional data.

For example, consider the case where you expect that the first 3 results from a search engine will be visible in users' browsers without scrolling, and all other results for a query would require scrolling.

LightGBM could be told to account for the position bias from results being "above the fold" by providing a ``positions`` array encoded as follows:

::

0
0
0
1
1
0
0
0
1
...

Where ``0 = "above the fold"`` and ``1 = "requires scrolling"``.
The specific values are not important, as long as they are consistent across all observations in the training data.
An encoding like ``100 = "above the fold"`` and ``17 = "requires scrolling"`` would result in exactly the same trained model.

In that way, ``positions`` in LightGBM's API are similar to a categorical feature.
Just as with non-ordinal categorical features, an integer representation is just used for memory and computational efficiency... LightGBM does not care about the absolute or relative magnitude of the values.

Unlike a categorical feature, however, ``positions`` are used to adjust the target to reduce the bias in predictions made by the trained model.

The position file corresponds with training data file line by line, and has one position per line. And if the name of training data file is ``train.txt``, the position file should be named as ``train.txt.position`` and placed in the same folder as the data file.
In this case, LightGBM will load the position file automatically if it exists. The positions can also be specified through the ``Dataset`` constructor when using Python API. If the positions are specified in both approaches, the ``.position`` file will be ignored.

Currently, implemented is an approach to model position bias by using an idea of Generalized Additive Models (`GAM <https://en.wikipedia.org/wiki/Generalized_additive_model>`_) to linearly decompose the document score ``s`` into the sum of a relevance component ``f`` and a positional component ``g``: ``s(x, pos) = f(x) + g(pos)`` where the former component depends on the original query-document features and the latter depends on the position of an item.
During the training, the compound scoring function ``s(x, pos)`` is fit with a standard ranking algorithm (e.g., LambdaMART) which boils down to jointly learning the relevance component ``f(x)`` (it is later returned as an unbiased model) and the position factors ``g(pos)`` that help better explain the observed (biased) labels.
Similar score decomposition ideas have previously been applied for classification & pointwise ranking tasks with assumptions of binary labels and binary relevance (a.k.a. "two-tower" models, refer to the papers: `Towards Disentangling Relevance and Bias in Unbiased Learning to Rank <https://arxiv.org/abs/2212.13937>`_, `PAL: a position-bias aware learning framework for CTR prediction in live recommender systems <https://dl.acm.org/doi/10.1145/3298689.3347033>`_, `A General Framework for Debiasing in CTR Prediction <https://arxiv.org/abs/2112.02767>`_).
In LightGBM, we adapt this idea to general pairwise Lerarning-to-Rank with arbitrary ordinal relevance labels.
Besides, GAMs have been used in the context of explainable ML (`Accurate Intelligible Models with Pairwise Interactions <https://www.cs.cornell.edu/~yinlou/papers/lou-kdd13.pdf>`_) to linearly decompose the contribution of each feature (and possibly their pairwise interactions) to the overall score, for subsequent analysis and interpretation of their effects in the trained models.
4 changes: 4 additions & 0 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,10 @@ Objective Parameters

- separate by ``,``

- ``lambdarank_position_bias_regularization`` :raw-html:`<a id="lambdarank_position_bias_regularization" title="Permalink to this parameter" href="#lambdarank_position_bias_regularization">&#x1F517;&#xFE0E;</a>`, default = ``0.0``, type = double, constraints: ``lambdarank_position_bias_regularization >= 0.0``

- used only in ``lambdarank`` application when positional information is provided and position bias is modeled. Larger values reduce the inferred position bias factors.

Metric Parameters
-----------------

Expand Down
4 changes: 4 additions & 0 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,10 @@ struct Config {
// desc = separate by ``,``
std::vector<double> label_gain;

// check = >=0.0
// desc = used only in ``lambdarank`` application when positional information is provided and position bias is modeled. Larger values reduce the inferred position bias factors.
double lambdarank_position_bias_regularization = 0.0;

#ifndef __NVCC__
#pragma endregion

Expand Down
43 changes: 43 additions & 0 deletions include/LightGBM/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ class Metadata {

void SetQuery(const data_size_t* query, data_size_t len);

void SetPosition(const data_size_t* position, data_size_t len);

/*!
* \brief Set initial scores
* \param init_score Initial scores, this class will manage memory for init_score.
Expand Down Expand Up @@ -213,6 +215,38 @@ class Metadata {
}
}

/*!
* \brief Get positions, if does not exist then return nullptr
* \return Pointer of positions
*/
inline const data_size_t* positions() const {
if (!positions_.empty()) {
return positions_.data();
} else {
return nullptr;
}
}

/*!
* \brief Get position IDs, if does not exist then return nullptr
* \return Pointer of position IDs
*/
inline const std::string* position_ids() const {
if (!position_ids_.empty()) {
return position_ids_.data();
} else {
return nullptr;
}
}

/*!
* \brief Get Number of different position IDs
* \return number of different position IDs
*/
inline size_t num_position_ids() const {
return position_ids_.size();
}

/*!
* \brief Get data boundaries on queries, if not exists, will return nullptr
* we assume data will order by query,
Expand Down Expand Up @@ -289,6 +323,8 @@ class Metadata {
private:
/*! \brief Load wights from file */
void LoadWeights();
/*! \brief Load positions from file */
void LoadPositions();
/*! \brief Load query boundaries from file */
void LoadQueryBoundaries();
/*! \brief Calculate query weights from queries */
Expand All @@ -309,10 +345,16 @@ class Metadata {
data_size_t num_data_;
/*! \brief Number of weights, used to check correct weight file */
data_size_t num_weights_;
/*! \brief Number of positions, used to check correct position file */
data_size_t num_positions_;
/*! \brief Label data */
std::vector<label_t> label_;
/*! \brief Weights data */
std::vector<label_t> weights_;
/*! \brief Positions data */
std::vector<data_size_t> positions_;
/*! \brief Position identifiers */
std::vector<std::string> position_ids_;
/*! \brief Query boundaries */
std::vector<data_size_t> query_boundaries_;
/*! \brief Query weights */
Expand All @@ -328,6 +370,7 @@ class Metadata {
/*! \brief mutex for threading safe call */
std::mutex mutex_;
bool weight_load_from_file_;
bool position_load_from_file_;
bool query_load_from_file_;
bool init_score_load_from_file_;
#ifdef USE_CUDA
Expand Down
68 changes: 60 additions & 8 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@
np.ndarray,
pd_Series
]
_LGBM_PositionType = Union[
np.ndarray,
pd_Series
]
_LGBM_InitScoreType = Union[
List[float],
List[List[float]],
Expand Down Expand Up @@ -577,7 +581,8 @@ def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_va
"label": _C_API_DTYPE_FLOAT32,
"weight": _C_API_DTYPE_FLOAT32,
"init_score": _C_API_DTYPE_FLOAT64,
"group": _C_API_DTYPE_INT32
"group": _C_API_DTYPE_INT32,
"position": _C_API_DTYPE_INT32
}

"""String name to int feature importance type mapper"""
Expand Down Expand Up @@ -1525,7 +1530,8 @@ def __init__(
feature_name: _LGBM_FeatureNameConfiguration = 'auto',
categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto',
params: Optional[Dict[str, Any]] = None,
free_raw_data: bool = True
free_raw_data: bool = True,
position: Optional[_LGBM_PositionType] = None,
):
"""Initialize Dataset.
Expand Down Expand Up @@ -1565,13 +1571,16 @@ def __init__(
Other parameters for Dataset.
free_raw_data : bool, optional (default=True)
If True, raw data is freed after constructing inner Dataset.
position : numpy 1-D array, pandas Series or None, optional (default=None)
Position of items used in unbiased learning-to-rank task.
"""
self._handle: Optional[_DatasetHandle] = None
self.data = data
self.label = label
self.reference = reference
self.weight = weight
self.group = group
self.position = position
self.init_score = init_score
self.feature_name: _LGBM_FeatureNameConfiguration = feature_name
self.categorical_feature: _LGBM_CategoricalFeatureConfiguration = categorical_feature
Expand Down Expand Up @@ -1836,7 +1845,8 @@ def _lazy_init(
predictor: Optional[_InnerPredictor],
feature_name: _LGBM_FeatureNameConfiguration,
categorical_feature: _LGBM_CategoricalFeatureConfiguration,
params: Optional[Dict[str, Any]]
params: Optional[Dict[str, Any]],
position: Optional[_LGBM_PositionType]
) -> "Dataset":
if data is None:
self._handle = None
Expand Down Expand Up @@ -1925,6 +1935,8 @@ def _lazy_init(
self.set_weight(weight)
if group is not None:
self.set_group(group)
if position is not None:
self.set_position(position)
if isinstance(predictor, _InnerPredictor):
if self._predictor is None and init_score is not None:
_log_warning("The init_score will be overridden by the prediction of init_model.")
Expand Down Expand Up @@ -2219,7 +2231,7 @@ def construct(self) -> "Dataset":
if self.used_indices is None:
# create valid
self._lazy_init(data=self.data, label=self.label, reference=self.reference,
weight=self.weight, group=self.group,
weight=self.weight, group=self.group, position=self.position,
init_score=self.init_score, predictor=self._predictor,
feature_name=self.feature_name, categorical_feature='auto', params=self.params)
else:
Expand All @@ -2242,6 +2254,8 @@ def construct(self) -> "Dataset":
self.get_data()
if self.group is not None:
self.set_group(self.group)
if self.position is not None:
self.set_position(self.position)
if self.get_label() is None:
raise ValueError("Label should not be None.")
if isinstance(self._predictor, _InnerPredictor) and self._predictor is not self.reference._predictor:
Expand All @@ -2256,7 +2270,8 @@ def construct(self) -> "Dataset":
self._lazy_init(data=self.data, label=self.label, reference=None,
weight=self.weight, group=self.group,
init_score=self.init_score, predictor=self._predictor,
feature_name=self.feature_name, categorical_feature=self.categorical_feature, params=self.params)
feature_name=self.feature_name, categorical_feature=self.categorical_feature,
params=self.params, position=self.position)
if self.free_raw_data:
self.data = None
self.feature_name = self.get_feature_name()
Expand All @@ -2269,7 +2284,8 @@ def create_valid(
weight: Optional[_LGBM_WeightType] = None,
group: Optional[_LGBM_GroupType] = None,
init_score: Optional[_LGBM_InitScoreType] = None,
params: Optional[Dict[str, Any]] = None
params: Optional[Dict[str, Any]] = None,
position: Optional[_LGBM_PositionType] = None
) -> "Dataset":
"""Create validation data align with current Dataset.
Expand All @@ -2292,14 +2308,16 @@ def create_valid(
Init score for Dataset.
params : dict or None, optional (default=None)
Other parameters for validation Dataset.
position : numpy 1-D array, pandas Series or None, optional (default=None)
Position of items used in unbiased learning-to-rank task.
Returns
-------
valid : Dataset
Validation Dataset with reference to self.
"""
ret = Dataset(data, label=label, reference=self,
weight=weight, group=group, init_score=init_score,
weight=weight, group=group, position=position, init_score=init_score,
params=params, free_raw_data=self.free_raw_data)
ret._predictor = self._predictor
ret.pandas_categorical = self.pandas_categorical
Expand Down Expand Up @@ -2434,7 +2452,7 @@ def set_field(
'In multiclass classification init_score can also be a list of lists, numpy 2-D array or pandas DataFrame.'
)
else:
dtype = np.int32 if field_name == 'group' else np.float32
dtype = np.int32 if (field_name == 'group' or field_name == 'position') else np.float32
data = _list_to_1d_numpy(data, dtype=dtype, name=field_name)

ptr_data: Union[_ctypes_float_ptr, _ctypes_int_ptr]
Expand Down Expand Up @@ -2727,6 +2745,28 @@ def set_group(
self.set_field('group', group)
return self

def set_position(
self,
position: Optional[_LGBM_PositionType]
) -> "Dataset":
"""Set position of Dataset (used for ranking).
Parameters
----------
position : numpy 1-D array, pandas Series or None, optional (default=None)
Position of items used in unbiased learning-to-rank task.
Returns
-------
self : Dataset
Dataset with set position.
"""
self.position = position
if self._handle is not None and position is not None:
position = _list_to_1d_numpy(position, dtype=np.int32, name='position')
self.set_field('position', position)
return self

def get_feature_name(self) -> List[str]:
"""Get the names of columns (features) in the Dataset.
Expand Down Expand Up @@ -2853,6 +2893,18 @@ def get_group(self) -> Optional[np.ndarray]:
self.group = np.diff(self.group)
return self.group

def get_position(self) -> Optional[np.ndarray]:
"""Get the position of the Dataset.
Returns
-------
position : numpy 1-D array or None
Position of items used in unbiased learning-to-rank task.
"""
if self.position is None:
self.position = self.get_field('position')
return self.position

def num_data(self) -> int:
"""Get the number of rows in the Dataset.
Expand Down
7 changes: 7 additions & 0 deletions src/io/config_auto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ const std::unordered_set<std::string>& Config::parameter_set() {
"lambdarank_truncation_level",
"lambdarank_norm",
"label_gain",
"lambdarank_position_bias_regularization",
"metric",
"metric_freq",
"is_provide_training_metric",
Expand Down Expand Up @@ -619,6 +620,9 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
label_gain = Common::StringToArray<double>(tmp_str, ',');
}

GetDouble(params, "lambdarank_position_bias_regularization", &lambdarank_position_bias_regularization);
CHECK_GE(lambdarank_position_bias_regularization, 0.0);

GetInt(params, "metric_freq", &metric_freq);
CHECK_GT(metric_freq, 0);

Expand Down Expand Up @@ -754,6 +758,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[lambdarank_truncation_level: " << lambdarank_truncation_level << "]\n";
str_buf << "[lambdarank_norm: " << lambdarank_norm << "]\n";
str_buf << "[label_gain: " << Common::Join(label_gain, ",") << "]\n";
str_buf << "[lambdarank_position_bias_regularization: " << lambdarank_position_bias_regularization << "]\n";
str_buf << "[eval_at: " << Common::Join(eval_at, ",") << "]\n";
str_buf << "[multi_error_top_k: " << multi_error_top_k << "]\n";
str_buf << "[auc_mu_weights: " << Common::Join(auc_mu_weights, ",") << "]\n";
Expand Down Expand Up @@ -893,6 +898,7 @@ const std::unordered_map<std::string, std::vector<std::string>>& Config::paramet
{"lambdarank_truncation_level", {}},
{"lambdarank_norm", {}},
{"label_gain", {}},
{"lambdarank_position_bias_regularization", {}},
{"metric", {"metrics", "metric_types"}},
{"metric_freq", {"output_freq"}},
{"is_provide_training_metric", {"training_metric", "is_training_metric", "train_metric"}},
Expand Down Expand Up @@ -1035,6 +1041,7 @@ const std::unordered_map<std::string, std::string>& Config::ParameterTypes() {
{"lambdarank_truncation_level", "int"},
{"lambdarank_norm", "bool"},
{"label_gain", "vector<double>"},
{"lambdarank_position_bias_regularization", "double"},
{"metric", "vector<string>"},
{"metric_freq", "int"},
{"is_provide_training_metric", "bool"},
Expand Down
5 changes: 5 additions & 0 deletions src/io/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,8 @@ bool Dataset::SetIntField(const char* field_name, const int* field_data,
name = Common::Trim(name);
if (name == std::string("query") || name == std::string("group")) {
metadata_.SetQuery(field_data, num_element);
} else if (name == std::string("position")) {
metadata_.SetPosition(field_data, num_element);
} else {
return false;
}
Expand Down Expand Up @@ -987,6 +989,9 @@ bool Dataset::GetIntField(const char* field_name, data_size_t* out_len,
if (name == std::string("query") || name == std::string("group")) {
*out_ptr = metadata_.query_boundaries();
*out_len = metadata_.num_queries() + 1;
} else if (name == std::string("position")) {
*out_ptr = metadata_.positions();
*out_len = num_data_;
} else {
return false;
}
Expand Down

0 comments on commit 7e34d23

Please sign in to comment.