In [7]:
import os
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from my_models import my_ResNet_CNN  # Assuming we're using resnet18, but you can use any model

# 1. Define the dataset and dataloader

data_transforms= transforms.Compose([
            transforms.Resize((128, 128)),
            # transforms.RandomRotation(90),
            # transforms.ColorJitter(),
            transforms.ToTensor()
        ])

# Assuming the provided directory is stored in 'root_dir'
root_dir = 'old/Red_Cell_Morphology_clean/SMA_cells/labelled'  # replace with your path
datasets = ImageFolder(root_dir, transform=data_transforms)
dataloader = DataLoader(datasets, batch_size=1, shuffle=False, num_workers=1)
print(datasets.class_to_idx)

# Label map based on the folder structure you provided
label_map = {"SMA": 1, "Non-SMA": 0}

# 2. Load the model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = my_ResNet_CNN()
model.load_state_dict(torch.load('Experiments_log/model_weights_42.pth'))
model.eval()
model = model.to(device)

# 3. Classify the images and calculate metrics

correct = 0
total = 0
true_positives = 0
true_negatives = 0
false_positives = 0
false_negatives = 0

with torch.no_grad():
    for imgs, labels in dataloader:
        # print(imgs.shape)  # Add this
        imgs=imgs.unsqueeze(0)
        imgs = imgs.to(device)
        labels = labels.to(device)
        outputs = model(imgs, mode='test')
        # print(outputs)
        # print(outputs)
        predicted = outputs.data.round().int().squeeze(0)
        labels = labels
        # print(predicted)
        # print(labels)
        total += labels.size(0)

        if predicted==labels:
            # print('pred', predicted)
            # print('labels', labels)
            correct += 1
        
        for pred, true_label in zip(predicted, labels):
            if true_label == label_map["SMA"]:
                if pred == true_label:
                    true_positives += 1
                else:
                    false_negatives += 1
            else:
                if pred == true_label:
                    true_negatives += 1
                else:
                    false_positives += 1

accuracy = correct / total
sensitivity = true_positives / (true_positives + false_negatives)
specificity = true_negatives / (true_negatives + false_positives)

print(f"Accuracy: {accuracy:.2f}")
print(f"Sensitivity: {sensitivity:.2f}")
print(f"Specificity: {specificity:.2f}")


{'Non-SMA': 0, 'SMA': 1}
Accuracy: 0.46
Sensitivity: 0.01
Specificity: 0.99
