diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h new file mode 100644 index 000000000000..f78a7ed0974a --- /dev/null +++ b/src/tree/hist/evaluate_splits.h @@ -0,0 +1,268 @@ +/*! + * Copyright 2021 by XGBoost Contributors + */ +#ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_ +#define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_ + +#include +#include +#include +#include +#include + +#include "../param.h" +#include "../constraints.h" +#include "../split_evaluator.h" +#include "../../common/random.h" +#include "../../common/hist_util.h" +#include "../../data/gradient_index.h" + +namespace xgboost { +namespace tree { + +template class HistEvaluator { + private: + struct NodeEntry { + /*! \brief statics for node entry */ + GradStats stats; + /*! \brief loss of this node, without split */ + bst_float root_gain{0.0f}; + }; + + private: + TrainParam param_; + std::shared_ptr column_sampler_; + TreeEvaluator tree_evaluator_; + int32_t n_threads_ {0}; + FeatureInteractionConstraintHost interaction_constraints_; + std::vector snode_; + + // if sum of statistics for non-missing values in the node + // is equal to sum of statistics for all values: + // then - there are no missing values + // else - there are missing values + bool static SplitContainsMissingValues(const GradStats e, + const NodeEntry &snode) { + if (e.GetGrad() == snode.stats.GetGrad() && + e.GetHess() == snode.stats.GetHess()) { + return false; + } else { + return true; + } + } + + // Enumerate/Scan the split values of specific feature + // Returns the sum of gradients corresponding to the data points that contains + // a non-missing value for the particular feature fid. + template + GradStats EnumerateSplit( + const GHistIndexMatrix &gmat, const common::GHistRow &hist, + const NodeEntry &snode, SplitEntry *p_best, bst_feature_t fidx, + bst_node_t nidx, + TreeEvaluator::SplitEvaluator const &evaluator) const { + static_assert(d_step == +1 || d_step == -1, "Invalid step."); + + // aliases + const std::vector &cut_ptr = gmat.cut.Ptrs(); + const std::vector &cut_val = gmat.cut.Values(); + + // statistics on both sides of split + GradStats c; + GradStats e; + // best split so far + SplitEntry best; + + // bin boundaries + CHECK_LE(cut_ptr[fidx], + static_cast(std::numeric_limits::max())); + CHECK_LE(cut_ptr[fidx + 1], + static_cast(std::numeric_limits::max())); + // imin: index (offset) of the minimum value for feature fid + // need this for backward enumeration + const auto imin = static_cast(cut_ptr[fidx]); + // ibegin, iend: smallest/largest cut points for feature fid + // use int to allow for value -1 + int32_t ibegin, iend; + if (d_step > 0) { + ibegin = static_cast(cut_ptr[fidx]); + iend = static_cast(cut_ptr.at(fidx + 1)); + } else { + ibegin = static_cast(cut_ptr[fidx + 1]) - 1; + iend = static_cast(cut_ptr[fidx]) - 1; + } + + for (int32_t i = ibegin; i != iend; i += d_step) { + // start working + // try to find a split + e.Add(hist[i].GetGrad(), hist[i].GetHess()); + if (e.GetHess() >= param_.min_child_weight) { + c.SetSubstract(snode.stats, e); + if (c.GetHess() >= param_.min_child_weight) { + bst_float loss_chg; + bst_float split_pt; + if (d_step > 0) { + // forward enumeration: split at right bound of each bin + loss_chg = static_cast( + evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{e}, + GradStats{c}) - + snode.root_gain); + split_pt = cut_val[i]; + best.Update(loss_chg, fidx, split_pt, d_step == -1, e, c); + } else { + // backward enumeration: split at left bound of each bin + loss_chg = static_cast( + evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{c}, + GradStats{e}) - + snode.root_gain); + if (i == imin) { + // for leftmost bin, left bound is the smallest feature value + split_pt = gmat.cut.MinValues()[fidx]; + } else { + split_pt = cut_val[i - 1]; + } + best.Update(loss_chg, fidx, split_pt, d_step == -1, c, e); + } + } + } + } + p_best->Update(best); + + return e; + } + + public: + void EvaluateSplits(const common::HistCollection &hist, + GHistIndexMatrix const &gidx, const RegTree &tree, + std::vector* p_entries) { + auto& entries = *p_entries; + // All nodes are on the same level, so we can store the shared ptr. + std::vector>> features( + entries.size()); + for (size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) { + auto nidx = entries[nidx_in_set].nid; + features[nidx_in_set] = + column_sampler_->GetFeatureSet(tree.GetDepth(nidx)); + } + CHECK(!features.empty()); + const size_t grain_size = + std::max(1, features.front()->Size() / n_threads_); + common::BlockedSpace2d space(entries.size(), [&](size_t nidx_in_set) { + return features[nidx_in_set]->Size(); + }, grain_size); + + std::vector tloc_candidates(omp_get_max_threads() * entries.size()); + for (size_t i = 0; i < entries.size(); ++i) { + for (decltype(n_threads_) j = 0; j < n_threads_; ++j) { + tloc_candidates[i * n_threads_ + j] = entries[i]; + } + } + auto evaluator = tree_evaluator_.GetEvaluator(); + + common::ParallelFor2d(space, n_threads_, [&](size_t nidx_in_set, common::Range1d r) { + auto tidx = omp_get_thread_num(); + auto entry = &tloc_candidates[n_threads_ * nidx_in_set + tidx]; + auto best = &entry->split; + auto nidx = entry->nid; + auto histogram = hist[nidx]; + auto features_set = features[nidx_in_set]->ConstHostSpan(); + for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) { + auto fidx = features_set[fidx_in_set]; + if (interaction_constraints_.Query(nidx, fidx)) { + auto grad_stats = EnumerateSplit<+1>(gidx, histogram, snode_[nidx], + best, fidx, nidx, evaluator); + if (SplitContainsMissingValues(grad_stats, snode_[nidx])) { + EnumerateSplit<-1>(gidx, histogram, snode_[nidx], best, fidx, nidx, + evaluator); + } + } + } + }); + + for (unsigned nidx_in_set = 0; nidx_in_set < entries.size(); + ++nidx_in_set) { + for (auto tidx = 0; tidx < n_threads_; ++tidx) { + entries[nidx_in_set].split.Update( + tloc_candidates[n_threads_ * nidx_in_set + tidx].split); + } + } + } + // Add splits to tree, handles all statistic + void ApplyTreeSplit(ExpandEntry candidate, RegTree *p_tree) { + auto evaluator = tree_evaluator_.GetEvaluator(); + RegTree &tree = *p_tree; + + GradStats parent_sum = candidate.split.left_sum; + parent_sum.Add(candidate.split.right_sum); + auto base_weight = + evaluator.CalcWeight(candidate.nid, param_, GradStats{parent_sum}); + + auto left_weight = evaluator.CalcWeight( + candidate.nid, param_, GradStats{candidate.split.left_sum}); + auto right_weight = evaluator.CalcWeight( + candidate.nid, param_, GradStats{candidate.split.right_sum}); + + tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), + candidate.split.split_value, candidate.split.DefaultLeft(), + base_weight, left_weight * param_.learning_rate, + right_weight * param_.learning_rate, + candidate.split.loss_chg, parent_sum.GetHess(), + candidate.split.left_sum.GetHess(), + candidate.split.right_sum.GetHess()); + + // Set up child constraints + auto left_child = tree[candidate.nid].LeftChild(); + auto right_child = tree[candidate.nid].RightChild(); + tree_evaluator_.AddSplit(candidate.nid, left_child, right_child, + tree[candidate.nid].SplitIndex(), left_weight, + right_weight); + + auto max_node = std::max(left_child, tree[candidate.nid].RightChild()); + max_node = std::max(candidate.nid, max_node); + snode_.resize(tree.GetNodes().size()); + snode_.at(left_child).stats = candidate.split.left_sum; + snode_.at(left_child).root_gain = evaluator.CalcGain( + candidate.nid, param_, GradStats{candidate.split.left_sum}); + snode_.at(right_child).stats = candidate.split.right_sum; + snode_.at(right_child).root_gain = evaluator.CalcGain( + candidate.nid, param_, GradStats{candidate.split.right_sum}); + + interaction_constraints_.Split(candidate.nid, + tree[candidate.nid].SplitIndex(), left_child, + right_child); + } + + auto Evaluator() const { return tree_evaluator_.GetEvaluator(); } + auto const& Stats() const { return snode_; } + + float InitRoot(GradStats const& root_sum) { + snode_.resize(1); + auto root_evaluator = tree_evaluator_.GetEvaluator(); + + snode_[0].stats = GradStats{root_sum.GetGrad(), root_sum.GetHess()}; + snode_[0].root_gain = root_evaluator.CalcGain(RegTree::kRoot, param_, + GradStats{snode_[0].stats}); + auto weight = root_evaluator.CalcWeight(RegTree::kRoot, param_, + GradStats{snode_[0].stats}); + return weight; + } + + public: + // The column sampler must be constructed by caller since we need to preserve the rng + // for the entire training session. + explicit HistEvaluator(TrainParam const ¶m, MetaInfo const &info, + int32_t n_threads, + std::shared_ptr sampler, + bool skip_0_index = false) + : param_{param}, column_sampler_{std::move(sampler)}, + tree_evaluator_{param, static_cast(info.num_col_), + GenericParameter::kCpuId}, + n_threads_{n_threads} { + interaction_constraints_.Configure(param, info.num_col_); + column_sampler_->Init(info.num_col_, info.feature_weigths.HostVector(), + param_.colsample_bynode, param_.colsample_bylevel, + param_.colsample_bytree, skip_0_index); + } +}; +} // namespace tree +} // namespace xgboost +#endif // XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_ diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index e430d95f5329..ae59323b6407 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -51,11 +51,8 @@ template void QuantileHistMaker::SetBuilder(const size_t n_trees, std::unique_ptr>* builder, DMatrix *dmat) { - builder->reset(new Builder( - n_trees, - param_, - std::move(pruner_), - int_constraint_, dmat)); + builder->reset( + new Builder(n_trees, param_, std::move(pruner_), dmat)); if (rabit::IsDistributed()) { (*builder)->SetHistSynchronizer(new DistributedHistSynchronizer()); (*builder)->SetHistRowsAdder(new DistributedHistRowsAdder()); @@ -75,6 +72,7 @@ void QuantileHistMaker::CallBuilderUpdate(const std::unique_ptrUpdate(gmat, column_matrix_, gpair, dmat, tree); } } + void QuantileHistMaker::Update(HostDeviceVector *gpair, DMatrix *dmat, const std::vector &trees) { @@ -93,7 +91,7 @@ void QuantileHistMaker::Update(HostDeviceVector *gpair, // rescale learning rate according to size of trees float lr = param_.learning_rate; param_.learning_rate = lr / trees.size(); - int_constraint_.Configure(param_, dmat->Info().num_col_); + // build tree const size_t n_trees = trees.size(); if (hist_maker_param_.single_precision_histogram) { @@ -296,12 +294,9 @@ void QuantileHistMaker::Builder::SetHistRowsAdder( template template void QuantileHistMaker::Builder::InitRoot( - const GHistIndexMatrix &gmat, - const DMatrix& fmat, - RegTree *p_tree, - const std::vector &gpair_h, - int *num_leaves, std::vector *expand) { - + const GHistIndexMatrix &gmat, const DMatrix &fmat, RegTree *p_tree, + const std::vector &gpair_h, int *num_leaves, + std::vector *expand) { CPUExpandEntry node(CPUExpandEntry::kRootNid, p_tree->GetDepth(0), 0.0f); nodes_for_explicit_hist_build_.clear(); @@ -315,10 +310,40 @@ void QuantileHistMaker::Builder::InitRoot( BuildLocalHistograms(gmat, p_tree, gpair_h); hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, p_tree); - this->InitNewNode(CPUExpandEntry::kRootNid, gmat, gpair_h, fmat, *p_tree); + { + auto nid = CPUExpandEntry::kRootNid; + GHistRowT hist = hist_[nid]; + GradientPairT grad_stat; + if (data_layout_ == DataLayout::kDenseDataZeroBased || + data_layout_ == DataLayout::kDenseDataOneBased) { + const std::vector &row_ptr = gmat.cut.Ptrs(); + const uint32_t ibegin = row_ptr[fid_least_bins_]; + const uint32_t iend = row_ptr[fid_least_bins_ + 1]; + auto begin = hist.data(); + for (uint32_t i = ibegin; i < iend; ++i) { + const GradientPairT et = begin[i]; + grad_stat.Add(et.GetGrad(), et.GetHess()); + } + } else { + const RowSetCollection::Elem e = row_set_collection_[nid]; + for (const size_t *it = e.begin; it < e.end; ++it) { + grad_stat.Add(gpair_h[*it].GetGrad(), gpair_h[*it].GetHess()); + } + } + histred_.Allreduce(&grad_stat, 1); + + auto weight = evaluator_->InitRoot(GradStats{grad_stat}); + p_tree->Stat(RegTree::kRoot).sum_hess = grad_stat.GetHess(); + p_tree->Stat(RegTree::kRoot).base_weight = weight; + (*p_tree)[RegTree::kRoot].SetLeaf(param_.learning_rate * weight); + + std::vector entries{node}; + builder_monitor_.Start("EvaluateSplits"); + evaluator_->EvaluateSplits(hist_, gmat, *p_tree, &entries); + builder_monitor_.Stop("EvaluateSplits"); + node = entries.front(); + } - this->EvaluateSplits({node}, gmat, hist_, *p_tree); - node.loss_chg = snode_[CPUExpandEntry::kRootNid].best.loss_chg; expand->push_back(node); ++(*num_leaves); } @@ -369,25 +394,10 @@ void QuantileHistMaker::Builder::AddSplitsToTree( RegTree *p_tree, int *num_leaves, std::vector* nodes_for_apply_split) { - auto evaluator = tree_evaluator_.GetEvaluator(); for (auto const& entry : expand) { - int nid = entry.nid; - if (entry.IsValid(param_, *num_leaves)) { - (*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate); - } else { nodes_for_apply_split->push_back(entry); - - NodeEntry& e = snode_[nid]; - bst_float left_leaf_weight = - evaluator.CalcWeight(nid, param_, GradStats{e.best.left_sum}) * param_.learning_rate; - bst_float right_leaf_weight = - evaluator.CalcWeight(nid, param_, GradStats{e.best.right_sum}) * param_.learning_rate; - p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, - e.best.DefaultLeft(), e.weight, left_leaf_weight, - right_leaf_weight, e.best.loss_chg, e.stats.GetHess(), - e.best.left_sum.GetHess(), e.best.right_sum.GetHess()); - // - 1 parent + 2 new children + evaluator_->ApplyTreeSplit(entry, p_tree); (*num_leaves)++; } } @@ -425,26 +435,6 @@ void QuantileHistMaker::Builder::SplitSiblings( builder_monitor_.Stop("SplitSiblings"); } -template -void QuantileHistMaker::Builder::BuildNodeStats( - const GHistIndexMatrix &gmat, - const DMatrix& fmat, - const std::vector &gpair_h, - const std::vector& nodes_for_apply_split, RegTree *p_tree) { - for (auto const& candidate : nodes_for_apply_split) { - const int nid = candidate.nid; - const int cleft = (*p_tree)[nid].LeftChild(); - const int cright = (*p_tree)[nid].RightChild(); - - InitNewNode(cleft, gmat, gpair_h, fmat, *p_tree); - InitNewNode(cright, gmat, gpair_h, fmat, *p_tree); - bst_uint featureid = snode_[nid].best.SplitIndex(); - tree_evaluator_.AddSplit(nid, cleft, cright, featureid, - snode_[cleft].weight, snode_[cright].weight); - interaction_constraints_.Split(nid, featureid, cleft, cright); - } -} - template template void QuantileHistMaker::Builder::ExpandTree( @@ -484,20 +474,13 @@ void QuantileHistMaker::Builder::ExpandTree( hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, p_tree); } - BuildNodeStats(gmat, *p_fmat, gpair_h, nodes_for_apply_split, p_tree); - EvaluateSplits(nodes_to_evaluate, gmat, hist_, *p_tree); + builder_monitor_.Start("EvaluateSplits"); + evaluator_->EvaluateSplits(hist_, gmat, *p_tree, &nodes_to_evaluate); + builder_monitor_.Stop("EvaluateSplits"); for (size_t i = 0; i < nodes_for_apply_split.size(); ++i) { - const CPUExpandEntry candidate = nodes_for_apply_split[i]; - const int nid = candidate.nid; - const int cleft = (*p_tree)[nid].LeftChild(); - const int cright = (*p_tree)[nid].RightChild(); - CPUExpandEntry left_node = nodes_to_evaluate[i*2 + 0]; - CPUExpandEntry right_node = nodes_to_evaluate[i*2 + 1]; - - left_node.loss_chg = snode_[cleft].best.loss_chg; - right_node.loss_chg = snode_[cright].best.loss_chg; - + CPUExpandEntry left_node = nodes_to_evaluate.at(i * 2 + 0); + CPUExpandEntry right_node = nodes_to_evaluate.at(i * 2 + 1); driver.Push(left_node); driver.Push(right_node); } @@ -521,9 +504,6 @@ void QuantileHistMaker::Builder::Update( gpair_local_ = *gpair_ptr; gpair_ptr = &gpair_local_; } - tree_evaluator_ = - TreeEvaluator(param_, p_fmat->Info().num_col_, GenericParameter::kCpuId); - interaction_constraints_.Reset(); p_last_fmat_mutable_ = p_fmat; this->InitData(gmat, *p_fmat, *p_tree, gpair_ptr); @@ -533,11 +513,6 @@ void QuantileHistMaker::Builder::Update( } else { ExpandTree(gmat, column_matrix, p_fmat, p_tree, *gpair_ptr); } - for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) { - p_tree->Stat(nid).loss_chg = snode_[nid].best.loss_chg; - p_tree->Stat(nid).base_weight = snode_[nid].weight; - p_tree->Stat(nid).sum_hess = static_cast(snode_[nid].stats.GetHess()); - } pruner_->Update(gpair, p_fmat, std::vector{p_tree}); builder_monitor_.Stop("Update"); @@ -761,14 +736,13 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& // store a pointer to the tree p_last_tree_ = &tree; if (data_layout_ == DataLayout::kDenseDataOneBased) { - column_sampler_.Init(info.num_col_, info.feature_weigths.ConstHostVector(), - param_.colsample_bynode, param_.colsample_bylevel, - param_.colsample_bytree, true); + evaluator_.reset(new HistEvaluator{ + param_, info, this->nthread_, column_sampler_, true}); } else { - column_sampler_.Init(info.num_col_, info.feature_weigths.ConstHostVector(), - param_.colsample_bynode, param_.colsample_bylevel, - param_.colsample_bytree, false); + evaluator_.reset(new HistEvaluator{ + param_, info, this->nthread_, column_sampler_, false}); } + if (data_layout_ == DataLayout::kDenseDataZeroBased || data_layout_ == DataLayout::kDenseDataOneBased) { /* specialized code for dense data: @@ -789,95 +763,10 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& } CHECK_GT(min_nbins_per_feature, 0U); } - { - snode_.reserve(256); - snode_.clear(); - } builder_monitor_.Stop("InitData"); } -// if sum of statistics for non-missing values in the node -// is equal to sum of statistics for all values: -// then - there are no missing values -// else - there are missing values -template -bool QuantileHistMaker::Builder::SplitContainsMissingValues( - const GradStats e, const NodeEntry &snode) { - if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) { - return false; - } else { - return true; - } -} - -// nodes_set - set of nodes to be processed in parallel -template -void QuantileHistMaker::Builder::EvaluateSplits( - const std::vector& nodes_set, - const GHistIndexMatrix& gmat, - const HistCollection& hist, - const RegTree& tree) { - builder_monitor_.Start("EvaluateSplits"); - - const size_t n_nodes_in_set = nodes_set.size(); - const size_t nthread = std::max(1, this->nthread_); - - using FeatureSetType = std::shared_ptr>; - std::vector features_sets(n_nodes_in_set); - best_split_tloc_.resize(nthread * n_nodes_in_set); - - // Generate feature set for each tree node - for (size_t nid_in_set = 0; nid_in_set < n_nodes_in_set; ++nid_in_set) { - const int32_t nid = nodes_set[nid_in_set].nid; - features_sets[nid_in_set] = column_sampler_.GetFeatureSet(tree.GetDepth(nid)); - - for (unsigned tid = 0; tid < nthread; ++tid) { - best_split_tloc_[nthread*nid_in_set + tid] = snode_[nid].best; - } - } - - // Create 2D space (# of nodes to process x # of features to process) - // to process them in parallel - const size_t grain_size = std::max(1, features_sets[0]->Size() / nthread); - common::BlockedSpace2d space(n_nodes_in_set, [&](size_t nid_in_set) { - return features_sets[nid_in_set]->Size(); - }, grain_size); - - auto evaluator = tree_evaluator_.GetEvaluator(); - // Start parallel enumeration for all tree nodes in the set and all features - common::ParallelFor2d(space, this->nthread_, [&](size_t nid_in_set, common::Range1d r) { - const int32_t nid = nodes_set[nid_in_set].nid; - const auto tid = static_cast(omp_get_thread_num()); - GHistRowT node_hist = hist[nid]; - - for (auto idx_in_feature_set = r.begin(); idx_in_feature_set < r.end(); ++idx_in_feature_set) { - const auto fid = features_sets[nid_in_set]->ConstHostVector()[idx_in_feature_set]; - if (interaction_constraints_.Query(nid, fid)) { - auto grad_stats = this->EnumerateSplit<+1>( - gmat, node_hist, snode_[nid], - &best_split_tloc_[nthread * nid_in_set + tid], fid, nid, evaluator); - if (SplitContainsMissingValues(grad_stats, snode_[nid])) { - this->EnumerateSplit<-1>( - gmat, node_hist, snode_[nid], - &best_split_tloc_[nthread * nid_in_set + tid], fid, nid, - evaluator); - } - } - } - }); - - // Find Best Split across threads for each node in nodes set - for (unsigned nid_in_set = 0; nid_in_set < n_nodes_in_set; ++nid_in_set) { - const int32_t nid = nodes_set[nid_in_set].nid; - for (unsigned tid = 0; tid < nthread; ++tid) { - snode_[nid].best.Update(best_split_tloc_[nthread*nid_in_set + tid]); - } - } - - builder_monitor_.Stop("EvaluateSplits"); -} - template void QuantileHistMaker::Builder::FindSplitConditions( const std::vector& nodes, @@ -988,139 +877,6 @@ void QuantileHistMaker::Builder::ApplySplit(const std::vector -void QuantileHistMaker::Builder::InitNewNode(int nid, - const GHistIndexMatrix& gmat, - const std::vector& gpair, - const DMatrix& fmat, - const RegTree& tree) { - builder_monitor_.Start("InitNewNode"); - { - snode_.resize(tree.param.num_nodes, NodeEntry(param_)); - } - - { - GHistRowT hist = hist_[nid]; - GradientPairT grad_stat; - if (tree[nid].IsRoot()) { - if (data_layout_ == DataLayout::kDenseDataZeroBased - || data_layout_ == DataLayout::kDenseDataOneBased) { - const std::vector& row_ptr = gmat.cut.Ptrs(); - const uint32_t ibegin = row_ptr[fid_least_bins_]; - const uint32_t iend = row_ptr[fid_least_bins_ + 1]; - auto begin = hist.data(); - for (uint32_t i = ibegin; i < iend; ++i) { - const GradientPairT et = begin[i]; - grad_stat.Add(et.GetGrad(), et.GetHess()); - } - } else { - const RowSetCollection::Elem e = row_set_collection_[nid]; - for (const size_t* it = e.begin; it < e.end; ++it) { - grad_stat.Add(gpair[*it].GetGrad(), gpair[*it].GetHess()); - } - } - histred_.Allreduce(&grad_stat, 1); - snode_[nid].stats = tree::GradStats(grad_stat.GetGrad(), grad_stat.GetHess()); - } else { - int parent_id = tree[nid].Parent(); - if (tree[nid].IsLeftChild()) { - snode_[nid].stats = snode_[parent_id].best.left_sum; - } else { - snode_[nid].stats = snode_[parent_id].best.right_sum; - } - } - } - - // calculating the weights - { - auto evaluator = tree_evaluator_.GetEvaluator(); - bst_uint parentid = tree[nid].Parent(); - snode_[nid].weight = static_cast( - evaluator.CalcWeight(parentid, param_, GradStats{snode_[nid].stats})); - snode_[nid].root_gain = static_cast( - evaluator.CalcGain(parentid, param_, GradStats{snode_[nid].stats})); - } - builder_monitor_.Stop("InitNewNode"); -} - -// Enumerate the split values of specific feature. -// Returns the sum of gradients corresponding to the data points that contains a non-missing value -// for the particular feature fid. -template -template -GradStats QuantileHistMaker::Builder::EnumerateSplit( - const GHistIndexMatrix &gmat, const GHistRowT &hist, const NodeEntry &snode, - SplitEntry *p_best, bst_uint fid, bst_uint nodeID, - TreeEvaluator::SplitEvaluator const &evaluator) const { - CHECK(d_step == +1 || d_step == -1); - - // aliases - const std::vector& cut_ptr = gmat.cut.Ptrs(); - const std::vector& cut_val = gmat.cut.Values(); - - // statistics on both sides of split - GradStats c; - GradStats e; - // best split so far - SplitEntry best; - - // bin boundaries - CHECK_LE(cut_ptr[fid], - static_cast(std::numeric_limits::max())); - CHECK_LE(cut_ptr[fid + 1], - static_cast(std::numeric_limits::max())); - // imin: index (offset) of the minimum value for feature fid - // need this for backward enumeration - const auto imin = static_cast(cut_ptr[fid]); - // ibegin, iend: smallest/largest cut points for feature fid - // use int to allow for value -1 - int32_t ibegin, iend; - if (d_step > 0) { - ibegin = static_cast(cut_ptr[fid]); - iend = static_cast(cut_ptr[fid + 1]); - } else { - ibegin = static_cast(cut_ptr[fid + 1]) - 1; - iend = static_cast(cut_ptr[fid]) - 1; - } - - for (int32_t i = ibegin; i != iend; i += d_step) { - // start working - // try to find a split - e.Add(hist[i].GetGrad(), hist[i].GetHess()); - if (e.GetHess() >= param_.min_child_weight) { - c.SetSubstract(snode.stats, e); - if (c.GetHess() >= param_.min_child_weight) { - bst_float loss_chg; - bst_float split_pt; - if (d_step > 0) { - // forward enumeration: split at right bound of each bin - loss_chg = static_cast( - evaluator.CalcSplitGain(param_, nodeID, fid, GradStats{e}, - GradStats{c}) - - snode.root_gain); - split_pt = cut_val[i]; - best.Update(loss_chg, fid, split_pt, d_step == -1, e, c); - } else { - // backward enumeration: split at left bound of each bin - loss_chg = static_cast( - evaluator.CalcSplitGain(param_, nodeID, fid, GradStats{c}, - GradStats{e}) - - snode.root_gain); - if (i == imin) { - // for leftmost bin, left bound is the smallest feature value - split_pt = gmat.cut.MinValues()[fid]; - } else { - split_pt = cut_val[i - 1]; - } - best.Update(loss_chg, fid, split_pt, d_step == -1, c, e); - } - } - } - } - p_best->Update(best); - - return e; -} template struct QuantileHistMaker::Builder; template struct QuantileHistMaker::Builder; diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 6c55a5bb1af2..3c82f57f972f 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -20,6 +20,8 @@ #include "xgboost/data.h" #include "xgboost/json.h" + +#include "hist/evaluate_splits.h" #include "constraints.h" #include "./param.h" #include "./driver.h" @@ -121,19 +123,23 @@ struct CPUExpandEntry { static const int kEmptyNid = -1; int nid; int depth; - bst_float loss_chg; + SplitEntry split; + + CPUExpandEntry() = default; CPUExpandEntry(int nid, int depth, bst_float loss_chg) - : nid(nid), depth(depth), loss_chg(loss_chg) {} + : nid(nid), depth(depth) { + split.loss_chg = loss_chg; + } bool IsValid(TrainParam const ¶m, int32_t num_leaves) const { - bool ret = loss_chg <= kRtEps || - (param.max_depth > 0 && this->depth == param.max_depth) || - (param.max_leaves > 0 && num_leaves == param.max_leaves); - return ret; + bool invalid = split.loss_chg <= kRtEps || + (param.max_depth > 0 && this->depth == param.max_depth) || + (param.max_leaves > 0 && num_leaves == param.max_leaves); + return !invalid; } bst_float GetLossChange() const { - return loss_chg; + return split.loss_chg; } int GetNodeId() const { @@ -214,39 +220,17 @@ class QuantileHistMaker: public TreeUpdater { DMatrix const* p_last_dmat_ {nullptr}; bool is_gmat_initialized_ {false}; - // data structure - struct NodeEntry { - /*! \brief statics for node entry */ - GradStats stats; - /*! \brief loss of this node, without split */ - bst_float root_gain; - /*! \brief weight calculated related to current data */ - float weight; - /*! \brief current best solution */ - SplitEntry best; - // constructor - explicit NodeEntry(const TrainParam&) - : root_gain(0.0f), weight(0.0f) {} - }; // actual builder that runs the algorithm - template struct Builder { public: using GHistRowT = GHistRow; using GradientPairT = xgboost::detail::GradientPairInternal; // constructor - explicit Builder(const size_t n_trees, - const TrainParam& param, - std::unique_ptr pruner, - FeatureInteractionConstraintHost int_constraints_, - DMatrix const* fmat) - : n_trees_(n_trees), - param_(param), - tree_evaluator_(param, fmat->Info().num_col_, GenericParameter::kCpuId), - pruner_(std::move(pruner)), - interaction_constraints_{std::move(int_constraints_)}, - p_last_tree_(nullptr), p_last_fmat_(fmat) { + explicit Builder(const size_t n_trees, const TrainParam ¶m, + std::unique_ptr pruner, DMatrix const *fmat) + : n_trees_(n_trees), param_(param), pruner_(std::move(pruner)), + p_last_tree_(nullptr), p_last_fmat_(fmat) { builder_monitor_.Init("Quantile::Builder"); } // update one tree, growing @@ -290,11 +274,6 @@ class QuantileHistMaker: public TreeUpdater { std::vector* gpair, std::vector* row_indices); - void EvaluateSplits(const std::vector& nodes_set, - const GHistIndexMatrix& gmat, - const HistCollection& hist, - const RegTree& tree); - template void ApplySplit(std::vector nodes, const GHistIndexMatrix& gmat, @@ -308,26 +287,6 @@ class QuantileHistMaker: public TreeUpdater { void FindSplitConditions(const std::vector& nodes, const RegTree& tree, const GHistIndexMatrix& gmat, std::vector* split_conditions); - void InitNewNode(int nid, - const GHistIndexMatrix& gmat, - const std::vector& gpair, - const DMatrix& fmat, - const RegTree& tree); - - // Enumerate the split values of specific feature - // Returns the sum of gradients corresponding to the data points that contains a non-missing - // value for the particular feature fid. - template - GradStats EnumerateSplit(const GHistIndexMatrix &gmat, const GHistRowT &hist, - const NodeEntry &snode, SplitEntry *p_best, bst_uint fid, - bst_uint nodeID, - TreeEvaluator::SplitEvaluator const &evaluator) const; - - // if sum of statistics for non-missing values in the node - // is equal to sum of statistics for all values: - // then - there are no missing values - // else - there are missing values - bool SplitContainsMissingValues(const GradStats e, const NodeEntry& snode); template void BuildLocalHistograms(const GHistIndexMatrix &gmat, @@ -352,10 +311,6 @@ class QuantileHistMaker: public TreeUpdater { int *num_leaves, std::vector* nodes_for_apply_split); - void BuildNodeStats(const GHistIndexMatrix &gmat, - const DMatrix& fmat, - const std::vector &gpair_h, - const std::vector& nodes_for_apply_split, RegTree *p_tree); template void ExpandTree(const GHistIndexMatrix& gmat, const ColumnMatrix& column_matrix, @@ -368,31 +323,24 @@ class QuantileHistMaker: public TreeUpdater { const TrainParam& param_; // number of omp thread used during training int nthread_; - common::ColumnSampler column_sampler_; + std::shared_ptr column_sampler_{ + std::make_shared()}; + + std::vector unused_rows_; // the internal row sets RowSetCollection row_set_collection_; - // tree rows that were not used for current training - std::vector unused_rows_; - // feature vectors for subsampled prediction - std::vector feat_vecs_; - // the temp space for split - std::vector row_split_tloc_; - std::vector best_split_tloc_; - /*! \brief TreeNode Data: statistics for each constructed node */ - std::vector snode_; std::vector gpair_local_; /*! \brief culmulative histogram of gradients. */ HistCollection hist_; /*! \brief culmulative local parent histogram of gradients. */ HistCollection hist_local_worker_; - TreeEvaluator tree_evaluator_; /*! \brief feature with least # of bins. to be used for dense specialization of InitNewNode() */ uint32_t fid_least_bins_; GHistBuilder hist_builder_; std::unique_ptr pruner_; - FeatureInteractionConstraintHost interaction_constraints_; + std::unique_ptr> evaluator_; static constexpr size_t kPartitionBlockSize = 2048; common::PartitionBuilder partition_builder_; @@ -402,10 +350,6 @@ class QuantileHistMaker: public TreeUpdater { DMatrix const* const p_last_fmat_; DMatrix* p_last_fmat_mutable_; - using ExpandQueue = - std::priority_queue, - std::function>; - // key is the node id which should be calculated by Subtraction Trick, value is the node which // provides the evidence for subtraction std::vector nodes_for_subtraction_trick_; @@ -438,7 +382,6 @@ class QuantileHistMaker: public TreeUpdater { std::unique_ptr> double_builder_; std::unique_ptr pruner_; - FeatureInteractionConstraintHost int_constraint_; }; template diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc new file mode 100644 index 000000000000..c9228edf992d --- /dev/null +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -0,0 +1,112 @@ +#include +#include +#include "../../../../src/tree/hist/evaluate_splits.h" +#include "../../../../src/tree/updater_quantile_hist.h" +#include "../../../../src/common/hist_util.h" +#include "../../helpers.h" + +namespace xgboost { +namespace tree { + +template void TestEvaluateSplits() { + int static constexpr kRows = 8, kCols = 16; + auto orig = omp_get_max_threads(); + int32_t n_threads = std::min(omp_get_max_threads(), 4); + omp_set_num_threads(n_threads); + auto sampler = std::make_shared(); + + TrainParam param; + param.UpdateAllowUnknown(Args{{}}); + param.min_child_weight = 0; + param.reg_lambda = 0; + + auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix(); + + auto evaluator = + HistEvaluator{param, dmat->Info(), n_threads, sampler}; + common::HistCollection hist; + std::vector row_gpairs = { + {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f}, + {0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f}}; + + size_t constexpr kMaxBins = 4; + // dense, no missing values + + GHistIndexMatrix gmat(dmat.get(), kMaxBins); + common::RowSetCollection row_set_collection; + std::vector &row_indices = *row_set_collection.Data(); + row_indices.resize(kRows); + std::iota(row_indices.begin(), row_indices.end(), 0); + row_set_collection.Init(); + + auto hist_builder = GHistBuilder(n_threads, gmat.cut.Ptrs().back()); + hist.Init(gmat.cut.Ptrs().back()); + hist.AddHistRow(0); + hist.AllocateAllData(); + hist_builder.template BuildHist(row_gpairs, row_set_collection[0], + gmat, hist[0]); + + // Compute total gradient for all data points + GradientPairPrecise total_gpair; + for (const auto &e : row_gpairs) { + total_gpair += GradientPairPrecise(e); + } + + RegTree tree; + std::vector entries(1); + entries.front().nid = 0; + entries.front().depth = 0; + + evaluator.InitRoot(GradStats{total_gpair}); + evaluator.EvaluateSplits(hist, gmat, tree, &entries); + + auto best_loss_chg = + evaluator.Evaluator().CalcSplitGain( + param, 0, entries.front().split.SplitIndex(), + entries.front().split.left_sum, entries.front().split.right_sum) - + evaluator.Stats().front().root_gain; + ASSERT_EQ(entries.front().split.loss_chg, best_loss_chg); + ASSERT_GT(entries.front().split.loss_chg, 16.2f); + + // Assert that's the best split + for (size_t i = 1; i < gmat.cut.Ptrs().size(); ++i) { + GradStats left, right; + for (size_t j = gmat.cut.Ptrs()[i-1]; j < gmat.cut.Ptrs()[i]; ++j) { + auto loss_chg = + evaluator.Evaluator().CalcSplitGain(param, 0, i - 1, left, right) - + evaluator.Stats().front().root_gain; + ASSERT_GE(best_loss_chg, loss_chg); + left.Add(hist[0][j].GetGrad(), hist[0][j].GetHess()); + right.SetSubstract(GradStats{total_gpair}, left); + } + } + + omp_set_num_threads(orig); +} + +TEST(HistEvaluator, Evaluate) { + TestEvaluateSplits(); + TestEvaluateSplits(); +} + +TEST(HistEvaluator, Apply) { + RegTree tree; + int static constexpr kNRows = 8, kNCols = 16; + TrainParam param; + param.UpdateAllowUnknown(Args{{}}); + auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix(); + auto sampler = std::make_shared(); + auto evaluator_ = + HistEvaluator{param, dmat->Info(), 4, sampler}; + + CPUExpandEntry entry{0, 0, 10.0f}; + entry.split.left_sum = GradStats{0.4, 0.6f}; + entry.split.right_sum = GradStats{0.5, 0.7f}; + + evaluator_.ApplyTreeSplit(entry, &tree); + ASSERT_EQ(tree.NumExtraNodes(), 2); + ASSERT_EQ(tree.Stat(tree[0].LeftChild()).sum_hess, 0.6f); + ASSERT_EQ(tree.Stat(tree[0].RightChild()).sum_hess, 0.7f); +} +} // namespace tree +} // namespace xgboost diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 33772025735f..decde1db1515 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -26,12 +26,9 @@ class QuantileHistMock : public QuantileHistMaker { using RealImpl = QuantileHistMaker::Builder; using GHistRowT = typename RealImpl::GHistRowT; - BuilderMock(const TrainParam& param, - std::unique_ptr pruner, - FeatureInteractionConstraintHost int_constraint, - DMatrix const* fmat) - : RealImpl(1, param, std::move(pruner), - std::move(int_constraint), fmat) {} + BuilderMock(const TrainParam ¶m, std::unique_ptr pruner, + DMatrix const *fmat) + : RealImpl(1, param, std::move(pruner), fmat) {} public: void TestInitData(const GHistIndexMatrix& gmat, @@ -336,92 +333,6 @@ class QuantileHistMock : public QuantileHistMaker { } } - void TestEvaluateSplit(const RegTree& tree) { - std::vector row_gpairs = - { {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f}, - {0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f} }; - size_t constexpr kMaxBins = 4; - auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix(); - // dense, no missing values - - GHistIndexMatrix gmat(dmat.get(), kMaxBins); - - RealImpl::InitData(gmat, *dmat, tree, &row_gpairs); - this->hist_.AddHistRow(0); - this->hist_.AllocateAllData(); - this->hist_builder_.template BuildHist(row_gpairs, this->row_set_collection_[0], - gmat, this->hist_[0]); - - RealImpl::InitNewNode(0, gmat, row_gpairs, *dmat, tree); - - /* Compute correct split (best_split) using the computed histogram */ - const size_t num_row = dmat->Info().num_row_; - const size_t num_feature = dmat->Info().num_col_; - CHECK_EQ(num_row, row_gpairs.size()); - // Compute total gradient for all data points - GradientPairPrecise total_gpair; - for (const auto& e : row_gpairs) { - total_gpair += GradientPairPrecise(e); - } - // Now enumerate all feature*threshold combination to get best split - // To simplify logic, we make some assumptions: - // 1) no missing values in data - // 2) no regularization, i.e. set min_child_weight, reg_lambda, reg_alpha, - // and max_delta_step to 0. - bst_float best_split_gain = 0.0f; - size_t best_split_threshold = std::numeric_limits::max(); - size_t best_split_feature = std::numeric_limits::max(); - // Enumerate all features - for (size_t fid = 0; fid < num_feature; ++fid) { - const size_t bin_id_min = gmat.cut.Ptrs()[fid]; - const size_t bin_id_max = gmat.cut.Ptrs()[fid + 1]; - // Enumerate all bin ID in [bin_id_min, bin_id_max), i.e. every possible - // choice of thresholds for feature fid - for (size_t split_thresh = bin_id_min; - split_thresh < bin_id_max; ++split_thresh) { - // left_sum, right_sum: Gradient sums for data points whose feature - // value is left/right side of the split threshold - GradientPairPrecise left_sum, right_sum; - for (size_t rid = 0; rid < num_row; ++rid) { - for (size_t offset = gmat.row_ptr[rid]; - offset < gmat.row_ptr[rid + 1]; ++offset) { - const size_t bin_id = gmat.index[offset]; - if (bin_id >= bin_id_min && bin_id < bin_id_max) { - if (bin_id <= split_thresh) { - left_sum += GradientPairPrecise(row_gpairs[rid]); - } else { - right_sum += GradientPairPrecise(row_gpairs[rid]); - } - } - } - } - // Now compute gain (change in loss) - auto evaluator = this->tree_evaluator_.GetEvaluator(); - const auto split_gain = evaluator.CalcSplitGain( - this->param_, 0, fid, GradStats(left_sum), GradStats(right_sum)); - if (split_gain > best_split_gain) { - best_split_gain = split_gain; - best_split_feature = fid; - best_split_threshold = split_thresh; - } - } - } - - /* Now compare against result given by EvaluateSplit() */ - CPUExpandEntry node(CPUExpandEntry::kRootNid, - tree.GetDepth(0), - this->snode_[0].best.loss_chg); - RealImpl::EvaluateSplits({node}, gmat, this->hist_, tree); - ASSERT_EQ(this->snode_[0].best.SplitIndex(), best_split_feature); - ASSERT_EQ(this->snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]); - } - - void TestEvaluateSplitParallel(const RegTree &tree) { - omp_set_num_threads(2); - TestEvaluateSplit(tree); - omp_set_num_threads(1); - } - void TestApplySplit(const RegTree& tree) { std::vector row_gpairs = { {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f}, @@ -441,7 +352,6 @@ class QuantileHistMock : public QuantileHistMaker { RealImpl::InitData(gmat, *dmat, tree, &row_gpairs); this->hist_.AddHistRow(0); this->hist_.AllocateAllData(); - RealImpl::InitNewNode(0, gmat, row_gpairs, *dmat, tree); const size_t num_row = dmat->Info().num_row_; // split by feature 0 @@ -513,7 +423,6 @@ class QuantileHistMock : public QuantileHistMaker { new BuilderMock( param_, std::move(pruner_), - int_constraint_, dmat_.get())); if (batch) { float_builder_->SetHistSynchronizer(new BatchHistSynchronizer()); @@ -527,7 +436,6 @@ class QuantileHistMock : public QuantileHistMaker { new BuilderMock( param_, std::move(pruner_), - int_constraint_, dmat_.get())); if (batch) { double_builder_->SetHistSynchronizer(new BatchHistSynchronizer()); @@ -622,23 +530,13 @@ class QuantileHistMock : public QuantileHistMaker { } } - void TestEvaluateSplit() { - RegTree tree = RegTree(); - tree.param.UpdateAllowUnknown(cfg_); - if (double_builder_) { - double_builder_->TestEvaluateSplit(tree); - } else { - float_builder_->TestEvaluateSplit(tree); - } - } - void TestApplySplit() { RegTree tree = RegTree(); tree.param.UpdateAllowUnknown(cfg_); if (double_builder_) { double_builder_->TestApplySplit(tree); } else { - float_builder_->TestEvaluateSplit(tree); + float_builder_->TestApplySplit(tree); } } }; @@ -716,19 +614,6 @@ TEST(QuantileHist, BuildHist) { maker_float.TestBuildHist(); } -TEST(QuantileHist, EvalSplits) { - std::vector> cfg - {{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}, - {"split_evaluator", "elastic_net"}, - {"reg_lambda", "0"}, {"reg_alpha", "0"}, {"max_delta_step", "0"}, - {"min_child_weight", "0"}}; - QuantileHistMock maker(cfg); - maker.TestEvaluateSplit(); - const bool single_precision_histogram = true; - QuantileHistMock maker_float(cfg, single_precision_histogram); - maker_float.TestEvaluateSplit(); -} - TEST(QuantileHist, ApplySplit) { std::vector> cfg {{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())},