In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import (Input, Conv2D, MaxPooling2D, Dropout, Flatten,
                                     Dense, LSTM, MultiHeadAttention, Concatenate, Reshape)
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import (confusion_matrix, precision_score, recall_score, f1_score,
                           roc_curve, roc_auc_score, precision_recall_curve,
                           average_precision_score, classification_report)
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
from tqdm import tqdm
from scipy.stats import ttest_ind
import pandas as pd
import shap
from tensorflow.keras.saving import register_keras_serializable
from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
import numpy as np
import os
from skimage.transform import resize
from matplotlib.colors import ListedColormap
import matplotlib.cm as cm
import librosa
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Set style for better plots
plt.style.use('default')
sns.set_palette("husl")

# =============================================================================
# --- Configuration ---
# =============================================================================

# Dataset and mode configuration
ITALIAN_DATASET = "ITALIAN_DATASET"
UAMS_DATASET = "UAMS_DATASET"
NEUROVOZ_DATASET = "NEUROVOZ_DATASET"
MPOWER_DATASET = "MPOWER_DATASET"
SYNTHETIC_DATASET = "SYNTHETIC_DATASET"

MODE_ALL_VALIDS = "ALL_VALIDS"
MODE_A = "A"

FEATURE_MODE_BASIC = "BASIC"
FEATURE_MODE_ALL = "ALL"
FEATURE_MODE_DEFAULT = "DEFAULT"

MODEL_NAME = "nca_cnn_lstm"

# Select configuration
DATASET = UAMS_DATASET
MODE = MODE_A
FEATURE_MODE = FEATURE_MODE_DEFAULT

# Map dataset names
dataset_mapping = {
    NEUROVOZ_DATASET: "Neurovoz",
    UAMS_DATASET: "UAMS",
    MPOWER_DATASET: "mPower",
    SYNTHETIC_DATASET: "Synthetic",
    ITALIAN_DATASET: "Italian"
}

dataset = dataset_mapping[DATASET]

# Path Setup
FEATURES_FILE_PATH = os.path.join(os.getcwd(), dataset, "data", f"features_{MODE}_{FEATURE_MODE}.npz")
MODEL_PATH = os.path.join(os.getcwd(), dataset, f"results_{MODE}_{FEATURE_MODE}", MODEL_NAME)
os.makedirs(MODEL_PATH, exist_ok=True)

# File paths
EVALUATION_FILE_PATH = os.path.join(MODEL_PATH, "evaluation.csv")
HISTORY_SAVE_PATH = os.path.join(MODEL_PATH, "history.csv")
BEST_MODEL_PATH = os.path.join(MODEL_PATH, "best_model.keras")
PLOTS_PATH = os.path.join(MODEL_PATH, "plots")
SHAP_OUTPUT_PATH = os.path.join(MODEL_PATH, "shap_analysis")
GRADCAM_OUTPUT_PATH = os.path.join(MODEL_PATH, "gradcam_analysis")
ANALYSIS_PATH = os.path.join(MODEL_PATH, "comprehensive_analysis")

# Create directories
for path in [PLOTS_PATH, SHAP_OUTPUT_PATH, GRADCAM_OUTPUT_PATH, ANALYSIS_PATH]:
    os.makedirs(path, exist_ok=True)

# Hyperparameters
EPOCHS = 30
BATCH_SIZE = 32
LEARNING_RATE = 0.001
DROPOUT_RATE = 0.5
L2_STRENGTH = 0.01

# Callbacks
checkpoint_cb = ModelCheckpoint(BEST_MODEL_PATH, monitor='val_auc', mode='max',
                               save_best_only=True, verbose=1)
early_stop_cb = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

# =============================================================================
# --- Data Loading ---
# =============================================================================

def load_data(feature_file_path):
    """Load and prepare data from feature file"""
    print(f"--- Loading data from {feature_file_path} ---")

    if not os.path.exists(feature_file_path):
        print(f"Creating dummy data for demonstration...")
        # Create dummy data for demonstration
        n_samples = 1000
        n_features = 60  # mel_spectrogram (30) + mfcc (30)
        n_timesteps = 94

        X = np.random.randn(n_samples, n_features, n_timesteps)
        y = np.random.randint(0, 2, n_samples)

        return X, y

    with np.load(feature_file_path) as data:
        labels = data['labels']
        mel_spectrogram = data['mel_spectrogram']
        mfcc = data['mfcc']
        X = np.concatenate((mel_spectrogram, mfcc), axis=-1)
        return X, labels

# =============================================================================
# --- Model Architecture ---
# =============================================================================

@register_keras_serializable()
class ParkinsonDetectorModel(Model):
    def __init__(self, input_shape, **kwargs):
        super(ParkinsonDetectorModel, self).__init__(**kwargs)
        self.input_shape_config = input_shape

        self.reshape_in = Reshape((input_shape[0], input_shape[1], 1))
        self.conv1a = Conv2D(64, 5, activation='relu', kernel_regularizer=l2(L2_STRENGTH), padding='same')
        self.conv1b = Conv2D(64, 5, activation='relu', kernel_regularizer=l2(L2_STRENGTH), padding='same')
        self.pool1 = MaxPooling2D(5)
        self.drop1 = Dropout(DROPOUT_RATE)
        self.conv2a = Conv2D(64, 5, activation='relu', kernel_regularizer=l2(L2_STRENGTH), padding='same')
        self.conv2b = Conv2D(64, 5, activation='relu', kernel_regularizer=l2(L2_STRENGTH), padding='same', name='last_conv_layer')
        self.pool2 = MaxPooling2D(5, name='cnn_output')
        self.drop2 = Dropout(DROPOUT_RATE)
        self.flatten_cnn = Flatten()
        self.attention = MultiHeadAttention(num_heads=2, key_dim=64, name='attention_output')
        self.flatten_att = Flatten()
        self.lstm1 = LSTM(128, return_sequences=True)
        self.lstm2 = LSTM(128, return_sequences=False, name='lstm_output')
        self.drop_lstm = Dropout(DROPOUT_RATE)
        self.concat = Concatenate()
        self.dense_bottleneck = Dense(128, activation='relu', name='bottleneck_features')
        self.dense_output = Dense(1, activation='sigmoid')

    def call(self, inputs, training=False):
        x = self.reshape_in(inputs)
        x = self.conv1a(x)
        x = self.conv1b(x)
        x = self.pool1(x)
        x = self.drop1(x, training=training)
        x = self.conv2a(x)
        x = self.conv2b(x)
        cnn_branch_output = self.pool2(x)
        x = self.drop2(cnn_branch_output, training=training)

        cnn_flat = self.flatten_cnn(x)
        shape = tf.shape(x)
        sequence = tf.reshape(x, [-1, shape[1] * shape[2], shape[3]])

        att_branch_output = self.attention(query=sequence, key=sequence, value=sequence)
        att_flat = self.flatten_att(att_branch_output)

        lstm_seq = self.lstm1(sequence)
        lstm_branch_output = self.lstm2(lstm_seq)
        lstm_out = self.drop_lstm(lstm_branch_output, training=training)

        concatenated = self.concat([cnn_flat, att_flat, lstm_out])
        bottleneck = self.dense_bottleneck(concatenated)
        final_output = self.dense_output(bottleneck)

        return final_output

    def get_config(self):
        config = super(ParkinsonDetectorModel, self).get_config()
        config.update({"input_shape": self.input_shape_config})
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

def build_model(input_shape: tuple) -> Model:
    """Build the hybrid model"""
    print("--- Building the model ---")
    inputs = Input(shape=input_shape)
    parkinson_detector = ParkinsonDetectorModel(input_shape=input_shape)
    outputs = parkinson_detector(inputs)
    model = Model(inputs=inputs, outputs=outputs)
    print("Model built successfully.")
    return model

# =============================================================================
# --- Comprehensive Plotting Functions ---
# =============================================================================

def plot_training_history(history_df, save_path):
    """Plot comprehensive training history"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Training History Analysis', fontsize=16, fontweight='bold')

    # Loss
    axes[0, 0].plot(history_df.index, history_df['loss'], label='Training Loss', linewidth=2)
    axes[0, 0].plot(history_df.index, history_df['val_loss'], label='Validation Loss', linewidth=2)
    axes[0, 0].set_title('Model Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Accuracy
    axes[0, 1].plot(history_df.index, history_df['accuracy'], label='Training Accuracy', linewidth=2)
    axes[0, 1].plot(history_df.index, history_df['val_accuracy'], label='Validation Accuracy', linewidth=2)
    axes[0, 1].set_title('Model Accuracy')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # AUC
    axes[1, 0].plot(history_df.index, history_df['auc'], label='Training AUC', linewidth=2)
    axes[1, 0].plot(history_df.index, history_df['val_auc'], label='Validation AUC', linewidth=2)
    axes[1, 0].set_title('Model AUC')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('AUC')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)

    # Learning curves comparison
    axes[1, 1].plot(history_df.index, history_df['loss'], label='Train Loss', alpha=0.7)
    axes[1, 1].plot(history_df.index, history_df['val_loss'], label='Val Loss', alpha=0.7)
    ax2 = axes[1, 1].twinx()
    ax2.plot(history_df.index, history_df['accuracy'], label='Train Acc', linestyle='--', alpha=0.7)
    ax2.plot(history_df.index, history_df['val_accuracy'], label='Val Acc', linestyle='--', alpha=0.7)
    axes[1, 1].set_title('Loss vs Accuracy')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Loss', color='blue')
    ax2.set_ylabel('Accuracy', color='red')
    axes[1, 1].legend(loc='upper left')
    ax2.legend(loc='upper right')

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Training history plot saved to {save_path}")

def plot_model_performance(y_true, y_pred_proba, save_dir):
    """Generate comprehensive performance plots"""
    y_pred = (y_pred_proba > 0.5).astype(int)

    # Create figure with subplots
    fig = plt.figure(figsize=(20, 15))
    gs = gridspec.GridSpec(3, 3, hspace=0.3, wspace=0.3)

    # 1. Confusion Matrix
    ax1 = fig.add_subplot(gs[0, 0])
    cm = confusion_matrix(y_true, y_pred)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax1)
    ax1.set_title('Confusion Matrix')
    ax1.set_xlabel('Predicted')
    ax1.set_ylabel('Actual')

    # 2. ROC Curve
    ax2 = fig.add_subplot(gs[0, 1])
    fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
    auc_score = roc_auc_score(y_true, y_pred_proba)
    ax2.plot(fpr, tpr, label=f'ROC Curve (AUC = {auc_score:.3f})', linewidth=2)
    ax2.plot([0, 1], [0, 1], 'k--', alpha=0.5)
    ax2.set_xlabel('False Positive Rate')
    ax2.set_ylabel('True Positive Rate')
    ax2.set_title('ROC Curve')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # 3. Precision-Recall Curve
    ax3 = fig.add_subplot(gs[0, 2])
    precision, recall, _ = precision_recall_curve(y_true, y_pred_proba)
    avg_precision = average_precision_score(y_true, y_pred_proba)
    ax3.plot(recall, precision, label=f'PR Curve (AP = {avg_precision:.3f})', linewidth=2)
    ax3.set_xlabel('Recall')
    ax3.set_ylabel('Precision')
    ax3.set_title('Precision-Recall Curve')
    ax3.legend()
    ax3.grid(True, alpha=0.3)

    # 4. Prediction Distribution
    ax4 = fig.add_subplot(gs[1, 0])
    ax4.hist(y_pred_proba[y_true == 0], bins=30, alpha=0.7, label='Healthy', density=True)
    ax4.hist(y_pred_proba[y_true == 1], bins=30, alpha=0.7, label='Parkinson', density=True)
    ax4.axvline(x=0.5, color='red', linestyle='--', label='Threshold')
    ax4.set_xlabel('Prediction Probability')
    ax4.set_ylabel('Density')
    ax4.set_title('Prediction Distribution')
    ax4.legend()

    # 5. Threshold Analysis
    ax5 = fig.add_subplot(gs[1, 1])
    thresholds = np.linspace(0.1, 0.9, 50)
    precisions, recalls, f1s = [], [], []

    for thresh in thresholds:
        y_pred_thresh = (y_pred_proba > thresh).astype(int)
        precisions.append(precision_score(y_true, y_pred_thresh, zero_division=0))
        recalls.append(recall_score(y_true, y_pred_thresh, zero_division=0))
        f1s.append(f1_score(y_true, y_pred_thresh, zero_division=0))

    ax5.plot(thresholds, precisions, label='Precision', linewidth=2)
    ax5.plot(thresholds, recalls, label='Recall', linewidth=2)
    ax5.plot(thresholds, f1s, label='F1-Score', linewidth=2)
    ax5.axvline(x=0.5, color='red', linestyle='--', alpha=0.7, label='Default Threshold')
    ax5.set_xlabel('Threshold')
    ax5.set_ylabel('Score')
    ax5.set_title('Threshold Analysis')
    ax5.legend()
    ax5.grid(True, alpha=0.3)

    # 6. Class Distribution
    ax6 = fig.add_subplot(gs[1, 2])
    class_counts = np.bincount(y_true)
    ax6.bar(['Healthy', 'Parkinson'], class_counts, color=['lightblue', 'lightcoral'])
    ax6.set_title('Class Distribution')
    ax6.set_ylabel('Count')
    for i, count in enumerate(class_counts):
        ax6.text(i, count + 0.01 * max(class_counts), str(count), ha='center')

    # 7. Error Analysis
    ax7 = fig.add_subplot(gs[2, :])
    errors = np.abs(y_true - y_pred_proba.flatten())
    ax7.scatter(range(len(errors)), errors, alpha=0.6,
               c=['red' if y_true[i] != y_pred[i] else 'blue' for i in range(len(y_true))],
               s=20)
    ax7.set_xlabel('Sample Index')
    ax7.set_ylabel('Prediction Error')
    ax7.set_title('Prediction Errors (Red: Misclassified, Blue: Correct)')
    ax7.grid(True, alpha=0.3)

    plt.suptitle('Comprehensive Model Performance Analysis', fontsize=16, fontweight='bold')
    plt.savefig(os.path.join(save_dir, 'comprehensive_performance.png'), dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Comprehensive performance plot saved to {save_dir}")

def save_metrics_to_csv(y_true, y_pred_proba, filename, threshold=0.5):
    """Save detailed metrics to CSV"""
    y_pred_binary = (np.array(y_pred_proba) > threshold).astype(int)

    cm = confusion_matrix(y_true, y_pred_binary)
    tn, fp, fn, tp = cm.ravel()

    total_samples = cm.sum()
    tn_percent = (tn / total_samples) * 100 if total_samples > 0 else 0
    fp_percent = (fp / total_samples) * 100 if total_samples > 0 else 0
    fn_percent = (fn / total_samples) * 100 if total_samples > 0 else 0
    tp_percent = (tp / total_samples) * 100 if total_samples > 0 else 0

    precision = precision_score(y_true, y_pred_binary, zero_division=0)
    recall = recall_score(y_true, y_pred_binary, zero_division=0)
    f1 = f1_score(y_true, y_pred_binary, zero_division=0)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    auc = roc_auc_score(y_true, y_pred_proba)

    report_data = {
        'Metric': [
            'True Positive (TP)', 'True Negative (TN)', 'False Positive (FP)', 'False Negative (FN)',
            'Precision', 'Recall (Sensitivity)', 'Specificity', 'F1-Score', 'AUC', 'Accuracy'
        ],
        'Value': [
            f"{tp} ({tp_percent:.2f}%)", f"{tn} ({tn_percent:.2f}%)",
            f"{fp} ({fp_percent:.2f}%)", f"{fn} ({fn_percent:.2f}%)",
            f"{precision:.4f}", f"{recall:.4f}", f"{specificity:.4f}",
            f"{f1:.4f}", f"{auc:.4f}", f"{(tp + tn) / total_samples:.4f}"
        ]
    }

    df = pd.DataFrame(report_data)
    df.to_csv(filename, index=False, encoding='utf-8-sig')
    print(f"Evaluation results saved to {filename}")

# =============================================================================
# --- Feature Analysis ---
# =============================================================================

def generate_feature_map_info(X_shape):
    """Generate feature layout information"""
    n_features, n_timesteps = X_shape[1], X_shape[2]

    # Assume first half is mel spectrogram, second half is MFCC
    mel_features = n_features // 2
    mfcc_features = n_features - mel_features

    feature_layout = {
        'mel_spectrogram': mel_features,
        'mfcc': mfcc_features
    }

    colors = plt.get_cmap('Paired', len(feature_layout))
    feature_names = list(feature_layout.keys())

    color_mask = np.zeros((n_features, n_timesteps), dtype=int)

    current_row = 0
    for i, (name, num_rows) in enumerate(feature_layout.items()):
        color_mask[current_row:current_row + num_rows, :] = i
        current_row += num_rows

    legend_patches = [mpatches.Patch(color=colors(i), label=f"{name} ({feature_layout[name]} rows)")
                      for i, name in enumerate(feature_names)]

    return {
        'color_mask': color_mask,
        'feature_layout': feature_layout,
        'feature_names': feature_names,
        'colors': colors,
        'legend_patches': legend_patches,
        'total_rows': n_features
    }

def plot_feature_analysis(X, y, save_dir):
    """Generate feature analysis plots"""
    print("Generating feature analysis plots...")

    # Feature statistics by class
    healthy_features = X[y == 0]
    parkinson_features = X[y == 1]

    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # 1. Feature means by class
    healthy_mean = np.mean(healthy_features, axis=(0, 2))
    parkinson_mean = np.mean(parkinson_features, axis=(0, 2))

    axes[0, 0].plot(healthy_mean, label='Healthy', linewidth=2)
    axes[0, 0].plot(parkinson_mean, label='Parkinson', linewidth=2)
    axes[0, 0].set_title('Average Feature Values by Class')
    axes[0, 0].set_xlabel('Feature Index')
    axes[0, 0].set_ylabel('Average Value')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # 2. Feature variance by class
    healthy_var = np.var(healthy_features, axis=(0, 2))
    parkinson_var = np.var(parkinson_features, axis=(0, 2))

    axes[0, 1].plot(healthy_var, label='Healthy', linewidth=2)
    axes[0, 1].plot(parkinson_var, label='Parkinson', linewidth=2)
    axes[0, 1].set_title('Feature Variance by Class')
    axes[0, 1].set_xlabel('Feature Index')
    axes[0, 1].set_ylabel('Variance')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # 3. Sample spectrograms
    if len(healthy_features) > 0:
        axes[1, 0].imshow(healthy_features[0], aspect='auto', cmap='viridis')
        axes[1, 0].set_title('Sample Healthy Spectrogram')
        axes[1, 0].set_xlabel('Time Steps')
        axes[1, 0].set_ylabel('Features')

    if len(parkinson_features) > 0:
        axes[1, 1].imshow(parkinson_features[0], aspect='auto', cmap='viridis')
        axes[1, 1].set_title('Sample Parkinson Spectrogram')
        axes[1, 1].set_xlabel('Time Steps')
        axes[1, 1].set_ylabel('Features')

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'feature_analysis.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # Feature correlation analysis
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))

    # Flatten features for correlation
    healthy_flat = healthy_features.reshape(len(healthy_features), -1)
    parkinson_flat = parkinson_features.reshape(len(parkinson_features), -1)

    # Sample features for correlation (too many features for full correlation matrix)
    n_sample_features = min(50, healthy_flat.shape[1])
    sample_indices = np.random.choice(healthy_flat.shape[1], n_sample_features, replace=False)

    healthy_corr = np.corrcoef(healthy_flat[:, sample_indices].T)
    parkinson_corr = np.corrcoef(parkinson_flat[:, sample_indices].T)

    im1 = axes[0].imshow(healthy_corr, cmap='coolwarm', vmin=-1, vmax=1)
    axes[0].set_title('Healthy Feature Correlations')
    plt.colorbar(im1, ax=axes[0])

    im2 = axes[1].imshow(parkinson_corr, cmap='coolwarm', vmin=-1, vmax=1)
    axes[1].set_title('Parkinson Feature Correlations')
    plt.colorbar(im2, ax=axes[1])

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'feature_correlations.png'), dpi=300, bbox_inches='tight')
    plt.close()

    print(f"Feature analysis plots saved to {save_dir}")

# =============================================================================
# --- SHAP Analysis ---
# =============================================================================

def run_full_shap_analysis(model, X_train, X_test, y_test, output_path,
                          samples_per_class=50, top_n=20):
    """Run comprehensive SHAP analysis"""
    print("\n--- Running Full SHAP Analysis ---")
    os.makedirs(output_path, exist_ok=True)

    feature_map_info = generate_feature_map_info(X_test.shape)

    # Balanced sample selection
    healthy_indices = np.where(y_test == 0)[0]
    parkinson_indices = np.where(y_test == 1)[0]

    num_healthy = min(samples_per_class, len(healthy_indices))
    num_parkinson = min(samples_per_class, len(parkinson_indices))

    selected_healthy = np.random.choice(healthy_indices, num_healthy, replace=False) if num_healthy > 0 else np.array([])
    selected_parkinson = np.random.choice(parkinson_indices, num_parkinson, replace=False) if num_parkinson > 0 else np.array([])

    final_indices = np.concatenate([selected_healthy, selected_parkinson]).astype(int)
    np.random.shuffle(final_indices)

    test_samples = X_test[final_indices]
    y_true_samples = y_test[final_indices]

    print(f"Calculating SHAP values for {len(test_samples)} balanced samples...")

    # Use a smaller background set for faster computation
    background = X_train[:min(50, len(X_train))].astype(np.float32)
    explainer = shap.GradientExplainer(model, background)

    shap_values_list = []
    for sample in tqdm(test_samples, desc="SHAP Progress"):
        sample_batch = np.expand_dims(sample, axis=0).astype(np.float32)
        sv = explainer.shap_values(sample_batch)
        if isinstance(sv, list):
            sv = sv[0]
        shap_values_list.append(sv.squeeze())

    shap_values = np.array(shap_values_list)

    # Global importance analysis
    flat_shap = shap_values.reshape(len(shap_values), -1)
    mean_abs_shap = np.mean(np.abs(flat_shap), axis=0)
    top_indices = np.argsort(mean_abs_shap)[::-1][:top_n]

    # Create coordinates
    coords = [np.unravel_index(i, shap_values.shape[1:]) for i in top_indices]

    # Map to feature names
    def get_feature_name(row_idx, feature_layout):
        cum = 0
        for name, nrows in feature_layout.items():
            if row_idx < cum + nrows:
                return name
            cum += nrows
        return "Unknown"

    labels = []
    for row_idx, time_idx in coords:
        fname = get_feature_name(row_idx, feature_map_info['feature_layout'])
        labels.append(f"{fname}_T{time_idx}")

    # Plot global importance
    plt.figure(figsize=(12, 8))
    bars = plt.bar(range(len(top_indices)), mean_abs_shap[top_indices])
    plt.xticks(range(len(top_indices)), labels, rotation=45, ha='right')
    plt.title(f'Top-{top_n} Most Important Features (Global SHAP)')
    plt.xlabel('Feature (Type_TimeStep)')
    plt.ylabel('Mean |SHAP Value|')

    # Color bars by feature type
    colors = ['lightblue' if 'mel' in label else 'lightcoral' for label in labels]
    for bar, color in zip(bars, colors):
        bar.set_color(color)

    plt.tight_layout()
    plt.savefig(os.path.join(output_path, 'shap_global_importance.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # Class-specific analysis
    hc_mask = (y_true_samples == 0)
    pd_mask = (y_true_samples == 1)

    if np.any(hc_mask) and np.any(pd_mask):
        hc_shap = shap_values[hc_mask].mean(axis=0)
        pd_shap = shap_values[pd_mask].mean(axis=0)
        diff_shap = pd_shap - hc_shap

        # Create comprehensive SHAP visualization
        fig, axes = plt.subplots(2, 2, figsize=(16, 12))

        # Healthy average
        im1 = axes[0, 0].imshow(hc_shap, cmap='RdBu_r', aspect='auto')
        axes[0, 0].set_title('Average SHAP - Healthy')
        axes[0, 0].set_xlabel('Time Steps')
        axes[0, 0].set_ylabel('Features')
        plt.colorbar(im1, ax=axes[0, 0])

        # Parkinson average
        im2 = axes[0, 1].imshow(pd_shap, cmap='RdBu_r', aspect='auto')
        axes[0, 1].set_title('Average SHAP - Parkinson')
        axes[0, 1].set_xlabel('Time Steps')
        axes[0, 1].set_ylabel('Features')
        plt.colorbar(im2, ax=axes[0, 1])

        # Difference
        max_diff = np.max(np.abs(diff_shap))
        im3 = axes[1, 0].imshow(diff_shap, cmap='seismic', aspect='auto',
                               vmin=-max_diff, vmax=max_diff)
        axes[1, 0].set_title('SHAP Difference (Parkinson - Healthy)')
        axes[1, 0].set_xlabel('Time Steps')
        axes[1, 0].set_ylabel('Features')
        plt.colorbar(im3, ax=axes[1, 0])

        # Feature map
        axes[1, 1].imshow(feature_map_info['color_mask'],
                         cmap=feature_map_info['colors'], aspect='auto')
        axes[1, 1].set_title('Feature Map')
        axes[1, 1].set_xlabel('Time Steps')
        axes[1, 1].set_ylabel('Features')
        axes[1, 1].legend(handles=feature_map_info['legend_patches'],
                         loc='center left', bbox_to_anchor=(1, 0.5))

        plt.tight_layout()
        plt.savefig(os.path.join(output_path, 'shap_class_comparison.png'),
                   dpi=300, bbox_inches='tight')
        plt.close()

    # Save SHAP summary statistics
    shap_stats = {
        'total_samples_analyzed': len(test_samples),
        'healthy_samples': np.sum(hc_mask) if np.any(hc_mask) else 0,
        'parkinson_samples': np.sum(pd_mask) if np.any(pd_mask) else 0,
        'mean_abs_shap_healthy': np.mean(np.abs(hc_shap)) if np.any(hc_mask) else 0,
        'mean_abs_shap_parkinson': np.mean(np.abs(pd_shap)) if np.any(pd_mask) else 0,
        'max_shap_difference': np.max(np.abs(diff_shap)) if np.any(hc_mask) and np.any(pd_mask) else 0
    }

    pd.DataFrame([shap_stats]).to_csv(os.path.join(output_path, 'shap_statistics.csv'), index=False)

    print(f"SHAP analysis complete. Results saved to {output_path}")

# =============================================================================
# --- Grad-CAM Analysis ---
# =============================================================================

def run_gradcam_analysis(model, X_test, y_test, output_path, num_samples=50):
    """Run Grad-CAM analysis"""
    print("\n--- Running Grad-CAM Analysis ---")
    os.makedirs(output_path, exist_ok=True)

    # Find ParkinsonDetectorModel layer
    parkinson_detector = None
    for layer in model.layers:
        if 'ParkinsonDetectorModel' in str(type(layer)):
            parkinson_detector = layer
            break

    if parkinson_detector is None:
        print("ParkinsonDetectorModel not found. Skipping Grad-CAM analysis.")
        return

    def get_conv_and_output(inputs):
        x = parkinson_detector.reshape_in(inputs)
        x = parkinson_detector.conv1a(x)
        x = parkinson_detector.conv1b(x)
        x = parkinson_detector.pool1(x)
        x = parkinson_detector.drop1(x, training=False)
        x = parkinson_detector.conv2a(x)
        conv_output = parkinson_detector.conv2b(x)
        x = parkinson_detector.pool2(conv_output)
        x = parkinson_detector.drop2(x, training=False)
        cnn_flat = parkinson_detector.flatten_cnn(x)

        shape = tf.shape(x)
        sequence = tf.reshape(x, [-1, shape[1] * shape[2], shape[3]])
        att_out = parkinson_detector.attention(query=sequence, key=sequence, value=sequence)
        att_flat = parkinson_detector.flatten_att(att_out)
        lstm_seq = parkinson_detector.lstm1(sequence)
        lstm_out = parkinson_detector.lstm2(lstm_seq)
        lstm_out = parkinson_detector.drop_lstm(lstm_out, training=False)
        concatenated = parkinson_detector.concat([cnn_flat, att_flat, lstm_out])
        bottleneck = parkinson_detector.dense_bottleneck(concatenated)
        final_output = parkinson_detector.dense_output(bottleneck)

        return conv_output, final_output

    # Sample selection
    parkinson_indices = np.where(y_test == 1)[0]
    healthy_indices = np.where(y_test == 0)[0]

    n_samples = min(num_samples, min(len(parkinson_indices), len(healthy_indices)))

    selected_pd = np.random.choice(parkinson_indices, n_samples, replace=False) if len(parkinson_indices) > 0 else []
    selected_hc = np.random.choice(healthy_indices, n_samples, replace=False) if len(healthy_indices) > 0 else []

    def calculate_gradcam_heatmaps(indices):
        heatmaps = []
        for i in tqdm(indices, desc="Computing Grad-CAM"):
            img = X_test[i:i+1]
            with tf.GradientTape() as tape:
                img_tensor = tf.cast(img, tf.float32)
                tape.watch(img_tensor)
                conv_outputs, preds = get_conv_and_output(img_tensor)
                loss = preds[:, 0]

            grads = tape.gradient(loss, conv_outputs)
            pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
            conv_outputs_np = conv_outputs[0].numpy()
            pooled_grads_np = pooled_grads.numpy()

            heatmap = np.zeros(conv_outputs_np.shape[:-1])
            for j in range(conv_outputs_np.shape[-1]):
                heatmap += pooled_grads_np[j] * conv_outputs_np[:, :, j]

            heatmap = np.maximum(heatmap, 0)
            heatmap /= (heatmap.max() + 1e-10)
            heatmaps.append(heatmap)

        return heatmaps

    # Calculate heatmaps
    pd_heatmaps = calculate_gradcam_heatmaps(selected_pd) if len(selected_pd) > 0 else []
    hc_heatmaps = calculate_gradcam_heatmaps(selected_hc) if len(selected_hc) > 0 else []

    # Average heatmaps
    avg_pd_heatmap = np.mean(pd_heatmaps, axis=0) if pd_heatmaps else np.zeros((X_test.shape[1], X_test.shape[2]))
    avg_hc_heatmap = np.mean(hc_heatmaps, axis=0) if hc_heatmaps else np.zeros((X_test.shape[1], X_test.shape[2]))

    # Resize to original input dimensions
    original_height, original_width = X_test.shape[1], X_test.shape[2]

    upscaled_pd = resize(avg_pd_heatmap, (original_height, original_width),
                        order=3, mode='reflect', anti_aliasing=True)
    upscaled_hc = resize(avg_hc_heatmap, (original_height, original_width),
                        order=3, mode='reflect', anti_aliasing=True)

    # Create visualization
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))

    # Sample inputs
    if len(selected_pd) > 0:
        sample_pd = X_test[selected_pd[0]]
        axes[0, 0].imshow(sample_pd, cmap='viridis', aspect='auto')
        axes[0, 0].set_title(f'Parkinson Sample {selected_pd[0]}')
        axes[0, 0].set_xlabel('Time Steps')
        axes[0, 0].set_ylabel('Features')

    # Parkinson heatmap
    im1 = axes[0, 1].imshow(upscaled_pd, cmap='jet', aspect='auto')
    axes[0, 1].set_title(f'Avg. Parkinson Attention ({len(selected_pd)} samples)')
    axes[0, 1].set_xlabel('Time Steps')
    axes[0, 1].set_ylabel('Features')
    plt.colorbar(im1, ax=axes[0, 1])

    # Feature map
    feature_map_info = generate_feature_map_info(X_test.shape)
    axes[0, 2].imshow(feature_map_info['color_mask'], cmap=feature_map_info['colors'], aspect='auto')
    axes[0, 2].set_title('Feature Layout')
    axes[0, 2].set_xlabel('Time Steps')
    axes[0, 2].set_ylabel('Features')
    axes[0, 2].legend(handles=feature_map_info['legend_patches'],
                     loc='center left', bbox_to_anchor=(1, 0.5))

    # Healthy samples
    if len(selected_hc) > 0:
        sample_hc = X_test[selected_hc[0]]
        axes[1, 0].imshow(sample_hc, cmap='viridis', aspect='auto')
        axes[1, 0].set_title(f'Healthy Sample {selected_hc[0]}')
        axes[1, 0].set_xlabel('Time Steps')
        axes[1, 0].set_ylabel('Features')

    # Healthy heatmap
    im2 = axes[1, 1].imshow(upscaled_hc, cmap='jet', aspect='auto')
    axes[1, 1].set_title(f'Avg. Healthy Attention ({len(selected_hc)} samples)')
    axes[1, 1].set_xlabel('Time Steps')
    axes[1, 1].set_ylabel('Features')
    plt.colorbar(im2, ax=axes[1, 1])

    # Difference heatmap
    diff_heatmap = upscaled_pd - upscaled_hc
    max_diff = np.max(np.abs(diff_heatmap))
    im3 = axes[1, 2].imshow(diff_heatmap, cmap='seismic', aspect='auto',
                           vmin=-max_diff, vmax=max_diff)
    axes[1, 2].set_title('Attention Difference (PD - HC)')
    axes[1, 2].set_xlabel('Time Steps')
    axes[1, 2].set_ylabel('Features')
    plt.colorbar(im3, ax=axes[1, 2])

    plt.suptitle('Grad-CAM Analysis: Model Attention Patterns', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(os.path.join(output_path, 'gradcam_comprehensive.png'), dpi=300, bbox_inches='tight')
    plt.close()

    print(f"Grad-CAM analysis complete. Results saved to {output_path}")

# =============================================================================
# --- Cross-Validation Analysis ---
# =============================================================================

def run_cross_validation(X, y, n_folds=5):
    """Perform stratified cross-validation"""
    print(f"\n--- Running {n_folds}-Fold Cross-Validation ---")

    skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)

    cv_results = {
        'fold': [], 'accuracy': [], 'precision': [], 'recall': [],
        'f1': [], 'auc': [], 'specificity': []
    }

    fold_histories = []

    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        print(f"\nTraining fold {fold + 1}/{n_folds}")

        X_train_fold, X_val_fold = X[train_idx], X[val_idx]
        y_train_fold, y_val_fold = y[train_idx], y[val_idx]

        # Build model
        model = build_model(input_shape=(X.shape[1], X.shape[2]))
        model.compile(
            optimizer=Adam(learning_rate=LEARNING_RATE),
            loss='binary_crossentropy',
            metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
        )

        # Train with early stopping
        early_stop = EarlyStopping(monitor='val_auc', patience=10, restore_best_weights=True)

        history = model.fit(
            X_train_fold, y_train_fold,
            validation_data=(X_val_fold, y_val_fold),
            epochs=EPOCHS,
            batch_size=BATCH_SIZE,
            callbacks=[early_stop],
            verbose=0
        )

        fold_histories.append(history.history)

        # Evaluate
        y_pred_proba = model.predict(X_val_fold, verbose=0)
        y_pred = (y_pred_proba > 0.5).astype(int)

        # Calculate metrics
        cm = confusion_matrix(y_val_fold, y_pred)
        tn, fp, fn, tp = cm.ravel()

        cv_results['fold'].append(fold + 1)
        cv_results['accuracy'].append((tp + tn) / (tp + tn + fp + fn))
        cv_results['precision'].append(precision_score(y_val_fold, y_pred, zero_division=0))
        cv_results['recall'].append(recall_score(y_val_fold, y_pred, zero_division=0))
        cv_results['f1'].append(f1_score(y_val_fold, y_pred, zero_division=0))
        cv_results['auc'].append(roc_auc_score(y_val_fold, y_pred_proba))
        cv_results['specificity'].append(tn / (tn + fp) if (tn + fp) > 0 else 0)

    # Save CV results
    cv_df = pd.DataFrame(cv_results)

    # Add summary statistics
    summary_stats = {}
    for metric in ['accuracy', 'precision', 'recall', 'f1', 'auc', 'specificity']:
        summary_stats[f'{metric}_mean'] = cv_df[metric].mean()
        summary_stats[f'{metric}_std'] = cv_df[metric].std()

    summary_df = pd.DataFrame([summary_stats])

    # Save results
    cv_df.to_csv(os.path.join(ANALYSIS_PATH, 'cross_validation_results.csv'), index=False)
    summary_df.to_csv(os.path.join(ANALYSIS_PATH, 'cross_validation_summary.csv'), index=False)

    # Plot CV results
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    metrics = ['accuracy', 'precision', 'recall', 'f1', 'auc', 'specificity']

    for i, metric in enumerate(metrics):
        ax = axes[i // 3, i % 3]
        ax.bar(cv_df['fold'], cv_df[metric])
        ax.axhline(y=cv_df[metric].mean(), color='red', linestyle='--',
                  label=f'Mean: {cv_df[metric].mean():.3f}')
        ax.set_title(f'{metric.upper()} across folds')
        ax.set_xlabel('Fold')
        ax.set_ylabel(metric.upper())
        ax.legend()
        ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(os.path.join(ANALYSIS_PATH, 'cross_validation_metrics.png'), dpi=300, bbox_inches='tight')
    plt.close()

    print(f"Cross-validation complete. Results saved to {ANALYSIS_PATH}")
    return cv_df, summary_df

# =============================================================================
# --- Model Architecture Visualization ---
# =============================================================================

def visualize_model_architecture(model, save_path):
    """Create model architecture visualization"""
    try:
        tf.keras.utils.plot_model(
            model,
            to_file=save_path,
            show_shapes=True,
            show_layer_names=True,
            rankdir='TB',
            expand_nested=True,
            dpi=96
        )
        print(f"Model architecture diagram saved to {save_path}")
    except Exception as e:
        print(f"Could not create model architecture diagram: {e}")

# =============================================================================
# --- Main Execution ---
# =============================================================================

def main():
    """Main execution function"""
    print("="*60)
    print("PARKINSON'S DISEASE DETECTION - COMPREHENSIVE ANALYSIS")
    print("="*60)

    # Load data
    X, y = load_data(FEATURES_FILE_PATH)
    print(f"Loaded data: {X.shape} features, {len(y)} labels")
    print(f"Class distribution: {np.bincount(y)}")

    # Split data
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )
    print(f"\nData split: Train {len(y_train)}, Test {len(y_test)}")

    # Generate feature analysis
    plot_feature_analysis(X, y, PLOTS_PATH)

    # Build and compile model
    model = build_model(input_shape=(X_train.shape[1], X_train.shape[2]))
    model.summary()

    # Visualize model architecture
    visualize_model_architecture(model, os.path.join(PLOTS_PATH, 'model_architecture.png'))

    optimizer = Adam(learning_rate=LEARNING_RATE)
    model.compile(
        optimizer=optimizer,
        loss='binary_crossentropy',
        metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
    )

    # Train model
    print("\n--- Starting Model Training ---")
    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        callbacks=[checkpoint_cb, early_stop_cb],
        verbose=1
    )

    # Save training history
    history_df = pd.DataFrame(history.history)
    history_df.to_csv(HISTORY_SAVE_PATH, index_label='epoch')
    print(f"Training history saved to {HISTORY_SAVE_PATH}")

    # Plot training history
    plot_training_history(history_df, os.path.join(PLOTS_PATH, 'training_history.png'))

    # Evaluate model
    print("\n--- Evaluating Model ---")
    y_pred_probabilities = model.predict(X_test)

    # Save metrics
    save_metrics_to_csv(y_test, y_pred_probabilities, EVALUATION_FILE_PATH)

    # Generate performance plots
    plot_model_performance(y_test, y_pred_probabilities, PLOTS_PATH)

    # Run cross-validation
    cv_results, cv_summary = run_cross_validation(X, y, n_folds=5)

    # Load best model for explainability analysis
    if os.path.exists(BEST_MODEL_PATH):
        print("\n--- Loading Best Model for Explainability Analysis ---")
        try:
            best_model = load_model(BEST_MODEL_PATH,
                                  custom_objects={'ParkinsonDetectorModel': ParkinsonDetectorModel})

            # SHAP Analysis
            run_full_shap_analysis(best_model, X_train, X_test, y_test,
                                 SHAP_OUTPUT_PATH, samples_per_class=50, top_n=20)

            # Grad-CAM Analysis
            run_gradcam_analysis(best_model, X_test, y_test, GRADCAM_OUTPUT_PATH, num_samples=50)

        except Exception as e:
            print(f"Error in explainability analysis: {e}")

    # Generate final summary report
    print("\n--- Generating Summary Report ---")

    # Load evaluation metrics
    eval_df = pd.read_csv(EVALUATION_FILE_PATH)

    summary_report = {
        'Dataset': dataset,
        'Mode': MODE,
        'Feature_Mode': FEATURE_MODE,
        'Total_Samples': len(y),
        'Training_Samples': len(y_train),
        'Test_Samples': len(y_test),
        'Input_Shape': f"{X.shape[1]}x{X.shape[2]}",
        'Model_Parameters': model.count_params(),
        'Training_Epochs': len(history_df),
        'Best_Val_AUC': max(history_df['val_auc']),
        'Final_Test_Metrics': eval_df.to_dict('records')
    }

    # Save summary
    with open(os.path.join(ANALYSIS_PATH, 'analysis_summary.txt'), 'w') as f:
        for key, value in summary_report.items():
            f.write(f"{key}: {value}\n")

    print("\n" + "="*60)
    print("ANALYSIS COMPLETE!")
    print("="*60)
    print(f"All results saved to: {MODEL_PATH}")
    print(f"- Plots: {PLOTS_PATH}")
    print(f"- SHAP Analysis: {SHAP_OUTPUT_PATH}")
    print(f"- Grad-CAM Analysis: {GRADCAM_OUTPUT_PATH}")
    print(f"- Comprehensive Analysis: {ANALYSIS_PATH}")
    print("="*60)

if __name__ == '__main__':
    main()


PARKINSON'S DISEASE DETECTION - COMPREHENSIVE ANALYSIS
--- Loading data from D:\Projects\Voice\Parkinson-s-Disease-Detector-Using-AI\Parkinson-s-Disease-Detector-Using-AI\1\UAMS\data\features_A_DEFAULT.npz ---
Loaded data: (328, 30, 188) features, 328 labels
Class distribution: [164 164]

Data split: Train 262, Test 66
Generating feature analysis plots...
Feature analysis plots saved to D:\Projects\Voice\Parkinson-s-Disease-Detector-Using-AI\Parkinson-s-Disease-Detector-Using-AI\1\UAMS\results_A_DEFAULT\nca_cnn_lstm\plots
--- Building the model ---
Model built successfully.


You must install graphviz (see instructions at https://graphviz.gitlab.io/download/) for `plot_model` to work.
Model architecture diagram saved to D:\Projects\Voice\Parkinson-s-Disease-Detector-Using-AI\Parkinson-s-Disease-Detector-Using-AI\1\UAMS\results_A_DEFAULT\nca_cnn_lstm\plots\model_architecture.png

--- Starting Model Training ---
Epoch 1/30
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 733ms/step - accuracy: 0.5322 - auc: 0.5194 - loss: 8.9707
Epoch 1: val_auc improved from None to 0.58770, saving model to D:\Projects\Voice\Parkinson-s-Disease-Detector-Using-AI\Parkinson-s-Disease-Detector-Using-AI\1\UAMS\results_A_DEFAULT\nca_cnn_lstm\best_model.keras
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 1s/step - accuracy: 0.5267 - auc: 0.5081 - loss: 7.3131 - val_accuracy: 0.5000 - val_auc: 0.5877 - val_loss: 2.8764
Epoch 2/30
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 674ms/step - accuracy: 0.5691 - auc: 0.6384 - loss: 3.0623
Epo