## Pruning sklearn DecisionTreeClassifiers

This post serves two purposes:
1. It shows a simple quick way of manually pruning selected nodes from the tree.
2. It points out ensuing problems in computing SHAP values

It seems that we need a more comprehensive post pruning function for trees, an issue which I have (unsuccessfully) raised here:

https://github.com/scikit-learn/scikit-learn/issues/18680#issuecomment-716163291

In [30]:
from sklearn.tree import DecisionTreeClassifier  
import copy
import numpy as np

#for plotting
import matplotlib.pyplot as plt
from sklearn import tree

import shap
X,y = shap.datasets.boston()

In [31]:
max_depth=3
tree_B1 = tree.DecisionTreeRegressor(random_state=0,max_depth=max_depth)
tree_B1 = tree_B1.fit(X.values, y)


We now selectively prune the two left and the one right extreme parent nodes 2,5 and 12 (in layer depth2):

In [32]:
tree_B2 = copy.deepcopy(tree_B1)
#prune the tree
for i in [2,5,12]:
    tree_B2.tree_.children_left[i] = -1
    tree_B2.tree_.children_right[i]  = -1
    tree_B2.tree_.n_node_samples[i] = 0
    tree_B2.tree_.weighted_n_node_samples[i]  = 0

We plot the trees side by side
![Original versus pruned tree](figures/BothTrees.png)

The pruning procedure from above did not actually remove the nodes from the tree data structure but simply terminated the links. 
It is not clear to me which functions and/or modules access the nodes

* via tree traversal -- in which case they would "see" a pruned tree, opposed to 
* directly as an array, -- in which case the "shallow" pruning would not have an effect

In the following we investigate the effect of pruning on SHAP values.


### SHAP values

In [33]:
explainer1 = shap.TreeExplainer(tree_B1)
Shap_train1 = explainer1.shap_values(X)

explainer2 = shap.TreeExplainer(tree_B2)
Shap_train2 = explainer2.shap_values(X)

Setting feature_perturbation = "tree_path_dependent" because no background data was given.


AssertionError: The background dataset you provided does not cover all the leaves in the model, so TreeExplainer cannot run with the feature_perturbation="tree_path_dependent" option! Try providing a larger background dataset, or using feature_perturbation="interventional".

In [None]:
shap.summary_plot(Shap_train1, X, plot_type="bar")

In [None]:
shap.summary_plot(Shap_train2, X, plot_type="bar")

What if we deliberately changed the values of the pruned nodes ? Does that affect the SHAP vales ?

In [None]:
#The node IDs of the pruned leafs are:  3,4,6,7,13,14
np.random.seed(123)
tree_B2.tree_.value[[3,4,6,7,13,14]]  = np.reshape(np.random.normal(0,20,6), (6,1,1))
for i in np.arange(tree_B2.tree_.node_count):
    tree_B2.tree_.n_node_samples[i] = 0.5*tree_B2.tree_.n_node_samples[i] #?
    tree_B2.tree_.weighted_n_node_samples[i] = 0.5*tree_B2.tree_.weighted_n_node_samples[i]


In [None]:
explainer2b = shap.TreeExplainer(tree_B2)
Shap_train2b = explainer2b.shap_values(X)
shap.summary_plot(Shap_train2b, X, plot_type="bar")