In [1]:
import os
import torch
import pandas as pd
import numpy as np
from collections import defaultdict, OrderedDict
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd,
    ScaleIntensityRanged, Resized, ToTensord
)
from monai.networks.nets import SwinUNETR
from monai.data import DataLoader, CacheDataset
from sklearn.metrics import (
    classification_report, roc_auc_score,
    precision_recall_fscore_support
)
from sklearn.model_selection import StratifiedKFold
import torch.nn as nn
from tqdm import tqdm

data_path = "/projects/b1038/Pulmonary/ksenkow/CLAD_serial_CT/data/6multiplied"
logdir = "./logs/swin_unetr"  
data_csv = "/projects/b1038/Pulmonary/ksenkow/CLAD_serial_CT/data/v2_analysis/01gather_data/mortality_metadata.csv"

data = pd.read_csv(data_csv, index_col=0)
data['path'] = data_path + '/' + data['Patient'] + '/' + data['filename'] + '.nii.gz'
data = data[['path', 'mortality_12m', 'Patient']]

# stratified group k fold
k_folds = 5
targets = data.groupby("Patient")["mortality_12m"].max().values
stratified_kfold = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=42)
patient_to_idx = defaultdict(list)
for idx, patient in enumerate(data["Patient"].values):
    patient_to_idx[patient].append(idx)

# same standardization as training
transform = Compose([
    LoadImaged(keys=["image"]),
    EnsureChannelFirstd(keys=["image"]),
    ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
    Resized(keys=["image"], spatial_size=(96, 96, 96)),
    ToTensord(keys=["image"]),
])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


fold_auroc = []
fold_precision_class0 = []
fold_precision_class1 = []
fold_recall_class0 = []
fold_recall_class1 = []
fold_f1_class0 = []
fold_f1_class1 = []

# evaluate each Fold
for fold, (_, test_patient_idx) in enumerate(stratified_kfold.split(np.zeros(len(targets)), targets)):
    print(f"evaluating fold {fold+1}/{k_folds}")
    
    test_patients = [list(patient_to_idx.keys())[i] for i in test_patient_idx]
    test_data = data[data['Patient'].isin(test_patients)]
    test_files = [
        {"image": row["path"], "label": row["mortality_12m"]}
        for _, row in test_data.iterrows()
    ]

    test_ds = CacheDataset(data=test_files, transform=transform, cache_rate=1.0, num_workers=8)
    test_loader = DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=8, pin_memory=False)

    # define model & load checkpoint
    model = SwinUNETR(
        img_size=(96, 96, 96),
        in_channels=1,
        out_channels=48,
        feature_size=48,
        use_checkpoint=True
    )

    model.classification_head = nn.Sequential(
        nn.AdaptiveAvgPool3d(1),
        nn.Flatten(start_dim=1, end_dim=-1),
        nn.Linear(48, 2)
    )
    model_path = os.path.join(logdir, f"best_model_fold_{fold}.pth")
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    all_outputs = []
    all_labels = []
    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"Testing Fold {fold+1}"):
            inputs = batch["image"].to(device)
            labels = batch["label"].to(device)
            logits = model.classification_head(model(inputs))
            all_outputs.append(logits.cpu())
            all_labels.append(labels.cpu())

    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    probabilities = torch.softmax(all_outputs, dim=1).numpy()
    predictions = probabilities.argmax(axis=1)
    true_labels = all_labels.numpy()

    # AUROC (for the positive class, index=1)
    auroc = roc_auc_score(true_labels, probabilities[:, 1])
    fold_auroc.append(auroc)

    # precision, recall, F1 per class
    precision, recall, f1, _ = precision_recall_fscore_support(
        true_labels, predictions, average=None
    )
    fold_precision_class0.append(precision[0])
    fold_precision_class1.append(precision[1])
    fold_recall_class0.append(recall[0])
    fold_recall_class1.append(recall[1])
    fold_f1_class0.append(f1[0])
    fold_f1_class1.append(f1[1])

    print(f"Fold {fold+1} - AUROC: {auroc:.4f}")
    print(classification_report(true_labels, predictions))


mean_auroc, std_auroc = np.mean(fold_auroc), np.std(fold_auroc)

mean_prec0, std_prec0 = np.mean(fold_precision_class0), np.std(fold_precision_class0)
mean_prec1, std_prec1 = np.mean(fold_precision_class1), np.std(fold_precision_class1)
mean_recall0, std_recall0 = np.mean(fold_recall_class0), np.std(fold_recall_class0)
mean_recall1, std_recall1 = np.mean(fold_recall_class1), np.std(fold_recall_class1)
mean_f1_0, std_f1_0 = np.mean(fold_f1_class0), np.std(fold_f1_class0)
mean_f1_1, std_f1_1 = np.mean(fold_f1_class1), np.std(fold_f1_class1)


print("Class Specific Metrics Across 5 Folds:")
table_data_classes = {
    "Class": ["0 (Negative)", "1 (Positive)"],
    "Precision (mean±std)": [
        f"{mean_prec0:.4f} ± {std_prec0:.4f}",
        f"{mean_prec1:.4f} ± {std_prec1:.4f}"
    ],
    "Recall (mean±std)": [
        f"{mean_recall0:.4f} ± {std_recall0:.4f}",
        f"{mean_recall1:.4f} ± {std_recall1:.4f}"
    ],
    "F1-Score (mean±std)": [
        f"{mean_f1_0:.4f} ± {std_f1_0:.4f}",
        f"{mean_f1_1:.4f} ± {std_f1_1:.4f}"
    ]
}
df_classes = pd.DataFrame(table_data_classes)
print(df_classes.to_string(index=False))


summary_data = {
    "Model": ["Swin UNETR"],
    "AUROC": [f"{mean_auroc:.4f} ± {std_auroc:.4f}"],
    "Sensitivity": [f"{mean_recall1:.4f} ± {std_recall1:.4f}"],
    "Specificity": [f"{mean_recall0:.4f} ± {std_recall0:.4f}"],
    "Precision": [f"{mean_prec1:.4f} ± {std_prec1:.4f}"],
    "F1-Score": [f"{mean_f1_1:.4f} ± {std_f1_1:.4f}"]
}
df_summary = pd.DataFrame(summary_data)
print("Ensembl Summary")
print(df_summary.to_string(index=False))

  from .autonotebook import tqdm as notebook_tqdm


evaluating fold 1/5


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 433/433 [02:21<00:00,  3.05it/s]
  model.load_state_dict(torch.load(model_path, map_location=device))
Testing Fold 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 55/55 [00:32<00:00,  1.70it/s]


Fold 1 - AUROC: 0.6358
              precision    recall  f1-score   support

           0       0.89      0.94      0.91       378
           1       0.30      0.16      0.21        55

    accuracy                           0.85       433
   macro avg       0.59      0.55      0.56       433
weighted avg       0.81      0.85      0.82       433

evaluating fold 2/5


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 368/368 [02:09<00:00,  2.84it/s]
  model.load_state_dict(torch.load(model_path, map_location=device))
Testing Fold 2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 46/46 [00:16<00:00,  2.84it/s]


Fold 2 - AUROC: 0.7014
              precision    recall  f1-score   support

           0       0.78      1.00      0.87       284
           1       0.75      0.04      0.07        84

    accuracy                           0.78       368
   macro avg       0.76      0.52      0.47       368
weighted avg       0.77      0.78      0.69       368

evaluating fold 3/5


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 312/312 [01:46<00:00,  2.94it/s]
  model.load_state_dict(torch.load(model_path, map_location=device))
Testing Fold 3: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [00:13<00:00,  2.82it/s]


Fold 3 - AUROC: 0.6413
              precision    recall  f1-score   support

           0       0.82      0.99      0.90       248
           1       0.82      0.14      0.24        64

    accuracy                           0.82       312
   macro avg       0.82      0.57      0.57       312
weighted avg       0.82      0.82      0.76       312

evaluating fold 4/5


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 337/337 [01:52<00:00,  3.00it/s]
  model.load_state_dict(torch.load(model_path, map_location=device))
Testing Fold 4: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 43/43 [00:14<00:00,  2.88it/s]


Fold 4 - AUROC: 0.5597
              precision    recall  f1-score   support

           0       0.78      0.98      0.87       257
           1       0.58      0.09      0.15        80

    accuracy                           0.77       337
   macro avg       0.68      0.53      0.51       337
weighted avg       0.73      0.77      0.70       337

evaluating fold 5/5


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 360/360 [01:56<00:00,  3.10it/s]
  model.load_state_dict(torch.load(model_path, map_location=device))
Testing Fold 5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 45/45 [00:15<00:00,  2.83it/s]


Fold 5 - AUROC: 0.5201
              precision    recall  f1-score   support

           0       0.84      0.94      0.88       294
           1       0.39      0.18      0.25        66

    accuracy                           0.80       360
   macro avg       0.61      0.56      0.57       360
weighted avg       0.75      0.80      0.77       360

Class Specific Metrics Across 5 Folds:
       Class Precision (mean±std) Recall (mean±std) F1-Score (mean±std)
0 (Negative)      0.8184 ± 0.0409   0.9698 ± 0.0251     0.8865 ± 0.0171
1 (Positive)      0.5677 ± 0.2002   0.1219 ± 0.0535     0.1839 ± 0.0669
Ensembl Summary
     Model           AUROC     Sensitivity     Specificity       Precision        F1-Score
Swin UNETR 0.6117 ± 0.0642 0.1219 ± 0.0535 0.9698 ± 0.0251 0.5677 ± 0.2002 0.1839 ± 0.0669
