Skip to content

Commit

Permalink
dispatched call for update position
Browse files Browse the repository at this point in the history
  • Loading branch information
ShvetsKS committed May 15, 2022
1 parent 912105c commit 3b4ead6
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 123 deletions.
19 changes: 0 additions & 19 deletions src/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,13 @@ void SubtractionHist(GHistRow dst, const GHistRow src1, const GHistRow src2, siz
}
}

// template<typename GradientSumT>
void ClearHist(double* dest_hist,
size_t begin, size_t end) {
for (size_t bin_id = begin; bin_id < end; ++bin_id) {
dest_hist[bin_id] = 0;
}
}
// template void ClearHist(float* dest_hist,
// size_t begin, size_t end);
// template void ClearHist(double* dest_hist,
// size_t begin, size_t end);

// template<typename GradientSumT>
void ReduceHist(double* dest_hist,
const std::vector<std::vector<uint16_t>>& local_threads_mapping,
std::vector<std::vector<std::vector<double>>>* histograms,
Expand Down Expand Up @@ -116,18 +110,5 @@ void ReduceHist(double* dest_hist,
}
}

// template void ReduceHist(float* dest_hist,
// const std::vector<std::vector<uint16_t>>& local_threads_mapping,
// std::vector<std::vector<std::vector<float>>>* histograms,
// const size_t node_displace,
// const std::vector<uint16_t>& threads_id_for_node,
// size_t begin, size_t end);
// template void ReduceHist(double* dest_hist,
// const std::vector<std::vector<uint16_t>>& local_threads_mapping,
// std::vector<std::vector<std::vector<double>>>* histograms,
// const size_t node_displace,
// const std::vector<uint16_t>& threads_id_for_node,
// size_t begin, size_t end);

} // namespace common
} // namespace xgboost
5 changes: 1 addition & 4 deletions src/common/hist_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,15 +450,13 @@ class HistCollection {
std::vector<size_t> row_ptr_;
};


// template<typename GradientSumT>
void ReduceHist(double* dest_hist,
const std::vector<std::vector<uint16_t>>& local_threads_mapping,
std::vector<std::vector<std::vector<double>>>* histograms,
const size_t node_id,
const std::vector<uint16_t>& threads_id_for_node,
size_t begin, size_t end);
// template<typename GradientSumT>

void ClearHist(double* dest_hist,
size_t begin, size_t end);
/*!
Expand All @@ -468,7 +466,6 @@ void ClearHist(double* dest_hist,
*/
class ParallelGHistBuilder {
public:
// using GHistRowT = GHistRow<GradientSumT>;
std::vector<std::vector<std::vector<double>>> histograms_buffer;
std::vector<std::vector<uint16_t>> local_threads_mapping;

Expand Down
3 changes: 0 additions & 3 deletions src/tree/updater_approx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ class GloablApproxBuilder {
CHECK_EQ(out_preds.Size(), data->Info().num_row_);
CHECK(p_last_tree_);

size_t n_nodes = p_last_tree_->GetNodes().size();

auto evaluator = evaluator_.Evaluator();
auto const &tree = *p_last_tree_;
auto const &snode = evaluator_.Stats();
Expand Down Expand Up @@ -275,7 +273,6 @@ class GloablApproxBuilder {
evaluator_.ApplyTreeSplit(candidate, p_tree);
applied[candidate.nid] = candidate;
applied_vec.push_back(candidate);
is_applied = true;
CHECK_EQ(applied[candidate.nid].nid, candidate.nid);
num_leaves++;
int left_child_nidx = tree[candidate.nid].LeftChild();
Expand Down
132 changes: 35 additions & 97 deletions src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,53 +164,20 @@ void QuantileHistMaker::Builder::AddSplitsToTree(
const bool is_loss_guided = static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy)
!= TrainParam::kDepthWise;
std::vector<uint16_t> complete_node_ids;
if (param_.max_depth == 0) {
size_t max_nid = 0;
int max_nid_child = 0;
size_t it = 0;
for (auto const& entry : expand) {
max_nid = std::max(max_nid, static_cast<size_t>(2*entry.nid + 2));
if (entry.IsValid(param_, *num_leaves)) {
nodes_for_apply_split->push_back(entry);
evaluator_->ApplyTreeSplit(entry, p_tree);
++(*num_leaves);
++it;
max_nid_child = std::max(max_nid_child,
static_cast<int>(std::max((*p_tree)[entry.nid].LeftChild(),
(*p_tree)[entry.nid].RightChild())));
}
}
(*num_leaves) -= it;
for (auto const& entry : expand) {
if (entry.IsValid(param_, *num_leaves)) {
(*num_leaves)++;
complete_node_ids.push_back((*p_tree)[entry.nid].LeftChild());
complete_node_ids.push_back((*p_tree)[entry.nid].RightChild());
*is_left_small = entry.split.left_sum.GetHess() <= entry.split.right_sum.GetHess();
if (entry.split.left_sum.GetHess() <= entry.split.right_sum.GetHess() || is_loss_guided) {
smalest_nodes_mask[(*p_tree)[entry.nid].LeftChild()] = true;
} else {
smalest_nodes_mask[(*p_tree)[entry.nid].RightChild()] = true;
}
}
}

} else {
for (auto const& entry : expand) {
if (entry.IsValid(param_, *num_leaves)) {
nodes_for_apply_split->push_back(entry);
evaluator_->ApplyTreeSplit(entry, p_tree);
(*num_leaves)++;
complete_node_ids.push_back((*p_tree)[entry.nid].LeftChild());
complete_node_ids.push_back((*p_tree)[entry.nid].RightChild());
*is_left_small = entry.split.left_sum.GetHess() <= entry.split.right_sum.GetHess();
if (entry.split.left_sum.GetHess() <= entry.split.right_sum.GetHess() || is_loss_guided) {
smalest_nodes_mask[(*p_tree)[entry.nid].LeftChild()] = true;
smalest_nodes_mask[(*p_tree)[entry.nid].RightChild()] = false;
} else {
smalest_nodes_mask[(*p_tree)[entry.nid].RightChild()] = true;
smalest_nodes_mask[ (*p_tree)[entry.nid].LeftChild()] = false;
}
for (auto const& entry : expand) {
if (entry.IsValid(param_, *num_leaves)) {
nodes_for_apply_split->push_back(entry);
evaluator_->ApplyTreeSplit(entry, p_tree);
(*num_leaves)++;
complete_node_ids.push_back((*p_tree)[entry.nid].LeftChild());
complete_node_ids.push_back((*p_tree)[entry.nid].RightChild());
*is_left_small = entry.split.left_sum.GetHess() <= entry.split.right_sum.GetHess();
if (entry.split.left_sum.GetHess() <= entry.split.right_sum.GetHess() || is_loss_guided) {
smalest_nodes_mask[(*p_tree)[entry.nid].LeftChild()] = true;
smalest_nodes_mask[(*p_tree)[entry.nid].RightChild()] = false;
} else {
smalest_nodes_mask[(*p_tree)[entry.nid].RightChild()] = true;
smalest_nodes_mask[ (*p_tree)[entry.nid].LeftChild()] = false;
}
}
}
Expand Down Expand Up @@ -245,15 +212,13 @@ void QuantileHistMaker::Builder::SplitSiblings(
const CPUExpandEntry right_node = CPUExpandEntry(cright, p_tree->GetDepth(cright), 0.0);
nodes_to_evaluate->push_back(left_node);
nodes_to_evaluate->push_back(right_node);
bool is_loss_guide = static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy) ==
TrainParam::kDepthWise ? false : true;
if (entry.split.left_sum.GetHess() <= entry.split.right_sum.GetHess()) {
nodes_for_explicit_hist_build_.push_back(left_node);
nodes_for_subtraction_trick_.push_back(right_node);
} else {
nodes_for_explicit_hist_build_.push_back(right_node);
nodes_for_subtraction_trick_.push_back(left_node);
}
if (entry.split.left_sum.GetHess() <= entry.split.right_sum.GetHess()) {
nodes_for_explicit_hist_build_.push_back(left_node);
nodes_for_subtraction_trick_.push_back(right_node);
} else {
nodes_for_explicit_hist_build_.push_back(right_node);
nodes_for_subtraction_trick_.push_back(left_node);
}
}
monitor_->Stop("SplitSiblings");
}
Expand Down Expand Up @@ -308,47 +273,20 @@ void QuantileHistMaker::Builder::ExpandTree(
size_t page_id{0};
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
CommonRowPartitioner &partitioner = this->partitioner_.at(page_id);
if (is_loss_guide) {
if (page.cut.HasCategorical()) {
partitioner.UpdatePosition<any_missing, BinIdxType, true, true>(this->ctx_, page,
nodes_for_apply_split, p_tree,
depth,
&smalest_nodes_mask,
is_loss_guide,
&split_conditions_,
&split_ind_, param_.max_depth,
&child_node_ids_, is_left_small, true);
} else {
partitioner.UpdatePosition<any_missing, BinIdxType, true, false>(this->ctx_, page,
nodes_for_apply_split, p_tree,
depth,
&smalest_nodes_mask,
is_loss_guide,
&split_conditions_,
&split_ind_, param_.max_depth,
&child_node_ids_, is_left_small, true);
}
} else {
if (page.cut.HasCategorical()) {
partitioner.UpdatePosition<any_missing, BinIdxType, false, true>(this->ctx_, page,
nodes_for_apply_split, p_tree,
depth,
&smalest_nodes_mask,
is_loss_guide,
&split_conditions_,
&split_ind_, param_.max_depth,
&child_node_ids_, is_left_small, true);
} else {
partitioner.UpdatePosition<any_missing, BinIdxType, false, false>(this->ctx_, page,
nodes_for_apply_split, p_tree,
depth,
&smalest_nodes_mask,
is_loss_guide,
&split_conditions_,
&split_ind_, param_.max_depth,
&child_node_ids_, is_left_small, true);
}
}
partitioner.UpdatePositionDispatched({any_missing,
static_cast<common::BinTypeSize>(sizeof(BinIdxType)),
is_loss_guide, page.cut.HasCategorical()},
this->ctx_,
page,
nodes_for_apply_split,
p_tree,
depth,
&smalest_nodes_mask,
is_loss_guide,
&split_conditions_,
&split_ind_, param_.max_depth,
&child_node_ids_, is_left_small,
true);
++page_id;
}

Expand Down

0 comments on commit 3b4ead6

Please sign in to comment.