In [None]:
# module1 Libraries
import os
import cv2
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from tensorflow.keras.losses import Loss
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.layers import (Input, Attention, Conv2D, MaxPooling2D, GlobalAveragePooling2D, Activation, Dense, GRU, Multiply, MaxPooling2D, BatchNormalization, Lambda,Bidirectional,
                                     ReLU, Add, Dropout, LSTM, MultiHeadAttention, Reshape,Flatten,LayerNormalization)
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing import image
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,
    confusion_matrix, roc_curve, precision_recall_curve, classification_report
)
from tensorflow.keras.initializers import he_normal
import tensorflow as tf
import shap
from lime import lime_image
from skimage.segmentation import mark_boundaries
from tf_explain.core.grad_cam import GradCAM
import json

In [None]:
# module2 Data Loading and Preprocessing
root_dir = r'D:\A\data\dataset4'
train_dir = os.path.join(root_dir, "train")
test_dir = os.path.join(root_dir, "test")

# module3 Image preprocessing and augmentation
def preprocess_image(image):

    h, w, _ = image.shape
    center_x, center_y = w // 2, h // 2
    crop_size = min(center_x, center_y)  
    cropped_image = image[
        center_y - crop_size:center_y + crop_size,
        center_x - crop_size:center_x + crop_size
    ]

    resized_image = cv2.resize(cropped_image, (224, 224))

    normalized_image = resized_image / 255.0  
    return normalized_image

train_datagen = ImageDataGenerator(
    horizontal_flip=True,
    rotation_range=120,         
    preprocessing_function=preprocess_image, 
    validation_split=0.1
)
test_datagen = ImageDataGenerator(rescale=1.0/255)

In [None]:
# module4 Load and preprocess the training and testing images
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    subset='training', 
    shuffle=True  
)

validation_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    subset='validation',
    shuffle=False
)

test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    shuffle=False
)
print("Class indices mapping:", train_generator.class_indices)


In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Add, Reshape, Dense, LSTM, MultiHeadAttention, LayerNormalization, Input, MaxPooling2D
from tensorflow.keras.initializers import he_normal
from tensorflow.keras.models import Model

def identity_block(X, kernel_size, filters, stage, block):
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base   = 'bn'  + str(stage) + block + '_branch'
    
    F1, F2, F3 = filters
    X_shortcut = X

    # 1x1
    X = Conv2D(F1, (1,1), strides=(1,1), padding='valid',
               name=conv_name_base + '2a', kernel_initializer=he_normal(seed=0))(X)
    X = BatchNormalization(axis=3, name=bn_name_base + '2a')(X)
    X = Activation('relu')(X)
    
    # 3x3
    X = Conv2D(F2, kernel_size, strides=(1,1), padding='same',
               name=conv_name_base + '2b', kernel_initializer=he_normal(seed=0))(X)
    X = BatchNormalization(axis=3, name=bn_name_base + '2b')(X)
    X = Activation('relu')(X)
    
    # 1x1
    X = Conv2D(F3, (1,1), strides=(1,1), padding='valid',
               name=conv_name_base + '2c', kernel_initializer=he_normal(seed=0))(X)
    X = BatchNormalization(axis=3, name=bn_name_base + '2c')(X)
    
    X = Add()([X, X_shortcut])
    X = Activation('relu')(X)
    
    return X

def convolutional_block(X, kernel_size, filters, stage, block, strides=(2,2)):
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base   = 'bn'  + str(stage) + block + '_branch'
    
    F1, F2, F3 = filters
    X_shortcut = X

    # 1x1
    X = Conv2D(F1, (1,1), strides=strides, padding='valid',
               name=conv_name_base + '2a', kernel_initializer=he_normal(seed=0))(X)
    X = BatchNormalization(axis=3, name=bn_name_base + '2a')(X)
    X = Activation('relu')(X)
    
    # 3x3
    X = Conv2D(F2, kernel_size, strides=(1,1), padding='same',
               name=conv_name_base + '2b', kernel_initializer=he_normal(seed=0))(X)
    X = BatchNormalization(axis=3, name=bn_name_base + '2b')(X)
    X = Activation('relu')(X)
    
    # 1x1
    X = Conv2D(F3, (1,1), strides=(1,1), padding='valid',
               name=conv_name_base + '2c', kernel_initializer=he_normal(seed=0))(X)
    X = BatchNormalization(axis=3, name=bn_name_base + '2c')(X)
    
    # shortcut
    X_shortcut = Conv2D(F3, (1,1), strides=strides, padding='valid',
                        name=conv_name_base + '1', kernel_initializer=he_normal(seed=0))(X_shortcut)
    X_shortcut = BatchNormalization(axis=3, name=bn_name_base + '1')(X_shortcut)
    
    X = Add()([X, X_shortcut])
    X = Activation('relu')(X)
    
    return X

def ResNet50_Light_MHA(input_shape=(224,224,3), classes=4, alpha=0.5):
    X_input = Input(input_shape)
    
    # Backbone from ResNet50
    X = Conv2D(64, (7,7), strides=(2,2), padding='same', name='conv1',
               kernel_initializer=he_normal(seed=0))(X_input)
    X = BatchNormalization(axis=3, name='bn_conv1')(X)
    X = Activation('relu')(X)
    X = MaxPooling2D((3,3), strides=(2,2), padding='same')(X)
    
    filters_list = [[64, 64, 256], [128, 128, 512], [256, 256, 1024], [512, 512, 2048]]
    strides_list = [(1,1), (2,2), (2,2), (2,2)]
    
    for stage, (filters, strides) in enumerate(zip(filters_list, strides_list), start=2):
        X = convolutional_block(X, (3,3), [int(f*alpha) for f in filters], stage, block='a', strides=strides)
        X = identity_block(X, (3,3), [int(f*alpha) for f in filters], stage, block='b')
        X = identity_block(X, (3,3), [int(f*alpha) for f in filters], stage, block='c')
    
    # Feature Extraction via EfficientNetB5 inspired layers
    X = Conv2D(512, (1,1), strides=(1,1), padding='same',
               name='reduce_channels', kernel_initializer=he_normal(seed=0))(X)
    X = BatchNormalization(axis=3, name='bn_reduce')(X)
    X = Activation('relu')(X)
    
    shape = tf.keras.backend.int_shape(X)  
    seq_len = shape[1] * shape[2]
    X_seq = Reshape((seq_len, shape[3]))(X)
    
    # Multi-Head Attention
    mha = MultiHeadAttention(num_heads=2, key_dim=64)
    X_attn = mha(query=X_seq, value=X_seq, key=X_seq)
    X_attn = Add()([X_seq, X_attn])
    X_attn = LayerNormalization()(X_attn)
    
    # GRU for temporal dependencies
    lstm_out = LSTM(128)(X_attn)
    
    outputs = Dense(classes, activation='softmax', name='fc')(lstm_out)
    model = Model(inputs=X_input, outputs=outputs)
    
    return model

model = ResNet50_Light_MHA()

# Display model summary
model.summary()

In [None]:
# module6 The loss function

class CauchyLoss(tf.keras.losses.Loss):
    def __init__(self, sigma=1.0, name="cauchy_loss"):
        super().__init__(name=name)
        self.sigma_sq = sigma ** 2  # 预计算 sigma^2

    def call(self, y_true, y_pred):
        error_sq = K.square(y_true - y_pred)  # (y_true - y_pred)^2
        loss = K.log(1 + error_sq / self.sigma_sq)  # Cauchy Loss 计算
        return K.mean(loss)  # 计算 batch 平均损失

    def get_config(self):
        return {"sigma": self.sigma_sq}

cauchy_loss = CauchyLoss(sigma=1.5)

#7 Compile the model
model.compile(optimizer=Adam(learning_rate=0.0002),
              loss=cauchy_loss,
              metrics=['accuracy'])

In [None]:
# module8 Callbacks

checkpoint = ModelCheckpoint('model_best.h5',  
                             monitor='val_loss',
                             save_best_only=True, 
                             mode='min') 

early_stopping = EarlyStopping(monitor='val_loss', 
                               patience=10, 
                               mode='min',
                               restore_best_weights=True) 

reduce_lr = ReduceLROnPlateau(monitor='val_loss',
                              factor=0.1, 
                              patience=5,  
                              min_lr=0.00002,
                              mode='min',
                              verbose=1 )


In [None]:
# module9 Train the model
history = model.fit(
    train_generator, 
    validation_data=validation_generator,
    steps_per_epoch = len(train_generator),
    validation_steps = len(validation_generator),
    epochs=45,  
    callbacks=[checkpoint, early_stopping, reduce_lr],
    workers=6,  
    max_queue_size=20
)


In [None]:
# module10 Load the best model
model = load_model('model_best.h5', custom_objects={'CauchyLoss': CauchyLoss})

In [None]:
# Model Evaluation

In [None]:
save_dir='plots'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

# Plot the training and validation loss and accuracy curves
plt.figure(figsize=(12, 5))
# Accuracy curves
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(save_dir, 'accuracy_plot.png'), dpi=600, bbox_inches='tight')
plt.show()
# Plot the training and validation loss and accuracy curves
plt.figure(figsize=(12, 5))
# Loss curves
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(save_dir, 'loss_plot.png'), dpi=600, bbox_inches='tight')
plt.show()

In [None]:
# Prediction and Inference

In [None]:
# Make predictions on the test data
predictions = model.predict(test_generator)  
true_labels = test_generator.classes  
class_labels = list(test_generator.class_indices.keys())  
predicted_classes = np.argmax(predictions, axis=1)

In [None]:
# Calculate confusion matrix
cm = confusion_matrix(true_labels, predicted_classes)

def evaluate_model(true_labels, predicted_classes, class_labels, save_dir):
    # Ensure save directory exists
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # Micro average accuracy
    micro_accuracy = accuracy_score(true_labels, predicted_classes)
    
    # Generate classification report
    classification_report_dict = classification_report(
        true_labels, predicted_classes, target_names=class_labels, output_dict=True, zero_division=0
    )

    # Calculate specificity and accuracy per class
    specificity_per_class = []
    accuracy_per_class = []
    total_tn, total_fp = 0, 0  # For micro-average specificity
    for i in range(len(class_labels)):
        true_negative = np.sum(cm) - (np.sum(cm[i, :]) + np.sum(cm[:, i]) - cm[i, i])
        false_positive = np.sum(cm[:, i]) - cm[i, i]
        specificity = true_negative / (true_negative + false_positive) if (true_negative + false_positive) > 0 else 0.0
        specificity_per_class.append(specificity)

        accuracy = cm[i, i] / np.sum(cm[i, :]) if np.sum(cm[i, :]) > 0 else 0.0
        accuracy_per_class.append(accuracy)

        total_tn += true_negative
        total_fp += false_positive

    # Calculate macro and micro average specificity
    macro_specificity = np.mean(specificity_per_class)
    micro_specificity = total_tn / (total_tn + total_fp) if (total_tn + total_fp) > 0 else 0.0

    # Save classification report as JSON
    json_save_path = os.path.join(save_dir, "classification_report.json")
    with open(json_save_path, "w") as f:
        json.dump(classification_report_dict, f, indent=4)
    print(f"Classification Report saved to {json_save_path}")

    # Print classification report
    print("Classification Report:")
    print(f"{'Class':<20} {'Precision':<12} {'Recall':<12} {'F1-Score':<12} {'Specificity':<12} {'Accuracy':<12} {'Support':<12}")
    for idx, cls in enumerate(class_labels):
        metrics = classification_report_dict[cls]
        print(f"{cls:<20} {metrics['precision']:<12.2f} {metrics['recall']:<12.2f} {metrics['f1-score']:<12.2f} {specificity_per_class[idx]:<12.2f} {accuracy_per_class[idx]:<12.2f} {metrics['support']:<12}")

    # Calculate macro average accuracy
    macro_accuracy = np.mean(accuracy_per_class)

    # Print micro average accuracy
    print(f"\nMicro Average Accuracy: {micro_accuracy:.2f}")
    
    # Print macro average accuracy
    print(f"Macro Average Accuracy: {macro_accuracy:.2f}")

    # Macro and Micro average metrics
    macro_precision = precision_score(true_labels, predicted_classes, average='macro', zero_division=0)
    macro_recall = recall_score(true_labels, predicted_classes, average='macro', zero_division=0)
    macro_f1 = f1_score(true_labels, predicted_classes, average='macro', zero_division=0)

    micro_precision = precision_score(true_labels, predicted_classes, average='micro', zero_division=0)
    micro_recall = recall_score(true_labels, predicted_classes, average='micro', zero_division=0)
    micro_f1 = f1_score(true_labels, predicted_classes, average='micro', zero_division=0)

    print("\nMacro Average Metrics:")
    print(f"Precision: {macro_precision:.2f}, Recall: {macro_recall:.2f}, F1-Score: {macro_f1:.2f}, Specificity: {macro_specificity:.2f}")

    print("\nMicro Average Metrics:")
    print(f"Precision: {micro_precision:.2f}, Recall: {micro_recall:.2f}, F1-Score: {micro_f1:.2f}, Specificity: {micro_specificity:.2f}")

evaluate_model(true_labels, predicted_classes, class_labels, save_dir)


In [None]:
#Confusion Matrix
def plot_confusion_matrix(cm, class_labels, title='Confusion Matrix', cmap=plt.cm.Blues, save_dir=None):
    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(class_labels))
    plt.xticks(tick_marks, class_labels, rotation=45)
    plt.yticks(tick_marks, class_labels)
    
    thresh = cm.max() / 2.0
    for i, j in np.ndindex(cm.shape):
        plt.text(j, i, f"{cm[i, j]}", horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    if save_dir:
        plt.savefig(f"{save_dir}/confusion_matrix.png", dpi=600, bbox_inches='tight')
    plt.show()

# ROC-AUC 
def plot_roc_auc_curves(true_labels, predictions, class_labels, save_dir):
    binarized_labels = tf.keras.utils.to_categorical(true_labels, num_classes=len(class_labels))
    plt.figure(figsize=(12, 8))
    for i in range(len(class_labels)):
        fpr, tpr, _ = roc_curve(binarized_labels[:, i], predictions[:, i])
        roc_auc = roc_auc_score(binarized_labels[:, i], predictions[:, i])
        plt.plot(fpr, tpr, lw=2, label=f'ROC Curve for class {class_labels[i]} (area = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC-AUC Curves')
    plt.legend(loc='lower right')
    plt.grid(True)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    plt.savefig(os.path.join(save_dir, 'roc_auc_curve.png'), dpi=600, bbox_inches='tight')
    plt.show()

# Precision-Recall
def plot_precision_recall_curves(true_labels, predictions, class_labels, save_dir):
    binarized_labels = tf.keras.utils.to_categorical(true_labels, num_classes=len(class_labels))
    plt.figure(figsize=(12, 8))
    for i in range(len(class_labels)):
        precision, recall, _ = precision_recall_curve(binarized_labels[:, i], predictions[:, i])
        pr_auc = tf.keras.metrics.AUC()(tf.convert_to_tensor(recall, dtype=tf.float32), tf.convert_to_tensor(precision, dtype=tf.float32)).numpy()
        plt.plot(recall, precision, lw=2, label=f'Precision-Recall Curve for class {class_labels[i]} (area = {pr_auc:.2f})')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curves')
    plt.legend(loc='lower left')
    plt.grid(True)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    plt.savefig(os.path.join(save_dir, 'precision_recall_curve.png'), dpi=600, bbox_inches='tight')
    plt.show()

plot_confusion_matrix(cm, class_labels, title='Confusion Matrix', save_dir=save_dir)
plot_roc_auc_curves(true_labels, predictions, class_labels, save_dir)
plot_precision_recall_curves(true_labels, predictions, class_labels, save_dir)

In [None]:
#lime

In [None]:
explainer = lime_image.LimeImageExplainer()

def model_predict(image_batch):
    return model.predict(image_batch)

class_labels = list(test_generator.class_indices.keys())

num_images = 1
class_images = {label: [] for label in class_labels}

for images, labels in test_generator:
    for i, label in enumerate(labels):
        label_index = np.argmax(label)
        class_name = class_labels[label_index]
        if len(class_images[class_name]) < num_images:
            class_images[class_name].append(images[i])
    if all(len(class_images[label]) >= num_images for label in class_labels):
        break
save_dir = "plots/lime"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
for class_name, images in class_images.items():
    for idx, img in enumerate(images):
        explanation = explainer.explain_instance(
            img.astype('double'),
            model_predict,
            top_labels=1,
            num_samples=1000
        )
        temp, mask = explanation.get_image_and_mask(
            explanation.top_labels[0],
            positive_only=False,
            num_features=10
        )
        plt.figure(figsize=(12, 6))
        
        plt.subplot(1, 2, 1)
        plt.imshow(img) 
        plt.title(f"Original: {class_name}")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))
        plt.title(f"Explained: {class_name} | Image {idx + 1}")
        plt.axis('off')
        
        save_path = os.path.join(save_dir, f"{class_name}_image_{idx + 1}.png")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()


In [None]:
#gradcam

In [None]:
class_labels = list(test_generator.class_indices.keys())

num_images = 1
class_images = {label: [] for label in class_labels}

for images, labels in test_generator:
    for i, label in enumerate(labels):
        label_index = np.argmax(label)
        class_name = class_labels[label_index]
        if len(class_images[class_name]) < num_images:
            class_images[class_name].append((images[i], label_index))  
    if all(len(class_images[label]) >= num_images for label in class_labels):
        break

save_dir = "plots/grad_cam"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

target_layer_name = [layer.name for layer in model.layers if 'conv2d' in layer.name][-1]

grad_cam = GradCAM()

for class_name, images in class_images.items():
    for idx, (img, label_index) in enumerate(images):
        data = ([img], None) 
        cam_result = grad_cam.explain(
            data, model, class_index=label_index, layer_name=target_layer_name
        )

        plt.figure(figsize=(12, 6))
        
        plt.subplot(1, 2, 1)
        plt.imshow(img)  
        plt.title(f"Original: {class_name}")
        plt.axis('off')
        plt.subplot(1, 2, 2)
        plt.imshow(img)  
        plt.imshow(cam_result, cmap='jet', alpha=0.5) 
        plt.title(f"GradCAM: {class_name} | Image {idx + 1}")
        plt.axis('off')
        save_path = os.path.join(save_dir, f"{class_name}_image_{idx + 1}.png")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()