# FastTreeSHAP in Census Income Data

This notebook contains usages and detailed comparisons of FastTreeSHAP v1, FastTreeSHAP v2 and the original TreeSHAP in classification problems using sklearn and xgboost. It also contains the discussions of automatic algorithm selection. The source of census income data is https://archive.ics.uci.edu/ml/datasets/census+income.

## Load Python libraries

In [None]:
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score, accuracy_score
import xgboost as xgb
import fasttreeshap
import time

## Pre-process training and testing data

In [None]:
# source of data: https://archive.ics.uci.edu/ml/datasets/census+income
train = pd.read_csv("../data/adult_data.txt", sep = ",\s+", header = None, engine = "python")
test = pd.read_csv("../data/adult_test.txt", sep = ",\s+", header = None, skiprows = 1, engine = "python")
label_train = train[14].map({"<=50K": 0, ">50K": 1}).tolist()
label_test = test[14].map({"<=50K.": 0, ">50K.": 1}).tolist()
train = train.iloc[:, :-2]
test = test.iloc[:, :-2]

# one-hot-encoding on categorical features
feature_names = ["age", "workclass", "fnlwgt", "education", "education-num", "marital-status", "occupation", 
                 "relationship", "race", "sex", "capital-gain", "capital-loss", "hours-per-week"]
train.columns = feature_names
test.columns = feature_names
categorical_feature_names = ["workclass", "education", "marital-status", "occupation", "relationship", "race", "sex"]
def dummy_transform(df):
    for name in categorical_feature_names:
        dummy_df = pd.get_dummies(df[name])
        if "?" in dummy_df.columns.values:
            dummy_df.drop("?", axis=1, inplace=True)
        df = pd.concat([df, dummy_df], axis=1)
        df.drop(name, axis=1, inplace=True)
    return df
train = dummy_transform(train)
test = dummy_transform(test)
print("Training data has {} rows and {} columns.".format(train.shape[0], train.shape[1])) 
print("Testing data has {} rows and {} columns.".format(test.shape[0], test.shape[1])) 

## Train a random forest model and compute SHAP values

In [None]:
n_estimators = 100  # number of trees in random forest model
max_depth = 8  # maximum depth of any trees in random forest model

In [None]:
# train a random forest model
rf_model = RandomForestClassifier(n_estimators = n_estimators, max_depth = max_depth, random_state = 0)
rf_model.fit(train, label_train)
print("AUC on testing set is {:.2f}.".format(roc_auc_score(label_test, rf_model.predict_proba(test)[:, 1])))
print("Accuracy on testing set is {:.2f}.".format(accuracy_score(label_test, rf_model.predict(test))))

In [None]:
# obtain total number of leaves
shap_explainer = fasttreeshap.TreeExplainer(rf_model)
num_leaves = sum(shap_explainer.model.num_nodes) - sum(sum(shap_explainer.model.children_left > 0))
print("Total number of leaves is {}.".format(num_leaves))

In [None]:
# estimate memory usage of FastTreeSHAP v2 since FastTreeSHAP v2 has a stricter memory constraint than
# TreeSHAP and FastTreeSHAP v1
max_node = max(shap_explainer.model.num_nodes)
max_leaves = (max_node + 1) // 2
max_combinations = 2**max_depth
memory = max_leaves * max_combinations * 8
if memory < 1024:
    print("Memory usage of FastTreeSHAP v2 is around {:.2f}B.".format(memory))
elif memory / 1024 < 1024:
    print("Memory usage of FastTreeSHAP v2 is around {:.2f}KB.".format(memory / 1024))
elif memory / 1024**2 < 1024:
    print("Memory usage of FastTreeSHAP v2 is around {:.2f}MB.".format(memory / 1024**2))
else:
    print("Memory usage of FastTreeSHAP v2 is around {:.2f}GB.".format(memory / 1024**3))

### Compute SHAP values via different versions of TreeSHAP

In [None]:
num_sample = 1000  # number of samples to be explained

In [None]:
# compute SHAP values via FastTreeSHAP v0 (i.e., original TreeSHAP)
shap_explainer = fasttreeshap.TreeExplainer(rf_model, algorithm = "v0")
shap_values_v0 = shap_explainer(test.iloc[:num_sample]).values
shap_values_v0.shape

In [None]:
# compute SHAP values via FastTreeSHAP v1
shap_explainer = fasttreeshap.TreeExplainer(rf_model, algorithm = "v1")
shap_values_v1 = shap_explainer(test.iloc[:num_sample]).values
shap_values_v1.shape

In [None]:
# justify the correctness of FastTreeSHAP v1
print("Maximum difference of SHAP values between v1 and v0 is {:.2e}.".format(
    np.max(abs(shap_values_v1 - shap_values_v0))))

In [None]:
# compute SHAP values via FastTreeSHAP v2
shap_explainer = fasttreeshap.TreeExplainer(rf_model, algorithm = "v2")
shap_values_v2 = shap_explainer(test.iloc[:num_sample]).values
shap_values_v2.shape

In [None]:
# justify the correctness of FastTreeSHAP v2
print("Maximum difference of SHAP values between v2 and v0 is {:.2e}.".format(
    np.max(abs(shap_values_v2 - shap_values_v0))))

In [None]:
# compute SHAP values via automatic TreeSHAP algorithm selection
shap_explainer = fasttreeshap.TreeExplainer(rf_model, algorithm = "auto")
shap_values_auto = shap_explainer(test.iloc[:num_sample]).values
shap_values_auto.shape

In [None]:
# justify the correctness of automatically selected TreeSHAP algorithm
# it turns out that "auto" selects "v2" as the most appropriate TreeSHAP algorithm
print("Maximum difference of SHAP values between auto and v0 is {:.2e}.".format(
    np.max(abs(shap_values_auto - shap_values_v0))))

### Compare running times of different versions of TreeSHAP in computing SHAP values

In [None]:
# compute SHAP values/SHAP interaction values via TreeSHAP algorithm with version "algorithm_version"
def run_fasttreeshap(model, sample, interactions, algorithm_version, num_round, num_sample, shortcut = False):
    shap_explainer = fasttreeshap.TreeExplainer(model, algorithm = algorithm_version, shortcut = shortcut)
    run_time = np.zeros(num_round)
    for i in range(num_round):
        start = time.time()
        shap_values = shap_explainer(sample.iloc[:num_sample], interactions = interactions).values
        run_time[i] = time.time() - start
        print("Round {} takes {:.3f} sec.".format(i + 1, run_time[i]))
    print("Average running time of {} is {:.3f} sec (std {:.3f} sec){}.".format(
        algorithm_version, np.mean(run_time), np.std(run_time), " (with shortcut)" if shortcut else ""))

In [None]:
num_sample = 1000  # number of samples to be explained
num_round = 5  # number of rounds to record mean and standard deviation of running time

In [None]:
# run FastTreeSHAP v0 (i.e., original TreeSHAP) multiple times and record its average running time
run_fasttreeshap(
    model = rf_model, sample = test, interactions = False, algorithm_version = "v0", 
    num_round = num_round, num_sample = num_sample)

In [None]:
# run FastTreeSHAP v1 multiple times and record its average running time
run_fasttreeshap(
    model = rf_model, sample = test, interactions = False, algorithm_version = "v1", 
    num_round = num_round, num_sample = num_sample)

In [None]:
# run FastTreeSHAP v2 multiple times and record its average running time
run_fasttreeshap(
    model = rf_model, sample = test, interactions = False, algorithm_version = "v2", 
    num_round = num_round, num_sample = num_sample)

In [None]:
# run automatically selected TreeSHAP algorithm multiple times and record its average running time
# it turns out that "auto" selects "v2" as the most appropriate TreeSHAP algorithm
run_fasttreeshap(
    model = rf_model, sample = test, interactions = False, algorithm_version = "auto", 
    num_round = num_round, num_sample = num_sample)

### Compute SHAP interaction values via different versions of TreeSHAP

In [None]:
num_sample = 10  # number of samples to be explained

In [None]:
# compute SHAP interaction values via FastTreeSHAP v0 (i.e., original TreeSHAP)
shap_explainer = fasttreeshap.TreeExplainer(rf_model, algorithm = "v0")
shap_interaction_values_v0 = shap_explainer(test.iloc[:num_sample], interactions = True).values
shap_interaction_values_v0.shape

In [None]:
# compute SHAP interaction values via FastTreeSHAP v1
shap_explainer = fasttreeshap.TreeExplainer(rf_model, algorithm = "v1")
shap_interaction_values_v1 = shap_explainer(test.iloc[:num_sample], interactions = True).values
shap_interaction_values_v1.shape

In [None]:
# justify the correctness of FastTreeSHAP v1
print("Maximum difference of SHAP interaction values between v1 and v0 is {:.2e}.".format(
    np.max(abs(shap_interaction_values_v1 - shap_interaction_values_v0))))

In [None]:
# compute SHAP interaction values via automatic TreeSHAP algorithm selection
# v1 is always preferred to v0 in any use cases, and v2 does not support interactions
shap_explainer = fasttreeshap.TreeExplainer(rf_model, algorithm = "auto")
shap_interaction_values_auto = shap_explainer(test.iloc[:num_sample], interactions = True).values
shap_interaction_values_auto.shape

In [None]:
# justify the correctness of automatically selected TreeSHAP algorithm
print("Maximum difference of SHAP interaction values between auto and v0 is {:.2e}.".format(
    np.max(abs(shap_interaction_values_auto - shap_interaction_values_v0))))

### Compare running times of different versions of TreeSHAP in computing SHAP interaction values

In [None]:
num_sample = 10  # number of samples to be explained
num_round = 5  # number of rounds to record mean and standard deviation of running time

In [None]:
# run FastTreeSHAP v0 (i.e., original TreeSHAP) multiple times and record its average running time
run_fasttreeshap(
    model = rf_model, sample = test, interactions = True, algorithm_version = "v0", 
    num_round = num_round, num_sample = num_sample)

In [None]:
# run FastTreeSHAP v1 multiple times and record its average running time
run_fasttreeshap(
    model = rf_model, sample = test, interactions = True, algorithm_version = "v1", 
    num_round = num_round, num_sample = num_sample)

In [None]:
# run automatically selected TreeSHAP algorithm multiple times and record its average running time
# v1 is always preferred to v0 in any use cases, and v2 does not support interactions
run_fasttreeshap(
    model = rf_model, sample = test, interactions = True, algorithm_version = "auto", 
    num_round = num_round, num_sample = num_sample)

## Train an xgboost model and compute SHAP values

In [None]:
n_estimators = 100  # number of trees in xgboost model
max_depth = 8  # maximum depth of any trees in xgboost model

In [None]:
# train an xgboost model
xgb_model = xgb.XGBClassifier(
    max_depth = max_depth, learning_rate = 0.1, n_estimators = n_estimators, n_jobs = 4, 
    subsample = 1, colsample_bytree = 1, colsample_bylevel = 1, reg_alpha = 0, reg_lambda = 1,
    scale_pos_weight = 1, random_state = 0)
xgb_model.fit(train, label_train)
print("AUC on testing set is {:.2f}.".format(roc_auc_score(label_test, xgb_model.predict_proba(test)[:, 1])))
print("Accuracy on testing set is {:.2f}.".format(accuracy_score(label_test, xgb_model.predict(test))))

In [None]:
# obtain total number of leaves
shap_explainer = fasttreeshap.TreeExplainer(xgb_model)
num_leaves = sum(shap_explainer.model.num_nodes) - sum(sum(shap_explainer.model.children_left > 0))
print("Total number of leaves is {}.".format(num_leaves))

In [None]:
# estimate memory usage of FastTreeSHAP v2 since FastTreeSHAP v2 has a stricter memory constraint than
# TreeSHAP and FastTreeSHAP v1
max_node = max(shap_explainer.model.num_nodes)
max_leaves = (max_node + 1) // 2
max_combinations = 2**max_depth
memory = max_leaves * max_combinations * 8
if memory < 1024:
    print("Memory usage of FastTreeSHAP v2 is around {:.2f}B.".format(memory))
elif memory / 1024 < 1024:
    print("Memory usage of FastTreeSHAP v2 is around {:.2f}KB.".format(memory / 1024))
elif memory / 1024**2 < 1024:
    print("Memory usage of FastTreeSHAP v2 is around {:.2f}MB.".format(memory / 1024**2))
else:
    print("Memory usage of FastTreeSHAP v2 is around {:.2f}GB.".format(memory / 1024**3))

### Compute SHAP values via different versions of TreeSHAP

In [None]:
num_sample = 1000  # number of samples to be explained

In [None]:
# compute SHAP values via "shortcut" (i.e., original TreeSHAP in xgboost library)
# parallel computing is enabled in "shortcut"
shap_explainer = fasttreeshap.TreeExplainer(xgb_model, algorithm = "v0", shortcut = True)
shap_values_shortcut = shap_explainer(test.iloc[:num_sample]).values
shap_values_shortcut.shape

In [None]:
# compute SHAP values via FastTreeSHAP v0 (i.e., original TreeSHAP in shap library)
shap_explainer = fasttreeshap.TreeExplainer(xgb_model, algorithm = "v0", shortcut = False)
shap_values_v0 = shap_explainer(test.iloc[:num_sample]).values
shap_values_v0.shape

In [None]:
# justify the correctness of FastTreeSHAP v0
print("Mean and maximum differences of SHAP values between v0 and shortcut is {:.2e} and {:.2e}.".format(
    np.mean(abs(shap_values_v0 - shap_values_shortcut)), np.max(abs(shap_values_v0 - shap_values_shortcut))))

In [None]:
# compute SHAP values via FastTreeSHAP v1
shap_explainer = fasttreeshap.TreeExplainer(xgb_model, algorithm = "v1", shortcut = False)
shap_values_v1 = shap_explainer(test.iloc[:num_sample]).values
shap_values_v1.shape

In [None]:
# justify the correctness of FastTreeSHAP v1
print("Maximum difference of SHAP values between v1 and v0 is {:.2e}.".format(
    np.max(abs(shap_values_v1 - shap_values_v0))))

In [None]:
# compute SHAP values via FastTreeSHAP v2
shap_explainer = fasttreeshap.TreeExplainer(xgb_model, algorithm = "v2", shortcut = False)
shap_values_v2 = shap_explainer(test.iloc[:num_sample]).values
shap_values_v2.shape

In [None]:
# justify the correctness of FastTreeSHAP v2
print("Maximum difference of SHAP values between v2 and v0 is {:.2e}.".format(
    np.max(abs(shap_values_v2 - shap_values_v0))))

In [None]:
# compute SHAP values via automatic TreeSHAP algorithm selection
shap_explainer = fasttreeshap.TreeExplainer(xgb_model, algorithm = "auto", shortcut = False)
shap_values_auto = shap_explainer(test.iloc[:num_sample]).values
shap_values_auto.shape

In [None]:
# justify the correctness of automatically selected TreeSHAP algorithm
# it turns out that "auto" selects "v2" as the most appropriate TreeSHAP algorithm
print("Maximum difference of SHAP values between auto and v0 is {:.2e}.".format(
    np.max(abs(shap_values_auto - shap_values_v0))))

### Compare running times of different versions of TreeSHAP in computing SHAP values

In [None]:
num_sample = 1000  # number of samples to be explained
num_round = 5  # number of rounds to record mean and standard deviation of running time

In [None]:
# run "shortcut" version of TreeSHAP multiple times and record its average running time
# parallel computing is enabled in "shortcut" and it is working in progress in FastTreeSHAP package
# (possible) speedup of original TreeSHAP with shortcut over original TreeSHAP without shortcut is mainly due to
# parallel computing
run_fasttreeshap(
    model = xgb_model, sample = test, interactions = False, algorithm_version = "v0", 
    num_round = num_round, num_sample = num_sample, shortcut = True)

In [None]:
# run FastTreeSHAP v0 (i.e., original TreeSHAP) multiple times and record its average running time
run_fasttreeshap(
    model = xgb_model, sample = test, interactions = False, algorithm_version = "v0", 
    num_round = num_round, num_sample = num_sample, shortcut = False)

In [None]:
# run FastTreeSHAP v1 multiple times and record its average running time
run_fasttreeshap(
    model = xgb_model, sample = test, interactions = False, algorithm_version = "v1", 
    num_round = num_round, num_sample = num_sample, shortcut = False)

In [None]:
# run FastTreeSHAP v2 multiple times and record its average running time
run_fasttreeshap(
    model = xgb_model, sample = test, interactions = False, algorithm_version = "v2", 
    num_round = num_round, num_sample = num_sample, shortcut = False)

In [None]:
# run automatically selected TreeSHAP algorithm multiple times and record its average running time
# it turns out that "auto" selects "v2" as the most appropriate TreeSHAP algorithm
run_fasttreeshap(
    model = xgb_model, sample = test, interactions = False, algorithm_version = "auto", 
    num_round = num_round, num_sample = num_sample, shortcut = False)

### Compute SHAP interaction values via different versions of TreeSHAP

In [None]:
num_sample = 10  # number of samples to be explained

In [None]:
# compute SHAP interaction values via "shortcut" (i.e., original TreeSHAP in xgboost library)
# parallel computing is enabled in "shortcut"
shap_explainer = fasttreeshap.TreeExplainer(xgb_model, algorithm = "v0", shortcut = True)
shap_interaction_values_shortcut = shap_explainer(test.iloc[:num_sample], interactions = True).values
shap_interaction_values_shortcut.shape

In [None]:
# compute SHAP interaction values via FastTreeSHAP v0 (i.e., original TreeSHAP in shap library)
shap_explainer = fasttreeshap.TreeExplainer(xgb_model, algorithm = "v0", shortcut = False)
shap_interaction_values_v0 = shap_explainer(test.iloc[:num_sample], interactions = True).values
shap_interaction_values_v0.shape

In [None]:
# justify the correctness of FastTreeSHAP v0
print("Mean and maximum differences of SHAP values between v0 and shortcut is {:.2e} and {:.2e}.".format(
    np.mean(abs(shap_interaction_values_v0 - shap_interaction_values_shortcut)), 
    np.max(abs(shap_interaction_values_v0 - shap_interaction_values_shortcut))))

In [None]:
# compute SHAP interaction values via FastTreeSHAP v1
shap_explainer = fasttreeshap.TreeExplainer(xgb_model, algorithm = "v1", shortcut = False)
shap_interaction_values_v1 = shap_explainer(test.iloc[:num_sample], interactions = True).values
shap_interaction_values_v1.shape

In [None]:
# justify the correctness of FastTreeSHAP v1
print("Maximum difference of SHAP interaction values between v1 and v0 is {:.2e}.".format(
    np.max(abs(shap_interaction_values_v1 - shap_interaction_values_v0))))

In [None]:
# compute SHAP interaction values via automatic TreeSHAP algorithm selection
# v1 is always preferred to v0 in any use cases, and v2 does not support interactions
shap_explainer = fasttreeshap.TreeExplainer(xgb_model, algorithm = "auto", shortcut = False)
shap_interaction_values_auto = shap_explainer(test.iloc[:num_sample], interactions = True).values
shap_interaction_values_auto.shape

In [None]:
# justify the correctness of automatically selected TreeSHAP algorithm
print("Maximum difference of SHAP interaction values between auto and v0 is {:.2e}.".format(
    np.max(abs(shap_interaction_values_auto - shap_interaction_values_v0))))

### Compare running times of different versions of TreeSHAP in computing SHAP interaction values

In [None]:
num_sample = 10  # number of samples to be explained
num_round = 5  # number of rounds to record mean and standard deviation of running time

In [None]:
# run "shortcut" version of TreeSHAP multiple times and record its average running time
# parallel computing is enabled in "shortcut" and it is working in progress in FastTreeSHAP package
# (possible) speedup of original TreeSHAP with shortcut over original TreeSHAP without shortcut is mainly due to
# parallel computing
run_fasttreeshap(
    model = xgb_model, sample = test, interactions = True, algorithm_version = "v0", 
    num_round = num_round, num_sample = num_sample, shortcut = True)

In [None]:
# run FastTreeSHAP v0 (i.e., original TreeSHAP) multiple times and record its average running time
run_fasttreeshap(
    model = xgb_model, sample = test, interactions = True, algorithm_version = "v0", 
    num_round = num_round, num_sample = num_sample, shortcut = False)

In [None]:
# run FastTreeSHAP v1 multiple times and record its average running time
run_fasttreeshap(
    model = xgb_model, sample = test, interactions = True, algorithm_version = "v1", 
    num_round = num_round, num_sample = num_sample, shortcut = False)

In [None]:
# run automatically selected TreeSHAP algorithm multiple times and record its average running time
# v1 is always preferred to v0 in any use cases, and v2 does not support interactions
run_fasttreeshap(
    model = xgb_model, sample = test, interactions = True, algorithm_version = "auto", 
    num_round = num_round, num_sample = num_sample, shortcut = False)

## Deep dive into automatic algorithm selection

The default value of the argument `algorithm` in the class `TreeExplainer` is `auto`, indicating that the TreeSHAP algorithm is automatically selected from `"v0"`, `"v1"` and `"v2"` according to the number of samples to be explained and the constraint on the allocated memory.

Specifically, `"v1"` is always perferred to `"v0"` in any use cases, and `"v2"` is perferred to `"v1"` when the number of samples to be explained is sufficiently large: <img src="https://latex.codecogs.com/svg.latex?M>2^{D+1}/D,"/> and the memory constraint is also satisfied: <img src="https://latex.codecogs.com/svg.latex?L2^D\cdot8Byte<0.25\cdot Total\,Memory."/> Here *M* is the number of samples to be explained, *D* is the maximum depth of any tree, and *L* is the maximum number of leaves in any tree. More detailed discussion of the above criteria can be found in [FastTreeSHAP](https://arxiv.org/abs/2109.09847) paper.

### Automatic algorithm selection in moderate models with varying number of samples to be explained

In moderate models (i.e., memory constraint is not a big concern), `"auto"` selects `"v2"` when the number of samples to be explained exceeds a threshold as defined above, and selects `"v1"` otherwise.

In [None]:
n_estimators = 100  # number of trees in random forest model
max_depth = 8  # maximum depth of any trees in random forest model

In [None]:
# train a random forest model
rf_model = RandomForestClassifier(n_estimators = n_estimators, max_depth = max_depth, random_state = 0)
rf_model.fit(train, label_train)

In [None]:
# estimated memory usage of FastTreeSHAP v2 shows that memory constraint is not a big concern
shap_explainer = fasttreeshap.TreeExplainer(rf_model)
max_node = max(shap_explainer.model.num_nodes)
max_leaves = (max_node + 1) // 2
max_combinations = 2**max_depth
memory = max_leaves * max_combinations * 8
if memory < 1024:
    print("Memory usage of FastTreeSHAP v2 is around {:.2f}B.".format(memory))
elif memory / 1024 < 1024:
    print("Memory usage of FastTreeSHAP v2 is around {:.2f}KB.".format(memory / 1024))
elif memory / 1024**2 < 1024:
    print("Memory usage of FastTreeSHAP v2 is around {:.2f}MB.".format(memory / 1024**2))
else:
    print("Memory usage of FastTreeSHAP v2 is around {:.2f}GB.".format(memory / 1024**3))

When number of samples to be explained is 100, `"auto"` selects `"v2"` as the most appropriate TreeSHAP algorithm.

In [None]:
# number of samples to be explained
num_sample = 100

In [None]:
# compute SHAP values via FastTreeSHAP v0 (i.e., original TreeSHAP)
shap_explainer = fasttreeshap.TreeExplainer(rf_model, algorithm = "v0")
shap_values_v0 = shap_explainer(test.iloc[:num_sample]).values

# compute SHAP values via FastTreeSHAP v1
shap_explainer = fasttreeshap.TreeExplainer(rf_model, algorithm = "v1")
shap_values_v1 = shap_explainer(test.iloc[:num_sample]).values

# compute SHAP values via FastTreeSHAP v2
shap_explainer = fasttreeshap.TreeExplainer(rf_model, algorithm = "v2")
shap_values_v2 = shap_explainer(test.iloc[:num_sample]).values

# compute SHAP values via automatic TreeSHAP algorithm selection
shap_explainer = fasttreeshap.TreeExplainer(rf_model, algorithm = "auto")
shap_values_auto = shap_explainer(test.iloc[:num_sample]).values

In [None]:
# "auto" selects "v2" as the most appropriate TreeSHAP algorithm when number of samples is 100
print("Maximum difference of SHAP values between auto and v0 is {:.2e}.".format(
    np.max(abs(shap_values_auto - shap_values_v0))))
print("Maximum difference of SHAP values between auto and v1 is {:.2e}.".format(
    np.max(abs(shap_values_auto - shap_values_v1))))
print("Maximum difference of SHAP values between auto and v2 is {:.2e}.".format(
    np.max(abs(shap_values_auto - shap_values_v2))))

When number of samples to be explained is 50, `"auto"` selects `"v1"` as the most appropriate TreeSHAP algorithm.

In [None]:
# number of samples to be explained
num_sample = 50

In [None]:
# compute SHAP values via FastTreeSHAP v0 (i.e., original TreeSHAP)
shap_explainer = fasttreeshap.TreeExplainer(rf_model, algorithm = "v0")
shap_values_v0 = shap_explainer(test.iloc[:num_sample]).values

# compute SHAP values via FastTreeSHAP v1
shap_explainer = fasttreeshap.TreeExplainer(rf_model, algorithm = "v1")
shap_values_v1 = shap_explainer(test.iloc[:num_sample]).values

# compute SHAP values via FastTreeSHAP v2
shap_explainer = fasttreeshap.TreeExplainer(rf_model, algorithm = "v2")
shap_values_v2 = shap_explainer(test.iloc[:num_sample]).values

# compute SHAP values via automatic TreeSHAP algorithm selection
shap_explainer = fasttreeshap.TreeExplainer(rf_model, algorithm = "auto")
shap_values_auto = shap_explainer(test.iloc[:num_sample]).values

In [None]:
# "auto" selects "v1" as the most appropriate TreeSHAP algorithm when number of samples is 50
print("Maximum difference of SHAP values between auto and v0 is {:.2e}.".format(
    np.max(abs(shap_values_auto - shap_values_v0))))
print("Maximum difference of SHAP values between auto and v1 is {:.2e}.".format(
    np.max(abs(shap_values_auto - shap_values_v1))))
print("Maximum difference of SHAP values between auto and v2 is {:.2e}.".format(
    np.max(abs(shap_values_auto - shap_values_v2))))

### Automatic algorithm selection in very large models

In very large models, `"auto"` selects `"v1"` instead of `"v2"` when the potential memory risk is detected, no matter how large the number of samples to be explained is.

In [None]:
n_estimators = 100  # number of trees in random forest model
max_depth = 20  # maximum depth of any trees in random forest model

In [None]:
# train a random forest model
rf_model = RandomForestClassifier(n_estimators = n_estimators, max_depth = max_depth, random_state = 0)
rf_model.fit(train, label_train)

In [None]:
# estimated memory usage of FastTreeSHAP v2 shows a potential memory risk
shap_explainer = fasttreeshap.TreeExplainer(rf_model)
max_node = max(shap_explainer.model.num_nodes)
max_leaves = (max_node + 1) // 2
max_combinations = 2**max_depth
memory = max_leaves * max_combinations * 8
if memory < 1024:
    print("Memory usage of FastTreeSHAP v2 is around {:.2f}B.".format(memory))
elif memory / 1024 < 1024:
    print("Memory usage of FastTreeSHAP v2 is around {:.2f}KB.".format(memory / 1024))
elif memory / 1024**2 < 1024:
    print("Memory usage of FastTreeSHAP v2 is around {:.2f}MB.".format(memory / 1024**2))
else:
    print("Memory usage of FastTreeSHAP v2 is around {:.2f}GB.".format(memory / 1024**3))

In [None]:
# number of samples to be explained
num_sample = 10

In [None]:
# compute SHAP values via FastTreeSHAP v0 (i.e., original TreeSHAP)
shap_explainer = fasttreeshap.TreeExplainer(rf_model, algorithm = "v0")
shap_values_v0 = shap_explainer(test.iloc[:num_sample]).values

# compute SHAP values via FastTreeSHAP v1
shap_explainer = fasttreeshap.TreeExplainer(rf_model, algorithm = "v1")
shap_values_v1 = shap_explainer(test.iloc[:num_sample]).values

# compute SHAP values via FastTreeSHAP v2
shap_explainer = fasttreeshap.TreeExplainer(rf_model, algorithm = "v2")
shap_values_v2 = shap_explainer(test.iloc[:num_sample]).values

# compute SHAP values via automatic TreeSHAP algorithm selection
shap_explainer = fasttreeshap.TreeExplainer(rf_model, algorithm = "auto")
shap_values_auto = shap_explainer(test.iloc[:num_sample]).values

In [None]:
# "v2" is automatically switched to "v1" as potential memory risk is detected
print("Maximum difference of SHAP values between v2 and v0 is {:.2e}.".format(
    np.max(abs(shap_values_v2 - shap_values_v0))))
print("Maximum difference of SHAP values between v2 and v1 is {:.2e}.".format(
    np.max(abs(shap_values_v2 - shap_values_v1))))

In [None]:
# "auto" selects "v1" as the most appropriate TreeSHAP algorithm as potential memory risk is detected
print("Maximum difference of SHAP values between auto and v0 is {:.2e}.".format(
    np.max(abs(shap_values_auto - shap_values_v0))))
print("Maximum difference of SHAP values between auto and v1 is {:.2e}.".format(
    np.max(abs(shap_values_auto - shap_values_v1))))
print("Maximum difference of SHAP values between auto and v2 is {:.2e}.".format(
    np.max(abs(shap_values_auto - shap_values_v2))))