Skip to content

Commit

Permalink
align tree cfl/reg APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuber21 committed Sep 27, 2023
1 parent 8dfc415 commit 00d8087
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,12 @@ class DAAL_EXPORT ModelBuilder
* \param[in] parentId Parent node to which new node is added (use noParent for root node)
* \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree)
* \param[in] classLabel Class label to be predicted
* \param[in] defaultLeft Behaviour in case of missing values
* \param[in] cover Cover (Hessian sum) of the node
* \return Node identifier
*/
NodeId addLeafNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel, const double cover)
NodeId addLeafNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel, const int defaultLeft,
const double cover)
{
NodeId resId;
_status |= addLeafNodeInternal(treeId, parentId, position, classLabel, cover, resId);
Expand All @@ -123,6 +126,7 @@ class DAAL_EXPORT ModelBuilder
* \param[in] parentId Parent node to which new node is added (use noParent for root node)
* \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree)
* \param[in] proba Array with probability values for each class
* \param[in] cover Cover (Hessian sum) of the node
* \return Node identifier
*/
NodeId addLeafNodeByProba(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba, const double cover)
Expand All @@ -140,13 +144,15 @@ class DAAL_EXPORT ModelBuilder
* \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree)
* \param[in] featureIndex Feature index for splitting
* \param[in] featureValue Feature value for splitting
* \param[in] defaultLeft Behaviour in case of missing values
* \param[in] cover Cover (Hessian sum) of the node
* \return Node identifier
*/
NodeId addSplitNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex, const double featureValue,
const double cover)
const int defaultLeft, const double cover)
{
NodeId resId;
_status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, cover, resId);
_status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, defaultLeft, cover, resId);
services::throwIfPossible(_status);
return resId;
}
Expand Down Expand Up @@ -190,7 +196,7 @@ class DAAL_EXPORT ModelBuilder
services::Status addLeafNodeByProbaInternal(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba,
const double cover, NodeId & res);
services::Status addSplitNodeInternal(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex,
const double featureValue, const double cover, const int defaultLeft, NodeId & res);
const double featureValue, const int defaultLeft, const double cover, NodeId & res);

private:
size_t _nClasses;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@ class DAAL_EXPORT ModelBuilder
* \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree)
* \param[in] featureIndex Feature index for splitting
* \param[in] featureValue Feature value for splitting
* \param[in] cover Cover of the node (sum_hess)
* \param[in] defaultLeft Behaviour in case of missing values
* \param[in] cover Cover of the node (sum_hess)
* \return Node identifier
*/
NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, double cover, int defaultLeft = 0)
NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft, double cover)
{
NodeId resId;
_status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, cover, resId, defaultLeft);
_status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, defaultLeft, cover, resId);
services::throwIfPossible(_status);
return resId;
}
Expand Down Expand Up @@ -160,8 +160,8 @@ class DAAL_EXPORT ModelBuilder
services::Status initialize(size_t nFeatures, size_t nIterations);
services::Status createTreeInternal(size_t nNodes, TreeId & resId);
services::Status addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, double cover, NodeId & res);
services::Status addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, double cover,
int defaultLeft, NodeId & res);
services::Status addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft,
double cover, NodeId & res);
services::Status convertModelInternal();
};
/** @} */
Expand Down
2 changes: 1 addition & 1 deletion cpp/daal/src/algorithms/dtrees/dtrees_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ void setProbabilities(const size_t treeId, const size_t nodeId, const size_t res
}

services::Status addSplitNodeInternal(data_management::DataCollectionPtr & serializationData, size_t treeId, size_t parentId, size_t position,
size_t featureIndex, double featureValue, double cover, int defaultLeft, size_t & res)
size_t featureIndex, double featureValue, int defaultLeft, double cover, size_t & res)
{
const size_t noParent = static_cast<size_t>(-1);
services::Status s;
Expand Down
2 changes: 1 addition & 1 deletion cpp/daal/src/algorithms/dtrees/dtrees_model_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ void setNode(DecisionTreeNode & node, int featureIndex, size_t classLabel, doubl
void setNode(DecisionTreeNode & node, int featureIndex, double response, double cover);

services::Status addSplitNodeInternal(data_management::DataCollectionPtr & serializationData, size_t treeId, size_t parentId, size_t position,
size_t featureIndex, double featureValue, double cover, int defaultLeft, size_t & res);
size_t featureIndex, double featureValue, int defaultLeft, double cover, size_t & res);

void setProbabilities(const size_t treeId, const size_t nodeId, const size_t response, const data_management::DataCollectionPtr probTbl,
const double * const prob);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,12 @@ services::Status ModelBuilder::addLeafNodeByProbaInternal(const TreeId treeId, c
}

services::Status ModelBuilder::addSplitNodeInternal(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex,
const double featureValue, const double cover, const int defaultLeft, NodeId & res)
const double featureValue, const int defaultLeft, const double cover, NodeId & res)
{
decision_forest::classification::internal::ModelImpl & modelImplRef =
daal::algorithms::dtrees::internal::getModelRef<decision_forest::classification::internal::ModelImpl, ModelPtr>(_model);
return daal::algorithms::dtrees::internal::addSplitNodeInternal(modelImplRef._serializationData, treeId, parentId, position, featureIndex,
featureValue, cover, defaultLeft, res);
featureValue, defaultLeft, cover, res);
}

} // namespace interface2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ services::Status ModelBuilder::addSplitNodeInternal(TreeId treeId, NodeId parent
gbt::classification::internal::ModelImpl & modelImplRef =
daal::algorithms::dtrees::internal::getModelRef<daal::algorithms::gbt::classification::internal::ModelImpl, ModelPtr>(_model);
return daal::algorithms::dtrees::internal::addSplitNodeInternal(modelImplRef._serializationData, treeId, parentId, position, featureIndex,
featureValue, cover, defaultLeft, res);
featureValue, defaultLeft, cover, res);
}

} // namespace interface1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ services::Status ModelBuilder::addLeafNodeInternal(TreeId treeId, NodeId parentI
}

services::Status ModelBuilder::addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue,
double cover, int defaultLeft, NodeId & res)
int defaultLeft, double cover, NodeId & res)
{
gbt::regression::internal::ModelImpl & modelImplRef =
daal::algorithms::dtrees::internal::getModelRef<daal::algorithms::gbt::regression::internal::ModelImpl, ModelPtr>(_model);
return daal::algorithms::dtrees::internal::addSplitNodeInternal(modelImplRef._serializationData, treeId, parentId, position, featureIndex,
featureValue, cover, defaultLeft, res);
featureValue, defaultLeft, cover, res);
}

} // namespace interface1
Expand Down

0 comments on commit 00d8087

Please sign in to comment.