In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from copy import deepcopy

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)


import imodels

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

from tqdm import tqdm

import shap

from datasets import DATASETS_CLASSIFICATION

In [None]:
def feature_variability(dataset_name, database_name, N=100, save=False):
    X, y, cols = imodels.util.data_util.get_clean_dataset(dataset_name, database_name)
    cols = np.array(cols)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=50)

    variance_RF = np.zeros(len(cols))
    variance_hsRF = np.zeros(len(cols))
    idx_all = list(range(X_train.shape[0]))

    for _ in tqdm(range(N)):
        idx_train, _ = train_test_split(idx_all, test_size=0.33)
        X_t, y_t = X_train[idx_train], y_train[idx_train]

        RF = RandomForestClassifier(n_estimators=50)
        RF.fit(X_t, y_t)
        hsRF = imodels.HSTreeClassifierCV(deepcopy(RF))
        hsRF.fit(X_t, y_t)

        shap_values_RF = shap.TreeExplainer(RF).shap_values(X_test)[1]
        variance_RF += shap_values_RF.std(axis=0)
        shap_values_hsRF = shap.TreeExplainer(hsRF.estimator_).shap_values(X_test)[1]
        variance_hsRF += shap_values_hsRF.std(axis=0)

    variance_RF /= N
    variance_hsRF /= N
    
    n_take = min(len(cols), 10)
    #xaxis = np.array(range(variance_RF.shape[0]))
    xaxis = np.array(range(n_take))
    fig = plt.figure(dataset_name)
    plt.clf()
    plt.bar(xaxis-0.2, variance_RF[:n_take], width=0.4, color="firebrick", label="RF")
    plt.bar(xaxis+0.2, variance_hsRF[:n_take], width=0.4, color="black", label="hsRF")
    plt.xticks(xaxis, cols[:n_take], rotation=90)
    plt.legend()
    plt.ylabel("SHAP Variability")
    if save:
        plt.savefig("../figures/SHAP_variability/"+dataset_name, bbox_inches="tight", facecolor="white", edgecolor="auto")

In [None]:
for (dataset_name, database_name) in DATASETS_CLASSIFICATION.values():
    feature_variability(dataset_name, database_name, save=True)