### Setup/Imports

In [None]:
import os
import gc
import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import scipy.io
from pipeline.paths import Directories, Files
from pipeline.utils import read_scans_agg_file
import scipy.stats
plt.rcParams['figure.dpi'] = 100

### Config

In [None]:
PHASE = "7"
DATA_VERSION = "14"
SPLIT_TO_ANALYZE = 'VALIDATION'  # Options: 'TRAIN', 'VALIDATION', 'TEST'

MODELS_TO_ANALYZE = [
    {
        'name': 'MC Dropout 15%',
        'type': 'stochastic',
        'domain': 'IMAG',
        'model_version_root': 'MK7_MCDROPOUT_15_pct',
        'count': 50,
    },
    {
        'name': 'MC Dropout 30%',
        'type': 'stochastic',
        'domain': 'IMAG',
        'model_version_root': 'MK7_MCDROPOUT_30_pct',
        'count': 50,
    },
    {
        'name': 'MC Dropout 50%',
        'type': 'stochastic',
        'domain': 'IMAG',
        'model_version_root': 'MK7_MCDROPOUT_50_pct',
        'count': 50,
    },
]

# Path to the .pt file containing tumor locations.
# This is a 5D tensor [patient, scan, (x, y, z)]
WORK_ROOT = "D:/NoahSilverberg/ngCBCT"
TUMOR_LOCATIONS_FILE = 'D:/NoahSilverberg/ngCBCT/3D_recon/tumor_location_FF.pt'

# --- Advanced Config ---
SCANS_AGG_FILE = 'scans_to_agg_FF.txt'
SSIM_KWARGS = {"k1": 0.03, "k2": 0.06, "kernel_size": 11}

# --- Setup ---
# Create Directories and Files objects
phase_dataver_dir = os.path.join(WORK_ROOT, f"phase{PHASE}", f"DS{DATA_VERSION}")
DIRECTORIES = Directories(
    # mat_projections_dir=os.path.join("H:\Public/Noah", "mat"),
    # pt_projections_dir=os.path.join("H:\Public/Noah", "prj_pt"),
    # projections_aggregate_dir=os.path.join(PHASE_DATAVER_DIR, "aggregates", "projections"),
    # projections_model_dir=os.path.join('H:\Public/Noah/phase7/DS14', "models", "projections"),
    # projections_results_dir=os.path.join('H:\Public/Noah/phase7/DS14', "results", "projections"),
    # projections_gated_dir=os.path.join("H:\Public/Noah", "gated", "prj_mat"),
    # reconstructions_dir=os.path.join('H:\Public/Noah/phase7/DS14', "reconstructions"),
    reconstructions_gated_dir=os.path.join("H:\Public/Noah", "gated", "fdk_recon"),
    # images_aggregate_dir=os.path.join(phase_dataver_dir, "aggregates", "images"),
    # images_model_dir=os.path.join('H:\Public/Noah/phase7/DS14', "models", "images"),
    images_results_dir=os.path.join('H:\Public/Noah/phase7/DS14', "results", "images"),
    error_results_dir= os.path.join('H:\Public/Noah/phase7/DS14', "results", "error_results"),
)
FILES = Files(DIRECTORIES)

# Load the list of scans
scans_agg, scan_type_agg = read_scans_agg_file(SCANS_AGG_FILE)
analysis_scans = scans_agg[SPLIT_TO_ANALYZE]

# Load tumor locations
tumor_locations = torch.load(TUMOR_LOCATIONS_FILE, weights_only=False)

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE} named '{torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}'")

print(f"\nConfiguration loaded.")
print(f"Analyzing {len(analysis_scans)} scans from the '{SPLIT_TO_ANALYZE}' split.")
print(f"Found {len(MODELS_TO_ANALYZE)} model(s) to analyze.")

## Function definitions, etc.

### Data loading/prep functions

In [None]:
def load_ground_truth(files_obj: Files, scan_info, domain, slice_idx=None):
    """
    Loads the ground truth data for a given scan and domain.
    """
    patient, scan, scan_type = scan_info
    
    if domain == 'PROJ':
        gt_path = files_obj.get_projections_results_filepath('fdk', patient, scan, scan_type, gated=True)
        data = torch.from_numpy(scipy.io.loadmat(gt_path)['prj']).detach().permute(1, 0, 2)
    elif domain == 'FDK':
        gt_path = files_obj.get_recon_filepath("fdk", patient, scan, scan_type, gated=True, ensure_exists=False)
        data = torch.load(gt_path).detach()
        data = data[20:-20, :, :]
        if scan_type == "FF":
            data = data[:, 128:-128, 128:-128]
        data = 25. * torch.clip(data, min=0.0, max=0.04)
    elif domain == 'IMAG':
        # The ground truth for the IMAG domain is the FDK of the gated projection
        gt_path = files_obj.get_recon_filepath("fdk", patient, scan, scan_type, gated=True, ensure_exists=False)
        data = torch.load(gt_path).detach()
        data = data[20:-20, :, :]
        if scan_type == "FF":
            data = data[:, 128:-128, 128:-128]
        data = 25. * torch.clip(data, min=0.0, max=0.04)
    else:
        raise ValueError(f"Unknown domain: {domain}")

    if slice_idx is not None and data.ndim == 3:
        return data[slice_idx]
    return data

def load_predictions(files_obj: Files, model_config, scan_info, slice_idx=None):
    """
    Loads all predictions for a given model, scan, and domain.
    """
    patient, scan, scan_type = scan_info
    domain = model_config['domain']
    root = model_config['model_version_root']
    count = model_config['count']
    model_type = model_config['type']

    predictions = []
    
    print(f"Loading {count} predictions for {model_config['name']}...")
    
    for i in tqdm(range(count), desc="Loading predictions", leave=False):
        passthrough_num = None
        model_version = root

        if model_type == 'ensemble':
            model_version = f"{root}_{i+1:02d}"
        elif model_type == 'stochastic':
            passthrough_num = i

        if domain == 'PROJ':
            pred_path = files_obj.get_projections_results_filepath(model_version, patient, scan, scan_type, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
            pred = torch.from_numpy(scipy.io.loadmat(pred_path)['prj']).detach().permute(1, 0, 2)
        elif domain == 'FDK':
            pred_path = files_obj.get_recon_filepath(model_version, patient, scan, scan_type, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
            pred = torch.load(pred_path).detach()
            pred = pred[20:-20, :, :]
            if scan_type == "FF":
                pred = pred[:, 128:-128, 128:-128]
            pred = 25. * torch.clip(pred, min=0.0, max=0.04)
        elif domain == 'IMAG':
            # This assumes the results are saved with the ID model version name
            pred_path = files_obj.get_images_results_filepath(model_version, patient, scan, passthrough_num=passthrough_num, ensure_exists=False)
            pred = torch.load(pred_path).detach()
            pred = torch.squeeze(pred, dim=1)
            pred = torch.permute(pred, (0, 2, 1))
        else:
            raise ValueError(f"Unknown domain: {domain}")
            
        predictions.append(pred)

    predictions_tensor = torch.stack(predictions)
    
    if slice_idx is not None and predictions_tensor.ndim == 4:
        return predictions_tensor[:, slice_idx, :, :]
        
    return predictions_tensor

print("Data loading functions defined.")

### Metric calculation functions: Image quality and pre-calibration uncertainties

In [None]:
def calculate_ause_sparsification(uncertainty, errors):
    """
    Calculates the Area Under the Sparsification Error curve (AUSE) efficiently.
    """
    uncertainty_flat = uncertainty.flatten()
    errors_flat = errors.flatten()
    
    # Normalize by overall MAE so the curve starts at 1
    overall_mae = np.mean(errors_flat)
    
    def get_sparsification_curve_fast(sorted_errs):
        n_pixels = len(sorted_errs)
        cumulative_errors = np.cumsum(sorted_errs)
        total_error_sum = cumulative_errors[-1]
        sum_errors_removed = np.insert(cumulative_errors[:-1], 0, 0)
        sum_errors_remaining = total_error_sum - sum_errors_removed
        n_remaining = np.arange(n_pixels, 0, -1)
        curve = sum_errors_remaining / n_remaining
        if overall_mae > 0:
            curve = curve / overall_mae # Normalize the curve
        return curve

    # Move arrays to GPU using torch for sorting
    uncertainty_tensor = torch.from_numpy(uncertainty_flat).cuda()
    errors_tensor = torch.from_numpy(errors_flat).cuda()

    # Model curve (sorted by uncertainty)
    model_sorted_indices = torch.argsort(uncertainty_tensor, descending=True)
    model_sorted_errors = errors_tensor[model_sorted_indices].cpu().numpy()
    model_curve = get_sparsification_curve_fast(model_sorted_errors)

    # Oracle curve (sorted by error)
    oracle_sorted_errors = torch.sort(errors_tensor, descending=True)[0].cpu().numpy()
    oracle_curve = get_sparsification_curve_fast(oracle_sorted_errors)
    
    # The AUSE is the area between the two normalized curves
    ause = np.mean(np.abs(model_curve - oracle_curve))
    return ause

def calculate_spearman_correlation(uncertainty, errors, device):
    """
    Calculates the Spearman's Rank Correlation Coefficient between
    the uncertainty and the absolute error.
    """
    uncertainty_flat = torch.from_numpy(uncertainty.flatten()).to(device)
    errors_flat = torch.from_numpy(errors.flatten()).to(device)
    
    # obtain ranks
    xr = torch.argsort(torch.argsort(uncertainty_flat)).float()
    yr = torch.argsort(torch.argsort(errors_flat)).float()

    # demean
    xr = xr - xr.mean()
    yr = yr - yr.mean()

    # compute covariance and norms
    cov = (xr * yr).sum() / (xr.numel() - 1)
    rho = cov / (xr.std(unbiased=True) * yr.std(unbiased=True))
    
    return rho.item()

def calculate_pearson_correlation(uncertainty, errors, device):
    """
    Calculates the Pearson Correlation Coefficient between
    the uncertainty and the absolute error on the GPU.
    """
    uncertainty_flat = torch.from_numpy(uncertainty.flatten()).to(device)
    errors_flat = torch.from_numpy(errors.flatten()).to(device)

    # Demean
    uncertainty_demeaned = uncertainty_flat - uncertainty_flat.mean()
    errors_demeaned = errors_flat - errors_flat.mean()

    # Compute covariance and standard deviations using N-1 for unbiased estimate
    n = uncertainty_flat.numel()
    if n < 2:
        return 0.0 # Not enough data to compute correlation
        
    cov = (uncertainty_demeaned * errors_demeaned).sum() / (n - 1)
    std_uncertainty = torch.std(uncertainty_flat, unbiased=True)
    std_errors = torch.std(errors_flat, unbiased=True)

    # Compute Pearson correlation
    # Add epsilon to avoid division by zero
    rho = cov / (std_uncertainty * std_errors + 1e-6)

    return rho.item()


import torchmetrics
import torchmetrics.image

def calculate_volume_metrics_2_pass(files_obj, model_config, scan_info, gt_volume, device):
    """
    Calculates stats and metrics using a two-pass algorithm for variance for improved stability.
    Also includes a robust PSNR calculation that handles infinite values.
    Pass 1: Calculate the mean of all predictions.
    Pass 2: Calculate variance and other metrics using the pre-calculated mean.
    """
    n_samples = model_config['count']
    gt_volume = gt_volume.to(device) # Ensure GT is on the correct device

    def prediction_generator():
        # This generator yields tensors directly on the GPU
        patient, scan, scan_type = scan_info
        domain = model_config['domain']
        root = model_config['model_version_root']
        model_type = model_config['type']
        for i in range(n_samples):
            passthrough_num = None
            model_version = root
            if model_type == 'ensemble': model_version = f"{root}_{i+1:02d}"
            elif model_type == 'stochastic': passthrough_num = i

            if domain == 'PROJ':
                pred_path = files_obj.get_projections_results_filepath(model_version, patient, scan, scan_type, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
                pred = torch.from_numpy(scipy.io.loadmat(pred_path)['prj']).detach().permute(1, 0, 2)
            elif domain == 'FDK':
                pred_path = files_obj.get_recon_filepath(model_version, patient, scan, scan_type, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
                pred = torch.load(pred_path).detach()
                pred = pred[20:-20, :, :]
                if scan_type == "FF":
                    pred = pred[:, 128:-128, 128:-128]
                pred = 25. * torch.clip(pred, min=0.0, max=0.04)
            elif domain == 'IMAG':
                pred_path = files_obj.get_images_results_filepath(model_version, patient, scan, passthrough_num=passthrough_num, ensure_exists=False)
                pred = torch.load(pred_path).detach()
                pred = torch.squeeze(pred, dim=1)
                pred = torch.permute(pred, (0, 2, 1))
            yield pred.to(device)

    # --- Pass 1: Calculate Mean ---
    print("Pass 1: Calculating mean prediction...")
    mean_volume = torch.zeros_like(gt_volume)
    # Using a simple sum and divide for the mean
    for pred_volume in tqdm(prediction_generator(), total=n_samples, desc="Pass 1/2 (Mean)", leave=False):
        mean_volume += pred_volume
    mean_volume /= n_samples

    # --- Pass 2: Calculate Variance and Metrics ---
    print("Pass 2: Calculating variance and metrics...")
    sum_sq_diff_volume = torch.zeros_like(gt_volume)
    sample_avg_ssims, sample_avg_psnrs, sample_avg_mses, sample_avg_maes = [], [], [], []

    # Initialize metrics on the specified device
    data_range = 1.0
    ssim_metric = torchmetrics.image.StructuralSimilarityIndexMeasure(data_range=data_range, **SSIM_KWARGS).to(device)

    # PSNR metric that returns per-slice results to handle 'inf'
    psnr_metric = torchmetrics.image.PeakSignalNoiseRatio(data_range=data_range, reduction='none').to(device)

    for pred_volume in tqdm(prediction_generator(), total=n_samples, desc="Pass 2/2 (Var & Metrics)", leave=False):
        # Variance calculation
        diff = pred_volume - mean_volume
        sum_sq_diff_volume += diff * diff

        # Per-sample metrics
        if gt_volume.ndim > 2:
            pred_vol_batch = pred_volume.unsqueeze(1)
            gt_vol_batch = gt_volume.unsqueeze(1)
            
            sample_avg_ssims.append(ssim_metric(pred_vol_batch, gt_vol_batch).item())
            sample_avg_mses.append(torch.mean((gt_volume - pred_volume)**2).item())
            sample_avg_maes.append(torch.mean(torch.abs(gt_volume - pred_volume)).item())

            # --- Robust PSNR Calculation ---
            psnr_per_slice = psnr_metric(pred_vol_batch, gt_vol_batch)
            finite_psnrs = psnr_per_slice[torch.isfinite(psnr_per_slice)] # Filter out inf values
            if finite_psnrs.numel() > 0:
                sample_avg_psnrs.append(torch.mean(finite_psnrs).item())
            else:
                # For debugging, print abs difference between GT and prediction
                abs_diff = torch.sum(torch.abs(gt_volume - pred_volume))
                print(f"Absolute difference between GT and prediction: {abs_diff.item()}")
                # Raise error
                raise ValueError("All slices in the prediction are perfect matches, leading to infinite PSNR values.")

    # Finalize variance and uncertainty
    if n_samples > 1:
        # Using n_samples for population standard deviation, as in the original code.
        # For sample standard deviation, use (n_samples - 1).
        variance_volume_map = sum_sq_diff_volume / n_samples
        uncertainty_volume_map = torch.sqrt(variance_volume_map)
    else:
        uncertainty_volume_map = torch.zeros_like(mean_volume)

    # --- Calculate metrics for the mean prediction ---
    metrics = {}
    if gt_volume.ndim > 2:
        mean_vol_batch = mean_volume.unsqueeze(1)
        gt_vol_batch = gt_volume.unsqueeze(1)
        
        metrics['mean_ssim'] = ssim_metric(mean_vol_batch, gt_vol_batch).item()
        metrics['mean_mse'] = torch.mean((gt_volume - mean_volume)**2).item()
        metrics['mean_mae'] = torch.mean(torch.abs(gt_volume - mean_volume)).item()

        # Robust PSNR for the mean prediction
        mean_psnr_per_slice = psnr_metric(mean_vol_batch, gt_vol_batch)
        finite_mean_psnrs = mean_psnr_per_slice[torch.isfinite(mean_psnr_per_slice)]
        if finite_mean_psnrs.numel() > 0:
            metrics['mean_psnr'] = torch.mean(finite_mean_psnrs).item()
        else:
            # For debugging, print abs difference between GT and mean prediction
            abs_diff = torch.sum(torch.abs(gt_volume - mean_volume))
            print(f"Absolute difference between GT and mean prediction: {abs_diff.item()}")
            # Raise error
            raise ValueError("All slices in the mean prediction are perfect matches, leading to infinite PSNR values.")

    # --- Aggregate per-sample metrics ---
    metrics['sample_avg_ssim'] = np.mean(sample_avg_ssims) if sample_avg_ssims else 0
    metrics['sample_avg_psnr'] = np.mean(sample_avg_psnrs) if sample_avg_psnrs else 0
    metrics['sample_avg_mse'] = np.mean(sample_avg_mses) if sample_avg_mses else 0
    metrics['sample_avg_mae'] = np.mean(sample_avg_maes) if sample_avg_maes else 0
    metrics['mean_std'] = torch.mean(uncertainty_volume_map).item()
    metrics['rmv'] = torch.sqrt(torch.mean(uncertainty_volume_map**2)).item()

    return metrics, mean_volume, uncertainty_volume_map

def calculate_evidential_volume_metrics(files_obj, model_config, scan_info, gt_volume, device):
    """
    Calculates stats and metrics for a single-pass evidential model.
    """
    patient, scan, scan_type = scan_info
    domain = model_config['domain']
    root = model_config['model_version_root']
    
    if domain != 'IMAG':
        raise NotImplementedError("Evidential regression is only implemented for the IMAG domain.")

    # Path to the dictionary of evidential outputs (passthrough_num is None for the final result)
    pred_path = files_obj.get_images_results_filepath(root, patient, scan, passthrough_num=None, ensure_exists=False)
    
    # Load the dictionary of tensors
    evidential_outputs = torch.load(pred_path, map_location=device)
    gamma = evidential_outputs['gamma']  # This is the mean prediction
    nu = evidential_outputs['nu']
    alpha = evidential_outputs['alpha']
    beta = evidential_outputs['beta']

    # Transpose all tensors
    gamma = torch.permute(gamma, (0, 2, 1))
    nu = torch.permute(nu, (0, 2, 1))
    alpha = torch.permute(alpha, (0, 2, 1))
    beta = torch.permute(beta, (0, 2, 1))

    # Mean prediction is gamma
    mean_volume = gamma.detach()

    # Total uncertainty: Var[y] = E[sigma^2] + Var[mu] = (beta / (alpha - 1)) * (1 + 1/nu)
    # Add a small epsilon to denominators to avoid division by zero
    variance_volume_map = (beta / (alpha - 1.0 + 1e-6)) * (1.0 + (1.0 / (nu + 1e-6)))
    uncertainty_volume_map = torch.sqrt(variance_volume_map).detach()

    # --- Calculate metrics for the mean prediction ---
    metrics = {}
    gt_volume = gt_volume.to(device)

    # Initialize torchmetrics
    data_range = 1.0
    ssim_metric = torchmetrics.image.StructuralSimilarityIndexMeasure(data_range=data_range, **SSIM_KWARGS).to(device)
    psnr_metric = torchmetrics.image.PeakSignalNoiseRatio(data_range=data_range, reduction='none').to(device)

    mean_vol_batch = mean_volume.unsqueeze(1) # Add channel dim
    gt_vol_batch = gt_volume.unsqueeze(1)   # Add channel dim

    metrics['mean_ssim'] = ssim_metric(mean_vol_batch, gt_vol_batch).item()
    metrics['mean_mse'] = torch.mean((gt_volume - mean_volume)**2).item()
    metrics['mean_mae'] = torch.mean(torch.abs(gt_volume - mean_volume)).item()

    # Robust PSNR calculation
    mean_psnr_per_slice = psnr_metric(mean_vol_batch, gt_vol_batch)
    finite_mean_psnrs = mean_psnr_per_slice[torch.isfinite(mean_psnr_per_slice)]
    if finite_mean_psnrs.numel() > 0:
        metrics['mean_psnr'] = torch.mean(finite_mean_psnrs).item()
    else:
        metrics['mean_psnr'] = float('inf')

    # "sample" metrics don't apply, so we copy the mean metrics for a consistent DataFrame.
    metrics['sample_avg_ssim'] = metrics['mean_ssim']
    metrics['sample_avg_psnr'] = metrics['mean_psnr']
    metrics['sample_avg_mse'] = metrics['mean_mse']
    metrics['sample_avg_mae'] = metrics['mean_mae']
    metrics['mean_std'] = torch.mean(uncertainty_volume_map).item()
    metrics['rmv'] = torch.sqrt(torch.mean(uncertainty_volume_map**2)).item()

    return metrics, mean_volume, uncertainty_volume_map

def calculate_error_model_metrics(files_obj, model_config, scan_info, gt_volume, device):
    """
    Calculates stats and metrics for an auxiliary error-prediction model.
    The 'mean prediction' is from the primary model, and the 'uncertainty' is the
    prediction from the auxiliary error model.
    """
    patient, scan, scan_type = scan_info
    # The primary model's version is specified in the 'domain' field for this type
    primary_model_version = model_config['domain']
    error_model_version = model_config['model_version_root']

    # Enforce that the count for this model type must be 1
    if model_config['count'] != 1:
        raise ValueError(f"Error models must have a count of 1, but got {model_config['count']}.")
    
    # The domain is implicitly IMAG for this model type. The 'domain' field is repurposed.
    if model_config['domain'] == 'IMAG':
        raise ValueError("For 'error' model type, 'domain' should specify the primary model version, not 'IMAG'.")

    # --- Load Mean Prediction (from primary model) ---
    # The primary model is deterministic, so passthrough_num is None
    mean_pred_path = files_obj.get_images_results_filepath(primary_model_version, patient, scan, passthrough_num=None, ensure_exists=False)
    mean_volume = torch.load(mean_pred_path, map_location=device)
    mean_volume = torch.squeeze(mean_volume, dim=1)
    mean_volume = torch.permute(mean_volume, (0, 2, 1))

    # --- Load Uncertainty Prediction (from auxiliary error model) ---
    # The error model is also deterministic (count=1), so passthrough_num is None
    uncertainty_path = files_obj.get_error_results_filepath(error_model_version, patient, scan, passthrough_num=None, ensure_exists=False)
    uncertainty_volume_map = torch.load(uncertainty_path, map_location=device)
    # The error model should output a single channel, but we squeeze just in case
    uncertainty_volume_map = torch.squeeze(uncertainty_volume_map, dim=1)
    uncertainty_volume_map = torch.permute(uncertainty_volume_map, (0, 2, 1))

    # --- Calculate metrics for the mean prediction ---
    metrics = {}
    gt_volume = gt_volume.to(device)

    # Initialize torchmetrics
    data_range = 1.0
    ssim_metric = torchmetrics.image.StructuralSimilarityIndexMeasure(data_range=data_range, **SSIM_KWARGS).to(device)
    psnr_metric = torchmetrics.image.PeakSignalNoiseRatio(data_range=data_range, reduction='none').to(device)

    mean_vol_batch = mean_volume.unsqueeze(1) # Add channel dim
    gt_vol_batch = gt_volume.unsqueeze(1)   # Add channel dim

    metrics['mean_ssim'] = ssim_metric(mean_vol_batch, gt_vol_batch).item()
    metrics['mean_mse'] = torch.mean((gt_volume - mean_volume)**2).item()
    metrics['mean_mae'] = torch.mean(torch.abs(gt_volume - mean_volume)).item()

    # Robust PSNR calculation
    mean_psnr_per_slice = psnr_metric(mean_vol_batch, gt_vol_batch)
    finite_mean_psnrs = mean_psnr_per_slice[torch.isfinite(mean_psnr_per_slice)]
    if finite_mean_psnrs.numel() > 0:
        metrics['mean_psnr'] = torch.mean(finite_mean_psnrs).item()
    else:
        metrics['mean_psnr'] = float('inf')

    # "sample" metrics don't apply, so we copy the mean metrics for a consistent DataFrame.
    metrics['sample_avg_ssim'] = metrics['mean_ssim']
    metrics['sample_avg_psnr'] = metrics['mean_psnr']
    metrics['sample_avg_mse'] = metrics['mean_mse']
    metrics['sample_avg_mae'] = metrics['mean_mae']
    metrics['mean_std'] = torch.mean(uncertainty_volume_map).item()
    metrics['rmv'] = torch.sqrt(torch.mean(uncertainty_volume_map**2)).item()

    return metrics, mean_volume, uncertainty_volume_map

def plot_error_histogram(errors, model_name, scan_name, combined=False):
    """Plots histograms of the raw (not absolute) errors with and without percentile range limits."""
    fig, axs = plt.subplots(1, 2, figsize=(18, 5))

    scan_title_part = "All Scans" if combined else scan_name

    # Histogram with percentile range limits
    if errors.size > 0:
        range_lims = np.percentile(errors, [0.1, 99.9])
        axs[0].hist(errors, bins=150, log=False, range=range_lims)
        axs[0].set_title(f'Raw Errors (0.1-99.9% Range)\n{model_name} - {scan_title_part}')
    else:
        axs[0].hist(errors, bins=150, log=True)
        axs[0].set_title(f'Raw Errors (Auto Range)\n{model_name} - {scan_title_part}')
    axs[0].set_xlabel('Error (GT - Prediction)')
    axs[0].set_ylabel('Frequency (log scale)')
    axs[0].grid(True, linestyle=':')

    # Histogram with full range (no limits)
    axs[1].hist(errors, bins=150, log=False)
    axs[1].set_title(f'Raw Errors (Full Range)\n{model_name} - {scan_title_part}')
    axs[1].set_xlabel('Error (GT - Prediction)')
    axs[1].set_ylabel('Frequency (log scale)')
    axs[1].grid(True, linestyle=':')

    plt.tight_layout()
    plt.show()

def plot_std_histogram(std_devs, model_name, scan_name, combined=False):
    """Plots histograms of the predicted standard deviations."""
    fig, axs = plt.subplots(1, 2, figsize=(18, 5))

    scan_title_part = "All Scans" if combined else scan_name

    # Histogram with percentile range limits
    if std_devs.size > 0:
        range_lims = [0, np.percentile(std_devs, 99.9)]
        axs[0].hist(std_devs, bins=150, log=False, range=range_lims)
        axs[0].set_title(f'Std Deviations (0-99.9% Range)\n{model_name} - {scan_title_part}')
    else:
        axs[0].hist(std_devs, bins=150, log=False)
        axs[0].set_title(f'Std Deviations (Auto Range)\n{model_name} - {scan_title_part}')
    axs[0].set_xlabel('Predicted Standard Deviation')
    axs[0].set_ylabel('Frequency')
    axs[0].grid(True, linestyle=':')

    # Histogram with full range (no limits)
    axs[1].hist(std_devs, bins=150, log=False)
    axs[1].set_title(f'Std Deviations (Full Range)\n{model_name} - {scan_title_part}')
    axs[1].set_xlabel('Predicted Standard Deviation')
    axs[1].set_ylabel('Frequency')
    axs[1].grid(True, linestyle=':')

    plt.tight_layout()
    plt.show()

print("Metric calculation functions defined.")

### Uncertainty calibration and post-calibration metric functions

In [None]:
import numpy as np
import scipy.stats
from scipy.interpolate import interp1d
from scipy.optimize import isotonic_regression

# --- Calibration Methods ---

def calculate_platt_scaler(errors, std_devs):
    """
    Calculates the optimal scaling factor 'T' for variance scaling.
    This factor is found by minimizing the NLL on a validation set.
    The optimal T is sqrt(mean(squared_error / variance)).

    Args:
        errors (np.ndarray): The absolute errors (ground_truth - prediction).
        std_devs (np.ndarray): The predicted standard deviations.

    Returns:
        float: The scaling factor T.
    """
    errors_flat = errors.flatten()
    std_devs_flat = std_devs.flatten()

    # To avoid division by zero, add a small epsilon to the variance
    variances_flat = std_devs_flat**2 + 1e-6

    # Calculate T^2 = mean(error^2 / variance)
    t_squared = np.mean(errors_flat**2 / variances_flat)

    return np.sqrt(t_squared)


def train_isotonic_regression(predicted_std, observed_errors, device):
    """
    Trains an Isotonic Regression model using scipy.optimize.isotonic_regression.
    """
    pred_std_flat = predicted_std.flatten()
    obs_err_flat = observed_errors.flatten()

    # Sort by predicted standard deviation
    print("Sorting erorrs...")
    sort_indices = torch.argsort(torch.from_numpy(pred_std_flat).to(device)).cpu().numpy()
    sorted_pred_std = pred_std_flat[sort_indices]
    sorted_obs_err_sq = obs_err_flat[sort_indices]**2

    # Apply scipy's isotonic regression
    print("Applying isotonic regression...")
    calibrated_variances = isotonic_regression(sorted_obs_err_sq).x
    calibrated_std = np.sqrt(calibrated_variances)

    # Create an interpolation function to map new predictions
    unique_pred_std, unique_indices = np.unique(sorted_pred_std, return_index=True)
    unique_calib_std = calibrated_std[unique_indices]
    
    # check that the unique pred std are sorted
    if not np.all(np.diff(unique_pred_std) >= 0):
        raise ValueError("Predicted standard deviations are not sorted. Ensure the input is sorted before applying isotonic regression.")

    print("Creating interpolation model...")
    iso_model = interp1d(unique_pred_std, unique_calib_std, kind='linear', bounds_error=False, 
                         fill_value=(unique_calib_std[0], unique_calib_std[-1]), assume_sorted=True)

    return iso_model

def train_cdf_isotonic_regression(ground_truth, mean_pred, uncal_std, device):
    """
    Trains an Isotonic Regression model on the cumulative probabilities,
    as described in Kuleshov et al., 2018.

    Returns:
        scipy.interpolate.interp1d: The trained isotonic regression model.
    """
    gt_flat = ground_truth.flatten()
    pred_flat = mean_pred.flatten()
    uncert_flat = uncal_std.flatten() + 1e-9 # Avoid zero std dev

    print("Calculating predicted CDF values for calibration training...")
    # Get the predicted CDF value for each ground truth point
    pred_cdfs = scipy.stats.norm.cdf(gt_flat, loc=pred_flat, scale=uncert_flat)
    
    # Sort the CDF values to prepare for isotonic regression
    sorted_indices = torch.argsort(torch.from_numpy(pred_cdfs).to(device)).cpu().numpy()
    sorted_pred_cdfs = pred_cdfs[sorted_indices]

    # The target values for a perfectly calibrated model would be uniformly spaced
    n_points = len(sorted_pred_cdfs)
    empirical_cdfs = np.arange(n_points) / n_points

    print("Training Isotonic Regression model on CDF values...")
    # The isotonic_regression function in scipy returns the calibrated values directly
    calibrated_cdfs = isotonic_regression(empirical_cdfs).x
    
    # Create an interpolation function to map any new predicted cdf to a calibrated one
    print("Creating interpolation model for CDF calibration...")
    iso_cdf_model = interp1d(
        sorted_pred_cdfs, 
        calibrated_cdfs,
        kind='linear',
        bounds_error=False,
        fill_value=(0.0, 1.0),
        assume_sorted=True
    )
    
    return iso_cdf_model


# TODO go through this and make sure it makes sense (esp that is actually does what Kuleshov does [obvously other than the ppf part])
def apply_cdf_isotonic_regression(iso_cdf_model, uncal_uncertainty_map, target_confidence=0.95):
    """
    Uses a trained CDF isotonic model to find a new scaling factor for the standard deviation.
    """
    print(f"Finding new std dev scaling factor for {target_confidence*100}% confidence...")
    
    # We want to find a new z-score (and thus a new std dev) such that the
    # recalibrated confidence interval matches the target confidence.
    # We are looking for an original probability p_orig such that:
    # iso_cdf_model(p_orig) - iso_cdf_model(1 - p_orig) = target_confidence
    
    p_lower_orig = (1.0 - target_confidence) / 2.0
    p_upper_orig = 1.0 - p_lower_orig

    # The calibrated probability of the original interval is:
    calibrated_prob = iso_cdf_model(p_upper_orig) - iso_cdf_model(p_lower_orig)
    
    # We need to find a new z-score that corresponds to this calibrated probability
    # The new upper probability is (1 + calibrated_prob) / 2
    p_upper_new = (1.0 + calibrated_prob) / 2.0
    
    # Find the z-score for this new probability
    z_new = scipy.stats.norm.ppf(p_upper_new)
    
    # Find the z-score for the original target confidence
    z_orig = scipy.stats.norm.ppf(p_upper_orig)
    
    # The scaling factor is the ratio of the new z-score to the original one
    scaling_factor = z_new / (z_orig + 1e-6)
    
    print(f"CDF calibration scaling factor: {scaling_factor:.4f}")
    
    return uncal_uncertainty_map * scaling_factor


# --- New Evaluation Metrics ---

def calculate_nll(ground_truth, mean_pred, uncertainty_map):
    """
    Calculates the Negative Log-Likelihood (NLL) for a Gaussian prediction.
    """
    gt_flat = ground_truth.flatten()
    pred_flat = mean_pred.flatten()
    uncert_flat = uncertainty_map.flatten()

    # Add a small epsilon to variance to prevent log(0) or division by zero
    variance = uncert_flat**2 + 1e-9
    
    # NLL formula for a Gaussian distribution
    nll_values = 0.5 * (np.log(2 * np.pi * variance) + (gt_flat - pred_flat)**2 / variance)
    
    return np.mean(nll_values)


def calculate_all_eces(ground_truth, mean_pred, uncertainty_map, n_bins=20):
    """
    Calculates variations of the Expected Calibration Error (ECE) for regression.
    """
    gt_flat = ground_truth.flatten()
    pred_flat = mean_pred.flatten()
    uncert_flat = uncertainty_map.flatten() + 1e-9 # Avoid zero std dev

    # Get the predicted CDF value for each ground truth point
    pred_cdfs = scipy.stats.norm.cdf(gt_flat, loc=pred_flat, scale=uncert_flat)
    
    expected_confidence_levels = np.linspace(0, 1, n_bins)
    observed_frequencies = np.array([np.mean(pred_cdfs <= p_j) for p_j in expected_confidence_levels])

    # --- Calculate bin weights (proportional to number of points in each bin) ---
    bin_weights = np.zeros(n_bins)
    for i in range(1, n_bins):
        lower_bound = expected_confidence_levels[i-1]
        upper_bound = expected_confidence_levels[i]
        points_in_bin = (pred_cdfs > lower_bound) & (pred_cdfs <= upper_bound)
        bin_weights[i] = np.mean(points_in_bin)
    
    if np.sum(bin_weights) > 0:
        bin_weights /= np.sum(bin_weights)
    
    # --- Calculate ECE variants ---
    abs_diff = np.abs(expected_confidence_levels - observed_frequencies)
    sq_diff = (expected_confidence_levels - observed_frequencies)**2

    ece_weighted_abs = np.sum(bin_weights * abs_diff)
    ece_weighted_sq = np.sum(bin_weights * sq_diff)
    ece_unweighted_abs = np.mean(abs_diff)
    ece_unweighted_sq = np.mean(sq_diff)
    
    return {
        'ece_weighted_abs': ece_weighted_abs,
        'ece_weighted_sq': ece_weighted_sq,
        'ece_unweighted_abs': ece_unweighted_abs,
        'ece_unweighted_sq': ece_unweighted_sq,
    }


def calculate_ence(ground_truth, mean_pred, uncertainty_map, device, n_bins=20):
    """
    Calculates the Expected Normalized Calibration Error (ENCE) using quantile-based binning
    as described in Levi et al., 2020.
    """
    gt_flat = ground_truth.flatten()
    pred_flat = mean_pred.flatten()
    uncert_flat = uncertainty_map.flatten()
    
    # Ensure there are enough unique values for binning
    if len(np.unique(uncert_flat)) < n_bins:
        print(f"Warning: Number of unique uncertainties is less than n_bins. ENCE may be unreliable.")

    # Get the indices that would sort the uncertainties
    sorted_indices = torch.argsort(torch.from_numpy(uncert_flat).to(device)).cpu().numpy()
    
    # Split the sorted indices into N bins of equal size.
    # np.array_split handles cases where the total number of points is not divisible by n_bins.
    binned_indices = np.array_split(sorted_indices, n_bins)

    ence_sum = 0
    
    # Loop through each bin of indices
    for bin_idx_list in binned_indices:
        # Skip empty bins, though this is unlikely with quantile binning
        if len(bin_idx_list) > 0:
            # Root Mean Variance (RMV) in bin
            rmv_j = np.sqrt(np.mean(uncert_flat[bin_idx_list]**2)) + 1e-9
            
            # Root Mean Squared Error (RMSE) in bin
            rmse_j = np.sqrt(np.mean((gt_flat[bin_idx_list] - pred_flat[bin_idx_list])**2))

            # Add the normalized error for this bin to the sum
            ence_sum += np.abs(rmv_j - rmse_j) / rmv_j

    return ence_sum / n_bins


def calculate_mpiw(uncertainty_map, confidence_levels=[0.68, 0.95]):
    """
    Calculates the Mean Prediction Interval Width (MPIW) for given confidence levels.
    """
    uncert_flat = uncertainty_map.flatten()
    results = {}
    for level in confidence_levels:
        # Get the z-score for the confidence level (e.g., 1.96 for 95%)
        z_score = scipy.stats.norm.ppf(1 - (1 - level) / 2)
        
        # Width of the prediction interval
        widths = 2 * z_score * uncert_flat
        
        # Store the mean width
        results[f'mpiw_{int(level*100)}'] = np.mean(widths)
        
    return results

print("✅ Calibration and advanced uncertainty metric functions are defined.")

### Visualization functions

In [None]:
def plot_mean_comparison(mean_pred, ground_truth, uncertainty_map, model_name, scan_name, slice_idx, tumor_coords_xy=None, log_scale=False, clip_pct=None):
    """Plots the GT, mean prediction, absolute error, and uncertainty map."""
    error_map = np.abs(ground_truth - mean_pred)

    if log_scale:
        error_map = np.log1p(error_map)
        uncertainty_map = np.log1p(uncertainty_map)

    # Clip the the error and uncertainty maps for better visualization
    if clip_pct is not None:
        error_map = np.clip(error_map, 0, np.percentile(error_map, clip_pct))
        uncertainty_map = np.clip(uncertainty_map, 0, np.percentile(uncertainty_map, clip_pct))

    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    fig.suptitle(f'{model_name} - {scan_name} - Slice {slice_idx} (Mean vs. GT)', fontsize=16)

    im1 = axes[0].imshow(ground_truth, cmap='gray')
    axes[0].set_title('Ground Truth')
    axes[0].axis('off')
    fig.colorbar(im1, ax=axes[0])

    if tumor_coords_xy:
        x, y = tumor_coords_xy
        x -= 6
        y -= 6
        for i in range(4):
            axes[i].annotate('', xy=(x, y), xytext=(x - 30, y - 30),
                            arrowprops=dict(facecolor='red', edgecolor='red', shrink=0.05, width=1, headwidth=5, headlength=5))

    im2 = axes[1].imshow(mean_pred, cmap='gray')
    axes[1].set_title('Mean Prediction')
    axes[1].axis('off')
    fig.colorbar(im2, ax=axes[1])

    im3 = axes[2].imshow(error_map, cmap='magma')
    if log_scale:
        axes[2].set_title('Log1p Absolute Error Map')
    else:
        axes[2].set_title('Absolute Error Map')
    axes[2].axis('off')
    fig.colorbar(im3, ax=axes[2])

    im4 = axes[3].imshow(uncertainty_map, cmap='viridis')
    if log_scale:
        axes[3].set_title('Log1p Uncertainty Map (Std Dev)')
    else:
        axes[3].set_title('Uncertainty (Std Dev)')
    axes[3].axis('off')
    fig.colorbar(im4, ax=axes[3])

    plt.tight_layout()
    # plt.savefig(f"{model_name}_{scan_name}_slice{slice_idx}_clip_{clip_pct}_mean_comparison.png", dpi=400)
    plt.show()

def plot_ssim_map(ssim_map, model_name, scan_name):
    """Plots the SSIM map."""
    plt.figure(figsize=(3, 3))
    plt.imshow(ssim_map, cmap='viridis', vmin=0, vmax=1)
    plt.title(f'SSIM Map - {model_name} - {scan_name}')
    plt.colorbar()
    plt.axis('off')
    plt.show()

def plot_calibration_curve(ground_truth, mean_pred, uncertainty_map, model_name, scan_name, n_levels=20):
    """
    Plots the calibration curve with a marginal histogram below the x-axis
    showing the distribution of the predicted CDF values.
    """
    gt_flat = ground_truth.flatten()
    pred_flat = mean_pred.flatten()
    uncert_flat = uncertainty_map.flatten()

    # Calculate the predicted CDF value for each point
    pred_cdfs = scipy.stats.norm.cdf(gt_flat, loc=pred_flat, scale=uncert_flat)
    
    # --- Create figure with two subplots, sharing the x-axis ---
    fig, (ax_cal, ax_hist) = plt.subplots(
        2, 1,
        figsize=(8, 8),
        sharex=True,
        gridspec_kw={'height_ratios': [3, 1]} # Main plot is 3x taller
    )
    
    # --- Main Calibration Plot (top) ---
    expected_confidence_levels = np.linspace(0, 1, n_levels)
    observed_frequencies = np.array([np.mean(pred_cdfs <= p_j) for p_j in expected_confidence_levels])

    ax_cal.plot([0, 1], [0, 1], '--', color='grey', label='Perfectly Calibrated')
    ax_cal.plot(expected_confidence_levels, observed_frequencies, '-o', label='Model Calibration')
    ax_cal.set_ylabel('Observed Confidence Level')
    ax_cal.set_title(f'Calibration Plot - {model_name} - {scan_name}')
    ax_cal.legend()
    ax_cal.grid(True, linestyle=':')

    # --- Marginal Histogram (bottom) ---
    ax_hist.hist(pred_cdfs, bins=50, range=(0,1), density=True, color='steelblue', alpha=0.8)
    ax_hist.set_xlabel('Expected Confidence Level (Predicted CDF)')
    ax_hist.set_ylabel('Density')
    ax_hist.set_yscale('log')
    # ax_hist.set_yticks([]) # Hide y-ticks for clarity

    # Final adjustments
    plt.tight_layout()
    plt.show()

def plot_sparsification_curve(uncertainty, errors, model_name, scan_name, device):
    """
    Plots the model and oracle sparsification curves used for AUSE calculation.
    """
    uncertainty_flat = uncertainty.flatten()
    errors_flat = errors.flatten()
    
    def get_sparsification_curve_fast(sorted_errs, overall_mae):
        n_pixels = len(sorted_errs)
        cumulative_errors = np.cumsum(sorted_errs)
        total_error_sum = cumulative_errors[-1]
        sum_errors_removed = np.insert(cumulative_errors[:-1], 0, 0)
        sum_errors_remaining = total_error_sum - sum_errors_removed
        n_remaining = np.arange(n_pixels, 0, -1)
        curve = sum_errors_remaining / n_remaining
        if overall_mae > 0:
            curve = curve / overall_mae
        return curve
    
    overall_mae = np.mean(errors_flat)

    # Model curve (sorted by uncertainty)
    model_sorted_indices = torch.argsort(torch.from_numpy(uncertainty_flat).to(device), descending=True).cpu().numpy()
    model_sorted_errors = errors_flat[model_sorted_indices]
    model_curve = get_sparsification_curve_fast(model_sorted_errors, overall_mae)

    # Oracle curve (sorted by error)
    oracle_sorted_errors = torch.sort(torch.from_numpy(errors_flat).to(device), descending=True).values.cpu().numpy()
    oracle_curve = get_sparsification_curve_fast(oracle_sorted_errors, overall_mae)
    
    # X-axis: fraction of pixels removed
    fraction_removed = np.linspace(0, 1, len(model_curve))

    # Downsample to 1000 points
    if len(fraction_removed) > 1000:
        step = len(fraction_removed) // 1000
        fraction_removed = fraction_removed[::step]
        model_curve = model_curve[::step]
        oracle_curve = oracle_curve[::step]
    
    plt.figure(figsize=(7, 6))
    plt.plot(fraction_removed, model_curve, label='Model (Sort by Uncertainty)')
    plt.plot(fraction_removed, oracle_curve, '--', label='Oracle (Sort by Error)')
    plt.xlabel('Fraction of Pixels Removed')
    plt.ylabel('Mean Absolute Error of Remaining Pixels')
    plt.title(f'Sparsification Curve - {model_name} - {scan_name}')
    plt.legend()
    plt.grid(True, linestyle=':')
    plt.show()

def plot_samples_comparison(ground_truth, mean_pred, samples, model_name, scan_name, slice_idx, tumor_coords_xy=None):
    """
    Plots the ground truth, mean prediction, and a few individual sample predictions.
    
    Args:
        ground_truth (np.ndarray): The 2D ground truth slice.
        mean_pred (np.ndarray): The 2D mean prediction slice.
        samples (list of np.ndarray): A list of 2D sample prediction slices.
        model_name (str): The name of the model for the title.
        scan_name (str): The name of the scan for the title.
        slice_idx (int): The index of the slice for the title.
        tumor_coords_xy (tuple, optional): (x, y) coordinates for the tumor arrow.
    """
    num_samples = len(samples)
    # Total columns = 1 for GT + 1 for Mean + N for samples
    num_cols = 2 + num_samples
    
    fig, axes = plt.subplots(1, num_cols, figsize=(4 * num_cols, 4.5), constrained_layout=True)
    fig.suptitle(f'{model_name} - {scan_name} - Slice {slice_idx} (GT, Mean, and Samples)', fontsize=16)

    # Determine a consistent grayscale range based on the ground truth and mean
    vmin = min(ground_truth.min(), mean_pred.min())
    vmax = max(ground_truth.max(), mean_pred.max())

    # --- Plot Ground Truth ---
    axes[0].imshow(ground_truth, cmap='gray', vmin=vmin, vmax=vmax)
    axes[0].set_title('Ground Truth')
    axes[0].axis('off')

    # --- Plot Mean Prediction ---
    axes[1].imshow(mean_pred, cmap='gray', vmin=vmin, vmax=vmax)
    axes[1].set_title('Mean Prediction')
    axes[1].axis('off')

    # --- Plot Samples ---
    for i in range(num_samples):
        ax = axes[i + 2]
        im = ax.imshow(samples[i], cmap='gray', vmin=vmin, vmax=vmax)
        ax.set_title(f'Sample {i+1}')
        ax.axis('off')

    # --- Add Tumor Arrow ---
    if tumor_coords_xy:
        x, y = tumor_coords_xy
        x -= 6 # Shift so the arrow doesn't overlap the tumor
        y -= 6
        for ax in axes:
            ax.annotate('', xy=(x, y), xytext=(x - 30, y - 30),
                        arrowprops=dict(facecolor='red', edgecolor='red', shrink=0.05, 
                                        width=1, headwidth=5, headlength=5))

    plt.show()

def plot_worst_samples_comparison(ground_truth, mean_pred, worst_samples_data, model_name, scan_name, slice_idx, tumor_coords_xy=None):
    """
    Plots the ground truth, mean prediction, and the worst-performing sample predictions.
    
    Args:
        ground_truth (np.ndarray): The 2D ground truth slice.
        mean_pred (np.ndarray): The 2D mean prediction slice.
        worst_samples_data (list): A list of tuples, where each tuple is 
                                   (loss, sample_slice_numpy, sample_index).
        model_name (str): The name of the model for the title.
        scan_name (str): The name of the scan for the title.
        slice_idx (int): The index of the slice for the title.
        tumor_coords_xy (tuple, optional): (x, y) coordinates for the tumor arrow.
    """
    num_samples = len(worst_samples_data)
    num_cols = 2 + num_samples
    
    fig, axes = plt.subplots(1, num_cols, figsize=(4 * num_cols, 5), constrained_layout=True)
    fig.suptitle(f'{model_name} - {scan_name} - Slice {slice_idx} (Top {num_samples} Worst Samples by SmoothL1Loss)', fontsize=16)

    # Determine a consistent grayscale range
    vmin = min(ground_truth.min(), mean_pred.min())
    vmax = max(ground_truth.max(), mean_pred.max())

    # --- Plot Ground Truth ---
    axes[0].imshow(ground_truth, cmap='gray', vmin=vmin, vmax=vmax)
    axes[0].set_title('Ground Truth')
    axes[0].axis('off')

    # --- Plot Mean Prediction ---
    axes[1].imshow(mean_pred, cmap='gray', vmin=vmin, vmax=vmax)
    axes[1].set_title('Mean Prediction')
    axes[1].axis('off')

    # --- Plot Worst Samples ---
    for i in range(num_samples):
        loss, sample_slice, sample_idx = worst_samples_data[i]
        ax = axes[i + 2]
        im = ax.imshow(sample_slice, cmap='gray', vmin=vmin, vmax=vmax)
        # Add the loss and original sample number to the title
        ax.set_title(f'Sample #{sample_idx}\nLoss: {loss:.4f}')
        ax.axis('off')

    # --- Add Tumor Arrow ---
    if tumor_coords_xy:
        x, y = tumor_coords_xy
        for ax in axes:
            ax.annotate('', xy=(x, y), xytext=(x - 30, y - 30),
                        arrowprops=dict(facecolor='red', edgecolor='red', shrink=0.05, 
                                        width=1, headwidth=5, headlength=5))

    plt.show()

def plot_combined_calibration_curves(
    ground_truth, mean_pred, 
    platt_uncertainty_map, iso_uncertainty_map, iso_cdf_uncertainty_map, 
    model_name, scan_name, n_bins=20
):
    """
    Plots ECE and ENCE calibration curves on a single figure,
    comparing all three calibration methods directly.
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 7))
    fig.suptitle(f'Calibration Comparison for {model_name} on {scan_name}', fontsize=16)

    # --- ECE Subplot (ax1) ---
    ece_calibrations = {
        'STD Scaling': platt_uncertainty_map, 
        'Isotonic (Var)': iso_uncertainty_map,
        'Isotonic (CDF)': iso_cdf_uncertainty_map
    }
    ece_values = {}

    for name, uncert_map in ece_calibrations.items():
        gt_flat, pred_flat, uncert_flat = ground_truth.flatten(), mean_pred.flatten(), uncert_map.flatten() + 1e-9
        pred_cdfs = scipy.stats.norm.cdf(gt_flat, loc=pred_flat, scale=uncert_flat)
        
        expected_confidence = np.linspace(0, 1, n_bins + 1)
        observed_confidence = np.array([np.mean(pred_cdfs <= p_j) for p_j in expected_confidence])
        
        ax1.plot(expected_confidence, observed_confidence, '-o', label=name, alpha=0.8)
        
        ece_metric = calculate_all_eces(ground_truth, mean_pred, uncert_map, n_bins=n_bins)
        ece_values[name] = ece_metric['ece_unweighted_abs']

    ax1.plot([0, 1], [0, 1], '--', color='grey', label='Perfect')
    title_ece = (
        f'ECE Plot\nSTD ECE: {ece_values["STD Scaling"]:.4f} | '
        f'Iso (Var) ECE: {ece_values["Isotonic (Var)"]:.4f} | '
        f'Iso (CDF) ECE: {ece_values["Isotonic (CDF)"]:.4f}'
    )
    ax1.set_title(title_ece)
    ax1.set_xlabel('Expected Confidence Level')
    ax1.set_ylabel('Observed Confidence Level')
    ax1.legend()
    ax1.grid(True, linestyle=':')
    ax1.axis('equal')
    ax1.set_xlim([0, 1])
    ax1.set_ylim([0, 1])

    # --- ENCE Subplot (ax2) ---
    ence_calibrations = {
        'STD Scaling': platt_uncertainty_map,
        'Isotonic (Var)': iso_uncertainty_map,
        'Isotonic (CDF)': iso_cdf_uncertainty_map
    }
    ence_values = {}
    all_points = []

    for name, uncert_map in ence_calibrations.items():
        gt_flat, pred_flat, uncert_flat = ground_truth.flatten(), mean_pred.flatten(), uncert_map.flatten()
        
        sorted_indices = torch.argsort(torch.from_numpy(uncert_flat).to(DEVICE)).cpu().numpy()
        binned_indices = np.array_split(sorted_indices, n_bins)
        
        bin_rmv, bin_rmse = [], []
        for bin_idx_list in binned_indices:
            if len(bin_idx_list) > 0:
                rmv_j = np.sqrt(np.mean(uncert_flat[bin_idx_list]**2))
                rmse_j = np.sqrt(np.mean((gt_flat[bin_idx_list] - pred_flat[bin_idx_list])**2))
                bin_rmv.append(rmv_j)
                bin_rmse.append(rmse_j)
        
        ax2.plot(bin_rmv, bin_rmse, '-o', label=name, alpha=0.8, zorder=3)
        all_points.extend(bin_rmv)
        all_points.extend(bin_rmse)
        
        ence_values[name] = calculate_ence(ground_truth, mean_pred, uncert_map, DEVICE, n_bins=n_bins)

    if all_points:
      max_val = np.max(all_points) * 1.1
      ax2.plot([0, max_val], [0, max_val], '--', color='grey', label='Perfect')
      ax2.set_xlim(left=0, right=max_val)
      ax2.set_ylim(bottom=0, top=max_val)

    title_ence = (
        f'ENCE Plot (RMSE vs. RMV)\nSTD ENCE: {ence_values["STD Scaling"]:.4f} | '
        f'Iso (Var) ENCE: {ence_values["Isotonic (Var)"]:.4f} | '
        f'Iso (CDF) ENCE: {ence_values["Isotonic (CDF)"]:.4f}'
    )
    ax2.set_title(title_ence)
    ax2.set_xlabel('Root Mean Variance (RMV) per Bin')
    ax2.set_ylabel('Root Mean Squared Error (RMSE) per Bin')
    ax2.legend()
    ax2.grid(True, linestyle=':')
    ax2.axis('equal')

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

def plot_combined_calibration_curves_old(
    ground_truth, mean_pred, 
    platt_uncertainty_map, iso_uncertainty_map, 
    model_name, scan_name, n_bins=20
):
    """
    Plots ECE and ENCE calibration curves on a single figure,
    comparing STD Scaling and Isotonic Regression directly.
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
    fig.suptitle(f'Calibration Comparison for {model_name} on {scan_name}', fontsize=16)

    # --- ECE Subplot (ax1) ---
    ece_calibrations = {'STD Scaling': platt_uncertainty_map, 'Isotonic': iso_uncertainty_map}
    ece_values = {}

    for name, uncert_map in ece_calibrations.items():
        gt_flat, pred_flat, uncert_flat = ground_truth.flatten(), mean_pred.flatten(), uncert_map.flatten() + 1e-9
        pred_cdfs = scipy.stats.norm.cdf(gt_flat, loc=pred_flat, scale=uncert_flat)
        
        expected_confidence = np.linspace(0, 1, n_bins + 1)
        observed_confidence = np.array([np.mean(pred_cdfs <= p_j) for p_j in expected_confidence])
        
        ax1.plot(expected_confidence, observed_confidence, '-o', label=name, alpha=0.8)
        
        ece_metric = calculate_all_eces(ground_truth, mean_pred, uncert_map, n_bins=n_bins)
        ece_values[name] = ece_metric['ece_unweighted_abs']

    ax1.plot([0, 1], [0, 1], '--', color='grey', label='Perfect')
    ax1.set_title(f'ECE Plot | STD ECE: {ece_values["STD Scaling"]:.4f} | Isotonic ECE: {ece_values["Isotonic"]:.4f}')
    ax1.set_xlabel('Expected Confidence Level')
    ax1.set_ylabel('Observed Confidence Level')
    ax1.legend()
    ax1.grid(True, linestyle=':')
    ax1.axis('equal')
    ax1.set_xlim([0, 1])
    ax1.set_ylim([0, 1])

    # --- ENCE Subplot (ax2) ---
    ence_calibrations = {'STD Scaling': platt_uncertainty_map, 'Isotonic': iso_uncertainty_map}
    ence_values = {}
    all_points = []

    for name, uncert_map in ence_calibrations.items():
        gt_flat, pred_flat, uncert_flat = ground_truth.flatten(), mean_pred.flatten(), uncert_map.flatten()
        
        sorted_indices = torch.argsort(torch.from_numpy(uncert_flat).to(DEVICE)).cpu().numpy()
        binned_indices = np.array_split(sorted_indices, n_bins)
        
        bin_rmv, bin_rmse = [], []
        for bin_idx_list in binned_indices:
            if len(bin_idx_list) > 0:
                rmv_j = np.sqrt(np.mean(uncert_flat[bin_idx_list]**2))
                rmse_j = np.sqrt(np.mean((gt_flat[bin_idx_list] - pred_flat[bin_idx_list])**2))
                bin_rmv.append(rmv_j)
                bin_rmse.append(rmse_j)
        
        ax2.plot(bin_rmv, bin_rmse, '-o', label=name, alpha=0.8, zorder=3)
        all_points.extend(bin_rmv)
        all_points.extend(bin_rmse)
        
        ence_values[name] = calculate_ence(ground_truth, mean_pred, uncert_map, DEVICE, n_bins=n_bins)

    if all_points:
      max_val = np.max(all_points) * 1.1
      ax2.plot([0, max_val], [0, max_val], '--', color='grey', label='Perfect')
      ax2.set_xlim(left=0, right=max_val)
      ax2.set_ylim(bottom=0, top=max_val)

    ax2.set_title(f'ENCE Plot (RMSE vs. RMV) | STD ENCE: {ence_values["STD Scaling"]:.4f} | Isotonic ENCE: {ence_values["Isotonic"]:.4f}')
    ax2.set_xlabel('Root Mean Variance (RMV) per Bin')
    ax2.set_ylabel('Root Mean Squared Error (RMSE) per Bin')
    ax2.legend()
    ax2.grid(True, linestyle=':')
    ax2.axis('equal')

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

print("Visualization functions defined.")

## Pre-calibration processing & results

NOTES FOR BBB:
pi=0.75, sigma1=1e-1, sigma2=1e-3, beta=1e-1 was good
pi=0.75, sigma1=1e-1, sigma2=1e-3, beta=1e-2 was good
pi=0.25, sigma1=1e-1, sigma2=1e-3, beta=1e-2 was good

pi=0.75, sigma1=1e-1, sigma2=1e-3, beta=1e-3 had hallucinations
pi=0.75, sigma1=1e-1, sigma2=1e-3, beta=1e0 had minor hallucinations
pi=0.75, beta=1e-2, and sigma1=5e-1, sigma1=5e-2, sigma2=1e-2, sigma2=1e-4 all hallucinated (i.e., whenever I changed sigma1 or sigma2)
pi=0.5 has been bad for all beta (hallucinations)
pi=0.25, sigma1=1e-1, sigma2=1e-3, beta=1e-1 had large, but not very extreme hallucinations

### Main loop: Pre-calibration

In [None]:
# This list will store dictionaries of results for each scan and model
all_results = []
# These dictionaries will store all pixel/voxel data for combined histograms
all_model_errors = {}
all_model_stds = {}


for model_config in MODELS_TO_ANALYZE:
    model_name = model_config['name']
    domain = model_config['domain']
    model_type = model_config['type']
    
    # Initialize lists for the current model
    all_model_errors[model_name] = []
    all_model_stds[model_name] = []
    scan_results = []

    for scan_info in tqdm(analysis_scans, desc=f"Analyzing Model: {model_name}"):
        patient, scan, _ = scan_info
        scan_name = f"p{patient}_{scan}"
        
        # --- Determine Domain and Plotting Slice ---
        is_visual_domain = domain != "PROJ"
        plot_slice_idx = None
        tumor_xy = None
        
        if is_visual_domain:
            if 'tumor_locations' in locals() and tumor_locations is not None:
                try:
                    loc = tumor_locations[int(patient), int(scan)]
                    tumor_xy = (loc[1].item(), loc[0].item())
                    plot_slice_idx = int(loc[2].item()) - 20
                except (IndexError, TypeError):
                    raise Exception(f"Could not find tumor location for {scan_name}.")
            else:
                raise Exception(f"Could not find tumor location for {scan_name}.")
        
        # --- Data Loading (Ground Truth Only) ---
        # For 'error' type, the domain is implicitly 'IMAG' for the ground truth
        gt_domain = 'IMAG' if model_type == 'error' else domain
        gt_volume = load_ground_truth(FILES, scan_info, gt_domain, slice_idx=None)
        gt_volume_np = gt_volume.cpu().numpy()
        
        # --- Iterative & Evidential Metric Calculation ---
        if model_type in ['stochastic', 'ensemble']:
            print("Calculating metrics iteratively for stochastic/ensemble model...")
            iq_metrics, mean_pred_vol_th, uncertainty_map_vol_th = calculate_volume_metrics_2_pass(
                FILES, model_config, scan_info, gt_volume.to(DEVICE), DEVICE
            )
        elif model_type == 'evidential':
            print("Calculating metrics for evidential model...")
            iq_metrics, mean_pred_vol_th, uncertainty_map_vol_th = calculate_evidential_volume_metrics(
                FILES, model_config, scan_info, gt_volume.to(DEVICE), DEVICE
            )
        elif model_type == 'error':
            print("Calculating metrics for error-prediction model...")
            iq_metrics, mean_pred_vol_th, uncertainty_map_vol_th = calculate_error_model_metrics(
                FILES, model_config, scan_info, gt_volume.to(DEVICE), DEVICE
            )
        else:
            raise ValueError(f"Unknown model type: {model_type}")

        mean_pred_vol = mean_pred_vol_th.cpu().numpy()
        uncertainty_map_vol = uncertainty_map_vol_th.cpu().numpy()
        del mean_pred_vol_th, uncertainty_map_vol_th
        
        # --- Raw Errors, Data Collection, and Per-Scan Histograms ---
        raw_errors_vol = gt_volume_np - mean_pred_vol
        
        # Store flattened arrays for later analysis
        all_model_errors[model_name].append(raw_errors_vol.flatten())
        all_model_stds[model_name].append(uncertainty_map_vol.flatten())
        
        # Plot per-scan histograms
        plot_error_histogram(raw_errors_vol.flatten(), model_name, scan_name)
        plot_std_histogram(uncertainty_map_vol.flatten(), model_name, scan_name)

        # --- Uncertainty Metric Calculation ---
        errors_vol = np.abs(raw_errors_vol)

        print("Calculating AUSE...")
        ause_val = calculate_ause_sparsification(uncertainty_map_vol, errors_vol)

        print("Calculating Spearman's correlation (uncertainty vs. error)...")
        spearman_val = calculate_spearman_correlation(uncertainty_map_vol, errors_vol, DEVICE)
        
        print("Calculating Pearson's correlation (uncertainty vs. error)...")
        pearson_val = calculate_pearson_correlation(uncertainty_map_vol, errors_vol, DEVICE)

        print("Calculating MPIW...")
        mpiw_metrics = calculate_mpiw(uncertainty_map_vol, confidence_levels=[0.95])

        # --- Store Results ---
        scan_result = {
            'model_name': model_name,
            'scan_name': scan_name,
            **iq_metrics,
            'ause': ause_val,
            'spearman_corr_uncert_err': spearman_val,
            'pearson_corr_uncert_err': pearson_val,
            **mpiw_metrics,
        }
        scan_results.append(scan_result)
        
        # --- Visualization ---
        print(f"\n--- Results for {model_name} on {scan_name} ---")
        
        # plot_calibration_curve(gt_volume_np, mean_pred_vol, uncertainty_map_vol, model_name, scan_name)
        # plot_sparsification_curve(uncertainty_map_vol, errors_vol, model_name, scan_name, DEVICE)
        
        if is_visual_domain:
            gt_slice_np = gt_volume_np[plot_slice_idx]
            mean_pred_slice = mean_pred_vol[plot_slice_idx]
            uncertainty_map_slice = uncertainty_map_vol[plot_slice_idx]
            
            plot_mean_comparison(mean_pred_slice, gt_slice_np, uncertainty_map_slice, 
                                 model_name, scan_name, plot_slice_idx, tumor_coords_xy=tumor_xy)

            plot_mean_comparison(mean_pred_slice, gt_slice_np, uncertainty_map_slice, 
                                 model_name, scan_name, plot_slice_idx, tumor_coords_xy=tumor_xy, log_scale=False, clip_pct=99)

            if model_type not in ['evidential', 'error']:
                # --- Load and plot a few samples for visual comparison ---
                print("Loading and plotting samples...")
                SAMPLES_TO_PLOT = 3
                sample_slices_for_plotting = []
                
                # Determine the number of samples to load (can't be more than what's available)
                num_to_load = min(SAMPLES_TO_PLOT, model_config['count'])

                for i in range(num_to_load):
                    passthrough_num = None
                    model_version = model_config['model_version_root']

                    if model_type == 'ensemble':
                        model_version = f"{model_config['model_version_root']}_{i+1:02d}"
                    elif model_type == 'stochastic':
                        passthrough_num = i
                    
                    # This loading logic is copied from your metrics function
                    pred = None
                    if domain == 'FDK':
                        pred_path = FILES.get_recon_filepath(model_version, patient, scan, scan_type_agg, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
                        pred = torch.load(pred_path).detach()
                        pred = pred[20:-20, :, :]
                        if scan_type_agg == "FF":
                            pred = pred[:, 128:-128, 128:-128]
                        pred = 25. * torch.clip(pred, min=0.0, max=0.04)
                    elif domain == 'IMAG':
                        pred_path = FILES.get_images_results_filepath(model_version, patient, scan, passthrough_num=passthrough_num, ensure_exists=False)
                        pred = torch.load(pred_path).detach()
                        pred = torch.squeeze(pred, dim=1)
                        pred = torch.permute(pred, (0, 2, 1))
                    
                    # Get the specific slice and convert to a numpy array for plotting
                    if pred is not None:
                        sample_slice = pred[plot_slice_idx].cpu().numpy()
                        sample_slices_for_plotting.append(sample_slice)

                # Call the new plotting function
                if sample_slices_for_plotting:
                    plot_samples_comparison(
                        ground_truth=gt_slice_np,
                        mean_pred=mean_pred_slice,
                        samples=sample_slices_for_plotting,
                        model_name=model_name,
                        scan_name=scan_name,
                        slice_idx=plot_slice_idx,
                        tumor_coords_xy=tumor_xy
                    )

                # print("Finding and plotting worst samples by Smooth L1 Loss...")
                # WORST_SAMPLES_TO_PLOT = 5

                # # List to store tuples of (loss, sample_slice_numpy, sample_index)
                # worst_samples_data = []
                # # Get the ground truth slice as a tensor on the correct device
                # gt_slice_tensor = gt_volume[plot_slice_idx].to(DEVICE)

                # # Loop through all available samples to find the worst ones
                # for i in tqdm(range(model_config['count']), desc="Finding Worst Samples", leave=False):
                #     passthrough_num = None
                #     model_version = model_config['model_version_root']
                #     model_type = model_config['type']

                #     if model_type == 'ensemble':
                #         model_version = f"{model_config['model_version_root']}_{i+1:02d}"
                #     elif model_type == 'stochastic':
                #         passthrough_num = i
                    
                #     # Load the prediction volume as a tensor on the GPU
                #     pred_vol_tensor = None
                #     if domain == 'FDK':
                #         pred_path = FILES.get_recon_filepath(model_version, patient, scan, scan_type_agg, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
                #         pred_vol_tensor = torch.load(pred_path, map_location=DEVICE).detach()
                #         pred_vol_tensor = pred_vol_tensor[20:-20, :, :]
                #         pred_vol_tensor = 25. * torch.clip(pred_vol_tensor, min=0.0, max=0.04)
                #     elif domain == 'IMAG':
                #         pred_path = FILES.get_images_results_filepath(model_version, patient, scan, passthrough_num=passthrough_num, ensure_exists=False)
                #         pred_vol_tensor = torch.load(pred_path, map_location=DEVICE).detach()
                #         pred_vol_tensor = torch.squeeze(pred_vol_tensor, dim=1)
                #         pred_vol_tensor = torch.permute(pred_vol_tensor, (0, 2, 1))

                #     if pred_vol_tensor is not None:
                #         pred_slice_tensor = pred_vol_tensor[plot_slice_idx]
                        
                #         # Calculate Smooth L1 Loss for the current slice
                #         import torch.nn.functional as F
                #         loss = F.smooth_l1_loss(pred_slice_tensor, gt_slice_tensor, reduction='mean').item()

                #         # Keep track of the top 5 worst samples (highest loss)
                #         if len(worst_samples_data) < WORST_SAMPLES_TO_PLOT:
                #             worst_samples_data.append((loss, pred_slice_tensor.cpu().numpy(), i))
                #         else:
                #             # Find the sample with the minimum loss currently in our list
                #             min_loss_in_list = min(worst_samples_data, key=lambda x: x[0])
                #             if loss > min_loss_in_list[0]:
                #                 # If current sample is worse, replace the "best of the worst"
                #                 worst_samples_data.remove(min_loss_in_list)
                #                 worst_samples_data.append((loss, pred_slice_tensor.cpu().numpy(), i))

                # # Sort the final list from worst to best for plotting
                # worst_samples_data.sort(key=lambda x: x[0], reverse=True)

                # # Call the new plotting function with the results
                # if worst_samples_data:
                #     plot_worst_samples_comparison(
                #         ground_truth=gt_slice_np,
                #         mean_pred=mean_pred_slice,
                #         worst_samples_data=worst_samples_data,
                #         model_name=model_name,
                #         scan_name=scan_name,
                #         slice_idx=plot_slice_idx,
                #         tumor_coords_xy=tumor_xy
                #     )

        # --- Clean up memory ---
        del gt_volume, gt_volume_np, mean_pred_vol, uncertainty_map_vol, errors_vol, raw_errors_vol
        gc.collect()
        
    all_results.extend(scan_results)

# Convert results to a pandas DataFrame for easier analysis
results_df = pd.DataFrame(all_results)

print("\n\n✅ Analysis complete for all models and scans.")
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
results_df

### Summary/Comparison

In [None]:
print("\n\n--- Combined Analysis Across All Scans ---")

for model_name in [m['name'] for m in MODELS_TO_ANALYZE]:
    print(f"\n{'='*25}\nMODEL: {model_name}\n{'='*25}")

    # --- 1. Combined Histograms ---
    print("\n--- Combined Histograms ---")
    # Concatenate all errors and stds for the current model
    if all_model_errors[model_name]:
        combined_errors = np.concatenate(all_model_errors[model_name])
        plot_error_histogram(combined_errors, model_name, "", combined=True)
    else:
        print("No error data to plot for combined histogram.")

    if all_model_stds[model_name]:
        combined_stds = np.concatenate(all_model_stds[model_name])
        plot_std_histogram(combined_stds, model_name, "", combined=True)
    else:
        print("No std dev data to plot for combined histogram.")


    # --- 2. Cross-Scan Correlations ---
    print("\n--- Cross-Scan Correlation Results ---")
    model_df = results_df[results_df['model_name'] == model_name].copy()

    if len(model_df) < 2:
        print("Cannot calculate cross-scan correlations with fewer than 2 scans.")
        continue

    # Define the metrics to correlate
    uncert_metrics_to_correlate = ['mean_std', 'rmv']
    iq_metrics_to_correlate = ['mean_ssim', 'mean_psnr', 'sample_avg_ssim', 'sample_avg_psnr']
    
    # For evidential models, sample_avg is the same as mean, so we remove duplicates
    if MODELS_TO_ANALYZE[0]['type'] == 'evidential':
        iq_metrics_to_correlate = ['mean_ssim', 'mean_psnr']

    correlation_results = []

    for uncert_metric in uncert_metrics_to_correlate:
        for iq_metric in iq_metrics_to_correlate:
            # Ensure the columns exist before trying to access them
            if uncert_metric in model_df.columns and iq_metric in model_df.columns:
                x = model_df[uncert_metric]
                y = model_df[iq_metric]

                pearson_corr, pearson_p = scipy.stats.pearsonr(x, y)
                spearman_corr, spearman_p = scipy.stats.spearmanr(x, y)

                correlation_results.append({
                    'Uncertainty Metric': uncert_metric,
                    'IQ Metric': iq_metric,
                    'Pearson Correlation': pearson_corr,
                    'Pearson p-value': pearson_p,
                    'Spearman Correlation': spearman_corr,
                    'Spearman p-value': spearman_p,
                })

    if correlation_results:
        corr_df = pd.DataFrame(correlation_results)
        display(corr_df.round(4))
    else:
        print("Could not compute any correlations.")

    # --- 3. Cross-Scan Correlation Scatter Plots ---
    print("\n--- Cross-Scan Correlation Scatter Plots ---")
    for uncert_metric in uncert_metrics_to_correlate:
        num_iq_metrics = len(iq_metrics_to_correlate)
        fig, axes = plt.subplots(1, num_iq_metrics, figsize=(5 * num_iq_metrics, 4.5))
        fig.suptitle(f'{model_name}: {uncert_metric} vs. Image Quality Metrics', fontsize=16)
        
        # Handle case where there's only one subplot
        if num_iq_metrics == 1:
            axes = [axes]

        for i, iq_metric in enumerate(iq_metrics_to_correlate):
            if uncert_metric in model_df.columns and iq_metric in model_df.columns:
                axes[i].scatter(model_df[uncert_metric], model_df[iq_metric], alpha=0.7)
                axes[i].set_xlabel(uncert_metric)
                axes[i].set_ylabel(iq_metric)
                axes[i].set_title(f'{uncert_metric} vs. {iq_metric}')
                axes[i].grid(True, linestyle=':')
        
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.show()

# --- 4. Cross-Model Histogram Comparison ---
print(f"\n{'='*25}\nCROSS-MODEL HISTOGRAMS\n{'='*25}")

# Plotting combined error histograms for all models (with range limits)
plt.figure(figsize=(12, 6))
for model_name, errors_list in all_model_errors.items():
    if errors_list:
        combined_errors = np.concatenate(errors_list)
        range_lims = np.percentile(combined_errors, [0.1, 99.9])
        plt.hist(combined_errors, bins=200, log=True, range=range_lims, alpha=0.6, label=model_name)
plt.title('Combined Raw Error Distribution (All Models, 0.1-99.9% Range)')
plt.xlabel('Error (GT - Prediction)')
plt.ylabel('Frequency')
plt.legend()
plt.grid(True, linestyle=':')
plt.show()

# Plotting combined error histograms for all models (no range limits)
plt.figure(figsize=(12, 6))
for model_name, errors_list in all_model_errors.items():
    if errors_list:
        combined_errors = np.concatenate(errors_list)
        plt.hist(combined_errors, bins=200, log=True, alpha=0.6, label=model_name)
plt.title('Combined Raw Error Distribution (All Models, Full Range)')
plt.xlabel('Error (GT - Prediction)')
plt.ylabel('Frequency')
plt.legend()
plt.grid(True, linestyle=':')
plt.show()

# Plotting combined std dev histograms for all models (with range limits)
plt.figure(figsize=(12, 6))
for model_name, stds_list in all_model_stds.items():
    if stds_list:
        combined_stds = np.concatenate(stds_list)
        range_lims = [0, np.percentile(combined_stds, 99.9)]
        plt.hist(combined_stds, bins=200, log=True, range=range_lims, alpha=0.6, label=model_name)
plt.title('Combined Std Deviation Distribution (All Models, (0-99.9% Range))')
plt.xlabel('Predicted Standard Deviation')
plt.ylabel('Frequency')
plt.legend()
plt.grid(True, linestyle=':')
plt.show()

# Plotting combined std dev histograms for all models (no range limits)
plt.figure(figsize=(12, 6))
for model_name, stds_list in all_model_stds.items():
    if stds_list:
        combined_stds = np.concatenate(stds_list)
        plt.hist(combined_stds, bins=200, log=True, alpha=0.6, label=model_name)
plt.title('Combined Std Deviation Distribution (All Models, Full Range)')
plt.xlabel('Predicted Standard Deviation')
plt.ylabel('Frequency')
plt.legend()
plt.grid(True, linestyle=':')
plt.show()

In [None]:
# results_df = pd.read_csv('MCdropout_FDK_results_val.csv')

# Only aggregate numeric columns
numeric_cols = results_df.select_dtypes(include=[np.number]).columns
summary = results_df.groupby('model_name')[numeric_cols].agg(['mean', 'std']).reset_index()

# save results_df and summary to CSV files
# results_df.to_csv('MCdropout_FDK_results_val.csv', index=False)
# summary.to_csv('MCdropout_FDK_summary_val.csv', index=False)

# summaries = [
#     ('BBB pi=0.75 mu=0.0 sigma1=1e-1 sigma2=1e-3 beta=1e-2 (50)', 'BBB_summary_test.csv'),
#     ('MC Dropoout 30% (50)', 'MCdropout_30_summary_test.csv'),
#     ('Ensemble (10)', 'ensemble_summary_test.csv'),
# ]

# # Read and concatenate all summary files
# import pandas as pd
# summary_list = []
# for name, path in summaries:
#     summary_df = pd.read_csv(path, header=[0, 1])  # Read with multi-index columns

#     # Note there might be multiple rows
#     # We only take the row that matches the model name
#     # Find the row(s) where the model_name matches the given name (allow partial match for ensemble size)
#     row_mask = summary_df[('model_name', 'Unnamed: 0_level_1')] == name
#     summary_df = summary_df[row_mask].reset_index(drop=True)

#     summary_df['model_name'] = name  # Add model name column
#     summary_list.append(summary_df)
# summary = pd.concat(summary_list, ignore_index=True)
# numeric_cols = [col for col, id in summary.select_dtypes(include=[np.number]).columns if id == 'mean']

# Prepare summary_display with formatted mean ± std for each metric
summary_display = pd.DataFrame()
summary_display['model_name'] = summary['model_name']
for col in numeric_cols:
    mean_col = (col, 'mean')
    std_col = (col, 'std')
    summary_display[col] = summary[mean_col].map('{:.4f}'.format) + ' ± ' + summary[std_col].map('{:.4f}'.format)

print("\n\n=======================================================")
print("               Model Comparison Summary")
print("=======================================================")

pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)
display(summary_display)

BBB NOTES:
BBB pi=0.75, sigma1=0.1, sigma2=0.001, beta=0.01: decent (small hallucination) <-- KEEP
BBB pi=0.25, sigma1=0.1, sigma2=0.001, beta=0.01: decent (bad metrics)
BBB pi=0.5, sigma1=0.1, sigma2=0.001, beta=0.01: bad (big hallucination)
BBB pi=0.5, sigma1=0.3, sigma2=0.001, beta=0.01: bad (hallucination)
BBB pi=0.5, sigma1=0.03, sigma2=0.001, beta=0.01: decent (ok-ish image quality) <-- KEEP
BBB pi=0.5, sigma1=0.01, sigma2=0.003, beta=0.01: decent (ok-ish image quality) <-- KEEP
BBB pi=0.5, sigma1=0.01, sigma2=0.03, beta=0.01: bad texture in air
BBB pi=0.5, sigma1=0.01, sigma2=0.001, beta=0.001: bad texture in air, and hallucination
BBB pi=0.5, sigma1=0.01, sigma2=0.001, beta=0.1: hallucination
BBB pi=0.75, sigma1=0.3, sigma2=0.001, beta=0.01: hallucination
BBB pi=0.75, sigma1=0.03, sigma2=0.001, beta=0.01: minor hallucination outside <-- KEEP
BBB pi=0.75, sigma1=0.1, sigma2=0.003, beta=0.01: minor hallucinations outside <-- KEEP
BBB pi=0.75, sigma1=0.1, sigma2=0.0003, beta=0.01: hallucination
BBB pi=0.75, sigma1=0.1, sigma2=0.001, beta=0.001: hallucination
BBB pi=0.75, sigma1=0.01, sigma2=0.001, beta=0.1: minor hallucination outside <-- KEEP

### Compare results across models/ensembles

In [None]:
# Look at the trends of image quality and uncrtainty metrics over ensemble size
import pandas as pd

results_df = pd.read_csv('MCdropout_30_summary_val.csv')

# Now plot the trends of image quality and uncertainty metrics over ensemble size
import matplotlib.pyplot as plt

# Note that columns come in pairs: mean and std like this "mean_ssim", "mean_ssim.1" and the second one is the std
# and we also need to skip the first row after the headers
columns = results_df.columns[1:]  # Skip the first column which is 'model_name'
# delete the first row after the headers
results_df = results_df.iloc[1:]  # Skip the first row after the headers
ensemble_sizes = results_df['model_name'].str.extract(r'\((\d+)\)').astype(int).values.flatten()
# Plot each metric on a separate plot
# Group columns by metric type
ssim_cols = [col for col in columns if 'ssim' in col and not col.endswith('.1')]
psnr_cols = [col for col in columns if 'psnr' in col and not col.endswith('.1')]
mae_cols = [col for col in columns if 'mae' in col and not col.endswith('.1')]
mse_cols = [col for col in columns if 'mse' in col and not col.endswith('.1')]
ause_cols = [col for col in columns if 'ause' in col and not col.endswith('.1')]
corr_cols = [col for col in columns if 'corr' in col and not col.endswith('.1')]

def plot_metrics(metric_cols, title):
    plt.figure(figsize=(8, 5))
    for col in metric_cols:
        metric_mean = results_df[col].astype(float)
        metric_std = results_df[f"{col}.1"].astype(float)
        plt.errorbar(ensemble_sizes, metric_mean, yerr=metric_std, label=col, fmt='-o', capsize=5)
    plt.title(title)
    plt.xlabel('Ensemble Size')
    plt.xticks(ensemble_sizes)
    plt.legend()
    plt.grid(True)
    plt.show()

# Plot grouped metrics
if ssim_cols:
    plot_metrics(ssim_cols, 'SSIM Metrics Over Ensemble Size')
if psnr_cols:
    plot_metrics(psnr_cols, 'PSNR Metrics Over Ensemble Size')
if mae_cols:
    plot_metrics(mae_cols, 'MAE Metrics Over Ensemble Size')
if mse_cols:
    plot_metrics(mse_cols, 'MSE Metrics Over Ensemble Size')
if ause_cols:
    plot_metrics(ause_cols, 'AUSE Metrics Over Ensemble Size')
if corr_cols:
    plot_metrics(corr_cols, 'Correlation Metrics Over Ensemble Size')

In [None]:
import matplotlib.pyplot as plt

# List of CSVs to compare
csv_files = [
    # 'MCdropout_15_summary_val.csv',
    # 'MCdropout_30_summary_val.csv',
    # 'MCdropout_50_summary_val.csv',
    # 'ensemble_summary_val.csv',
    'MCdropout_30_summary_test.csv',
    # 'ensemble_summary_test.csv',
    'BBB_summary_test.csv'
    # Add more CSV file paths here
]

results_dfs = [pd.read_csv(csv) for csv in csv_files]
labels = [csv.split('.')[0] for csv in csv_files]
colors = plt.cm.tab10.colors  # Up to 10 distinct colors

def get_metric_cols(columns, metric):
    return [col for col in columns if metric in col and not col.endswith('.1')]

metrics = ['ssim', 'psnr', 'mae', 'mse', 'ause', 'corr']
metric_titles = {
    'ssim': 'SSIM Metrics Over Ensemble Size',
    'psnr': 'PSNR Metrics Over Ensemble Size',
    'mae': 'MAE Metrics Over Ensemble Size',
    'mse': 'MSE Metrics Over Ensemble Size',
    'ause': 'AUSE Metrics Over Ensemble Size',
    'corr': 'Correlation Metrics Over Ensemble Size'
}

for metric in metrics:
    plt.figure(figsize=(8, 5))
    # Find the minimum max ensemble size across all CSVs
    max_x = min([
        results_df.iloc[1:]['model_name'].str.extract(r'\((\d+)\)').astype(int).values.flatten().max()
        for results_df in results_dfs
    ])
    for i, results_df in enumerate(results_dfs):
        columns = results_df.columns[1:]
        results_df = results_df.iloc[1:]
        ensemble_sizes = results_df['model_name'].str.extract(r'\((\d+)\)').astype(int).values.flatten()
        metric_cols = get_metric_cols(columns, metric)
        for j, col in enumerate(metric_cols):
            metric_mean = results_df[col].astype(float).values
            # Sort by ensemble size
            sorted_idx = np.argsort(ensemble_sizes)
            sorted_ensemble_sizes = ensemble_sizes[sorted_idx]
            sorted_metric_mean = metric_mean[sorted_idx]
            linestyle = '-' if j == 0 else '--'
            label = f"{labels[i]}: {col}"
            # Only plot up to max_x
            mask = sorted_ensemble_sizes <= max_x
            plt.plot(
                sorted_ensemble_sizes[mask], sorted_metric_mean[mask],
                label=label,
                color=colors[i % len(colors)], linestyle=linestyle, marker='o'
            )
    plt.title(metric_titles[metric])
    plt.xlabel('Ensemble Size')
    plt.xlim(left=None, right=max_x)
    plt.xticks(np.arange(sorted_ensemble_sizes.min(), max_x + 1))
    plt.legend()
    plt.grid(True)
    plt.show()


## Post-calibration processing & results

In [None]:
### Calibration Analysis (Train on Validation, Evaluate on Test)
# This cell performs a separate analysis focused on uncertainty calibration.
# 1. It trains calibration models (Platt, Isotonic) for each model on the VALIDATION data.
# 2. It then evaluates the uncalibrated and calibrated uncertainties on the TEST data.

# This list will store dictionaries of post-calibration results
calibration_results = []

# Load the scan lists for validation and testing, assuming they don't change
all_scans, _ = read_scans_agg_file(SCANS_AGG_FILE)
validation_scans = all_scans['VALIDATION']
test_scans = all_scans['TEST']

for model_config in MODELS_TO_ANALYZE:
    model_name = model_config['name']
    domain = model_config['domain']
    
    # ======================================================================
    # 1. CALIBRATION TRAINING on the VALIDATION set
    # ======================================================================
    print(f"\n--- Training calibration for model: {model_name} ---")
    
    all_val_errors = []
    all_val_uncertainties = []
    all_val_means = []
    all_val_gt = []

    for scan_info in tqdm(validation_scans, desc=f"Gathering validation data for {model_name}", leave=False):
        # Load GT and calculate uncalibrated predictions for this validation scan
        gt_volume = load_ground_truth(FILES, scan_info, domain)
        # The 'calculate_volume_metrics_2_pass' is efficient for getting mean and std dev
        if model_config['type'] in ['stochastic', 'ensemble']:
            _, mean_pred_vol, uncertainty_map_vol = calculate_volume_metrics_2_pass(
                FILES, model_config, scan_info, gt_volume.to(DEVICE), DEVICE
            )
        elif model_config['type'] == 'evidential':
            _, mean_pred_vol, uncertainty_map_vol = calculate_evidential_volume_metrics(
                FILES, model_config, scan_info, gt_volume.to(DEVICE), DEVICE
            )
        else:
            raise ValueError(f"Unknown model type: {model_config['type']}")
        
        # Move results to CPU and convert to numpy for calibration training
        gt_volume_np = gt_volume.cpu().numpy()
        del gt_volume
        mean_pred_vol_np = mean_pred_vol.cpu().numpy()
        del mean_pred_vol
        uncertainty_map_vol_np = uncertainty_map_vol.cpu().numpy()
        del uncertainty_map_vol

        errors_vol = np.abs(gt_volume_np - mean_pred_vol_np)
        
        all_val_errors.append(errors_vol.flatten())
        all_val_uncertainties.append(uncertainty_map_vol_np.flatten())
        # Store means and GT for CDF calibration
        all_val_means.append(mean_pred_vol_np.flatten())
        all_val_gt.append(gt_volume_np.flatten())


    # Consolidate all validation data into single arrays
    val_errors_full = np.concatenate(all_val_errors)
    val_uncertainties_full = np.concatenate(all_val_uncertainties)
    # Consolidate means and GT
    val_means_full = np.concatenate(all_val_means)
    val_gt_full = np.concatenate(all_val_gt)
    
    # Train the calibration models using the full validation dataset
    print("Calculating Platt Scaler...")
    platt_scaler = calculate_platt_scaler(val_errors_full, val_uncertainties_full)
    print(f"Platt Scaler T={platt_scaler:.4f}")
    
    print("Training Isotonic Regression model on variance...")
    isotonic_model = train_isotonic_regression(val_uncertainties_full, val_errors_full, DEVICE)
    print("Done.")

    # # Train the CDF-based Isotonic Regression model
    # print("Training Isotonic Regression model on CDF...")
    # iso_cdf_model = train_cdf_isotonic_regression(val_gt_full, val_means_full, val_uncertainties_full, DEVICE)
    # print("Done.")
        
    # Clean up memory from the training phase
    del all_val_errors, all_val_uncertainties, val_errors_full, val_uncertainties_full
    del all_val_means, all_val_gt, val_means_full, val_gt_full
    gc.collect()

    # ======================================================================
    # 2. FINAL EVALUATION on the TEST set
    # ======================================================================
    print(f"\n--- Evaluating model: {model_name} on the TEST set ---")
    
    for scan_info in tqdm(test_scans, desc=f"Analyzing Test Scans for {model_name}"):
        patient, scan, _ = scan_info
        scan_name = f"p{patient}_{scan}"
        
        # --- Data Loading & Base Uncalibrated Calculation ---
        gt_volume = load_ground_truth(FILES, scan_info, domain)
        gt_volume_np = gt_volume.cpu().numpy()
        
        if model_config['type'] in ['stochastic', 'ensemble']:
            _, mean_pred_vol, uncal_uncertainty_map = calculate_volume_metrics_2_pass(
                FILES, model_config, scan_info, gt_volume.to(DEVICE), DEVICE
            )
        elif model_config['type'] == 'evidential':
            _, mean_pred_vol, uncal_uncertainty_map = calculate_evidential_volume_metrics(
                FILES, model_config, scan_info, gt_volume.to(DEVICE), DEVICE
            )
        else:
            raise ValueError(f"Unknown model type: {model_config['type']}")
        del gt_volume
        
        mean_pred_vol_np = mean_pred_vol.cpu().numpy()
        del mean_pred_vol
        uncal_uncertainty_map_np = uncal_uncertainty_map.cpu().numpy()
        del uncal_uncertainty_map
        
        # --- Apply Calibrations ---
        print("Applying calibrations...")
        platt_uncertainty_map_np = uncal_uncertainty_map_np * platt_scaler
        iso_uncertainty_map_np = isotonic_model(uncal_uncertainty_map_np)
        # # --- Apply the CDF-based calibration ---
        # iso_cdf_uncertainty_map_np = apply_cdf_isotonic_regression(iso_cdf_model, uncal_uncertainty_map_np)

        # --- Store results for this scan ---
        scan_result = {'model_name': model_name, 'scan_name': scan_name}
        
        # --- Calculate and Store All Metrics (Uncalibrated, Platt, Isotonic) ---
        calibrations = {
            'platt': platt_uncertainty_map_np,
            'iso_var': iso_uncertainty_map_np,
            # 'iso_cdf': iso_cdf_uncertainty_map_np,
        }

        # We need the error map for AUSE calculation
        errors_vol = np.abs(gt_volume_np - mean_pred_vol_np)
        
        print("Evaluating metrics for each calibration...")
        for cal_name, uncertainty_map in calibrations.items():
            print("Calculating metrics for calibration:", cal_name)

            print("Calculating AUSE...")
            ause_val = calculate_ause_sparsification(uncertainty_map, errors_vol)
            scan_result[f'ause_{cal_name}'] = ause_val

            print("Calculating ECE...")
            for n_bins in [10, 20, 50]:
                eces = calculate_all_eces(gt_volume_np, mean_pred_vol_np, uncertainty_map, n_bins)
                for key, val in eces.items():
                    scan_result[f'{key}_{n_bins}bins_{cal_name}'] = val

            print("Calculating ENCE...")         
            for n_bins in [10, 20, 50]:
                scan_result[f'ence_{n_bins}bins_{cal_name}'] = calculate_ence(gt_volume_np, mean_pred_vol_np, uncertainty_map, DEVICE, n_bins)
            
            print("Calculating NLL...")
            scan_result[f'nll_{cal_name}'] = calculate_nll(gt_volume_np, mean_pred_vol_np, uncertainty_map)
            
            print("Calculating MPIW...")
            mpiws = calculate_mpiw(uncertainty_map, confidence_levels=[0.68, 0.95])
            for key, val in mpiws.items():
                scan_result[f'{key}_{cal_name}'] = val
            print("Done with metrics for", cal_name)

        calibration_results.append(scan_result)

        # Platt Scaling Plots
        print(f"--- Generating combined calibration plots for {model_name} on {scan_name} ---")
        # plot_combined_calibration_curves(
        #     ground_truth=gt_volume_np,
        #     mean_pred=mean_pred_vol_np,
        #     platt_uncertainty_map=calibrations['platt'],
        #     iso_uncertainty_map=calibrations['iso_var'],
        #     iso_cdf_uncertainty_map=calibrations['iso_cdf'],
        #     model_name=model_name,
        #     scan_name=scan_name,
        #     n_bins=20
        # )
        plot_combined_calibration_curves_old(
            ground_truth=gt_volume_np,
            mean_pred=mean_pred_vol_np,
            platt_uncertainty_map=calibrations['platt'],
            iso_uncertainty_map=calibrations['iso_var'],
            model_name=model_name,
            scan_name=scan_name,
            n_bins=20
        )
        
        # --- Clean up memory ---
        del gt_volume_np, mean_pred_vol_np, uncal_uncertainty_map_np, errors_vol
        del platt_uncertainty_map_np, iso_uncertainty_map_np #, iso_cdf_uncertainty_map_np
        gc.collect()

# Convert results to a pandas DataFrame for easier analysis
calibration_results_df = pd.DataFrame(calibration_results)

print("\n\n✅ Calibration analysis complete.")
calibration_results_df

In [None]:
### Aggregate Results Summary

# Select only the numeric columns for aggregation
numeric_cols = calibration_results_df.select_dtypes(include=[np.number]).columns
summary = calibration_results_df.groupby('model_name')[numeric_cols].agg(['mean', 'std']).reset_index()

# Save results_df and summary to CSV files
calibration_results_df.to_csv('BBB_results_calibration.csv', index=False)
summary.to_csv('BBB_summary_calibration.csv', index=False)

# Prepare a display DataFrame with formatted 'mean ± std' strings
summary_display = pd.DataFrame()
summary_display['model_name'] = summary['model_name']

for col in numeric_cols:
    mean_col = (col, 'mean')
    std_col = (col, 'std')
    # Format the string, handling potential NaN values in std dev
    summary_display[col] = summary[mean_col].map('{:.4f}'.format) + ' ± ' + summary[std_col].map('{:.4f}'.format)

print("\\n\\n=======================================================")
print("               Model Comparison Summary")
print("=======================================================")

pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)
display(summary_display)

In [None]:
# Example: Combine multiple summary DataFrames from different CSVs
import pandas as pd

# List of summary CSV files to compare
summary_csv_files = [
    'MCdropout_30_summary_calibration.csv',
    'ensemble_summary_calibration.csv',
    'BBB_summary_calibration.csv',
    # Add more summary CSV file paths here
]
columns_to_include = [
    # 'ece_weighted_abs_20bins_platt',
    'ece_unweighted_abs_20bins_platt',
    'ence_20bins_platt',
    'nll_platt',
    # 'mpiw_68_platt',
    'mpiw_95_platt',
    # 'ece_weighted_abs_20bins_iso',
    'ece_unweighted_abs_20bins_iso',
    'ence_20bins_iso',
    'nll_iso',
    # 'mpiw_68_iso',
    'mpiw_95_iso',
]

# Load each summary CSV as a DataFrame and append to a list
summaries = [pd.read_csv(csv) for csv in summary_csv_files]

for i in range(len(summaries)):
    # Discard the first non-header row
    summaries[i] = summaries[i].iloc[1:]

    # Convert all columns to numeric (except 'model_name')
    for col in summaries[i].columns:
        if col != 'model_name':
            summaries[i][col] = pd.to_numeric(summaries[i][col], errors='raise')

# Concatenate all summary DataFrames row-wise
if summaries:
    combined_summary = pd.concat(summaries, ignore_index=True)
    print("\n\n=======================================================")
    print("         Combined Model Calibration Summaries")
    print("=======================================================")
    pd.set_option('display.max_columns', None)
    pd.set_option('display.max_colwidth', None)
    # Display the dataframe with values formatted as 'mean ± std'
    for col in columns_to_include:
        mean_col = col
        std_col = col + '.1'
        combined_summary[col] = combined_summary[mean_col].map('{:.4f}'.format) + ' ± ' + combined_summary[std_col].map('{:.4f}'.format)
    
    # Keep only the relevant columns
    combined_summary = combined_summary[['model_name'] + columns_to_include]
    display(combined_summary)
else:
    print("No summary CSV files provided.")