In [1]:
import torch
import torch.nn as nn
import numpy as np
from efficientNetV2 import MRIClassifier
import pickle
import glob
import nibabel as nib
from torchvision import transforms
import torch.nn.functional as F
from sklearn.metrics import (
    roc_auc_score, 
    precision_recall_curve, 
    auc, 
    precision_score, 
    recall_score, 
    f1_score, 
    confusion_matrix,
    roc_curve,
    average_precision_score,
    accuracy_score
)

In [2]:
synth_nii_files = glob.glob("/space/mcdonald-syn01/1/projects/jsawant/DSC250/nii_test/*.nii")
device = "cuda:0"

In [3]:
model = MRIClassifier(dropout_rate=0.5).to(device)
checkpoint = torch.load("/space/mcdonald-syn01/1/projects/jsawant/DSC250/classifier/runs/run_20250305_143032_efficientNetV2/fold_1/best_model.pth", 
                        map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
del checkpoint

  checkpoint = torch.load("/space/mcdonald-syn01/1/projects/jsawant/DSC250/classifier/runs/run_20250305_143032_efficientNetV2/fold_1/best_model.pth",


In [4]:
def preprocess(image):
    image = image[:,9:124,:]
    image = torch.Tensor(image)
    image = image.permute(1,0,2)
    return image.unsqueeze(0)

In [10]:
ground_truths = []
preds = {}
scores = {}
with torch.no_grad():
    for file in synth_nii_files:
        image = nib.load(file)
        image = image.get_fdata()
        #image = (image - np.mean(image)) / np.std(image)    # z-norm
        # Get slices from 10 to 125 
        image = preprocess(image).to(device)
        # label = synth_file['label']
        score,_ = model(image)
        score = F.sigmoid(score)
        pred = 1 if score.item() > 0.50 else 0
        #ground_truths.append(label)
        preds[file.split('/')[-1]] = pred
        scores[file.split('/')[-1]] = score.item()

#ground_truths = np.array(ground_truths)
# preds = np.array(preds)
# scores = np.array(scores)
        
                
    

In [16]:
metrics = {
    'accuracy': accuracy_score(ground_truths, preds),
    'auc_roc': roc_auc_score(ground_truths, scores),
    'auc_pr': average_precision_score(ground_truths, scores),
    'ppv': precision_score(ground_truths, preds),
    'sensitivity': recall_score(ground_truths, preds),
    'specificity': recall_score(ground_truths, preds, pos_label=0),
    'precision': precision_score(ground_truths, preds),
    'recall': recall_score(ground_truths, preds),
    'f1_score': f1_score(ground_truths, preds),
}
metrics

{'accuracy': 0.25,
 'auc_roc': 1.0,
 'auc_pr': 1.0,
 'ppv': 0.25,
 'sensitivity': 1.0,
 'specificity': 0.0,
 'precision': 0.25,
 'recall': 1.0,
 'f1_score': 0.4}

In [11]:
preds

{'7.nii': 1,
 '37.nii': 0,
 '14.nii': 0,
 '40.nii': 0,
 '4.nii': 0,
 '43.nii': 0,
 '9.nii': 0,
 '1.nii': 0,
 '20.nii': 0,
 '5.nii': 0,
 '24.nii': 0,
 '11.nii': 0,
 '35.nii': 1,
 '29.nii': 1,
 '6.nii': 1,
 '32.nii': 1,
 '47.nii': 1,
 '12.nii': 1,
 '36.nii': 1,
 '30.nii': 1,
 '16.nii': 1,
 '25.nii': 1,
 '13.nii': 1,
 '44.nii': 1,
 '18.nii': 1,
 '2.nii': 0,
 '17.nii': 0,
 '45.nii': 0,
 '26.nii': 0,
 '48.nii': 0,
 '19.nii': 1,
 '31.nii': 1,
 '21.nii': 0,
 '15.nii': 1,
 '8.nii': 0,
 '39.nii': 0,
 '38.nii': 1,
 '28.nii': 1,
 '42.nii': 1,
 '23.nii': 1,
 '27.nii': 1,
 '46.nii': 1,
 '33.nii': 1,
 '41.nii': 1,
 '10.nii': 1,
 '34.nii': 1,
 '3.nii': 0,
 '22.nii': 1}

In [12]:
scores

{'7.nii': 0.8529288172721863,
 '37.nii': 0.21806910634040833,
 '14.nii': 0.10483557730913162,
 '40.nii': 0.3451373279094696,
 '4.nii': 0.0553562268614769,
 '43.nii': 0.1448381394147873,
 '9.nii': 0.1828383356332779,
 '1.nii': 0.26868191361427307,
 '20.nii': 0.29189783334732056,
 '5.nii': 0.1257357895374298,
 '24.nii': 0.30551886558532715,
 '11.nii': 0.3618048131465912,
 '35.nii': 0.9861770272254944,
 '29.nii': 0.9922845363616943,
 '6.nii': 0.6665353178977966,
 '32.nii': 0.997718870639801,
 '47.nii': 0.995389461517334,
 '12.nii': 0.9931092262268066,
 '36.nii': 0.987528383731842,
 '30.nii': 0.9494888782501221,
 '16.nii': 0.9711605310440063,
 '25.nii': 0.9952627420425415,
 '13.nii': 0.9712305665016174,
 '44.nii': 0.9926679730415344,
 '18.nii': 0.7629279494285583,
 '2.nii': 0.22633160650730133,
 '17.nii': 0.05715594440698624,
 '45.nii': 0.4523855149745941,
 '26.nii': 0.18338485062122345,
 '48.nii': 0.47499752044677734,
 '19.nii': 0.8880132436752319,
 '31.nii': 0.6575129628181458,
 '21.nii'

In [13]:
import pandas as pd
csv = "/space/mcdonald-syn01/1/projects/jsawant/DSC250/qualitative_results.csv"
df = pd.read_csv(csv)

In [22]:
scores_arr = []
preds_arr = []
gts_arr = []
for i in range(len(df)):
    if df['Real/synthetic'].iloc[i]=="Synthetic":
        file = str(df['ID'].iloc[i]) + '.nii'
        pred = preds[file]
        score = scores[file]
        gt = 0 if df['label'].iloc[i]=="HC" else 1
        scores_arr.append(score)
        preds_arr.append(pred)
        gts_arr.append(gt)

scores_arr = np.array(scores_arr)
preds_arr = np.array(preds_arr)
gts_arr = np.array(gts_arr)

In [23]:
np.sum(gts_arr==preds_arr)/24

0.7916666666666666

In [28]:
auc =  roc_auc_score(gts_arr, scores_arr)
auc

0.8819444444444444