diff --git a/docs/Parameters.rst b/docs/Parameters.rst
index 0b287fdc65f..42fb3febc70 100644
--- a/docs/Parameters.rst
+++ b/docs/Parameters.rst
@@ -254,6 +254,14 @@ Learning Control Parameters
- random seed for bagging
+- ``feature_fraction_bynode`` :raw-html:`🔗︎`, default = ``false``, type = bool, aliases: ``sub_feature_bynode``, ``colsample_bytree_bynode``
+
+ - set this to ``true`` to randomly select part of features for each node
+
+ - set this to ``false`` to randomly select part of features for each tree (use the same sub features for each tree)
+
+ - **Note**: set this to ``true`` cannot speed up the training, but set this to ``false`` can speed up the training linearly
+
- ``feature_fraction`` :raw-html:`🔗︎`, default = ``1.0``, type = double, aliases: ``sub_feature``, ``colsample_bytree``, constraints: ``0.0 < feature_fraction <= 1.0``
- LightGBM will randomly select part of features on each iteration if ``feature_fraction`` smaller than ``1.0``. For example, if you set it to ``0.8``, LightGBM will select 80% of features before training each tree
diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h
index 5e92582337e..ad341c1d8be 100644
--- a/include/LightGBM/config.h
+++ b/include/LightGBM/config.h
@@ -268,6 +268,12 @@ struct Config {
// desc = random seed for bagging
int bagging_seed = 3;
+ // alias = sub_feature_bynode, colsample_bytree_bynode
+ // desc = set this to ``true`` to randomly select part of features for each node
+ // desc = set this to ``false`` to randomly select part of features for each tree (use the same sub features for each tree)
+ // desc = **Note**: set this to ``true`` cannot speed up the training, but set this to ``false`` can speed up the training linearly
+ bool feature_fraction_bynode = false;
+
// alias = sub_feature, colsample_bytree
// check = >0.0
// check = <=1.0
diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp
index d598dc907e3..b2957cb6335 100644
--- a/src/io/config_auto.cpp
+++ b/src/io/config_auto.cpp
@@ -66,6 +66,8 @@ std::unordered_map Config::alias_table({
{"neg_bagging", "neg_bagging_fraction"},
{"subsample_freq", "bagging_freq"},
{"bagging_fraction_seed", "bagging_seed"},
+ {"sub_feature_bynode", "feature_fraction_bynode"},
+ {"colsample_bytree_bynode", "feature_fraction_bynode"},
{"sub_feature", "feature_fraction"},
{"colsample_bytree", "feature_fraction"},
{"early_stopping_rounds", "early_stopping_round"},
@@ -186,6 +188,7 @@ std::unordered_set Config::parameter_set({
"neg_bagging_fraction",
"bagging_freq",
"bagging_seed",
+ "feature_fraction_bynode",
"feature_fraction",
"feature_fraction_seed",
"early_stopping_round",
@@ -324,6 +327,8 @@ void Config::GetMembersFromString(const std::unordered_map0.0);
CHECK(feature_fraction <=1.0);
@@ -586,6 +591,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[neg_bagging_fraction: " << neg_bagging_fraction << "]\n";
str_buf << "[bagging_freq: " << bagging_freq << "]\n";
str_buf << "[bagging_seed: " << bagging_seed << "]\n";
+ str_buf << "[feature_fraction_bynode: " << feature_fraction_bynode << "]\n";
str_buf << "[feature_fraction: " << feature_fraction << "]\n";
str_buf << "[feature_fraction_seed: " << feature_fraction_seed << "]\n";
str_buf << "[early_stopping_round: " << early_stopping_round << "]\n";
diff --git a/src/treelearner/data_parallel_tree_learner.cpp b/src/treelearner/data_parallel_tree_learner.cpp
index a1fa299bded..1a705e53ac7 100644
--- a/src/treelearner/data_parallel_tree_learner.cpp
+++ b/src/treelearner/data_parallel_tree_learner.cpp
@@ -167,7 +167,12 @@ template
void DataParallelTreeLearner::FindBestSplitsFromHistograms(const std::vector&, bool) {
std::vector smaller_bests_per_thread(this->num_threads_, SplitInfo());
std::vector larger_bests_per_thread(this->num_threads_, SplitInfo());
-
+ std::vector smaller_node_used_features(this->num_features_, 1);
+ std::vector larger_node_used_features(this->num_features_, 1);
+ if (this->config_->feature_fraction_bynode) {
+ smaller_node_used_features = this->GetUsedFeatures();
+ larger_node_used_features = this->GetUsedFeatures();
+ }
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
@@ -193,7 +198,7 @@ void DataParallelTreeLearner::FindBestSplitsFromHistograms(const
this->smaller_leaf_splits_->max_constraint(),
&smaller_split);
smaller_split.feature = real_feature_index;
- if (smaller_split > smaller_bests_per_thread[tid]) {
+ if (smaller_split > smaller_bests_per_thread[tid] && smaller_node_used_features[feature_index]) {
smaller_bests_per_thread[tid] = smaller_split;
}
@@ -213,7 +218,7 @@ void DataParallelTreeLearner::FindBestSplitsFromHistograms(const
this->larger_leaf_splits_->max_constraint(),
&larger_split);
larger_split.feature = real_feature_index;
- if (larger_split > larger_bests_per_thread[tid]) {
+ if (larger_split > larger_bests_per_thread[tid] && larger_node_used_features[feature_index]) {
larger_bests_per_thread[tid] = larger_split;
}
OMP_LOOP_EX_END();
diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp
index f37be94abb1..e03f1cf44f2 100644
--- a/src/treelearner/serial_tree_learner.cpp
+++ b/src/treelearner/serial_tree_learner.cpp
@@ -268,26 +268,34 @@ Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vect
return FitByExistingTree(old_tree, gradients, hessians);
}
+std::vector SerialTreeLearner::GetUsedFeatures() {
+ std::vector ret(num_features_, 1);
+ if (config_->feature_fraction >= 1.0f) {
+ return ret;
+ }
+ int used_feature_cnt = static_cast(valid_feature_indices_.size()*config_->feature_fraction);
+ // at least use one feature
+ used_feature_cnt = std::max(used_feature_cnt, 1);
+ // initialize used features
+ std::memset(ret.data(), 0, sizeof(int8_t) * num_features_);
+ auto sampled_indices = random_.Sample(static_cast(valid_feature_indices_.size()), used_feature_cnt);
+ int omp_loop_size = static_cast(sampled_indices.size());
+ #pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
+ for (int i = 0; i < omp_loop_size; ++i) {
+ int used_feature = valid_feature_indices_[sampled_indices[i]];
+ int inner_feature_index = train_data_->InnerFeatureIndex(used_feature);
+ CHECK(inner_feature_index >= 0);
+ ret[inner_feature_index] = 1;
+ }
+ return ret;
+}
+
void SerialTreeLearner::BeforeTrain() {
// reset histogram pool
histogram_pool_.ResetMap();
- if (config_->feature_fraction < 1) {
- int used_feature_cnt = static_cast(valid_feature_indices_.size()*config_->feature_fraction);
- // at least use one feature
- used_feature_cnt = std::max(used_feature_cnt, 1);
- // initialize used features
- std::memset(is_feature_used_.data(), 0, sizeof(int8_t) * num_features_);
- // Get used feature at current tree
- auto sampled_indices = random_.Sample(static_cast(valid_feature_indices_.size()), used_feature_cnt);
- int omp_loop_size = static_cast(sampled_indices.size());
- #pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
- for (int i = 0; i < omp_loop_size; ++i) {
- int used_feature = valid_feature_indices_[sampled_indices[i]];
- int inner_feature_index = train_data_->InnerFeatureIndex(used_feature);
- CHECK(inner_feature_index >= 0);
- is_feature_used_[inner_feature_index] = 1;
- }
+ if (config_->feature_fraction < 1 && !config_->feature_fraction_bynode) {
+ is_feature_used_ = GetUsedFeatures();
} else {
#pragma omp parallel for schedule(static, 512) if (num_features_ >= 1024)
for (int i = 0; i < num_features_; ++i) {
@@ -513,6 +521,12 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector&
#endif
std::vector smaller_best(num_threads_);
std::vector larger_best(num_threads_);
+ std::vector smaller_node_used_features(num_features_, 1);
+ std::vector larger_node_used_features(num_features_, 1);
+ if (config_->feature_fraction_bynode) {
+ smaller_node_used_features = GetUsedFeatures();
+ larger_node_used_features = GetUsedFeatures();
+ }
OMP_INIT_EX();
// find splits
#pragma omp parallel for schedule(static)
@@ -542,7 +556,7 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector&
smaller_split.gain -= config_->cegb_tradeoff * CalculateOndemandCosts(real_fidx, smaller_leaf_splits_->LeafIndex());
}
splits_per_leaf_[smaller_leaf_splits_->LeafIndex()*train_data_->num_features() + feature_index] = smaller_split;
- if (smaller_split > smaller_best[tid]) {
+ if (smaller_split > smaller_best[tid] && smaller_node_used_features[feature_index]) {
smaller_best[tid] = smaller_split;
}
// only has root leaf
@@ -573,7 +587,7 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector&
larger_split.gain -= config_->cegb_tradeoff*CalculateOndemandCosts(real_fidx, larger_leaf_splits_->LeafIndex());
}
splits_per_leaf_[larger_leaf_splits_->LeafIndex()*train_data_->num_features() + feature_index] = larger_split;
- if (larger_split > larger_best[tid]) {
+ if (larger_split > larger_best[tid] && larger_node_used_features[feature_index]) {
larger_best[tid] = larger_split;
}
OMP_LOOP_EX_END();
diff --git a/src/treelearner/serial_tree_learner.h b/src/treelearner/serial_tree_learner.h
index 52feee1daf1..279b1fd4a68 100644
--- a/src/treelearner/serial_tree_learner.h
+++ b/src/treelearner/serial_tree_learner.h
@@ -78,6 +78,8 @@ class SerialTreeLearner: public TreeLearner {
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const override;
protected:
+
+ virtual std::vector GetUsedFeatures();
/*!
* \brief Some initial works before training
*/
diff --git a/src/treelearner/voting_parallel_tree_learner.cpp b/src/treelearner/voting_parallel_tree_learner.cpp
index f1c35e71f31..978f2b18e64 100644
--- a/src/treelearner/voting_parallel_tree_learner.cpp
+++ b/src/treelearner/voting_parallel_tree_learner.cpp
@@ -377,6 +377,12 @@ template
void VotingParallelTreeLearner::FindBestSplitsFromHistograms(const std::vector&, bool) {
std::vector smaller_bests_per_thread(this->num_threads_);
std::vector larger_best_per_thread(this->num_threads_);
+ std::vector smaller_node_used_features(this->num_features_, 1);
+ std::vector larger_node_used_features(this->num_features_, 1);
+ if (this->config_->feature_fraction_bynode) {
+ smaller_node_used_features = this->GetUsedFeatures();
+ larger_node_used_features = this->GetUsedFeatures();
+ }
// find best split from local aggregated histograms
OMP_INIT_EX();
@@ -405,7 +411,7 @@ void VotingParallelTreeLearner::FindBestSplitsFromHistograms(cons
smaller_leaf_splits_global_->max_constraint(),
&smaller_split);
smaller_split.feature = real_feature_index;
- if (smaller_split > smaller_bests_per_thread[tid]) {
+ if (smaller_split > smaller_bests_per_thread[tid] && smaller_node_used_features[feature_index]) {
smaller_bests_per_thread[tid] = smaller_split;
}
}
@@ -429,7 +435,7 @@ void VotingParallelTreeLearner::FindBestSplitsFromHistograms(cons
larger_leaf_splits_global_->max_constraint(),
&larger_split);
larger_split.feature = real_feature_index;
- if (larger_split > larger_best_per_thread[tid]) {
+ if (larger_split > larger_best_per_thread[tid] && larger_node_used_features[feature_index]) {
larger_best_per_thread[tid] = larger_split;
}
}
diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py
index d7448d053a7..63f1468132a 100644
--- a/tests/python_package_test/test_engine.py
+++ b/tests/python_package_test/test_engine.py
@@ -1584,3 +1584,30 @@ def constant_metric(preds, train_data):
decreasing_metric(preds, train_data)],
early_stopping_rounds=5, verbose_eval=False)
self.assertEqual(gbm.best_iteration, 1)
+
+ def test_node_level_subcol(self):
+ X, y = load_breast_cancer(True)
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
+ params = {
+ 'objective': 'binary',
+ 'metric': 'binary_logloss',
+ 'feature_fraction': 0.8,
+ 'feature_fraction_bynode': True,
+ 'verbose': -1
+ }
+ lgb_train = lgb.Dataset(X_train, y_train)
+ lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
+ evals_result = {}
+ gbm = lgb.train(params, lgb_train,
+ num_boost_round=25,
+ valid_sets=lgb_eval,
+ verbose_eval=False,
+ evals_result=evals_result)
+ ret = log_loss(y_test, gbm.predict(X_test))
+ self.assertLess(ret, 0.13)
+ self.assertAlmostEqual(evals_result['valid_0']['binary_logloss'][-1], ret, places=5)
+ params['feature_fraction'] = 0.5
+ gbm2 = lgb.train(params, lgb_train,
+ num_boost_round=25)
+ ret2 = log_loss(y_test, gbm2.predict(X_test))
+ self.assertNotEqual(ret, ret2)