In [2]:
import os
import torch
import pandas as pd
import numpy as np
from collections import defaultdict
from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd, ScaleIntensityRanged, Resized, ToTensord
from monai.data import DataLoader, CacheDataset
import torch.nn as nn
from torchvision.models.video import r3d_18
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, classification_report
from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm

class Modified3DResNet18(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        # load pretrained 3D ResNet18 (pretrained on Kinetics)
        self.resnet = r3d_18(pretrained=True)
        # modify first conv layer to accept single-channel input
        original_conv = self.resnet.stem[0]
        self.resnet.stem[0] = nn.Conv3d(1, 64,
                                        kernel_size=original_conv.kernel_size,
                                        stride=original_conv.stride,
                                        padding=original_conv.padding,
                                        bias=False)
        # modify final fully-connected layer for num_classes
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.resnet(x)

def load_data(labels_file, data_path, task):
    df = pd.read_csv(labels_file, index_col=0)
    df = df[df[task].notna()]
    df['path'] = f'{data_path}/' + df['Patient'].astype(str) + '/' + df['filename'].astype(str) + '.nii.gz'
    df = df[[ "path", task, "Patient" ]]
    file_list = df['path'].values
    labels = df[task].astype(int).values
    patient_ids = df['Patient'].values
    return file_list, labels, patient_ids

def prepare_data(files, labels):
    return [{"image": file_path, "label": int(label)} for file_path, label in zip(files, labels)]


data_path = "/projects/b1038/Pulmonary/ksenkow/CLAD_serial_CT/data/6multiplied"
labels_file = "/projects/b1038/Pulmonary/ksenkow/CLAD_serial_CT/data/v2_analysis/01gather_data/mortality_metadata.csv"
logdir = "../../05mortality_analysis/ResNet/logs/ResNet-2/resnet18_mortality_12m_20250308_092243" 

# standardization (as in training)
transforms = Compose([
    LoadImaged(keys=["image"]),
    EnsureChannelFirstd(keys=["image"]),
    ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
    Resized(keys=["image"], spatial_size=(128, 128, 128)),
    ToTensord(keys=["image"]),
])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# stratification split
task = "mortality_12m"
file_list, labels, patient_ids = load_data(labels_file, data_path, task)
df_patients = pd.DataFrame({'patient': patient_ids, 'label': labels})
patient_labels = df_patients.groupby("patient")["label"].max()
unique_patients = patient_labels.index.to_numpy()
unique_labels = patient_labels.values

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)


fold_auroc = []
fold_precision_class0 = []
fold_precision_class1 = []
fold_recall_class0 = []
fold_recall_class1 = []
fold_f1_class0 = []
fold_f1_class1 = []

# evaluation loop
for fold, (train_val_idx, test_idx) in enumerate(skf.split(unique_patients, unique_labels)):
    print(f"evaluating fold {fold+1}/5")
    test_patients = unique_patients[test_idx]
    
    # select test files & labels using np.isin
    test_mask = np.isin(patient_ids, test_patients)
    test_files = file_list[test_mask]
    test_labels = labels[test_mask]
    
    test_data = prepare_data(test_files, test_labels)
    test_ds = CacheDataset(data=test_data, transform=transforms, cache_rate=1.0, num_workers=8)
    test_loader = DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=8, pin_memory=True)
    
    # load model checkpoint for this fold
    model = Modified3DResNet18(num_classes=2).to(device)
    model_path = os.path.join(logdir, f"fold_{fold}_best_model.pth")
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    all_outputs = []
    all_labels = []
    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"Fold {fold+1} Testing"):
            inputs = batch["image"].to(device)
            labels_batch = batch["label"].to(device)
            outputs = model(inputs)
            all_outputs.append(outputs.cpu())
            all_labels.append(labels_batch.cpu())
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    
    probs = torch.softmax(all_outputs, dim=1).numpy()
    preds = probs.argmax(axis=1)
    true_labels = all_labels.numpy()
    
    # AUROC using probability of positive class (index 1)
    auroc = roc_auc_score(true_labels, probs[:, 1])
    fold_auroc.append(auroc)
    
    # compute per class precision, recall, and F1
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, preds, average=None)
    fold_precision_class0.append(precision[0])
    fold_precision_class1.append(precision[1])
    fold_recall_class0.append(recall[0])
    fold_recall_class1.append(recall[1])
    fold_f1_class0.append(f1[0])
    fold_f1_class1.append(f1[1])
    
    print(f"Fold {fold+1} AUROC: {auroc:.4f}")
    print(classification_report(true_labels, preds))


# ensemble metrics (Mean ± Std)
mean_auroc, std_auroc = np.mean(fold_auroc), np.std(fold_auroc)
mean_prec0, std_prec0 = np.mean(fold_precision_class0), np.std(fold_precision_class0)
mean_prec1, std_prec1 = np.mean(fold_precision_class1), np.std(fold_precision_class1)
mean_recall0, std_recall0 = np.mean(fold_recall_class0), np.std(fold_recall_class0)
mean_recall1, std_recall1 = np.mean(fold_recall_class1), np.std(fold_recall_class1)
mean_f1_0, std_f1_0 = np.mean(fold_f1_class0), np.std(fold_f1_class0)
mean_f1_1, std_f1_1 = np.mean(fold_f1_class1), np.std(fold_f1_class1)


table_data_classes = {
    "Class": ["0 (Negative)", "1 (Positive)"],
    "Precision (mean±std)": [
        f"{mean_prec0:.4f} ± {std_prec0:.4f}",
        f"{mean_prec1:.4f} ± {std_prec1:.4f}"
    ],
    "Recall (mean±std)": [
        f"{mean_recall0:.4f} ± {std_recall0:.4f}",
        f"{mean_recall1:.4f} ± {std_recall1:.4f}"
    ],
    "F1-Score (mean±std)": [
        f"{mean_f1_0:.4f} ± {std_f1_0:.4f}",
        f"{mean_f1_1:.4f} ± {std_f1_1:.4f}"
    ]
}
df_classes = pd.DataFrame(table_data_classes)
print("Class specific metrics across 5 folds")
print(df_classes.to_string(index=False))

# ensembl summary
summary_data = {
    "Model": ["Modified3DResNet18"],
    "AUROC": [f"{mean_auroc:.4f} ± {std_auroc:.4f}"],
    "Sensitivity": [f"{mean_recall1:.4f} ± {std_recall1:.4f}"],  # Recall for class 1
    "Specificity": [f"{mean_recall0:.4f} ± {std_recall0:.4f}"],  # Recall for class 0
    "Precision": [f"{mean_prec1:.4f} ± {std_prec1:.4f}"],         # Precision for class 1
    "F1-Score": [f"{mean_f1_1:.4f} ± {std_f1_1:.4f}"]
}
df_summary = pd.DataFrame(summary_data)
print("Ensembl summary")
print(df_summary.to_string(index=False))

evaluating fold 1/5


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 367/367 [02:03<00:00,  2.98it/s]
Fold 1 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 46/46 [00:09<00:00,  4.63it/s]


Fold 1 AUROC: 0.6651
              precision    recall  f1-score   support

           0       0.76      0.94      0.84       270
           1       0.47      0.15      0.23        97

    accuracy                           0.73       367
   macro avg       0.61      0.55      0.53       367
weighted avg       0.68      0.73      0.68       367

evaluating fold 2/5


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 403/403 [02:37<00:00,  2.56it/s]
Fold 2 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 51/51 [00:06<00:00,  8.05it/s]


Fold 2 AUROC: 0.5792
              precision    recall  f1-score   support

           0       0.84      0.89      0.86       333
           1       0.28      0.21      0.24        70

    accuracy                           0.77       403
   macro avg       0.56      0.55      0.55       403
weighted avg       0.75      0.77      0.76       403

evaluating fold 3/5


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 324/324 [02:11<00:00,  2.47it/s]
Fold 3 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [00:05<00:00,  7.91it/s]


Fold 3 AUROC: 0.7020
              precision    recall  f1-score   support

           0       0.85      0.73      0.79       260
           1       0.31      0.48      0.38        64

    accuracy                           0.69       324
   macro avg       0.58      0.61      0.58       324
weighted avg       0.75      0.69      0.71       324

evaluating fold 4/5


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 410/410 [02:30<00:00,  2.73it/s]
Fold 4 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52/52 [00:06<00:00,  8.63it/s]


Fold 4 AUROC: 0.6696
              precision    recall  f1-score   support

           0       0.87      0.71      0.78       337
           1       0.27      0.49      0.35        73

    accuracy                           0.67       410
   macro avg       0.57      0.60      0.56       410
weighted avg       0.76      0.67      0.70       410

evaluating fold 5/5


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 306/306 [01:43<00:00,  2.96it/s]
Fold 5 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [00:04<00:00,  8.06it/s]


Fold 5 AUROC: 0.6840
              precision    recall  f1-score   support

           0       0.87      0.95      0.91       261
           1       0.36      0.18      0.24        45

    accuracy                           0.83       306
   macro avg       0.62      0.56      0.57       306
weighted avg       0.80      0.83      0.81       306

Class specific metrics across 5 folds
       Class Precision (mean±std) Recall (mean±std) F1-Score (mean±std)
0 (Negative)      0.8373 ± 0.0421   0.8426 ± 0.1010     0.8351 ± 0.0471
1 (Positive)      0.3388 ± 0.0726   0.3048 ± 0.1514     0.2882 ± 0.0618
Ensembl summary
             Model           AUROC     Sensitivity     Specificity       Precision        F1-Score
Modified3DResNet18 0.6600 ± 0.0424 0.3048 ± 0.1514 0.8426 ± 0.1010 0.3388 ± 0.0726 0.2882 ± 0.0618
