In [None]:
#%matplotlib inline
import pandas as pd
import numpy as np
import matplotlib.pyplot as pl
import seaborn as sns
import json

import warnings
warnings.simplefilter(action='ignore')

from sklearn.model_selection import cross_val_predict, KFold
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import LabelEncoder
from sklearnex import patch_sklearn

from xgboost import XGBClassifier
import shap
import sage

# Accelerate sklearn operations on Intel CPUs
patch_sklearn()

# load JS visualization code to notebook
#shap.initjs() 

### Define classes

In [None]:
classes = ['girls', 'mixed'] # choose 2 from 'boys', 'girls', and 'mixed'
classes.sort()

### Load and format data

In [None]:
# load
fn = f"data/selected_features_{classes[0]}_{classes[1]}.csv" # see binary_xgboost_optimizer.py
features_df = pd.read_csv(fn, index_col="stimulus_id")

# make features names shorter for plotting
features_df.columns = [col.replace("_Mean", "").replace("Mean", "") for col in features_df.columns]
features_df.columns = [col.replace("_peak_Peak", "_Peak") for col in features_df.columns]

# now make them all long 25 characters (the longest), by adding spaces before the feature name
features_df.columns = [col.rjust(25) if col != "target" else col for col in features_df.columns]

# format
X = features_df.drop("target", axis=1)
y = features_df["target"]
label_encoder = LabelEncoder()
label_encoder = label_encoder.fit(y)
y = label_encoder.transform(y)

In [None]:
class0, class1 = tuple(label_encoder.inverse_transform([0, 1]))
print(f"Class 0: {class0}\nClass 1: {class1}")

### Hyperparameter tuning

In [None]:
# import the best parameters from an optuna study with 2000 trials (see binary_xgboost_optimizer.py)
fn = f"xgboost_params_2000_{classes[0]}_{classes[1]}.json"
with open(fn) as json_file:
    best_params = json.load(json_file)

### Leave-one-out cross-validation

In [22]:
base_params = {
        "verbosity": 0,
        "objective": "binary:logistic",
        # use exact for small dataset
        "tree_method": "exact",
}
best_params.update(base_params)

model = XGBClassifier(**best_params)
loocv = KFold(len(X))
y_pred = cross_val_predict(model, X, y, cv=loocv)

In [None]:
# confusion matrix
cm = confusion_matrix(y, y_pred)
cm_df = pd.DataFrame(cm, index=label_encoder.inverse_transform([0,1]), columns=label_encoder.inverse_transform([0,1]))

pl.figure(figsize=(5.5,4))
sns.heatmap(cm_df, annot=True, cmap='Blues', fmt='g', cbar=False)
pl.title(f'F1 score:{f1_score(y, y_pred):.2f}')
pl.ylabel('True label')
pl.xlabel('Predicted label')
pl.show()

### Train the model on the whole dataset

In [None]:
model = XGBClassifier(**best_params).fit(X, y)

### SAGE analysis 
Global feature importance

In [None]:
# Set up an imputer to handle missing features
imputer = sage.MarginalImputer(model, X.values)

# Set up an estimator
estimator = sage.PermutationEstimator(imputer, loss='cross entropy', random_state=42)

# Calculate SAGE values
sage_values = estimator(X.values, y)
sage_values.plot(X.columns, figsize=(10,10))

In [None]:

sage_values_df = pd.DataFrame(columns=['feature name','SAGE value'])
sage_values_df['feature name'] = X.columns
sage_values_df['SAGE value'] = sage_values.values
sage_values_df.sort_values(by=['SAGE value'], ascending=False, inplace=True)

# give SAGE order to SHAP beeswarm plot
sage_ordered_features = sage_values_df['feature name'].tolist()

col2num = {col: i for i, col in enumerate(X.columns)}


In [None]:
# plot top 10 sage values only
pl.figure(figsize=(10,10))
sns.barplot(x="SAGE value", y="feature name", data=sage_values_df, order=sage_ordered_features[:10])
#pl.title('Feature importance')
pl.ylabel('')
pl.xticks(fontsize=30)
pl.yticks(fontsize=30)
pl.xlabel('SAGE value', fontsize=30)
pl.savefig(f'sage_top_10_{class0}_{class1}.pdf', format='pdf', bbox_inches='tight')
pl.show()

In [None]:
# reorder features so that mfccs are last (and won't be shown in the SHAP beeswarm plot)
ordered_not_mfccs = [f for f in sage_ordered_features if 'mfcc' not in f]
ordered_mfccs = [f for f in sage_ordered_features if 'mfcc' in f]
re_ordered_features = ordered_not_mfccs + ordered_mfccs

### SHAP analysis
Local feature contributions

In [None]:
explainer = shap.Explainer(model, feature_names=X.columns)
shap_values = explainer(X)


In [None]:
# # visualize the first prediction's explanation
# shap.plots.waterfall(shap_values[0], max_display=20) # f(x) = log odds

# summarize the effects of all the features
ax = shap.plots.beeswarm(
    shap_values, 
    show= False, 
    order=[col2num[col] for col in re_ordered_features], 
    max_display=11, 
    color_bar=False,
)

In [None]:
# modified from https://github.com/shap/shap/blob/master/shap/plots/_beeswarm.py

import colorcet as cc
from matplotlib.colors import to_rgb, LinearSegmentedColormap

clist = [to_rgb(c) for c in cc.CET_CBTL1[::-1][30:-120]]
cm = LinearSegmentedColormap.from_list("", clist, N=256)
display(cm)

from shap.plots._labels import labels
import scipy

color_bar = False
max_display = 10
all_fontsizes = 30
plot_size = (10, 10)
row_height = 0.4
alpha = 1
axis_color="#333333"
dots_size = 25
color = cm
color_bar_label=labels["FEATURE_VALUE"]

shap_exp = shap_values
# we make a copy here, because later there are places that might modify this array
values = np.copy(shap_exp.values)
features = shap_exp.data
if scipy.sparse.issparse(features):
    features = features.toarray()
feature_names = shap_exp.feature_names

num_features = values.shape[1]

feature_order = [col2num[col] for col in re_ordered_features]

feature_inds = feature_order[:max_display]

# build our y-tick labels
yticklabels = [feature_names[i] for i in feature_inds]

pl.gcf().set_size_inches(plot_size[0], plot_size[1])
pl.axvline(x=0, color="#999999", zorder=-1)

# make the beeswarm dots
for pos, i in enumerate(reversed(feature_inds)):
    pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
    shaps = values[:, i]
    fvalues = None if features is None else features[:, i]
    inds = np.arange(len(shaps))
    np.random.shuffle(inds)
    if fvalues is not None:
        fvalues = fvalues[inds]
    shaps = shaps[inds]
    colored_feature = True

    fvalues = np.array(fvalues, dtype=np.float64)  # make sure this can be numeric

    N = len(shaps)
    # hspacing = (np.max(shaps) - np.min(shaps)) / 200
    # curr_bin = []
    nbins = 100
    quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))
    inds = np.argsort(quant + np.random.randn(N) * 1e-6)
    layer = 0
    last_bin = -1
    ys = np.zeros(N)
    for ind in inds:
        if quant[ind] != last_bin:
            layer = 0
        ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1)
        layer += 1
        last_bin = quant[ind]
    ys *= 0.9 * (row_height / np.max(ys + 1))

    if features is not None and colored_feature:
        # trim the color range, but prevent the color range from collapsing
        vmin = np.nanpercentile(fvalues, 5)
        vmax = np.nanpercentile(fvalues, 95)
        if vmin == vmax:
            vmin = np.nanpercentile(fvalues, 1)
            vmax = np.nanpercentile(fvalues, 99)
            if vmin == vmax:
                vmin = np.min(fvalues)
                vmax = np.max(fvalues)
        if vmin > vmax: # fixes rare numerical precision issues
            vmin = vmax

        if features.shape[0] != len(shaps):
            emsg = "Feature and SHAP matrices must have the same number of rows!"
            raise DimensionError(emsg)

        # plot the nan fvalues in the interaction feature as YELLOW
        nan_mask = np.isnan(fvalues)
        pl.scatter(shaps[nan_mask], pos + ys[nan_mask], color="#FDFD96",
                    s=16, alpha=alpha, linewidth=0,
                    zorder=3, rasterized=len(shaps) > 500)

        # plot the non-nan fvalues colored by the trimmed feature value
        cvals = fvalues[np.invert(nan_mask)].astype(np.float64)
        cvals_imp = cvals.copy()
        cvals_imp[np.isnan(cvals)] = (vmin + vmax) / 2.0
        cvals[cvals_imp > vmax] = vmax
        cvals[cvals_imp < vmin] = vmin
        pl.scatter(shaps[np.invert(nan_mask)], pos + ys[np.invert(nan_mask)],
                    cmap=color, vmin=vmin, vmax=vmax, s=dots_size,
                    c=cvals, alpha=alpha, linewidth=0,
                    zorder=3, rasterized=len(shaps) > 500)


# draw the color bar
if color_bar and features is not None:
    import matplotlib.cm as cm
    m = cm.ScalarMappable(cmap=color)
    m.set_array([0, 1])
    cb = pl.colorbar(m, ax=pl.gca(), ticks=[0, 1], aspect=40)
    cb.set_ticklabels([labels['FEATURE_VALUE_LOW'], labels['FEATURE_VALUE_HIGH']])
    cb.set_label(color_bar_label, size=all_fontsizes, labelpad=0)
    cb.ax.tick_params(labelsize=all_fontsizes, length=0)
    cb.set_alpha(1)
    cb.outline.set_visible(False)

pl.gca().xaxis.set_ticks_position('bottom')
pl.gca().yaxis.set_ticks_position('none')
pl.gca().spines['right'].set_visible(False)
pl.gca().spines['top'].set_visible(False)
pl.gca().spines['left'].set_visible(False)
pl.gca().tick_params(color=axis_color, labelcolor=axis_color)
pl.yticks(range(len(feature_inds)), reversed(yticklabels), fontsize=all_fontsizes)
pl.gca().tick_params('y', length=20, width=0.5, which='major')
pl.gca().tick_params('x', labelsize=all_fontsizes)
pl.ylim(-1, len(feature_inds))
pl.xlabel("SHAP value", fontsize=all_fontsizes)
pl.savefig(f'shap_top_10_{class0}_{class1}.pdf', format='pdf', bbox_inches='tight')

#### Interactions

Note that when there are no interactions present, the SHAP interaction values are just a diagonal matrix with the SHAP values on the diagonal.

In [None]:
# get interaction values
shap_interaction_values = explainer.shap_interaction_values(X)

# Check that SHAP interaction values essentially sum to the marginal predictions
pred = model.predict_proba(X)
pred_logodds = np.log(pred / (1 - pred))[:, 1] # convert to log odds
print(np.abs(shap_interaction_values.sum((1, 2)) + explainer.expected_value - pred_logodds).max())

# get mean absolute value of interaction values
shap_interaction_mean = np.abs(shap_interaction_values).mean(0)
shap_interaction_mean = np.round(shap_interaction_mean, decimals=1)
shap_interaction_df = pd.DataFrame(shap_interaction_mean, index=X.columns, columns=X.columns)

# reorder rows and columns according to SAGE order
shap_interaction_df = shap_interaction_df.reindex(sage_ordered_features)
shap_interaction_df = shap_interaction_df[sage_ordered_features]

# plot heatmap
pl.figure(figsize=(10,10))
sns.heatmap(shap_interaction_df, annot=True, cmap='Blues', fmt='g')
pl.title('Interaction effects')
pl.show()