# Deep Brain Stimulation: CLassification of STN-DBS ON and OFF states
For the classification of the STN-DBS ON/OFF states, the extracted feature maps were organised into a vectorized format reshaping the 3D data into a 1D vector. A mask was later applied to remove the zero values surrounding the brain so that each element in the vectors will represent a specific voxel of the corresponding connectivity map. For each measure, nine classification algorithms were implemented. The default parameters were used for all cases.

#### Libraries

In [None]:
import matplotlib.pyplot as plt
import nibabel as nib
import nilearn.masking
import pandas as pd
import seaborn as sns
from aeon.visualisation import plot_critical_difference, plot_significance
from functions import *
from matplotlib.colors import Normalize
from matplotlib.ticker import MaxNLocator
from nilearn import plotting
from scipy.stats import friedmanchisquare

#### Classification performance (Figure 2A, 2B)

In [None]:
model_list = ["LR", "KNN", "NB", "DT", "RF", "XGB", "GB", "SVC", "LDA"] * 10

file_path = "../results/scores/classification_performance_train.xlsx"
plot_auc_heatmap(file_path, model_list)

# Load and plot results for test
file_path = "../results/scores/classification_performance_test.xlsx"
plot_auc_heatmap(file_path, model_list)

#### Classification performances (Figure 2C, 2D)

In [None]:
model_names = ["LR", "KNN", "NB", "DT", "RF", "XGB", "GB", "SVC", "LDA"]

# Load and plot results for train
file_path = "../results/scores/classification_performance_train.xlsx"
metrics = ["Accuracy_mean", "F1_mean", "Recall_mean"]
plot_metrics(file_path, metrics, model_names)

# Load and plot results for test
file_path = "../results/scores/classification_performance_test.xlsx"
metrics = ["Accuracy", "F1", "Recall"]
plot_metrics(file_path, metrics, model_names)

#### Critical difference diagrams (Figure 3A, 3B)

In [None]:
file_path = "../results/scores/classification_performance_train.xlsx"
file_path = "../results/scores/classification_performance_test.xlsx"
xls = pd.ExcelFile(file_path)
datasets = xls.sheet_names
all_results = []

for sheet in datasets:
    df = xls.parse(sheet)
    df.set_index(df.columns[0], inplace=True)
    if "ROC_AUC" in df.columns:
        all_results.append(df["ROC_AUC"].values)
    if "AUC_mean" in df.columns:
        all_results.append(df["AUC_mean"].values)

auc = np.array(all_results)
classifiers = ["LR", "KNN", "NB", "DT", "RF", "XGB", "GB", "SVC", "LDA"]
auc_df = pd.DataFrame(auc, index=datasets, columns=classifiers)

# Perform the Friedman Test
friedman_stat, p_value = friedmanchisquare(*auc_df.T.values)
print(f"Friedman test statistic: {friedman_stat}, p-value = {p_value}")

plt.figure(figsize=(16, 12))

# critical difference diagram
plot_critical_difference(
    auc,
    classifiers,
    lower_better=False,
    correction="holm",
    alpha=0.05,
)

# Save and display the plot
# plt.savefig('critical_difference_diagram_test.svg', format="svg", bbox_inches="tight")
plt.show()
plot_significance(
    auc,
    classifiers,
)
# plt.savefig('significance_test.svg', format="svg", bbox_inches="tight")
plt.show()

#### Boxplot ROCAUC (Figure 3C, 3D)

In [None]:
file_path = "../results/scores/classification_performance_train.xlsx"
file_path = "../results/scores/classification_performance_test.xlsx"
xls = pd.ExcelFile(file_path)
classifiers = ["LR", "KNN", "NB", "DT", "RF", "XGB", "GB", "SVC", "LDA"]
datasets = xls.sheet_names
all_results = []

for sheet in datasets:
    df = xls.parse(sheet)
    df.set_index(df.columns[0], inplace=True)
    if "ROC_AUC" in df.columns:
        all_results.append(df["ROC_AUC"].values)
    if "AUC_mean" in df.columns:
        all_results.append(df["AUC_mean"].values)

auc = np.array(all_results).T

df_long = pd.DataFrame(auc.T, columns=classifiers).melt(var_name="Classifier", value_name="AUC")

# Set figure size
plt.figure(figsize=(8, 3.5))

# Create violin plot
sns.violinplot(data=df_long, x="Classifier", y="AUC", inner=None, linewidth=1.5, color="lightgray")

# Overlay with boxplot for median and IQR
sns.boxplot(
    data=df_long,
    x="Classifier",
    y="AUC",
    width=0.12,
    showcaps=False,
    showfliers=False,
    boxprops={"facecolor": "black", "edgecolor": "None", "linewidth": 0},
    whiskerprops={"color": "black", "linewidth": 1},
    medianprops={"color": "white", "linewidth": 3},
)

plt.xticks(rotation=45, size=12)
plt.yticks(size=12)
plt.title("AUC Distribution Across Classifiers")
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.ylim(0.16, 1.13)
# Save and show plot
# plt.savefig('violin_plot_test.svg', format="svg", bbox_inches="tight", dpi=300)
plt.show()

#### Feature importance (Figure 4)

In [None]:
# Import feature map
nii_file = "../results/maps/ECM_norm_LDA_importance_map.nii.gz"

fig = plt.figure(figsize=(15, 5))
display = plotting.plot_glass_brain(None, figure=fig)
display.add_overlay(nii_file, alpha=0.8, cmap="RdYlBu_r", colorbar=True)
display.add_contours(
    nii_file,
    cmap="RdYlBu_r",
)

# Show and save plot
plt.savefig("ECM_norm_LDA.svg")
plt.show()


#### Generate maps for Figure 4

In [None]:
# Path to feature importances
feature_importances_path = "../results/scores/classification_features_importances_test.xlsx"

# Loop over all measures
for measure in measures:
    print(f"Processing feature importance for measure: {measure}")
    df_feat = pd.read_excel(feature_importances_path, sheet_name=measure)

    # For each model in the sheet
    for model_name in df_feat.columns:
        importance_scores = df_feat[model_name].values

        # Unmask to 3D brain image
        tmp_img = nilearn.masking.unmask(importance_scores, mask_img)

        # Save NIfTI image
        nii_data = tmp_img.get_fdata()
        nii_affine = tmp_img.affine
        importance_nii = nib.Nifti1Image(nii_data, affine=nii_affine)
        nib.save(importance_nii, "../results/maps/" + f"{measure}_{model_name}_importance_map.nii.gz")

print("Feature importance maps saved successfully.")