# **Advancing Automated Cervical Cancer Diagnosis Using Attention CNN and SHAP-Enhanced Feature Selection**

## Project Overview

This project focuses on the automated classification of cervical cancer using Liquid-Based Cytology (LBC) Pap smear images. The dataset, provided by Hussain et al., comprises **963 high-resolution images** from **460 patients**, categorized into four classes:

- **High-grade Squamous Intraepithelial Lesion (HSIL)**
- **Low-grade Squamous Intraepithelial Lesion (LSIL)**
- **Negative for Intraepithelial Lesion or Malignancy (NILM)**
- **Squamous Cell Carcinoma (SCC)**


## Objectives

- Preprocess high-resolution Pap smear images to enhance quality and ensure uniform input for model training.
- Extract discriminative features using advanced CNN architectures.
- Evaluate the performance of various classifiers, including a custom **AttCNN** designed for medical image analysis.
- Compare the effectiveness of different feature extractor-classifier combinations using metrics such as:
  - **Accuracy (ACC)**
  - **Area Under the Curve (AUC)**
  - **Precision (PRE)**
  - **Specificity (SP)**
  - **Sensitivity (SN)**
  - **F1 Score**
  - **Matthews Correlation Coefficient (MCC)**



## Dataset

The dataset consists of **963 LBC images** captured at **400x magnification**, split into training (75%) and testing (25%) sets. The training set is augmented to **2000 images** to balance classes and improve model robustness.  
The classes include:

- **NILM:** 613 images (normal)
- **HSIL:** 113 images (abnormal)
- **LSIL:** 163 images (abnormal)
- **SCC:** 74 images (abnormal)

The dataset can be found at:  
**Hussain, Elima; B. Mahanta, Lipi; Borah, Himakshi; Ray Das, Chandana (2019), “Liquid based cytology pap smear images for multi-class diagnosis of cervical cancer”, Mendeley Data, V2, doi: [10.17632/zddtpgzv63.2](https://doi.org/10.17632/zddtpgzv63.2)**

This notebook implements the **preprocessing pipeline**, **feature extraction**, **classification**, and **result visualization** **without feature selection** to establish a **baseline performance**.


import os
import cv2
import joblib
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.metrics import (
    accuracy_score, roc_auc_score, matthews_corrcoef, confusion_matrix,
    classification_report, precision_score, f1_score
)
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
import lightgbm as lgb
from xgboost import XGBClassifier
from sklearn.neighbors import KNeighborsClassifier
from lightgbm import LGBMClassifier
from sklearn.manifold import TSNE
from sklearn.metrics import roc_curve, auc
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Input, Conv2D, LSTM, Dense, Dropout, Flatten, MaxPooling2D,
                                     TimeDistributed, GlobalAveragePooling2D, Multiply,
                                     BatchNormalization, LeakyReLU, Bidirectional, RNN)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.applications import DenseNet201, VGG16, ResNet152V2

warnings.filterwarnings('ignore')


In [None]:
# Save directory
save_dir = "models/shap/v2/"
os.makedirs(save_dir, exist_ok=True)
save_path = "figures/shap/v2/"
os.makedirs(save_path, exist_ok=True)

In [None]:
train_path = "dataset/Liquid based cytology/train/"
test_path = 'dataset/Liquid based cytology/valid/'
IMG_HEIGHT, IMG_WIDTH = 224, 224
NUM_CLASSES = 4
BATCH_SIZE = 16
EPOCHS = 50
CLASSES = 4

In [None]:
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255,
    horizontal_flip=True,
    height_shift_range=0.2,
    width_shift_range=0.2,
    zoom_range=0.2,
    rotation_range=20,
    shear_range=0.2
)
test_datagen =  tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255
)

In [None]:
train_data = train_datagen.flow_from_directory(
    train_path,  
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    # subset = 'training',
    shuffle=False  # Keep order for feature extraction
)

test_data = test_datagen.flow_from_directory(
    test_path,  
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    # subset='validation',
    shuffle=False
)

In [None]:
def build_extended_densenet121(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3)):
    densenet121_base = DenseNet121(weights='imagenet', include_top=False, input_shape=input_shape)
    x = densenet121_base.get_layer('conv5_block32_concat').output  # 
    x = BatchNormalization()(x)
    x = Conv2D(1024, (3, 3), padding='valid', activation='relu', name='densenet_custom_conv1')(x)
    x = Conv2D(1024, (3, 3), padding='valid', activation='relu', name='densenet_custom_conv2')(x)
    x = BatchNormalization()(x)
    model = Model(inputs=densenet121_base.input, outputs=x, name='Extended_DenseNet121')
    return model

def extract_features(generator, model):
    features = model.predict(generator, steps=len(generator))
    labels = generator.labels
    return features, labels

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix, roc_curve, auc
import seaborn as sns
import pandas as pd
import os
from sklearn.manifold import TSNE

def visualize_confusion_matrix(y_test, y_pred, model_name=''):
    plt.rcParams['font.family'] = 'DejaVu Serif'

    title = ''.join(c.upper() if c.isalpha() else c for c in model_name)
    # Compute confusion matrix
    cm = confusion_matrix(y_test, y_pred)

    # Infer class labels from y_test (since train_path isn't a parameter)
    class_labels = sorted(os.listdir(train_path))

    # Create figure and axes
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Plot heatmap with counts
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_labels, yticklabels=class_labels,
                annot_kws={"size": 16}, ax=ax)

    # Compute percentages and add them below counts in all cells
    percentages = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            percentage_text = f"{percentages[i, j]:.1f}%"
            ax.text(j + 0.5, i + 0.7, percentage_text, ha='center', va='center',
                color='black', fontsize=16, bbox=dict(facecolor='white', edgecolor='white', boxstyle='round,pad=0.3'))


    # Set labels and title
    ax.set_title(title, fontsize=20, fontweight='bold', pad=16)
    ax.set_ylabel('True Label', fontsize=20)
    ax.set_xlabel('Predicted Label', fontsize=17)
    ax.tick_params(axis='both', labelsize=14)

    # Sky-blue border (no grid as requested)
    for spine in ax.spines.values():
        spine.set_edgecolor('skyblue')
        spine.set_linewidth(1.2)

    # Save and show
    os.makedirs(save_path, exist_ok=True)
    plt.savefig(os.path.join(save_path, f'{model_name}_confusion_matrix.png'), dpi=1000, bbox_inches='tight')
    plt.show()

    # Reset font
    plt.rcParams['font.family'] = 'sans-serif'

def visualize_features_scatter(features, labels, class_names, title, filename='scatter_plot'):
    title = ''.join(c.upper() if c.isalpha() else c for c in title)

    if len(features.shape) > 2:
        features_flat = features.reshape(features.shape[0], -1)
    else:
        features_flat = features

    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    features_2d = tsne.fit_transform(features_flat)
    print(f"Features reduced to 2D: {features_2d.shape}")

    plt.rcParams['font.family'] = 'DejaVu Serif'

    unique_labels = np.unique(labels)
    fig, ax = plt.subplots(figsize=(10, 8))

    colors = ['#E63946', '#87CEEB', '#9B5DE5', '#2A9D8F']
    if len(unique_labels) > len(colors):
        raise ValueError("More classes than predefined colors! Add more to the list.")

    for class_idx in unique_labels:
        mask = labels == class_idx
        class_name = class_names[class_idx]  # Get name by index
        print(f"Plotting class: {class_name}, Points: {np.sum(mask)}")
        ax.scatter(features_2d[mask, 0], features_2d[mask, 1],
                   label=f'{class_name}', color=colors[class_idx], alpha=0.6, s=45,zorder=3)

    ax.set_title(title, fontsize=20, fontweight='bold', pad=16)
    ax.set_xlabel('t-SNE Component 1', fontsize=16)
    ax.set_ylabel('t-SNE Component 2', fontsize=16)

    # Customized legend
    legend = ax.legend(
        title='Classes',
        fontsize=16,
        title_fontsize=16,
        loc='upper left',
        #bbox_to_anchor=(1.25, -0.02)
    )
    legend.get_frame().set_facecolor('white')        # Legend box background
    legend.get_frame().set_edgecolor('#cec0c0')      # Border color
    legend.get_frame().set_linewidth(1.2)
    legend.get_frame().set_alpha(0.5)
    # ax.legend(title='Classes', fontsize=16, title_fontsize=16, loc='lower right', bbox_to_anchor=(1.30, -0.01))
    ax.tick_params(axis='both', labelsize=12)

    ax.minorticks_on()
    ax.tick_params(which='minor', length=0)
    ax.grid(which='major', linestyle='-', linewidth=0.6, color='#cec0c0', alpha=0.9)
    ax.grid(which='minor', linestyle=':', linewidth=0.5, color='#cec0c0', alpha=0.5)
    ax.set_xticks(np.arange(-70, 80, 10))  # From -60 to 60, step by 5
    ax.set_yticks(np.arange(-50, 60, 10))  # From -40 to 40, step by 5

    for spine in ax.spines.values():
        spine.set_edgecolor('#cec0c0')
        spine.set_linewidth(1.2)

    os.makedirs(save_path, exist_ok=True)
    plt.savefig(os.path.join(save_path, f'{filename}.png'), dpi=1000, bbox_inches='tight')
    plt.show()

    plt.rcParams['font.family'] = 'sans-serif'

def plot_roc_curves(roc_data, y_test_cat, extractor=''):
    # Set font
    plt.rcParams['font.family'] = 'DejaVu Serif'

    classifiers = list(roc_data.keys())

    # Common FPR points for interpolation
    mean_fpr = np.linspace(0, 1, 100)

    # Create figure and axes
    fig, ax = plt.subplots(figsize=(10, 9))

    # Plot ROC curves for each classifier
    for clf_name in classifiers:
        y_pred_probs = roc_data[clf_name]
        mean_tpr = np.zeros_like(mean_fpr)
        macro_aucs = []

        num_classes = y_test_cat.shape[1]
        for i in range(num_classes):
            fpr, tpr, _ = roc_curve(y_test_cat[:, i], y_pred_probs[:, i])
            macro_aucs.append(auc(fpr, tpr))
            mean_tpr += np.interp(mean_fpr, fpr, tpr)  # Interpolate TPR values

        mean_tpr /= num_classes  # Average over classes
        macro_auc = np.mean(macro_aucs)  # Compute macro AUC

        ax.plot(mean_fpr, mean_tpr, lw=2, label=f'{clf_name} (AUC = {macro_auc:.4f})')

    # Plot diagonal line
    ax.plot([0, 1], [0, 1], 'k--', lw=2)

    # Set axis limits and ticks
    ax.set_xlim([-0.02, 1.0])
    ax.set_ylim([0.0, 1.02])
    ax.set_xticks(np.arange(0.0, 1.1, 0.1))
    ax.set_yticks(np.arange(0.0, 1.1, 0.1))

    # Set labels and title
    ax.set_xlabel('False Positive Rate', fontsize=16)
    ax.set_ylabel('True Positive Rate', fontsize=16)
    ax.set_title('ROC Curves per Classifier', fontsize=20, fontweight='bold')
    ax.tick_params(axis='both', labelsize=14)

    # Customized legend
    legend = ax.legend(
        fontsize=17,
        title_fontsize=17,
        loc='lower right'
        # bbox_to_anchor=(1.42, -0.01)
    )
    legend.get_frame().set_facecolor('white')        # Legend box background
    legend.get_frame().set_edgecolor('#cec0c0')      # Border color
    legend.get_frame().set_linewidth(1.2)
    legend.get_frame().set_alpha(0.7)

    # External legend at bottom right
    # ax.legend(loc='lower right', bbox_to_anchor=(1.42, -0.01), fontsize=12)

    # Sky-blue grid
    ax.minorticks_on()
    ax.tick_params(which='minor', length=0)
    ax.grid(which='major', linestyle='-', linewidth=0.6, color='#cec0c0', alpha=0.9)
    ax.grid(which='minor', linestyle=':', linewidth=0.5, color='#cec0c0', alpha=0.5)

    # Sky-blue border
    for spine in ax.spines.values():
        spine.set_edgecolor('#cec0c0')
        spine.set_linewidth(1.2)

    # Save and show
    

    #path = save_path + '_'+ extractor
    os.makedirs(save_path, exist_ok=True)
    filename = extractor + '_roc_curves.png'
    plt.savefig(os.path.join(save_path, filename), dpi=1000, bbox_inches='tight')
    plt.show()

    # Reset font
    plt.rcParams['font.family'] = 'sans-serif'

def plot_accuracy_comparison(accuracy_dict):
    # Set font
    plt.rcParams['font.family'] = 'DejaVu Serif'

    # Convert accuracy_dict to DataFrame
    df = pd.DataFrame(accuracy_dict)

    # Create figure and axes
    fig, ax = plt.subplots(figsize=(13, 10))  # Increased width to fit the legend
    # Sky-blue grid
    ax.minorticks_on()
    ax.tick_params(which='minor', length=0)
    ax.grid(which='major', axis='y', linestyle='-', linewidth=0.6, color='#cec0c0', alpha=1)
    ax.grid(which='minor', axis='y', linestyle=':', linewidth=0.5, color='#cec0c0', alpha=0.8)

    # Plot barplot
    sns.barplot(
        x='Feature Extractor', 
        y='Accuracy', 
        hue='Classifier', 
        data=df, 
        ax=ax, 
        ci=None, 
        zorder=3  # Bars drawn on top of grid
    )

    # Set labels and title
    ax.set_title('Accuracy Comparison Across Feature Extractors and Classifiers', fontsize=22, fontweight='bold', pad=16)
    ax.set_xlabel('Feature Extractor', fontsize=16)
    ax.set_ylabel('Accuracy', fontsize=16)

    # Customized legend
    legend = ax.legend(
        fontsize=18,
        title_fontsize=19,
        loc='lower right',
        bbox_to_anchor=(1.38, -0.01)
    )
    legend.get_frame().set_facecolor('white')        # Legend box background
    legend.get_frame().set_edgecolor('#cec0c0')      # Border color
    legend.get_frame().set_linewidth(1.2)
    legend.get_frame().set_alpha(0.95)

    # External legend at bottom right
    # ax.legend(title='Classifier', fontsize=16, title_fontsize=16, loc='lower right', bbox_to_anchor=(1.36, -0.01), ncol=1)

    # Set ticks
    ax.tick_params(axis='both', labelsize=16)
    ax.set_yticks(np.arange(0.0, 1.0, 0.1))

    

    # Sky-blue border
    for spine in ax.spines.values():
        spine.set_edgecolor('#cec0c0')
        spine.set_linewidth(1.2)

    # Save and show
    plt.tight_layout()  # Prevents clipping
    os.makedirs(save_path, exist_ok=True)
    plt.savefig(os.path.join(save_path, 'accuracy_comparison.png'), dpi=1000, bbox_inches='tight')
    plt.show()

    # Reset font
    plt.rcParams['font.family'] = 'sans-serif'

def plot_classifier_comparison(metrics_dict, extractor = ''):
    # Set font
    plt.rcParams['font.family'] = 'DejaVu Serif'

    # Convert metrics_dict to DataFrame and melt it
    metrics_df = pd.DataFrame(metrics_dict)
    df_melted = metrics_df.melt(id_vars='Classifier', var_name='Metric', value_name='Value')

    # Create figure and axes
    fig, ax = plt.subplots(figsize=(18, 10))  # Increased width to fit the legend
    


    # Plot barplot
    sns.barplot(x='Classifier', y='Value', hue='Metric', data=df_melted, ax=ax,zorder=3,width=0.9)

    # Set labels and title
    ax.set_title('Classifier Performance Comparison', fontsize=22, fontweight='bold', pad=16)
    ax.set_xlabel('Classifier', fontsize=22)
    ax.set_ylabel('Score', fontsize=22)

    # Customized legend
    legend = ax.legend(
        fontsize=20,
        title_fontsize=20,
        loc='lower right',
        bbox_to_anchor=(1.19, -0.01)
    )
    legend.get_frame().set_facecolor('white')        # Legend box background
    legend.get_frame().set_edgecolor('#cec0c0')      # Border color
    legend.get_frame().set_linewidth(1.2)
    legend.get_frame().set_alpha(0.95)

    # External legend at bottom right
    # ax.legend(title='Metric', fontsize=15, title_fontsize=15, loc='lower right', bbox_to_anchor=(1.54, -0.01))

    # Set ticks
    ax.tick_params(axis='both', labelsize=17)
    ax.set_yticks(np.arange(0.0, 1.1, 0.1))
    # Sky-blue grid
    ax.minorticks_on()
    ax.tick_params(which='minor', length=0)
    ax.grid(which='major', axis='y', linestyle='-', linewidth=0.6, color='#cec0c0', alpha=1)
    ax.grid(which='minor', axis='y', linestyle=':', linewidth=0.5, color='#cec0c0', alpha=0.8)

    
    # Sky-blue border
    for spine in ax.spines.values():
        spine.set_edgecolor('#cec0c0')
        spine.set_linewidth(1.2)

    # Save and show
    plt.tight_layout()  # Prevents clipping
    os.makedirs(save_path, exist_ok=True)
    filename = extractor + '_classifier_comparison.png'
    plt.savefig(os.path.join(save_path, filename), dpi=1000, bbox_inches='tight')
    plt.show()

    # Reset font
    plt.rcParams['font.family'] = 'sans-serif'

In [None]:
def calc_multiclass_metrics(y_true, y_pred):
    cm = confusion_matrix(y_true, y_pred)
    n_classes = cm.shape[0]
    expected_classes = CLASSES
    classes = np.arange(expected_classes)

    sensitivities = []
    specificities = []
    for i in range(expected_classes):
        if i < n_classes:
            tp = cm[i, i]
            fn = cm[i, :].sum() - tp
            fp = cm[:, i].sum() - tp
            tn = cm.sum() - (tp + fp + fn)
        else:
            tp = fn = fp = tn = 0

        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        sensitivities.append(sensitivity)
        specificities.append(specificity)

    macro_sensitivity = np.mean(sensitivities)
    macro_specificity = np.mean(specificities)

    if n_classes < expected_classes:
        print(f"Warning: Only {n_classes} classes detected in confusion matrix, expected {expected_classes}")

    return sensitivities, specificities, macro_sensitivity, macro_specificity

In [None]:
class Attention2D(tf.keras.layers.Layer):
    def __init__(self, filters, name_prefix, **kwargs):
        super(Attention2D, self).__init__(**kwargs)
        self.filters = filters
        self.name_prefix = name_prefix
        self.conv1 = Conv2D(filters, kernel_size=1, padding='same', name=f'{name_prefix}_attn_conv')
        self.bn = BatchNormalization(name=f'{name_prefix}_attn_bn')
        self.relu = LeakyReLU(name=f'{name_prefix}_attn_lrelu')
        self.conv2 = Conv2D(filters, kernel_size=1, padding='same', activation='sigmoid')

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn(x)
        x = self.relu(x)
        x = self.conv2(x)
        return Multiply()([inputs, x])

    def compute_output_shape(self, input_shape):
        return input_shape

    def get_config(self):
        config = super(Attention2D, self).get_config()
        config.update({"filters": self.filters, "name_prefix": self.name_prefix})
        return config

In [None]:
def build_2d_cnn_lstm(input_shape, num_classes):
    inputs = Input(shape=input_shape)

    # Wrap tf.expand_dims in a Lambda layer
    x = tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=1))(inputs)
    x = TimeDistributed(Conv2D(256, (3, 3), padding='same', activation='relu'))(x)
    x = TimeDistributed(BatchNormalization())(x)
    x = TimeDistributed(Attention2D(256, 'attn1'))(x)

    x = TimeDistributed(Conv2D(128, (3, 3), padding='same', activation='relu'))(x)
    x = TimeDistributed(BatchNormalization())(x)
    x = TimeDistributed(Attention2D(128, 'attn2'))(x)

    x = TimeDistributed(Flatten())(x)
    x = LSTM(256, return_sequences=False)(x)
    x = Dropout(0.5)(x)

    outputs = Dense(num_classes, activation='softmax')(x)

    model = Model(inputs, outputs)
    model.compile(optimizer=Adam(learning_rate=0.00001), loss='categorical_crossentropy', metrics=['accuracy'])
    return model

In [None]:
def train_cnn_lstm(X_train, y_train_cat, X_test, y_test, y_test_cat, model_name):
    input_shape = X_train.shape[1:]
    model = build_2d_cnn_lstm(input_shape, NUM_CLASSES)
    callbacks = [
        EarlyStopping(patience=20, restore_best_weights=True),
        ReduceLROnPlateau(factor=0.5, patience=5)
    ]
    history = model.fit(
        X_train, y_train_cat,
        validation_data=(X_test, y_test_cat),
        epochs=500,
        batch_size=BATCH_SIZE,
        # callbacks=callbacks,
        verbose=1
    )

    plt.rcParams['font.family'] = 'DejaVu Serif'
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    ax1.plot(history.history['loss'], label='Train Loss', color='#E63946')
    ax1.plot(history.history['val_loss'], label='Validation Loss', color='#2A9D8F')
    ax1.set_title('AttCNN Loss', fontsize=18, fontweight='bold')
    ax1.set_xlabel('Epoch', fontsize=16)
    ax1.set_ylabel('Loss', fontsize=16)
    # ax1.legend(fontsize=15)
    ax1.tick_params(axis='both', labelsize=12)
    ax1.grid(which='major', linestyle='-', linewidth=0.6, color='#cbe2ff', alpha=1)
    ax1.grid(which='minor', linestyle=':', linewidth=0.5, color='#cbe2ff', alpha=0.8)
    for spine in ax1.spines.values():
        spine.set_edgecolor('skyblue')
        spine.set_linewidth(1.2)

    # Customized legend
    legend = ax1.legend(
        fontsize=16,
        title_fontsize=16
    )
    legend.get_frame().set_facecolor('white')        # Legend box background
    legend.get_frame().set_edgecolor('skyblue')      # Border color
    legend.get_frame().set_linewidth(0.5)
    legend.get_frame().set_alpha(0.1)
    
   
    ax2.plot(history.history['accuracy'], label='Train Accuracy', color='#E63946')
    ax2.plot(history.history['val_accuracy'], label='Validation Accuracy', color='#2A9D8F')
    ax2.set_title('AttCNN Accuracy', fontsize=16, fontweight='bold')
    ax2.set_xlabel('Epoch', fontsize=16)
    ax2.set_ylabel('Accuracy', fontsize=16)
    

    # Customized legend
    legend = ax2.legend(
        fontsize=16,
        title_fontsize=16
    )
    legend.get_frame().set_facecolor('white')        # Legend box background
    legend.get_frame().set_edgecolor('skyblue')      # Border color
    legend.get_frame().set_linewidth(0.5)
    legend.get_frame().set_alpha(0.1)
    # ax2.legend(fontsize=15)
    ax2.tick_params(axis='both', labelsize=12)
    ax2.grid(which='major', linestyle='-', linewidth=0.6, color='#cbe2ff', alpha=1)
    ax2.grid(which='minor', linestyle=':', linewidth=0.5, color='#cbe2ff', alpha=0.8)
    for spine in ax2.spines.values():
        spine.set_edgecolor('skyblue')
        spine.set_linewidth(1.2)
    plt.tight_layout()
    os.makedirs(save_path, exist_ok=True)
    plt.savefig(os.path.join(save_path, f'{model_name}_att_cnn_training_curves.png'), dpi=1000, bbox_inches='tight')
    plt.show()
    plt.rcParams['font.family'] = 'sans-serif'

    y_pred_probs = model.predict(X_test)
    y_pred = np.argmax(y_pred_probs, axis=1)
    classes = np.arange(CLASSES)

    print(f"\n--- {model_name}_att_cnn ---")
    print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))
    print("Classification Report:\n", classification_report(y_test, y_pred, digits=4))
    sensitivities, specificities, macro_sens, macro_spec = calc_multiclass_metrics(y_test, y_pred)
    for i, cls in enumerate(classes):
        print(f"Class {cls} - Sensitivity: {sensitivities[i]:.4f}, Specificity: {specificities[i]:.4f}")
    print(f"Macro-Averaged Sensitivity: {macro_sens:.4f}")
    print(f"Macro-Averaged Specificity: {macro_spec:.4f}")
    auc_score = roc_auc_score(y_test_cat, y_pred_probs, multi_class='ovr')
    print(f"AUC (OvR): {auc_score:.4f}")
    mcc = matthews_corrcoef(y_test, y_pred)
    print(f"MCC: {mcc:.4f}")

    visualize_confusion_matrix(y_test, y_pred, f"{model_name}_att_cnn")
    model.save(os.path.join(save_dir, f"{model_name}_att_cnn.h5"))
    print(f"Model saved as {save_dir}{model_name}_att_cnn.h5")

    return {
        'metrics': {
            'accuracy': accuracy_score(y_test, y_pred),
            'auc': auc_score,
            'precision': precision_score(y_test, y_pred, average='macro'),
            'specificity': macro_spec,
            'sensitivity': macro_sens,
            'f1': f1_score(y_test, y_pred, average='macro'),
            'mcc': mcc
        },
        'y_pred_probs': y_pred_probs
    }

In [None]:
import shap
def shap_feature_selection_2d(X_train_2d, y_train, X_test_2d, n_features_to_select, f_name):
    X_train_flat = X_train_2d.reshape(X_train_2d.shape[0], -1)
    X_test_flat = X_test_2d.reshape(X_test_2d.shape[0], -1)

    model = XGBClassifier(eval_metric='mlogloss',tree_method = 'gpu_hist')
    print("Training XGBoost for SHAP...")
    model.fit(X_train_flat, y_train)

    booster = model.get_booster()
    if 'gpu' not in str(booster.attributes()).lower():
        print("Warning: XGBoost is not using GPU! Falling back to CPU.")

    explainer = shap.TreeExplainer(model, X_train_flat)
    print("Computing SHAP values (on CPU)...")
    X_test_flat_sample = shap.sample(X_test_flat, 200)
    shap_values = explainer.shap_values(X_test_flat_sample)

    # Plot SHAP summary
    plt.rcParams['font.family'] = 'sans-serif'
    class_names = sorted(os.listdir(train_path))
    shap.summary_plot(shap_values, X_test_flat_sample, plot_type="bar", class_names=class_names, show=False)
    plt.title('SHAP Feature Importance Across Classes', fontsize=18, fontweight='bold', pad=16)
    os.makedirs(save_path, exist_ok=True)
    plt.savefig(os.path.join(save_path, f_name), dpi=1000, bbox_inches='tight')
    plt.show()
    plt.rcParams['font.family'] = 'sans-serif'

    if isinstance(shap_values, list):
        shap_abs = np.mean([np.abs(shap_val) for shap_val in shap_values], axis=0)
    else:
        shap_abs = np.abs(shap_values)

    n_channels = X_train_2d.shape[-1]
    shap_per_channel = np.zeros(n_channels)
    pixels_per_channel = X_train_2d.shape[1] * X_train_2d.shape[2]
    for c in range(n_channels):
        start_idx = c * pixels_per_channel
        end_idx = (c + 1) * pixels_per_channel
        shap_per_channel[c] = np.mean(shap_abs[:, start_idx:end_idx])

    # top_indices = np.argsort(shap_per_channel)[::-1][:n_features_to_select]
    # X_train_selected = X_train_2d[..., top_indices]
    # X_test_selected = X_test_2d[..., top_indices]

    # print(f"Selected features shape (train): {X_train_selected.shape}")
    # print(f"Selected features shape (test): {X_test_selected.shape}")
    # return X_train_selected, X_test_selected, top_indices
    return shap_per_channel

In [None]:
# Extract features
extended_densenet121 = build_extended_densenet121()
train_features_2d, y_train = extract_features(train_data, extended_densenet121)
test_features_2d, y_test = extract_features(test_data, extended_densenet121)
print(f"Train features shape: {train_features_2d.shape}")
print(f"Test features shape: {test_features_2d.shape}")

In [None]:
# Prepare one-hot labels for CNN-LSTM
y_train_cat = tf.keras.utils.to_categorical(y_train, NUM_CLASSES)
y_test_cat = tf.keras.utils.to_categorical(y_test, NUM_CLASSES)

In [None]:
shap_per_channel = shap_feature_selection_2d(
    train_features_2d, y_train, test_features_2d, n_features_to_select,'shap_feature_importance_500.png')

In [None]:
X_train_2d = train_features_2d
X_test_2d = test_features_2d

In [None]:
top_indices = np.argsort(shap_per_channel)[::-1][:500]
X_train_selected = X_train_2d[..., top_indices]
X_test_selected = X_test_2d[..., top_indices]

print(f"Selected features shape (train): {X_train_selected.shape}")
print(f"Selected features shape (test): {X_test_selected.shape}")

In [None]:
class_names = os.listdir(train_path)

In [None]:
# Visualize features
visualize_features_scatter(
    X_train_selected, y_train, class_names,
    title='t-SNE of Selected Features with SHAP',
    filename='tsne-selected_shap_features_500'
)

In [None]:
# Train classifiers and collect metrics and predicted probabilities
metrics_dict = {
    'Classifier': [],
    'Accuracy': [],
    'AUC':[],
    'Precision':[],
    'Specificity': [],
    'Sensitivity': [],
    'F1': [],
    'MCC': []
}
roc_data = {}

In [None]:
result = train_cnn_lstm(X_train_selected, y_train_cat, X_test_selected, y_test,  y_test_cat,'shap-v2',epochs=50,batch=16,lr=0.00001)
metrics_dict['Classifier'].append('AttCNN')
metrics_dict['Accuracy'].append(result['metrics']['accuracy'])
metrics_dict['AUC'].append(result['metrics']['auc'])
metrics_dict['Precision'].append(result['metrics']['precision'])
metrics_dict['Specificity'].append(result['metrics']['specificity'])
metrics_dict['Sensitivity'].append(result['metrics']['sensitivity'])
metrics_dict['F1'].append(result['metrics']['f1'])
metrics_dict['MCC'].append(result['metrics']['mcc'])
roc_data['AttCNN'] = result['y_pred_probs']