In [None]:
import os
import time
import copy
import tqdm

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F

from torchvision.models import resnet50, densenet121, mobilenet_v2
from torchvision.datasets import ImageFolder 
from torch.utils.data import DataLoader
from torch.optim import AdamW

from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve, auc, f1_score

# Check GPU availability
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# When testing the model, you must set it to include only the original image
transform = transforms.Compose([
        transforms.Resize([224,224]),  
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
    ])

batch_size = 512

In [None]:
# Set up the file directory containing the mtf translation image
test_transfer = ImageFolder(root='../MTF_spl/test', transform=transform) 
test_loader = torch.utils.data.DataLoader(test_transfer, batch_size=batch_size, shuffle=True, num_workers=4)

class_names = test_transfer.classes

In [None]:
# Turn the model test 100 times to represent the median value of each result.
def evaluate(model, test_loader, num_evaluations=100):
    model.eval()  
    
    # List to store each evaluation result
    test_losses = []
    test_accuracies = []
    confusion_matrices = []
    roc_aucs = []
    f1_scores = []

    for _ in range(num_evaluations):
        test_loss, test_accuracy, cm, roc_auc, f1 = single_evaluation(model, test_loader)
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)
        confusion_matrices.append(cm)
        roc_aucs.append(roc_auc)
        f1_scores.append(f1)

    # Calculate the median values
    median_test_loss = np.median(test_losses)
    median_test_accuracy = np.median(test_accuracies)
    median_confusion_matrix = np.median(confusion_matrices, axis=0)
    median_roc_auc = np.median(roc_aucs)
    median_f1_score = np.median(f1_scores)

    return median_test_loss, median_test_accuracy, median_confusion_matrix, median_roc_auc, median_f1_score

def single_evaluation(model, test_loader):
    test_loss = 0 
    correct = 0   
    all_predictions = []
    all_targets = []
    
    with torch.no_grad(): 
        for data, target in test_loader:  
            data, target = data.to(device), target.to(device)  
            output = model(data) 
            
            test_loss += F.cross_entropy(output, target, reduction='sum').item() 
            
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item() 
            
            all_predictions.extend(pred.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
   
    test_loss /= len(test_loader.dataset) 
    test_accuracy = 100. * correct / len(test_loader.dataset) 
    
    cm = confusion_matrix(all_targets, all_predictions)
    
    f1 = f1_score(all_targets, all_predictions, average='weighted')
    
    fpr, tpr, _ = roc_curve(all_targets, all_predictions)
    roc_auc = auc(fpr, tpr)

    return test_loss, test_accuracy, cm, roc_auc, f1

In [None]:
ResNet50_epoch_100=torch.load('ResNet50_epoch_100.pt') 
ResNet50_epoch_100.eval()  
test_loss, test_accuracy, cm, roc_auc, f1  = evaluate(ResNet50_epoch_100, test_loader)

tn, fp, fn, tp = cm.ravel()
specificity = tn / (tn + fp)
sensitivity = tp / (tp + fn)

print('ResNet50_epoch_100 test acc:  ', test_accuracy)
print('ResNet50_epoch_100 test loss: ', test_loss)

print("Confusion Matrix:")
print(f"Specificity: {specificity}")
print(f"Sensitivity: {sensitivity}")
print(f"AUC: {roc_auc}")
print(f"f1 score : {f1}")

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="0.1f", cmap="Blues", 
            xticklabels=class_names, yticklabels=class_names, 
            annot_kws={"size":15})
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('ResNet50_epoch_100 Confusion Matrix')
plt.show()