In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from copy import deepcopy

import warnings
warnings.filterwarnings("ignore")

import imodels

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

from tqdm import tqdm

import shap

from datasets import DATASETS_CLASSIFICATION

In [None]:
def shap_clusters(dataset_name, database_name, save=False):
    X, y, cols = imodels.util.data_util.get_clean_dataset(dataset_name, database_name)
    cols = np.array(cols)
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=50)
    idx_all = list(range(X_train.shape[0]))
    idx_train, _ = train_test_split(idx_all, test_size=0.33)
    X_t, y_t = X_train[idx_train], y_train[idx_train]
    
    RF = RandomForestClassifier(n_estimators=50)
    RF.fit(X_t, y_t)
    hsRF = imodels.HSTreeClassifierCV(deepcopy(RF), reg_param_list=[50, 100, 200, 300, 400, 500])
    hsRF.fit(X_t, y_t)
    print(f"Optimal lambda: {hsRF.reg_param}")
    
    fig = plt.figure(dataset_name+"_RF")
    plt.clf()
    shap_values_RF = shap.TreeExplainer(RF).shap_values(X_test)[1]
    shap.summary_plot(shap_values_RF, X_test, feature_names=cols, max_display=10, sort=False, show=False)
    plt.title("RF")
    plt.xlabel("SHAP value")
    if save:
        plt.savefig("../graphs/claim_4/SHAP_interpretation/"+dataset_name+"_RF", bbox_inches="tight", facecolor="white", edgecolor="auto")
    
    fig = plt.figure(dataset_name+"_hsRF")
    plt.clf()
    shap_values_hsRF = shap.TreeExplainer(hsRF.estimator_).shap_values(X_test)[1]
    shap.summary_plot(shap_values_hsRF, X_test, feature_names=cols, max_display=10, sort=False, show=False)
    plt.title("hsRF")
    plt.xlabel("SHAP value")
    if save:
        plt.savefig("../graphs/claim_4/SHAP_interpretation/"+dataset_name+"_hsRF", bbox_inches="tight", facecolor="white", edgecolor="auto")

In [None]:
for (dataset_name, database_name) in DATASETS_CLASSIFICATION.values():
    shap_clusters(dataset_name, database_name, save=False)