In [None]:
# Libraries

In [None]:
import os
import cv2
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from tensorflow.keras.losses import Loss
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.layers import (Input, Conv2D, MaxPooling2D, GlobalAveragePooling2D, Dense, BatchNormalization, 
                                     ReLU, Add, Dropout, LSTM, MultiHeadAttention, Reshape,Flatten)
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing import image
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,
    confusion_matrix, roc_curve, precision_recall_curve, classification_report
)
import tensorflow as tf
import shap
from lime import lime_image
from skimage.segmentation import mark_boundaries
from tf_explain.core.grad_cam import GradCAM
import json

In [None]:
# Data Loading and Preprocessing

In [None]:
root_dir = r'D:\Multi-Class Diabetic Retinopathy Classification\data\aptos2019-blindness-detection'
train_dir = os.path.join(root_dir, "train")
test_dir = os.path.join(root_dir, "test")

In [None]:
# Image preprocessing and augmentation
def preprocess_image(image):

    h, w, _ = image.shape
    center_x, center_y = w // 2, h // 2
    crop_size = min(center_x, center_y)  
    cropped_image = image[
        center_y - crop_size:center_y + crop_size,
        center_x - crop_size:center_x + crop_size
    ]

    resized_image = cv2.resize(cropped_image, (224, 224))

    normalized_image = resized_image / 255.0  
    return normalized_image

train_datagen = ImageDataGenerator(
    horizontal_flip=True,
    rotation_range=90,
    zoom_range=0.2,
    brightness_range=[0.8, 1.2],
    preprocessing_function=preprocess_image, 
    validation_split=0.2               
)
test_datagen = ImageDataGenerator(rescale=1.0/255)

In [None]:
# Load and preprocess the training and testing images
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=64,
    class_mode='categorical',
    subset='training', 
    shuffle=True  
)

validation_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=64,
    class_mode='categorical',
    subset='validation',
    shuffle=False
)

test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(224, 224),
    batch_size=64,
    class_mode='categorical',
    shuffle=False
)
print("Class indices mapping:", train_generator.class_indices)

In [None]:
# Model Definition

In [None]:
# residual block with L2 regularization
def residual_block(x, filters, strides=1):
    shortcut = x
    if strides != 1 or x.shape[-1] != filters * 4:
        shortcut = Conv2D(filters * 4, kernel_size=(1, 1), strides=strides, padding='same',
                          kernel_regularizer=l2(0.001))(shortcut)
        shortcut = BatchNormalization()(shortcut)

    # 1x1 Conv (Reduce dimension)
    x = Conv2D(filters, kernel_size=(1, 1), strides=strides, padding='same', kernel_regularizer=l2(0.001))(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    # 3x3 Conv
    x = Conv2D(filters, kernel_size=(3, 3), padding='same', kernel_regularizer=l2(0.001))(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    # 1x1 Conv (Restore dimension)
    x = Conv2D(filters * 4, kernel_size=(1, 1), padding='same', kernel_regularizer=l2(0.001))(x)
    x = BatchNormalization()(x)

    # Add shortcut to the main path
    x = Add()([shortcut, x])
    x = ReLU()(x)
    return x



In [None]:
# Input layer
input = Input(shape=(224, 224, 3))

# Initial convolutional layer
x = Conv2D(32, kernel_size=(7, 7), strides=(2, 2), padding='same')(input)
x = BatchNormalization()(x)
x = ReLU()(x)
x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)

# MHA
x = MaxPooling2D(pool_size=(2, 2))(x)  
x = Reshape((-1, x.shape[-1]))(x)  # Flatten spatial dimensions for MHA
x = MultiHeadAttention(num_heads=2, key_dim=16)(x, x)  
x = Reshape((28, 28, 32))(x)  

# Residual Block Stages
# Stage 1
x = residual_block(x, 32)
x = residual_block(x, 32)

# Stage 2
x = residual_block(x, 64, strides=2)
x = residual_block(x, 64)

# MHA
x = MaxPooling2D(pool_size=(2, 2))(x)  
x = Reshape((-1, x.shape[-1]))(x)  
x = MultiHeadAttention(num_heads=2, key_dim=16)(x, x)  
x = Reshape((14, 14, 64))(x) 

# Stage 3
x = residual_block(x, 128, strides=2)
x = residual_block(x, 128)

# Stage 4
x = residual_block(x, 256, strides=2)
x = residual_block(x, 256)

# Global Average Pooling
x = GlobalAveragePooling2D()(x)

# Reshape to sequence format for LSTM
x = Reshape((1, -1))(x)  # Reshape to (batch_size, time_steps=1, features)
x = LSTM(32, return_sequences=False)(x)  

# Fully Connected Layer
x = Dense(128, activation='relu', kernel_regularizer=l2(0.001))(x)
output = Dense(3, activation='softmax', kernel_regularizer=l2(0.001))(x)

In [None]:
model = Model(inputs=input, outputs=output)

model.summary()

In [None]:
class CustomLoss(Loss):
    def call(self, y_true, y_pred):
        # 调用 Keras 内置的 categorical_crossentropy
        return K.categorical_crossentropy(y_true, y_pred)

# 初始化自定义损失函数
custom_loss = CustomLoss()

# 编译模型
model.compile(optimizer=Adam(learning_rate=0.0002),
              loss=custom_loss,
              metrics=['accuracy'])

In [None]:
# Model Training

In [None]:
checkpoint = ModelCheckpoint('model_best.h5',  
                             monitor='val_loss',
                             save_best_only=True, 
                             mode='min') 

early_stopping = EarlyStopping(monitor='val_loss', 
                               patience=10, 
                               mode='min',
                               restore_best_weights=True) 

reduce_lr = ReduceLROnPlateau(monitor='val_loss',
                              factor=0.1, 
                              patience=5,  
                              min_lr=0.00002,
                              mode='min',
                              verbose=1 )


In [None]:

history = model.fit(
    train_generator, 
    validation_data=validation_generator,
    steps_per_epoch = len(train_generator),
    validation_steps = len(validation_generator),
    epochs=45,  
    callbacks=[checkpoint, early_stopping, reduce_lr],
    workers=6,  
    max_queue_size=30                                                                                                                            
)

In [None]:
# Model Evaluation

In [None]:
save_dir='plots'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

In [None]:
# Plot the training and validation loss and accuracy curves
plt.figure(figsize=(12, 5))
# Accuracy curves
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(save_dir, 'accuracy_plot.png'), dpi=600, bbox_inches='tight')
plt.show()
# Plot the training and validation loss and accuracy curves
plt.figure(figsize=(12, 5))
# Loss curves
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(save_dir, 'loss_plot.png'), dpi=600, bbox_inches='tight')
plt.show()

In [None]:
# Model Loading

In [None]:
model = load_model('model_best.h5', custom_objects={'CustomLoss': CustomLoss})

In [None]:
# Prediction and Inference

In [None]:
# Make predictions on the test data
predictions = model.predict(test_generator)  
true_labels = test_generator.classes  
class_labels = list(test_generator.class_indices.keys())  
predicted_classes = np.argmax(predictions, axis=1)

In [None]:
# Calculate confusion matrix
cm = confusion_matrix(true_labels, predicted_classes)

def evaluate_model(true_labels, predicted_classes, class_labels, save_dir):
    # Ensure save directory exists
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # Micro average accuracy
    micro_accuracy = accuracy_score(true_labels, predicted_classes)
    
    # Generate classification report
    classification_report_dict = classification_report(
        true_labels, predicted_classes, target_names=class_labels, output_dict=True, zero_division=0
    )

    # Calculate specificity and accuracy per class
    specificity_per_class = []
    accuracy_per_class = []
    total_tn, total_fp = 0, 0  # For micro-average specificity
    for i in range(len(class_labels)):
        true_negative = np.sum(cm) - (np.sum(cm[i, :]) + np.sum(cm[:, i]) - cm[i, i])
        false_positive = np.sum(cm[:, i]) - cm[i, i]
        specificity = true_negative / (true_negative + false_positive) if (true_negative + false_positive) > 0 else 0.0
        specificity_per_class.append(specificity)

        accuracy = cm[i, i] / np.sum(cm[i, :]) if np.sum(cm[i, :]) > 0 else 0.0
        accuracy_per_class.append(accuracy)

        total_tn += true_negative
        total_fp += false_positive

    # Calculate macro and micro average specificity
    macro_specificity = np.mean(specificity_per_class)
    micro_specificity = total_tn / (total_tn + total_fp) if (total_tn + total_fp) > 0 else 0.0

    # Save classification report as JSON
    json_save_path = os.path.join(save_dir, "classification_report.json")
    with open(json_save_path, "w") as f:
        json.dump(classification_report_dict, f, indent=4)
    print(f"Classification Report saved to {json_save_path}")

    # Print classification report
    print("Classification Report:")
    print(f"{'Class':<20} {'Precision':<12} {'Recall':<12} {'F1-Score':<12} {'Specificity':<12} {'Accuracy':<12} {'Support':<12}")
    for idx, cls in enumerate(class_labels):
        metrics = classification_report_dict[cls]
        print(f"{cls:<20} {metrics['precision']:<12.2f} {metrics['recall']:<12.2f} {metrics['f1-score']:<12.2f} {specificity_per_class[idx]:<12.2f} {accuracy_per_class[idx]:<12.2f} {metrics['support']:<12}")

    # Calculate macro average accuracy
    macro_accuracy = np.mean(accuracy_per_class)

    # Print micro average accuracy
    print(f"\nMicro Average Accuracy: {micro_accuracy:.2f}")
    
    # Print macro average accuracy
    print(f"Macro Average Accuracy: {macro_accuracy:.2f}")

    # Macro and Micro average metrics
    macro_precision = precision_score(true_labels, predicted_classes, average='macro', zero_division=0)
    macro_recall = recall_score(true_labels, predicted_classes, average='macro', zero_division=0)
    macro_f1 = f1_score(true_labels, predicted_classes, average='macro', zero_division=0)

    micro_precision = precision_score(true_labels, predicted_classes, average='micro', zero_division=0)
    micro_recall = recall_score(true_labels, predicted_classes, average='micro', zero_division=0)
    micro_f1 = f1_score(true_labels, predicted_classes, average='micro', zero_division=0)

    print("\nMacro Average Metrics:")
    print(f"Precision: {macro_precision:.2f}, Recall: {macro_recall:.2f}, F1-Score: {macro_f1:.2f}, Specificity: {macro_specificity:.2f}")

    print("\nMicro Average Metrics:")
    print(f"Precision: {micro_precision:.2f}, Recall: {micro_recall:.2f}, F1-Score: {micro_f1:.2f}, Specificity: {micro_specificity:.2f}")

evaluate_model(true_labels, predicted_classes, class_labels, save_dir)


In [None]:
#Confusion Matrix
def plot_confusion_matrix(cm, class_labels, title='Confusion Matrix', cmap=plt.cm.Blues, save_dir=None):
    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(class_labels))
    plt.xticks(tick_marks, class_labels, rotation=45)
    plt.yticks(tick_marks, class_labels)
    
    thresh = cm.max() / 2.0
    for i, j in np.ndindex(cm.shape):
        plt.text(j, i, f"{cm[i, j]}", horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    if save_dir:
        plt.savefig(f"{save_dir}/confusion_matrix.png", dpi=600, bbox_inches='tight')
    plt.show()

# ROC-AUC 
def plot_roc_auc_curves(true_labels, predictions, class_labels, save_dir):
    binarized_labels = tf.keras.utils.to_categorical(true_labels, num_classes=len(class_labels))
    plt.figure(figsize=(12, 8))
    for i in range(len(class_labels)):
        fpr, tpr, _ = roc_curve(binarized_labels[:, i], predictions[:, i])
        roc_auc = roc_auc_score(binarized_labels[:, i], predictions[:, i])
        plt.plot(fpr, tpr, lw=2, label=f'ROC Curve for class {class_labels[i]} (area = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC-AUC Curves')
    plt.legend(loc='lower right')
    plt.grid(True)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    plt.savefig(os.path.join(save_dir, 'roc_auc_curve.png'), dpi=600, bbox_inches='tight')
    plt.show()

# Precision-Recall
def plot_precision_recall_curves(true_labels, predictions, class_labels, save_dir):
    binarized_labels = tf.keras.utils.to_categorical(true_labels, num_classes=len(class_labels))
    plt.figure(figsize=(12, 8))
    for i in range(len(class_labels)):
        precision, recall, _ = precision_recall_curve(binarized_labels[:, i], predictions[:, i])
        pr_auc = tf.keras.metrics.AUC()(tf.convert_to_tensor(recall, dtype=tf.float32), tf.convert_to_tensor(precision, dtype=tf.float32)).numpy()
        plt.plot(recall, precision, lw=2, label=f'Precision-Recall Curve for class {class_labels[i]} (area = {pr_auc:.2f})')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curves')
    plt.legend(loc='lower left')
    plt.grid(True)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    plt.savefig(os.path.join(save_dir, 'precision_recall_curve.png'), dpi=600, bbox_inches='tight')
    plt.show()

plot_confusion_matrix(cm, class_labels, title='Confusion Matrix', save_dir=save_dir)
plot_roc_auc_curves(true_labels, predictions, class_labels, save_dir)
plot_precision_recall_curves(true_labels, predictions, class_labels, save_dir)

In [None]:
#lime

In [None]:
explainer = lime_image.LimeImageExplainer()

def model_predict(image_batch):
    return model.predict(image_batch)

class_labels = list(test_generator.class_indices.keys())

num_images = 1
class_images = {label: [] for label in class_labels}

for images, labels in test_generator:
    for i, label in enumerate(labels):
        label_index = np.argmax(label)
        class_name = class_labels[label_index]
        if len(class_images[class_name]) < num_images:
            class_images[class_name].append(images[i])
    if all(len(class_images[label]) >= num_images for label in class_labels):
        break
save_dir = "plots/lime"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
for class_name, images in class_images.items():
    for idx, img in enumerate(images):
        explanation = explainer.explain_instance(
            img.astype('double'),
            model_predict,
            top_labels=1,
            num_samples=1000
        )
        temp, mask = explanation.get_image_and_mask(
            explanation.top_labels[0],
            positive_only=False,
            num_features=10
        )
        plt.figure(figsize=(12, 6))
        
        plt.subplot(1, 2, 1)
        plt.imshow(img) 
        plt.title(f"Original: {class_name}")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))
        plt.title(f"Explained: {class_name} | Image {idx + 1}")
        plt.axis('off')
        
        save_path = os.path.join(save_dir, f"{class_name}_image_{idx + 1}.png")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()


In [None]:
#gradcam

In [None]:
class_labels = list(test_generator.class_indices.keys())

num_images = 1
class_images = {label: [] for label in class_labels}

for images, labels in test_generator:
    for i, label in enumerate(labels):
        label_index = np.argmax(label)
        class_name = class_labels[label_index]
        if len(class_images[class_name]) < num_images:
            class_images[class_name].append((images[i], label_index))  
    if all(len(class_images[label]) >= num_images for label in class_labels):
        break

save_dir = "plots/grad_cam"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

target_layer_name = [layer.name for layer in model.layers if 'conv2d' in layer.name][-1]

grad_cam = GradCAM()

for class_name, images in class_images.items():
    for idx, (img, label_index) in enumerate(images):
        data = ([img], None) 
        cam_result = grad_cam.explain(
            data, model, class_index=label_index, layer_name=target_layer_name
        )

        plt.figure(figsize=(12, 6))
        
        plt.subplot(1, 2, 1)
        plt.imshow(img)  
        plt.title(f"Original: {class_name}")
        plt.axis('off')
        plt.subplot(1, 2, 2)
        plt.imshow(img)  
        plt.imshow(cam_result, cmap='jet', alpha=0.5) 
        plt.title(f"GradCAM: {class_name} | Image {idx + 1}")
        plt.axis('off')
        save_path = os.path.join(save_dir, f"{class_name}_image_{idx + 1}.png")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()

In [None]:
#shape

In [None]:
batch_size = 10 

test_data = []

try:
    while len(test_data) < batch_size:
        images, _ = next(test_generator)
        test_data.extend(images)
except StopIteration:
    print("Data generator is exhausted. Collected:", len(test_data), "samples.")

if len(test_data) > 0:
    test_data = np.array(test_data[:batch_size])
else:
    raise ValueError("No data available from the generator. Please check the generator configuration.")

print("Test data shape:", test_data.shape)

masker = shap.maskers.Image("inpaint_telea", test_data[0].shape)  
explainer = shap.Explainer(model, masker)  

print("Generating SHAP values...")
shap_values = explainer(test_data)  
print("SHAP values generated successfully!")

overall_shap_values = np.sum(np.abs(shap_values.values), axis=-1)  

flattened_shap_values = overall_shap_values.reshape(test_data.shape[0], -1)
flattened_test_data = test_data.reshape(test_data.shape[0], -1) 

print("Plotting SHAP summary plot...")
plt.figure(figsize=(10, 6))
shap.summary_plot(
    flattened_shap_values,
    flattened_test_data,
    plot_type="dot", 
    show=False
)
plt.title("SHAP Summary Plot (Overall)")

save_dir = "plots/shap_summary"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

save_path = os.path.join(save_dir, "shap_summary_overall.png")
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()

print(f"Overall SHAP summary plot saved at {save_path}")

In [None]:
#ex

In [None]:
# Display the first few images with true and predicted labels
for i in range(20):  # Display the first 20 images
    plt.figure()
    plt.imshow(image.load_img(os.path.join(test_dir, test_generator.filenames[i])))
    plt.title(f"True Label: {class_labels[true_labels[i]]}, Predicted Label: {class_labels[np.argmax(predictions[i])]}")
    plt.axis('off')
    plt.show()

In [None]:
import gc

K.clear_session()
gc.collect()
del model

In [None]:
from numba import cuda

cuda.select_device(0)
cuda.close()