# HSMA 6 - Session 4G - Exercise 2A - Stroke Dataset Explainable AI

The data loaded in this exercise is for seven acute stroke units, and whether a patient receives clost-busting treatment for stroke.  There are lots of features, and a description of the features can be found in the file stroke_data_feature_descriptions.csv.

Train a decision tree model to try to predict whether or not a stroke patient receives clot-busting treatment.  Use the prompts below to write each section of code.

Run the code below to import the dataset. 

In [None]:
import pandas as pd
import numpy as np

# Download data
# (not required if running locally and have previously downloaded data)

download_required = True

if download_required:

    # Download processed data:
    address = 'https://raw.githubusercontent.com/MichaelAllen1966/' + \
                '2004_titanic/master/jupyter_notebooks/data/hsma_stroke.csv'
    data = pd.read_csv(address)

    # Create a data subfolder if one does not already exist
    import os
    data_directory ='./data/'
    if not os.path.exists(data_directory):
        os.makedirs(data_directory)

    # Save data to data subfolder
    data.to_csv(data_directory + 'hsma_stroke.csv', index=False)

# Load data
data = pd.read_csv('data/hsma_stroke.csv')
# Make all data 'float' type
data = data.astype(float)

Preview the data.

In [None]:
data.head()

In [None]:
# Import machine learning methods

from xgboost.sklearn import XGBClassifier

from sklearn.model_selection import train_test_split
from sklearn.tree import plot_tree

import plotly.express as px
import matplotlib.pyplot as plt

from sklearn.metrics import auc, roc_curve, RocCurveDisplay, f1_score, precision_score, \
                            recall_score, confusion_matrix, ConfusionMatrixDisplay, \
                            classification_report


In [None]:
# Additional imports for explainable AI
from sklearn.inspection import PartialDependenceDisplay, permutation_importance

# Import shap for shapley values
import shap

# JavaScript Important for the interactive charts later on
shap.initjs()

In [None]:
X = data.drop('Clotbuster given',axis=1) # X = all 'data' except the 'survived' column
y = data['Clotbuster given'] # y = 'survived' column from 'data'
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.25, random_state=42)
model = XGBClassifier(random_state=42)
model = model.fit(X_train,y_train)
# Predict training and test set labels
y_pred_train = model.predict(X_train)
y_pred_test = model.predict(X_test)

accuracy_train = np.mean(y_pred_train == y_train)
accuracy_test = np.mean(y_pred_test == y_test)

print (f'Accuracy of predicting training data = {accuracy_train}')
print (f'Accuracy of predicting test data = {accuracy_test}')

## Explainable AI

### Explore Feature Importance

#### Importance with MDI

In [None]:
 ## YOUR CODE HERE

### Importance with PFI

In [None]:
 ## YOUR CODE HERE

### PDP + ICE

In [None]:
 ## YOUR CODE HERE

### SHAP

#### Create the SHAP explainer and the shap values.

In [None]:
explainer =  ## YOUR CODE HERE

shap_values =  ## YOUR CODE HERE

shap_values

#### Return just the values

In [None]:
shap_values_numeric = ## YOUR CODE HERE
shap_values_numeric

### Feature table

Create a table that shows feature importance for MDI, PFI and SHAP.

In [None]:
 ## YOUR CODE HERE

##### Display the top 10 features according to MDI, PFI and SHAP.

In [None]:
## YOUR CODE HERE

### SHAP Plots

#### Global: Beeswarm

In [None]:
## YOUR CODE HERE

#### Global: Bar

In [None]:
## YOUR CODE HERE

#### Bar: by factor

In [None]:
## YOUR CODE HERE

#### Local: Waterfall plots

In [None]:
## YOUR CODE HERE

### Local: Force Plots 

In [None]:
## YOUR CODE HERE

### Global: Force Plots

In [None]:
## YOUR CODE HERE

### Dependence Plots

#### Simple scatter of a single feature

Choose a single feature to create a scatter (dependence) plot for.

In [None]:
## YOUR CODE HERE

Colour the plot by the most strongly interacting feature.

In [None]:
## YOUR CODE HERE

Colour your plot by a single feature.

In [None]:
## YOUR CODE HERE

In [None]:
## YOUR CODE HERE