In [1]:
import pandas as pd
import numpy as np
import plotly.express as px
import shap
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import lime
import lime.lime_tabular

In [2]:
# load data
df = pd.read_csv('data/BankChurners.csv')#.sample(500)
df.drop(df.columns[0], axis=1, inplace=True)
df.drop(df.columns[-2:], axis=1, inplace=True)

df.loc[df['Attrition_Flag'] == "Existing Customer",["Attrition_Flag"]] = 0
df.loc[df['Attrition_Flag'] == "Attrited Customer",["Attrition_Flag"]] = 1

df[["Attrition_Flag"]] = df[["Attrition_Flag"]].astype(int)

In [3]:
np.random.seed(42)

#find categorical variables
categorical = [var for var in df.columns if df[var].dtype=='O']

encoded = pd.get_dummies(df[categorical], prefix=categorical)
df_enc = pd.concat([encoded, df], axis=1)
df_enc.drop(categorical, axis=1, inplace=True)

X = df_enc.drop(["Attrition_Flag"], axis=1)
y = df_enc["Attrition_Flag"]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

## Model Training

In [4]:
def train_and_store_model(model, X_train, y_train, X_test):
    model.fit(X_train.values, y_train)
    y_pred = model.predict(X_test)
    return {'model': model, 'predictions': y_pred}

models = {
    'RandomForest': RandomForestClassifier(),
    'DecisionTree': DecisionTreeClassifier(),
    'GradientBoosting': GradientBoostingClassifier(),
}

models_dict = {}

for model_name, model in models.items():
    models_dict[model_name] = train_and_store_model(model, X_train, y_train, X_test)

# Print classification reports
for model_name, model_info in models_dict.items():
    print(f"Classification Report for {model_name}:\n")
    print(classification_report(y_test, model_info['predictions']))

is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.
is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.
X has feature names, but RandomForestClassifier was fitted without feature names
is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.
is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.
is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.
X has feature names, but DecisionTreeClassifier was fitted without feature names
is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.
is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.
is_sparse is deprecated an

Classification Report for RandomForest:

              precision    recall  f1-score   support

           0       0.95      0.99      0.97      1699
           1       0.93      0.75      0.83       327

    accuracy                           0.95      2026
   macro avg       0.94      0.87      0.90      2026
weighted avg       0.95      0.95      0.95      2026

Classification Report for DecisionTree:

              precision    recall  f1-score   support

           0       0.96      0.96      0.96      1699
           1       0.80      0.78      0.79       327

    accuracy                           0.93      2026
   macro avg       0.88      0.87      0.88      2026
weighted avg       0.93      0.93      0.93      2026

Classification Report for GradientBoosting:

              precision    recall  f1-score   support

           0       0.97      0.99      0.98      1699
           1       0.93      0.85      0.89       327

    accuracy                           0.97      2026
 

X has feature names, but GradientBoostingClassifier was fitted without feature names
is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.
is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.
is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.
is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.
is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.
is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.
is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.


## Run SHAP

In [5]:
shap_values_dict = {}

# Calculate SHAP values for each model
for model_name, model_info in models_dict.items():
    model = model_info['model']
    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(X_test)

    # Reshape SHAP values
    if isinstance(shap_values, list):
        shap_values = shap_values[0]

    shap_values_dict[model_name] = shap_values

# create a df with the shap values' means
shap_mean_df = pd.DataFrame()

for model_name, shap_values in shap_values_dict.items():
    shap_mean_df[model_name] = np.abs(shap_values).mean(axis=0)
    #normalize
    shap_mean_df[model_name] = shap_mean_df[model_name] / shap_mean_df[model_name].sum()

shap_mean_df.index = X_test.columns
#sort by sum
shap_mean_df = shap_mean_df.sort_values(by=list(shap_mean_df.columns), ascending=False)
shap_mean_df.head(10)

Unnamed: 0,RandomForest,DecisionTree,GradientBoosting
Total_Trans_Ct,0.25238,0.355082,0.343321
Total_Revolving_Bal,0.145293,0.141205,0.128637
Total_Trans_Amt,0.127243,0.173111,0.211898
Total_Ct_Chng_Q4_Q1,0.095349,0.065044,0.069666
Total_Relationship_Count,0.08724,0.079089,0.057417
Avg_Utilization_Ratio,0.046004,0.024136,0.001858
Months_Inactive_12_mon,0.039551,0.010668,0.047181
Total_Amt_Chng_Q4_Q1,0.037168,0.033245,0.054462
Contacts_Count_12_mon,0.028271,0.011519,0.030789
Credit_Limit,0.018979,0.01678,0.00559


In [7]:
# # Plot the feature dependences
# for model, shap_values in shap_values_dict.items():
#     # Create dependence plots for each feature
#     for feature_name in X_test.columns:
#         shap.dependence_plot(feature_name, shap_values, X_test, feature_names=X_test.columns,
#                              interaction_index='auto', show=False)
#         plt.title(f'SHAP Dependence Plot - {model_name} - {feature_name}')
#         plt.show()

## Run LIME

In [8]:
# Create LIME explainer
explainer = lime.lime_tabular.LimeTabularExplainer(X_train.values, feature_names=X_train.columns)

# Perform LIME on each model
lime_results_dict = {}

for model_name, model_info in models_dict.items():
    model = model_info['model']
    lime_results = []

    # Iterate over X_test instances
    for instance in X_test.values:
        explanation = explainer.explain_instance(instance, model.predict_proba, num_features=len(X_test.columns))
        lime_results.append(explanation.as_list())

    lime_results_dict[model_name] = lime_results

# Average the LIME results
average_lime_results = {}

for model_name, lime_results in lime_results_dict.items():
    average_weights = {}
    for lime_result in lime_results:
        for feature, weight in lime_result:
            average_weights.setdefault(feature, []).append(weight)

    average_lime_results[model_name] = [(feature, np.mean(weights)) for feature, weights in average_weights.items()]


In [9]:
# Create a df with the LIME values' means
lime_mean_df = pd.DataFrame()

for model_name, lime_results in average_lime_results.items():
    lime_mean_df[model_name] = [weight for feature, weight in lime_results]
    # Normalize
    lime_mean_df[model_name] = lime_mean_df[model_name] / lime_mean_df[model_name].sum()

lime_mean_df.index = [feature for feature, weight in lime_results]

lime_mean_df = lime_mean_df.sort_values(by=list(lime_mean_df.columns), ascending=False)
lime_mean_df

Unnamed: 0,RandomForest,DecisionTree,GradientBoosting
Total_Trans_Ct > 81.00,2.085390,44.232385,2.184697
67.00 < Total_Trans_Ct <= 81.00,1.658041,28.517763,1.219197
Card_Category_Platinum <= 0.00,0.856045,48.933208,0.310428
4.00 < Total_Relationship_Count <= 5.00,0.854747,1.451800,0.309436
Total_Relationship_Count > 5.00,0.632134,4.657491,0.312915
...,...,...,...
Total_Relationship_Count <= 3.00,-1.036007,-9.596153,-0.577343
Total_Trans_Amt <= 2160.00,-1.153148,21.003098,1.461701
Total_Trans_Amt > 4739.00,-1.256607,-36.336932,-2.620934
Total_Revolving_Bal <= 326.00,-1.993379,-16.366899,-1.090073


In [10]:
# Create a dictionary to map indices to feature names
index_to_feature = {}
for index in lime_mean_df.index:
    for feature_name in X_test.columns:
        if feature_name in str(index):
            index_to_feature[index] = feature_name

# Update the indices in lime_mean_df
lime_mean_df.index = lime_mean_df.index.map(lambda x: index_to_feature.get(x, x))

# Take absolute values, merge and normalize them
lime_mean_df = lime_mean_df.abs()
lime_mean_df = lime_mean_df.groupby(lime_mean_df.index).sum()
lime_mean_df = lime_mean_df / lime_mean_df.sum()

lime_mean_df.head(10)

Unnamed: 0,RandomForest,DecisionTree,GradientBoosting
Avg_Open_To_Buy,0.009522,0.006667,0.003441
Avg_Utilization_Ratio,0.012838,0.002583,0.001659
Card_Category_Blue,0.0,0.0,0.0
Card_Category_Gold,0.003046,0.000616,0.003826
Card_Category_Platinum,0.042927,0.20007,0.026449
Card_Category_Silver,0.003379,0.001915,0.001599
Contacts_Count_12_mon,0.012411,0.00676,0.018745
Credit_Limit,0.009829,0.005919,0.002769
Customer_Age,0.004495,0.004027,0.005795
Dependent_count,0.008588,0.007446,0.000878


## Merge SHAP and LIME results

In [11]:
# Merge shap_mean_df and lime_mean_df
shap_lime_df = pd.concat([shap_mean_df.assign(method='SHAP'), lime_mean_df.assign(method='LIME')], axis=0).reset_index()

# Melt the dataframe for visualization
shap_lime_df = shap_lime_df.melt(id_vars=['index', 'method'], var_name='model', value_name='value')

shap_lime_df

Unnamed: 0,index,method,model,value
0,Total_Trans_Ct,SHAP,RandomForest,0.252380
1,Total_Revolving_Bal,SHAP,RandomForest,0.145293
2,Total_Trans_Amt,SHAP,RandomForest,0.127243
3,Total_Ct_Chng_Q4_Q1,SHAP,RandomForest,0.095349
4,Total_Relationship_Count,SHAP,RandomForest,0.087240
...,...,...,...,...
217,Total_Ct_Chng_Q4_Q1,LIME,GradientBoosting,0.059222
218,Total_Relationship_Count,LIME,GradientBoosting,0.076696
219,Total_Revolving_Bal,LIME,GradientBoosting,0.110467
220,Total_Trans_Amt,LIME,GradientBoosting,0.267335


In [12]:
# plot the results
fig = px.bar(shap_lime_df.sort_values(by='value', ascending=True),
                y='index', 
                x='value', 
                color='model', 
                barmode='group',
                facet_col='method', 
                title="Top SHAP and LIME features")
fig.show()