Skip to content

Commit

Permalink
Decouple Boosting Types (fixes #3128) (#4827)
Browse files Browse the repository at this point in the history
* add parameter data_sample_strategy

* abstract GOSS as a sample strategy(GOSS1), togetherwith origial GOSS (Normal Bagging has not been abstracted, so do NOT use it now)

* abstract Bagging as a subclass (BAGGING), but original Bagging members in GBDT are still kept

* fix some variables

* remove GOSS(as boost) and Bagging logic in GBDT

* rename GOSS1 to GOSS(as sample strategy)

* add warning about use GOSS as boosting_type

* a little ; bug

* remove CHECK when "gradients != nullptr"

* rename DataSampleStrategy to avoid confusion

* remove and add some ccomments, followingconvention

* fix bug about GBDT::ResetConfig (ObjectiveFunction inconsistencty bet…

* add std::ignore to avoid compiler warnings (anpotential fails)

* update Makevars and vcxproj

* handle constant hessian

move resize of gradient vectors out of sample strategy

* mark override for IsHessianChange

* fix lint errors

* rerun parameter_generator.py

* update config_auto.cpp

* delete redundant blank line

* update num_data_ when train_data_ is updated

set gradients and hessians when GOSS

* check bagging_freq is not zero

* reset config_ value

merge ResetBaggingConfig and ResetGOSS

* remove useless check

* add ttests in test_engine.py

* remove whitespace in blank line

* remove arguments verbose_eval and evals_result

* Update tests/python_package_test/test_engine.py

reduce num_boost_round

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update tests/python_package_test/test_engine.py

reduce num_boost_round

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update tests/python_package_test/test_engine.py

reduce num_boost_round

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update tests/python_package_test/test_engine.py

reduce num_boost_round

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update tests/python_package_test/test_engine.py

reduce num_boost_round

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update tests/python_package_test/test_engine.py

reduce num_boost_round

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update src/boosting/sample_strategy.cpp

modify warning about setting goss as `boosting_type`

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update tests/python_package_test/test_engine.py

replace load_boston() with make_regression()

remove value checks of mean_squared_error in test_sample_strategy_with_boosting()

* Update tests/python_package_test/test_engine.py

add value checks of mean_squared_error in test_sample_strategy_with_boosting()

* Modify warnning about using goss as boosting type

* Update tests/python_package_test/test_engine.py

add random_state=42 for make_regression()

reduce the threshold of mean_square_error

* Update src/boosting/sample_strategy.cpp

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* remove goss from boosting types in documentation

* Update src/boosting/bagging.hpp

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* Update src/boosting/bagging.hpp

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* Update src/boosting/goss.hpp

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* Update src/boosting/goss.hpp

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* rename GOSS with GOSSStrategy

* update doc

* address comments

* fix table in doc

* Update include/LightGBM/config.h

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* update documentation

* update test case

* revert useless change in test_engine.py

* add tests for evaluation results in test_sample_strategy_with_boosting

* include <string>

* change to assert_allclose in test_goss_boosting_and_strategy_equivalent

* more tolerance in result checking, due to minor difference in results of gpu versions

* change == to np.testing.assert_allclose

* fix test case

* set gpu_use_dp to true

* change --report to --report-level for rstcheck

* use gpu_use_dp=true in test_goss_boosting_and_strategy_equivalent

* revert unexpected changes of non-ascii characters

* revert unexpected changes of non-ascii characters

* remove useless changes

* allocate gradients_pointer_ and hessians_pointer when necessary

* add spaces

* remove redundant virtual

* include <LightGBM/utils/log.h> for USE_CUDA

* check for  in test_goss_boosting_and_strategy_equivalent

* check for identity in test_sample_strategy_with_boosting

* remove cuda  option in test_sample_strategy_with_boosting

* Update tests/python_package_test/test_engine.py

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* Update tests/python_package_test/test_engine.py

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* ResetGradientBuffers after ResetSampleConfig

* ResetGradientBuffers after ResetSampleConfig

* ResetGradientBuffers after bagging

* remove useless code

* check objective_function_ instead of gradients

* enable rf with goss

simplify params in test cases

* remove useless changes

* allow rf with feature subsampling alone

* change position of ResetGradientBuffers

* check for dask

* add parameter types for data_sample_strategy

Co-authored-by: Guangda Liu <v-guangdaliu@microsoft.com>
Co-authored-by: Yu Shi <shiyu_k1994@qq.com>
Co-authored-by: GuangdaLiu <90019144+GuangdaLiu@users.noreply.github.com>
Co-authored-by: James Lamb <jaylamb20@gmail.com>
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
  • Loading branch information
6 people committed Dec 28, 2022
1 parent a2ae6b9 commit fffd066
Show file tree
Hide file tree
Showing 23 changed files with 694 additions and 433 deletions.
1 change: 1 addition & 0 deletions R-package/src/Makevars.in
Expand Up @@ -26,6 +26,7 @@ OBJECTS = \
boosting/gbdt_model_text.o \
boosting/gbdt_prediction.o \
boosting/prediction_early_stop.o \
boosting/sample_strategy.o \
io/bin.o \
io/config.o \
io/config_auto.o \
Expand Down
1 change: 1 addition & 0 deletions R-package/src/Makevars.win.in
Expand Up @@ -27,6 +27,7 @@ OBJECTS = \
boosting/gbdt_model_text.o \
boosting/gbdt_prediction.o \
boosting/prediction_early_stop.o \
boosting/sample_strategy.o \
io/bin.o \
io/config.o \
io/config_auto.o \
Expand Down
2 changes: 1 addition & 1 deletion docs/Development-Guide.rst
Expand Up @@ -19,7 +19,7 @@ Important Classes
+-------------------------+----------------------------------------------------------------------------------------+
| ``Bin`` | Data structure used for storing feature discrete values (converted from float values) |
+-------------------------+----------------------------------------------------------------------------------------+
| ``Boosting`` | Boosting interface (GBDT, DART, GOSS, etc.) |
| ``Boosting`` | Boosting interface (GBDT, DART, etc.) |
+-------------------------+----------------------------------------------------------------------------------------+
| ``Config`` | Stores parameters and configurations |
+-------------------------+----------------------------------------------------------------------------------------+
Expand Down
14 changes: 10 additions & 4 deletions docs/Parameters.rst
Expand Up @@ -127,18 +127,24 @@ Core Parameters

- label should be ``int`` type, and larger number represents the higher relevance (e.g. 0:bad, 1:fair, 2:good, 3:perfect)

- ``boosting`` :raw-html:`<a id="boosting" title="Permalink to this parameter" href="#boosting">&#x1F517;&#xFE0E;</a>`, default = ``gbdt``, type = enum, options: ``gbdt``, ``rf``, ``dart``, ``goss``, aliases: ``boosting_type``, ``boost``
- ``boosting`` :raw-html:`<a id="boosting" title="Permalink to this parameter" href="#boosting">&#x1F517;&#xFE0E;</a>`, default = ``gbdt``, type = enum, options: ``gbdt``, ``rf``, ``dart``, aliases: ``boosting_type``, ``boost``

- ``gbdt``, traditional Gradient Boosting Decision Tree, aliases: ``gbrt``

- ``rf``, Random Forest, aliases: ``random_forest``

- ``dart``, `Dropouts meet Multiple Additive Regression Trees <https://arxiv.org/abs/1505.01866>`__

- ``goss``, Gradient-based One-Side Sampling

- **Note**: internally, LightGBM uses ``gbdt`` mode for the first ``1 / learning_rate`` iterations

- ``data_sample_strategy`` :raw-html:`<a id="data_sample_strategy" title="Permalink to this parameter" href="#data_sample_strategy">&#x1F517;&#xFE0E;</a>`, default = ``bagging``, type = enum, options: ``bagging``, ``goss``

- ``bagging``, Randomly Bagging Sampling

- **Note**: ``bagging`` is only effective when ``bagging_freq > 0`` and ``bagging_fraction < 1.0``

- ``goss``, Gradient-based One-Side Sampling

- ``data`` :raw-html:`<a id="data" title="Permalink to this parameter" href="#data">&#x1F517;&#xFE0E;</a>`, default = ``""``, type = string, aliases: ``train``, ``train_data``, ``train_data_file``, ``data_filename``

- path of training data, LightGBM will train from this data
Expand Down Expand Up @@ -268,7 +274,7 @@ Learning Control Parameters

- ``num_threads`` is relatively small, e.g. ``<= 16``

- you want to use small ``bagging_fraction`` or ``goss`` boosting to speed up
- you want to use small ``bagging_fraction`` or ``goss`` sample strategy to speed up

- **Note**: setting this to ``true`` will double the memory cost for Dataset object. If you have not enough memory, you can try setting ``force_col_wise=true``

Expand Down
13 changes: 10 additions & 3 deletions include/LightGBM/config.h
Expand Up @@ -153,14 +153,21 @@ struct Config {
// [doc-only]
// type = enum
// alias = boosting_type, boost
// options = gbdt, rf, dart, goss
// options = gbdt, rf, dart
// desc = ``gbdt``, traditional Gradient Boosting Decision Tree, aliases: ``gbrt``
// desc = ``rf``, Random Forest, aliases: ``random_forest``
// desc = ``dart``, `Dropouts meet Multiple Additive Regression Trees <https://arxiv.org/abs/1505.01866>`__
// desc = ``goss``, Gradient-based One-Side Sampling
// descl2 = **Note**: internally, LightGBM uses ``gbdt`` mode for the first ``1 / learning_rate`` iterations
std::string boosting = "gbdt";

// [doc-only]
// type = enum
// options = bagging, goss
// desc = ``bagging``, Randomly Bagging Sampling
// descl2 = **Note**: ``bagging`` is only effective when ``bagging_freq > 0`` and ``bagging_fraction < 1.0``
// desc = ``goss``, Gradient-based One-Side Sampling
std::string data_sample_strategy = "bagging";

// alias = train, train_data, train_data_file, data_filename
// desc = path of training data, LightGBM will train from this data
// desc = **Note**: can be used only in CLI version
Expand Down Expand Up @@ -263,7 +270,7 @@ struct Config {
// desc = enabling this is recommended when:
// descl2 = the number of data points is large, and the total number of bins is relatively small
// descl2 = ``num_threads`` is relatively small, e.g. ``<= 16``
// descl2 = you want to use small ``bagging_fraction`` or ``goss`` boosting to speed up
// descl2 = you want to use small ``bagging_fraction`` or ``goss`` sample strategy to speed up
// desc = **Note**: setting this to ``true`` will double the memory cost for Dataset object. If you have not enough memory, you can try setting ``force_col_wise=true``
// desc = **Note**: when both ``force_col_wise`` and ``force_row_wise`` are ``false``, LightGBM will firstly try them both, and then use the faster one. To remove the overhead of testing set the faster one to ``true`` manually
// desc = **Note**: this parameter cannot be used at the same time with ``force_col_wise``, choose only one of them
Expand Down
3 changes: 2 additions & 1 deletion include/LightGBM/cuda/cuda_utils.h
Expand Up @@ -10,10 +10,10 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <LightGBM/utils/log.h>
#endif // USE_CUDA || USE_CUDA_EXP

#ifdef USE_CUDA_EXP
#include <LightGBM/utils/log.h>
#include <vector>
#endif // USE_CUDA_EXP

Expand Down Expand Up @@ -124,6 +124,7 @@ class CUDAVector {
}
if (size == 0) {
Clear();
return;
}
T* new_data = nullptr;
AllocateCUDAMemory<T>(&new_data, size, __FILE__, __LINE__);
Expand Down
83 changes: 83 additions & 0 deletions include/LightGBM/sample_strategy.h
@@ -0,0 +1,83 @@
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/

#ifndef LIGHTGBM_SAMPLE_STRATEGY_H_
#define LIGHTGBM_SAMPLE_STRATEGY_H_

#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/utils/random.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/threading.h>
#include <LightGBM/config.h>
#include <LightGBM/dataset.h>
#include <LightGBM/tree_learner.h>
#include <LightGBM/objective_function.h>

#include <memory>
#include <vector>

namespace LightGBM {

class SampleStrategy {
public:
SampleStrategy() : balanced_bagging_(false), bagging_runner_(0, bagging_rand_block_), need_resize_gradients_(false) {}

virtual ~SampleStrategy() {}

static SampleStrategy* CreateSampleStrategy(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function, int num_tree_per_iteration);

virtual void Bagging(int iter, TreeLearner* tree_learner, score_t* gradients, score_t* hessians) = 0;

virtual void ResetSampleConfig(const Config* config, bool is_change_dataset) = 0;

bool is_use_subset() const { return is_use_subset_; }

data_size_t bag_data_cnt() const { return bag_data_cnt_; }

std::vector<data_size_t, Common::AlignmentAllocator<data_size_t, kAlignedSize>>& bag_data_indices() { return bag_data_indices_; }

#ifdef USE_CUDA_EXP
CUDAVector<data_size_t>& cuda_bag_data_indices() { return cuda_bag_data_indices_; }
#endif // USE_CUDA_EXP

void UpdateObjectiveFunction(const ObjectiveFunction* objective_function) {
objective_function_ = objective_function;
}

void UpdateTrainingData(const Dataset* train_data) {
train_data_ = train_data;
num_data_ = train_data->num_data();
}

virtual bool IsHessianChange() const = 0;

bool NeedResizeGradients() const { return need_resize_gradients_; }

protected:
const Config* config_;
const Dataset* train_data_;
const ObjectiveFunction* objective_function_;
std::vector<data_size_t, Common::AlignmentAllocator<data_size_t, kAlignedSize>> bag_data_indices_;
data_size_t bag_data_cnt_;
data_size_t num_data_;
int num_tree_per_iteration_;
std::unique_ptr<Dataset> tmp_subset_;
bool is_use_subset_;
bool balanced_bagging_;
const int bagging_rand_block_ = 1024;
std::vector<Random> bagging_rands_;
ParallelPartitionRunner<data_size_t, false> bagging_runner_;
/*! \brief whether need to resize the gradient vectors */
bool need_resize_gradients_;

#ifdef USE_CUDA_EXP
/*! \brief Buffer for bag_data_indices_ on GPU, used only with cuda_exp */
CUDAVector<data_size_t> cuda_bag_data_indices_;
#endif // USE_CUDA_EXP
};

} // namespace LightGBM

#endif // LIGHTGBM_SAMPLE_STRATEGY_H_
2 changes: 2 additions & 0 deletions python-package/lightgbm/dask.py
Expand Up @@ -1042,6 +1042,8 @@ def _lgb_dask_fit(
eval_at: Optional[Iterable[int]] = None,
**kwargs: Any
) -> "_DaskLGBMModel":
if not DASK_INSTALLED:
raise LightGBMError('dask is required for lightgbm.dask')
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')

Expand Down
1 change: 0 additions & 1 deletion python-package/lightgbm/sklearn.py
Expand Up @@ -382,7 +382,6 @@ def __init__(
boosting_type : str, optional (default='gbdt')
'gbdt', traditional Gradient Boosting Decision Tree.
'dart', Dropouts meet Multiple Additive Regression Trees.
'goss', Gradient-based One-Side Sampling.
'rf', Random Forest.
num_leaves : int, optional (default=31)
Maximum tree leaves for base learners.
Expand Down

0 comments on commit fffd066

Please sign in to comment.