In [1]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
import shap
from tqdm import tqdm
from scipy.stats import ttest_ind

from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Input, Conv2D, MaxPooling2D, Dropout, Flatten,
                                     Dense, LSTM, MultiHeadAttention, Concatenate, Reshape)
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.saving import register_keras_serializable

from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.neighbors import KNeighborsClassifier, NeighborhoodComponentsAnalysis
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix, roc_curve, auc
from sklearn.preprocessing import StandardScaler
from imblearn.over_sampling import SMOTE
from imblearn.pipeline import Pipeline

# =============================================================================
# --- 🚀 CONFIGURATION ---
# =============================================================================
# Define the datasets to use for TRAINING. Can be one or more.
# Example: TRAIN_DATASETS = ["Italian", "mPower"]
TRAIN_DATASETS = ["UAMS", "Neurovoz"] 

# Define the single dataset to use for TESTING.
# - Set to a different dataset name (e.g., "mPower") for a true unseen test.
# - Set to None to use a validation split from the training data for the final evaluation.
TEST_DATASET = None 

# Define parameters for the feature files
MODE = "ALL_VALIDS"
FEATURE_MODE = "ALL"

# A unique name for this training run to save models and results
RUN_ID = f"trained_on_{'_'.join(TRAIN_DATASETS)}"
RESULTS_PATH = os.path.join(os.getcwd(), "runs", RUN_ID)
PLOTS_PATH = os.path.join(RESULTS_PATH, "plots") # Centralized folder for all plots
os.makedirs(RESULTS_PATH, exist_ok=True)
os.makedirs(PLOTS_PATH, exist_ok=True)

# Hyperparameters
EPOCHS = 30
BATCH_SIZE = 32
LEARNING_RATE = 0.001
DROPOUT_RATE = 0.5
L2_STRENGTH = 0.01


KeyboardInterrupt: 

In [None]:

# =============================================================================
# --- Your Custom CNN Models ---
# =============================================================================
@register_keras_serializable()
class ParkinsonDetectorModel(Model):
    """Your end-to-end CNN model for original features."""
    def __init__(self, input_shape, **kwargs):
        super(ParkinsonDetectorModel, self).__init__(**kwargs)
        self.input_shape_config = input_shape
        self.reshape_in = Reshape((input_shape[0], input_shape[1], 1))
        self.conv1a = Conv2D(64, 5, activation='relu', kernel_regularizer=l2(L2_STRENGTH), padding='same')
        self.conv1b = Conv2D(64, 5, activation='relu', kernel_regularizer=l2(L2_STRENGTH), padding='same')
        self.pool1 = MaxPooling2D(5)
        self.drop1 = Dropout(DROPOUT_RATE)
        self.conv2a = Conv2D(64, 5, activation='relu', kernel_regularizer=l2(L2_STRENGTH), padding='same')
        self.conv2b = Conv2D(64, 5, activation='relu', kernel_regularizer=l2(L2_STRENGTH), padding='same', name="last_conv_layer")
        self.pool2 = MaxPooling2D(5)
        self.drop2 = Dropout(DROPOUT_RATE)
        self.flatten_cnn = Flatten()
        self.attention = MultiHeadAttention(num_heads=2, key_dim=64)
        self.flatten_att = Flatten()
        self.lstm1 = LSTM(128, return_sequences=True)
        self.lstm2 = LSTM(128, return_sequences=False)
        self.drop_lstm = Dropout(DROPOUT_RATE)
        self.concat = Concatenate()
        self.dense_bottleneck = Dense(128, activation='relu', name='bottleneck_features')
        self.dense_output = Dense(1, activation='sigmoid')

    def call(self, inputs, extract_features=False, grad_cam=False):
        x = self.reshape_in(inputs)
        x = self.conv1a(x); x = self.conv1b(x); x = self.pool1(x); x = self.drop1(x, training=False)
        x = self.conv2a(x)
        last_conv_output = self.conv2b(x)
        x = self.pool2(last_conv_output); x = self.drop2(x, training=False)
        cnn_flat = self.flatten_cnn(x)
        shape = tf.shape(x)
        sequence = tf.reshape(x, [-1, shape[1] * shape[2], shape[3]])
        att_out = self.attention(query=sequence, key=sequence, value=sequence)
        att_flat = self.flatten_att(att_out)
        lstm_seq = self.lstm1(sequence); lstm_out = self.lstm2(lstm_seq); lstm_out = self.drop_lstm(lstm_out, training=False)
        concatenated = self.concat([cnn_flat, att_flat, lstm_out])
        bottleneck = self.dense_bottleneck(concatenated)
        final_output = self.dense_output(bottleneck)

        if grad_cam: return final_output, last_conv_output
        if extract_features: return bottleneck
        return final_output

    def get_config(self):
        config = super(ParkinsonDetectorModel, self).get_config()
        config.update({"input_shape": self.input_shape_config}); return config
    @classmethod
    def from_config(cls, config):
        return cls(**config)

@register_keras_serializable()
class ParkinsonDetectorModelNCA(Model):
    """A modified version of your model for the small NCA input."""
    def __init__(self, input_shape, **kwargs):
        super(ParkinsonDetectorModelNCA, self).__init__(**kwargs)
        self.input_shape_config = input_shape
        self.reshape_in = Reshape((input_shape[0], input_shape[1], 1))
        self.conv1a = Conv2D(64, 3, activation='relu', kernel_regularizer=l2(L2_STRENGTH), padding='same')
        self.conv1b = Conv2D(64, 3, activation='relu', kernel_regularizer=l2(L2_STRENGTH), padding='same')
        self.pool1 = MaxPooling2D(2)
        self.drop1 = Dropout(DROPOUT_RATE)
        self.conv2a = Conv2D(64, 3, activation='relu', kernel_regularizer=l2(L2_STRENGTH), padding='same')
        self.conv2b = Conv2D(64, 3, activation='relu', kernel_regularizer=l2(L2_STRENGTH), padding='same', name="last_conv_layer")
        self.pool2 = MaxPooling2D(2)
        self.drop2 = Dropout(DROPOUT_RATE)
        self.flatten_cnn = Flatten()
        self.attention = MultiHeadAttention(num_heads=2, key_dim=64)
        self.flatten_att = Flatten()
        self.lstm1 = LSTM(128, return_sequences=True)
        self.lstm2 = LSTM(128, return_sequences=False)
        self.drop_lstm = Dropout(DROPOUT_RATE)
        self.concat = Concatenate()
        self.dense_bottleneck = Dense(128, activation='relu', name='bottleneck_features')
        self.dense_output = Dense(1, activation='sigmoid')

    def call(self, inputs, extract_features=False, grad_cam=False):
        x = self.reshape_in(inputs)
        x = self.conv1a(x); x = self.conv1b(x); x = self.pool1(x); x = self.drop1(x, training=False)
        x = self.conv2a(x)
        last_conv_output = self.conv2b(x)
        x = self.pool2(last_conv_output); x = self.drop2(x, training=False)
        cnn_flat = self.flatten_cnn(x)
        shape = tf.shape(x)
        sequence = tf.reshape(x, [-1, shape[1] * shape[2], shape[3]])
        att_out = self.attention(query=sequence, key=sequence, value=sequence)
        att_flat = self.flatten_att(att_out)
        lstm_seq = self.lstm1(sequence); lstm_out = self.lstm2(lstm_seq); lstm_out = self.drop_lstm(lstm_out, training=False)
        concatenated = self.concat([cnn_flat, att_flat, lstm_out])
        bottleneck = self.dense_bottleneck(concatenated)
        final_output = self.dense_output(bottleneck)

        if grad_cam: return final_output, last_conv_output
        if extract_features: return bottleneck
        return final_output

    def get_config(self):
        config = super(ParkinsonDetectorModelNCA, self).get_config()
        config.update({"input_shape": self.input_shape_config}); return config
    @classmethod
    def from_config(cls, config):
        return cls(**config)

# =============================================================================
# --- Data Loading & Helper Functions ---
# =============================================================================
def load_single_dataset(dataset_name, mode, feature_mode):
    path = os.path.join(os.getcwd(), dataset_name, "data", f"features_{mode}_{feature_mode}.npz")
    if not os.path.exists(path):
        print(f"WARNING: Data file not found at {path}. Skipping."); return None
    print(f"--- Loading data from {path} ---")
    with np.load(path) as data:
        X = np.concatenate((data['mel_spectrogram'], data['mfcc']), axis=1)
        labels = data['labels']
    print(f"Loaded {dataset_name} successfully. Shape: {X.shape}"); return X, labels

def load_and_combine_data(dataset_names, mode, feature_mode):
    all_X, all_y = [], []
    for name in dataset_names:
        data = load_single_dataset(name, mode, feature_mode)
        if data: all_X.append(data[0]); all_y.append(data[1])
    if not all_X: raise ValueError("No training data could be loaded. Aborting.")
    combined_X = np.concatenate(all_X, axis=0); combined_y = np.concatenate(all_y, axis=0)
    print(f"\n--- All training data combined. Final shape: X={combined_X.shape}, y={combined_y.shape} ---")
    return combined_X, combined_y

# =============================================================================
# --- Plotting and Evaluation Functions ---
# =============================================================================
def plot_and_save_history(history, model_name, path):
    history_df = pd.DataFrame(history.history)
    history_df.to_csv(os.path.join(path, f"{model_name}_history.csv"))
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, axes = plt.subplots(1, 3, figsize=(20, 5))
    axes[0].plot(history_df['loss'], label='Train Loss'); axes[0].plot(history_df['val_loss'], label='Val Loss', linestyle='--')
    axes[0].set_title(f'{model_name} - Model Loss'); axes[0].set_xlabel('Epoch'); axes[0].legend()
    axes[1].plot(history_df['accuracy'], label='Train Acc'); axes[1].plot(history_df['val_accuracy'], label='Val Acc', linestyle='--')
    axes[1].set_title(f'{model_name} - Model Accuracy'); axes[1].set_xlabel('Epoch'); axes[1].legend()
    if 'auc' in history_df.columns:
        axes[2].plot(history_df['auc'], label='Train AUC'); axes[2].plot(history_df['val_auc'], label='Val AUC', linestyle='--')
        axes[2].set_title(f'{model_name} - Model AUC'); axes[2].set_xlabel('Epoch'); axes[2].legend()
    plt.tight_layout(); plt.savefig(os.path.join(path, f"{model_name}_history.png"), dpi=300); plt.close()
    print(f"✅ Saved training history and plot for {model_name}.")

def plot_and_save_confusion_matrix(y_true, y_pred, model_name, path):
    cm = confusion_matrix(y_true, y_pred); cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    fig, axes = plt.subplots(1, 2, figsize=(14, 6)); class_names = ['Healthy', 'Parkinson']
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0], xticklabels=class_names, yticklabels=class_names)
    axes[0].set_title(f'{model_name}\nConfusion Matrix (Counts)'); axes[0].set_xlabel('Predicted'); axes[0].set_ylabel('True')
    sns.heatmap(cm_percent, annot=True, fmt='.2%', cmap='Blues', ax=axes[1], xticklabels=class_names, yticklabels=class_names)
    axes[1].set_title(f'{model_name}\nConfusion Matrix (Percentages)'); axes[1].set_xlabel('Predicted'); axes[1].set_ylabel('True')
    plt.tight_layout(); plt.savefig(os.path.join(path, f"{model_name}_confusion_matrix.png"), dpi=300); plt.close()
    print(f"✅ Saved confusion matrix for {model_name}.")

def plot_and_save_roc_curve(y_true, y_pred_proba, model_name, path):
    fpr, tpr, _ = roc_curve(y_true, y_pred_proba); roc_auc = auc(fpr, tpr)
    plt.figure(figsize=(8, 6)); plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:0.3f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--'); plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate')
    plt.title(f'{model_name} - ROC Curve'); plt.legend(loc="lower right"); plt.grid(True)
    plt.savefig(os.path.join(path, f"{model_name}_roc_curve.png"), dpi=300); plt.close()
    print(f"✅ Saved ROC curve for {model_name}.")

# =============================================================================
# --- Explainability Functions ---
# =============================================================================
def run_full_shap_analysis(model, X_train, X_test, y_test, output_path, num_samples=50):
    print("\n--- Running SHAP Analysis ---")
    os.makedirs(output_path, exist_ok=True)
    
    # FIX: Create a functional wrapper for the subclassed model to make it compatible with SHAP
    inputs = tf.keras.Input(shape=model.input_shape_config)
    outputs = model(inputs)
    functional_model = Model(inputs, outputs)

    idx = np.random.choice(len(X_test), min(num_samples, len(X_test)), replace=False)
    test_samples, y_true_samples = X_test[idx], y_test[idx]
    
    # Use the functional wrapper model with the explainer
    explainer = shap.GradientExplainer(functional_model, X_train[:50])
    shap_values = explainer.shap_values(test_samples)
    
    if isinstance(shap_values, list): shap_values = shap_values[0]

    # Plotting logic for SHAP
    mean_abs_shap = np.mean(np.abs(shap_values.reshape(shap_values.shape[0], -1)), axis=0)
    top_idx = np.argsort(mean_abs_shap)[::-1][:20]
    coords = [np.unravel_index(i, (shap_values.shape[1], shap_values.shape[2])) for i in top_idx]
    labels = [f"T{t} F{f}" for t, f in coords]
    plt.figure(figsize=(12, 6)); plt.bar(range(20), mean_abs_shap[top_idx])
    plt.xticks(range(20), labels, rotation=45, ha="right"); plt.title("Top-20 Global SHAP Features")
    plt.ylabel("Mean |SHAP value|"); plt.tight_layout()
    plt.savefig(os.path.join(output_path, "shap_global_bar.png"), dpi=300); plt.close()
    print("-> Saved SHAP global bar plot.")

    hc_mask, pd_mask = (y_true_samples == 0), (y_true_samples == 1)
    if np.any(hc_mask):
        hc_mean = shap_values[hc_mask].mean(axis=0).squeeze()
        plt.figure(); plt.imshow(hc_mean, cmap="bwr", aspect="auto"); plt.colorbar(); plt.title("Average SHAP - Healthy")
        plt.savefig(os.path.join(output_path, "shap_summary_healthy.png"), dpi=300); plt.close()
    if np.any(pd_mask):
        pd_mean = shap_values[pd_mask].mean(axis=0).squeeze()
        plt.figure(); plt.imshow(pd_mean, cmap="bwr", aspect="auto"); plt.colorbar(); plt.title("Average SHAP - Parkinson's")
        plt.savefig(os.path.join(output_path, "shap_summary_parkinson.png"), dpi=300); plt.close()
    print("-> Saved SHAP class heatmaps.")
    print("--- SHAP Analysis Complete ---")


def run_gradcam_analysis(model, X_test, y_test, output_path, num_samples=30):
    print("\n--- Running Grad-CAM Analysis ---")
    os.makedirs(output_path, exist_ok=True)
    y_pred = (model.predict(X_test) > 0.5).astype(int).flatten()
    tp_idx = np.where((y_test == 1) & (y_pred == 1))[0]; tn_idx = np.where((y_test == 0) & (y_pred == 0))[0]
    def get_avg_heatmap(indices):
        heatmaps = []
        for i in tqdm(indices, desc="Grad-CAM Progress", leave=False):
            img_array = X_test[i:i+1]
            with tf.GradientTape() as tape:
                final_preds, last_conv_output = model(img_array, grad_cam=True)
                tape.watch(last_conv_output); loss = final_preds[0]
            grads = tape.gradient(loss, last_conv_output)
            pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
            heatmap = last_conv_output[0] @ pooled_grads[..., tf.newaxis]; heatmap = tf.squeeze(heatmap)
            heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-10)
            heatmaps.append(heatmap.numpy())
        return np.mean(heatmaps, axis=0) if heatmaps else np.zeros(X_test.shape[1:3])
    avg_tp_heatmap = get_avg_heatmap(np.random.choice(tp_idx, min(num_samples, len(tp_idx)), replace=False))
    avg_tn_heatmap = get_avg_heatmap(np.random.choice(tn_idx, min(num_samples, len(tn_idx)), replace=False))
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    im1 = axes[0].imshow(avg_tp_heatmap, cmap='jet', aspect='auto'); axes[0].set_title(f'Avg Grad-CAM for Parkinson\'s (TP)'); fig.colorbar(im1, ax=axes[0])
    im2 = axes[1].imshow(avg_tn_heatmap, cmap='jet', aspect='auto'); axes[1].set_title(f'Avg Grad-CAM for Healthy (TN)'); fig.colorbar(im2, ax=axes[1])
    plt.suptitle("Average Model Attention by Class"); plt.tight_layout()
    plt.savefig(os.path.join(output_path, "gradcam_average_comparison.png"), dpi=300); plt.close()
    print("-> Saved average Grad-CAM comparison."); print("--- Grad-CAM Analysis Complete ---")


In [None]:

# =============================================================================
# --- Main Execution ---
# =============================================================================
if __name__ == '__main__':
    # =========================================================================
    # --- PHASE 1: TRAINING ---
    # =========================================================================
    print("="*80); print(f"🚀 STARTING TRAINING PHASE | RUN_ID: {RUN_ID}"); print("="*80)
    X_train_full, y_train_full = load_and_combine_data(TRAIN_DATASETS, MODE, FEATURE_MODE)
    X_train, X_val, y_train, y_val = train_test_split(X_train_full, y_train_full, test_size=0.2, random_state=42, stratify=y_train_full)
    n_samples, d1, d2 = X_train.shape; X_train_2d = X_train.reshape((n_samples, d1*d2)); X_val_2d = X_val.reshape((X_val.shape[0], d1*d2))
    print(f"\nTraining data prepared. Train: {X_train.shape}, Validation: {X_val.shape}")
    print("\nClass distribution in main training set:"); print(pd.Series(y_train).value_counts())
    


In [None]:
    # --- Train Exp 1: k-NN ---
    print("\n" + "-"*80); print("TRAINING MODEL 1: k-NN (Baseline)"); print("-"*80)
    X_train_full_2d = X_train_full.reshape((X_train_full.shape[0], d1*d2))
    pipeline_knn = Pipeline([('scaler', StandardScaler()), ('smote', SMOTE(random_state=42)), ('nca', NeighborhoodComponentsAnalysis(random_state=42, max_iter=200)), ('classifier', KNeighborsClassifier())])
    param_dist_knn = {'nca__n_components': [10, 20, 30, 40], 'classifier__n_neighbors': [3, 5, 7], 'classifier__weights': ['distance'], 'classifier__metric': ['manhattan']}
    search_knn = RandomizedSearchCV(pipeline_knn, param_dist_knn, n_iter=10, cv=3, scoring='accuracy', n_jobs=1, random_state=42, verbose=1)
    search_knn.fit(X_train_full_2d, y_train_full)
    joblib.dump(search_knn.best_estimator_, os.path.join(RESULTS_PATH, "model_1_knn.joblib"))
    print("✅ Best k-NN pipeline trained and saved.")


In [None]:

    # --- Train Exp 2: CNN ---
    print("\n" + "-"*80); print("TRAINING MODEL 2: CNN (End-to-End)"); print("-"*80)
    MODEL_CNN_PATH = os.path.join(RESULTS_PATH, "model_2_cnn.keras")
    model_cnn = ParkinsonDetectorModel(input_shape=(d1, d2))
    model_cnn.compile(optimizer=Adam(learning_rate=LEARNING_RATE), loss='binary_crossentropy', metrics=['accuracy', tf.keras.metrics.AUC(name='auc')])
    history_cnn = model_cnn.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=EPOCHS, batch_size=BATCH_SIZE,
                                callbacks=[ModelCheckpoint(MODEL_CNN_PATH, save_best_only=True, monitor='val_auc', mode='max', verbose=1)])
    plot_and_save_history(history_cnn, "model_2_cnn", RESULTS_PATH)
    


In [None]:
    # --- Train Exp 3: CNN + k-NN ---
    print("\n" + "-"*80); print("TRAINING MODEL 3: CNN Feature Extractor + k-NN"); print("-"*80)
    best_model_cnn_extractor = tf.keras.models.load_model(MODEL_CNN_PATH)
    X_train_features = best_model_cnn_extractor(X_train_full, extract_features=True).numpy()
    search_cnn_knn = RandomizedSearchCV(pipeline_knn, param_dist_knn, n_iter=10, cv=3, scoring='accuracy', n_jobs=1, random_state=42, verbose=1)
    search_cnn_knn.fit(X_train_features, y_train_full)
    joblib.dump(search_cnn_knn.best_estimator_, os.path.join(RESULTS_PATH, "model_3_cnn_knn.joblib"))
    print("✅ Best CNN+k-NN pipeline trained and saved.")

In [None]:
    # --- Train Exp 4: NCA + CNN ---
    print("\n" + "-"*80); print("TRAINING MODEL 4: NCA pre-processing + CNN"); print("-"*80)
    NCA_COMPONENTS = 64; nca_img_dim = int(np.sqrt(NCA_COMPONENTS))
    nca_preprocessor = Pipeline([('scaler', StandardScaler()), ('nca', NeighborhoodComponentsAnalysis(n_components=NCA_COMPONENTS, random_state=42, max_iter=200))])
    X_train_nca = nca_preprocessor.fit_transform(X_train_full_2d, y_train_full)
    joblib.dump(nca_preprocessor, os.path.join(RESULTS_PATH, "model_4_nca_preprocessor.joblib"))
    X_train_nca_3d = X_train_nca.reshape(-1, nca_img_dim, nca_img_dim)
    X_val_nca = nca_preprocessor.transform(X_val_2d); X_val_nca_3d = X_val_nca.reshape(-1, nca_img_dim, nca_img_dim)
    MODEL_NCA_CNN_PATH = os.path.join(RESULTS_PATH, "model_4_nca_cnn.keras")
    model_nca_cnn = ParkinsonDetectorModelNCA(input_shape=(nca_img_dim, nca_img_dim))
    model_nca_cnn.compile(optimizer=Adam(learning_rate=LEARNING_RATE), loss='binary_crossentropy', metrics=['accuracy', tf.keras.metrics.AUC(name='auc')])
    history_nca_cnn = model_nca_cnn.fit(X_train_nca_3d, y_train_full, validation_data=(X_val_nca_3d, y_val), epochs=EPOCHS, batch_size=BATCH_SIZE,
                                        callbacks=[ModelCheckpoint(MODEL_NCA_CNN_PATH, save_best_only=True, monitor='val_auc', mode='max', verbose=1)])
    plot_and_save_history(history_nca_cnn, "model_4_nca_cnn", RESULTS_PATH)
    


In [None]:
    # =========================================================================
    # --- PHASE 2: TESTING AND EXPLAINABILITY ---
    # =========================================================================
    print("\n\n" + "="*80)
    if TEST_DATASET is None:
        print(f"🔬 TESTING & EXPLAINING ON VALIDATION SPLIT FROM: {', '.join(TRAIN_DATASETS)}")
        X_test, y_test = X_val, y_val
    else:
        print(f"🔬 TESTING & EXPLAINING ON UNSEEN DATASET: {TEST_DATASET}")
        test_data = load_single_dataset(TEST_DATASET, MODE, FEATURE_MODE)
        if test_data is None: exit("No test data found. Aborting.")
        X_test, y_test = test_data
    print("="*80)

    test_results = {}
    
    # --- Test Model 1: k-NN ---
    print("\n" + "-"*80); print("EVALUATING MODEL 1: k-NN (Baseline)"); print("-"*80)
    model_1 = joblib.load(os.path.join(RESULTS_PATH, "model_1_knn.joblib"))
    X_test_2d = X_test.reshape((X_test.shape[0], d1 * d2))
    y_pred_1 = model_1.predict(X_test_2d); y_pred_proba_1 = model_1.predict_proba(X_test_2d)[:, 1]
    acc1 = accuracy_score(y_test, y_pred_1); test_results["1: k-NN (Baseline)"] = acc1
    print(f"Accuracy on Test Set: {acc1:.4f}\n{classification_report(y_test, y_pred_1, digits=4)}")
    plot_and_save_confusion_matrix(y_test, y_pred_1, "model_1_knn", PLOTS_PATH)
    plot_and_save_roc_curve(y_test, y_pred_proba_1, "model_1_knn", PLOTS_PATH)

    # --- Test Model 2: CNN ---
    print("\n" + "-"*80); print("EVALUATING MODEL 2: CNN (End-to-End)"); print("-"*80)
    model_2 = tf.keras.models.load_model(os.path.join(RESULTS_PATH, "model_2_cnn.keras"))
    y_pred_proba_2 = model_2.predict(X_test); y_pred_2 = (y_pred_proba_2 > 0.5).astype("int32")
    acc2 = accuracy_score(y_test, y_pred_2); test_results["2: CNN (End-to-End)"] = acc2
    print(f"Accuracy on Test Set: {acc2:.4f}\n{classification_report(y_test, y_pred_2, digits=4)}")
    plot_and_save_confusion_matrix(y_test, y_pred_2, "model_2_cnn", PLOTS_PATH)
    plot_and_save_roc_curve(y_test, y_pred_proba_2, "model_2_cnn", PLOTS_PATH)
    
    # --- Test Model 3: CNN + k-NN ---
    print("\n" + "-"*80); print("EVALUATING MODEL 3: CNN + k-NN"); print("-"*80)
    model_3_extractor = tf.keras.models.load_model(os.path.join(RESULTS_PATH, "model_2_cnn.keras"))
    model_3_knn = joblib.load(os.path.join(RESULTS_PATH, "model_3_cnn_knn.joblib"))
    X_test_features = model_3_extractor(X_test, extract_features=True).numpy()
    y_pred_3 = model_3_knn.predict(X_test_features); y_pred_proba_3 = model_3_knn.predict_proba(X_test_features)[:, 1]
    acc3 = accuracy_score(y_test, y_pred_3); test_results["3: CNN + k-NN"] = acc3
    print(f"Accuracy on Test Set: {acc3:.4f}\n{classification_report(y_test, y_pred_3, digits=4)}")
    plot_and_save_confusion_matrix(y_test, y_pred_3, "model_3_cnn_knn", PLOTS_PATH)
    plot_and_save_roc_curve(y_test, y_pred_proba_3, "model_3_cnn_knn", PLOTS_PATH)
    
    # --- Test Model 4: NCA + CNN ---
    print("\n" + "-"*80); print("EVALUATING MODEL 4: NCA + CNN"); print("-"*80)
    model_4_preprocessor = joblib.load(os.path.join(RESULTS_PATH, "model_4_nca_preprocessor.joblib"))
    model_4_cnn = tf.keras.models.load_model(os.path.join(RESULTS_PATH, "model_4_nca_cnn.keras"))
    X_test_2d_4 = X_test.reshape((X_test.shape[0], d1 * d2))
    X_test_nca = model_4_preprocessor.transform(X_test_2d_4)
    X_test_nca_3d = X_test_nca.reshape(-1, nca_img_dim, nca_img_dim)
    y_pred_proba_4 = model_4_cnn.predict(X_test_nca_3d); y_pred_4 = (y_pred_proba_4 > 0.5).astype("int32")
    acc4 = accuracy_score(y_test, y_pred_4); test_results["4: NCA + CNN"] = acc4
    print(f"Accuracy on Test Set: {acc4:.4f}\n{classification_report(y_test, y_pred_4, digits=4)}")
    plot_and_save_confusion_matrix(y_test, y_pred_4, "model_4_nca_cnn", PLOTS_PATH)
    plot_and_save_roc_curve(y_test, y_pred_proba_4, "model_4_nca_cnn", PLOTS_PATH)
    


In [None]:
    # --- Explainability for CNN-based models ---
    print("\n\n" + "="*80); print("🔬 RUNNING EXPLAINABILITY ANALYSIS"); print("="*80)
    explain_path_cnn = os.path.join(PLOTS_PATH, "explainability_cnn_models")
    print("\n--- Explaining Model 2 (CNN) and Model 3 (CNN+k-NN) Feature Extractor ---")
    run_full_shap_analysis(model_2, X_train, X_test, y_test, os.path.join(explain_path_cnn, "model_2_cnn"))
    run_gradcam_analysis(model_2, X_test, y_test, os.path.join(explain_path_cnn, "model_2_cnn"))
    print("\n--- Explaining Model 4 (NCA + CNN) ---")
    X_train_full_nca_3d = model_4_preprocessor.transform(X_train_full_2d).reshape(-1, nca_img_dim, nca_img_dim)
    run_full_shap_analysis(model_4_cnn, X_train_full_nca_3d, X_test_nca_3d, y_test, os.path.join(explain_path_cnn, "model_4_nca_cnn"))
    run_gradcam_analysis(model_4_cnn, X_test_nca_3d, y_test, os.path.join(explain_path_cnn, "model_4_nca_cnn"))
    
    # --- Final Summary ---
    print("\n\n" + "="*80); print(f"🏆 FINAL SUMMARY - PERFORMANCE ON TEST SET"); print("="*80)
    best_model_name = max(test_results, key=test_results.get); best_accuracy = test_results[best_model_name]
    print(f"{'Model':<40} | {'Test Accuracy':<20}"); print("-" * 65)
    for name, acc in test_results.items(): print(f"{name:<40} | {acc:<20.4f}")
    print("-" * 65); print(f"\n🚀 Best Performing Model: '{best_model_name}' with an accuracy of {best_accuracy:.4f}"); print("="*80)
