In [None]:
# module1 Libraries
import os
import cv2
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from tensorflow.keras.losses import Loss
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.layers import (Input, Attention, Conv2D, MaxPooling2D, GlobalAveragePooling2D, Activation, Dense, GRU, Multiply, MaxPooling2D, BatchNormalization, Lambda,Bidirectional,
                                     ReLU, Add, Dropout, LSTM, MultiHeadAttention, Reshape,Flatten,LayerNormalization)
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing import image
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,
    confusion_matrix, roc_curve, precision_recall_curve, classification_report
)
from tensorflow.keras.initializers import he_normal
import tensorflow as tf
import shap
from lime import lime_image
from skimage.segmentation import mark_boundaries
from tf_explain.core.grad_cam import GradCAM
import json

In [None]:
# module2 Data Loading and Preprocessing
root_dir = r'D:\Multi-Class Diabetic Retinopathy Classification\data\aptos2019-blindness-detection'
train_dir = os.path.join(root_dir, "train")
test_dir = os.path.join(root_dir, "test")

# module3 Image preprocessing and augmentation
import cv2
import numpy as np

import cv2
import numpy as np
def preprocess_image(image):

    h, w, _ = image.shape
    center_x, center_y = w // 2, h // 2
    crop_size = min(center_x, center_y)  
    cropped_image = image[
        center_y - crop_size:center_y + crop_size,
        center_x - crop_size:center_x + crop_size
    ]

    resized_image = cv2.resize(cropped_image, (224, 224), interpolation=cv2.INTER_CUBIC)

    normalized_image = resized_image / 255.0  
    return normalized_image

train_datagen = ImageDataGenerator(
    horizontal_flip=True,
    rotation_range=90,
    preprocessing_function=preprocess_image,  # 仍然用这个函数
    validation_split=0.2
)

test_datagen = ImageDataGenerator(rescale=1.0/255)

In [None]:
# module4 Load and preprocess the training and testing images
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    subset='training', 
    shuffle=True  
)

validation_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    subset='validation',
    shuffle=False
)

test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    shuffle=False
)
print("Class indices mapping:", train_generator.class_indices)


In [None]:
import tensorflow as tf
from tensorflow.keras.layers import (Input, Conv2D, BatchNormalization, ReLU, Add, GlobalAveragePooling2D, GlobalMaxPooling2D,Conv1D,
                                     Dense, Reshape, Multiply, Concatenate, Activation, Lambda, Dropout)
from tensorflow.keras.models import Model

from tensorflow.keras.regularizers import l2

import tensorflow as tf
from tensorflow.keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, Dense, Reshape, Multiply, Add, ReLU, Conv2D, Concatenate, BatchNormalization

def se_block(x, reduction_ratio=16):
    filters = x.shape[-1]
    squeeze = GlobalAveragePooling2D()(x)
    excitation = Dense(filters // reduction_ratio, activation='relu', kernel_initializer='he_normal')(squeeze)
    excitation = Dense(filters, activation='sigmoid', kernel_initializer='he_normal')(excitation)
    excitation = Reshape((1, 1, filters))(excitation)
    x = Multiply()([x, excitation])
    return x

def eca_block(x, kernel_size=3):
    filters = x.shape[-1]
    squeeze = GlobalAveragePooling2D()(x)
    squeeze = Reshape((filters, 1))(squeeze)
    excitation = Conv1D(1, kernel_size=kernel_size, padding="same", activation='sigmoid', kernel_initializer='he_normal')(squeeze)
    excitation = Reshape((1, 1, filters))(excitation)
    x = Multiply()([x, excitation])
    return x

def cbam_block(x, reduction_ratio=16):
    filters = x.shape[-1]
    avg_pool = GlobalAveragePooling2D()(x)
    max_pool = GlobalMaxPooling2D()(x)
    shared_dense = Dense(filters // reduction_ratio, activation='relu', kernel_initializer='he_normal')
    shared_dense_out_avg = shared_dense(avg_pool)
    shared_dense_out_max = shared_dense(max_pool)
    channel_attention = Dense(filters, activation='sigmoid', kernel_initializer='he_normal')(Add()([shared_dense_out_avg, shared_dense_out_max]))
    channel_attention = Reshape((1, 1, filters))(channel_attention)
    x = Multiply()([x, channel_attention])

    avg_pool_spatial = tf.reduce_mean(x, axis=-1, keepdims=True)
    max_pool_spatial = tf.reduce_max(x, axis=-1, keepdims=True)
    concat = Concatenate(axis=-1)([avg_pool_spatial, max_pool_spatial])
    spatial_attention = Conv2D(1, kernel_size=7, strides=1, padding='same', activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(concat)
    x = Multiply()([x, spatial_attention])
    return x

def cbam_se_eca_block(x):
    #x = cbam_block(x)  
    #x = se_block(x)  
    #x = eca_block(x) 
    return x

import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import (Input, Conv2D, BatchNormalization, ReLU, Add, 
                                     GlobalAveragePooling2D, Dense, Dropout)
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2

def residual_block(x, filters, kernel_size=3, stride=1, downsample=False, regularizer=l2(0.001)):
    shortcut = x

    x = Conv2D(filters, kernel_size, strides=stride, padding='same',
               kernel_initializer='he_normal',
               kernel_regularizer=regularizer)(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    x = Conv2D(filters, kernel_size, strides=1, padding='same',
               kernel_initializer='he_normal',
               kernel_regularizer=regularizer)(x)
    x = BatchNormalization()(x)

    if downsample:
        shortcut = Conv2D(filters, kernel_size=1, strides=stride, padding='same',
                          kernel_initializer='he_normal',
                          kernel_regularizer=regularizer)(shortcut)
        shortcut = BatchNormalization()(shortcut)

    x = Add()([x, shortcut])
    x = ReLU()(x)
    return x


def build_model(input_shape=(224, 224, 3), num_classes=5, regularizer=l2(0.001)):
    inputs = Input(shape=input_shape)
    x = Conv2D(64, kernel_size=7, strides=2, padding='same',
               kernel_initializer='he_normal',
               kernel_regularizer=regularizer)(inputs)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = tf.keras.layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(x)
    
    x = residual_block(x, filters=64, downsample=False, regularizer=regularizer)
    x = residual_block(x, filters=64, downsample=False, regularizer=regularizer)
    
    x = residual_block(x, filters=128, stride=2, downsample=True, regularizer=regularizer)
    x = residual_block(x, filters=128, downsample=False, regularizer=regularizer)
    
    x = residual_block(x, filters=256, stride=2, downsample=True, regularizer=regularizer)
    x = residual_block(x, filters=256, downsample=False, regularizer=regularizer)
    
    x = residual_block(x, filters=512, stride=2, downsample=True, regularizer=regularizer)
    x = residual_block(x, filters=512, downsample=False, regularizer=regularizer)
    
    x = GlobalAveragePooling2D()(x)
    x = Dense(512, activation='relu', kernel_initializer='he_normal',
              kernel_regularizer=regularizer)(x)
    x = Dropout(0.3)(x)
    outputs = Dense(num_classes, activation=None, kernel_initializer='he_normal',
                    kernel_regularizer=regularizer)(x)
    
    model = Model(inputs, outputs)
    return model

# **类别样本数**
#class_counts = [1635, 331, 891, 174, 264]
#class_counts = [1454, 302, 786, 157, 230]

class_counts = np.bincount(train_generator.classes, minlength=5).astype(np.float32)

# 打印类别对应的样本数量
for i, count in enumerate(class_counts):
    print(f"类别 {i}: {int(count)} 样本")

# 也可以直接打印整个数组
print("训练集类别样本分布:", class_counts)



# **构建模型**
model = build_model(input_shape=(224, 224, 3), num_classes=5)
model.summary()

class BalancedSoftmaxLoss(tf.keras.losses.Loss):
    def __init__(self, class_counts, name='balanced_softmax_loss', **kwargs):
        super(BalancedSoftmaxLoss, self).__init__(name=name, **kwargs)
        self.class_counts = tf.constant(class_counts, dtype=tf.float32)

    def call(self, y_true, logits):
        y_true_converted = tf.cond(
            tf.equal(tf.rank(y_true), 2),
            lambda: tf.argmax(y_true, axis=-1),
            lambda: tf.cast(y_true, tf.int64)
        )
        adjusted_logits = logits + tf.math.log(self.class_counts)
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true_converted, logits=adjusted_logits)
        return loss

    def get_config(self):
        config = super(BalancedSoftmaxLoss, self).get_config()
        config.update({"class_counts": self.class_counts.numpy().tolist()})
        return config

# 编译模型时使用自定义的 Balanced Softmax Loss
model.compile(optimizer=Adam(learning_rate=0.0002),
              loss=BalancedSoftmaxLoss(class_counts=class_counts),
              metrics=['accuracy'])



In [None]:
# module8 Callbacks

import math
from tensorflow.keras.callbacks import LearningRateScheduler

# 定义余弦退火函数
def cosine_annealing(epoch, lr):
    total_epochs = 70  # 总训练周期数
    initial_lr = 0.00013
    min_lr = 0.00001
    new_lr = min_lr + (initial_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * epoch / total_epochs))
    return new_lr

# 使用 LearningRateScheduler 替换 ReduceLROnPlateau
cosine_scheduler = LearningRateScheduler(cosine_annealing, verbose=1)

# 保持其他回调不变
checkpoint = ModelCheckpoint('model_best.h5',  
                             monitor='val_loss',
                             save_best_only=True, 
                             mode='min') 

early_stopping = EarlyStopping(monitor='val_loss', 
                               patience=10, 
                               mode='min',
                               restore_best_weights=True) 

# 使用余弦退火回调进行训练
history = model.fit(
    train_generator, 
    validation_data=validation_generator,
    steps_per_epoch=len(train_generator),
    validation_steps=len(validation_generator),
    epochs=70,  
    callbacks=[checkpoint, early_stopping, cosine_scheduler],
    workers=6, 
    max_queue_size=20
)


In [None]:
# module10 Load the best model
model = load_model('model_best.h5', custom_objects={'BalancedSoftmaxLoss': BalancedSoftmaxLoss})

In [None]:
# Model Evaluation

In [None]:
save_dir='plots'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)
     
plt.figure(figsize=(12, 5))
# Accuracy curves
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy', fontsize=25)
plt.xlabel('Epoch', fontsize=20)
plt.ylabel('Accuracy', fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.legend(fontsize=20)
plt.grid(True)
plt.savefig(os.path.join(save_dir, 'accuracy_plot.png'), dpi=600, bbox_inches='tight')
plt.show()

plt.figure(figsize=(12, 5))
# Loss curves
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss', fontsize=25)
plt.xlabel('Epoch', fontsize=20)
plt.ylabel('Loss', fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.legend(fontsize=20)
plt.grid(True)
plt.savefig(os.path.join(save_dir, 'loss_plot.png'), dpi=600, bbox_inches='tight')
plt.show()

In [None]:
# Prediction and Inference

In [None]:
# Make predictions on the test data
predictions = model.predict(test_generator)  
true_labels = test_generator.classes  
class_labels = list(test_generator.class_indices.keys())  
predicted_classes = np.argmax(predictions, axis=1)

In [None]:
# Calculate confusion matrix
cm = confusion_matrix(true_labels, predicted_classes)

def evaluate_model(true_labels, predicted_classes, class_labels, save_dir):
    # Ensure save directory exists
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # Micro average accuracy
    micro_accuracy = accuracy_score(true_labels, predicted_classes)
    
    # Generate classification report
    classification_report_dict = classification_report(
        true_labels, predicted_classes, target_names=class_labels, output_dict=True, zero_division=0
    )

    # Calculate specificity and accuracy per class
    specificity_per_class = []
    accuracy_per_class = []
    total_tn, total_fp = 0, 0  # For micro-average specificity
    for i in range(len(class_labels)):
        true_negative = np.sum(cm) - (np.sum(cm[i, :]) + np.sum(cm[:, i]) - cm[i, i])
        false_positive = np.sum(cm[:, i]) - cm[i, i]
        specificity = true_negative / (true_negative + false_positive) if (true_negative + false_positive) > 0 else 0.0
        specificity_per_class.append(specificity)

        accuracy = cm[i, i] / np.sum(cm[i, :]) if np.sum(cm[i, :]) > 0 else 0.0
        accuracy_per_class.append(accuracy)

        total_tn += true_negative
        total_fp += false_positive

    # Calculate macro and micro average specificity
    macro_specificity = np.mean(specificity_per_class)
    micro_specificity = total_tn / (total_tn + total_fp) if (total_tn + total_fp) > 0 else 0.0

    # Save classification report as JSON
    json_save_path = os.path.join(save_dir, "classification_report.json")
    with open(json_save_path, "w") as f:
        json.dump(classification_report_dict, f, indent=4)
    print(f"Classification Report saved to {json_save_path}")

    # Print classification report
    print("Classification Report:")
    print(f"{'Class':<20} {'Precision':<12} {'Recall':<12} {'F1-Score':<12} {'Specificity':<12} {'Accuracy':<12} {'Support':<12}")
    for idx, cls in enumerate(class_labels):
        metrics = classification_report_dict[cls]
        print(f"{cls:<20} {metrics['precision']:<12.2f} {metrics['recall']:<12.2f} {metrics['f1-score']:<12.2f} {specificity_per_class[idx]:<12.2f} {accuracy_per_class[idx]:<12.2f} {metrics['support']:<12}")

    # Calculate macro average accuracy
    macro_accuracy = np.mean(accuracy_per_class)

    # Print micro average accuracy
    print(f"\nMicro Average Accuracy: {micro_accuracy:.2f}")
    
    # Print macro average accuracy
    print(f"Macro Average Accuracy: {macro_accuracy:.2f}")

    # Macro and Micro average metrics
    macro_precision = precision_score(true_labels, predicted_classes, average='macro', zero_division=0)
    macro_recall = recall_score(true_labels, predicted_classes, average='macro', zero_division=0)
    macro_f1 = f1_score(true_labels, predicted_classes, average='macro', zero_division=0)

    micro_precision = precision_score(true_labels, predicted_classes, average='micro', zero_division=0)
    micro_recall = recall_score(true_labels, predicted_classes, average='micro', zero_division=0)
    micro_f1 = f1_score(true_labels, predicted_classes, average='micro', zero_division=0)

    print("\nMacro Average Metrics:")
    print(f"Precision: {macro_precision:.2f}, Recall: {macro_recall:.2f}, F1-Score: {macro_f1:.2f}, Specificity: {macro_specificity:.2f}")

    print("\nMicro Average Metrics:")
    print(f"Precision: {micro_precision:.2f}, Recall: {micro_recall:.2f}, F1-Score: {micro_f1:.2f}, Specificity: {micro_specificity:.2f}")

evaluate_model(true_labels, predicted_classes, class_labels, save_dir)


In [None]:
#Confusion Matrix
def plot_confusion_matrix(cm, class_labels, title='Confusion Matrix', cmap=plt.cm.Blues, save_dir=None):
    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(class_labels))
    plt.xticks(tick_marks, class_labels, rotation=45)
    plt.yticks(tick_marks, class_labels)
    
    thresh = cm.max() / 2.0
    for i, j in np.ndindex(cm.shape):
        plt.text(j, i, f"{cm[i, j]}", horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    if save_dir:
        plt.savefig(f"{save_dir}/confusion_matrix.png", dpi=600, bbox_inches='tight')
    plt.show()

# ROC-AUC 
def plot_roc_auc_curves(true_labels, predictions, class_labels, save_dir):
    binarized_labels = tf.keras.utils.to_categorical(true_labels, num_classes=len(class_labels))
    plt.figure(figsize=(12, 8))
    for i in range(len(class_labels)):
        fpr, tpr, _ = roc_curve(binarized_labels[:, i], predictions[:, i])
        roc_auc = roc_auc_score(binarized_labels[:, i], predictions[:, i])
        plt.plot(fpr, tpr, lw=2, label=f'ROC Curve for class {class_labels[i]} (area = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC-AUC Curves')
    plt.legend(loc='lower right')
    plt.grid(True)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    plt.savefig(os.path.join(save_dir, 'roc_auc_curve.png'), dpi=600, bbox_inches='tight')
    plt.show()

# Precision-Recall
def plot_precision_recall_curves(true_labels, predictions, class_labels, save_dir):
    binarized_labels = tf.keras.utils.to_categorical(true_labels, num_classes=len(class_labels))
    plt.figure(figsize=(12, 8))
    for i in range(len(class_labels)):
        precision, recall, _ = precision_recall_curve(binarized_labels[:, i], predictions[:, i])
        pr_auc = tf.keras.metrics.AUC()(tf.convert_to_tensor(recall, dtype=tf.float32), tf.convert_to_tensor(precision, dtype=tf.float32)).numpy()
        plt.plot(recall, precision, lw=2, label=f'Precision-Recall Curve for class {class_labels[i]} (area = {pr_auc:.2f})')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curves')
    plt.legend(loc='lower left')
    plt.grid(True)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    plt.savefig(os.path.join(save_dir, 'precision_recall_curve.png'), dpi=600, bbox_inches='tight')
    plt.show()

plot_confusion_matrix(cm, class_labels, title='Confusion Matrix', save_dir=save_dir)
plot_roc_auc_curves(true_labels, predictions, class_labels, save_dir)
plot_precision_recall_curves(true_labels, predictions, class_labels, save_dir)

In [None]:
#lime

In [None]:
explainer = lime_image.LimeImageExplainer()

def model_predict(image_batch):
    return model.predict(image_batch)

class_labels = list(test_generator.class_indices.keys())

num_images = 1
class_images = {label: [] for label in class_labels}

for images, labels in test_generator:
    for i, label in enumerate(labels):
        label_index = np.argmax(label)
        class_name = class_labels[label_index]
        if len(class_images[class_name]) < num_images:
            class_images[class_name].append(images[i])
    if all(len(class_images[label]) >= num_images for label in class_labels):
        break
save_dir = "plots/lime"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
for class_name, images in class_images.items():
    for idx, img in enumerate(images):
        explanation = explainer.explain_instance(
            img.astype('double'),
            model_predict,
            top_labels=1,
            num_samples=1000
        )
        temp, mask = explanation.get_image_and_mask(
            explanation.top_labels[0],
            positive_only=False,
            num_features=10
        )
        plt.figure(figsize=(12, 6))
        
        plt.subplot(1, 2, 1)
        plt.imshow(img) 
        plt.title(f"Original: {class_name}")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))
        plt.title(f"Explained: {class_name} | Image {idx + 1}")
        plt.axis('off')
        
        save_path = os.path.join(save_dir, f"{class_name}_image_{idx + 1}.png")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()


In [None]:
#gradcam

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tf_explain.core.grad_cam import GradCAM

save_dir = "plots/grad_cam_tuned"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

target_layer_name = [layer.name for layer in model.layers if 'conv2d' in layer.name][-1]
grad_cam = GradCAM()

for class_name, images in class_images.items():
    for idx, (img, label_index) in enumerate(images):
        data = ([img], None) 
        cam_result = grad_cam.explain(
            data, model, class_index=label_index, layer_name=target_layer_name
        )

        # **优化 Grad-CAM 颜色对比**
        cam_result = cam_result - np.min(cam_result)  # **确保最小值为0**
        cam_result = cam_result / (np.max(cam_result) + 1e-8)  # **避免除零**
        cam_result = (cam_result * 255).astype(np.uint8)  # **归一化到 0-255**

        # **颜色映射调整**
        heatmap = cv2.applyColorMap(cam_result, cv2.COLORMAP_JET)  # 或者 COLORMAP_TURBO
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)  # 适配 matplotlib

        # **叠加 Grad-CAM 到原图**
        plt.figure(figsize=(12, 6))
        
        plt.subplot(1, 2, 1)
        plt.imshow(img)  
        plt.title(f"Original: {class_name}")
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(img, alpha=0.4)  # 原图透明度降低
        plt.imshow(heatmap, alpha=0.6)  # Grad-CAM 更明显
        plt.title(f"GradCAM Enhanced: {class_name} | Image {idx + 1}")
        plt.axis('off')

        save_path = os.path.join(save_dir, f"{class_name}_image_{idx + 1}.png")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()


In [None]:
import shap
import matplotlib.pyplot as plt
import numpy as np
import os

# 重新可视化 SHAP 解释
def plot_shap_values(image, shap_values, save_path, class_name, idx):
    """ 绘制增强版 SHAP 解释结果 """
    shap_values = np.array(shap_values)  # 确保 SHAP 值是 numpy 数组
    shap_values = np.mean(shap_values, axis=-1)  # **通道均值 (RGB → 灰度)**

    # **增强对比度**
    abs_max = np.max(np.abs(shap_values))
    if abs_max > 0:
        shap_values /= abs_max  # **归一化，使 SHAP 颜色更明显**
    shap_values *= 3  # **放大 SHAP 效果，增强可视化**

    plt.figure(figsize=(12, 6))
    
    # 原始图像
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title(f"Original: {class_name}")
    plt.axis('off')
    
    # SHAP 热力图（增强版）
    plt.subplot(1, 2, 2)
    plt.imshow(image, alpha=0.4)  # 原图透明度降低
    plt.imshow(shap_values, cmap="seismic", alpha=0.6, vmin=-1, vmax=1)  # **增强对比度**
    plt.colorbar(label="SHAP value")
    plt.title(f"SHAP Enhanced: {class_name} | Image {idx + 1}")
    plt.axis('off')

    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

# 选择一个测试图片
num_images = 5
save_dir = "plots/shap_enhanced"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

for i in range(num_images):
    class_name = "ClassNamePlaceholder"  # 你可以替换为实际的类别名
    save_path = os.path.join(save_dir, f"shap_enhanced_image_{i+1}.png")
    plot_shap_values(sample_images[i], shap_values[0][i], save_path, class_name, i)  # **使用增强 SHAP 版本**