In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os

from utils.utils import load_data
from utils.cnn_utils import datasetT1, sigmoid_focal_loss
from utils.cnn_train import eval
from utils.cnn_model import SFCN

from sklearn.model_selection import train_test_split
from sklearn.utils import resample
from sklearn.metrics import average_precision_score, roc_auc_score, brier_score_loss, f1_score, hamming_loss

import torch
from torch import nn
from torch.utils.data import DataLoader

In [None]:
user_dir = os.path.expanduser("~")
source_path = user_dir + '/t1images/' # on 48860 it is /scratch/hbnetdata/t1images/ 
checkpoints_path = "checkpoints/"

In [1]:
device_index = 0 # or 1 or 2 or 3
device = "cuda:" + str(device_index)

In [16]:
# Load data for evaluation
X, _, Y = load_data('classification_t1')
print(f"Size of T1 set: {X.shape[0]}")

Size of T1 set: 2491


In [17]:
_, X_test, _, Y_test = train_test_split(X.iloc[:,0], Y.iloc[:,1:], test_size=0.25, random_state=0)
print(f"Size test set: {X_test.shape[0]}")

Size training set: 1868
Size test set: 623


In [None]:
modality = 'T1w' # 'GM', 'WM', 'CSF'

In [None]:
test_data = datasetT1(X_test, Y_test, modality=modality, source_path=source_path)

---

In [None]:
batch_size = 8

In [None]:
model = SFCN(output_dim=13)
model.to(device)
model.load_state_dict(torch.load(checkpoints_path + "model.pth"))

In [None]:
loss_string = 'bce' 

if loss_string == 'bce':
    loss_fn = nn.BCEWithLogitsLoss()
elif loss_string == 'focal':
    loss_fn = sigmoid_focal_loss()

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

---

In [None]:
auprc = []
auroc = []
brier = []
hamm = []
f1 = []

eval_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

for i in range(100):
    X_test_resampled, y_test_resampled = resample(X_test, Y_test, replace=True, n_samples=len(Y_test), random_state=0+i)

    eval_data = datasetT1(X_test_resampled, y_test_resampled)
    eval_dataloader = DataLoader(eval_data, batch_size=batch_size, shuffle=False)
    y_prob, y_pred  = eval(eval_dataloader, device, model, loss_fn)

    # Compute brier score
    brier_scores = np.zeros(y_prob.shape[1])
    for i in range(y_prob.shape[1]):
        brier_scores[i] = brier_score_loss(y_test_resampled.iloc[:,i], y_prob[:,i])
    brier.append(brier_scores.mean())
    
    # Other metrics
    auprc.append(average_precision_score(y_test_resampled, y_prob, average='macro'))
    auroc.append(roc_auc_score(y_test_resampled, y_prob, average='macro'))
    f1.append(f1_score(y_test_resampled, y_pred, average='micro'))
    hamm.append(hamming_loss(y_test_resampled, y_pred))

print(f"Mean scores for 3D-CNN with 95% confidence intervals:")
print("    AUPRC macro: {:.2f} [{:.2f}, {:.2f}]".format(np.mean(auprc), np.percentile(auprc, 2.5), np.percentile(auprc, 97.5)))
print("    AUROC macro: {:.2f} [{:.2f}, {:.2f}]".format(np.mean(auroc), np.percentile(auroc, 2.5), np.percentile(auroc, 97.5)))
print("    Brier score: {:.2f} [{:.2f}, {:.2f}]".format(np.mean(brier), np.percentile(brier, 2.5), np.percentile(brier, 97.5)))
print("    Hamming loss: {:.2f} [{:.2f}, {:.2f}]".format(np.mean(hamm), np.percentile(hamm, 2.5), np.percentile(hamm, 97.5)))
print("    Micro Avg F1 score: {:.2f} [{:.2f}, {:.2f}]".format(np.mean(f1), np.percentile(f1, 2.5), np.percentile(f1, 97.5)))