In [1]:
import os
import numpy as np
import pandas as pd
import torch
from collections import defaultdict
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, ScaleIntensityRanged,
    Resized, ToTensord
)
from monai.data import DataLoader, CacheDataset
import torch.nn as nn
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, classification_report
from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm

data_path = "/projects/b1038/Pulmonary/ksenkow/CLAD_serial_CT/data/6multiplied"
labels_file = "/projects/b1038/Pulmonary/ksenkow/CLAD_serial_CT/data/v2_analysis/01gather_data/mortality_metadata.csv"
logdir = "./logs/swin_unetr_balanced"
num_classes = 2
k_folds = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data = pd.read_csv(labels_file, index_col=0)
data['path'] = f"{data_path}/" + data['Patient'] + "/" + data['filename'] + ".nii.gz"
data = data[['path', 'mortality_12m', 'Patient']]

# identify patients with at least one positive scan
patients_with_positive = set(data[data['mortality_12m'] == 1]['Patient'])
# patients with only negatives
negative_only_patients = set(data['Patient']) - patients_with_positive
DOWNSAMPLE_RATIO = 3
downsampled_negatives = np.random.choice(
    list(negative_only_patients),
    size=min(len(negative_only_patients), len(patients_with_positive) * DOWNSAMPLE_RATIO),
    replace=False
)
balanced_data = data[(data['Patient'].isin(patients_with_positive)) | (data['Patient'].isin(downsampled_negatives))]

# stratified group k fold
patient_labels = balanced_data.groupby("Patient")["mortality_12m"].max()
unique_patients = patient_labels.index.to_numpy()
unique_labels = patient_labels.values

skf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=42)
# map patients to indices in balanced_data
patient_to_idx = defaultdict(list)
for idx, patient in enumerate(balanced_data["Patient"].values):
    patient_to_idx[patient].append(idx)

def prepare_data(df):
    return [{"image": row["path"], "label": row["mortality_12m"]} for _, row in df.iterrows()]


# base transform
base_transform = Compose([
    LoadImaged(keys=["image"]),
    EnsureChannelFirstd(keys=["image"]),
    ScaleIntensityRanged(keys=["image"],
                         a_min=-175, a_max=250, 
                         b_min=0.0, b_max=1.0, clip=True),
    Resized(keys=["image"], spatial_size=(160, 160, 160)),
    ToTensord(keys=["image"]),
])

# define the model
from monai.networks.nets import SwinUNETR
def get_model():
    model = SwinUNETR(
        img_size=(160, 160, 160),
        in_channels=1,
        out_channels=48,
        feature_size=48,
        use_checkpoint=True
    )
    model.classification_head = nn.Sequential(
        nn.AdaptiveAvgPool3d(1),
        nn.Flatten(start_dim=1, end_dim=-1),
        nn.Linear(48, num_classes)
    )
    return model


fold_auroc = []
fold_precision_class0 = []
fold_precision_class1 = []
fold_recall_class0 = []
fold_recall_class1 = []
fold_f1_class0 = []
fold_f1_class1 = []

# evaluation Loop
for fold, (train_val_idx, test_idx) in enumerate(skf.split(unique_patients, unique_labels)):
    print(f"evaluating fold: {fold+1}/{k_folds}")
    # get IDs for test pnts
    test_patients = unique_patients[test_idx]
    test_df = balanced_data[balanced_data['Patient'].isin(test_patients)]
    test_data = prepare_data(test_df)
    
    # test dataset & dataloader created using base transform
    test_ds = CacheDataset(data=test_data, transform=base_transform, cache_rate=1.0, num_workers=4)
    test_loader = DataLoader(test_ds, batch_size=5, shuffle=False, num_workers=4)
    
    # load model checkpoint for this fold
    model = get_model().to(device)
    model_path = os.path.join(logdir, f"best_model_fold_{fold}.pth")
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint)
    model.eval()
    
    all_outputs = []
    all_labels = []
    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"Fold {fold+1} Testing"):
            inputs = batch["image"].to(device)
            labels_batch = batch["label"].to(device)
            outputs = model.classification_head(model(inputs))
            all_outputs.append(outputs.cpu())
            all_labels.append(labels_batch.cpu())
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    
    # compute softmax probabilities and predictions
    probs = torch.softmax(all_outputs, dim=1).numpy()
    preds = probs.argmax(axis=1)
    true_labels = all_labels.numpy()
    
    # compute AUROC using probability for positive class (index 1)
    auroc = roc_auc_score(true_labels, probs[:, 1])
    fold_auroc.append(auroc)
    
    # compute class precision, recall, and F1-score
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, preds, average=None)
    fold_precision_class0.append(precision[0])
    fold_precision_class1.append(precision[1])
    fold_recall_class0.append(recall[0])
    fold_recall_class1.append(recall[1])
    fold_f1_class0.append(f1[0])
    fold_f1_class1.append(f1[1])
    
    print(f"Fold {fold+1} AUROC: {auroc:.4f}")
    print(classification_report(true_labels, preds))

# ensembl metrics (mean ± std)
mean_auroc = np.mean(fold_auroc)
std_auroc = np.std(fold_auroc)
mean_prec0 = np.mean(fold_precision_class0)
std_prec0 = np.std(fold_precision_class0)
mean_prec1 = np.mean(fold_precision_class1)
std_prec1 = np.std(fold_precision_class1)
mean_recall0 = np.mean(fold_recall_class0)
std_recall0 = np.std(fold_recall_class0)
mean_recall1 = np.mean(fold_recall_class1)
std_recall1 = np.std(fold_recall_class1)
mean_f1_0 = np.mean(fold_f1_class0)
std_f1_0 = np.std(fold_f1_class0)
mean_f1_1 = np.mean(fold_f1_class1)
std_f1_1 = np.std(fold_f1_class1)


table_data_classes = {
    "Class": ["0 (Negative)", "1 (Positive)"],
    "Precision (mean±std)": [
        f"{mean_prec0:.4f} ± {std_prec0:.4f}",
        f"{mean_prec1:.4f} ± {std_prec1:.4f}"
    ],
    "Recall (mean±std)": [
        f"{mean_recall0:.4f} ± {std_recall0:.4f}",
        f"{mean_recall1:.4f} ± {std_recall1:.4f}"
    ],
    "F1-Score (mean±std)": [
        f"{mean_f1_0:.4f} ± {std_f1_0:.4f}",
        f"{mean_f1_1:.4f} ± {std_f1_1:.4f}"
    ]
}
df_classes = pd.DataFrame(table_data_classes)
print("Class Specific Metrics Across 5 Folds")
print(df_classes.to_string(index=False))

# ensembl summary
summary_data = {
    "Model": ["Swin UNETR (Balanced)"],
    "AUROC": [f"{mean_auroc:.4f} ± {std_auroc:.4f}"],
    "Sensitivity": [f"{mean_recall1:.4f} ± {std_recall1:.4f}"],  # recall for class 1
    "Specificity": [f"{mean_recall0:.4f} ± {std_recall0:.4f}"],  # recall for class 0
    "Precision": [f"{mean_prec1:.4f} ± {std_prec1:.4f}"],         # precision for class 1
    "F1-Score": [f"{mean_f1_1:.4f} ± {std_f1_1:.4f}"]
}
df_summary = pd.DataFrame(summary_data)
print("Ensembl Summary")
print(df_summary.to_string(index=False))

  from .autonotebook import tqdm as notebook_tqdm


evaluating fold: 1/5


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 306/306 [02:12<00:00,  2.31it/s]
  checkpoint = torch.load(model_path, map_location=device)
Fold 1 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [01:01<00:00,  1.01it/s]


Fold 1 AUROC: 0.6234
              precision    recall  f1-score   support

           0       0.78      1.00      0.87       235
           1       1.00      0.04      0.08        71

    accuracy                           0.78       306
   macro avg       0.89      0.52      0.48       306
weighted avg       0.83      0.78      0.69       306

evaluating fold: 2/5


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 310/310 [02:15<00:00,  2.29it/s]
  checkpoint = torch.load(model_path, map_location=device)
Fold 2 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [00:59<00:00,  1.04it/s]


Fold 2 AUROC: 0.5875
              precision    recall  f1-score   support

           0       0.69      1.00      0.82       211
           1       1.00      0.05      0.10        99

    accuracy                           0.70       310
   macro avg       0.85      0.53      0.46       310
weighted avg       0.79      0.70      0.59       310

evaluating fold: 3/5


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 315/315 [02:37<00:00,  2.00it/s]
  checkpoint = torch.load(model_path, map_location=device)
Fold 3 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [01:01<00:00,  1.03it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 3 AUROC: 0.6904
              precision    recall  f1-score   support

           0       0.77      1.00      0.87       242
           1       0.00      0.00      0.00        73

    accuracy                           0.77       315
   macro avg       0.38      0.50      0.43       315
weighted avg       0.59      0.77      0.67       315

evaluating fold: 4/5


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 305/305 [02:17<00:00,  2.22it/s]
  checkpoint = torch.load(model_path, map_location=device)
Fold 4 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 61/61 [00:58<00:00,  1.04it/s]


Fold 4 AUROC: 0.6275
              precision    recall  f1-score   support

           0       0.80      0.99      0.88       241
           1       0.62      0.08      0.14        64

    accuracy                           0.80       305
   macro avg       0.71      0.53      0.51       305
weighted avg       0.76      0.80      0.73       305

evaluating fold: 5/5


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 392/392 [02:51<00:00,  2.29it/s]
  checkpoint = torch.load(model_path, map_location=device)
Fold 5 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [01:16<00:00,  1.04it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 5 AUROC: 0.6488
              precision    recall  f1-score   support

           0       0.89      1.00      0.94       350
           1       0.00      0.00      0.00        42

    accuracy                           0.89       392
   macro avg       0.45      0.50      0.47       392
weighted avg       0.80      0.89      0.84       392

Class Specific Metrics Across 5 Folds
       Class Precision (mean±std) Recall (mean±std) F1-Score (mean±std)
0 (Negative)      0.7860 ± 0.0647   0.9975 ± 0.0050     0.8777 ± 0.0401
1 (Positive)      0.5250 ± 0.4500   0.0342 ± 0.0303     0.0632 ± 0.0550
Ensembl Summary
                Model           AUROC     Sensitivity     Specificity       Precision        F1-Score
Swin UNETR (Balanced) 0.6355 ± 0.0338 0.0342 ± 0.0303 0.9975 ± 0.0050 0.5250 ± 0.4500 0.0632 ± 0.0550
