Skip to content

Commit

Permalink
Disable column sample by node for the exact tree method. (#10083)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Mar 1, 2024
1 parent 8189126 commit 3941b31
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 13 deletions.
2 changes: 1 addition & 1 deletion R-package/tests/testthat/test_feature_weights.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ test_that("training with feature weights works", {
expect_lt(importance[1, Frequency], importance[9, Frequency])
}

for (tm in c("hist", "approx", "exact")) {
for (tm in c("hist", "approx")) {
test(tm)
}
})
2 changes: 1 addition & 1 deletion doc/parameter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ Parameters for Tree Booster
- All ``colsample_by*`` parameters have a range of (0, 1], the default value of 1, and specify the fraction of columns to be subsampled.
- ``colsample_bytree`` is the subsample ratio of columns when constructing each tree. Subsampling occurs once for every tree constructed.
- ``colsample_bylevel`` is the subsample ratio of columns for each level. Subsampling occurs once for every new depth level reached in a tree. Columns are subsampled from the set of columns chosen for the current tree.
- ``colsample_bynode`` is the subsample ratio of columns for each node (split). Subsampling occurs once every time a new split is evaluated. Columns are subsampled from the set of columns chosen for the current level.
- ``colsample_bynode`` is the subsample ratio of columns for each node (split). Subsampling occurs once every time a new split is evaluated. Columns are subsampled from the set of columns chosen for the current level. This is not supported by the exact tree method.
- ``colsample_by*`` parameters work cumulatively. For instance,
the combination ``{'colsample_bytree':0.5, 'colsample_bylevel':0.5,
'colsample_bynode':0.5}`` with 64 features will leave 8 features to choose from at
Expand Down
18 changes: 9 additions & 9 deletions src/tree/updater_colmaker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ class ColMaker: public TreeUpdater {
if (dmat->Info().HasCategorical()) {
LOG(FATAL) << error::NoCategorical("Updater `grow_colmaker` or `exact` tree method");
}
if (param->colsample_bynode - 1.0 != 0.0) {
LOG(FATAL) << "column sample by node is not yet supported by the exact tree method";
}
this->LazyGetColumnDensity(dmat);
// rescale learning rate according to size of trees
interaction_constraints_.Configure(*param, dmat->Info().num_row_);
Expand Down Expand Up @@ -440,9 +443,8 @@ class ColMaker: public TreeUpdater {
}

// update the solution candidate
virtual void UpdateSolution(const SortedCSCPage &batch,
const std::vector<bst_feature_t> &feat_set,
const std::vector<GradientPair> &gpair, DMatrix *) {
void UpdateSolution(SortedCSCPage const &batch, const std::vector<bst_feature_t> &feat_set,
const std::vector<GradientPair> &gpair) {
// start enumeration
const auto num_features = feat_set.size();
CHECK(this->ctx_);
Expand All @@ -466,17 +468,15 @@ class ColMaker: public TreeUpdater {
}
});
}

// find splits at current level, do split per level
inline void FindSplit(int depth,
const std::vector<int> &qexpand,
const std::vector<GradientPair> &gpair,
DMatrix *p_fmat,
RegTree *p_tree) {
void FindSplit(bst_node_t depth, const std::vector<int> &qexpand,
std::vector<GradientPair> const &gpair, DMatrix *p_fmat, RegTree *p_tree) {
auto evaluator = tree_evaluator_.GetEvaluator();

auto feat_set = column_sampler_->GetFeatureSet(depth);
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>(ctx_)) {
this->UpdateSolution(batch, feat_set->HostVector(), gpair, p_fmat);
this->UpdateSolution(batch, feat_set->HostVector(), gpair);
}
// after this each thread's stemp will get the best candidates, aggregate results
this->SyncBestSolution(qexpand);
Expand Down
18 changes: 16 additions & 2 deletions tests/python/test_updaters.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,24 @@ class TestTreeMethod:
def test_exact(self, param, num_rounds, dataset):
if dataset.name.endswith("-l1"):
return
param['tree_method'] = 'exact'
param["tree_method"] = "exact"
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), num_rounds)
assert tm.non_increasing(result['train'][dataset.metric])
assert tm.non_increasing(result["train"][dataset.metric])

def test_exact_sample_by_node_error(self) -> None:
X, y, w = tm.make_regression(128, 12, False)
with pytest.raises(ValueError, match="column sample by node"):
xgb.train(
{"tree_method": "exact", "colsample_bynode": 0.999},
xgb.DMatrix(X, y, weight=w),
)

xgb.train(
{"tree_method": "exact", "colsample_bynode": 1.0},
xgb.DMatrix(X, y, weight=w),
num_boost_round=2,
)

@given(
exact_parameter_strategy,
Expand Down

0 comments on commit 3941b31

Please sign in to comment.