In [1]:
import os
import torch
import pandas as pd
import numpy as np
from collections import defaultdict
from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd, ScaleIntensityRanged, Resized, ToTensord
from monai.data import DataLoader, CacheDataset
import torch.nn as nn
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, classification_report
from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm

class Modified3DDenseNet(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        from monai.networks.nets import DenseNet121
        self.densenet = DenseNet121(
            spatial_dims=3,
            in_channels=1,
            out_channels=num_classes
        )
    def forward(self, x):
        return self.densenet(x)

def load_data(labels_file, data_path, task):
    df = pd.read_csv(labels_file, index_col=0)
    df = df[df[task].notna()]
    df['path'] = f'{data_path}/' + df['Patient'].astype(str) + '/' + df['filename'].astype(str) + '.nii.gz'
    df = df[['path', task, 'Patient']]
    file_list = df['path'].values
    labels = df[task].astype(int).values
    patient_ids = df['Patient'].values
    return file_list, labels, patient_ids

def prepare_data(files, labels):
    return [{"image": file_path, "label": label} for file_path, label in zip(files, labels)]


data_path = "/projects/b1038/Pulmonary/ksenkow/CLAD_serial_CT/data/6multiplied"
labels_file = "/projects/b1038/Pulmonary/ksenkow/CLAD_serial_CT/data/v2_analysis/01gather_data/mortality_metadata.csv"
logdir = "../../05mortality_analysis/DenseNet/logs/DenseNet-2/densenet_mortality_12m_20250308_165417" 

task = "mortality_12m"
transforms = Compose([
    LoadImaged(keys=["image"]),
    EnsureChannelFirstd(keys=["image"]),
    ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
    Resized(keys=["image"], spatial_size=(128, 128, 128)),
    ToTensord(keys=["image"]),
])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# prepare data
file_list, labels, patient_ids = load_data(labels_file, data_path, task)
df_patients = pd.DataFrame({"patient": patient_ids, "label": labels})
patient_labels = df_patients.groupby("patient")["label"].max()
unique_patients = patient_labels.index.to_numpy()
unique_labels = patient_labels.values

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

fold_auroc = []
fold_precision_class0 = []
fold_precision_class1 = []
fold_recall_class0 = []
fold_recall_class1 = []
fold_f1_class0 = []
fold_f1_class1 = []

# evaluation loop
for fold, (train_val_idx, test_idx) in enumerate(skf.split(unique_patients, unique_labels)):
    print(f"Evaluating fold {fold+1}/5")
    test_patients = unique_patients[test_idx]
    test_mask = np.isin(patient_ids, test_patients)
    test_files = file_list[test_mask]
    test_labels = labels[test_mask]
    
    test_data = prepare_data(test_files, test_labels)
    test_ds = CacheDataset(data=test_data, transform=transforms, cache_rate=1.0, num_workers=8)
    test_loader = DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=8, pin_memory=True)
    
    # load the best model for this fold
    model = Modified3DDenseNet(num_classes=2).to(device)
    model_path = os.path.join(logdir, f"fold_{fold}_best_model.pth")
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    all_outputs = []
    all_labels = []
    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"Fold {fold+1} Testing"):
            inputs = batch["image"].to(device)
            labels_batch = batch["label"].to(device)
            outputs = model(inputs)
            all_outputs.append(outputs.cpu())
            all_labels.append(labels_batch.cpu())
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    
    probs = torch.softmax(all_outputs, dim=1).numpy()
    preds = probs.argmax(axis=1)
    true_labels = all_labels.numpy()
    
    # compute AUROC using the positive class (index 1).
    auroc = roc_auc_score(true_labels, probs[:, 1])
    fold_auroc.append(auroc)
    
    # compute per-class precision, recall, and F1.
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, preds, average=None)
    fold_precision_class0.append(precision[0])
    fold_precision_class1.append(precision[1])
    fold_recall_class0.append(recall[0])
    fold_recall_class1.append(recall[1])
    fold_f1_class0.append(f1[0])
    fold_f1_class1.append(f1[1])
    
    print(f"Fold {fold+1} AUROC: {auroc:.4f}")
    print(classification_report(true_labels, preds))

# ensemble metrics (mean ± std)
mean_auroc, std_auroc = np.mean(fold_auroc), np.std(fold_auroc)
mean_prec0, std_prec0 = np.mean(fold_precision_class0), np.std(fold_precision_class0)
mean_prec1, std_prec1 = np.mean(fold_precision_class1), np.std(fold_precision_class1)
mean_recall0, std_recall0 = np.mean(fold_recall_class0), np.std(fold_recall_class0)
mean_recall1, std_recall1 = np.mean(fold_recall_class1), np.std(fold_recall_class1)
mean_f1_0, std_f1_0 = np.mean(fold_f1_class0), np.std(fold_f1_class0)
mean_f1_1, std_f1_1 = np.mean(fold_f1_class1), np.std(fold_f1_class1)

table_data_classes = {
    "Class": ["0 (Negative)", "1 (Positive)"],
    "Precision (mean±std)": [
        f"{mean_prec0:.4f} ± {std_prec0:.4f}",
        f"{mean_prec1:.4f} ± {std_prec1:.4f}"
    ],
    "Recall (mean±std)": [
        f"{mean_recall0:.4f} ± {std_recall0:.4f}",
        f"{mean_recall1:.4f} ± {std_recall1:.4f}"
    ],
    "F1-Score (mean±std)": [
        f"{mean_f1_0:.4f} ± {std_f1_0:.4f}",
        f"{mean_f1_1:.4f} ± {std_f1_1:.4f}"
    ]
}
df_classes = pd.DataFrame(table_data_classes)
print("Class specific metrics across 5 folds")
print(df_classes.to_string(index=False))

# ensembl summary
summary_data = {
    "Model": ["Modified3DDenseNet"],
    "AUROC": [f"{mean_auroc:.4f} ± {std_auroc:.4f}"],
    "Sensitivity": [f"{mean_recall1:.4f} ± {std_recall1:.4f}"],  # recall for class 1
    "Specificity": [f"{mean_recall0:.4f} ± {std_recall0:.4f}"],  # recall for class 0
    "Precision": [f"{mean_prec1:.4f} ± {std_prec1:.4f}"],         # precision for class 1
    "F1-Score": [f"{mean_f1_1:.4f} ± {std_f1_1:.4f}"]
}
df_summary = pd.DataFrame(summary_data)
print("Ensembl summary")
print(df_summary.to_string(index=False))

  from .autonotebook import tqdm as notebook_tqdm


Evaluating fold 1/5


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 367/367 [02:06<00:00,  2.89it/s]
Fold 1 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 46/46 [00:06<00:00,  7.39it/s]


Fold 1 AUROC: 0.6037
              precision    recall  f1-score   support

           0       0.76      0.90      0.83       270
           1       0.44      0.22      0.29        97

    accuracy                           0.72       367
   macro avg       0.60      0.56      0.56       367
weighted avg       0.68      0.72      0.68       367

Evaluating fold 2/5


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 403/403 [02:36<00:00,  2.58it/s]
Fold 2 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 51/51 [00:03<00:00, 15.35it/s]


Fold 2 AUROC: 0.5843
              precision    recall  f1-score   support

           0       0.86      0.87      0.86       333
           1       0.34      0.33      0.34        70

    accuracy                           0.77       403
   macro avg       0.60      0.60      0.60       403
weighted avg       0.77      0.77      0.77       403

Evaluating fold 3/5


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 324/324 [02:11<00:00,  2.47it/s]
Fold 3 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [00:02<00:00, 13.99it/s]


Fold 3 AUROC: 0.7037
              precision    recall  f1-score   support

           0       0.87      0.75      0.80       260
           1       0.34      0.53      0.42        64

    accuracy                           0.71       324
   macro avg       0.61      0.64      0.61       324
weighted avg       0.76      0.71      0.73       324

Evaluating fold 4/5


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 410/410 [02:30<00:00,  2.72it/s]
Fold 4 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52/52 [00:03<00:00, 15.81it/s]


Fold 4 AUROC: 0.6408
              precision    recall  f1-score   support

           0       0.88      0.85      0.86       337
           1       0.40      0.44      0.42        73

    accuracy                           0.78       410
   macro avg       0.64      0.65      0.64       410
weighted avg       0.79      0.78      0.78       410

Evaluating fold 5/5


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 306/306 [01:42<00:00,  3.00it/s]
Fold 5 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [00:03<00:00, 12.52it/s]


Fold 5 AUROC: 0.6533
              precision    recall  f1-score   support

           0       0.87      0.91      0.89       261
           1       0.31      0.24      0.28        45

    accuracy                           0.81       306
   macro avg       0.59      0.58      0.58       306
weighted avg       0.79      0.81      0.80       306

Class specific metrics across 5 folds
       Class Precision (mean±std) Recall (mean±std) F1-Score (mean±std)
0 (Negative)      0.8477 ± 0.0433   0.8561 ± 0.0566     0.8498 ± 0.0310
1 (Positive)      0.3667 ± 0.0439   0.3518 ± 0.1184     0.3466 ± 0.0604
Ensembl summary
             Model           AUROC     Sensitivity     Specificity       Precision        F1-Score
Modified3DDenseNet 0.6372 ± 0.0415 0.3518 ± 0.1184 0.8561 ± 0.0566 0.3667 ± 0.0439 0.3466 ± 0.0604
