# Geo-Radio Classification

This notebook is used to train the classification model on the ASOCA Dataset.
The notebook assumes the previous scripts have ran, which have saved the following models:
- a segmentation model trained using anatomix, currently pointing to `saved_models/segmentation/anatomix_trained_MM-WHS.pth`
- a registration model trained using Atlas-ISTN, currently pointing to `output/mm-whs/full-stn/train/model/stn.pt`
- an atlas labelmap created using Atlas-ISTN, currently pointing to `output/mm-whs/full-stn/train/model/atlas_labelmap_final.nii.gz`

If you have run the `anatomix-fine-tuning.py` and the `atlas-istn-anatomix.py` files, these will automatically be generated for you.

This model will segment the ASOCA images as directed by the config CSV file `data/config/inference.csv`.

# Imports and Global Config

In [None]:
import sys
sys.path.append("anatomix")

import torch
from monai.data import ThreadDataLoader, CacheDataset
from monai.transforms import Lambdad
from nets.stn import FullSTN3D
from img.datasets import ImageSegmentationOneHotDataset

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark     = False

In [None]:
spacing = (2.0, 2.0, 2.0)
crop_size = (96, 96, 96)
num_classes = 8

# Class Mapping 
*(should match "class_mapping" in `data/config/config.json`)*

In [None]:
class_mapping = {
        1: "myocardium",
        2: "left atrium",
        3: "left ventricle",
        4: "right atrium",
        5: "right ventricle",
        6: "aorta",
        7: "pulmonary artery",
    }

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu" 
print(device)

# Load STN

In [None]:
stn_path = "output/mm-whs/full-stn/train/model/stn.pt"
stn = FullSTN3D(input_size=crop_size, input_channels=2*(num_classes-1), device=device).to(device)
stn.load_state_dict(torch.load(stn_path))
stn.eval()

# Load dataset

In [None]:
dataset_test_base = ImageSegmentationOneHotDataset("data/config/inference.csv", 
                                            num_classes, crop_size, spacing, 
                                            normalizer=Lambdad(keys=["image"], func=lambda x: x), binarize=0, augmentation=False)
dataset_test = CacheDataset(data=dataset_test_base, transform=None, cache_rate=1.0, num_workers=4)
dataset_test.get_sample = dataset_test_base.get_sample
dataloader_test = ThreadDataLoader(dataset_test, batch_size=1, shuffle=False)

# Extract radiomic and deformation data
This loop will iterate over each CT test volume in `inference.csv`, creating features per volume in a list.
For each test volume, we store the following features for downstream classification:
- **label**, described as either "Diseased" or "Healthy" (obtained by parsing the file name).
- **struct_disp**, a dictionary keyed per substructure storing the respective deformation displacement field.
- **radiomics**, a dictionary keyed per substructure storing the respective radiomics features.

In [None]:
import torch
import numpy as np
import SimpleITK as sitk
from radiomics import featureextractor
from tqdm import tqdm


radiomics_settings = {
    'binWidth': 25,
    'resampledPixelSpacing': None,
    'interpolator': sitk.sitkLinear,
    'verbose': False
}
extractor = featureextractor.RadiomicsFeatureExtractor(**radiomics_settings)

atlas_label_itk = sitk.ReadImage("output/mm-whs/full-stn/train/model/atlas_labelmap_final.nii.gz")

arr_lab = sitk.GetArrayFromImage(atlas_label_itk)
if arr_lab.ndim == 4:  # already one‐hot in last dim
    atlas_label = (
        torch.from_numpy(arr_lab)
        .permute(3, 0, 1, 2)
        .unsqueeze(0)
        .float()
        .to(device)
    )
else:
    labels_int = torch.from_numpy(arr_lab).long()
    one_hot    = torch.nn.functional.one_hot(labels_int, num_classes=num_classes)
    atlas_label = one_hot.permute(3, 0, 1, 2).unsqueeze(0).float().to(device)

example_batch = next(iter(dataloader_test))
batch_size = example_batch["image"].size(0)
atlas_label = atlas_label.repeat(batch_size, 1, 1, 1, 1)

# Precompute identity grid once (for displacement = T – grid)
identity_grid = stn.grid.unsqueeze(0)
identity_grid = stn.move_grid_dims(identity_grid)
identity_grid = identity_grid.repeat(batch_size, 1, 1, 1, 1).to(device)

subjects = []

for batch in tqdm(dataloader_test, desc="extracting features"):
    image_tensor = batch["image"].to(device)
    label_onehot  = batch["labelmap"].to(device)
    fname         = batch["fname"][0]

    img_type = "Diseased" if "Diseased" in fname else "Normal"

    # Run STN to get full warp grid T
    src = label_onehot[:, 1:, ...]
    tgt = atlas_label[:, 1:, ...]
    _   = stn(torch.cat((src, tgt), dim=1))

    T = stn.get_T()
    full_disp = T - identity_grid
    disp_np = full_disp[0].detach().cpu().numpy()

    # per‐structure displacement
    struct_disp = {}
    for L in class_mapping.keys():
        maskL = label_onehot[0, L].bool().cpu().numpy()  
        disp_vox = disp_np[maskL]
        struct_disp[L] = disp_vox

    # per‐structure radiomics
    img_np = image_tensor[0, 0].detach().cpu().numpy()
    sitk_img = sitk.GetImageFromArray(img_np)
    sitk_img.SetSpacing(spacing)

    radiomics = {}
    for L, name in class_mapping.items():
        mask_np = label_onehot[0, L].cpu().numpy().astype(np.uint8)
        if mask_np.sum() == 0:
            # No voxels, so store an array of nans with length = len(SEMANTIC_FEATURES)
            radiomics[L] = np.full((len(SEMANTIC_FEATURES),), np.nan, dtype=float)
        else:
            sitk_mask = sitk.GetImageFromArray(mask_np)
            sitk_mask.CopyInformation(sitk_img)
            result = extractor.execute(sitk_img, sitk_mask)
            # get only original features (features derived from the unfiltered CT image)
            # of the original subset, the radiomic features are:
            # – First-order statistics
            # – Shape descriptors (3D)
            # – GLCM (Gray Level Co-occurrence Matrix)
            # – GLRLM (Gray Level Run Length Matrix)
            # – GLSZM (Gray Level Size Zone Matrix)
            # – NGTDM (Neighbouring Gray Tone Difference Matrix)
            # – GLDM (Gray Level Dependence Matrix)
            SEMANTIC_FEATURES = sorted([k for k in result.keys() if k.startswith("original_")])

            feats = []
            for feat_name in SEMANTIC_FEATURES:
                val = result.get(feat_name, float("nan"))
                feats.append(float(val))
            radiomics[L] = np.array(feats, dtype=float)

    # Collect everything into a single dict for this subject
    subject_data = {
        "fname":     fname,
        "label":     img_type,
        "full_disp": disp_np,       # [D,H,W,3]
        "struct_disp": struct_disp, # dict L->(n_vox_L,3)
        "radiomics":   radiomics    # dict L->(len(SEMANTIC_FEATURES),)
    }
    subjects.append(subject_data)

# After this loop, `subjects` is a list of length N (test cases),
# and each `subjects[i]` contains all the deformation + radiomics for that case.

# MLP classification
Here, we perform hyperparameter optimisation using `optuna` to find the best classification model for the diseased data.
We perform 5-fold stratified cross-validation over 3 seeds to achieve confidence in the low volume of data. 

In [None]:
import torch
import numpy as np
import pandas as pd
import optuna
import random
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold
from torchvision.ops import MLP
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

MAX_DEF_PC = 3

rows = []
for subj in subjects:
    row = {}
    
    for L, name in class_mapping.items():
        disp_vox = subj["struct_disp"][L]
        n_vox = disp_vox.shape[0]
        
        if n_vox < 1:
            # no voxels → all zeros
            evr = np.zeros(MAX_DEF_PC, dtype=float)
        else:
            u, s, vh = np.linalg.svd(disp_vox, full_matrices=False)

            # Store top MAX_DEF_PC singular values, pad with zeros if needed
            evr = np.zeros(MAX_DEF_PC, dtype=float)
            n_comp = min(len(s), MAX_DEF_PC)
            evr[:n_comp] = s[:n_comp]
        
        # store under def_pc1_<name> … def_pc5_<name>
        for pc_idx in range(MAX_DEF_PC):
            col = f"def_pc{pc_idx+1}_{name}"
            row[col] = float(evr[pc_idx])
    
    #Radiomics (semantic features) per structure
    for L, name in class_mapping.items():
        rad_vec = subj["radiomics"][L]  # shape = (len(SEMANTIC_FEATURES),) or all-nan
        for idx, feat_name in enumerate(SEMANTIC_FEATURES):
            col = f"{feat_name}_{name}"
            row[col] = float(rad_vec[idx])
    
    row["label"] = subj["label"]  # "Normal" or "Diseased"
    rows.append(row)

df = pd.DataFrame(rows)

df_full = pd.DataFrame.from_records(rows)
print(df_full.isna().any()[lambda x: x])
print(df_full.head(4))
print("Shape of df_full:", df_full.shape)

y_global = np.array(df_full["label"])
N = len(y_global)

seeds = [10, 101, 202]

def get_folds(seed):
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
    return list(skf.split(np.zeros((N, 1)), y_global))

def objective(trial):
    hidden_units = trial.suggest_int("hidden_units", 8, 512, step=8)
    lr           = trial.suggest_float("lr", 1e-4, 1e-2, log=True)
    dropout      = trial.suggest_float("dropout", 0.0, 0.5)
    num_layers   = trial.suggest_int("num_layers", 1, 12)
    num_epochs   = trial.suggest_int("num_epochs", 100, 400, step=25)

    def_pc_amt = trial.suggest_int("def_pc_amt", 1, MAX_DEF_PC)

    selected_cols = []
    for L, name in class_mapping.items():
        for i in range(def_pc_amt):
            selected_cols.append(f"def_pc{i+1}_{name}")
        for i, feat_name in enumerate(SEMANTIC_FEATURES):
            selected_cols.append(f"{feat_name}_{name}")

    X_df = df_full[selected_cols]                   # shape (N_samples, def_pc_amt*7 + F_sem*7)
    X_np = X_df.to_numpy(dtype=np.float32)           # numpy array (N_samples, D_trial)
    y_np = y_global.copy()                           # numpy array (N_samples,), strings "Normal"/"Diseased"
    N_samples = len(y_np)

    X_all = torch.from_numpy(X_np).float()
    y_all = torch.from_numpy((y_np == "Diseased").astype(np.float32)).to(device)

    seed_means = []
    per_seed_fold_metrics = []

    for s in seeds:
        random.seed(s)
        np.random.seed(s)
        torch.manual_seed(s)

        skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=s)
        # pass a dummy X of shape (N_samples, 1) because StratifiedKFold only uses y
        folds = list(skf.split(np.zeros((N_samples, 1)), y_np))

        fold_accuracy_list = []
        fold_metrics_list = []

        for (train_idx, val_idx) in folds:
            X_train = X_all[train_idx]
            y_train = y_all[train_idx]
            X_val   = X_all[val_idx]
            y_val   = y_all[val_idx]

            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            X_val_scaled   = scaler.transform(X_val)

            X_train = torch.from_numpy(X_train_scaled).float().to(device)
            X_val   = torch.from_numpy(X_val_scaled).float().to(device)

            # MLP
            layers = [hidden_units] * num_layers + [1]
            model = MLP(X_np.shape[1], layers, dropout=dropout).to(device)
            optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
            criterion = torch.nn.BCEWithLogitsLoss()

            # Training loop
            for epoch in range(1, num_epochs + 1):
                model.train()
                logits = model(X_train).squeeze(1)
                loss = criterion(logits, y_train)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # Validation
            model.eval()
            with torch.no_grad():
                val_logits = model(X_val).squeeze(1)
                val_probs = torch.sigmoid(val_logits).cpu().numpy()
                val_preds = (val_probs >= 0.5).astype(int)

                y_true = y_val.cpu().numpy().astype(int)
                y_pred = val_preds

                acc  = accuracy_score(y_true, y_pred)
                prec = precision_score(y_true, y_pred, zero_division=0)
                rec  = recall_score(y_true, y_pred, zero_division=0)
                f1   = f1_score(y_true, y_pred, zero_division=0)
                cm   = confusion_matrix(y_true, y_pred)
                tn, fp, fn, tp = cm.ravel()
                sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
                specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0

            fold_accuracy_list.append(acc)
            fold_metrics_list.append({
                "accuracy":        acc,
                "precision":       prec,
                "recall":          rec,
                "f1":              f1,
                "sensitivity":     sensitivity,
                "specificity":     specificity,
                "confusion_matrix": cm
            })

        seed_mean_acc = np.mean(fold_accuracy_list)
        seed_means.append(seed_mean_acc)
        per_seed_fold_metrics.append(fold_metrics_list)

    trial.set_user_attr("seed_means", seed_means)
    trial.set_user_attr("per_seed_fold_metrics", per_seed_fold_metrics)
    trial.set_user_attr("hyperparams", {
        "hidden_units": hidden_units,
        "lr":            lr,
        "dropout":       dropout,
        "num_layers":    num_layers,
        "num_epochs":    num_epochs,
        "def_pc_amt":    def_pc_amt
    })

    return np.mean(seed_means)

# Run Optuna
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=500, show_progress_bar=True)

# Tie-break on stddev
trials = study.trials
best_trial = max(trials, key=lambda t: (t.value, -np.std(t.user_attrs["seed_means"])))

best_score       = best_trial.value
best_std_seed    = np.std(best_trial.user_attrs["seed_means"])
best_hyperparams = best_trial.user_attrs["hyperparams"]
best_fold_info   = {
    'seeds': seeds,
    'seed_means': best_trial.user_attrs["seed_means"],
    'per_seed_fold_metrics': best_trial.user_attrs["per_seed_fold_metrics"]
}

# 7) Aggregate all metrics over seeds & folds
all_accuracies    = []
all_precisions    = []
all_recalls       = []
all_f1s           = []
all_sensitivities = []
all_specificities = []

for seed_metrics in best_fold_info['per_seed_fold_metrics']:
    for m in seed_metrics:
        all_accuracies.append(m['accuracy'])
        all_precisions.append(m['precision'])
        all_recalls.append(m['recall'])
        all_f1s.append(m['f1'])
        all_sensitivities.append(m['sensitivity'])
        all_specificities.append(m['specificity'])

acc_mean,  acc_std  = np.mean(all_accuracies),    np.std(all_accuracies)
prec_mean, prec_std = np.mean(all_precisions),   np.std(all_precisions)
rec_mean,  rec_std  = np.mean(all_recalls),      np.std(all_recalls)
f1_mean,   f1_std   = np.mean(all_f1s),          np.std(all_f1s)
sens_mean, sens_std = np.mean(all_sensitivities), np.std(all_sensitivities)
spec_mean, spec_std = np.mean(all_specificities), np.std(all_specificities)

print("\n=== Best Hyperparameters (by avg-seed, tie-break on lowest std) ===")
for k, v in best_hyperparams.items():
    print(f"{k:<12}: {v}")
print(f"Avg of seed-means = {best_score:.4f}")
print(f"Std of seed-means = {best_std_seed:.4f}\n")

print("=== Aggregate Metrics over all seeds & folds (mean ± std) ===")
print(f"Accuracy    : {acc_mean:.4f} ± {acc_std:.4f}")
print(f"Precision   : {prec_mean:.4f} ± {prec_std:.4f}")
print(f"Recall      : {rec_mean:.4f} ± {rec_std:.4f}")
print(f"F1 Score    : {f1_mean:.4f} ± {f1_std:.4f}")
print(f"Sensitivity : {sens_mean:.4f} ± {sens_std:.4f}")
print(f"Specificity : {spec_mean:.4f} ± {spec_std:.4f}\n")

for seed_idx, s in enumerate(best_fold_info['seeds']):
    print(f"--- Seed {s} (mean CV acc = {best_fold_info['seed_means'][seed_idx]:.4f}) ---")
    for fold_idx, m in enumerate(best_fold_info['per_seed_fold_metrics'][seed_idx], start=1):
        print(f"Fold {fold_idx}:")
        print(f"  Accuracy    = {m['accuracy']:.4f}")
        print(f"  Precision   = {m['precision']:.4f}")
        print(f"  Recall      = {m['recall']:.4f}")
        print(f"  F1 Score    = {m['f1']:.4f}")
        print(f"  Sensitivity = {m['sensitivity']:.4f}")
        print(f"  Specificity = {m['specificity']:.4f}")
        print(f"  Confusion Matrix:\n{m['confusion_matrix']}\n")


# ResNet Baseline
We provide a Resnet-50 model as an Image-only baseline to compare our model against.

In [None]:
import torch
import numpy as np
import random
import optuna
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix
)
from monai.networks.nets import resnet
from torch.utils.data import DataLoader, SubsetRandomSampler

dataset = dataset_test
all_indices = list(range(len(dataset)))

valid_indices = []
labels = []
for i in all_indices:
    fname = dataset[i]['fname']
    if "Diseased" in fname:
        valid_indices.append(i)
        labels.append(1)
    elif "Normal" in fname:
        valid_indices.append(i)
        labels.append(0)
labels = np.array(labels)
N = len(valid_indices)

seeds = [10, 101, 202]

def make_resnet50_3d():
    return resnet.resnet50(
        spatial_dims=3,
        n_input_channels=1,
        num_classes=1
    )

def get_folds(seed):
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
    return list(skf.split(np.zeros((N, 1)), labels))

def objective(trial):
    lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True)
    num_epochs = trial.suggest_categorical("num_epochs", [25, 50, 100])

    seed_means = []
    per_seed_fold_metrics = []

    for s in seeds:
        random.seed(s)
        np.random.seed(s)
        torch.manual_seed(s)

        folds = get_folds(s)
        fold_accs = []
        fold_metrics = []

        for train_idx, val_idx in folds:
            train_dataset_indices = [valid_indices[i] for i in train_idx]
            val_dataset_indices = [valid_indices[i] for i in val_idx]

            train_loader = DataLoader(
                dataset,
                batch_size=1,
                sampler=SubsetRandomSampler(train_dataset_indices),
                num_workers=2,
                pin_memory=torch.cuda.is_available()
            )
            val_loader = DataLoader(
                dataset,
                batch_size=1,
                sampler=SubsetRandomSampler(val_dataset_indices),
                num_workers=2,
                pin_memory=torch.cuda.is_available()
            )

            model = make_resnet50_3d().to(device)
            optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
            criterion = torch.nn.BCEWithLogitsLoss()

            for epoch in range(1, num_epochs + 1):
                model.train()
                for batch in train_loader:
                    imgs = batch['image'].float().to(device)
                    fname = batch['fname'][0]
                    lbl = torch.tensor(
                        [1.0 if "Diseased" in fname else 0.0],
                        device=device
                    ).view(-1)
                    optimizer.zero_grad()
                    logits = model(imgs).view(-1)
                    loss = criterion(logits, lbl)
                    loss.backward()
                    optimizer.step()

            model.eval()
            all_preds, all_trues = [], []
            with torch.no_grad():
                for batch in val_loader:
                    imgs = batch['image'].float().to(device)
                    fname = batch['fname'][0]
                    true_lbl = 1 if "Diseased" in fname else 0
                    logits = model(imgs).view(-1)
                    prob = torch.sigmoid(logits).cpu().item()
                    pred = 1 if prob >= 0.5 else 0
                    all_preds.append(pred)
                    all_trues.append(true_lbl)

            y_true = np.array(all_trues)
            y_pred = np.array(all_preds)
            acc  = accuracy_score(y_true, y_pred)
            prec = precision_score(y_true, y_pred, zero_division=0)
            rec  = recall_score(y_true, y_pred, zero_division=0)
            f1   = f1_score(y_true, y_pred, zero_division=0)
            cm   = confusion_matrix(y_true, y_pred)
            tn, fp, fn, tp = cm.ravel()
            sensitivity = tp/(tp+fn) if (tp+fn)>0 else 0.0
            specificity = tn/(tn+fp) if (tn+fp)>0 else 0.0

            fold_accs.append(acc)
            fold_metrics.append({
                'accuracy': acc,
                'precision': prec,
                'recall': rec,
                'f1': f1,
                'sensitivity': sensitivity,
                'specificity': specificity,
                'confusion_matrix': cm
            })

        seed_means.append(np.mean(fold_accs))
        per_seed_fold_metrics.append(fold_metrics)

    trial.set_user_attr("seed_means", seed_means)
    trial.set_user_attr("per_seed_fold_metrics", per_seed_fold_metrics)
    trial.set_user_attr("hyperparams", {
        'lr': lr,
        'weight_decay': weight_decay,
        'num_epochs': num_epochs
    })

    return np.mean(seed_means)

study = optuna.create_study(direction="maximize")
# We do 3 trials for brevity, but can be increased further if you have more compute/time.
# Else, you can reduce the number of seeds (from 3) or reduce the number of folds (from 5).
study.optimize(objective, n_trials=3)

trials = study.trials
best_trial = max(trials, key=lambda t: (t.value, -np.std(t.user_attrs["seed_means"])))

best_score       = best_trial.value
best_std_seed    = np.std(best_trial.user_attrs["seed_means"])
best_hyperparams = best_trial.user_attrs["hyperparams"]
best_fold_info   = {
    'seeds': seeds,
    'seed_means': best_trial.user_attrs["seed_means"],
    'per_seed_fold_metrics': best_trial.user_attrs["per_seed_fold_metrics"]
}

# Print results
print("\n=== Best Hyperparameters (by avg-seed, tie-break on lowest std) ===")
for k, v in best_hyperparams.items():
    print(f"{k:<12}: {v}")
print(f"Avg of seed-means = {best_score:.4f}")
print(f"Std of seed-means = {best_std_seed:.4f}\n")

# Aggregate all metrics over seeds & folds
all_accuracies = []
all_precisions = []
all_recalls    = []
all_f1s        = []
all_sensitivities = []
all_specificities = []

for seed_metrics in best_fold_info['per_seed_fold_metrics']:
    for m in seed_metrics:
        all_accuracies.append(m['accuracy'])
        all_precisions.append(m['precision'])
        all_recalls.append(m['recall'])
        all_f1s.append(m['f1'])
        all_sensitivities.append(m['sensitivity'])
        all_specificities.append(m['specificity'])

acc_mean,  acc_std  = np.mean(all_accuracies),    np.std(all_accuracies)
prec_mean, prec_std = np.mean(all_precisions),   np.std(all_precisions)
rec_mean,  rec_std  = np.mean(all_recalls),      np.std(all_recalls)
f1_mean,   f1_std   = np.mean(all_f1s),          np.std(all_f1s)
sens_mean, sens_std = np.mean(all_sensitivities), np.std(all_sensitivities)
spec_mean, spec_std = np.mean(all_specificities), np.std(all_specificities)

print("=== Aggregate Metrics over all seeds & folds (mean ± std) ===")
print(f"Accuracy    : {acc_mean:.4f} ± {acc_std:.4f}")
print(f"Precision   : {prec_mean:.4f} ± {prec_std:.4f}")
print(f"Recall      : {rec_mean:.4f} ± {rec_std:.4f}")
print(f"F1 Score    : {f1_mean:.4f} ± {f1_std:.4f}")
print(f"Sensitivity : {sens_mean:.4f} ± {sens_std:.4f}")
print(f"Specificity : {spec_mean:.4f} ± {spec_std:.4f}\n")

for seed_idx, s in enumerate(best_fold_info['seeds']):
    print(f"--- Seed {s} (mean CV acc = {best_fold_info['seed_means'][seed_idx]:.4f}) ---")
    for fold_idx, m in enumerate(best_fold_info['per_seed_fold_metrics'][seed_idx], start=1):
        print(f"Fold {fold_idx}:")
        print(f"  Accuracy    = {m['accuracy']:.4f}")
        print(f"  Precision   = {m['precision']:.4f}")
        print(f"  Recall      = {m['recall']:.4f}")
        print(f"  F1 Score    = {m['f1']:.4f}")
        print(f"  Sensitivity = {m['sensitivity']:.4f}")
        print(f"  Specificity = {m['specificity']:.4f}")
        print(f"  Confusion Matrix:\n{m['confusion_matrix']}\n")
