In [1]:
import os
import torch
import pandas as pd
import numpy as np
from collections import defaultdict
from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd, ScaleIntensityRanged, Resized, ToTensord
from monai.data import DataLoader, CacheDataset
import torch.nn as nn
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, classification_report
from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm

data_path = "/projects/b1038/Pulmonary/ksenkow/CLAD_serial_CT/data/6multiplied"
labels_file = "/projects/b1038/Pulmonary/ksenkow/CLAD_serial_CT/data/v2_analysis/01gather_data/mortality_metadata.csv"
logdir = "./logs/swin_unetr_unfrozen" 
num_classes = 2
k_folds = 4

data = pd.read_csv(labels_file, index_col=0)
data['path'] = f"{data_path}/" + data['Patient'] + "/" + data['filename'] + ".nii.gz"
data = data[['path', 'mortality_12m', 'Patient']]

def prepare_data(df):
    return [{"image": row["path"], "label": row["mortality_12m"]} for _, row in df.iterrows()]

# stratified group k fold
targets = data.groupby("Patient")["mortality_12m"].max().values
stratified_kfold = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=42)
patient_to_idx = defaultdict(list)
for idx, patient in enumerate(data["Patient"].values):
    patient_to_idx[patient].append(idx)

# same standardization as training
transform = Compose([
    LoadImaged(keys=["image"]),
    EnsureChannelFirstd(keys=["image"]),
    ScaleIntensityRanged(keys=["image"],
                         a_min=-175, a_max=250, 
                         b_min=0.0, b_max=1.0, clip=True),
    Resized(keys=["image"], spatial_size=(96, 96, 96)),
    ToTensord(keys=["image"]),
])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from monai.networks.nets import SwinUNETR
def get_model():
    model = SwinUNETR(img_size=(96, 96, 96),
                      in_channels=1,
                      out_channels=48,
                      feature_size=48,
                      use_checkpoint=True)
    model.classification_head = nn.Sequential(
        nn.AdaptiveAvgPool3d(1),
        nn.Flatten(start_dim=1, end_dim=-1),
        nn.Linear(48, num_classes)
    )
    return model

fold_auroc = []
fold_precision_class0 = []
fold_precision_class1 = []
fold_recall_class0 = []
fold_recall_class1 = []
fold_f1_class0 = []
fold_f1_class1 = []

# evaluation loop
for fold, (train_val_idx, test_idx) in enumerate(stratified_kfold.split(np.zeros(len(targets)), targets)):
    if fold == 3:
        continue
    print(f"evaluating fold {fold+1}/{k_folds}")
    test_patients = [list(patient_to_idx.keys())[i] for i in test_idx]
    test_df = data[data["Patient"].isin(test_patients)]
    test_data = prepare_data(test_df)
    
    test_ds = CacheDataset(data=test_data, transform=transform, cache_rate=1.0, num_workers=4)
    test_loader = DataLoader(test_ds, batch_size=5, shuffle=False, num_workers=4)
    
    # load model checkpoint for this fold
    model = get_model().to(device)
    model_path = os.path.join(logdir, f"best_model_fold_{fold}.pth")
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint)
    model.eval()
    
    all_outputs = []
    all_labels = []
    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"Fold {fold+1} Testing"):
            inputs = batch["image"].to(device)
            labels_batch = batch["label"].to(device)
            # forward pass through the classification head
            outputs = model.classification_head(model(inputs))
            all_outputs.append(outputs.cpu())
            all_labels.append(labels_batch.cpu())
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    
    # compute probabilities & predictions
    probs = torch.softmax(all_outputs, dim=1).numpy()
    preds = probs.argmax(axis=1)
    true_labels = all_labels.numpy()
    
    # compute AUROC (using probability for positive class, index 1)
    auroc = roc_auc_score(true_labels, probs[:, 1])
    fold_auroc.append(auroc)
    
    # compute per class precision, recall, and F1
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, preds, average=None)
    fold_precision_class0.append(precision[0])
    fold_precision_class1.append(precision[1])
    fold_recall_class0.append(recall[0])
    fold_recall_class1.append(recall[1])
    fold_f1_class0.append(f1[0])
    fold_f1_class1.append(f1[1])
    
    print(f"Fold {fold+1} AUROC: {auroc:.4f}")
    print(classification_report(true_labels, preds))

# ensemble metrics (mean ± std)
mean_auroc = np.mean(fold_auroc)
std_auroc = np.std(fold_auroc)
mean_prec0 = np.mean(fold_precision_class0)
std_prec0 = np.std(fold_precision_class0)
mean_prec1 = np.mean(fold_precision_class1)
std_prec1 = np.std(fold_precision_class1)
mean_recall0 = np.mean(fold_recall_class0)
std_recall0 = np.std(fold_recall_class0)
mean_recall1 = np.mean(fold_recall_class1)
std_recall1 = np.std(fold_recall_class1)
mean_f1_0 = np.mean(fold_f1_class0)
std_f1_0 = np.std(fold_f1_class0)
mean_f1_1 = np.mean(fold_f1_class1)
std_f1_1 = np.std(fold_f1_class1)


table_data_classes = {
    "Class": ["0 (Negative)", "1 (Positive)"],
    "Precision (mean±std)": [
        f"{mean_prec0:.4f} ± {std_prec0:.4f}",
        f"{mean_prec1:.4f} ± {std_prec1:.4f}"
    ],
    "Recall (mean±std)": [
        f"{mean_recall0:.4f} ± {std_recall0:.4f}",
        f"{mean_recall1:.4f} ± {std_recall1:.4f}"
    ],
    "F1-Score (mean±std)": [
        f"{mean_f1_0:.4f} ± {std_f1_0:.4f}",
        f"{mean_f1_1:.4f} ± {std_f1_1:.4f}"
    ]
}
df_classes = pd.DataFrame(table_data_classes)
print("class specific metrics across 5 folds")
print(df_classes.to_string(index=False))

# ensembl results
summary_data = {
    "Model": ["Swin UNETR (Unfrozen)"],
    "AUROC": [f"{mean_auroc:.4f} ± {std_auroc:.4f}"],
    "Sensitivity": [f"{mean_recall1:.4f} ± {std_recall1:.4f}"],  # recall for class 1
    "Specificity": [f"{mean_recall0:.4f} ± {std_recall0:.4f}"],  # recall for class 0
    "Precision": [f"{mean_prec1:.4f} ± {std_prec1:.4f}"],         # precision for class 1
    "F1-Score": [f"{mean_f1_1:.4f} ± {std_f1_1:.4f}"]
}
df_summary = pd.DataFrame(summary_data)
print("Ensembl summary")
print(df_summary.to_string(index=False))

  from .autonotebook import tqdm as notebook_tqdm


evaluating fold 1/4


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 510/510 [03:39<00:00,  2.32it/s]
  checkpoint = torch.load(model_path, map_location=device)
Fold 1 Testing: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 102/102 [00:25<00:00,  4.08it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 1 AUROC: 0.6431
              precision    recall  f1-score   support

           0       0.87      1.00      0.93       442
           1       0.00      0.00      0.00        68

    accuracy                           0.87       510
   macro avg       0.43      0.50      0.46       510
weighted avg       0.75      0.87      0.80       510

evaluating fold 2/4


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 417/417 [03:21<00:00,  2.07it/s]
  checkpoint = torch.load(model_path, map_location=device)
Fold 2 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:18<00:00,  4.54it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 2 AUROC: 0.7406
              precision    recall  f1-score   support

           0       0.75      1.00      0.86       314
           1       0.00      0.00      0.00       103

    accuracy                           0.75       417
   macro avg       0.38      0.50      0.43       417
weighted avg       0.57      0.75      0.65       417

evaluating fold 3/4


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 411/411 [03:24<00:00,  2.01it/s]
  checkpoint = torch.load(model_path, map_location=device)
Fold 3 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:18<00:00,  4.43it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 3 AUROC: 0.5108
              precision    recall  f1-score   support

           0       0.83      1.00      0.91       342
           1       0.00      0.00      0.00        69

    accuracy                           0.83       411
   macro avg       0.42      0.50      0.45       411
weighted avg       0.69      0.83      0.76       411

class specific metrics across 5 folds
       Class Precision (mean±std) Recall (mean±std) F1-Score (mean±std)
0 (Negative)      0.8173 ± 0.0476   1.0000 ± 0.0000     0.8987 ± 0.0292
1 (Positive)      0.0000 ± 0.0000   0.0000 ± 0.0000     0.0000 ± 0.0000
Ensembl summary
                Model           AUROC     Sensitivity     Specificity       Precision        F1-Score
Swin UNETR (Unfrozen) 0.6315 ± 0.0942 0.0000 ± 0.0000 1.0000 ± 0.0000 0.0000 ± 0.0000 0.0000 ± 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
