### Setup/Imports

In [None]:
import os
import gc
import numpy as np
import torch
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from tqdm import tqdm
import pandas as pd
import scipy.io
from pipeline.paths import Directories, Files
from pipeline.utils import read_scans_agg_file

### Config

In [None]:
PHASE = "7"
DATA_VERSION = "13"
WORK_ROOT = "D:/NoahSilverberg/ngCBCT"
SPLIT_TO_ANALYZE = 'VALIDATION'  # Options: 'TRAIN', 'VALIDATION', 'TEST'

MODELS_TO_ANALYZE = [
    {
        'name': 'MC Dropoout 30%',
        'type': 'stochastic',
        'domain': 'FDK',
        'model_version_root': 'MK7_MCDROPOUT_30_pct_NEW',
        'count': 20,
    },
    {
        'name': 'Ensemble',
        'type': 'ensemble',
        'domain': 'FDK',
        'model_version_root': 'MK7',
        'count': 5,
    },
    # Add other models here
]

# Path to the .pt file containing tumor locations.
# This is a 5D tensor [patient, scan, (x, y, z)]
TUMOR_LOCATIONS_FILE = 'D:/NoahSilverberg/ngCBCT/3D_recon/tumor_location.pt'

# --- Advanced Config ---
SCANS_AGG_FILE = 'scans_to_agg.txt'
SSIM_KWARGS = {"K1": 0.03, "K2": 0.06, "win_size": 15}

# --- Setup ---
# Create Directories and Files objects
phase_dataver_dir = os.path.join(WORK_ROOT, f"phase{PHASE}", f"DS{DATA_VERSION}")
DIRECTORIES = Directories(
    projections_results_dir=os.path.join(phase_dataver_dir, "results", "projections"),
    projections_gated_dir=os.path.join(WORK_ROOT, "gated", "prj_mat"),
    reconstructions_dir=os.path.join(phase_dataver_dir, "reconstructions"),
    reconstructions_gated_dir=os.path.join(WORK_ROOT, "gated", "fdk_recon"),
    images_results_dir=os.path.join(phase_dataver_dir, "results", "images"),
)
FILES = Files(DIRECTORIES)

# Load the list of scans
scans_agg, scan_type_agg = read_scans_agg_file(SCANS_AGG_FILE)
analysis_scans = scans_agg[SPLIT_TO_ANALYZE][-1:] # TODO

# Load tumor locations
tumor_locations = torch.load(TUMOR_LOCATIONS_FILE, weights_only=False)

print(f"\nConfiguration loaded.")
print(f"Analyzing {len(analysis_scans)} scans from the '{SPLIT_TO_ANALYZE}' split.")
print(f"Found {len(MODELS_TO_ANALYZE)} model(s) to analyze.")

### Data loading/prep functions

In [None]:
def load_ground_truth(files_obj: Files, scan_info, domain, slice_idx=None):
    """
    Loads the ground truth data for a given scan and domain.
    """
    patient, scan, scan_type = scan_info
    
    if domain == 'PROJ':
        gt_path = files_obj.get_projections_results_filepath('fdk', patient, scan, scan_type, gated=True)
        data = torch.from_numpy(scipy.io.loadmat(gt_path)['prj']).detach().permute(1, 0, 2)
    elif domain == 'FDK':
        gt_path = files_obj.get_recon_filepath("fdk", patient, scan, scan_type, gated=True, ensure_exists=False)
        data = torch.load(gt_path).detach()
        data = data[20:-20, :, :]
        data = 25. * torch.clip(data, min=0.0, max=0.04)
    elif domain == 'IMAG':
        # The ground truth for the IMAG domain is the FDK of the gated projection
        gt_path = files_obj.get_recon_filepath("fdk", patient, scan, scan_type, gated=True, ensure_exists=False)
        data = torch.load(gt_path).detach()
        data = data[20:-20, :, :]
        data = 25. * torch.clip(data, min=0.0, max=0.04)
    else:
        raise ValueError(f"Unknown domain: {domain}")

    if slice_idx is not None and data.ndim == 3:
        return data[slice_idx]
    return data

def load_predictions(files_obj: Files, model_config, scan_info, slice_idx=None):
    """
    Loads all predictions for a given model, scan, and domain.
    """
    patient, scan, scan_type = scan_info
    domain = model_config['domain']
    root = model_config['model_version_root']
    count = model_config['count']
    model_type = model_config['type']

    predictions = []
    
    print(f"Loading {count} predictions for {model_config['name']}...")
    
    for i in tqdm(range(count), desc="Loading predictions", leave=False):
        passthrough_num = None
        model_version = root

        if model_type == 'ensemble':
            model_version = f"{root}_{i+1:02d}"
        elif model_type == 'stochastic':
            passthrough_num = i

        if domain == 'PROJ':
            pred_path = files_obj.get_projections_results_filepath(model_version, patient, scan, scan_type, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
            pred = torch.from_numpy(scipy.io.loadmat(pred_path)['prj']).detach().permute(1, 0, 2)
        elif domain == 'FDK':
            pred_path = files_obj.get_recon_filepath(model_version, patient, scan, scan_type, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
            pred = torch.load(pred_path).detach()
            pred = pred[20:-20, :, :]
            pred = 25. * torch.clip(pred, min=0.0, max=0.04)
        elif domain == 'IMAG':
            # This assumes the results are saved with the ID model version name
            pred_path = files_obj.get_images_results_filepath(model_version, patient, scan, passthrough_num=passthrough_num, ensure_exists=False)
            pred = torch.load(pred_path).detach()
            pred = torch.squeeze(pred, dim=1)
            pred = torch.permute(pred, (0, 2, 1))
        else:
            raise ValueError(f"Unknown domain: {domain}")
            
        predictions.append(pred)

    predictions_tensor = torch.stack(predictions)
    
    if slice_idx is not None and predictions_tensor.ndim == 4:
        return predictions_tensor[:, slice_idx, :, :]
        
    return predictions_tensor

print("Data loading functions defined.")

### Metric calculation functions

In [None]:
def calculate_ause_sparsification(uncertainty, errors):
    """
    Calculates the Area Under the Sparsification Error curve (AUSE) efficiently.
    This version avoids the O(N^2) complexity of the naive implementation by using
    a vectorized approach with cumulative sums.

    Args:
        uncertainty (np.ndarray): The uncertainty map (e.g., standard deviation).
        errors (np.ndarray): The absolute error map (AE) between prediction and ground truth.

    Returns:
        float: The AUSE value.
    """
    uncertainty_flat = uncertainty.flatten()
    errors_flat = errors.flatten()

    # Sort errors based on descending uncertainty for the model curve
    model_sorted_indices = np.argsort(uncertainty_flat)[::-1]
    model_sorted_errors = errors_flat[model_sorted_indices]

    # Sort errors based on descending error for the oracle curve
    oracle_sorted_errors = np.sort(errors_flat)[::-1]

    # --- Efficiently calculate the sparsification curves ---
    def get_sparsification_curve_fast(sorted_errs):
        n_pixels = len(sorted_errs)
        # Calculate cumulative sum of errors once (O(N))
        cumulative_errors = np.cumsum(sorted_errs)
        total_error_sum = cumulative_errors[-1]

        # Calculate the sum of errors removed at each step k using the pre-computed sum
        sum_errors_removed = np.insert(cumulative_errors[:-1], 0, 0)

        # Vectorized calculation of remaining error sums
        sum_errors_remaining = total_error_sum - sum_errors_removed

        # Vectorized calculation of number of remaining pixels
        n_remaining = np.arange(n_pixels, 0, -1)

        # The sparsification curve (MAE of remaining pixels)
        curve = sum_errors_remaining / n_remaining
        return curve

    model_curve = get_sparsification_curve_fast(model_sorted_errors)
    oracle_curve = get_sparsification_curve_fast(oracle_sorted_errors)

    # Calculate the area between the two curves
    ause = np.mean(np.abs(model_curve - oracle_curve))

    return ause


def calculate_ece(ground_truth, mean_pred, uncertainty_map, n_levels=20):
    """
    Calculates the calibration error for regression tasks.

    This implementation is based on Equations 8 and 9 from Kuleshov et al., 2018,
    "Accurate Uncertainties for Deep Learning Using Calibrated Regression"[cite: 526].
    It is consistent with the logic in the paper and the provided GitHub repository.

    Args:
        ground_truth (np.ndarray): The ground truth volume/image.
        mean_pred (np.ndarray): The model's mean prediction volume/image.
        uncertainty_map (np.ndarray): The model's uncertainty (std dev) volume/image.
        n_levels (int): The number of confidence levels to use for the calculation.

    Returns:
        float: The calibration error score.
    """
    gt_flat = ground_truth.flatten()
    pred_flat = mean_pred.flatten()
    uncert_flat = uncertainty_map.flatten()

    # Step 1: Calculate the predicted CDF value F_t(y_t) for each ground truth point.
    # This is the first step in building a calibration plot (Section 3.5 of the paper) and
    # is consistent with the `pcdf` method in the provided `calibrated_regression.py`.
    # We assume the predictive distribution is Gaussian, as is standard[cite: 815].
    pred_cdfs = scipy.stats.norm.cdf(gt_flat, loc=pred_flat, scale=uncert_flat)

    # Step 2: Define the expected confidence levels, p_j, as per Equation 8[cite: 836].
    # These are the points on the x-axis of the calibration plot in Figure 3 of the paper[cite: 663, 698].
    expected_confidence_levels = np.linspace(0, 1, n_levels)

    # Step 3: Calculate the observed frequency, hat{p}_j, at each confidence level, as defined in Equation 8[cite: 836].
    # This is the empirical CDF of the `pred_cdfs` values, evaluated at each p_j.
    # It corresponds to the y-axis of the calibration plot.
    observed_frequencies = np.array([
        np.mean(pred_cdfs <= p_j) for p_j in expected_confidence_levels
    ])

    # Step 4: Calculate the final calibration error using Equation 9.
    # The paper states, "We used w_j = 1 in our experiments"[cite: 842].
    # This is the sum of squared errors between the calibration curve and the ideal diagonal line.
    calibration_error = np.sum((expected_confidence_levels - observed_frequencies)**2)

    return calibration_error

def calculate_volume_metrics_iteratively(files_obj: Files, model_config, scan_info, gt_volume_np: np.ndarray):
    """
    Loads one prediction volume at a time to calculate stats and metrics iteratively,
    avoiding high memory usage. Metrics like SSIM/PSNR are averaged over all slices.
    Computes mean and uncertainty using Welford's algorithm for online variance.
    """
    # --- Online Statistics Initialization for the entire volume ---
    n_samples = model_config['count']
    mean_volume_np = np.zeros_like(gt_volume_np, dtype=np.float32)
    m2_volume_np = np.zeros_like(gt_volume_np, dtype=np.float32)

    # --- Iterative Metric Initialization ---
    sample_avg_ssims, sample_avg_psnrs, sample_avg_mses, sample_avg_maes = [], [], [], []
    data_range = np.max(gt_volume_np) - np.min(gt_volume_np)

    # --- Prediction Generator ---
    def prediction_generator():
        patient, scan, scan_type = scan_info
        domain = model_config['domain']
        root = model_config['model_version_root']
        model_type = model_config['type']

        for i in range(n_samples):
            passthrough_num = None
            model_version = root
            if model_type == 'ensemble':
                model_version = f"{root}_{i+1:02d}"
            elif model_type == 'stochastic':
                passthrough_num = i

            # Load the full prediction volume (slice_idx=None)
            if domain == 'PROJ':
                pred_path = files_obj.get_projections_results_filepath(model_version, patient, scan, scan_type, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
                pred = torch.from_numpy(scipy.io.loadmat(pred_path)['prj']).detach().permute(1, 0, 2)
            elif domain == 'FDK':
                pred_path = files_obj.get_recon_filepath(model_version, patient, scan, scan_type, gated=False, passthrough_num=passthrough_num, ensure_exists=False)
                pred = torch.load(pred_path).detach()
                pred = pred[20:-20, :, :]
                pred = 25. * torch.clip(pred, min=0.0, max=0.04)
            elif domain == 'IMAG':
                pred_path = files_obj.get_images_results_filepath(model_version, patient, scan, passthrough_num=passthrough_num, ensure_exists=False)
                pred = torch.load(pred_path).detach()
                pred = torch.squeeze(pred, dim=1)
                pred = torch.permute(pred, (0, 2, 1))

            yield pred.cpu().numpy()

    # --- Main Iteration Loop ---
    for i, pred_volume_np in enumerate(tqdm(prediction_generator(), total=n_samples, desc="Iterative Metrics", leave=False)):
        # --- Calculate metrics for this sample by averaging over all slices ---
        ssims, psnrs, mses, maes = [], [], [], []
        # Check if the volume is sliceable (for FDK/IMAG)
        is_sliceable = gt_volume_np.ndim > 2 and pred_volume_np.ndim > 2
        
        if is_sliceable:
            for s in range(gt_volume_np.shape[0]):
                gt_slice = gt_volume_np[s]
                pred_slice = pred_volume_np[s]
                slice_data_range = np.max(gt_slice) - np.min(gt_slice)
                if slice_data_range == 0: continue # Skip empty slices
                
                ssims.append(ssim(gt_slice, pred_slice, data_range=slice_data_range, **SSIM_KWARGS))
                psnrs.append(psnr(gt_slice, pred_slice, data_range=slice_data_range))
                mses.append(np.mean((gt_slice - pred_slice)**2))
                maes.append(np.mean(np.abs(gt_slice - pred_slice)))
            
            sample_avg_ssims.append(np.mean(ssims) if ssims else 0)
            sample_avg_psnrs.append(np.mean(psnrs) if psnrs else 0)
            sample_avg_mses.append(np.mean(mses) if mses else 0)
            sample_avg_maes.append(np.mean(maes) if maes else 0)

        # --- Update online statistics for the whole volume (Welford's Algorithm) ---
        delta = pred_volume_np - mean_volume_np
        mean_volume_np += delta / (i + 1)
        delta2 = pred_volume_np - mean_volume_np
        m2_volume_np += delta * delta2

    # --- Finalize stats and metrics ---
    uncertainty_volume_map = np.sqrt(m2_volume_np / n_samples) if n_samples > 1 else np.zeros_like(mean_volume_np)
    metrics = {}

    # Calculate metrics on the final mean prediction (averaged over slices)
    mean_ssims, mean_psnrs, mean_mses, mean_maes = [], [], [], []
    if is_sliceable:
        for s in range(gt_volume_np.shape[0]):
            gt_slice = gt_volume_np[s]
            mean_slice = mean_volume_np[s]
            slice_data_range = np.max(gt_slice) - np.min(gt_slice)
            if slice_data_range == 0: continue
            
            mean_ssims.append(ssim(gt_slice, mean_slice, data_range=slice_data_range, **SSIM_KWARGS))
            mean_psnrs.append(psnr(gt_slice, mean_slice, data_range=slice_data_range))
            mean_mses.append(np.mean((gt_slice - mean_slice)**2))
            mean_maes.append(np.mean(np.abs(gt_slice - mean_slice)))

    metrics['mean_ssim'] = np.mean(mean_ssims) if mean_ssims else 0
    metrics['mean_psnr'] = np.mean(mean_psnrs) if mean_psnrs else 0
    metrics['mean_mse'] = np.mean(mean_mses) if mean_mses else 0
    metrics['mean_mae'] = np.mean(mean_maes) if mean_maes else 0

    # Add the averaged sample metrics
    metrics['sample_avg_ssim'] = np.mean(sample_avg_ssims) if sample_avg_ssims else 0
    metrics['sample_avg_psnr'] = np.mean(sample_avg_psnrs) if sample_avg_psnrs else 0
    metrics['sample_avg_mse'] = np.mean(sample_avg_mses) if sample_avg_mses else 0
    metrics['sample_avg_mae'] = np.mean(sample_avg_maes) if sample_avg_maes else 0

    return metrics, mean_volume_np, uncertainty_volume_map

print("Metric calculation functions defined.")

### Visualization functions

In [None]:
def plot_mean_comparison(mean_pred, ground_truth, uncertainty_map, model_name, scan_name, slice_idx, tumor_coords_xy=None):
    """Plots the GT, mean prediction, absolute error, and uncertainty map."""
    error_map = np.abs(ground_truth - mean_pred)

    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    fig.suptitle(f'{model_name} - {scan_name} - Slice {slice_idx} (Mean vs. GT)', fontsize=16)

    im1 = axes[0].imshow(ground_truth, cmap='gray')
    axes[0].set_title('Ground Truth')
    axes[0].axis('off')
    fig.colorbar(im1, ax=axes[0])

    im2 = axes[1].imshow(mean_pred, cmap='gray')
    axes[1].set_title('Mean Prediction')
    axes[1].axis('off')
    fig.colorbar(im2, ax=axes[1])

    im3 = axes[2].imshow(error_map, cmap='magma')
    axes[2].set_title('Absolute Error Map')
    axes[2].axis('off')
    fig.colorbar(im3, ax=axes[2])

    im4 = axes[3].imshow(uncertainty_map, cmap='viridis')
    axes[3].set_title('Uncertainty (Std Dev)')
    axes[3].axis('off')
    fig.colorbar(im4, ax=axes[3])

    if tumor_coords_xy:
        x, y = tumor_coords_xy
        axes[0].annotate('', xy=(x, y), xytext=(x - 30, y - 30),
                         arrowprops=dict(facecolor='red', edgecolor='red', shrink=0.05, width=1, headwidth=5, headlength=5))
        axes[1].annotate('', xy=(x, y), xytext=(x - 30, y - 30),
                         arrowprops=dict(facecolor='red', edgecolor='red', shrink=0.05, width=1, headwidth=5, headlength=5))
        axes[2].annotate('', xy=(x, y), xytext=(x - 30, y - 30),
                         arrowprops=dict(facecolor='red', edgecolor='red', shrink=0.05, width=1, headwidth=5, headlength=5))
        axes[3].annotate('', xy=(x, y), xytext=(x - 30, y - 30),
                         arrowprops=dict(facecolor='red', edgecolor='red', shrink=0.05, width=1, headwidth=5, headlength=5))

    plt.tight_layout()
    plt.show()

def plot_ssim_map(ssim_map, model_name, scan_name):
    """Plots the SSIM map."""
    plt.figure(figsize=(6, 6))
    plt.imshow(ssim_map, cmap='viridis', vmin=0, vmax=1)
    plt.title(f'SSIM Map - {model_name} - {scan_name}')
    plt.colorbar()
    plt.axis('off')
    plt.show()

def plot_calibration_curve(ground_truth, mean_pred, uncertainty_map, model_name, scan_name, n_levels=20):
    """
    Plots the calibration curve as described in Kuleshov et al., 2018.
    """
    gt_flat = ground_truth.flatten()
    pred_flat = mean_pred.flatten()
    uncert_flat = uncertainty_map.flatten()

    pred_cdfs = scipy.stats.norm.cdf(gt_flat, loc=pred_flat, scale=uncert_flat)
    expected_confidence_levels = np.linspace(0, 1, n_levels)
    observed_frequencies = np.array([np.mean(pred_cdfs <= p_j) for p_j in expected_confidence_levels])

    plt.figure(figsize=(6, 6))
    plt.plot([0, 1], [0, 1], '--', color='grey', label='Perfectly Calibrated')
    plt.plot(expected_confidence_levels, observed_frequencies, '-o', label='Model Calibration')
    plt.xlabel('Expected Confidence Level')
    plt.ylabel('Observed Confidence Level')
    plt.title(f'Calibration Plot - {model_name} - {scan_name}')
    plt.legend()
    plt.grid(True, linestyle=':')
    plt.show()


print("Visualization functions defined.")

### Main loop

In [None]:
# This list will store dictionaries of results for each scan and model
all_results = []

for model_config in MODELS_TO_ANALYZE:
    model_name = model_config['name']
    domain = model_config['domain']
    
    scan_results = []

    for scan_info in tqdm(analysis_scans, desc=f"Analyzing Model: {model_name}"):
        patient, scan, _ = scan_info
        scan_name = f"p{patient}_{scan}"
        
        # --- Determine Domain and Plotting Slice ---
        is_visual_domain = domain in ['FDK', 'IMAG']
        plot_slice_idx = None
        tumor_xy = None
        
        if is_visual_domain:
            if 'tumor_locations' in locals() and tumor_locations is not None:
                try:
                    loc = tumor_locations[int(patient), int(scan)]
                    tumor_xy = (loc[1].item(), loc[0].item())
                    plot_slice_idx = int(loc[2].item()) - 20
                except (IndexError, TypeError):
                    print(f"Warning: Could not find tumor location for {scan_name}. Plotting will not have an arrow.")
                    plot_slice_idx = 100
            else:
                plot_slice_idx = 100
        
        # --- Data Loading (Ground Truth Only) ---
        gt_volume = load_ground_truth(FILES, scan_info, domain, slice_idx=None)
        gt_volume_np = gt_volume.cpu().numpy()
        
        # --- Iterative Metric Calculation ---
        print("Calculating metrics for the entire volume...")
        iq_metrics, mean_pred_vol, uncertainty_map_vol = calculate_volume_metrics_iteratively(
            FILES, model_config, scan_info, gt_volume_np
        )
        
        # --- Uncertainty Metric Calculation ---
        errors_vol = np.abs(gt_volume_np - mean_pred_vol)
        print("Calculating AUSE...")
        ause_val = calculate_ause_sparsification(uncertainty_map_vol, errors_vol)
        print("Calculating ECE...")
        ece_val = calculate_ece(gt_volume_np, mean_pred_vol, uncertainty_map_vol)
        print("DONE")
        
        # --- Store Results ---
        scan_result = {
            'model_name': model_name,
            'scan_name': scan_name,
            **iq_metrics,
            'ause': ause_val,
            'ece': ece_val,
        }
        scan_results.append(scan_result)
        
        # --- Visualization (Conditional) ---
        if is_visual_domain:
            print(f"\n--- Results for {model_name} on {scan_name} (Plotting Slice: {plot_slice_idx}) ---")
            gt_slice_np = gt_volume_np[plot_slice_idx]
            mean_pred_slice = mean_pred_vol[plot_slice_idx]
            uncertainty_map_slice = uncertainty_map_vol[plot_slice_idx]
            _, ssim_map_val = ssim(gt_slice_np, mean_pred_slice, 
                                  data_range=(np.max(gt_slice_np) - np.min(gt_slice_np)), 
                                  full=True, **SSIM_KWARGS)
            
            plot_mean_comparison(mean_pred_slice, gt_slice_np, uncertainty_map_slice, 
                                 model_name, scan_name, plot_slice_idx, tumor_coords_xy=tumor_xy)
            plot_ssim_map(ssim_map_val, model_name, scan_name)
            
            # New call to plot the calibration curve
            plot_calibration_curve(gt_volume_np, mean_pred_vol, uncertainty_map_vol, model_name, scan_name)
            
        else:
            print(f"\n--- Metrics calculated for {model_name} on {scan_name} (PROJ domain, no plots) ---")
            # Still plot the calibration curve for non-visual domains
            plot_calibration_curve(gt_volume_np, mean_pred_vol, uncertainty_map_vol, model_name, scan_name)

        # --- Clean up memory ---
        del gt_volume, gt_volume_np, mean_pred_vol, uncertainty_map_vol, errors_vol
        gc.collect()
        
    all_results.extend(scan_results)

# Convert results to a pandas DataFrame for easier analysis
results_df = pd.DataFrame(all_results)

print("\n\n✅ Analysis complete for all models and scans.")
results_df

### Summary/Comparison

In [None]:
if len(MODELS_TO_ANALYZE) > 1:
    summary_data = []
    
    # Only aggregate numeric columns
    numeric_cols = results_df.select_dtypes(include=[np.number]).columns
    summary = results_df.groupby('model_name')[numeric_cols].agg(['mean', 'std']).reset_index()

    # Prepare summary_display with formatted mean ± std for each metric
    summary_display = pd.DataFrame()
    summary_display['model_name'] = summary['model_name']
    for col in numeric_cols:
        mean_col = (col, 'mean')
        std_col = (col, 'std')
        summary_display[col] = summary[mean_col].map('{:.4f}'.format) + ' ± ' + summary[std_col].map('{:.4f}'.format)
    
    print("\n\n=======================================================")
    print("               Model Comparison Summary")
    print("=======================================================")
    
    display(summary_display)

else:
    print("\nOnly one model was analyzed. No comparison table to generate.")