### Setup/Imports

In [None]:
import os
import gc
import numpy as np
import torch
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from tqdm import tqdm
import pandas as pd
import scipy.io
from pipeline.paths import Directories, Files
from pipeline.utils import read_scans_agg_file
import scipy.stats
plt.rcParams['figure.dpi'] = 200

### Config

In [None]:
PHASE = "7"
DATA_VERSION = "13"
WORK_ROOT = "D:/NoahSilverberg/ngCBCT"
SPLIT_TO_ANALYZE = 'TEST'  # Options: 'TRAIN', 'VALIDATION', 'TEST'

MODELS_TO_ANALYZE = [
    # {
    #     'name': 'SWAG Learning Rate 1e-2',
    #     'type': 'stochastic',
    #     'domain': 'FDK',
    #     'model_version_root': 'MK7_MCDROPOUT_15_pct_NEW_SWAG_lr1e-2',
    #     'count': 33,
    # },
    {
        'name': 'MC Dropoout 15%',
        'type': 'stochastic',
        'domain': 'FDK',
        'model_version_root': 'MK7_MCDROPOUT_15_pct_NEW',
        'count': 50,
    },
    {
        'name': 'MC Dropoout 30%',
        'type': 'stochastic',
        'domain': 'FDK',
        'model_version_root': 'MK7_MCDROPOUT_30_pct_NEW',
        'count': 50,
    },
    # {
    #     'name': 'MC Dropoout 50%',
    #     'type': 'stochastic',
    #     'domain': 'FDK',
    #     'model_version_root': 'MK7_MCDROPOUT_50_pct_NEW',
    #     'count': 50,
    # },
    {
        'name': 'Ensemble',
        'type': 'ensemble',
        'domain': 'FDK',
        'model_version_root': 'MK7',
        'count': 7,
    },
    # Add other models here
]

# Path to the .pt file containing tumor locations.
# This is a 5D tensor [patient, scan, (x, y, z)]
TUMOR_LOCATIONS_FILE = 'D:/NoahSilverberg/ngCBCT/3D_recon/tumor_location.pt'

# --- Advanced Config ---
SCANS_AGG_FILE = 'scans_to_agg.txt'
SSIM_KWARGS = {"K1": 0.03, "K2": 0.06, "win_size": 15}
SSIM_KWARGS_ = {"k1": 0.03, "k2": 0.06, "kernel_size": 15}

# --- Setup ---
# Create Directories and Files objects
phase_dataver_dir = os.path.join(WORK_ROOT, f"phase{PHASE}", f"DS{DATA_VERSION}")
DIRECTORIES = Directories(
    projections_results_dir=os.path.join(phase_dataver_dir, "results", "projections"),
    projections_gated_dir=os.path.join(WORK_ROOT, "gated", "prj_mat"),
    reconstructions_dir=os.path.join(phase_dataver_dir, "reconstructions"),
    reconstructions_gated_dir=os.path.join(WORK_ROOT, "gated", "fdk_recon"),
    images_results_dir=os.path.join(phase_dataver_dir, "results", "images"),
)
FILES = Files(DIRECTORIES)

# Load the list of scans
scans_agg, scan_type_agg = read_scans_agg_file(SCANS_AGG_FILE)
analysis_scans = scans_agg[SPLIT_TO_ANALYZE][1:2]

# Load tumor locations
tumor_locations = torch.load(TUMOR_LOCATIONS_FILE, weights_only=False)

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE} named '{torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}'")

print(f"\nConfiguration loaded.")
print(f"Analyzing {len(analysis_scans)} scans from the '{SPLIT_TO_ANALYZE}' split.")
print(f"Found {len(MODELS_TO_ANALYZE)} model(s) to analyze.")

### Data loading/prep functions

In [None]:
def load_ground_truth(files_obj: Files, scan_info, domain, slice_idx=None):
    """
    Loads the ground truth data for a given scan and domain.
    """
    patient, scan, scan_type = scan_info
    
    if domain == 'PROJ':
        gt_path = files_obj.get_projections_results_filepath('fdk', patient, scan, scan_type, gated=True)
        data = torch.from_numpy(scipy.io.loadmat(gt_path)['prj']).detach().permute(1, 0, 2)
    elif domain == 'FDK':
        gt_path = files_obj.get_recon_filepath("fdk", patient, scan, scan_type, gated=True, ensure_exists=False)
        data = torch.load(gt_path).detach()
        data = data[20:-20, :, :]
        data = 25. * torch.clip(data, min=0.0, max=0.04)
    elif domain == 'IMAG':
        # The ground truth for the IMAG domain is the FDK of the gated projection
        gt_path = files_obj.get_recon_filepath("fdk", patient, scan, scan_type, gated=True, ensure_exists=False)
        data = torch.load(gt_path).detach()
        data = data[20:-20, :, :]
        data = 25. * torch.clip(data, min=0.0, max=0.04)
    else:
        raise ValueError(f"Unknown domain: {domain}")

    if slice_idx is not None and data.ndim == 3:
        return data[slice_idx]
    return data

def load_predictions(files_obj: Files, model_config, scan_info, slice_idx=None):
    """
    Loads all predictions for a given model, scan, and domain.
    """
    patient, scan, scan_type = scan_info
    domain = model_config['domain']
    root = model_config['model_version_root']
    count = model_config['count']
    model_type = model_config['type']

    predictions = []
    
    print(f"Loading {count} predictions for {model_config['name']}...")
    
    for i in tqdm(range(count), desc="Loading predictions", leave=False):
        passthrough_num = None
        model_version = root

        if model_type == 'ensemble':
            model_version = f"{root}_{i+1:02d}"
        elif model_type == 'stochastic':
            passthrough_num = i

        if domain == 'PROJ':
            pred_path = files_obj.get_projections_results_filepath(model_version, patient, scan, scan_type, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
            pred = torch.from_numpy(scipy.io.loadmat(pred_path)['prj']).detach().permute(1, 0, 2)
        elif domain == 'FDK':
            pred_path = files_obj.get_recon_filepath(model_version, patient, scan, scan_type, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
            pred = torch.load(pred_path).detach()
            pred = pred[20:-20, :, :]
            pred = 25. * torch.clip(pred, min=0.0, max=0.04)
        elif domain == 'IMAG':
            # This assumes the results are saved with the ID model version name
            pred_path = files_obj.get_images_results_filepath(model_version, patient, scan, passthrough_num=passthrough_num, ensure_exists=False)
            pred = torch.load(pred_path).detach()
            pred = torch.squeeze(pred, dim=1)
            pred = torch.permute(pred, (0, 2, 1))
        else:
            raise ValueError(f"Unknown domain: {domain}")
            
        predictions.append(pred)

    predictions_tensor = torch.stack(predictions)
    
    if slice_idx is not None and predictions_tensor.ndim == 4:
        return predictions_tensor[:, slice_idx, :, :]
        
    return predictions_tensor

print("Data loading functions defined.")

### Metric calculation functions

In [None]:
def calculate_ause_sparsification(uncertainty, errors):
    """
    Calculates the Area Under the Sparsification Error curve (AUSE) efficiently.
    """
    uncertainty_flat = uncertainty.flatten()
    errors_flat = errors.flatten()
    
    # Normalize by overall MAE so the curve starts at 1
    overall_mae = np.mean(errors_flat)
    
    def get_sparsification_curve_fast(sorted_errs):
        n_pixels = len(sorted_errs)
        cumulative_errors = np.cumsum(sorted_errs)
        total_error_sum = cumulative_errors[-1]
        sum_errors_removed = np.insert(cumulative_errors[:-1], 0, 0)
        sum_errors_remaining = total_error_sum - sum_errors_removed
        n_remaining = np.arange(n_pixels, 0, -1)
        curve = sum_errors_remaining / n_remaining
        if overall_mae > 0:
            curve = curve / overall_mae # Normalize the curve
        return curve

    # Move arrays to GPU using torch for sorting
    uncertainty_tensor = torch.from_numpy(uncertainty_flat).cuda()
    errors_tensor = torch.from_numpy(errors_flat).cuda()

    # Model curve (sorted by uncertainty)
    model_sorted_indices = torch.argsort(uncertainty_tensor, descending=True)
    model_sorted_errors = errors_tensor[model_sorted_indices].cpu().numpy()
    model_curve = get_sparsification_curve_fast(model_sorted_errors)

    # Oracle curve (sorted by error)
    oracle_sorted_errors = torch.sort(errors_tensor, descending=True)[0].cpu().numpy()
    oracle_curve = get_sparsification_curve_fast(oracle_sorted_errors)
    
    # The AUSE is the area between the two normalized curves
    ause = np.mean(np.abs(model_curve - oracle_curve))
    return ause

def calculate_ece(ground_truth, mean_pred, uncertainty_map, n_levels=20):
    """
    Calculates the weighted calibration error for regression tasks based on Kuleshov et al., 2018.
    """
    gt_flat = ground_truth.flatten()
    pred_flat = mean_pred.flatten()
    uncert_flat = uncertainty_map.flatten()

    pred_cdfs = scipy.stats.norm.cdf(gt_flat, loc=pred_flat, scale=uncert_flat)
    expected_confidence_levels = np.linspace(0, 1, n_levels)
    observed_frequencies = np.array([np.mean(pred_cdfs <= p_j) for p_j in expected_confidence_levels])
    
    bin_boundaries = np.copy(expected_confidence_levels)
    bin_weights = np.zeros(n_levels)
    for i in range(1, n_levels):
        lower_bound = bin_boundaries[i-1]
        upper_bound = bin_boundaries[i]
        points_in_bin = (pred_cdfs > lower_bound) & (pred_cdfs <= upper_bound)
        bin_weights[i] = np.mean(points_in_bin)
        
    if np.sum(bin_weights) > 0:
        bin_weights /= np.sum(bin_weights)

    squared_errors = (expected_confidence_levels - observed_frequencies)**2
    weighted_calibration_error = np.sum(bin_weights * squared_errors)
    return weighted_calibration_error

def calculate_spearman_correlation(uncertainty, errors):
    """
    Calculates the Spearman's Rank Correlation Coefficient between
    the uncertainty and the absolute error.
    """
    uncertainty_flat = uncertainty.flatten()
    errors_flat = errors.flatten()
    
    # spearmanr returns correlation and p-value; we only need the correlation
    correlation, _ = scipy.stats.spearmanr(uncertainty_flat, errors_flat)
    return correlation

import torchmetrics
import torchmetrics.image

def calculate_volume_metrics_2_pass(files_obj, model_config, scan_info, gt_volume, device):
    """
    Calculates stats and metrics using a two-pass algorithm for variance for improved stability.
    Also includes a robust PSNR calculation that handles infinite values.
    Pass 1: Calculate the mean of all predictions.
    Pass 2: Calculate variance and other metrics using the pre-calculated mean.
    """
    n_samples = model_config['count']
    gt_volume = gt_volume.to(device) # Ensure GT is on the correct device

    def prediction_generator():
        # This generator yields tensors directly on the GPU
        patient, scan, scan_type = scan_info
        domain = model_config['domain']
        root = model_config['model_version_root']
        model_type = model_config['type']
        for i in range(n_samples):
            passthrough_num = None
            model_version = root
            if model_type == 'ensemble': model_version = f"{root}_{i+1:02d}"
            elif model_type == 'stochastic': passthrough_num = i

            if domain == 'PROJ':
                pred_path = files_obj.get_projections_results_filepath(model_version, patient, scan, scan_type, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
                pred = torch.from_numpy(scipy.io.loadmat(pred_path)['prj']).detach().permute(1, 0, 2)
            elif domain == 'FDK':
                pred_path = files_obj.get_recon_filepath(model_version, patient, scan, scan_type, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
                pred = torch.load(pred_path).detach()
                pred = pred[20:-20, :, :]
                pred = 25. * torch.clip(pred, min=0.0, max=0.04)
            elif domain == 'IMAG':
                pred_path = files_obj.get_images_results_filepath(model_version, patient, scan, passthrough_num=passthrough_num, ensure_exists=False)
                pred = torch.load(pred_path).detach()
                pred = torch.squeeze(pred, dim=1)
                pred = torch.permute(pred, (0, 2, 1))
            yield pred.to(device)

    # --- Pass 1: Calculate Mean ---
    print("Pass 1: Calculating mean prediction...")
    mean_volume = torch.zeros_like(gt_volume)
    # Using a simple sum and divide for the mean
    for pred_volume in tqdm(prediction_generator(), total=n_samples, desc="Pass 1/2 (Mean)", leave=False):
        mean_volume += pred_volume
    mean_volume /= n_samples

    # --- Pass 2: Calculate Variance and Metrics ---
    print("Pass 2: Calculating variance and metrics...")
    sum_sq_diff_volume = torch.zeros_like(gt_volume)
    sample_avg_ssims, sample_avg_psnrs, sample_avg_mses, sample_avg_maes = [], [], [], []

    # Initialize metrics on the specified device
    data_range = gt_volume.max() - gt_volume.min()
    ssim_metric = torchmetrics.image.StructuralSimilarityIndexMeasure(data_range=data_range, **SSIM_KWARGS_).to(device)

    # PSNR metric that returns per-slice results to handle 'inf'
    psnr_metric = torchmetrics.image.PeakSignalNoiseRatio(data_range=data_range, reduction='none').to(device)

    for pred_volume in tqdm(prediction_generator(), total=n_samples, desc="Pass 2/2 (Var & Metrics)", leave=False):
        # Variance calculation
        diff = pred_volume - mean_volume
        sum_sq_diff_volume += diff * diff

        # Per-sample metrics
        if gt_volume.ndim > 2:
            pred_vol_batch = pred_volume.unsqueeze(1)
            gt_vol_batch = gt_volume.unsqueeze(1)
            
            sample_avg_ssims.append(ssim_metric(pred_vol_batch, gt_vol_batch).item())
            sample_avg_mses.append(torch.mean((gt_volume - pred_volume)**2).item())
            sample_avg_maes.append(torch.mean(torch.abs(gt_volume - pred_volume)).item())

            # --- Robust PSNR Calculation ---
            psnr_per_slice = psnr_metric(pred_vol_batch, gt_vol_batch)
            finite_psnrs = psnr_per_slice[torch.isfinite(psnr_per_slice)] # Filter out inf values
            if finite_psnrs.numel() > 0:
                sample_avg_psnrs.append(torch.mean(finite_psnrs).item())
            else:
                # Handle case where all slices are perfect matches
                sample_avg_psnrs.append(100.0) # Assign a high value

    # Finalize variance and uncertainty
    if n_samples > 1:
        # Using n_samples for population standard deviation, as in the original code.
        # For sample standard deviation, use (n_samples - 1).
        variance_volume_map = sum_sq_diff_volume / n_samples
        uncertainty_volume_map = torch.sqrt(variance_volume_map)
    else:
        uncertainty_volume_map = torch.zeros_like(mean_volume)

    # --- Calculate metrics for the mean prediction ---
    metrics = {}
    if gt_volume.ndim > 2:
        mean_vol_batch = mean_volume.unsqueeze(1)
        gt_vol_batch = gt_volume.unsqueeze(1)
        
        metrics['mean_ssim'] = ssim_metric(mean_vol_batch, gt_vol_batch).item()
        metrics['mean_mse'] = torch.mean((gt_volume - mean_volume)**2).item()
        metrics['mean_mae'] = torch.mean(torch.abs(gt_volume - mean_volume)).item()

        # Robust PSNR for the mean prediction
        mean_psnr_per_slice = psnr_metric(mean_vol_batch, gt_vol_batch)
        finite_mean_psnrs = mean_psnr_per_slice[torch.isfinite(mean_psnr_per_slice)]
        if finite_mean_psnrs.numel() > 0:
            metrics['mean_psnr'] = torch.mean(finite_mean_psnrs).item()
        else:
            metrics['mean_psnr'] = 100.0

    # --- Aggregate per-sample metrics ---
    metrics['sample_avg_ssim'] = np.mean(sample_avg_ssims) if sample_avg_ssims else 0
    metrics['sample_avg_psnr'] = np.mean(sample_avg_psnrs) if sample_avg_psnrs else 0
    metrics['sample_avg_mse'] = np.mean(sample_avg_mses) if sample_avg_mses else 0
    metrics['sample_avg_mae'] = np.mean(sample_avg_maes) if sample_avg_maes else 0

    return metrics, mean_volume, uncertainty_volume_map

print("Metric calculation functions defined.")

### Visualization functions

In [None]:
def plot_mean_comparison(mean_pred, ground_truth, uncertainty_map, model_name, scan_name, slice_idx, tumor_coords_xy=None):
    """Plots the GT, mean prediction, absolute error, and uncertainty map."""
    error_map = np.abs(ground_truth - mean_pred)

    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    fig.suptitle(f'{model_name} - {scan_name} - Slice {slice_idx} (Mean vs. GT)', fontsize=16)

    im1 = axes[0].imshow(ground_truth, cmap='gray')
    axes[0].set_title('Ground Truth')
    axes[0].axis('off')
    fig.colorbar(im1, ax=axes[0])

    if tumor_coords_xy:
        x, y = tumor_coords_xy
        for i in range(4):
            axes[i].annotate('', xy=(x, y), xytext=(x - 30, y - 30),
                            arrowprops=dict(facecolor='red', edgecolor='red', shrink=0.05, width=1, headwidth=5, headlength=5))

    im2 = axes[1].imshow(mean_pred, cmap='gray')
    axes[1].set_title('Mean Prediction')
    axes[1].axis('off')
    fig.colorbar(im2, ax=axes[1])

    im3 = axes[2].imshow(error_map, cmap='magma')
    axes[2].set_title('Absolute Error Map')
    axes[2].axis('off')
    fig.colorbar(im3, ax=axes[2])

    im4 = axes[3].imshow(uncertainty_map, cmap='viridis')
    axes[3].set_title('Uncertainty (Std Dev)')
    axes[3].axis('off')
    fig.colorbar(im4, ax=axes[3])

    plt.tight_layout()
    plt.show()

def plot_ssim_map(ssim_map, model_name, scan_name):
    """Plots the SSIM map."""
    plt.figure(figsize=(6, 6))
    plt.imshow(ssim_map, cmap='viridis', vmin=0, vmax=1)
    plt.title(f'SSIM Map - {model_name} - {scan_name}')
    plt.colorbar()
    plt.axis('off')
    plt.show()

def plot_calibration_curve(ground_truth, mean_pred, uncertainty_map, model_name, scan_name, n_levels=20):
    """
    Plots the calibration curve with a marginal histogram below the x-axis
    showing the distribution of the predicted CDF values.
    """
    gt_flat = ground_truth.flatten()
    pred_flat = mean_pred.flatten()
    uncert_flat = uncertainty_map.flatten()

    # Calculate the predicted CDF value for each point
    pred_cdfs = scipy.stats.norm.cdf(gt_flat, loc=pred_flat, scale=uncert_flat)
    
    # --- Create figure with two subplots, sharing the x-axis ---
    fig, (ax_cal, ax_hist) = plt.subplots(
        2, 1,
        figsize=(8, 8),
        sharex=True,
        gridspec_kw={'height_ratios': [3, 1]} # Main plot is 3x taller
    )
    
    # --- Main Calibration Plot (top) ---
    expected_confidence_levels = np.linspace(0, 1, n_levels)
    observed_frequencies = np.array([np.mean(pred_cdfs <= p_j) for p_j in expected_confidence_levels])

    ax_cal.plot([0, 1], [0, 1], '--', color='grey', label='Perfectly Calibrated')
    ax_cal.plot(expected_confidence_levels, observed_frequencies, '-o', label='Model Calibration')
    ax_cal.set_ylabel('Observed Confidence Level')
    ax_cal.set_title(f'Calibration Plot - {model_name} - {scan_name}')
    ax_cal.legend()
    ax_cal.grid(True, linestyle=':')

    # --- Marginal Histogram (bottom) ---
    ax_hist.hist(pred_cdfs, bins=50, range=(0,1), density=True, color='steelblue', alpha=0.8)
    ax_hist.set_xlabel('Expected Confidence Level (Predicted CDF)')
    ax_hist.set_ylabel('Density')
    ax_hist.set_yscale('log')
    # ax_hist.set_yticks([]) # Hide y-ticks for clarity

    # Final adjustments
    plt.tight_layout()
    plt.show()

def plot_sparsification_curve(uncertainty, errors, model_name, scan_name):
    """
    Plots the model and oracle sparsification curves used for AUSE calculation.
    """
    uncertainty_flat = uncertainty.flatten()
    errors_flat = errors.flatten()
    
    def get_sparsification_curve_fast(sorted_errs, overall_mae):
        n_pixels = len(sorted_errs)
        cumulative_errors = np.cumsum(sorted_errs)
        total_error_sum = cumulative_errors[-1]
        sum_errors_removed = np.insert(cumulative_errors[:-1], 0, 0)
        sum_errors_remaining = total_error_sum - sum_errors_removed
        n_remaining = np.arange(n_pixels, 0, -1)
        curve = sum_errors_remaining / n_remaining
        if overall_mae > 0:
            curve = curve / overall_mae
        return curve
    
    overall_mae = np.mean(errors_flat)

    # Model curve (sorted by uncertainty)
    model_sorted_indices = np.argsort(uncertainty_flat)[::-1]
    model_sorted_errors = errors_flat[model_sorted_indices]
    model_curve = get_sparsification_curve_fast(model_sorted_errors, overall_mae)
    
    # Oracle curve (sorted by error)
    oracle_sorted_errors = np.sort(errors_flat)[::-1]
    oracle_curve = get_sparsification_curve_fast(oracle_sorted_errors, overall_mae)
    
    # X-axis: fraction of pixels removed
    fraction_removed = np.linspace(0, 1, len(model_curve))
    
    plt.figure(figsize=(7, 6))
    plt.plot(fraction_removed, model_curve, label='Model (Sort by Uncertainty)')
    plt.plot(fraction_removed, oracle_curve, '--', label='Oracle (Sort by Error)')
    plt.xlabel('Fraction of Pixels Removed')
    plt.ylabel('Mean Absolute Error of Remaining Pixels')
    plt.title(f'Sparsification Curve - {model_name} - {scan_name}')
    plt.legend()
    plt.grid(True, linestyle=':')
    plt.show()

def plot_samples_comparison(ground_truth, mean_pred, samples, model_name, scan_name, slice_idx, tumor_coords_xy=None):
    """
    Plots the ground truth, mean prediction, and a few individual sample predictions.
    
    Args:
        ground_truth (np.ndarray): The 2D ground truth slice.
        mean_pred (np.ndarray): The 2D mean prediction slice.
        samples (list of np.ndarray): A list of 2D sample prediction slices.
        model_name (str): The name of the model for the title.
        scan_name (str): The name of the scan for the title.
        slice_idx (int): The index of the slice for the title.
        tumor_coords_xy (tuple, optional): (x, y) coordinates for the tumor arrow.
    """
    num_samples = len(samples)
    # Total columns = 1 for GT + 1 for Mean + N for samples
    num_cols = 2 + num_samples
    
    fig, axes = plt.subplots(1, num_cols, figsize=(4 * num_cols, 4.5), constrained_layout=True)
    fig.suptitle(f'{model_name} - {scan_name} - Slice {slice_idx} (GT, Mean, and Samples)', fontsize=16)

    # Determine a consistent grayscale range based on the ground truth and mean
    vmin = min(ground_truth.min(), mean_pred.min())
    vmax = max(ground_truth.max(), mean_pred.max())

    # --- Plot Ground Truth ---
    axes[0].imshow(ground_truth, cmap='gray', vmin=vmin, vmax=vmax)
    axes[0].set_title('Ground Truth')
    axes[0].axis('off')

    # --- Plot Mean Prediction ---
    axes[1].imshow(mean_pred, cmap='gray', vmin=vmin, vmax=vmax)
    axes[1].set_title('Mean Prediction')
    axes[1].axis('off')

    # --- Plot Samples ---
    for i in range(num_samples):
        ax = axes[i + 2]
        im = ax.imshow(samples[i], cmap='gray', vmin=vmin, vmax=vmax)
        ax.set_title(f'Sample {i+1}')
        ax.axis('off')

    # --- Add Tumor Arrow ---
    if tumor_coords_xy:
        x, y = tumor_coords_xy
        for ax in axes:
            ax.annotate('', xy=(x, y), xytext=(x - 30, y - 30),
                        arrowprops=dict(facecolor='red', edgecolor='red', shrink=0.05, 
                                        width=1, headwidth=5, headlength=5))

    plt.show()

def plot_worst_samples_comparison(ground_truth, mean_pred, worst_samples_data, model_name, scan_name, slice_idx, tumor_coords_xy=None):
    """
    Plots the ground truth, mean prediction, and the worst-performing sample predictions.
    
    Args:
        ground_truth (np.ndarray): The 2D ground truth slice.
        mean_pred (np.ndarray): The 2D mean prediction slice.
        worst_samples_data (list): A list of tuples, where each tuple is 
                                   (loss, sample_slice_numpy, sample_index).
        model_name (str): The name of the model for the title.
        scan_name (str): The name of the scan for the title.
        slice_idx (int): The index of the slice for the title.
        tumor_coords_xy (tuple, optional): (x, y) coordinates for the tumor arrow.
    """
    num_samples = len(worst_samples_data)
    num_cols = 2 + num_samples
    
    fig, axes = plt.subplots(1, num_cols, figsize=(4 * num_cols, 5), constrained_layout=True)
    fig.suptitle(f'{model_name} - {scan_name} - Slice {slice_idx} (Top {num_samples} Worst Samples by SmoothL1Loss)', fontsize=16)

    # Determine a consistent grayscale range
    vmin = min(ground_truth.min(), mean_pred.min())
    vmax = max(ground_truth.max(), mean_pred.max())

    # --- Plot Ground Truth ---
    axes[0].imshow(ground_truth, cmap='gray', vmin=vmin, vmax=vmax)
    axes[0].set_title('Ground Truth')
    axes[0].axis('off')

    # --- Plot Mean Prediction ---
    axes[1].imshow(mean_pred, cmap='gray', vmin=vmin, vmax=vmax)
    axes[1].set_title('Mean Prediction')
    axes[1].axis('off')

    # --- Plot Worst Samples ---
    for i in range(num_samples):
        loss, sample_slice, sample_idx = worst_samples_data[i]
        ax = axes[i + 2]
        im = ax.imshow(sample_slice, cmap='gray', vmin=vmin, vmax=vmax)
        # Add the loss and original sample number to the title
        ax.set_title(f'Sample #{sample_idx}\nLoss: {loss:.4f}')
        ax.axis('off')

    # --- Add Tumor Arrow ---
    if tumor_coords_xy:
        x, y = tumor_coords_xy
        for ax in axes:
            ax.annotate('', xy=(x, y), xytext=(x - 30, y - 30),
                        arrowprops=dict(facecolor='red', edgecolor='red', shrink=0.05, 
                                        width=1, headwidth=5, headlength=5))

    plt.show()

print("Visualization functions defined.")

### Main loop

In [None]:
# This list will store dictionaries of results for each scan and model
all_results = []

for model_config in MODELS_TO_ANALYZE:
    model_name = model_config['name']
    domain = model_config['domain']
    
    scan_results = []

    for scan_info in tqdm(analysis_scans, desc=f"Analyzing Model: {model_name}"):
        patient, scan, _ = scan_info
        scan_name = f"p{patient}_{scan}"
        
        # --- Determine Domain and Plotting Slice ---
        is_visual_domain = domain in ['FDK', 'IMAG']
        plot_slice_idx = None
        tumor_xy = None
        
        if is_visual_domain:
            if 'tumor_locations' in locals() and tumor_locations is not None:
                try:
                    loc = tumor_locations[int(patient), int(scan)]
                    tumor_xy = (loc[1].item(), loc[0].item())
                    plot_slice_idx = int(loc[2].item()) - 20
                except (IndexError, TypeError):
                    print(f"Warning: Could not find tumor location for {scan_name}. Plotting will not have an arrow.")
                    plot_slice_idx = 100
            else:
                plot_slice_idx = 100
        
        # --- Data Loading (Ground Truth Only) ---
        gt_volume = load_ground_truth(FILES, scan_info, domain, slice_idx=None)
        gt_volume_np = gt_volume.cpu().numpy()
        
        # --- Iterative Metric Calculation ---
        print("Calculating metrics iteratively...")
        iq_metrics, mean_pred_vol, uncertainty_map_vol = calculate_volume_metrics_2_pass(
            FILES, model_config, scan_info, gt_volume.to(DEVICE), DEVICE
        )
        mean_pred_vol = mean_pred_vol.cpu().numpy()
        uncertainty_map_vol = uncertainty_map_vol.cpu().numpy()
        
        # --- Uncertainty Metric Calculation ---
        errors_vol = np.abs(gt_volume_np - mean_pred_vol)
        print("Calculating AUSE...")
        ause_val = calculate_ause_sparsification(uncertainty_map_vol, errors_vol)
        # print("Calculating ECE...")
        # ece_val = calculate_ece(gt_volume_np, mean_pred_vol, uncertainty_map_vol)
        print("Calculating Spearman's correlation...")
        spearman_val = calculate_spearman_correlation(uncertainty_map_vol, errors_vol)
        
        # --- Store Results ---
        scan_result = {
            'model_name': model_name,
            'scan_name': scan_name,
            **iq_metrics,
            'ause': ause_val,
            # 'ece': ece_val,
            'spearman_corr': spearman_val,
        }
        scan_results.append(scan_result)
        
        # --- Visualization ---
        print(f"\n--- Results for {model_name} on {scan_name} ---")
        
        # plot_calibration_curve(gt_volume_np, mean_pred_vol, uncertainty_map_vol, model_name, scan_name)
        plot_sparsification_curve(uncertainty_map_vol, errors_vol, model_name, scan_name)
        
        if is_visual_domain:
            gt_slice_np = gt_volume_np[plot_slice_idx]
            mean_pred_slice = mean_pred_vol[plot_slice_idx]
            uncertainty_map_slice = uncertainty_map_vol[plot_slice_idx]
            _, ssim_map_val = ssim(gt_slice_np, mean_pred_slice, 
                                  data_range=(np.max(gt_slice_np) - np.min(gt_slice_np)), 
                                  full=True, **SSIM_KWARGS)
            
            plot_mean_comparison(mean_pred_slice, gt_slice_np, uncertainty_map_slice, 
                                 model_name, scan_name, plot_slice_idx, tumor_coords_xy=tumor_xy)
            plot_ssim_map(ssim_map_val, model_name, scan_name)

            # --- Load and plot a few samples for visual comparison ---
            print("Loading and plotting samples...")
            SAMPLES_TO_PLOT = 5
            sample_slices_for_plotting = []
            
            # Determine the number of samples to load (can't be more than what's available)
            num_to_load = min(SAMPLES_TO_PLOT, model_config['count'])

            for i in range(num_to_load):
                passthrough_num = None
                model_version = model_config['model_version_root']
                model_type = model_config['type']

                if model_type == 'ensemble':
                    model_version = f"{model_config['model_version_root']}_{i+1:02d}"
                elif model_type == 'stochastic':
                    passthrough_num = i
                
                # This loading logic is copied from your metrics function
                pred = None
                if domain == 'FDK':
                    pred_path = FILES.get_recon_filepath(model_version, patient, scan, scan_type_agg, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
                    pred = torch.load(pred_path).detach()
                    pred = pred[20:-20, :, :]
                    pred = 25. * torch.clip(pred, min=0.0, max=0.04)
                elif domain == 'IMAG':
                    pred_path = FILES.get_images_results_filepath(model_version, patient, scan, passthrough_num=passthrough_num, ensure_exists=False)
                    pred = 
                    (pred_path).detach()
                    pred = torch.squeeze(pred, dim=1)
                    pred = torch.permute(pred, (0, 2, 1))
                
                # Get the specific slice and convert to a numpy array for plotting
                if pred is not None:
                    sample_slice = pred[plot_slice_idx].cpu().numpy()
                    sample_slices_for_plotting.append(sample_slice)

            # Call the new plotting function
            if sample_slices_for_plotting:
                plot_samples_comparison(
                    ground_truth=gt_slice_np,
                    mean_pred=mean_pred_slice,
                    samples=sample_slices_for_plotting,
                    model_name=model_name,
                    scan_name=scan_name,
                    slice_idx=plot_slice_idx,
                    tumor_coords_xy=tumor_xy
                )

            print("Finding and plotting worst samples by Smooth L1 Loss...")
            WORST_SAMPLES_TO_PLOT = 5

            # List to store tuples of (loss, sample_slice_numpy, sample_index)
            worst_samples_data = []
            # Get the ground truth slice as a tensor on the correct device
            gt_slice_tensor = gt_volume[plot_slice_idx].to(DEVICE)

            # Loop through all available samples to find the worst ones
            for i in tqdm(range(model_config['count']), desc="Finding Worst Samples", leave=False):
                passthrough_num = None
                model_version = model_config['model_version_root']
                model_type = model_config['type']

                if model_type == 'ensemble':
                    model_version = f"{model_config['model_version_root']}_{i+1:02d}"
                elif model_type == 'stochastic':
                    passthrough_num = i
                
                # Load the prediction volume as a tensor on the GPU
                pred_vol_tensor = None
                if domain == 'FDK':
                    pred_path = FILES.get_recon_filepath(model_version, patient, scan, scan_type_agg, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
                    pred_vol_tensor = torch.load(pred_path, map_location=DEVICE).detach()
                    pred_vol_tensor = pred_vol_tensor[20:-20, :, :]
                    pred_vol_tensor = 25. * torch.clip(pred_vol_tensor, min=0.0, max=0.04)
                elif domain == 'IMAG':
                    pred_path = FILES.get_images_results_filepath(model_version, patient, scan, passthrough_num=passthrough_num, ensure_exists=False)
                    pred_vol_tensor = torch.load(pred_path, map_location=DEVICE).detach()
                    pred_vol_tensor = torch.squeeze(pred_vol_tensor, dim=1)
                    pred_vol_tensor = torch.permute(pred_vol_tensor, (0, 2, 1))

                if pred_vol_tensor is not None:
                    pred_slice_tensor = pred_vol_tensor[plot_slice_idx]
                    
                    # Calculate Smooth L1 Loss for the current slice
                    import torch.nn.functional as F
                    loss = F.smooth_l1_loss(pred_slice_tensor, gt_slice_tensor, reduction='mean').item()

                    # Keep track of the top 5 worst samples (highest loss)
                    if len(worst_samples_data) < WORST_SAMPLES_TO_PLOT:
                        worst_samples_data.append((loss, pred_slice_tensor.cpu().numpy(), i))
                    else:
                        # Find the sample with the minimum loss currently in our list
                        min_loss_in_list = min(worst_samples_data, key=lambda x: x[0])
                        if loss > min_loss_in_list[0]:
                            # If current sample is worse, replace the "best of the worst"
                            worst_samples_data.remove(min_loss_in_list)
                            worst_samples_data.append((loss, pred_slice_tensor.cpu().numpy(), i))

            # Sort the final list from worst to best for plotting
            worst_samples_data.sort(key=lambda x: x[0], reverse=True)

            # Call the new plotting function with the results
            if worst_samples_data:
                plot_worst_samples_comparison(
                    ground_truth=gt_slice_np,
                    mean_pred=mean_pred_slice,
                    worst_samples_data=worst_samples_data,
                    model_name=model_name,
                    scan_name=scan_name,
                    slice_idx=plot_slice_idx,
                    tumor_coords_xy=tumor_xy
                )

        # --- Clean up memory ---
        del gt_volume, gt_volume_np, mean_pred_vol, uncertainty_map_vol, errors_vol
        gc.collect()
        
    all_results.extend(scan_results)

# Convert results to a pandas DataFrame for easier analysis
results_df = pd.DataFrame(all_results)

print("\n\n✅ Analysis complete for all models and scans.")
results_df

In [None]:
# This list will store dictionaries of results for each scan and model
all_results = []

for model_config in MODELS_TO_ANALYZE:
    model_name = model_config['name']
    domain = model_config['domain']
    
    scan_results = []

    for scan_info in tqdm(analysis_scans, desc=f"Analyzing Model: {model_name}"):
        patient, scan, _ = scan_info
        scan_name = f"p{patient}_{scan}"
        
        # --- Determine Domain and Plotting Slice ---
        is_visual_domain = domain in ['FDK', 'IMAG']
        plot_slice_idx = None
        tumor_xy = None
        
        if is_visual_domain:
            if 'tumor_locations' in locals() and tumor_locations is not None:
                try:
                    loc = tumor_locations[int(patient), int(scan)]
                    tumor_xy = (loc[1].item(), loc[0].item())
                    plot_slice_idx = int(loc[2].item()) - 20
                except (IndexError, TypeError):
                    print(f"Warning: Could not find tumor location for {scan_name}. Plotting will not have an arrow.")
                    plot_slice_idx = 100
            else:
                plot_slice_idx = 100
        
        # --- Data Loading (Ground Truth Only) ---
        gt_volume = load_ground_truth(FILES, scan_info, domain, slice_idx=None)
        gt_volume_np = gt_volume.cpu().numpy()
        
        # --- Iterative Metric Calculation ---
        print("Calculating metrics iteratively...")
        iq_metrics, mean_pred_vol, uncertainty_map_vol = calculate_volume_metrics_2_pass(
            FILES, model_config, scan_info, gt_volume.to(DEVICE), DEVICE
        )
        mean_pred_vol = mean_pred_vol.cpu().numpy()
        uncertainty_map_vol = uncertainty_map_vol.cpu().numpy()
        
        # --- Uncertainty Metric Calculation ---
        errors_vol = np.abs(gt_volume_np - mean_pred_vol)
        print("Calculating AUSE...")
        ause_val = calculate_ause_sparsification(uncertainty_map_vol, errors_vol)
        # print("Calculating ECE...")
        # ece_val = calculate_ece(gt_volume_np, mean_pred_vol, uncertainty_map_vol)
        print("Calculating Spearman's correlation...")
        spearman_val = calculate_spearman_correlation(uncertainty_map_vol, errors_vol)
        
        # --- Store Results ---
        scan_result = {
            'model_name': model_name,
            'scan_name': scan_name,
            **iq_metrics,
            'ause': ause_val,
            # 'ece': ece_val,
            'spearman_corr': spearman_val,
        }
        scan_results.append(scan_result)
        
        # --- Visualization ---
        print(f"\n--- Results for {model_name} on {scan_name} ---")
        
        # plot_calibration_curve(gt_volume_np, mean_pred_vol, uncertainty_map_vol, model_name, scan_name)
        plot_sparsification_curve(uncertainty_map_vol, errors_vol, model_name, scan_name)
        
        if is_visual_domain:
            gt_slice_np = gt_volume_np[plot_slice_idx]
            mean_pred_slice = mean_pred_vol[plot_slice_idx]
            uncertainty_map_slice = uncertainty_map_vol[plot_slice_idx]
            _, ssim_map_val = ssim(gt_slice_np, mean_pred_slice, 
                                  data_range=(np.max(gt_slice_np) - np.min(gt_slice_np)), 
                                  full=True, **SSIM_KWARGS)
            
            plot_mean_comparison(mean_pred_slice, gt_slice_np, uncertainty_map_slice, 
                                 model_name, scan_name, plot_slice_idx, tumor_coords_xy=tumor_xy)
            plot_ssim_map(ssim_map_val, model_name, scan_name)

            # --- Load and plot a few samples for visual comparison ---
            print("Loading and plotting samples...")
            SAMPLES_TO_PLOT = 5
            sample_slices_for_plotting = []
            
            # Determine the number of samples to load (can't be more than what's available)
            num_to_load = min(SAMPLES_TO_PLOT, model_config['count'])

            for i in range(num_to_load):
                passthrough_num = None
                model_version = model_config['model_version_root']
                model_type = model_config['type']

                if model_type == 'ensemble':
                    model_version = f"{model_config['model_version_root']}_{i+1:02d}"
                elif model_type == 'stochastic':
                    passthrough_num = i
                
                # This loading logic is copied from your metrics function
                pred = None
                if domain == 'FDK':
                    pred_path = FILES.get_recon_filepath(model_version, patient, scan, scan_type_agg, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
                    pred = torch.load(pred_path).detach()
                    pred = pred[20:-20, :, :]
                    pred = 25. * torch.clip(pred, min=0.0, max=0.04)
                elif domain == 'IMAG':
                    pred_path = FILES.get_images_results_filepath(model_version, patient, scan, passthrough_num=passthrough_num, ensure_exists=False)
                    pred = torch.load(pred_path).detach()
                    pred = torch.squeeze(pred, dim=1)
                    pred = torch.permute(pred, (0, 2, 1))
                
                # Get the specific slice and convert to a numpy array for plotting
                if pred is not None:
                    sample_slice = pred[plot_slice_idx].cpu().numpy()
                    sample_slices_for_plotting.append(sample_slice)

            # Call the new plotting function
            if sample_slices_for_plotting:
                plot_samples_comparison(
                    ground_truth=gt_slice_np,
                    mean_pred=mean_pred_slice,
                    samples=sample_slices_for_plotting,
                    model_name=model_name,
                    scan_name=scan_name,
                    slice_idx=plot_slice_idx,
                    tumor_coords_xy=tumor_xy
                )

            print("Finding and plotting worst samples by Smooth L1 Loss...")
            WORST_SAMPLES_TO_PLOT = 5

            # List to store tuples of (loss, sample_slice_numpy, sample_index)
            worst_samples_data = []
            # Get the ground truth slice as a tensor on the correct device
            gt_slice_tensor = gt_volume[plot_slice_idx].to(DEVICE)

            # Loop through all available samples to find the worst ones
            for i in tqdm(range(model_config['count']), desc="Finding Worst Samples", leave=False):
                passthrough_num = None
                model_version = model_config['model_version_root']
                model_type = model_config['type']

                if model_type == 'ensemble':
                    model_version = f"{model_config['model_version_root']}_{i+1:02d}"
                elif model_type == 'stochastic':
                    passthrough_num = i
                
                # Load the prediction volume as a tensor on the GPU
                pred_vol_tensor = None
                if domain == 'FDK':
                    pred_path = FILES.get_recon_filepath(model_version, patient, scan, scan_type_agg, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
                    pred_vol_tensor = torch.load(pred_path, map_location=DEVICE).detach()
                    pred_vol_tensor = pred_vol_tensor[20:-20, :, :]
                    pred_vol_tensor = 25. * torch.clip(pred_vol_tensor, min=0.0, max=0.04)
                elif domain == 'IMAG':
                    pred_path = FILES.get_images_results_filepath(model_version, patient, scan, passthrough_num=passthrough_num, ensure_exists=False)
                    pred_vol_tensor = torch.load(pred_path, map_location=DEVICE).detach()
                    pred_vol_tensor = torch.squeeze(pred_vol_tensor, dim=1)
                    pred_vol_tensor = torch.permute(pred_vol_tensor, (0, 2, 1))

                if pred_vol_tensor is not None:
                    pred_slice_tensor = pred_vol_tensor[plot_slice_idx]
                    
                    # Calculate Smooth L1 Loss for the current slice
                    import torch.nn.functional as F
                    loss = F.smooth_l1_loss(pred_slice_tensor, gt_slice_tensor, reduction='mean').item()

                    # Keep track of the top 5 worst samples (highest loss)
                    if len(worst_samples_data) < WORST_SAMPLES_TO_PLOT:
                        worst_samples_data.append((loss, pred_slice_tensor.cpu().numpy(), i))
                    else:
                        # Find the sample with the minimum loss currently in our list
                        min_loss_in_list = min(worst_samples_data, key=lambda x: x[0])
                        if loss > min_loss_in_list[0]:
                            # If current sample is worse, replace the "best of the worst"
                            worst_samples_data.remove(min_loss_in_list)
                            worst_samples_data.append((loss, pred_slice_tensor.cpu().numpy(), i))

            # Sort the final list from worst to best for plotting
            worst_samples_data.sort(key=lambda x: x[0], reverse=True)

            # Call the new plotting function with the results
            if worst_samples_data:
                plot_worst_samples_comparison(
                    ground_truth=gt_slice_np,
                    mean_pred=mean_pred_slice,
                    worst_samples_data=worst_samples_data,
                    model_name=model_name,
                    scan_name=scan_name,
                    slice_idx=plot_slice_idx,
                    tumor_coords_xy=tumor_xy
                )

        # --- Clean up memory ---
        del gt_volume, gt_volume_np, mean_pred_vol, uncertainty_map_vol, errors_vol
        gc.collect()
        
    all_results.extend(scan_results)

# Convert results to a pandas DataFrame for easier analysis
results_df = pd.DataFrame(all_results)

print("\n\n✅ Analysis complete for all models and scans.")
results_df

### Setup/Imports

In [None]:
import os
import gc
import numpy as np
import torch
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from tqdm import tqdm
import pandas as pd
import scipy.io
from pipeline.paths import Directories, Files
from pipeline.utils import read_scans_agg_file
import scipy.stats
plt.rcParams['figure.dpi'] = 200

### Config

In [None]:
PHASE = "7"
DATA_VERSION = "13"
WORK_ROOT = "D:/NoahSilverberg/ngCBCT"
SPLIT_TO_ANALYZE = 'VALIDATION'  # Options: 'TRAIN', 'VALIDATION', 'TEST'

MODELS_TO_ANALYZE = [
    # {
    #     'name': 'SWAG Learning Rate 1e-2',
    #     'type': 'stochastic',
    #     'domain': 'FDK',
    #     'model_version_root': 'MK7_MCDROPOUT_15_pct_NEW_SWAG_lr1e-2',
    #     'count': 33,
    # },
    {
        'name': 'MC Dropoout 15%',
        'type': 'stochastic',
        'domain': 'FDK',
        'model_version_root': 'MK7_MCDROPOUT_15_pct_NEW',
        'count': 50,
    },
    {
        'name': 'MC Dropoout 30%',
        'type': 'stochastic',
        'domain': 'FDK',
        'model_version_root': 'MK7_MCDROPOUT_30_pct_NEW',
        'count': 50,
    },
    {
        'name': 'MC Dropoout 50%',
        'type': 'stochastic',
        'domain': 'FDK',
        'model_version_root': 'MK7_MCDROPOUT_50_pct_NEW',
        'count': 50,
    },
    {
        'name': 'Ensemble',
        'type': 'ensemble',
        'domain': 'FDK',
        'model_version_root': 'MK7',
        'count': 7,
    },
    # Add other models here
]

# Path to the .pt file containing tumor locations.
# This is a 5D tensor [patient, scan, (x, y, z)]
TUMOR_LOCATIONS_FILE = 'D:/NoahSilverberg/ngCBCT/3D_recon/tumor_location.pt'

# --- Advanced Config ---
SCANS_AGG_FILE = 'scans_to_agg.txt'
SSIM_KWARGS = {"K1": 0.03, "K2": 0.06, "win_size": 15}
SSIM_KWARGS_ = {"k1": 0.03, "k2": 0.06, "kernel_size": 15}

# --- Setup ---
# Create Directories and Files objects
phase_dataver_dir = os.path.join(WORK_ROOT, f"phase{PHASE}", f"DS{DATA_VERSION}")
DIRECTORIES = Directories(
    projections_results_dir=os.path.join(phase_dataver_dir, "results", "projections"),
    projections_gated_dir=os.path.join(WORK_ROOT, "gated", "prj_mat"),
    reconstructions_dir=os.path.join(phase_dataver_dir, "reconstructions"),
    reconstructions_gated_dir=os.path.join(WORK_ROOT, "gated", "fdk_recon"),
    images_results_dir=os.path.join(phase_dataver_dir, "results", "images"),
)
FILES = Files(DIRECTORIES)

# Load the list of scans
scans_agg, scan_type_agg = read_scans_agg_file(SCANS_AGG_FILE)
analysis_scans = scans_agg[SPLIT_TO_ANALYZE][:4]

# Load tumor locations
tumor_locations = torch.load(TUMOR_LOCATIONS_FILE, weights_only=False)

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE} named '{torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}'")

print(f"\nConfiguration loaded.")
print(f"Analyzing {len(analysis_scans)} scans from the '{SPLIT_TO_ANALYZE}' split.")
print(f"Found {len(MODELS_TO_ANALYZE)} model(s) to analyze.")

### Data loading/prep functions

In [None]:
def load_ground_truth(files_obj: Files, scan_info, domain, slice_idx=None):
    """
    Loads the ground truth data for a given scan and domain.
    """
    patient, scan, scan_type = scan_info
    
    if domain == 'PROJ':
        gt_path = files_obj.get_projections_results_filepath('fdk', patient, scan, scan_type, gated=True)
        data = torch.from_numpy(scipy.io.loadmat(gt_path)['prj']).detach().permute(1, 0, 2)
    elif domain == 'FDK':
        gt_path = files_obj.get_recon_filepath("fdk", patient, scan, scan_type, gated=True, ensure_exists=False)
        data = torch.load(gt_path).detach()
        data = data[20:-20, :, :]
        data = 25. * torch.clip(data, min=0.0, max=0.04)
    elif domain == 'IMAG':
        # The ground truth for the IMAG domain is the FDK of the gated projection
        gt_path = files_obj.get_recon_filepath("fdk", patient, scan, scan_type, gated=True, ensure_exists=False)
        data = torch.load(gt_path).detach()
        data = data[20:-20, :, :]
        data = 25. * torch.clip(data, min=0.0, max=0.04)
    else:
        raise ValueError(f"Unknown domain: {domain}")

    if slice_idx is not None and data.ndim == 3:
        return data[slice_idx]
    return data

def load_predictions(files_obj: Files, model_config, scan_info, slice_idx=None):
    """
    Loads all predictions for a given model, scan, and domain.
    """
    patient, scan, scan_type = scan_info
    domain = model_config['domain']
    root = model_config['model_version_root']
    count = model_config['count']
    model_type = model_config['type']

    predictions = []
    
    print(f"Loading {count} predictions for {model_config['name']}...")
    
    for i in tqdm(range(count), desc="Loading predictions", leave=False):
        passthrough_num = None
        model_version = root

        if model_type == 'ensemble':
            model_version = f"{root}_{i+1:02d}"
        elif model_type == 'stochastic':
            passthrough_num = i

        if domain == 'PROJ':
            pred_path = files_obj.get_projections_results_filepath(model_version, patient, scan, scan_type, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
            pred = torch.from_numpy(scipy.io.loadmat(pred_path)['prj']).detach().permute(1, 0, 2)
        elif domain == 'FDK':
            pred_path = files_obj.get_recon_filepath(model_version, patient, scan, scan_type, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
            pred = torch.load(pred_path).detach()
            pred = pred[20:-20, :, :]
            pred = 25. * torch.clip(pred, min=0.0, max=0.04)
        elif domain == 'IMAG':
            # This assumes the results are saved with the ID model version name
            pred_path = files_obj.get_images_results_filepath(model_version, patient, scan, passthrough_num=passthrough_num, ensure_exists=False)
            pred = torch.load(pred_path).detach()
            pred = torch.squeeze(pred, dim=1)
            pred = torch.permute(pred, (0, 2, 1))
        else:
            raise ValueError(f"Unknown domain: {domain}")
            
        predictions.append(pred)

    predictions_tensor = torch.stack(predictions)
    
    if slice_idx is not None and predictions_tensor.ndim == 4:
        return predictions_tensor[:, slice_idx, :, :]
        
    return predictions_tensor

print("Data loading functions defined.")

### Metric calculation functions

In [None]:
def calculate_ause_sparsification(uncertainty, errors):
    """
    Calculates the Area Under the Sparsification Error curve (AUSE) efficiently.
    """
    uncertainty_flat = uncertainty.flatten()
    errors_flat = errors.flatten()
    
    # Normalize by overall MAE so the curve starts at 1
    overall_mae = np.mean(errors_flat)
    
    def get_sparsification_curve_fast(sorted_errs):
        n_pixels = len(sorted_errs)
        cumulative_errors = np.cumsum(sorted_errs)
        total_error_sum = cumulative_errors[-1]
        sum_errors_removed = np.insert(cumulative_errors[:-1], 0, 0)
        sum_errors_remaining = total_error_sum - sum_errors_removed
        n_remaining = np.arange(n_pixels, 0, -1)
        curve = sum_errors_remaining / n_remaining
        if overall_mae > 0:
            curve = curve / overall_mae # Normalize the curve
        return curve

    # Move arrays to GPU using torch for sorting
    uncertainty_tensor = torch.from_numpy(uncertainty_flat).cuda()
    errors_tensor = torch.from_numpy(errors_flat).cuda()

    # Model curve (sorted by uncertainty)
    model_sorted_indices = torch.argsort(uncertainty_tensor, descending=True)
    model_sorted_errors = errors_tensor[model_sorted_indices].cpu().numpy()
    model_curve = get_sparsification_curve_fast(model_sorted_errors)

    # Oracle curve (sorted by error)
    oracle_sorted_errors = torch.sort(errors_tensor, descending=True)[0].cpu().numpy()
    oracle_curve = get_sparsification_curve_fast(oracle_sorted_errors)
    
    # The AUSE is the area between the two normalized curves
    ause = np.mean(np.abs(model_curve - oracle_curve))
    return ause

def calculate_ece(ground_truth, mean_pred, uncertainty_map, n_levels=20):
    """
    Calculates the weighted calibration error for regression tasks based on Kuleshov et al., 2018.
    """
    gt_flat = ground_truth.flatten()
    pred_flat = mean_pred.flatten()
    uncert_flat = uncertainty_map.flatten()

    pred_cdfs = scipy.stats.norm.cdf(gt_flat, loc=pred_flat, scale=uncert_flat)
    expected_confidence_levels = np.linspace(0, 1, n_levels)
    observed_frequencies = np.array([np.mean(pred_cdfs <= p_j) for p_j in expected_confidence_levels])
    
    bin_boundaries = np.copy(expected_confidence_levels)
    bin_weights = np.zeros(n_levels)
    for i in range(1, n_levels):
        lower_bound = bin_boundaries[i-1]
        upper_bound = bin_boundaries[i]
        points_in_bin = (pred_cdfs > lower_bound) & (pred_cdfs <= upper_bound)
        bin_weights[i] = np.mean(points_in_bin)
        
    if np.sum(bin_weights) > 0:
        bin_weights /= np.sum(bin_weights)

    squared_errors = (expected_confidence_levels - observed_frequencies)**2
    weighted_calibration_error = np.sum(bin_weights * squared_errors)
    return weighted_calibration_error

def calculate_spearman_correlation(uncertainty, errors):
    """
    Calculates the Spearman's Rank Correlation Coefficient between
    the uncertainty and the absolute error.
    """
    uncertainty_flat = uncertainty.flatten()
    errors_flat = errors.flatten()
    
    # spearmanr returns correlation and p-value; we only need the correlation
    correlation, _ = scipy.stats.spearmanr(uncertainty_flat, errors_flat)
    return correlation

import torchmetrics
import torchmetrics.image

def calculate_volume_metrics_2_pass(files_obj, model_config, scan_info, gt_volume, device):
    """
    Calculates stats and metrics using a two-pass algorithm for variance for improved stability.
    Also includes a robust PSNR calculation that handles infinite values.
    Pass 1: Calculate the mean of all predictions.
    Pass 2: Calculate variance and other metrics using the pre-calculated mean.
    """
    n_samples = model_config['count']
    gt_volume = gt_volume.to(device) # Ensure GT is on the correct device

    def prediction_generator():
        # This generator yields tensors directly on the GPU
        patient, scan, scan_type = scan_info
        domain = model_config['domain']
        root = model_config['model_version_root']
        model_type = model_config['type']
        for i in range(n_samples):
            passthrough_num = None
            model_version = root
            if model_type == 'ensemble': model_version = f"{root}_{i+1:02d}"
            elif model_type == 'stochastic': passthrough_num = i

            if domain == 'PROJ':
                pred_path = files_obj.get_projections_results_filepath(model_version, patient, scan, scan_type, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
                pred = torch.from_numpy(scipy.io.loadmat(pred_path)['prj']).detach().permute(1, 0, 2)
            elif domain == 'FDK':
                pred_path = files_obj.get_recon_filepath(model_version, patient, scan, scan_type, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
                pred = torch.load(pred_path).detach()
                pred = pred[20:-20, :, :]
                pred = 25. * torch.clip(pred, min=0.0, max=0.04)
            elif domain == 'IMAG':
                pred_path = files_obj.get_images_results_filepath(model_version, patient, scan, passthrough_num=passthrough_num, ensure_exists=False)
                pred = torch.load(pred_path).detach()
                pred = torch.squeeze(pred, dim=1)
                pred = torch.permute(pred, (0, 2, 1))
            yield pred.to(device)

    # --- Pass 1: Calculate Mean ---
    print("Pass 1: Calculating mean prediction...")
    mean_volume = torch.zeros_like(gt_volume)
    # Using a simple sum and divide for the mean
    for pred_volume in tqdm(prediction_generator(), total=n_samples, desc="Pass 1/2 (Mean)", leave=False):
        mean_volume += pred_volume
    mean_volume /= n_samples

    # --- Pass 2: Calculate Variance and Metrics ---
    print("Pass 2: Calculating variance and metrics...")
    sum_sq_diff_volume = torch.zeros_like(gt_volume)
    sample_avg_ssims, sample_avg_psnrs, sample_avg_mses, sample_avg_maes = [], [], [], []

    # Initialize metrics on the specified device
    data_range = gt_volume.max() - gt_volume.min()
    ssim_metric = torchmetrics.image.StructuralSimilarityIndexMeasure(data_range=data_range, **SSIM_KWARGS_).to(device)

    # PSNR metric that returns per-slice results to handle 'inf'
    psnr_metric = torchmetrics.image.PeakSignalNoiseRatio(data_range=data_range, reduction='none').to(device)

    for pred_volume in tqdm(prediction_generator(), total=n_samples, desc="Pass 2/2 (Var & Metrics)", leave=False):
        # Variance calculation
        diff = pred_volume - mean_volume
        sum_sq_diff_volume += diff * diff

        # Per-sample metrics
        if gt_volume.ndim > 2:
            pred_vol_batch = pred_volume.unsqueeze(1)
            gt_vol_batch = gt_volume.unsqueeze(1)
            
            sample_avg_ssims.append(ssim_metric(pred_vol_batch, gt_vol_batch).item())
            sample_avg_mses.append(torch.mean((gt_volume - pred_volume)**2).item())
            sample_avg_maes.append(torch.mean(torch.abs(gt_volume - pred_volume)).item())

            # --- Robust PSNR Calculation ---
            psnr_per_slice = psnr_metric(pred_vol_batch, gt_vol_batch)
            finite_psnrs = psnr_per_slice[torch.isfinite(psnr_per_slice)] # Filter out inf values
            if finite_psnrs.numel() > 0:
                sample_avg_psnrs.append(torch.mean(finite_psnrs).item())
            else:
                # Handle case where all slices are perfect matches
                sample_avg_psnrs.append(100.0) # Assign a high value

    # Finalize variance and uncertainty
    if n_samples > 1:
        # Using n_samples for population standard deviation, as in the original code.
        # For sample standard deviation, use (n_samples - 1).
        variance_volume_map = sum_sq_diff_volume / n_samples
        uncertainty_volume_map = torch.sqrt(variance_volume_map)
    else:
        uncertainty_volume_map = torch.zeros_like(mean_volume)

    # --- Calculate metrics for the mean prediction ---
    metrics = {}
    if gt_volume.ndim > 2:
        mean_vol_batch = mean_volume.unsqueeze(1)
        gt_vol_batch = gt_volume.unsqueeze(1)
        
        metrics['mean_ssim'] = ssim_metric(mean_vol_batch, gt_vol_batch).item()
        metrics['mean_mse'] = torch.mean((gt_volume - mean_volume)**2).item()
        metrics['mean_mae'] = torch.mean(torch.abs(gt_volume - mean_volume)).item()

        # Robust PSNR for the mean prediction
        mean_psnr_per_slice = psnr_metric(mean_vol_batch, gt_vol_batch)
        finite_mean_psnrs = mean_psnr_per_slice[torch.isfinite(mean_psnr_per_slice)]
        if finite_mean_psnrs.numel() > 0:
            metrics['mean_psnr'] = torch.mean(finite_mean_psnrs).item()
        else:
            metrics['mean_psnr'] = 100.0

    # --- Aggregate per-sample metrics ---
    metrics['sample_avg_ssim'] = np.mean(sample_avg_ssims) if sample_avg_ssims else 0
    metrics['sample_avg_psnr'] = np.mean(sample_avg_psnrs) if sample_avg_psnrs else 0
    metrics['sample_avg_mse'] = np.mean(sample_avg_mses) if sample_avg_mses else 0
    metrics['sample_avg_mae'] = np.mean(sample_avg_maes) if sample_avg_maes else 0

    return metrics, mean_volume, uncertainty_volume_map

print("Metric calculation functions defined.")

### Visualization functions

In [None]:
def plot_mean_comparison(mean_pred, ground_truth, uncertainty_map, model_name, scan_name, slice_idx, tumor_coords_xy=None):
    """Plots the GT, mean prediction, absolute error, and uncertainty map."""
    error_map = np.abs(ground_truth - mean_pred)

    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    fig.suptitle(f'{model_name} - {scan_name} - Slice {slice_idx} (Mean vs. GT)', fontsize=16)

    im1 = axes[0].imshow(ground_truth, cmap='gray')
    axes[0].set_title('Ground Truth')
    axes[0].axis('off')
    fig.colorbar(im1, ax=axes[0])

    if tumor_coords_xy:
        x, y = tumor_coords_xy
        for i in range(4):
            axes[i].annotate('', xy=(x, y), xytext=(x - 30, y - 30),
                            arrowprops=dict(facecolor='red', edgecolor='red', shrink=0.05, width=1, headwidth=5, headlength=5))

    im2 = axes[1].imshow(mean_pred, cmap='gray')
    axes[1].set_title('Mean Prediction')
    axes[1].axis('off')
    fig.colorbar(im2, ax=axes[1])

    im3 = axes[2].imshow(error_map, cmap='magma')
    axes[2].set_title('Absolute Error Map')
    axes[2].axis('off')
    fig.colorbar(im3, ax=axes[2])

    im4 = axes[3].imshow(uncertainty_map, cmap='viridis')
    axes[3].set_title('Uncertainty (Std Dev)')
    axes[3].axis('off')
    fig.colorbar(im4, ax=axes[3])

    plt.tight_layout()
    plt.show()

    
def plot_ssim_map(ssim_map, model_name, scan_name):
    """Plots the SSIM map."""
    plt.figure(figsize=(6, 6))
    plt.imshow(ssim_map, cmap='viridis', vmin=0, vmax=1)
    plt.title(f'SSIM Map - {model_name} - {scan_name}')
    plt.colorbar()
    plt.axis('off')
    plt.show()

def plot_calibration_curve(ground_truth, mean_pred, uncertainty_map, model_name, scan_name, n_levels=20):
    """
    Plots the calibration curve with a marginal histogram below the x-axis
    showing the distribution of the predicted CDF values.
    """
    gt_flat = ground_truth.flatten()
    pred_flat = mean_pred.flatten()
    uncert_flat = uncertainty_map.flatten()

    # Calculate the predicted CDF value for each point
    pred_cdfs = scipy.stats.norm.cdf(gt_flat, loc=pred_flat, scale=uncert_flat)
    
    # --- Create figure with two subplots, sharing the x-axis ---
    fig, (ax_cal, ax_hist) = plt.subplots(
        2, 1,
        figsize=(8, 8),
        sharex=True,
        gridspec_kw={'height_ratios': [3, 1]} # Main plot is 3x taller
    )
    
    # --- Main Calibration Plot (top) ---
    expected_confidence_levels = np.linspace(0, 1, n_levels)
    observed_frequencies = np.array([np.mean(pred_cdfs <= p_j) for p_j in expected_confidence_levels])

    ax_cal.plot([0, 1], [0, 1], '--', color='grey', label='Perfectly Calibrated')
    ax_cal.plot(expected_confidence_levels, observed_frequencies, '-o', label='Model Calibration')
    ax_cal.set_ylabel('Observed Confidence Level')
    ax_cal.set_title(f'Calibration Plot - {model_name} - {scan_name}')
    ax_cal.legend()
    ax_cal.grid(True, linestyle=':')

    # --- Marginal Histogram (bottom) ---
    ax_hist.hist(pred_cdfs, bins=50, range=(0,1), density=True, color='steelblue', alpha=0.8)
    ax_hist.set_xlabel('Expected Confidence Level (Predicted CDF)')
    ax_hist.set_ylabel('Density')
    ax_hist.set_yscale('log')
    # ax_hist.set_yticks([]) # Hide y-ticks for clarity

    # Final adjustments
    plt.tight_layout()
    plt.show()

def plot_sparsification_curve(uncertainty, errors, model_name, scan_name):
    """
    Plots the model and oracle sparsification curves used for AUSE calculation.
    """
    uncertainty_flat = uncertainty.flatten()
    errors_flat = errors.flatten()
    
    def get_sparsification_curve_fast(sorted_errs, overall_mae):
        n_pixels = len(sorted_errs)
        cumulative_errors = np.cumsum(sorted_errs)
        total_error_sum = cumulative_errors[-1]
        sum_errors_removed = np.insert(cumulative_errors[:-1], 0, 0)
        sum_errors_remaining = total_error_sum - sum_errors_removed
        n_remaining = np.arange(n_pixels, 0, -1)
        curve = sum_errors_remaining / n_remaining
        if overall_mae > 0:
            curve = curve / overall_mae
        return curve
    
    overall_mae = np.mean(errors_flat)

    # Model curve (sorted by uncertainty)
    model_sorted_indices = np.argsort(uncertainty_flat)[::-1]
    model_sorted_errors = errors_flat[model_sorted_indices]
    model_curve = get_sparsification_curve_fast(model_sorted_errors, overall_mae)
    
    # Oracle curve (sorted by error)
    oracle_sorted_errors = np.sort(errors_flat)[::-1]
    oracle_curve = get_sparsification_curve_fast(oracle_sorted_errors, overall_mae)
    
    # X-axis: fraction of pixels removed
    fraction_removed = np.linspace(0, 1, len(model_curve))
    
    plt.figure(figsize=(7, 6))
    plt.plot(fraction_removed, model_curve, label='Model (Sort by Uncertainty)')
    plt.plot(fraction_removed, oracle_curve, '--', label='Oracle (Sort by Error)')
    plt.xlabel('Fraction of Pixels Removed')
    plt.ylabel('Mean Absolute Error of Remaining Pixels')
    plt.title(f'Sparsification Curve - {model_name} - {scan_name}')
    plt.legend()
    plt.grid(True, linestyle=':')
    plt.show()

def plot_samples_comparison(ground_truth, mean_pred, samples, model_name, scan_name, slice_idx, tumor_coords_xy=None):
    """
    Plots the ground truth, mean prediction, and a few individual sample predictions.
    
    Args:
        ground_truth (np.ndarray): The 2D ground truth slice.
        mean_pred (np.ndarray): The 2D mean prediction slice.
        samples (list of np.ndarray): A list of 2D sample prediction slices.
        model_name (str): The name of the model for the title.
        scan_name (str): The name of the scan for the title.
        slice_idx (int): The index of the slice for the title.
        tumor_coords_xy (tuple, optional): (x, y) coordinates for the tumor arrow.
    """
    num_samples = len(samples)
    # Total columns = 1 for GT + 1 for Mean + N for samples
    num_cols = 2 + num_samples
    
    fig, axes = plt.subplots(1, num_cols, figsize=(4 * num_cols, 4.5), constrained_layout=True)
    fig.suptitle(f'{model_name} - {scan_name} - Slice {slice_idx} (GT, Mean, and Samples)', fontsize=16)

    # Determine a consistent grayscale range based on the ground truth and mean
    vmin = min(ground_truth.min(), mean_pred.min())
    vmax = max(ground_truth.max(), mean_pred.max())

    # --- Plot Ground Truth ---
    axes[0].imshow(ground_truth, cmap='gray', vmin=vmin, vmax=vmax)
    axes[0].set_title('Ground Truth')
    axes[0].axis('off')

    # --- Plot Mean Prediction ---
    axes[1].imshow(mean_pred, cmap='gray', vmin=vmin, vmax=vmax)
    axes[1].set_title('Mean Prediction')
    axes[1].axis('off')

    # --- Plot Samples ---
    for i in range(num_samples):
        ax = axes[i + 2]
        im = ax.imshow(samples[i], cmap='gray', vmin=vmin, vmax=vmax)
        ax.set_title(f'Sample {i+1}')
        ax.axis('off')

    # --- Add Tumor Arrow ---
    if tumor_coords_xy:
        x, y = tumor_coords_xy
        for ax in axes:
            ax.annotate('', xy=(x, y), xytext=(x - 30, y - 30),
                        arrowprops=dict(facecolor='red', edgecolor='red', shrink=0.05, 
                                        width=1, headwidth=5, headlength=5))

    plt.show()

def plot_worst_samples_comparison(ground_truth, mean_pred, worst_samples_data, model_name, scan_name, slice_idx, tumor_coords_xy=None):
    """
    Plots the ground truth, mean prediction, and the worst-performing sample predictions.
    
    Args:
        ground_truth (np.ndarray): The 2D ground truth slice.
        mean_pred (np.ndarray): The 2D mean prediction slice.
        worst_samples_data (list): A list of tuples, where each tuple is 
                                   (loss, sample_slice_numpy, sample_index).
        model_name (str): The name of the model for the title.
        scan_name (str): The name of the scan for the title.
        slice_idx (int): The index of the slice for the title.
        tumor_coords_xy (tuple, optional): (x, y) coordinates for the tumor arrow.
    """
    num_samples = len(worst_samples_data)
    num_cols = 2 + num_samples
    
    fig, axes = plt.subplots(1, num_cols, figsize=(4 * num_cols, 5), constrained_layout=True)
    fig.suptitle(f'{model_name} - {scan_name} - Slice {slice_idx} (Top {num_samples} Worst Samples by SmoothL1Loss)', fontsize=16)

    # Determine a consistent grayscale range
    vmin = min(ground_truth.min(), mean_pred.min())
    vmax = max(ground_truth.max(), mean_pred.max())

    # --- Plot Ground Truth ---
    axes[0].imshow(ground_truth, cmap='gray', vmin=vmin, vmax=vmax)
    axes[0].set_title('Ground Truth')
    axes[0].axis('off')

    # --- Plot Mean Prediction ---
    axes[1].imshow(mean_pred, cmap='gray', vmin=vmin, vmax=vmax)
    axes[1].set_title('Mean Prediction')
    axes[1].axis('off')

    # --- Plot Worst Samples ---
    for i in range(num_samples):
        loss, sample_slice, sample_idx = worst_samples_data[i]
        ax = axes[i + 2]
        im = ax.imshow(sample_slice, cmap='gray', vmin=vmin, vmax=vmax)
        # Add the loss and original sample number to the title
        ax.set_title(f'Sample #{sample_idx}\nLoss: {loss:.4f}')
        ax.axis('off')

    # --- Add Tumor Arrow ---
    if tumor_coords_xy:
        x, y = tumor_coords_xy
        for ax in axes:
            ax.annotate('', xy=(x, y), xytext=(x - 30, y - 30),
                        arrowprops=dict(facecolor='red', edgecolor='red', shrink=0.05, 
                                        width=1, headwidth=5, headlength=5))

    plt.show()

print("Visualization functions defined.")

### Main loop

In [None]:
# This list will store dictionaries of results for each scan and model
all_results = []

for model_config in MODELS_TO_ANALYZE:
    model_name = model_config['name']
    domain = model_config['domain']
    
    scan_results = []

    for scan_info in tqdm(analysis_scans, desc=f"Analyzing Model: {model_name}"):
        patient, scan, _ = scan_info
        scan_name = f"p{patient}_{scan}"
        
        # --- Determine Domain and Plotting Slice ---
        is_visual_domain = domain in ['FDK', 'IMAG']
        plot_slice_idx = None
        tumor_xy = None
        
        if is_visual_domain:
            if 'tumor_locations' in locals() and tumor_locations is not None:
                try:
                    loc = tumor_locations[int(patient), int(scan)]
                    tumor_xy = (loc[1].item(), loc[0].item())
                    plot_slice_idx = int(loc[2].item()) - 20
                except (IndexError, TypeError):
                    print(f"Warning: Could not find tumor location for {scan_name}. Plotting will not have an arrow.")
                    plot_slice_idx = 100
            else:
                plot_slice_idx = 100
        
        # --- Data Loading (Ground Truth Only) ---
        gt_volume = load_ground_truth(FILES, scan_info, domain, slice_idx=None)
        gt_volume_np = gt_volume.cpu().numpy()
        
        # --- Iterative Metric Calculation ---
        print("Calculating metrics iteratively...")
        iq_metrics, mean_pred_vol, uncertainty_map_vol = calculate_volume_metrics_2_pass(
            FILES, model_config, scan_info, gt_volume.to(DEVICE), DEVICE
        )
        mean_pred_vol = mean_pred_vol.cpu().numpy()
        uncertainty_map_vol = uncertainty_map_vol.cpu().numpy()
        
        # --- Uncertainty Metric Calculation ---
        errors_vol = np.abs(gt_volume_np - mean_pred_vol)
        print("Calculating AUSE...")
        ause_val = calculate_ause_sparsification(uncertainty_map_vol, errors_vol)
        # print("Calculating ECE...")
        # ece_val = calculate_ece(gt_volume_np, mean_pred_vol, uncertainty_map_vol)
        print("Calculating Spearman's correlation...")
        spearman_val = calculate_spearman_correlation(uncertainty_map_vol, errors_vol)
        
        # --- Store Results ---
        scan_result = {
            'model_name': model_name,
            'scan_name': scan_name,
            **iq_metrics,
            'ause': ause_val,
            # 'ece': ece_val,
            'spearman_corr': spearman_val,
        }
        scan_results.append(scan_result)
        
        # # --- Visualization ---
        # print(f"\n--- Results for {model_name} on {scan_name} ---")
        
        # # plot_calibration_curve(gt_volume_np, mean_pred_vol, uncertainty_map_vol, model_name, scan_name)
        # plot_sparsification_curve(uncertainty_map_vol, errors_vol, model_name, scan_name)
        
        # if is_visual_domain:
        #     gt_slice_np = gt_volume_np[plot_slice_idx]
        #     mean_pred_slice = mean_pred_vol[plot_slice_idx]
        #     uncertainty_map_slice = uncertainty_map_vol[plot_slice_idx]
        #     _, ssim_map_val = ssim(gt_slice_np, mean_pred_slice, 
        #                           data_range=(np.max(gt_slice_np) - np.min(gt_slice_np)), 
        #                           full=True, **SSIM_KWARGS)
            
        #     plot_mean_comparison(mean_pred_slice, gt_slice_np, uncertainty_map_slice, 
        #                          model_name, scan_name, plot_slice_idx, tumor_coords_xy=tumor_xy)
        #     plot_ssim_map(ssim_map_val, model_name, scan_name)

        # --- Clean up memory ---
        del gt_volume, gt_volume_np, mean_pred_vol, uncertainty_map_vol, errors_vol
        gc.collect()
        
    all_results.extend(scan_results)

# Convert results to a pandas DataFrame for easier analysis
results_df = pd.DataFrame(all_results)

print("\n\n✅ Analysis complete for all models and scans.")
results_df

### Summary/Comparison

In [None]:
if len(MODELS_TO_ANALYZE) > 1:
    summary_data = []
    
    # Only aggregate numeric columns
    numeric_cols = results_df.select_dtypes(include=[np.number]).columns
    summary = results_df.groupby('model_name')[numeric_cols].agg(['mean', 'std']).reset_index()

    # Prepare summary_display with formatted mean ± std for each metric
    summary_display = pd.DataFrame()
    summary_display['model_name'] = summary['model_name']
    for col in numeric_cols:
        mean_col = (col, 'mean')
        std_col = (col, 'std')
        summary_display[col] = summary[mean_col].map('{:.4f}'.format) + ' ± ' + summary[std_col].map('{:.4f}'.format)
    
    print("\n\n=======================================================")
    print("               Model Comparison Summary")
    print("=======================================================")
    
    display(summary_display)

else:
    print("\nOnly one model was analyzed. No comparison table to generate.")

In [None]:
# This list will store dictionaries of results for each scan and model
all_results = []

for model_config in MODELS_TO_ANALYZE:
    model_name = model_config['name']
    domain = model_config['domain']
    
    scan_results = []

    for scan_info in tqdm(analysis_scans, desc=f"Analyzing Model: {model_name}"):
        patient, scan, _ = scan_info
        scan_name = f"p{patient}_{scan}"
        
        # --- Determine Domain and Plotting Slice ---
        is_visual_domain = domain in ['FDK', 'IMAG']
        plot_slice_idx = None
        tumor_xy = None
        
        if is_visual_domain:
            if 'tumor_locations' in locals() and tumor_locations is not None:
                try:
                    loc = tumor_locations[int(patient), int(scan)]
                    tumor_xy = (loc[1].item(), loc[0].item())
                    plot_slice_idx = int(loc[2].item()) - 20
                except (IndexError, TypeError):
                    print(f"Warning: Could not find tumor location for {scan_name}. Plotting will not have an arrow.")
                    plot_slice_idx = 100
            else:
                plot_slice_idx = 100
        
        # --- Data Loading (Ground Truth Only) ---
        gt_volume = load_ground_truth(FILES, scan_info, domain, slice_idx=None)
        gt_volume_np = gt_volume.cpu().numpy()
        
        # --- Iterative Metric Calculation ---
        print("Calculating metrics iteratively...")
        iq_metrics, mean_pred_vol, uncertainty_map_vol = calculate_volume_metrics_2_pass(
            FILES, model_config, scan_info, gt_volume.to(DEVICE), DEVICE
        )
        mean_pred_vol = mean_pred_vol.cpu().numpy()
        uncertainty_map_vol = uncertainty_map_vol.cpu().numpy()
        
        # --- Uncertainty Metric Calculation ---
        errors_vol = np.abs(gt_volume_np - mean_pred_vol)
        print("Calculating AUSE...")
        ause_val = calculate_ause_sparsification(uncertainty_map_vol, errors_vol)
        # print("Calculating ECE...")
        # ece_val = calculate_ece(gt_volume_np, mean_pred_vol, uncertainty_map_vol)
        print("Calculating Spearman's correlation...")
        spearman_val = calculate_spearman_correlation(uncertainty_map_vol, errors_vol)
        
        # --- Store Results ---
        scan_result = {
            'model_name': model_name,
            'scan_name': scan_name,
            **iq_metrics,
            'ause': ause_val,
            # 'ece': ece_val,
            'spearman_corr': spearman_val,
        }
        scan_results.append(scan_result)
        
        # # --- Visualization ---
        # print(f"\n--- Results for {model_name} on {scan_name} ---")
        
        # # plot_calibration_curve(gt_volume_np, mean_pred_vol, uncertainty_map_vol, model_name, scan_name)
        # plot_sparsification_curve(uncertainty_map_vol, errors_vol, model_name, scan_name)
        
        # if is_visual_domain:
        #     gt_slice_np = gt_volume_np[plot_slice_idx]
        #     mean_pred_slice = mean_pred_vol[plot_slice_idx]
        #     uncertainty_map_slice = uncertainty_map_vol[plot_slice_idx]
        #     _, ssim_map_val = ssim(gt_slice_np, mean_pred_slice, 
        #                           data_range=(np.max(gt_slice_np) - np.min(gt_slice_np)), 
        #                           full=True, **SSIM_KWARGS)
            
        #     plot_mean_comparison(mean_pred_slice, gt_slice_np, uncertainty_map_slice, 
        #                          model_name, scan_name, plot_slice_idx, tumor_coords_xy=tumor_xy)
        #     plot_ssim_map(ssim_map_val, model_name, scan_name)

        # --- Clean up memory ---
        del gt_volume, gt_volume_np, mean_pred_vol, uncertainty_map_vol, errors_vol
        gc.collect()
        
    all_results.extend(scan_results)

# Convert results to a pandas DataFrame for easier analysis
results_df = pd.DataFrame(all_results)

print("\n\n✅ Analysis complete for all models and scans.")
results_df

### Summary/Comparison

In [None]:
if len(MODELS_TO_ANALYZE) > 1:
    summary_data = []
    
    # Only aggregate numeric columns
    numeric_cols = results_df.select_dtypes(include=[np.number]).columns
    summary = results_df.groupby('model_name')[numeric_cols].agg(['mean', 'std']).reset_index()

    # Prepare summary_display with formatted mean ± std for each metric
    summary_display = pd.DataFrame()
    summary_display['model_name'] = summary['model_name']
    for col in numeric_cols:
        mean_col = (col, 'mean')
        std_col = (col, 'std')
        summary_display[col] = summary[mean_col].map('{:.4f}'.format) + ' ± ' + summary[std_col].map('{:.4f}'.format)
    
    print("\n\n=======================================================")
    print("               Model Comparison Summary")
    print("=======================================================")
    
    display(summary_display)

else:
    print("\nOnly one model was analyzed. No comparison table to generate.")