In [None]:
# module1 Libraries
import os
import cv2
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from tensorflow.keras.losses import Loss
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.layers import (Input, Attention, Conv2D, MaxPooling2D, GlobalAveragePooling2D, Activation, Dense, GRU, Multiply, MaxPooling2D, BatchNormalization, Lambda,Bidirectional,
                                     ReLU, Add, Dropout, LSTM, MultiHeadAttention, Reshape,Flatten,LayerNormalization)
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing import image
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,
    confusion_matrix, roc_curve, precision_recall_curve, classification_report
)
from tensorflow.keras.initializers import he_normal
import tensorflow as tf
import shap
from lime import lime_image
from skimage.segmentation import mark_boundaries
from tf_explain.core.grad_cam import GradCAM
import json

In [None]:
# module2 Data Loading and Preprocessing
root_dir = r'D:\A\data\dataset'
train_dir = os.path.join(root_dir, "train")
test_dir = os.path.join(root_dir, "test")

# module3 Image preprocessing and augmentation
def preprocess_image(image):

    h, w, _ = image.shape
    center_x, center_y = w // 2, h // 2
    crop_size = min(center_x, center_y)  
    cropped_image = image[
        center_y - crop_size:center_y + crop_size,
        center_x - crop_size:center_x + crop_size
    ]

    resized_image = cv2.resize(cropped_image, (224, 224))

    normalized_image = resized_image / 255.0  
    return normalized_image

train_datagen = ImageDataGenerator(
    horizontal_flip=True,
    brightness_range=[0.8, 1.2],
    rotation_range=90,         
    preprocessing_function=preprocess_image, 
    validation_split=0.1
)
test_datagen = ImageDataGenerator(rescale=1.0/255)

In [None]:
# module4 Load and preprocess the training and testing images
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    subset='training', 
    shuffle=True  
)

validation_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    subset='validation',
    shuffle=False
)

test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    shuffle=False
)
print("Class indices mapping:", train_generator.class_indices)


In [None]:
import tensorflow as tf
from tensorflow.keras.layers import (Input, Conv2D, BatchNormalization, ReLU, Add, GlobalAveragePooling2D, GlobalMaxPooling2D,
                                     Dense, Reshape, Multiply, Concatenate, Activation, Lambda, Dropout, UpSampling2D, MaxPooling2D)
from tensorflow.keras.models import Model
from tensorflow.keras.layers import LayerNormalization, MultiHeadAttention

# ------------------ CBAM 注意力模块 ------------------
def cbam_block(input_feature, ratio=8):
    """
    CBAM 注意力模块：先计算通道注意力，再计算空间注意力
    """
    channel = input_feature.shape[-1]
    # --- 通道注意力 ---
    shared_layer_one = Dense(channel // ratio,
                             activation='relu',
                             kernel_initializer='he_normal',
                             use_bias=True,
                             bias_initializer='zeros')
    shared_layer_two = Dense(channel,
                             kernel_initializer='he_normal',
                             use_bias=True,
                             bias_initializer='zeros')
    avg_pool = GlobalAveragePooling2D()(input_feature)
    avg_pool = Reshape((1, 1, channel))(avg_pool)
    avg_pool = shared_layer_one(avg_pool)
    avg_pool = shared_layer_two(avg_pool)
    
    max_pool = GlobalMaxPooling2D()(input_feature)
    max_pool = Reshape((1, 1, channel))(max_pool)
    max_pool = shared_layer_one(max_pool)
    max_pool = shared_layer_two(max_pool)
    
    cbam_feature = Add()([avg_pool, max_pool])
    cbam_feature = Activation('sigmoid')(cbam_feature)
    channel_refined_feature = Multiply()([input_feature, cbam_feature])
    
    # --- 空间注意力 ---
    kernel_size = 7
    avg_pool_spatial = Lambda(lambda x: tf.reduce_mean(x, axis=3, keepdims=True))(channel_refined_feature)
    max_pool_spatial = Lambda(lambda x: tf.reduce_max(x, axis=3, keepdims=True))(channel_refined_feature)
    concat = Concatenate(axis=3)([avg_pool_spatial, max_pool_spatial])
    spatial_attention = Conv2D(filters=1, kernel_size=kernel_size, strides=1, padding='same',
                               activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(concat)
    refined_feature = Multiply()([channel_refined_feature, spatial_attention])
    return refined_feature

# ------------------ Residual CBAM 块 ------------------
def residual_cbam_block(x, filters, kernel_size=3, stride=1, downsample=False):
    """
    残差块，内嵌 CBAM 模块；当 downsample 为 True 时调整 shortcut 的尺寸
    """
    shortcut = x
    # 第一个卷积层
    x = Conv2D(filters, kernel_size, strides=stride, padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    # 第二个卷积层
    x = Conv2D(filters, kernel_size, strides=1, padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    # 加入 CBAM 注意力
    x = cbam_block(x)
    # 如果需要降采样，调整 shortcut 的通道和尺寸
    if downsample:
        shortcut = Conv2D(filters, kernel_size=1, strides=stride, padding='same', kernel_initializer='he_normal')(shortcut)
        shortcut = BatchNormalization()(shortcut)
    x = Add()([x, shortcut])
    x = ReLU()(x)
    return x

# ------------------ Edge Detection 边缘检测模块 ------------------
def edge_detection_module(input_tensor):
    """
    边缘检测模块：提取图像边缘信息
    """
    x = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(input_tensor)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(1, (1, 1), padding='same', activation='sigmoid', kernel_initializer='he_normal')(x)
    return x

# ------------------ Multi-scale Pooling 多尺度池化模块 ------------------
def multi_scale_pooling(x, scales=[1, 2, 3]):
    """
    多尺度池化模块：对输入进行不同尺度的平均池化，
    并将结果统一调整到输入相同的空间尺寸后融合
    """
    pooled_features = [x]
    # 利用静态形状（假设残差块输出尺寸固定，如 (7, 7)）
    target_h = int(x.shape[1])
    target_w = int(x.shape[2])
    target_size = (target_h, target_w)
    for scale in scales:
        pooled = tf.keras.layers.AveragePooling2D(pool_size=(scale, scale), strides=(scale, scale), padding='same')(x)
        upsampled = Lambda(lambda t: tf.image.resize(t, target_size))(pooled)
        pooled_features.append(upsampled)
    return Concatenate()(pooled_features)

# ------------------ Transformer Encoder 模块 ------------------
def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout_rate):
    """
    Transformer Encoder：包括多头注意力、前馈网络和残差连接
    """
    x = LayerNormalization(epsilon=1e-6)(inputs)
    x_attn = MultiHeadAttention(key_dim=head_size, num_heads=num_heads, dropout=dropout_rate)(x, x)
    x_attn = Dropout(dropout_rate)(x_attn)
    res = Add()([inputs, x_attn])
    x = LayerNormalization(epsilon=1e-6)(res)
    x_ff = Dense(ff_dim, activation='relu')(x)
    x_ff = Dropout(dropout_rate)(x_ff)
    x_ff = Dense(inputs.shape[-1])(x_ff)
    x_ff = Dropout(dropout_rate)(x_ff)
    return Add()([res, x_ff])

# ------------------ Context Enhancement 模块 ------------------
def context_enhancement_module(input_feature, out_channels, dilation_rates=[1, 6, 12]):
    """
    Context Enhancement 模块：利用不同膨胀率卷积获取多尺度上下文信息，
    通过全局上下文进行加权融合
    """
    dilated_features = [Conv2D(out_channels, kernel_size=3, dilation_rate=rate, padding='same',
                                 activation='relu', kernel_initializer='he_normal')(input_feature)
                         for rate in dilation_rates]
    concatenated_features = Concatenate(axis=-1)(dilated_features)
    fusion_weights = Conv2D(out_channels, kernel_size=1, activation='sigmoid', padding='same',
                            kernel_initializer='he_normal')(concatenated_features)
    weighted_features = [Multiply()([feature, fusion_weights]) for feature in dilated_features]
    fused_feature = Add()(weighted_features)
    global_context = GlobalAveragePooling2D()(input_feature)
    global_context = Dense(out_channels, activation='relu', kernel_initializer='he_normal')(global_context)
    global_context = Dense(out_channels, activation='sigmoid', kernel_initializer='he_normal')(global_context)
    global_context = Reshape((1, 1, out_channels))(global_context)
    attention_weighted_feature = Multiply()([fused_feature, global_context])
    return attention_weighted_feature

# ------------------ 模型构建 ------------------
def build_advanced_model(input_shape=(224, 224, 3), num_classes=2):
    inputs = Input(shape=input_shape)
    
    # 初始卷积层
    x = Conv2D(64, kernel_size=7, strides=2, padding='same', kernel_initializer='he_normal')(inputs)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = tf.keras.layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(x)
    
    # 第一组残差块 (不降采样)
    x = residual_cbam_block(x, filters=64, downsample=False)
    x = residual_cbam_block(x, filters=64, downsample=False)
    
    # 第二组残差块 (首次降采样)
    x = residual_cbam_block(x, filters=128, stride=2, downsample=True)
    x = residual_cbam_block(x, filters=128, downsample=False)
    
    # 第三组残差块 (降采样)
    x = residual_cbam_block(x, filters=256, stride=2, downsample=True)
    x = residual_cbam_block(x, filters=256, downsample=False)
    
    # 第四组残差块 (降采样)
    x = residual_cbam_block(x, filters=512, stride=2, downsample=True)
    x = residual_cbam_block(x, filters=512, downsample=False)
    
    # 多尺度池化模块：融合不同尺度的特征
    ms_features = multi_scale_pooling(x, scales=[1, 2, 3])
    x = Concatenate()([x, ms_features])
    
    # 边缘检测模块：提取输入图像的边缘信息，并下采样到与 x 相同的空间尺寸
    edge_features = edge_detection_module(inputs)
    # 假设 x 空间尺寸为原始输入的约 1/32（如 7x7），则对 edge_features 进行下采样
    edge_features_down = MaxPooling2D(pool_size=(32, 32), padding='same')(edge_features)
    x = Concatenate()([x, edge_features_down])
    
    # Context Enhancement 模块
    context_features = context_enhancement_module(x, out_channels=512)
    
    # Transformer Encoder：先将上下文特征展平为序列，再进行编码后恢复空间维度
    target_h = int(context_features.shape[1])
    target_w = int(context_features.shape[2])
    target_c = int(context_features.shape[3])
    x_seq = Reshape((target_h * target_w, target_c))(context_features)
    x_seq = transformer_encoder(x_seq, head_size=128, num_heads=4, ff_dim=512, dropout_rate=0.1)
    x_context = Reshape((target_h, target_w, target_c))(x_seq)
    
    # 这里可以选择直接使用 transformer 输出或与原特征融合
    x = Add()([context_features, x_context])
    
    # 全局平均池化及全连接层
    x = GlobalAveragePooling2D()(x)
    x = Dense(512, activation='relu', kernel_initializer='he_normal')(x)
    x = Dropout(0.5)(x)
    outputs = Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs, outputs)
    return model

# 根据 module4 中 train_generator 获取类别数
num_classes = len(train_generator.class_indices)
model = build_advanced_model(input_shape=(224, 224, 3), num_classes=num_classes)
model.summary()


In [None]:
# module6 The loss function

class CauchyLoss(tf.keras.losses.Loss):
    def __init__(self, sigma=1.0, name="cauchy_loss"):
        super().__init__(name=name)
        self.sigma_sq = sigma ** 2 

    def call(self, y_true, y_pred):
        error_sq = K.square(y_true - y_pred)  # (y_true - y_pred)^2
        loss = K.log(1 + error_sq / self.sigma_sq)  
        return K.mean(loss)  

    def get_config(self):
        return {"sigma": self.sigma_sq}

cauchy_loss = CauchyLoss(sigma=1.5)

#7 Compile the model
model.compile(optimizer=Adam(learning_rate=0.0002),
              loss=cauchy_loss,
              metrics=['accuracy'])

In [None]:
# module8 Callbacks

checkpoint = ModelCheckpoint('model_best.h5',  
                             monitor='val_loss',
                             save_best_only=True, 
                             mode='min') 

early_stopping = EarlyStopping(monitor='val_loss', 
                               patience=10, 
                               mode='min',
                               restore_best_weights=True) 

reduce_lr = ReduceLROnPlateau(monitor='val_loss',
                              factor=0.1, 
                              patience=5,  
                              min_lr=0.00002,
                              mode='min',
                              verbose=1 )


In [None]:
# module9 Train the model
class_weights = {
    0: 0.40,  # No DR
    1: 1.99,  # Mild
    2: 0.74,  # Moderate
    3: 3.78,  # Severe
    4: 2.49   # Proliferative DR
}
history = model.fit(
    train_generator, 
    validation_data=validation_generator,
    steps_per_epoch = len(train_generator),
    validation_steps = len(validation_generator),
    epochs=45,  
    callbacks=[checkpoint, early_stopping, reduce_lr],
    workers=6, 
    class_weight=class_weights,
    max_queue_size=20
)


In [None]:
# module10 Load the best model
model = load_model('model_best.h5', custom_objects={'CauchyLoss': CauchyLoss})

In [None]:
# Model Evaluation

In [None]:
save_dir='plots'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

# Plot the training and validation loss and accuracy curves
plt.figure(figsize=(12, 5))
# Accuracy curves
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(save_dir, 'accuracy_plot.png'), dpi=600, bbox_inches='tight')
plt.show()
# Plot the training and validation loss and accuracy curves
plt.figure(figsize=(12, 5))
# Loss curves
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(save_dir, 'loss_plot.png'), dpi=600, bbox_inches='tight')
plt.show()

In [None]:
# Prediction and Inference

In [None]:
# Make predictions on the test data
predictions = model.predict(test_generator)  
true_labels = test_generator.classes  
class_labels = list(test_generator.class_indices.keys())  
predicted_classes = np.argmax(predictions, axis=1)

In [None]:
# Calculate confusion matrix
cm = confusion_matrix(true_labels, predicted_classes)

def evaluate_model(true_labels, predicted_classes, class_labels, save_dir):
    # Ensure save directory exists
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # Micro average accuracy
    micro_accuracy = accuracy_score(true_labels, predicted_classes)
    
    # Generate classification report
    classification_report_dict = classification_report(
        true_labels, predicted_classes, target_names=class_labels, output_dict=True, zero_division=0
    )

    # Calculate specificity and accuracy per class
    specificity_per_class = []
    accuracy_per_class = []
    total_tn, total_fp = 0, 0  # For micro-average specificity
    for i in range(len(class_labels)):
        true_negative = np.sum(cm) - (np.sum(cm[i, :]) + np.sum(cm[:, i]) - cm[i, i])
        false_positive = np.sum(cm[:, i]) - cm[i, i]
        specificity = true_negative / (true_negative + false_positive) if (true_negative + false_positive) > 0 else 0.0
        specificity_per_class.append(specificity)

        accuracy = cm[i, i] / np.sum(cm[i, :]) if np.sum(cm[i, :]) > 0 else 0.0
        accuracy_per_class.append(accuracy)

        total_tn += true_negative
        total_fp += false_positive

    # Calculate macro and micro average specificity
    macro_specificity = np.mean(specificity_per_class)
    micro_specificity = total_tn / (total_tn + total_fp) if (total_tn + total_fp) > 0 else 0.0

    # Save classification report as JSON
    json_save_path = os.path.join(save_dir, "classification_report.json")
    with open(json_save_path, "w") as f:
        json.dump(classification_report_dict, f, indent=4)
    print(f"Classification Report saved to {json_save_path}")

    # Print classification report
    print("Classification Report:")
    print(f"{'Class':<20} {'Precision':<12} {'Recall':<12} {'F1-Score':<12} {'Specificity':<12} {'Accuracy':<12} {'Support':<12}")
    for idx, cls in enumerate(class_labels):
        metrics = classification_report_dict[cls]
        print(f"{cls:<20} {metrics['precision']:<12.2f} {metrics['recall']:<12.2f} {metrics['f1-score']:<12.2f} {specificity_per_class[idx]:<12.2f} {accuracy_per_class[idx]:<12.2f} {metrics['support']:<12}")

    # Calculate macro average accuracy
    macro_accuracy = np.mean(accuracy_per_class)

    # Print micro average accuracy
    print(f"\nMicro Average Accuracy: {micro_accuracy:.2f}")
    
    # Print macro average accuracy
    print(f"Macro Average Accuracy: {macro_accuracy:.2f}")

    # Macro and Micro average metrics
    macro_precision = precision_score(true_labels, predicted_classes, average='macro', zero_division=0)
    macro_recall = recall_score(true_labels, predicted_classes, average='macro', zero_division=0)
    macro_f1 = f1_score(true_labels, predicted_classes, average='macro', zero_division=0)

    micro_precision = precision_score(true_labels, predicted_classes, average='micro', zero_division=0)
    micro_recall = recall_score(true_labels, predicted_classes, average='micro', zero_division=0)
    micro_f1 = f1_score(true_labels, predicted_classes, average='micro', zero_division=0)

    print("\nMacro Average Metrics:")
    print(f"Precision: {macro_precision:.2f}, Recall: {macro_recall:.2f}, F1-Score: {macro_f1:.2f}, Specificity: {macro_specificity:.2f}")

    print("\nMicro Average Metrics:")
    print(f"Precision: {micro_precision:.2f}, Recall: {micro_recall:.2f}, F1-Score: {micro_f1:.2f}, Specificity: {micro_specificity:.2f}")

evaluate_model(true_labels, predicted_classes, class_labels, save_dir)


In [None]:
#Confusion Matrix
def plot_confusion_matrix(cm, class_labels, title='Confusion Matrix', cmap=plt.cm.Blues, save_dir=None):
    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(class_labels))
    plt.xticks(tick_marks, class_labels, rotation=45)
    plt.yticks(tick_marks, class_labels)
    
    thresh = cm.max() / 2.0
    for i, j in np.ndindex(cm.shape):
        plt.text(j, i, f"{cm[i, j]}", horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    if save_dir:
        plt.savefig(f"{save_dir}/confusion_matrix.png", dpi=600, bbox_inches='tight')
    plt.show()

# ROC-AUC 
def plot_roc_auc_curves(true_labels, predictions, class_labels, save_dir):
    binarized_labels = tf.keras.utils.to_categorical(true_labels, num_classes=len(class_labels))
    plt.figure(figsize=(12, 8))
    for i in range(len(class_labels)):
        fpr, tpr, _ = roc_curve(binarized_labels[:, i], predictions[:, i])
        roc_auc = roc_auc_score(binarized_labels[:, i], predictions[:, i])
        plt.plot(fpr, tpr, lw=2, label=f'ROC Curve for class {class_labels[i]} (area = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC-AUC Curves')
    plt.legend(loc='lower right')
    plt.grid(True)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    plt.savefig(os.path.join(save_dir, 'roc_auc_curve.png'), dpi=600, bbox_inches='tight')
    plt.show()

# Precision-Recall
def plot_precision_recall_curves(true_labels, predictions, class_labels, save_dir):
    binarized_labels = tf.keras.utils.to_categorical(true_labels, num_classes=len(class_labels))
    plt.figure(figsize=(12, 8))
    for i in range(len(class_labels)):
        precision, recall, _ = precision_recall_curve(binarized_labels[:, i], predictions[:, i])
        pr_auc = tf.keras.metrics.AUC()(tf.convert_to_tensor(recall, dtype=tf.float32), tf.convert_to_tensor(precision, dtype=tf.float32)).numpy()
        plt.plot(recall, precision, lw=2, label=f'Precision-Recall Curve for class {class_labels[i]} (area = {pr_auc:.2f})')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curves')
    plt.legend(loc='lower left')
    plt.grid(True)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    plt.savefig(os.path.join(save_dir, 'precision_recall_curve.png'), dpi=600, bbox_inches='tight')
    plt.show()

plot_confusion_matrix(cm, class_labels, title='Confusion Matrix', save_dir=save_dir)
plot_roc_auc_curves(true_labels, predictions, class_labels, save_dir)
plot_precision_recall_curves(true_labels, predictions, class_labels, save_dir)

In [None]:
#lime

In [None]:
explainer = lime_image.LimeImageExplainer()

def model_predict(image_batch):
    return model.predict(image_batch)

class_labels = list(test_generator.class_indices.keys())

num_images = 1
class_images = {label: [] for label in class_labels}

for images, labels in test_generator:
    for i, label in enumerate(labels):
        label_index = np.argmax(label)
        class_name = class_labels[label_index]
        if len(class_images[class_name]) < num_images:
            class_images[class_name].append(images[i])
    if all(len(class_images[label]) >= num_images for label in class_labels):
        break
save_dir = "plots/lime"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
for class_name, images in class_images.items():
    for idx, img in enumerate(images):
        explanation = explainer.explain_instance(
            img.astype('double'),
            model_predict,
            top_labels=1,
            num_samples=1000
        )
        temp, mask = explanation.get_image_and_mask(
            explanation.top_labels[0],
            positive_only=False,
            num_features=10
        )
        plt.figure(figsize=(12, 6))
        
        plt.subplot(1, 2, 1)
        plt.imshow(img) 
        plt.title(f"Original: {class_name}")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))
        plt.title(f"Explained: {class_name} | Image {idx + 1}")
        plt.axis('off')
        
        save_path = os.path.join(save_dir, f"{class_name}_image_{idx + 1}.png")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()


In [None]:
#gradcam

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tf_explain.core.grad_cam import GradCAM

save_dir = "plots/grad_cam_tuned"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

target_layer_name = [layer.name for layer in model.layers if 'conv2d' in layer.name][-1]
grad_cam = GradCAM()

for class_name, images in class_images.items():
    for idx, (img, label_index) in enumerate(images):
        data = ([img], None) 
        cam_result = grad_cam.explain(
            data, model, class_index=label_index, layer_name=target_layer_name
        )

        # **优化 Grad-CAM 颜色对比**
        cam_result = cam_result - np.min(cam_result)  # **确保最小值为0**
        cam_result = cam_result / (np.max(cam_result) + 1e-8)  # **避免除零**
        cam_result = (cam_result * 255).astype(np.uint8)  # **归一化到 0-255**

        # **颜色映射调整**
        heatmap = cv2.applyColorMap(cam_result, cv2.COLORMAP_JET)  # 或者 COLORMAP_TURBO
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)  # 适配 matplotlib

        # **叠加 Grad-CAM 到原图**
        plt.figure(figsize=(12, 6))
        
        plt.subplot(1, 2, 1)
        plt.imshow(img)  
        plt.title(f"Original: {class_name}")
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(img, alpha=0.4)  # 原图透明度降低
        plt.imshow(heatmap, alpha=0.6)  # Grad-CAM 更明显
        plt.title(f"GradCAM Enhanced: {class_name} | Image {idx + 1}")
        plt.axis('off')

        save_path = os.path.join(save_dir, f"{class_name}_image_{idx + 1}.png")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()


In [None]:
import shap
import matplotlib.pyplot as plt
import numpy as np
import os

# 重新可视化 SHAP 解释
def plot_shap_values(image, shap_values, save_path, class_name, idx):
    """ 绘制增强版 SHAP 解释结果 """
    shap_values = np.array(shap_values)  # 确保 SHAP 值是 numpy 数组
    shap_values = np.mean(shap_values, axis=-1)  # **通道均值 (RGB → 灰度)**

    # **增强对比度**
    abs_max = np.max(np.abs(shap_values))
    if abs_max > 0:
        shap_values /= abs_max  # **归一化，使 SHAP 颜色更明显**
    shap_values *= 3  # **放大 SHAP 效果，增强可视化**

    plt.figure(figsize=(12, 6))
    
    # 原始图像
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title(f"Original: {class_name}")
    plt.axis('off')
    
    # SHAP 热力图（增强版）
    plt.subplot(1, 2, 2)
    plt.imshow(image, alpha=0.4)  # 原图透明度降低
    plt.imshow(shap_values, cmap="seismic", alpha=0.6, vmin=-1, vmax=1)  # **增强对比度**
    plt.colorbar(label="SHAP value")
    plt.title(f"SHAP Enhanced: {class_name} | Image {idx + 1}")
    plt.axis('off')

    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

# 选择一个测试图片
num_images = 5
save_dir = "plots/shap_enhanced"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

for i in range(num_images):
    class_name = "ClassNamePlaceholder"  # 你可以替换为实际的类别名
    save_path = os.path.join(save_dir, f"shap_enhanced_image_{i+1}.png")
    plot_shap_values(sample_images[i], shap_values[0][i], save_path, class_name, i)  # **使用增强 SHAP 版本**