In [4]:
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import (Input, Conv2D, MaxPooling2D, Dropout, Flatten,
                                     Dense, LSTM, MultiHeadAttention, Concatenate, Reshape)
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
from sklearn.neighbors import NeighborhoodComponentsAnalysis
from sklearn.preprocessing import StandardScaler
from imblearn.over_sampling import SMOTE
from tqdm import tqdm
from scipy.stats import ttest_ind
import pandas as pd
import shap
from tensorflow.keras.saving import register_keras_serializable
from mpl_toolkits.axes_grid1 import ImageGrid
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.patches as mpatches
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Reshape
import numpy as np
import matplotlib.pyplot as plt
import os
from skimage.transform import resize
from matplotlib.colors import ListedColormap
import matplotlib.cm as cm

# =============================================================================
# --- Configuration ---
# =============================================================================

# --- 1. Core Paths ---
ITALIAN_DATASET = "ITALIAN_DATASET"
UAMS_DATASET = "UAMS_DATASET"
NEUROVOZ_DATASET = "NEUROVOZ_DATASET"
MPOWER_DATASET = "MPOWER_DATASET"
SYNTHETIC_DATASET = "SYNTHETIC_DATASET"

MODE_ALL_VALIDS = "ALL_VALIDS"
MODE_A = "A"

FEATURE_MODE_BASIC = "BASIC"        # mel_spectrogram, mfcc, spectrogram
FEATURE_MODE_ALL = "ALL"            # basic + fsc
FEATURE_MODE_DEFAULT = "DEFAULT"

MODEL_NAME = "nca_smote_parallel"

# /////////// SELCET HERE \\\\\\\\\\\
# ----------------------------------
DATASET = UAMS_DATASET
MODE = MODE_A
FEATURE_MODE = FEATURE_MODE_DEFAULT
# ----------------------------------

dataset = ""
if DATASET == NEUROVOZ_DATASET:
    dataset = "Neurovoz"
elif DATASET == UAMS_DATASET:
    dataset = "UAMS"
elif DATASET == MPOWER_DATASET:
    dataset = "mPower"
elif DATASET == SYNTHETIC_DATASET:
    dataset = "Synthetic"
elif DATASET == ITALIAN_DATASET:
    dataset = "Italian"

# Path Setup
FEATURES_FILE_PATH = os.path.join(os.getcwd(), dataset, "data", f"features_{MODE}_{FEATURE_MODE}.npz")

MODEL_PATH = os.path.join(os.getcwd(), dataset, f"results_{MODE}_{FEATURE_MODE}", MODEL_NAME)
os.makedirs(MODEL_PATH, exist_ok=True)
EVALUATION_FILE_PATH = os.path.join(MODEL_PATH, "evaluation.csv")
HISTORY_SAVE_PATH = os.path.join(MODEL_PATH, "history.csv")
EVALUATION_FILE_PATH = os.path.join(MODEL_PATH, "evaluation.csv")
BEST_MODEL_PATH = os.path.join(MODEL_PATH, "best_model.keras")

SHAP_OUTPUT_PATH = os.path.join(MODEL_PATH, "shap_analysis")
GRADCAM_OUTPUT_PATH = os.path.join(MODEL_PATH, "gradcam_analysis")
os.makedirs(SHAP_OUTPUT_PATH, exist_ok=True)
os.makedirs(GRADCAM_OUTPUT_PATH, exist_ok=True)

# Hyperparameters
EPOCHS = 30
BATCH_SIZE = 32
LEARNING_RATE = 0.001
DROPOUT_RATE = 0.5
L2_STRENGTH = 0.01
NCA_COMPONENTS = 128  # Number of components for NCA
USE_SMOTE = True      # Whether to apply SMOTE for balancing

# Model Checkpoint Callback
checkpoint_cb = ModelCheckpoint(BEST_MODEL_PATH, monitor='val_auc', mode='max', save_best_only=True, verbose=1)

def load_data(feature_file_path):
    mel_spectrograms = None
    mfccs = None
    spectrograms = None
    fcs = None

    print(f"--- Loading data from {feature_file_path} ---")
    with np.load(feature_file_path) as data:
        labels = data['labels']

        if 'mel_spectrogram' in data.keys():
            mel_spectrograms = data['mel_spectrogram']
        if 'mfcc' in data:
            mfccs = data['mfcc']
        if 'spectrogram' in data.keys():
            spectrograms = data['spectrogram']

        feature_arrays = []

        if mel_spectrograms is not None:
            feature_arrays.append(mel_spectrograms)
            print(f"Added mel_spectrogram: {mel_spectrograms.shape}")

        if mfccs is not None:
            feature_arrays.append(mfccs)
            print(f"Added mfcc: {mfccs.shape}")

        if spectrograms is not None:
            feature_arrays.append(spectrograms)
            print(f"Added spectrogram: {spectrograms.shape}")

        if feature_arrays:
            X = np.concatenate(feature_arrays, axis=-1)  # concat along last axis
            print(f"Final concatenated shape: {X.shape}")
        else:
            raise ValueError("No valid features found in the file!")
        return X, labels

def apply_nca_smote_preprocessing(X_train, y_train, X_test, n_components=None, use_smote=True, min_dim=16):
    """
    Apply NCA for dimensionality reduction followed by SMOTE for balancing.
    Ensures minimum dimensions for CNN compatibility.

    Args:
        X_train: Training features (samples, height, width)
        y_train: Training labels
        X_test: Test features (samples, height, width) - ONLY TRANSFORMED, NEVER FITTED
        n_components: Number of NCA components
        use_smote: Whether to apply SMOTE after NCA
        min_dim: Minimum dimension for each axis (height, width)

    Returns:
        X_train_processed: Processed training features (samples, height, width, 1)
        y_train_processed: Processed training labels (may be augmented if SMOTE used)
        X_test_processed: Processed test features (samples, height, width, 1)
        nca: Fitted NCA object
        scaler: Fitted StandardScaler object
    """
    print("\n--- Applying NCA + SMOTE Preprocessing ---")

    # Get original shapes
    original_train_shape = X_train.shape
    original_test_shape = X_test.shape
    print(f"Original train shape: {original_train_shape}")
    print(f"Original test shape: {original_test_shape}")

    # Flatten the features for NCA (samples, height*width*channels -> samples, features)
    X_train_flat = X_train.reshape(X_train.shape[0], -1)
    X_test_flat = X_test.reshape(X_test.shape[0], -1)

    print(f"Flattened train shape: {X_train_flat.shape}")
    print(f"Flattened test shape: {X_test_flat.shape}")

    # Step 1: Standardize features (fit ONLY on training data)
    print("Standardizing features...")
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train_flat)  # FIT on training
    X_test_scaled = scaler.transform(X_test_flat)        # ONLY TRANSFORM test

    # Step 2: Apply NCA (fit ONLY on training data)
    if n_components is None:
        n_components = min(len(X_train) - 1, X_train_flat.shape[1])
    else:
        n_components = min(n_components, len(X_train) - 1, X_train_flat.shape[1])

    print(f"Applying NCA with {n_components} components...")
    nca = NeighborhoodComponentsAnalysis(n_components=n_components, random_state=42, max_iter=200)

    X_train_nca = nca.fit_transform(X_train_scaled, y_train)  # FIT on training
    X_test_nca = nca.transform(X_test_scaled)                 # ONLY TRANSFORM test

    print(f"NCA transformed train shape: {X_train_nca.shape}")
    print(f"NCA transformed test shape: {X_test_nca.shape}")

    # Step 3: Apply SMOTE for balancing (ONLY on training data)
    if use_smote:
        print("Applying SMOTE for class balancing...")
        print(f"Before SMOTE - Class distribution: {np.bincount(y_train)}")

        smote = SMOTE(random_state=42)
        X_train_balanced, y_train_balanced = smote.fit_resample(X_train_nca, y_train)

        print(f"After SMOTE - Class distribution: {np.bincount(y_train_balanced)}")
        print(f"SMOTE balanced train shape: {X_train_balanced.shape}")
    else:
        X_train_balanced = X_train_nca
        y_train_balanced = y_train

    # Step 4: Calculate optimal reshape dimensions for CNN input
    n_features = X_train_balanced.shape[1]

    # Find factors to create CNN-compatible dimensions
    height = max(min_dim, int(np.sqrt(n_features)))
    width = max(min_dim, int(np.ceil(n_features / height)))

    # Ensure minimum dimensions for CNN compatibility
    if height < min_dim:
        height = min_dim
    if width < min_dim:
        width = min_dim

    # Ensure it's CNN-friendly (pad if necessary)
    target_size = height * width
    pad_size = max(0, target_size - n_features)

    if pad_size > 0:
        print(f"Padding with {pad_size} zeros for CNN compatibility...")
        X_train_balanced = np.pad(X_train_balanced, ((0, 0), (0, pad_size)), mode='constant', constant_values=0)
        X_test_nca = np.pad(X_test_nca, ((0, 0), (0, pad_size)), mode='constant', constant_values=0)
    elif pad_size < 0:
        # Truncate if we have too many features
        print(f"Truncating {-pad_size} features for CNN compatibility...")
        X_train_balanced = X_train_balanced[:, :target_size]
        X_test_nca = X_test_nca[:, :target_size]

    # Reshape to 3D for CNN (samples, height, width, channels=1)
    X_train_final = X_train_balanced.reshape(X_train_balanced.shape[0], height, width, 1)
    X_test_final = X_test_nca.reshape(X_test_nca.shape[0], height, width, 1)

    print(f"Final CNN-ready train shape: {X_train_final.shape}")
    print(f"Final CNN-ready test shape: {X_test_final.shape}")
    print("--- NCA + SMOTE Preprocessing Complete ---\n")

    return X_train_final, y_train_balanced, X_test_final, nca, scaler

# =============================================================================
# --- Model Architecture ---
# =============================================================================

@register_keras_serializable()
class ParkinsonDetectorModel(Model):
    def __init__(self, input_shape, **kwargs):
        super(ParkinsonDetectorModel, self).__init__(**kwargs)
        self.input_shape_config = input_shape

        # Calculate adaptive pooling sizes based on input dimensions
        height, width = input_shape[0], input_shape[1]

        # Use smaller pool sizes for small inputs
        if height <= 16 or width <= 16:
            pool_size1 = 2
            pool_size2 = 2
        elif height <= 32 or width <= 32:
            pool_size1 = 3
            pool_size2 = 2
        else:
            pool_size1 = 5
            pool_size2 = 3

        print(f"Using pool sizes: {pool_size1}, {pool_size2} for input shape {input_shape}")

        # CNN layers with adaptive pooling
        self.conv1a = Conv2D(32, 3, activation='relu', kernel_regularizer=l2(L2_STRENGTH), padding='same')
        self.conv1b = Conv2D(32, 3, activation='relu', kernel_regularizer=l2(L2_STRENGTH), padding='same')
        self.pool1 = MaxPooling2D(pool_size1, padding='same')
        self.drop1 = Dropout(DROPOUT_RATE)

        self.conv2a = Conv2D(64, 3, activation='relu', kernel_regularizer=l2(L2_STRENGTH), padding='same')
        self.conv2b = Conv2D(64, 3, activation='relu', kernel_regularizer=l2(L2_STRENGTH), padding='same', name='last_conv_layer')
        self.pool2 = MaxPooling2D(pool_size2, padding='same', name='cnn_output')
        self.drop2 = Dropout(DROPOUT_RATE)

        # Other layers
        self.flatten_cnn = Flatten()
        self.attention = MultiHeadAttention(num_heads=2, key_dim=32, name='attention_output')
        self.flatten_att = Flatten()
        self.lstm1 = LSTM(64, return_sequences=True)
        self.lstm2 = LSTM(64, return_sequences=False, name='lstm_output')
        self.drop_lstm = Dropout(DROPOUT_RATE)
        self.concat = Concatenate()
        self.dense_bottleneck = Dense(64, activation='relu', name='bottleneck_features')
        self.dense_output = Dense(1, activation='sigmoid')

    def build(self, input_shape):
        """Build the model layers"""
        super(ParkinsonDetectorModel, self).build(input_shape)

    def call(self, inputs, training=False):
        # inputs is already (batch, height, width, channels)
        x = self.conv1a(inputs)
        x = self.conv1b(x)
        x = self.pool1(x)
        x = self.drop1(x, training=training)

        x = self.conv2a(x)
        x = self.conv2b(x)
        cnn_branch_output = self.pool2(x)
        x = self.drop2(cnn_branch_output, training=training)

        cnn_flat = self.flatten_cnn(x)

        # Prepare sequence for attention and LSTM
        shape = tf.shape(cnn_branch_output)
        sequence = tf.reshape(cnn_branch_output, [-1, shape[1] * shape[2], shape[3]])

        # Attention branch
        att_branch_output = self.attention(query=sequence, key=sequence, value=sequence)
        att_flat = self.flatten_att(att_branch_output)

        # LSTM branch
        lstm_seq = self.lstm1(sequence)
        lstm_branch_output = self.lstm2(lstm_seq)
        lstm_out = self.drop_lstm(lstm_branch_output, training=training)

        # Combine all branches
        concatenated = self.concat([cnn_flat, att_flat, lstm_out])
        bottleneck = self.dense_bottleneck(concatenated)
        final_output = self.dense_output(bottleneck)

        return final_output

    def get_config(self):
        config = super(ParkinsonDetectorModel, self).get_config()
        config.update({"input_shape": self.input_shape_config})
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

def build_model(input_shape: tuple) -> Model:
    """Builds the hybrid model by wrapping the custom class in a Functional API model."""
    print("--- Building the model ---")
    print(f"Input shape: {input_shape}")

    inputs = Input(shape=input_shape)
    parkinson_detector = ParkinsonDetectorModel(input_shape=input_shape)
    outputs = parkinson_detector(inputs)
    model = Model(inputs=inputs, outputs=outputs)

    print("Model built successfully.")
    return model

# =============================================================================
# --- Model Performance ---
# =============================================================================
def save_metrics_to_csv(y_true, y_pred_proba, filename="classification_report.csv", threshold=0.5):
    y_true = np.array(y_true)
    y_pred_binary = (np.array(y_pred_proba) > threshold).astype(int)

    cm = confusion_matrix(y_true, y_pred_binary)
    tn, fp, fn, tp = cm.ravel()

    total_samples = cm.sum()
    if total_samples == 0:
        tn_percent, fp_percent, fn_percent, tp_percent = 0, 0, 0, 0
    else:
        tn_percent = (tn / total_samples) * 100
        fp_percent = (fp / total_samples) * 100
        fn_percent = (fn / total_samples) * 100
        tp_percent = (tp / total_samples) * 100

    precision = precision_score(y_true, y_pred_binary, zero_division=0)
    recall = recall_score(y_true, y_pred_binary, zero_division=0)
    f1 = f1_score(y_true, y_pred_binary, zero_division=0)
    sensitivity = recall

    report_data = {
        'Metric': [
            'True Positive (TP)',
            'True Negative (TN)',
            'False Positive (FP)',
            'False Negative (FN)',
            'Precision',
            'Recall (Sensitivity)',
            'F1-Score'
        ],
        'Value': [
            f"{tp} ({tp_percent:.2f}%)",
            f"{tn} ({tn_percent:.2f}%)",
            f"{fp} ({fp_percent:.2f}%)",
            f"{fn} ({fn_percent:.2f}%)",
            f"{precision:.4f}",
            f"{recall:.4f}",
            f"{f1:.4f}"
        ]
    }
    df = pd.DataFrame(report_data)

    try:
        df.to_csv(filename, index=False, encoding='utf-8-sig')
        print(f"The evaluation results is stored: {filename}")
    except Exception as e:
        print(f"Error while saving the evaluation report: {e}")

# =============================================================================
# --- Model Explainability (SHAP & Grad-CAM) ---
# =============================================================================

def generate_nca_feature_map_info(nca_shape):
    """
    Generate feature layout information for NCA-transformed data.
    """
    height, width = nca_shape[1], nca_shape[2]

    feature_layout = {
        'nca_transformed_features': height
    }

    colors = plt.get_cmap('viridis', 1)
    feature_names = list(feature_layout.keys())
    total_rows = sum(feature_layout.values())

    color_mask = np.zeros((total_rows, width), dtype=int)

    legend_patches = [mpatches.Patch(color=colors(0), label=f"NCA Features ({height} rows)")]

    return {
        'color_mask': color_mask,
        'feature_layout': feature_layout,
        'feature_names': feature_names,
        'colors': colors,
        'legend_patches': legend_patches,
        'total_rows': total_rows
    }
def run_full_shap_analysis(model, X_train, X_test, y_test, output_path, nca_data_shape, samples_per_class=50, top_n=20):
    """
    Run SHAP analysis with balanced sample selection for NCA-transformed data.
    """
    print("\n--- Running Full SHAP Analysis on NCA Data ---")
    os.makedirs(output_path, exist_ok=True)

    feature_map_info = generate_nca_feature_map_info(nca_data_shape)
    legend_patches = feature_map_info['legend_patches']
    total_rows = feature_map_info['total_rows']

    # Balanced sample selection
    healthy_indices = np.where(y_test == 0)[0]
    parkinson_indices = np.where(y_test == 1)[0]
    num_healthy_to_select = min(samples_per_class, len(healthy_indices))
    num_parkinson_to_select = min(samples_per_class, len(parkinson_indices))

    selected_healthy_indices = np.random.choice(healthy_indices, num_healthy_to_select, replace=False)
    selected_parkinson_indices = np.random.choice(parkinson_indices, num_parkinson_to_select, replace=False)
    final_indices = np.concatenate([selected_healthy_indices, selected_parkinson_indices])
    np.random.shuffle(final_indices)

    test_samples = X_test[final_indices]
    y_true_samples = y_test[final_indices]

    print(f"Calculating SHAP values for {len(test_samples)} balanced samples...")
    explainer = shap.GradientExplainer(model, X_train[:50].astype(np.float32))
    shap_values_list = []

    for sample in tqdm(test_samples, desc="SHAP Progress"):
        sample_batch = np.expand_dims(sample, axis=0).astype(np.float32)
        sv = explainer.shap_values(sample_batch)
        if isinstance(sv, list):
            sv = sv[0]

        # Handle the shape properly - squeeze all singleton dimensions except batch
        # Expected shape after processing: (1, height, width)
        sv_squeezed = np.squeeze(sv)  # Remove all singleton dimensions

        # Ensure we have the right shape (height, width)
        if sv_squeezed.ndim == 3:  # (height, width, 1)
            sv_squeezed = np.squeeze(sv_squeezed, axis=-1)  # Remove channel dimension
        elif sv_squeezed.ndim == 1:  # Flattened
            height, width = nca_data_shape[1], nca_data_shape[2]
            sv_squeezed = sv_squeezed.reshape(height, width)

        shap_values_list.append(sv_squeezed)

    shap_values = np.array(shap_values_list)
    print(f"\nSHAP values shape: {shap_values.shape}")

    # Ensure we have the right dimensions
    if shap_values.ndim == 3:  # (samples, height, width)
        height, width = shap_values.shape[1], shap_values.shape[2]
    else:
        print(f"Unexpected SHAP values shape: {shap_values.shape}")
        return

    actual_data_time_steps = width

    # Global Top-N analysis
    flat_shap = shap_values.reshape(shap_values.shape[0], -1)
    mean_abs = np.mean(np.abs(flat_shap), axis=0)
    top_idx = np.argsort(mean_abs)[::-1][:top_n]

    # Fix the coordinate calculation
    coords = [np.unravel_index(i, (height, width)) for i in top_idx]
    labels = [f"R{r}C{c}" for r, c in coords]

    plt.figure(figsize=(12, 6))
    plt.bar(range(len(top_idx)), mean_abs[top_idx])
    plt.xticks(range(len(top_idx)), labels, rotation=45, ha="right")
    plt.title(f"Top-{top_n} Global SHAP Features (NCA transformed)")
    plt.xlabel("Row × Column")
    plt.ylabel("Mean |SHAP value|")
    plt.tight_layout()
    plt.savefig(os.path.join(output_path, "shap_global_bar_nca.png"), dpi=300, bbox_inches="tight")
    plt.close()
    print("-> Saved 'shap_global_bar_nca.png'")

    def plot_aligned_heatmap(heatmap_data, title, filename_suffix, cmap, label, vmin=None, vmax=None):
        fig, axes = plt.subplots(1, 2, figsize=(14, 7))
        ax_shap, ax_feature_map = axes[0], axes[1]

        # SHAP Heatmap
        img = ax_shap.imshow(heatmap_data, cmap=cmap, aspect='auto', interpolation='nearest', vmin=vmin, vmax=vmax)
        ax_shap.set_title(title, fontsize=12)
        ax_shap.set_xlabel(f"NCA Columns ({actual_data_time_steps})", fontsize=10)
        ax_shap.set_ylabel(f"NCA Rows ({total_rows})", fontsize=10)

        divider = make_axes_locatable(ax_shap)
        cax = divider.append_axes("right", size="5%", pad=0.1)
        fig.colorbar(img, cax=cax, label=label)

        # Feature Map
        ax_feature_map.imshow(feature_map_info['color_mask'], cmap=feature_map_info['colors'],
                             aspect='auto', interpolation='nearest')
        ax_feature_map.set_title("NCA Feature Map", fontsize=12)
        ax_feature_map.set_xlabel(f"NCA Columns ({actual_data_time_steps})", fontsize=10)
        ax_feature_map.tick_params(axis='y', labelleft=False)

        ax_feature_map.legend(handles=legend_patches, loc='upper left', bbox_to_anchor=(1.02, 1),
                             borderaxespad=0., fontsize=8)

        fig.suptitle(f"NCA-SHAP Analysis: {title}", fontsize=16, fontweight='bold')
        plt.tight_layout(rect=[0, 0, 1, 0.95])

        plt.savefig(os.path.join(output_path, f"shap_aligned_nca_{filename_suffix}.png"), dpi=300)
        plt.close(fig)
        print(f"-> Saved 'shap_aligned_nca_{filename_suffix}.png'")

    # Generate class-specific and difference maps
    hc_mask, pd_mask = (y_true_samples == 0), (y_true_samples == 1)

    if np.any(hc_mask):
        hc_mean = shap_values[hc_mask].mean(axis=0)
        plot_aligned_heatmap(hc_mean, "Average SHAP - Healthy (NCA)", "summary_healthy", "bwr", "Mean SHAP Value")
    if np.any(pd_mask):
        pd_mean = shap_values[pd_mask].mean(axis=0)
        plot_aligned_heatmap(pd_mean, "Average SHAP - Parkinson (NCA)", "summary_parkinson", "bwr", "Mean SHAP Value")
    if np.any(hc_mask) and np.any(pd_mask):
        diff_map = pd_mean - hc_mean
        max_abs_diff = np.max(np.abs(diff_map))
        plot_aligned_heatmap(diff_map, "SHAP Difference (PD - HC, NCA)", "difference", "seismic", "Δ SHAP (PD - HC)", vmin=-max_abs_diff, vmax=max_abs_diff)

    print("\n--- NCA-SHAP Analysis Complete ---")

def run_gradcam_analysis(model, X_test, y_test, output_path, nca_data_shape, num_samples=50):
    """
    Run Grad-CAM analysis for NCA-transformed data.
    """
    print("\n--- Running Grad-CAM Analysis with NCA Data ---")
    os.makedirs(output_path, exist_ok=True)

    parkinson_detector = None
    for layer in model.layers:
        if 'ParkinsonDetectorModel' in str(type(layer)):
            parkinson_detector = layer
            break

    if parkinson_detector is None:
        print("ParkinsonDetectorModel not found in the model layers.")
        return

    last_conv_layer = parkinson_detector.conv2b
    print(f"Using last conv layer: {last_conv_layer.name}")

    def get_conv_and_output(inputs):
        x = parkinson_detector.conv1a(inputs)
        x = parkinson_detector.conv1b(x)
        x = parkinson_detector.pool1(x)
        x = parkinson_detector.drop1(x, training=False)
        x = parkinson_detector.conv2a(x)
        conv_output = parkinson_detector.conv2b(x)
        x = parkinson_detector.pool2(conv_output)
        x = parkinson_detector.drop2(x, training=False)
        cnn_flat = parkinson_detector.flatten_cnn(x)
        shape = tf.shape(x)
        sequence = tf.reshape(x, [-1, shape[1] * shape[2], shape[3]])
        att_out = parkinson_detector.attention(query=sequence, key=sequence, value=sequence)
        att_flat = parkinson_detector.flatten_att(att_out)
        lstm_seq = parkinson_detector.lstm1(sequence)
        lstm_out = parkinson_detector.lstm2(lstm_seq)
        lstm_out = parkinson_detector.drop_lstm(lstm_out, training=False)
        concatenated = parkinson_detector.concat([cnn_flat, att_flat, lstm_out])
        bottleneck = parkinson_detector.dense_bottleneck(concatenated)
        final_output = parkinson_detector.dense_output(bottleneck)
        return conv_output, final_output

    # Sample selection
    parkinson_indices = np.where(y_test == 1)[0]
    healthy_indices = np.where(y_test == 0)[0]

    selected_pd_indices = list(np.random.choice(parkinson_indices, min(num_samples, len(parkinson_indices)), replace=False))
    selected_hc_indices = list(np.random.choice(healthy_indices, min(num_samples, len(healthy_indices)), replace=False))

    print(f"Selected {len(selected_pd_indices)} PD and {len(selected_hc_indices)} HC samples for Grad-CAM.")

    # Calculate heatmaps
    def calculate_gradcam_heatmaps(indices):
        heatmaps = []
        for i in indices:
            img = X_test[i:i+1]
            with tf.GradientTape() as tape:
                img_tensor = tf.cast(img, tf.float32)
                tape.watch(img_tensor)
                conv_outputs, preds = get_conv_and_output(img_tensor)
                loss = preds[:, 0]
            grads = tape.gradient(loss, conv_outputs)
            pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
            conv_outputs_np = conv_outputs[0].numpy()
            pooled_grads_np = pooled_grads.numpy()
            heatmap = np.zeros(conv_outputs_np.shape[:-1])
            for j in range(conv_outputs_np.shape[-1]):
                heatmap += pooled_grads_np[j] * conv_outputs_np[:, :, j]
            heatmap = np.maximum(heatmap, 0)
            heatmap /= (heatmap.max() + 1e-10)
            heatmaps.append(heatmap)
        return heatmaps

    tp_heatmaps = calculate_gradcam_heatmaps(selected_pd_indices)
    tn_heatmaps = calculate_gradcam_heatmaps(selected_hc_indices)

    avg_pd_heatmap = np.mean(tp_heatmaps, axis=0) if tp_heatmaps else np.zeros((nca_data_shape[1], nca_data_shape[2]))
    avg_hc_heatmap = np.mean(tn_heatmaps, axis=0) if tn_heatmaps else np.zeros((nca_data_shape[1], nca_data_shape[2]))

    # Resize heatmaps to match input dimensions
    original_height, original_width = nca_data_shape[1], nca_data_shape[2]

    upscaled_avg_pd_heatmap = resize(avg_pd_heatmap, (original_height, original_width),
                                     order=3, mode='reflect', anti_aliasing=True)
    upscaled_avg_hc_heatmap = resize(avg_hc_heatmap, (original_height, original_width),
                                     order=3, mode='reflect', anti_aliasing=True)

    # Get sample inputs for visualization
    sample_input_pd = X_test[selected_pd_indices[0]].squeeze() if selected_pd_indices else None
    sample_input_hc = X_test[selected_hc_indices[0]].squeeze() if selected_hc_indices else None

    def normalize_for_display(img_data):
        if img_data is None: return None
        return (img_data - img_data.min()) / (img_data.max() - img_data.min() + 1e-10)

    normalized_input_pd = normalize_for_display(sample_input_pd)
    normalized_input_hc = normalize_for_display(sample_input_hc)

    # Create comprehensive visualization
    fig, axes = plt.subplots(1, 5, figsize=(25, 6))
    fig.suptitle("Grad-CAM: NCA Model Analysis", fontsize=18, fontweight='bold')

    # PD Input
    if normalized_input_pd is not None:
        axes[0].imshow(normalized_input_pd, cmap='gray', aspect='auto', origin='lower')
        axes[0].set_title(f'PD Input (NCA)\n(Sample {selected_pd_indices[0] if selected_pd_indices else ""})')
        axes[0].set_xlabel("NCA Columns")
        axes[0].set_ylabel("NCA Rows")

    # PD Heatmap
    im_pd = axes[1].imshow(upscaled_avg_pd_heatmap, cmap='jet', aspect='auto', origin='lower')
    axes[1].set_title(f'Avg. PD Attention\n({len(selected_pd_indices)} samples)')
    axes[1].set_xlabel("NCA Columns")
    axes[1].set_yticklabels([])
    divider_pd = make_axes_locatable(axes[1])
    cax_pd = divider_pd.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im_pd, cax=cax_pd)

    # HC Input
    if normalized_input_hc is not None:
        axes[2].imshow(normalized_input_hc, cmap='gray', aspect='auto', origin='lower')
        axes[2].set_title(f'HC Input (NCA)\n(Sample {selected_hc_indices[0] if selected_hc_indices else ""})')
        axes[2].set_xlabel("NCA Columns")
        axes[2].set_yticklabels([])

    # HC Heatmap
    im_hc = axes[3].imshow(upscaled_avg_hc_heatmap, cmap='jet', aspect='auto', origin='lower')
    axes[3].set_title(f'Avg. HC Attention\n({len(selected_hc_indices)} samples)')
    axes[3].set_xlabel("NCA Columns")
    axes[3].set_yticklabels([])
    divider_hc = make_axes_locatable(axes[3])
    cax_hc = divider_hc.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im_hc, cax=cax_hc)

    # Feature Map
    feature_map_info = generate_nca_feature_map_info(nca_data_shape)
    if feature_map_info:
        legend_colors = [patch.get_facecolor() for patch in feature_map_info['legend_patches']]
        cmap = ListedColormap(legend_colors)
        axes[4].imshow(feature_map_info['color_mask'], cmap=cmap, aspect='auto',
                      interpolation='nearest', origin='lower')
        axes[4].set_title("NCA Feature\nLayout")
        axes[4].set_xlabel("NCA Columns")
        axes[4].set_yticklabels([])

        fig.legend(handles=feature_map_info['legend_patches'], loc='center left',
                  bbox_to_anchor=(0.93, 0.5), borderaxespad=0.)

    plt.tight_layout(rect=[0, 0, 0.93, 0.93])

    save_path = os.path.join(output_path, "gradcam_nca_full_comparison.png")
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.close(fig)
    print(f"✅ Saved NCA Grad-CAM comparison to {save_path}")

In [5]:
def generate_nca_feature_map_info(nca_shape, original_feature_info=None):
    """
    Generate feature layout information for NCA-transformed data with original feature type tracking.

    Args:
        nca_shape: Shape of NCA-transformed data (batch, height, width, channels)
        original_feature_info: Dict containing original feature information
    """
    height, width = nca_shape[1], nca_shape[2]

    if original_feature_info is None:
        # Fallback to generic NCA features
        feature_layout = {'nca_transformed_features': height}
        colors = plt.get_cmap('viridis', 1)
        legend_patches = [mpatches.Patch(color=colors(0), label=f"NCA Features ({height} rows)")]
        color_mask = np.zeros((height, width), dtype=int)
    else:
        # Create meaningful feature mapping based on original features
        mel_contrib = original_feature_info.get('mel_spectrogram_contrib', 0)
        mfcc_contrib = original_feature_info.get('mfcc_contrib', 0)
        spectrogram_contrib = original_feature_info.get('spectrogram_contrib', 0)

        total_contrib = mel_contrib + mfcc_contrib + spectrogram_contrib

        if total_contrib > 0:
            # Calculate proportional rows for each feature type
            mel_rows = max(1, int((mel_contrib / total_contrib) * height))
            mfcc_rows = max(1, int((mfcc_contrib / total_contrib) * height))
            spectrogram_rows = height - mel_rows - mfcc_rows  # Remaining rows

            feature_layout = {
                'mel_spectrogram': mel_rows,
                'mfcc': mfcc_rows,
                'spectrogram': spectrogram_rows
            }
        else:
            feature_layout = {'nca_transformed_features': height}

        # Create color mapping
        n_features = len(feature_layout)
        colors = plt.get_cmap('Set3', n_features)

        color_mask = np.zeros((height, width), dtype=int)
        legend_patches = []

        current_row = 0
        for i, (feature_name, rows) in enumerate(feature_layout.items()):
            if rows > 0:
                color_mask[current_row:current_row + rows, :] = i
                legend_patches.append(mpatches.Patch(color=colors(i), label=f"{feature_name} ({rows} rows)"))
                current_row += rows

    return {
        'color_mask': color_mask,
        'feature_layout': feature_layout,
        'feature_names': list(feature_layout.keys()),
        'colors': colors,
        'legend_patches': legend_patches,
        'total_rows': height
    }

def calculate_original_feature_contributions(original_shapes):
    """
    Calculate how much each original feature type contributes to the flattened feature vector.

    Args:
        original_shapes: Dict with keys like 'mel_spectrogram', 'mfcc', 'spectrogram' and their shapes

    Returns:
        Dict with contribution percentages
    """
    contributions = {}
    total_features = 0

    for feature_name, shape in original_shapes.items():
        if len(shape) >= 2:  # (samples, height, width) or (samples, features)
            feature_count = np.prod(shape[1:])  # Product of all dimensions except samples
            contributions[f"{feature_name}_contrib"] = feature_count
            total_features += feature_count

    # Normalize to percentages
    if total_features > 0:
        for key in contributions:
            contributions[key] = contributions[key] / total_features

    return contributions

def apply_nca_smote_preprocessing(X_train, y_train, X_test, n_components=None, use_smote=True, min_dim=16, original_feature_info=None):
    """
    Apply NCA for dimensionality reduction followed by SMOTE for balancing.
    Ensures minimum dimensions for CNN compatibility.

    Args:
        X_train: Training features (samples, height, width)
        y_train: Training labels
        X_test: Test features (samples, height, width) - ONLY TRANSFORMED, NEVER FITTED
        n_components: Number of NCA components
        use_smote: Whether to apply SMOTE after NCA
        min_dim: Minimum dimension for each axis (height, width)
        original_feature_info: Dict containing original feature information

    Returns:
        X_train_processed: Processed training features (samples, height, width, 1)
        y_train_processed: Processed training labels (may be augmented if SMOTE used)
        X_test_processed: Processed test features (samples, height, width, 1)
        nca: Fitted NCA object
        scaler: Fitted StandardScaler object
        feature_info: Enhanced feature information for visualization
    """
    print("\n--- Applying NCA + SMOTE Preprocessing ---")

    # Get original shapes
    original_train_shape = X_train.shape
    original_test_shape = X_test.shape
    print(f"Original train shape: {original_train_shape}")
    print(f"Original test shape: {original_test_shape}")

    # Flatten the features for NCA (samples, height*width*channels -> samples, features)
    X_train_flat = X_train.reshape(X_train.shape[0], -1)
    X_test_flat = X_test.reshape(X_test.shape[0], -1)

    print(f"Flattened train shape: {X_train_flat.shape}")
    print(f"Flattened test shape: {X_test_flat.shape}")

    # Step 1: Standardize features (fit ONLY on training data)
    print("Standardizing features...")
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train_flat)  # FIT on training
    X_test_scaled = scaler.transform(X_test_flat)        # ONLY TRANSFORM test

    # Step 2: Apply NCA (fit ONLY on training data)
    if n_components is None:
        n_components = min(len(X_train) - 1, X_train_flat.shape[1])
    else:
        n_components = min(n_components, len(X_train) - 1, X_train_flat.shape[1])

    print(f"Applying NCA with {n_components} components...")
    nca = NeighborhoodComponentsAnalysis(n_components=n_components, random_state=42, max_iter=200)

    X_train_nca = nca.fit_transform(X_train_scaled, y_train)  # FIT on training
    X_test_nca = nca.transform(X_test_scaled)                 # ONLY TRANSFORM test

    print(f"NCA transformed train shape: {X_train_nca.shape}")
    print(f"NCA transformed test shape: {X_test_nca.shape}")

    # Step 3: Apply SMOTE for balancing (ONLY on training data)
    if use_smote:
        print("Applying SMOTE for class balancing...")
        print(f"Before SMOTE - Class distribution: {np.bincount(y_train)}")

        smote = SMOTE(random_state=42)
        X_train_balanced, y_train_balanced = smote.fit_resample(X_train_nca, y_train)

        print(f"After SMOTE - Class distribution: {np.bincount(y_train_balanced)}")
        print(f"SMOTE balanced train shape: {X_train_balanced.shape}")
    else:
        X_train_balanced = X_train_nca
        y_train_balanced = y_train

    # Step 4: Calculate optimal reshape dimensions for CNN input
    n_features = X_train_balanced.shape[1]

    # Find factors to create CNN-compatible dimensions
    height = max(min_dim, int(np.sqrt(n_features)))
    width = max(min_dim, int(np.ceil(n_features / height)))

    # Ensure minimum dimensions for CNN compatibility
    if height < min_dim:
        height = min_dim
    if width < min_dim:
        width = min_dim

    # Ensure it's CNN-friendly (pad if necessary)
    target_size = height * width
    pad_size = max(0, target_size - n_features)

    if pad_size > 0:
        print(f"Padding with {pad_size} zeros for CNN compatibility...")
        X_train_balanced = np.pad(X_train_balanced, ((0, 0), (0, pad_size)), mode='constant', constant_values=0)
        X_test_nca = np.pad(X_test_nca, ((0, 0), (0, pad_size)), mode='constant', constant_values=0)
    elif pad_size < 0:
        # Truncate if we have too many features
        print(f"Truncating {-pad_size} features for CNN compatibility...")
        X_train_balanced = X_train_balanced[:, :target_size]
        X_test_nca = X_test_nca[:, :target_size]

    # Reshape to 3D for CNN (samples, height, width, channels=1)
    X_train_final = X_train_balanced.reshape(X_train_balanced.shape[0], height, width, 1)
    X_test_final = X_test_nca.reshape(X_test_nca.shape[0], height, width, 1)

    # Enhanced feature information for visualization
    enhanced_feature_info = original_feature_info.copy() if original_feature_info else {}
    enhanced_feature_info.update({
        'nca_shape': (X_train_final.shape[0], height, width, 1),
        'nca_components': n_components,
        'original_features': n_features,
        'padded_features': target_size
    })

    print(f"Final CNN-ready train shape: {X_train_final.shape}")
    print(f"Final CNN-ready test shape: {X_test_final.shape}")
    print("--- NCA + SMOTE Preprocessing Complete ---\n")

    return X_train_final, y_train_balanced, X_test_final, nca, scaler, enhanced_feature_info

def run_full_shap_analysis(model, X_train, X_test, y_test, output_path, nca_data_shape, feature_info=None, samples_per_class=50, top_n=20):
    """
    Run SHAP analysis with balanced sample selection for NCA-transformed data.
    """
    print("\n--- Running Full SHAP Analysis on NCA Data ---")
    os.makedirs(output_path, exist_ok=True)

    feature_map_info = generate_nca_feature_map_info(nca_data_shape, feature_info)
    legend_patches = feature_map_info['legend_patches']
    total_rows = feature_map_info['total_rows']

    # Balanced sample selection
    healthy_indices = np.where(y_test == 0)[0]
    parkinson_indices = np.where(y_test == 1)[0]
    num_healthy_to_select = min(samples_per_class, len(healthy_indices))
    num_parkinson_to_select = min(samples_per_class, len(parkinson_indices))

    selected_healthy_indices = np.random.choice(healthy_indices, num_healthy_to_select, replace=False)
    selected_parkinson_indices = np.random.choice(parkinson_indices, num_parkinson_to_select, replace=False)
    final_indices = np.concatenate([selected_healthy_indices, selected_parkinson_indices])
    np.random.shuffle(final_indices)

    test_samples = X_test[final_indices]
    y_true_samples = y_test[final_indices]

    print(f"Calculating SHAP values for {len(test_samples)} balanced samples...")
    explainer = shap.GradientExplainer(model, X_train[:50].astype(np.float32))
    shap_values_list = []

    for sample in tqdm(test_samples, desc="SHAP Progress"):
        sample_batch = np.expand_dims(sample, axis=0).astype(np.float32)
        sv = explainer.shap_values(sample_batch)
        if isinstance(sv, list):
            sv = sv[0]

        # Handle the shape properly - squeeze all singleton dimensions except batch
        sv_squeezed = np.squeeze(sv)  # Remove all singleton dimensions

        # Ensure we have the right shape (height, width)
        if sv_squeezed.ndim == 3:  # (height, width, 1)
            sv_squeezed = np.squeeze(sv_squeezed, axis=-1)  # Remove channel dimension
        elif sv_squeezed.ndim == 1:  # Flattened
            height, width = nca_data_shape[1], nca_data_shape[2]
            sv_squeezed = sv_squeezed.reshape(height, width)

        shap_values_list.append(sv_squeezed)

    shap_values = np.array(shap_values_list)
    print(f"\nSHAP values shape: {shap_values.shape}")

    # Ensure we have the right dimensions
    if shap_values.ndim == 3:  # (samples, height, width)
        height, width = shap_values.shape[1], shap_values.shape[2]
    else:
        print(f"Unexpected SHAP values shape: {shap_values.shape}")
        return

    actual_data_time_steps = width

    # Global Top-N analysis
    flat_shap = shap_values.reshape(shap_values.shape[0], -1)
    mean_abs = np.mean(np.abs(flat_shap), axis=0)
    top_idx = np.argsort(mean_abs)[::-1][:top_n]

    # Fix the coordinate calculation
    coords = [np.unravel_index(i, (height, width)) for i in top_idx]
    labels = [f"R{r}C{c}" for r, c in coords]

    plt.figure(figsize=(12, 6))
    plt.bar(range(len(top_idx)), mean_abs[top_idx])
    plt.xticks(range(len(top_idx)), labels, rotation=45, ha="right")
    plt.title(f"Top-{top_n} Global SHAP Features (NCA transformed)")
    plt.xlabel("Row × Column")
    plt.ylabel("Mean |SHAP value|")
    plt.tight_layout()
    plt.savefig(os.path.join(output_path, "shap_global_bar_nca.png"), dpi=300, bbox_inches="tight")
    plt.close()
    print("-> Saved 'shap_global_bar_nca.png'")

    def plot_aligned_heatmap(heatmap_data, title, filename_suffix, cmap, label, vmin=None, vmax=None):
        fig, axes = plt.subplots(1, 2, figsize=(14, 7))
        ax_shap, ax_feature_map = axes[0], axes[1]

        # SHAP Heatmap
        img = ax_shap.imshow(heatmap_data, cmap=cmap, aspect='auto', interpolation='nearest', vmin=vmin, vmax=vmax)
        ax_shap.set_title(title, fontsize=12)
        ax_shap.set_xlabel(f"NCA Columns ({actual_data_time_steps})", fontsize=10)
        ax_shap.set_ylabel(f"NCA Rows ({total_rows})", fontsize=10)

        divider = make_axes_locatable(ax_shap)
        cax = divider.append_axes("right", size="5%", pad=0.1)
        fig.colorbar(img, cax=cax, label=label)

        # Feature Map
        ax_feature_map.imshow(feature_map_info['color_mask'], cmap=feature_map_info['colors'],
                             aspect='auto', interpolation='nearest')
        ax_feature_map.set_title("NCA Feature Map", fontsize=12)
        ax_feature_map.set_xlabel(f"NCA Columns ({actual_data_time_steps})", fontsize=10)
        ax_feature_map.tick_params(axis='y', labelleft=False)

        ax_feature_map.legend(handles=legend_patches, loc='upper left', bbox_to_anchor=(1.02, 1),
                             borderaxespad=0., fontsize=8)

        fig.suptitle(f"NCA-SHAP Analysis: {title}", fontsize=16, fontweight='bold')
        plt.tight_layout(rect=[0, 0, 1, 0.95])

        plt.savefig(os.path.join(output_path, f"shap_aligned_nca_{filename_suffix}.png"), dpi=300)
        plt.close(fig)
        print(f"-> Saved 'shap_aligned_nca_{filename_suffix}.png'")

    # Generate class-specific and difference maps
    hc_mask, pd_mask = (y_true_samples == 0), (y_true_samples == 1)

    if np.any(hc_mask):
        hc_mean = shap_values[hc_mask].mean(axis=0)
        plot_aligned_heatmap(hc_mean, "Average SHAP - Healthy (NCA)", "summary_healthy", "bwr", "Mean SHAP Value")
    if np.any(pd_mask):
        pd_mean = shap_values[pd_mask].mean(axis=0)
        plot_aligned_heatmap(pd_mean, "Average SHAP - Parkinson (NCA)", "summary_parkinson", "bwr", "Mean SHAP Value")
    if np.any(hc_mask) and np.any(pd_mask):
        diff_map = pd_mean - hc_mean
        max_abs_diff = np.max(np.abs(diff_map))
        plot_aligned_heatmap(diff_map, "SHAP Difference (PD - HC, NCA)", "difference", "seismic", "Δ SHAP (PD - HC)", vmin=-max_abs_diff, vmax=max_abs_diff)

    # NEW: Aligned Significance Analysis
    print("\n--- Running Aligned Significance Analysis ---")
    run_aligned_significance_analysis(shap_values, y_true_samples, output_path, feature_map_info, actual_data_time_steps)

    print("\n--- NCA-SHAP Analysis Complete ---")

def run_aligned_significance_analysis(shap_values, y_true_samples, output_path, feature_map_info, actual_data_time_steps, alpha=0.05):
    """
    Run statistical significance analysis on SHAP values between classes.
    """
    print("Computing statistical significance of SHAP differences...")

    # Separate SHAP values by class
    hc_mask = (y_true_samples == 0)
    pd_mask = (y_true_samples == 1)

    if not np.any(hc_mask) or not np.any(pd_mask):
        print("Cannot perform significance analysis: need both classes present")
        return

    hc_shap = shap_values[hc_mask]  # (n_hc, height, width)
    pd_shap = shap_values[pd_mask]  # (n_pd, height, width)

    height, width = shap_values.shape[1], shap_values.shape[2]
    p_values = np.zeros((height, width))
    t_statistics = np.zeros((height, width))

    # Perform t-test for each spatial location
    for i in range(height):
        for j in range(width):
            hc_values = hc_shap[:, i, j]
            pd_values = pd_shap[:, i, j]

            # Perform independent t-test
            t_stat, p_val = ttest_ind(pd_values, hc_values, equal_var=False)
            t_statistics[i, j] = t_stat
            p_values[i, j] = p_val

    # Apply significance threshold
    significant_mask = p_values < alpha

    # Create significance-weighted difference map
    mean_diff = pd_shap.mean(axis=0) - hc_shap.mean(axis=0)
    significant_diff = np.where(significant_mask, mean_diff, 0)

    # Calculate significance statistics
    total_positions = height * width
    significant_positions = np.sum(significant_mask)
    significance_percentage = (significant_positions / total_positions) * 100

    print(f"Significant positions: {significant_positions}/{total_positions} ({significance_percentage:.1f}%)")

    # Create comprehensive significance visualization
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle(f"Statistical Significance Analysis (α = {alpha})", fontsize=16, fontweight='bold')

    # Row 1: Raw data
    # 1. Mean difference map
    im1 = axes[0,0].imshow(mean_diff, cmap='seismic', aspect='auto', interpolation='nearest')
    axes[0,0].set_title('Mean SHAP Difference\n(PD - HC)')
    axes[0,0].set_xlabel(f"NCA Columns ({actual_data_time_steps})")
    axes[0,0].set_ylabel("NCA Rows")
    divider1 = make_axes_locatable(axes[0,0])
    cax1 = divider1.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im1, cax=cax1, label="Δ SHAP")

    # 2. P-values map
    im2 = axes[0,1].imshow(p_values, cmap='viridis_r', aspect='auto', interpolation='nearest', vmax=0.1)
    axes[0,1].set_title('P-values Map\n(darker = more significant)')
    axes[0,1].set_xlabel(f"NCA Columns ({actual_data_time_steps})")
    axes[0,1].set_yticklabels([])
    divider2 = make_axes_locatable(axes[0,1])
    cax2 = divider2.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im2, cax=cax2, label="p-value")

    # 3. Feature map
    axes[0,2].imshow(feature_map_info['color_mask'], cmap=feature_map_info['colors'],
                     aspect='auto', interpolation='nearest')
    axes[0,2].set_title("Feature Layout")
    axes[0,2].set_xlabel(f"NCA Columns ({actual_data_time_steps})")
    axes[0,2].set_yticklabels([])
    axes[0,2].legend(handles=feature_map_info['legend_patches'], loc='upper left',
                     bbox_to_anchor=(1.02, 1), borderaxespad=0., fontsize=8)

    # Row 2: Significance analysis
    # 4. Significance mask
    im4 = axes[1,0].imshow(significant_mask.astype(int), cmap='RdYlBu_r', aspect='auto',
                          interpolation='nearest', vmin=0, vmax=1)
    axes[1,0].set_title(f'Significance Mask\n({significance_percentage:.1f}% significant)')
    axes[1,0].set_xlabel(f"NCA Columns ({actual_data_time_steps})")
    axes[1,0].set_ylabel("NCA Rows")
    divider4 = make_axes_locatable(axes[1,0])
    cax4 = divider4.append_axes("right", size="5%", pad=0.05)
    cbar4 = fig.colorbar(im4, cax=cax4, ticks=[0, 1])
    cbar4.set_ticklabels(['Non-sig.', 'Significant'])

    # 5. Significant differences only
    max_abs_sig_diff = np.max(np.abs(significant_diff)) if np.any(significant_diff) else 1
    im5 = axes[1,1].imshow(significant_diff, cmap='seismic', aspect='auto', interpolation='nearest',
                          vmin=-max_abs_sig_diff, vmax=max_abs_sig_diff)
    axes[1,1].set_title('Significant Differences Only\n(PD - HC, masked)')
    axes[1,1].set_xlabel(f"NCA Columns ({actual_data_time_steps})")
    axes[1,1].set_yticklabels([])
    divider5 = make_axes_locatable(axes[1,1])
    cax5 = divider5.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im5, cax=cax5, label="Significant Δ SHAP")

    # 6. T-statistics
    im6 = axes[1,2].imshow(t_statistics, cmap='seismic', aspect='auto', interpolation='nearest')
    axes[1,2].set_title('T-statistics\n(PD vs HC)')
    axes[1,2].set_xlabel(f"NCA Columns ({actual_data_time_steps})")
    axes[1,2].set_yticklabels([])
    divider6 = make_axes_locatable(axes[1,2])
    cax6 = divider6.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im6, cax=cax6, label="t-statistic")

    plt.tight_layout()

    # Save the significance analysis
    significance_path = os.path.join(output_path, "shap_aligned_nca_significance.png")
    plt.savefig(significance_path, dpi=300, bbox_inches="tight")
    plt.close(fig)
    print(f"-> Saved significance analysis to '{significance_path}'")

    # Save significance statistics
    significance_stats = {
        'total_positions': total_positions,
        'significant_positions': significant_positions,
        'significance_percentage': significance_percentage,
        'alpha_threshold': alpha,
        'max_significant_difference': float(max_abs_sig_diff),
        'mean_p_value': float(np.mean(p_values)),
        'min_p_value': float(np.min(p_values))
    }

    stats_df = pd.DataFrame([significance_stats])
    stats_path = os.path.join(output_path, "significance_statistics.csv")
    stats_df.to_csv(stats_path, index=False)
    print(f"-> Saved significance statistics to '{stats_path}'")

def load_data(feature_file_path):
    """
    Load data and return both features and original feature information for tracking.
    """
    mel_spectrograms = None
    mfccs = None
    spectrograms = None
    original_shapes = {}

    print(f"--- Loading data from {feature_file_path} ---")
    with np.load(feature_file_path) as data:
        labels = data['labels']

        if 'mel_spectrogram' in data.keys():
            mel_spectrograms = data['mel_spectrogram']
            original_shapes['mel_spectrogram'] = mel_spectrograms.shape
            print(f"Added mel_spectrogram: {mel_spectrograms.shape}")

        if 'mfcc' in data:
            mfccs = data['mfcc']
            original_shapes['mfcc'] = mfccs.shape
            print(f"Added mfcc: {mfccs.shape}")

        if 'spectrogram' in data.keys():
            spectrograms = data['spectrogram']
            original_shapes['spectrogram'] = spectrograms.shape
            print(f"Added spectrogram: {spectrograms.shape}")

        feature_arrays = []

        if mel_spectrograms is not None:
            feature_arrays.append(mel_spectrograms)

        if mfccs is not None:
            feature_arrays.append(mfccs)

        if spectrograms is not None:
            feature_arrays.append(spectrograms)

        if feature_arrays:
            X = np.concatenate(feature_arrays, axis=-1)  # concat along last axis
            print(f"Final concatenated shape: {X.shape}")
        else:
            raise ValueError("No valid features found in the file!")

        # Calculate feature contributions
        feature_contributions = calculate_original_feature_contributions(original_shapes)

        return X, labels, feature_contributions


In [6]:
if __name__ == '__main__':

    X, y, original_feature_info = load_data(FEATURES_FILE_PATH)

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )
    print(f"\nData split into training ({len(y_train)}) and testing ({len(y_test)}) sets.")

    # Apply NCA + SMOTE preprocessing with feature tracking
    X_train_processed, y_train_processed, X_test_processed, nca_transformer, scaler, enhanced_feature_info = apply_nca_smote_preprocessing(
        X_train, y_train, X_test, n_components=NCA_COMPONENTS, use_smote=USE_SMOTE,
        original_feature_info=original_feature_info
    )

    # Build model with proper 3D input shape (height, width, channels)
    model = build_model(input_shape=(X_train_processed.shape[1], X_train_processed.shape[2], X_train_processed.shape[3]))
    model.summary()

    optimizer = Adam(learning_rate=LEARNING_RATE)
    model.compile(optimizer=optimizer, loss='binary_crossentropy',
                 metrics=['accuracy', tf.keras.metrics.AUC(name='auc')])

    print("\n--- Starting model training with NCA+SMOTE processed data ---")
    history = model.fit(
        X_train_processed, y_train_processed,
        validation_data=(X_test_processed, y_test),
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        callbacks=[checkpoint_cb],
        verbose=1
    )
    print("--- Model training finished ---")

    pd.DataFrame(history.history).to_csv(HISTORY_SAVE_PATH, index_label='epoch')
    print(f"\nTraining history saved to '{HISTORY_SAVE_PATH}'")

    print("\n--- Start evaluating model ---")
    y_pred_probabilities = model.predict(X_test_processed)
    save_metrics_to_csv(y_test, y_pred_probabilities, EVALUATION_FILE_PATH)

    if os.path.exists(BEST_MODEL_PATH):
        print("\n--- Loading best saved model for explainability analysis ---")
        best_model = load_model(BEST_MODEL_PATH, custom_objects={'ParkinsonDetectorModel': ParkinsonDetectorModel})

        # Run SHAP analysis with enhanced feature information
        run_full_shap_analysis(model, X_train_processed, X_test_processed, y_test,
                              SHAP_OUTPUT_PATH, X_test_processed.shape, enhanced_feature_info, 50, 20)

        # Run Grad-CAM analysis with enhanced feature information
        run_gradcam_analysis(best_model, X_test_processed, y_test, GRADCAM_OUTPUT_PATH,
                           X_test_processed.shape, 50)
    else:
        print("\nCould not find best model file. Skipping SHAP and Grad-CAM analysis.")


--- Loading data from D:\Projects\Voice\Parkinson-s-Disease-Detector-Using-AI\Parkinson-s-Disease-Detector-Using-AI\1\UAMS\data\features_A_DEFAULT.npz ---
Added mel_spectrogram: (328, 30, 94)
Added mfcc: (328, 30, 94)
Final concatenated shape: (328, 30, 188)

Data split into training (262) and testing (66) sets.

--- Applying NCA + SMOTE Preprocessing ---
Original train shape: (262, 30, 188)
Original test shape: (66, 30, 188)
Flattened train shape: (262, 5640)
Flattened test shape: (66, 5640)
Standardizing features...
Applying NCA with 128 components...
NCA transformed train shape: (262, 128)
NCA transformed test shape: (66, 128)
Applying SMOTE for class balancing...
Before SMOTE - Class distribution: [131 131]
After SMOTE - Class distribution: [131 131]
SMOTE balanced train shape: (262, 128)
Padding with 128 zeros for CNN compatibility...
Final CNN-ready train shape: (262, 16, 16, 1)
Final CNN-ready test shape: (66, 16, 16, 1)
--- NCA + SMOTE Preprocessing Complete ---

--- Building t


--- Starting model training with NCA+SMOTE processed data ---
Epoch 1/30
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 51ms/step - accuracy: 0.5623 - auc: 0.5639 - loss: 2.1278
Epoch 1: val_auc improved from None to 0.58494, saving model to D:\Projects\Voice\Parkinson-s-Disease-Detector-Using-AI\Parkinson-s-Disease-Detector-Using-AI\1\UAMS\results_A_DEFAULT\nca_smote_parallel\best_model.keras
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 387ms/step - accuracy: 0.5305 - auc: 0.5501 - loss: 2.1097 - val_accuracy: 0.5000 - val_auc: 0.5849 - val_loss: 2.0194
Epoch 2/30
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 53ms/step - accuracy: 0.4949 - auc: 0.5672 - loss: 2.0126
Epoch 2: val_auc did not improve from 0.58494
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 78ms/step - accuracy: 0.5000 - auc: 0.5653 - loss: 1.9841 - val_accuracy: 0.5000 - val_auc: 0.4261 - val_loss: 1.9255
Epoch 3/30
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━

Expected: keras_tensor_4
Received: inputs=['Tensor(shape=(1, 16, 16, 1))']
Expected: keras_tensor_4
Received: inputs=['Tensor(shape=(50, 16, 16, 1))']
SHAP Progress: 100%|██████████| 66/66 [00:33<00:00,  1.97it/s]



SHAP values shape: (66, 16, 16)
-> Saved 'shap_global_bar_nca.png'
-> Saved 'shap_aligned_nca_summary_healthy.png'
-> Saved 'shap_aligned_nca_summary_parkinson.png'
-> Saved 'shap_aligned_nca_difference.png'

--- Running Aligned Significance Analysis ---
Computing statistical significance of SHAP differences...
Significant positions: 8/256 (3.1%)
-> Saved significance analysis to 'D:\Projects\Voice\Parkinson-s-Disease-Detector-Using-AI\Parkinson-s-Disease-Detector-Using-AI\1\UAMS\results_A_DEFAULT\nca_smote_parallel\shap_analysis\shap_aligned_nca_significance.png'
-> Saved significance statistics to 'D:\Projects\Voice\Parkinson-s-Disease-Detector-Using-AI\Parkinson-s-Disease-Detector-Using-AI\1\UAMS\results_A_DEFAULT\nca_smote_parallel\shap_analysis\significance_statistics.csv'

--- NCA-SHAP Analysis Complete ---

--- Running Grad-CAM Analysis with NCA Data ---
Using last conv layer: last_conv_layer
Selected 33 PD and 33 HC samples for Grad-CAM.
✅ Saved NCA Grad-CAM comparison to D:\P