Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SHAP calculation to GBT regression #2460

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
6e86e4e
WIP: Add SHAP contributions and interactions
ahuber21 Jul 11, 2023
a36406e
no more segfaults
ahuber21 Sep 14, 2023
396230f
fix pred_interactions
ahuber21 Sep 18, 2023
a2ec071
add fast treeshap v1
ahuber21 Sep 19, 2023
b922070
Add combinationSum calculation for Fast TreeSHAP v2
ahuber21 Sep 21, 2023
d31a577
daal_calloc -> daal_malloc
ahuber21 Sep 25, 2023
b7389da
support shap contribution calculation with Fast TreeSHAP v2
ahuber21 Sep 26, 2023
0b45b88
Consistently add cover to daaal APIs, add output parameters to end of…
ahuber21 Sep 27, 2023
d396842
align tree cfl/reg APIs
ahuber21 Sep 27, 2023
63da212
restore .gitignore from master
ahuber21 Sep 27, 2023
9032573
cleanup for review
ahuber21 Sep 28, 2023
2cb96e1
add newline
ahuber21 Sep 28, 2023
b873df4
remove defaultLeft value that's not needed
ahuber21 Sep 28, 2023
7cf533f
Update model builder examples
ahuber21 Sep 28, 2023
55f3cb2
Add backwards-compatible model builder API & deprecate decls
ahuber21 Sep 28, 2023
c130bb0
fix: remove dead code
ahuber21 Oct 4, 2023
22b4fb9
fix: simplify number of nodes calculation
ahuber21 Oct 4, 2023
dbe0af3
chore: typos and code style
ahuber21 Oct 4, 2023
e635831
Fix bazel build
ahuber21 Oct 5, 2023
060042c
fix: remove dead member variable in GbtDecisionTree
ahuber21 Oct 6, 2023
13bb552
feat: add first unit tests for model builders
ahuber21 Oct 6, 2023
4438cb6
revert dal_module back to daal_module
ahuber21 Oct 6, 2023
65fcdc8
feat: execute dal unit tests in CI
ahuber21 Oct 6, 2023
f8f648c
reorganize how tests are executed
ahuber21 Oct 6, 2023
65e7806
add license
ahuber21 Oct 6, 2023
e418ec4
Fix new_ts: nodeIsLeaf/nodeIsDummyLeaf internal usage & classificatio…
ahuber21 Oct 9, 2023
bc04a68
Update TreeVisitor with node cover value
ahuber21 Oct 9, 2023
6473d69
remove deprecation version in comment
ahuber21 Oct 10, 2023
7bc51f7
remove skipping of XGBoost base_score tree
ahuber21 Oct 11, 2023
bc02093
feature: proper support for XGBoost's base_score value
ahuber21 Oct 12, 2023
30f4f28
Update code attributions / cite / license
ahuber21 Oct 13, 2023
1a3906f
typo
ahuber21 Oct 13, 2023
932dfa5
chore: remove resIncrement from GBT predict
ahuber21 Oct 16, 2023
0d229ed
Document functions and separate declarations and implementations
ahuber21 Oct 18, 2023
6c87c7c
review comments #1
ahuber21 Oct 18, 2023
a367244
review comments #2 - fix pImpl idiom
ahuber21 Oct 18, 2023
10a9984
refactor: replace boolean parameters with DAAL_UINT64 flag
ahuber21 Oct 19, 2023
3ef5a9c
fix: usage of bias/margin for LightGBM models
ahuber21 Oct 19, 2023
73326ad
review comments #2
ahuber21 Oct 19, 2023
3ba321b
fixup endless for loop
ahuber21 Oct 19, 2023
48167a5
use TArray, introduce TreeShapVersion enum
ahuber21 Oct 19, 2023
5e94bcf
use TArray where possible
ahuber21 Oct 19, 2023
6a7b7c9
fix: move data field to implementation class
ahuber21 Oct 20, 2023
d65f806
Update cpp/daal/include/algorithms/tree_utils/tree_utils.h
ahuber21 Oct 20, 2023
f0a42a9
add typedef to shorten statements
ahuber21 Oct 20, 2023
b4a5058
provide doxygen description of gbt classification funtions
Oct 20, 2023
137e1bb
fix some typos
ahuber21 Oct 20, 2023
2c7b59a
consistently use size_t for node indexing; unsigned -> uint32_t
ahuber21 Oct 23, 2023
5f19b8c
fix: don't include test in release
ahuber21 Oct 26, 2023
a0c2c7b
fix multiline comments
ahuber21 Oct 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .ci/pipeline/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@ jobs:
--test_thread_mode=par
displayName: 'cpp-examples-thread-release-dynamic'

- script: |
bazel test //cpp/daal:tests
displayName: 'daal-tests-algorithms'

- script: |
bazel test //cpp/oneapi/dal:tests \
--config=host \
Expand Down
37 changes: 30 additions & 7 deletions cpp/daal/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
package(default_visibility = ["//visibility:public"])
load("@onedal//dev/bazel:dal.bzl",
"dal_test_suite",
"dal_collect_test_suites",
)
load("@onedal//dev/bazel:daal.bzl",
"daal_module",
"daal_static_lib",
Expand Down Expand Up @@ -28,7 +32,7 @@ daal_module(
deps = select({
"@config//:backend_ref": [ "@openblas//:openblas",
],
"//conditions:default": [ "@micromkl//:mkl_thr",
"//conditions:default": [ "@micromkl//:mkl_thr",
],
}),
)
Expand All @@ -54,7 +58,7 @@ daal_module(
"DAAL_HIDE_DEPRECATED",
],
deps = select({
"@config//:backend_ref": [
"@config//:backend_ref": [
":public_includes",
"@openblas//:headers",
],
Expand Down Expand Up @@ -123,11 +127,11 @@ daal_module(
hdrs = glob(["src/sycl/**/*.h", "src/sycl/**/*.cl"]),
srcs = glob(["src/sycl/**/*.cpp"]),
deps = select({
"@config//:backend_ref": [
"@config//:backend_ref": [
":services",
"@onedal//cpp/daal/src/algorithms/engines:kernel",
],
"//conditions:default": [
"//conditions:default": [
":services",
"@onedal//cpp/daal/src/algorithms/engines:kernel",
"@micromkl_dpc//:headers",
Expand All @@ -146,13 +150,13 @@ daal_module(
"TBB_USE_ASSERT=0",
],
deps = select({
"@config//:backend_ref": [
"@config//:backend_ref": [
":threading_headers",
":mathbackend_thread",
"@tbb//:tbb",
"@tbb//:tbbmalloc",
],
"//conditions:default": [
"//conditions:default": [
":threading_headers",
":mathbackend_thread",
"@tbb//:tbb",
Expand Down Expand Up @@ -269,7 +273,7 @@ daal_dynamic_lib(
],
def_file = select({
"@config//:backend_ref": "src/threading/export_lnx32e.ref.def",
"//conditions:default": "src/threading/export_lnx32e.mkl.def",
"//conditions:default": "src/threading/export_lnx32e.mkl.def",
}),
)

Expand Down Expand Up @@ -316,3 +320,22 @@ filegroup(
":thread_static",
],
)

dal_test_suite(
name = "unit_tests",
framework = "catch2",
srcs = glob([
"test/*.cpp",
]),
)

dal_collect_test_suites(
name = "tests",
root = "@onedal//cpp/daal/src/algorithms",
modules = [
"dtrees/gbt/regression"
],
tests = [
":unit_tests",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -107,49 +107,79 @@ class DAAL_EXPORT ModelBuilder
* \param[in] parentId Parent node to which new node is added (use noParent for root node)
* \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree)
* \param[in] classLabel Class label to be predicted
* \param[in] cover Cover (Hessian sum) of the node
* \return Node identifier
*/
NodeId addLeafNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel)
NodeId addLeafNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel, const double cover)
{
NodeId resId;
_status |= addLeafNodeInternal(treeId, parentId, position, classLabel, resId);
_status |= addLeafNodeInternal(treeId, parentId, position, classLabel, cover, resId);
services::throwIfPossible(_status);
return resId;
}

/**
* \DAAL_DEPRECATED
*/
DAAL_DEPRECATED NodeId addLeafNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel)
{
return addLeafNode(treeId, parentId, position, classLabel, 0);
}

/**
* Create Leaf node and add it to certain tree
* \param[in] treeId Tree to which new node is added
* \param[in] parentId Parent node to which new node is added (use noParent for root node)
* \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree)
* \param[in] proba Array with probability values for each class
* \param[in] cover Cover (Hessian sum) of the node
* \return Node identifier
*/
NodeId addLeafNodeByProba(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba)
NodeId addLeafNodeByProba(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba, const double cover)
{
NodeId resId;
_status |= addLeafNodeByProbaInternal(treeId, parentId, position, proba, resId);
_status |= addLeafNodeByProbaInternal(treeId, parentId, position, proba, cover, resId);
services::throwIfPossible(_status);
return resId;
}

/**
* \DAAL_DEPRECATED
*/
DAAL_DEPRECATED NodeId addLeafNodeByProba(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba)
{
return addLeafNodeByProba(treeId, parentId, position, proba, 0);
}

/**
* Create Split node and add it to certain tree
* \param[in] treeId Tree to which new node is added
* \param[in] parentId Parent node to which new node is added (use noParent for root node)
* \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree)
* \param[in] featureIndex Feature index for splitting
* \param[in] featureValue Feature value for splitting
* \param[in] defaultLeft Behaviour in case of missing values
* \param[in] cover Cover (Hessian sum) of the node
* \return Node identifier
*/
NodeId addSplitNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex, const double featureValue)
NodeId addSplitNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex, const double featureValue,
const int defaultLeft, const double cover)
{
NodeId resId;
_status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, resId);
_status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, defaultLeft, cover, resId);
services::throwIfPossible(_status);
return resId;
}

/**
* \DAAL_DEPRECATED
*/
DAAL_DEPRECATED NodeId addSplitNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex,
const double featureValue)
{
return addSplitNode(treeId, parentId, position, featureIndex, featureValue, 0, 0);
}

void setNFeatures(size_t nFeatures)
{
if (!_model.get())
Expand Down Expand Up @@ -184,11 +214,12 @@ class DAAL_EXPORT ModelBuilder
services::Status _status;
services::Status initialize(const size_t nClasses, const size_t nTrees);
services::Status createTreeInternal(const size_t nNodes, TreeId & resId);
services::Status addLeafNodeInternal(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel, NodeId & res);
services::Status addLeafNodeInternal(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel,
const double cover, NodeId & res);
services::Status addLeafNodeByProbaInternal(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba,
NodeId & res);
const double cover, NodeId & res);
services::Status addSplitNodeInternal(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex,
const double featureValue, NodeId & res);
const double featureValue, const int defaultLeft, const double cover, NodeId & res);

private:
size_t _nClasses;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,20 @@ class DAAL_EXPORT Model : public classifier::Model
*/
virtual size_t getNumberOfTrees() const = 0;

/**
* \brief Set the Prediction Bias term
*
* \param value global prediction bias
*/
virtual void setPredictionBias(double value) = 0;

/**
* \brief Get the Prediction Bias term
*
* \return double prediction bias
*/
virtual double getPredictionBias() const = 0;

protected:
Model() : classifier::Model() {}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,25 @@ class DAAL_EXPORT ModelBuilder
* \param[in] parentId Parent node to which new node is added (use noParent for root node)
* \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree)
* \param[in] response Response value for leaf node to be predicted
* \param[in] cover Cover (Hessian sum) of the node
* \return Node identifier
*/
NodeId addLeafNode(TreeId treeId, NodeId parentId, size_t position, double response)
NodeId addLeafNode(TreeId treeId, NodeId parentId, size_t position, double response, double cover)
{
NodeId resId;
_status |= addLeafNodeInternal(treeId, parentId, position, response, resId);
_status |= addLeafNodeInternal(treeId, parentId, position, response, cover, resId);
services::throwIfPossible(_status);
return resId;
}

/**
* \DAAL_DEPRECATED
*/
DAAL_DEPRECATED NodeId addLeafNode(TreeId treeId, NodeId parentId, size_t position, double response)
{
return addLeafNode(treeId, parentId, position, response, 0);
}

/**
* Create Split node and add it to certain tree
* \param[in] treeId Tree to which new node is added
Expand All @@ -127,16 +136,25 @@ class DAAL_EXPORT ModelBuilder
* \param[in] featureIndex Feature index for splitting
* \param[in] featureValue Feature value for splitting
* \param[in] defaultLeft Behaviour in case of missing values
* \param[in] cover Cover (Hessian sum) of the node
* \return Node identifier
*/
NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft = 0)
NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft, double cover)
{
NodeId resId;
_status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, resId, defaultLeft);
_status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, defaultLeft, cover, resId);
services::throwIfPossible(_status);
return resId;
}

/**
* \DAAL_DEPRECATED
*/
DAAL_DEPRECATED NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue)
ahuber21 marked this conversation as resolved.
Show resolved Hide resolved
{
return addSplitNode(treeId, parentId, position, featureIndex, featureValue, 0, 0);
}

/**
* Get built model
* \return Model pointer
Expand All @@ -159,9 +177,9 @@ class DAAL_EXPORT ModelBuilder
services::Status _status;
services::Status initialize(size_t nFeatures, size_t nIterations, size_t nClasses);
services::Status createTreeInternal(size_t nNodes, size_t classLabel, TreeId & resId);
services::Status addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, NodeId & res);
services::Status addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, NodeId & res,
int defaultLeft);
services::Status addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, const double cover, NodeId & res);
services::Status addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft,
const double cover, NodeId & res);
services::Status convertModelInternal();
size_t _nClasses;
size_t _nIterations;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ enum Method
defaultDense = 0 /*!< Default method */
};

/**
* <a name="DAAL-ENUM-ALGORITHMS__GBT__CLASSIFICATION__PREDICTION__RESULTTOCOMPUTEID"></a>
* Available identifiers to specify the result to compute - results are mutually exclusive
*/
enum ResultToComputeId
{
predictionResult = (1 << 0), /*!< Compute the regular prediction */
shapContributions = (1 << 1), /*!< Compute SHAP contribution values */
shapInteractions = (1 << 2) /*!< Compute SHAP interaction values */
};

/**
* \brief Contains version 2.0 of the Intel(R) oneAPI Data Analytics Library interface.
*/
Expand All @@ -70,9 +81,12 @@ namespace interface2
/* [Parameter source code] */
struct DAAL_EXPORT Parameter : public daal::algorithms::classifier::Parameter
{
Parameter(size_t nClasses = 2) : daal::algorithms::classifier::Parameter(nClasses), nIterations(0) {}
Parameter(const Parameter & o) : daal::algorithms::classifier::Parameter(o), nIterations(o.nIterations) {}
size_t nIterations; /*!< Number of iterations of the trained model to be used for prediction */
typedef daal::algorithms::classifier::Parameter super;

Parameter(size_t nClasses = 2) : super(nClasses), nIterations(0), resultsToCompute(predictionResult) {}
Parameter(const Parameter & o) : super(o), nIterations(o.nIterations), resultsToCompute(o.resultsToCompute) {}
size_t nIterations; /*!< Number of iterations of the trained model to be used for prediction */
DAAL_UINT64 resultsToCompute; /*!< 64 bit integer flag that indicates the results to compute */
};
/* [Parameter source code] */
} // namespace interface2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,20 @@ class DAAL_EXPORT Model : public algorithms::regression::Model
*/
virtual size_t getNumberOfTrees() const = 0;

/**
* \brief Set the Prediction Bias term
*
* \param value global prediction bias
*/
virtual void setPredictionBias(double value) = 0;

/**
* \brief Get the Prediction Bias term
*
* \return double prediction bias
*/
virtual double getPredictionBias() const = 0;

protected:
Model();
};
Expand Down
Loading
Loading