calculate feature-importances using SHAP-values

In [1]:
# setup

import os
import sys
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import shap
from sklearn.ensemble import RandomForestRegressor

# add root to sys path
sys.path.append("/Users/gilanorup/Desktop/Studium/MSc/MA/code/masters_thesis_gn/src")
from config.constants import GIT_DIRECTORY
from data_preparation.feature_set_helpers import stratified_cv_feature_importance
from regression.evaluation_helpers import format_title

# parameters
task_name = "picnicScene"
target = "SemanticFluencyScore"
folds_path = os.path.join(GIT_DIRECTORY, "data/stratified_folds.csv")
scores_path = os.path.join(GIT_DIRECTORY, "data/language_scores_all_subjects.csv")
features_path = os.path.join(GIT_DIRECTORY, f"results/features/filtered/{task_name}_filtered.csv")
save_dir = os.path.join(GIT_DIRECTORY, "results/feature_importance", task_name, target)
os.makedirs(save_dir, exist_ok=True)

# use same model as in model-comparison
# parameters for random forest
rf_params={"n_estimators": 200, "random_state": 42, "min_samples_leaf": 5, "max_features": "sqrt"}
# use same subject intersection
full_subjects_csv = os.path.join(GIT_DIRECTORY, "results", "regression", "model_comparison",
                                 "filtered", "tasks", f"{target}_full_subjects.csv")
full_subjects = set(pd.read_csv(full_subjects_csv)["Subject_ID"])

# load data
df = pd.read_csv(features_path)
fold_df = pd.read_csv(folds_path)
scores_df = pd.read_csv(scores_path)

# one-hot encode demographics
demo = fold_df[["Subject_ID","Gender","Education","Country","Age","Socioeconomic"]].copy()
placeholders = {"no_answer","other"}
for col in ["Gender","Education","Country"]:
    demo[col] = demo[col].astype("string").str.lower().str.strip()
    demo[col] = demo[col].replace(placeholders, pd.NA)
demo = demo.dropna(subset=["Gender","Education","Country"]).copy()

one_hot_drop_first = True
cat_cols = ["Gender","Education","Country"]
num_cols = ["Age","Socioeconomic"]

demo_dummies = pd.get_dummies(demo[cat_cols], columns=cat_cols, drop_first=one_hot_drop_first)

demo = pd.concat(
    [demo[["Subject_ID"] + num_cols].reset_index(drop=True),
     demo_dummies.reset_index(drop=True)],
    axis=1
)

# merge data
df = pd.merge(df, demo, on="Subject_ID", how="inner")
df = pd.merge(df, fold_df[["Subject_ID","fold"]], on="Subject_ID", how="inner")
df = pd.merge(df, scores_df[["Subject_ID", target]], on="Subject_ID", how="inner")

if full_subjects is not None:
    df = df[df["Subject_ID"].isin(full_subjects)].copy()

exclude_features = {"fam_verbs","img_verbs","fam_nouns","img_nouns"}

# full possible feature lists
linguistic_features = {
    "n_words", "ttr", "mattr_10", "mattr_20", "mattr_30", "mattr_40", "mattr_50", "filler_word_ratio",
    "average_word_length", "brunets_index", "honores_statistic", "guirauds_statistic", "light_verb_ratio",
    "empty_word_ratio", "nid_ratio", "adjacent_repetitions", "aoa_content", "aoa_nouns", "aoa_verbs",
    "fam_content", "fam_nouns", "fam_verbs", "img_content", "img_nouns", "img_verbs", "freq_content",
    "freq_nouns", "freq_verbs", "concr_content", "concr_nouns", "concr_verbs", "um_ratio", "uh_ratio",
    "er_ratio", "ah_ratio", "ADJ", "ADP", "ADV", "AUX", "CCONJ", "DET", "INTJ", "NOUN", "NUM", "PART",
    "PRON", "PROPN", "SCONJ", "VERB", "OTHER", "NOUN/VERB", "PRON/NOUN", "DET/NOUN", "AUX/VERB",
    "OPEN/CLOSED", "INFORMATION_WORDS", "article_pause_contentword"
}
acoustic_features = {
    "phonation_rate", "total_speech_duration", "speech_rate_phonemes", "speech_rate_words", "n_pauses",
    "total_pause_duration", "avg_pause_duration", "short_pause_count", "long_pause_count", "pause_word_ratio",
    "pause_ratio", "pause_rate",
    "eGeMAPS_F0semitoneFrom27.5Hz_sma3nz_amean", "eGeMAPS_F0semitoneFrom27.5Hz_sma3nz_stddevNorm",
    "eGeMAPS_F0semitoneFrom27.5Hz_sma3nz_percentile20.0", "eGeMAPS_F0semitoneFrom27.5Hz_sma3nz_percentile50.0",
    "eGeMAPS_F0semitoneFrom27.5Hz_sma3nz_percentile80.0", "eGeMAPS_F0semitoneFrom27.5Hz_sma3nz_pctlrange0-2",
    "eGeMAPS_F0semitoneFrom27.5Hz_sma3nz_meanRisingSlope", "eGeMAPS_F0semitoneFrom27.5Hz_sma3nz_stddevRisingSlope",
    "eGeMAPS_F0semitoneFrom27.5Hz_sma3nz_meanFallingSlope", "eGeMAPS_F0semitoneFrom27.5Hz_sma3nz_stddevFallingSlope",
    "eGeMAPS_loudness_sma3_amean", "eGeMAPS_loudness_sma3_stddevNorm", "eGeMAPS_loudness_sma3_percentile20.0",
    "eGeMAPS_loudness_sma3_percentile50.0", "eGeMAPS_loudness_sma3_percentile80.0", "eGeMAPS_loudness_sma3_pctlrange0-2",
    "eGeMAPS_loudness_sma3_meanRisingSlope", "eGeMAPS_loudness_sma3_stddevRisingSlope",
    "eGeMAPS_loudness_sma3_meanFallingSlope", "eGeMAPS_loudness_sma3_stddevFallingSlope",
    "eGeMAPS_spectralFlux_sma3_amean", "eGeMAPS_spectralFlux_sma3_stddevNorm",
    "eGeMAPS_mfcc1_sma3_amean", "eGeMAPS_mfcc1_sma3_stddevNorm", "eGeMAPS_mfcc2_sma3_amean",
    "eGeMAPS_mfcc2_sma3_stddevNorm", "eGeMAPS_mfcc3_sma3_amean", "eGeMAPS_mfcc3_sma3_stddevNorm",
    "eGeMAPS_mfcc4_sma3_amean", "eGeMAPS_mfcc4_sma3_stddevNorm",
    "eGeMAPS_jitterLocal_sma3nz_amean", "eGeMAPS_jitterLocal_sma3nz_stddevNorm",
    "eGeMAPS_shimmerLocaldB_sma3nz_amean", "eGeMAPS_shimmerLocaldB_sma3nz_stddevNorm",
    "eGeMAPS_HNRdBACF_sma3nz_amean", "eGeMAPS_HNRdBACF_sma3nz_stddevNorm",
    "eGeMAPS_logRelF0-H1-H2_sma3nz_amean", "eGeMAPS_logRelF0-H1-H2_sma3nz_stddevNorm",
    "eGeMAPS_logRelF0-H1-A3_sma3nz_amean", "eGeMAPS_logRelF0-H1-A3_sma3nz_stddevNorm",
    "eGeMAPS_F1frequency_sma3nz_amean", "eGeMAPS_F1frequency_sma3nz_stddevNorm",
    "eGeMAPS_F1bandwidth_sma3nz_amean", "eGeMAPS_F1bandwidth_sma3nz_stddevNorm",
    "eGeMAPS_F1amplitudeLogRelF0_sma3nz_amean", "eGeMAPS_F1amplitudeLogRelF0_sma3nz_stddevNorm",
    "eGeMAPS_F2frequency_sma3nz_amean", "eGeMAPS_F2frequency_sma3nz_stddevNorm",
    "eGeMAPS_F2bandwidth_sma3nz_amean", "eGeMAPS_F2bandwidth_sma3nz_stddevNorm",
    "eGeMAPS_F2amplitudeLogRelF0_sma3nz_amean", "eGeMAPS_F2amplitudeLogRelF0_sma3nz_stddevNorm",
    "eGeMAPS_F3frequency_sma3nz_amean", "eGeMAPS_F3frequency_sma3nz_stddevNorm",
    "eGeMAPS_F3bandwidth_sma3nz_amean", "eGeMAPS_F3bandwidth_sma3nz_stddevNorm",
    "eGeMAPS_F3amplitudeLogRelF0_sma3nz_amean", "eGeMAPS_F3amplitudeLogRelF0_sma3nz_stddevNorm",
    "eGeMAPS_alphaRatioV_sma3nz_amean", "eGeMAPS_alphaRatioV_sma3nz_stddevNorm",
    "eGeMAPS_hammarbergIndexV_sma3nz_amean", "eGeMAPS_hammarbergIndexV_sma3nz_stddevNorm",
    "eGeMAPS_slopeV0-500_sma3nz_amean", "eGeMAPS_slopeV0-500_sma3nz_stddevNorm",
    "eGeMAPS_slopeV500-1500_sma3nz_amean", "eGeMAPS_slopeV500-1500_sma3nz_stddevNorm",
    "eGeMAPS_spectralFluxV_sma3nz_amean", "eGeMAPS_spectralFluxV_sma3nz_stddevNorm",
    "eGeMAPS_mfcc1V_sma3nz_amean", "eGeMAPS_mfcc1V_sma3nz_stddevNorm",
    "eGeMAPS_mfcc2V_sma3nz_amean", "eGeMAPS_mfcc2V_sma3nz_stddevNorm",
    "eGeMAPS_mfcc3V_sma3nz_amean", "eGeMAPS_mfcc3V_sma3nz_stddevNorm",
    "eGeMAPS_mfcc4V_sma3nz_amean", "eGeMAPS_mfcc4V_sma3nz_stddevNorm",
    "eGeMAPS_alphaRatioUV_sma3nz_amean", "eGeMAPS_hammarbergIndexUV_sma3nz_amean",
    "eGeMAPS_slopeUV0-500_sma3nz_amean", "eGeMAPS_slopeUV500-1500_sma3nz_amean",
    "eGeMAPS_spectralFluxUV_sma3nz_amean",
    "eGeMAPS_loudnessPeaksPerSec", "eGeMAPS_VoicedSegmentsPerSec", "eGeMAPS_MeanVoicedSegmentLengthSec",
    "eGeMAPS_StddevVoicedSegmentLengthSec", "eGeMAPS_MeanUnvoicedSegmentLength",
    "eGeMAPS_StddevUnvoicedSegmentLength", "eGeMAPS_equivalentSoundLevel_dBp"
}

demographic_cols = [c for c in df.columns if c.startswith("Gender_") or c.startswith("Education_") or c.startswith("Country_")] + ["Age","Socioeconomic"]

full_feature_set = (linguistic_features | acoustic_features) - exclude_features
feature_cols = sorted([f for f in full_feature_set if f in df.columns]) + demographic_cols

# coerce numeric and dropna
df[feature_cols] = df[feature_cols].apply(pd.to_numeric, errors="coerce")
df = df.dropna(subset=[target] + feature_cols).copy()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# plot style
plt.style.use("default")
plt.rcParams.update({
    "axes.facecolor": "white",
    "figure.facecolor": "white",
    "axes.edgecolor": "black",
    "axes.labelcolor": "black",
    "xtick.color": "black",
    "ytick.color": "black",
    "font.family": "Arial",
    "savefig.dpi": 300,
    "savefig.bbox": "tight"
})

In [3]:
# run CV + importances
shap_explanation, shap_table = stratified_cv_feature_importance(
    df=df,
    fold_column="fold",
    model_type=RandomForestRegressor,
    model_params=rf_params,
    target_column=target,
    feature_columns=feature_cols,
    save_dir=save_dir,
    task_name=task_name
)

SHAP failed on fold 1: Cannot cast array data from dtype('O') to dtype('float64') according to the rule 'safe'


Found a NULL input array in _cext_dense_tree_update_weights!


SHAP failed on fold 2: Cannot cast array data from dtype('O') to dtype('float64') according to the rule 'safe'


Found a NULL input array in _cext_dense_tree_update_weights!


SHAP failed on fold 3: Cannot cast array data from dtype('O') to dtype('float64') according to the rule 'safe'


Found a NULL input array in _cext_dense_tree_update_weights!


SHAP failed on fold 4: Cannot cast array data from dtype('O') to dtype('float64') according to the rule 'safe'


Found a NULL input array in _cext_dense_tree_update_weights!


SHAP failed on fold 5: Cannot cast array data from dtype('O') to dtype('float64') according to the rule 'safe'
no SHAP values available


Found a NULL input array in _cext_dense_tree_update_weights!


In [4]:
# SHAP plots
shap.plots.bar(shap_explanation, max_display=20, show=False)
plt.tight_layout(rect=[0, 0, 0.85, 1])
plt.savefig(os.path.join(save_dir, f"{task_name}_{target}_shap_bar.png"), dpi=300)
plt.close()

shap.summary_plot(shap_explanation, plot_type="bar", show=False)
plt.savefig(os.path.join(save_dir, f"{task_name}_{target}_shap_summary.png"), dpi=300)
plt.close()

shap.summary_plot(shap_explanation, plot_type="violin", show=False)
plt.savefig(os.path.join(save_dir, f"{task_name}_{target}_shap_violin.png"), dpi=300)
plt.close()

shap.plots.beeswarm(shap_explanation, show=False)
plt.tight_layout(rect=[0, 0, 0.85, 1])
plt.savefig(os.path.join(save_dir, f"{task_name}_{target}_shap_beeswarm.png"), dpi=300)
plt.close()

TypeError: The shap_values argument must be an Explanation object, Cohorts object, or dictionary of Explanation objects!

In [None]:
# local SHAP-plot
# pick a random subject from df
rng = np.random.default_rng(0)
rand_idx = int(rng.integers(0, len(df)))
subject_id = df.iloc[rand_idx]["Subject_ID"]
subject_fold = df.iloc[rand_idx]["fold"]

# split train/test respecting the subject's fold
train_df = df[df["fold"] != subject_fold].copy()
test_row = df[df["Subject_ID"] == subject_id].copy()

X_train = train_df[feature_cols]
y_train = train_df[target]
X_test_one = test_row[feature_cols]
y_test_one = test_row[target].iloc[0]

rf = RandomForestRegressor(**rf_params)
rf.fit(X_train, y_train)

background = shap.sample(X_train, min(200, len(X_train)), random_state=0)
explainer = shap.TreeExplainer(
    rf,
    data=background,
    feature_perturbation="interventional",
    model_output="raw"
)
# local explanation for this one instance
ex = explainer(X_test_one)

y_pred_one = rf.predict(X_test_one)[0]
residual_one = y_test_one - y_pred_one
title = (f"Local SHAP Values: {format_title(target)} "
         f"(Subject {subject_id}, {format_title(task_name)})\n"
         f"true={y_test_one:.2f}, pred={y_pred_one:.2f}, resid={residual_one:.2f}")

# waterfall plot
shap.plots.waterfall(ex[0], max_display=15, show=False)
plt.title(title)
plt.tight_layout()
out_path = os.path.join(
    save_dir,
    f"{task_name}_{target}_local_waterfall_subject-{subject_id}.png"
)
plt.savefig(out_path, dpi=300, bbox_inches="tight")
plt.close()
print(f"Saved: {out_path}")