In [3]:
import os
import torch
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import nibabel as nib
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, classification_report
from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm
import sys

logdir = "../../05mortality_analysis/MedMamba/logs/medmamba3d-v2" 
data_path = "/projects/b1038/Pulmonary/ksenkow/CLAD_serial_CT/data/6multiplied"
data_csv = "/projects/b1038/Pulmonary/ksenkow/CLAD_serial_CT/data/v2_analysis/01gather_data/mortality_metadata.csv"

data = pd.read_csv(data_csv, index_col=0)
data['path'] = f"{data_path}/" + data['Patient'] + '/' + data['filename'] + '.nii.gz'
data = data[['path', 'mortality_12m', 'Patient']]

# define the dataset (same as training)
class CTScanDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        sample = self.data[idx]
        image_path, label = sample['image'], sample['label']
        image = nib.load(image_path).get_fdata()
        image = np.clip(image, -175, 250)
        image = (image + 175) / 425.0
        image = np.expand_dims(image, axis=0)  # (1, H, W, D)
        image = torch.tensor(image, dtype=torch.float32)
        image = F.interpolate(image.unsqueeze(0), size=(96, 96, 96), mode='trilinear', align_corners=False).squeeze(0)
        return image, torch.tensor(label, dtype=torch.long)

# import MedMamba3D model
sys.path.append("/projects/b1038/Pulmonary/ksenkow/CLAD_serial_CT/code/mamba_env/MedMamba")
from MedMamba3D import VSSM3D as MedMamba3D

# define the model wrapper
class MortalityModel(nn.Module):
    def __init__(self):
        super(MortalityModel, self).__init__()
        self.model = MedMamba3D(in_chans=1, num_classes=2)
    def forward(self, x):
        return self.model(x)

patient_mortality = data.groupby('Patient')['mortality_12m'].max()
k_folds = 5
stratified_kfold = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=42)
patient_to_idx = defaultdict(list)
for idx, patient in enumerate(data["Patient"].values):
    patient_to_idx[patient].append(idx)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

fold_auroc = []
fold_precision_class0 = []
fold_precision_class1 = []
fold_recall_class0 = []
fold_recall_class1 = []
fold_f1_class0 = []
fold_f1_class1 = []

# evaluate each fold - testset
for fold, (_, test_patient_idx) in enumerate(stratified_kfold.split(np.zeros(len(patient_mortality)), patient_mortality)):
    print(f"evaluating fold {fold+1}/{k_folds}")
    test_patients = patient_mortality.index[test_patient_idx]
    test_data = data[data['Patient'].isin(test_patients)]
    test_samples = [{"image": row["path"], "label": row["mortality_12m"]} for _, row in test_data.iterrows()]
    
    test_ds = CTScanDataset(test_samples)
    test_loader = DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=8, pin_memory=True)
    
    # load model checkpoint
    model = MortalityModel().to(device)
    model_path = os.path.join(logdir, f"best_model_fold_{fold}.pth")
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    
    all_outputs = []
    all_labels = []
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc=f"Fold {fold+1} Testing"):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)  # logits
            all_outputs.append(outputs.cpu())
            all_labels.append(labels.cpu())
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    
    probs = torch.softmax(all_outputs, dim=1).numpy()
    preds = probs.argmax(axis=1)
    true_labels = all_labels.numpy()
    
    # compute AUROC (for positive class)
    auroc = roc_auc_score(true_labels, probs[:, 1])
    fold_auroc.append(auroc)
    
    # compute per class precision, recall, and F1
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, preds, average=None)
    fold_precision_class0.append(precision[0])
    fold_precision_class1.append(precision[1])
    fold_recall_class0.append(recall[0])
    fold_recall_class1.append(recall[1])
    fold_f1_class0.append(f1[0])
    fold_f1_class1.append(f1[1])
    
    print(f"Fold {fold+1} AUROC: {auroc:.4f}")
    print(classification_report(true_labels, preds))

# ensemble metrics (mean ± std)
mean_auroc, std_auroc = np.mean(fold_auroc), np.std(fold_auroc)
mean_prec0, std_prec0 = np.mean(fold_precision_class0), np.std(fold_precision_class0)
mean_prec1, std_prec1 = np.mean(fold_precision_class1), np.std(fold_precision_class1)
mean_recall0, std_recall0 = np.mean(fold_recall_class0), np.std(fold_recall_class0)
mean_recall1, std_recall1 = np.mean(fold_recall_class1), np.std(fold_recall_class1)
mean_f1_0, std_f1_0 = np.mean(fold_f1_class0), np.std(fold_f1_class0)
mean_f1_1, std_f1_1 = np.mean(fold_f1_class1), np.std(fold_f1_class1)

table_data_classes = {
    "Class": ["0 (Negative)", "1 (Positive)"],
    "Precision (mean±std)": [
        f"{mean_prec0:.4f} ± {std_prec0:.4f}",
        f"{mean_prec1:.4f} ± {std_prec1:.4f}"
    ],
    "Recall (mean±std)": [
        f"{mean_recall0:.4f} ± {std_recall0:.4f}",
        f"{mean_recall1:.4f} ± {std_recall1:.4f}"
    ],
    "F1-Score (mean±std)": [
        f"{mean_f1_0:.4f} ± {std_f1_0:.4f}",
        f"{mean_f1_1:.4f} ± {std_f1_1:.4f}"
    ]
}
df_classes = pd.DataFrame(table_data_classes)
print("Class specific metrics across 5 folds")
print(df_classes.to_string(index=False))

# ensembl summary
summary_data = {
    "Model": ["MedMamba3D"],
    "AUROC": [f"{mean_auroc:.4f} ± {std_auroc:.4f}"],
    "Sensitivity": [f"{mean_recall1:.4f} ± {std_recall1:.4f}"],  # Recall for class 1
    "Specificity": [f"{mean_recall0:.4f} ± {std_recall0:.4f}"],  # Recall for class 0
    "Precision": [f"{mean_prec1:.4f} ± {std_prec1:.4f}"],         # Precision for class 1
    "F1-Score": [f"{mean_f1_1:.4f} ± {std_f1_1:.4f}"]
}
df_summary = pd.DataFrame(summary_data)
print("Ensembl summary")
print(df_summary.to_string(index=False))

evaluating fold 1/5


Fold 1 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 46/46 [01:01<00:00,  1.34s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 1 AUROC: 0.4990
              precision    recall  f1-score   support

           0       0.74      1.00      0.85       270
           1       0.00      0.00      0.00        97

    accuracy                           0.74       367
   macro avg       0.37      0.50      0.42       367
weighted avg       0.54      0.74      0.62       367

evaluating fold 2/5


Fold 2 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 51/51 [01:24<00:00,  1.66s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 2 AUROC: 0.4493
              precision    recall  f1-score   support

           0       0.83      1.00      0.90       333
           1       0.00      0.00      0.00        70

    accuracy                           0.83       403
   macro avg       0.41      0.50      0.45       403
weighted avg       0.68      0.83      0.75       403

evaluating fold 3/5


Fold 3 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [01:08<00:00,  1.67s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 3 AUROC: 0.5733
              precision    recall  f1-score   support

           0       0.80      1.00      0.89       260
           1       0.00      0.00      0.00        64

    accuracy                           0.80       324
   macro avg       0.40      0.50      0.45       324
weighted avg       0.64      0.80      0.71       324

evaluating fold 4/5


Fold 4 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52/52 [01:12<00:00,  1.40s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 4 AUROC: 0.6514
              precision    recall  f1-score   support

           0       0.82      1.00      0.90       337
           1       0.00      0.00      0.00        73

    accuracy                           0.82       410
   macro avg       0.41      0.50      0.45       410
weighted avg       0.68      0.82      0.74       410

evaluating fold 5/5


Fold 5 Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [00:58<00:00,  1.50s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 5 AUROC: 0.5436
              precision    recall  f1-score   support

           0       0.85      1.00      0.92       261
           1       0.00      0.00      0.00        45

    accuracy                           0.85       306
   macro avg       0.43      0.50      0.46       306
weighted avg       0.73      0.85      0.79       306

Class specific metrics across 5 folds
       Class Precision (mean±std) Recall (mean±std) F1-Score (mean±std)
0 (Negative)      0.8079 ± 0.0395   1.0000 ± 0.0000     0.8932 ± 0.0247
1 (Positive)      0.0000 ± 0.0000   0.0000 ± 0.0000     0.0000 ± 0.0000
Ensembl summary
     Model           AUROC     Sensitivity     Specificity       Precision        F1-Score
MedMamba3D 0.5433 ± 0.0684 0.0000 ± 0.0000 1.0000 ± 0.0000 0.0000 ± 0.0000 0.0000 ± 0.0000
