Setup and imports

In [1]:
from __future__ import annotations

import sys
from pathlib import Path

import numpy as np
import pandas as pd

import shap
import matplotlib.pyplot as plt

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import StratifiedKFold

import lightgbm as lgb

# Project imports
PROJECT_ROOT = Path.cwd().parents[0]  # assumes notebooks/ is one level below repo root
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.config import get_paths, set_seed
from src.io import read_table

Paths and load datasets

In [2]:
set_seed(42)
paths = get_paths()

# Load processed datasets (CSV recommended for maximum compatibility)
X = read_table(paths.processed / "X_radiomics.csv")
y_group = read_table(paths.processed / "y_group.csv")["group"]

LABELS_ORDER = ["GC", "G2", "G5", "G7", "G14"]

print("Loaded datasets:")
print(f"- X_radiomics: {X.shape}")
print(f"- y_group:     {y_group.shape} | unique={sorted(y_group.unique().tolist())}")

# Output folder (NOT versioned in GitHub; included in Zenodo release)
SHAP_DIR = paths.root / "supplementary" / "shap"
SHAP_DIR.mkdir(parents=True, exist_ok=True)
print(f"SHAP outputs will be saved to: {SHAP_DIR.resolve()}")

Loaded datasets:
- X_radiomics: (571, 105)
- y_group:     (571,) | unique=['G14', 'G2', 'G5', 'G7', 'GC']
SHAP outputs will be saved to: C:\Users\modre\Documents\masseter\supplementary\shap


Fit final LightGBM classifier (with MinMax scaling)

In [7]:
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler
import lightgbm as lgb

clf = Pipeline([
    ("scaler", MinMaxScaler()),
    ("model", lgb.LGBMClassifier(
        random_state=42,
        n_estimators=500,
    ))
])

clf.fit(X, y_group)

print("Final LightGBM classifier fitted on full dataset.")
print("Classes:", clf.named_steps["model"].classes_)


[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.002183 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 19285
[LightGBM] [Info] Number of data points in the train set: 571, number of used features: 104
[LightGBM] [Info] Start training from score -1.646909
[LightGBM] [Info] Start training from score -1.646909
[LightGBM] [Info] Start training from score -1.551599
[LightGBM] [Info] Start training from score -1.646909
[LightGBM] [Info] Start training from score -1.559897
Final LightGBM classifier fitted on full dataset.
Classes: ['G14' 'G2' 'G5' 'G7' 'GC']


Prepare scaled features and extract trained model

In [8]:
import pandas as pd

X_scaled = clf.named_steps["scaler"].transform(X)
X_scaled = pd.DataFrame(X_scaled, columns=X.columns)

lgbm_model = clf.named_steps["model"]

print("Prepared scaled feature matrix:", X_scaled.shape)

Prepared scaled feature matrix: (571, 105)


Compute SHAP values (multiclass)

In [9]:
import shap

explainer = shap.TreeExplainer(lgbm_model)
shap_values = explainer.shap_values(X_scaled)

classes_ = list(lgbm_model.classes_)
print("Model classes:", classes_)

if isinstance(shap_values, list):
    for i, sv in enumerate(shap_values):
        print(f"Class {classes_[i]}: shap_values shape = {sv.shape}")
else:
    print("SHAP array shape:", shap_values.shape)


Model classes: ['G14', 'G2', 'G5', 'G7', 'GC']
SHAP array shape: (571, 105, 5)


Save global importance (mean |SHAP| aggregated across classes)

In [11]:
import numpy as np
import pandas as pd

# Normalize SHAP output to a 3D array: (n_classes, n_samples, n_features)
if isinstance(shap_values, list):
    # list of (n_samples, n_features)
    shap_3d = np.stack(shap_values, axis=0)
elif isinstance(shap_values, np.ndarray) and shap_values.ndim == 3:
    # could be (n_samples, n_features, n_classes) or (n_classes, n_samples, n_features)
    if shap_values.shape[0] == X_scaled.shape[0]:
        # (n_samples, n_features, n_classes) -> (n_classes, n_samples, n_features)
        shap_3d = np.transpose(shap_values, (2, 0, 1))
    else:
        # assume already (n_classes, n_samples, n_features)
        shap_3d = shap_values
else:
    raise RuntimeError(f"Unexpected SHAP output type/shape: {type(shap_values)} / {getattr(shap_values, 'shape', None)}")

n_classes, n_samples, n_features = shap_3d.shape
print("Normalized SHAP shape (n_classes, n_samples, n_features):", shap_3d.shape)

# Global importance: mean(|SHAP|) across samples and classes
abs_mean_global = np.mean(np.abs(shap_3d), axis=(0, 1))  # (n_features,)

imp_global = (
    pd.DataFrame({"feature": X_scaled.columns, "mean_abs_shap": abs_mean_global})
    .sort_values("mean_abs_shap", ascending=False)
    .reset_index(drop=True)
)

out_csv = SHAP_DIR / "shap_classification_feature_importance_global.csv"
imp_global.to_csv(out_csv, index=False)

print(f"Saved: {out_csv}")
display(imp_global.head(20))


Normalized SHAP shape (n_classes, n_samples, n_features): (5, 571, 105)
Saved: c:\Users\modre\Documents\masseter\supplementary\shap\shap_classification_feature_importance_global.csv


Unnamed: 0,feature,mean_abs_shap
0,shapeSurfaceVolumeRatio_delta,1.112775
1,shapeMaximum2DDiameterRow_delta,0.883955
2,glszmSizeZoneNonUniformity_delta,0.564584
3,ngtdmCoarseness_delta,0.435105
4,shapeMinorAxisLength_delta,0.281548
5,glszmZonePercentage_delta,0.263695
6,glrlmShortRunLowGrayLevelEmphasis_delta,0.263583
7,glrlmLongRunHighGrayLevelEmphasis_delta,0.258656
8,gldmLargeDependenceLowGrayLevelEmphasis_delta,0.253399
9,glrlmRunVariance_delta,0.246861


Summary bar global

In [12]:
import matplotlib.pyplot as plt
import shap

# Aggregate across classes: mean over classes -> (n_samples, n_features)
shap_agg = shap_3d.mean(axis=0)

plt.figure()
shap.summary_plot(
    shap_agg,
    X_scaled,
    plot_type="bar",
    show=False,
    max_display=30
)

out_png = SHAP_DIR / "shap_classification_summary_bar_global.png"
plt.tight_layout()
plt.savefig(out_png, dpi=300, bbox_inches="tight")
plt.close()

print(f"Saved: {out_png}")


  shap.summary_plot(


Saved: c:\Users\modre\Documents\masseter\supplementary\shap\shap_classification_summary_bar_global.png


Beeswarm global

In [13]:
plt.figure()
shap.summary_plot(
    shap_agg,
    X_scaled,
    show=False,
    max_display=30
)

out_png = SHAP_DIR / "shap_classification_beeswarm_global.png"
plt.tight_layout()
plt.savefig(out_png, dpi=300, bbox_inches="tight")
plt.close()

print(f"Saved: {out_png}")


  shap.summary_plot(


Saved: c:\Users\modre\Documents\masseter\supplementary\shap\shap_classification_beeswarm_global.png


Beeswarm class

In [14]:
classes_ = list(lgbm_model.classes_)  # already printed earlier

for i, cls in enumerate(classes_):
    plt.figure()
    shap.summary_plot(
        shap_3d[i],  # (n_samples, n_features) for that class
        X_scaled,
        show=False,
        max_display=25
    )
    out_png = SHAP_DIR / f"shap_classification_beeswarm_{cls}.png"
    plt.tight_layout()
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved: {out_png}")


  shap.summary_plot(


Saved: c:\Users\modre\Documents\masseter\supplementary\shap\shap_classification_beeswarm_G14.png


  shap.summary_plot(


Saved: c:\Users\modre\Documents\masseter\supplementary\shap\shap_classification_beeswarm_G2.png


  shap.summary_plot(


Saved: c:\Users\modre\Documents\masseter\supplementary\shap\shap_classification_beeswarm_G5.png


  shap.summary_plot(


Saved: c:\Users\modre\Documents\masseter\supplementary\shap\shap_classification_beeswarm_G7.png


  shap.summary_plot(


Saved: c:\Users\modre\Documents\masseter\supplementary\shap\shap_classification_beeswarm_GC.png
