diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 0454bc4e7278..dad92e630425 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -223,8 +223,8 @@ __device__ bst_node_t GetLeafIndex(bst_row_t ridx, TreeView const &tree, while (!n.IsLeaf()) { float fvalue = loader->GetElement(ridx, n.SplitIndex()); bool is_missing = common::CheckNAN(fvalue); - nidx = GetNextNode(tree.d_tree, nidx, fvalue, is_missing, - tree.cats); + nidx = GetNextNode(tree.d_tree, nidx, fvalue, + is_missing, tree.cats); n = tree.d_tree[nidx]; } return nidx; diff --git a/src/predictor/predict_fn.h b/src/predictor/predict_fn.h index 56cb1ea7fd72..76851f160b24 100644 --- a/src/predictor/predict_fn.h +++ b/src/predictor/predict_fn.h @@ -8,7 +8,7 @@ namespace xgboost { namespace predictor { -template +template inline XGBOOST_DEVICE bst_node_t GetNextNode( common::Span tree, bst_node_t nid, float fvalue, bool is_missing, RegTree::CategoricalSplitMatrix const& cats) { diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 9fbfcee87673..6719574b1a06 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -1060,8 +1060,9 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat, while (!(*this)[nid].IsLeaf()) { split_index = (*this)[nid].SplitIndex(); - nid = predictor::GetNextNode(nodes, nid, feat.GetFvalue(split_index), - feat.IsMissing(split_index), cats); + nid = predictor::GetNextNode(nodes, nid, + feat.GetFvalue(split_index), + feat.IsMissing(split_index), cats); bst_float new_value = this->node_mean_values_[nid]; // update feature weight out_contribs[split_index] += new_value - node_value; diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index 29fb998d279b..cb7daf4e943e 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -129,9 +129,9 @@ class TreeRefresher: public TreeUpdater { // traverse tree while (!tree[pid].IsLeaf()) { unsigned split_index = tree[pid].SplitIndex(); - pid = - predictor::GetNextNode(nodes, pid, feat.GetFvalue(split_index), - feat.IsMissing(split_index), cats); + pid = predictor::GetNextNode( + nodes, pid, feat.GetFvalue(split_index), feat.IsMissing(split_index), + cats); gstats[pid].Add(gpair[ridx]); } }