In [2]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import seaborn as sns
import pandas as pd
from sklearn.metrics import classification_report, confusion_matrix
from my_models import my_ResNet_CNN
from my_models_simple import my_ResNet_CNN_simple

from collections import defaultdict
from sklearn.metrics import f1_score, roc_curve, auc
from dataloader import CustomImageDataset
from sklearn.metrics import average_precision_score
from sklearn.metrics import confusion_matrix


  from .autonotebook import tqdm as notebook_tqdm


In [6]:

# MILSMA 4-4
model_path = 'Experiments_log/model_weights_105.pth'

# MILSMA 12-3
#model_path = 'Experiments_log/20230811_024433/model_weights_4.pth'

model = my_ResNet_CNN()
model.load_state_dict(torch.load(model_path))
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Map class to 0/1
class_to_idx = {'sma': 1, 'non-sma': 0} # Dictionaty to assign 1 to 'sma' and 0 to 'non-sma' samples

transform_augm = transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor()
])

new_path='edof_new_sma_2023_cells_cleaned'

# Create dataset
dataset = CustomImageDataset(root_dir=new_path, class_to_idx=class_to_idx, min_number_images=20, transform=transform_augm)

test_loader = DataLoader(dataset, batch_size=1, shuffle=False)  # Adjust batch size as needed

df_samples = pd.read_csv('new_data1.csv')

data = []

with torch.no_grad():
        for imgs, labels, img_paths in test_loader:
            imgs = [tensor.to(device) for tensor in imgs]

            # imgs = imgs.to(device)
            labels = labels.to(device)
            outputs = model(imgs, mode='test')
            probs = outputs.cpu().numpy()
            for img_path, label, prob in zip(img_paths, labels, probs):
                data.append([img_path, label.item(), prob[0]])

img_class_df = pd.DataFrame(data, columns=['image_path', 'true_label', 'predicted_probability'])

img_class_df['Diagnosis'] = img_class_df['image_path'].map(df_samples.set_index('FASt-Mal-Code')['Diagnosis'])

img_class_df['predicted label'] = np.round(img_class_df['predicted_probability']).astype(int)

combined_df = img_class_df


def calc_sens_spec_acc(df, pos_label, neg_label, col='true_label'):
    true_positives = df[(df[col] == pos_label) & (df['predicted label'] == pos_label)].shape[0]
    true_negatives = df[(df[col] == neg_label) & (df['predicted label'] == neg_label)].shape[0]
    false_positives = df[(df[col] == neg_label) & (df['predicted label'] == pos_label)].shape[0]
    false_negatives = df[(df[col] == pos_label) & (df['predicted label'] == neg_label)].shape[0]

    sensitivity = true_positives / (true_positives + false_negatives)
    specificity = true_negatives / (true_negatives + false_positives)
    accuracy = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives)

    return sensitivity, specificity, accuracy

# Assuming your dataframe is called combined_df:

print('For classes:')
# Calculate sensitivity, specificity, and accuracy for each unique 'true label'
for label in combined_df['true_label'].unique():
    sensitivity, specificity, accuracy = calc_sens_spec_acc(combined_df, label, 1-label)
    label_name = 'SMA' if label == 1 else 'Non-SMA'
    print(f"For {label_name}: Sensitivity = {sensitivity:.2f}, Specificity = {specificity:.2f}, Accuracy = {accuracy:.2f}")

print('For subclasses:')
# Calculate sensitivity, specificity, and accuracy for each unique 'Diagnosis'
for diagnosis in combined_df['Diagnosis'].unique():
    if diagnosis != 'Severe Malaria Anaemia':
        subset = combined_df[(combined_df['Diagnosis'] == diagnosis) | (combined_df['Diagnosis'] == 'Severe Malaria Anaemia')]
        sensitivity, specificity, accuracy = calc_sens_spec_acc(subset, 1, 0)
        print(f"For {diagnosis}: Sensitivity = {sensitivity:.2f}, Specificity = {specificity:.2f}, Accuracy = {accuracy:.2f}")

For classes:
For SMA: Sensitivity = 0.82, Specificity = 0.38, Accuracy = 0.58
For Non-SMA: Sensitivity = 0.38, Specificity = 0.82, Accuracy = 0.58
For subclasses:
For No Malaria, Severe Anaemia: Sensitivity = 0.82, Specificity = 0.29, Accuracy = 0.61
For No Malaria, Anaemia: Sensitivity = 0.82, Specificity = 0.00, Accuracy = 0.69
For No Malaria, No Anaemia: Sensitivity = 0.82, Specificity = 1.00, Accuracy = 0.86
For Malaria, No Anaemia: Sensitivity = 0.82, Specificity = 0.00, Accuracy = 0.75
