## Setup and Configuration
This initial block sets up the environment, defines file paths, and specifies the validation scans and model versions to be analyzed.

In [None]:
from pipeline.paths import Directories, Files
import os
import torch
import numpy as np
import gc
import matplotlib.pyplot as plt
from sklearn.isotonic import IsotonicRegression

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
if torch.cuda.is_available():
    CUDA_DEVICE = torch.device("cuda:0")
    print(f"CUDA is available. Using device: {CUDA_DEVICE}")
else:
    print("CUDA is not available. Please check your PyTorch installation. Using CPU instead.")
    CUDA_DEVICE = torch.device("cpu")

PHASE = '7'
DATA_VERSION = '13'


# Base directory
WORK_ROOT = "D:/NoahSilverberg/ngCBCT"

# NSG_CBCT Path where the raw matlab data is stored
NSG_CBCT_PATH = "D:/MitchellYu/NSG_CBCT"

# Directory with all files specific to this phase/data version
PHASE_DATAVER_DIR = os.path.join(
    WORK_ROOT, f"phase{PHASE}", f"DS{DATA_VERSION}"
)

DIRECTORIES = Directories(
    # mat_projections_dir=os.path.join(NSG_CBCT_PATH, "data/prj/HF/mat"),
    # pt_projections_dir=os.path.join(WORK_ROOT, "prj_pt"),
    # projections_aggregate_dir=os.path.join(PHASE_DATAVER_DIR, "aggregates", "projections"),
    # projections_model_dir=os.path.join(PHASE_DATAVER_DIR, "models", "projections"),
    # projections_results_dir=os.path.join(PHASE_DATAVER_DIR, "results", "projections"),
    # projections_gated_dir=os.path.join(WORK_ROOT, "gated", "prj_mat"),
    reconstructions_dir=os.path.join(PHASE_DATAVER_DIR, "reconstructions"),
    reconstructions_gated_dir=os.path.join(WORK_ROOT, "gated", "fdk_recon"),
    # reconstructions_dir=os.path.join("H:\\", "Public", "Noah", "reconstructions"),
    # reconstructions_gated_dir=os.path.join("H:\\", "Public", "Noah", "gated", "fdk_recon"),
    # images_aggregate_dir=os.path.join(PHASE_DATAVER_DIR, "aggregates", "images"),
    # images_model_dir=os.path.join(PHASE_DATAVER_DIR, "models", "images"),
    # images_results_dir=os.path.join(PHASE_DATAVER_DIR, "results", "images"),
)

FILES = Files(DIRECTORIES)

# VAL_SCANS = [('02', '01'), ('02', '02'), ('16', '01'), ('16', '02'), ('22', '01'), ('22', '02')]
VAL_SCANS = [('08', '01'), ('10', '01'), ('14', '01'), ('14', '02'), ('15', '01'), ('20', '01')]

SCAN_TYPE = 'HF'
MODEL_VERSIONS = ['MK7_MCDROPOUT']
PASSTHROUGH_COUNT = 50 # use None for multiple models -- only use this for one model for MC dropout

## Data Loading
This block gathers the file paths for the model's reconstructions (50 passes for MC Dropout) and the corresponding ground truth reconstructions.

In [2]:
recon_paths_dict = {}
recon_names_dict = {}
gt_paths_dict = {}

for patient, scan in VAL_SCANS:
    print(f"Processing patient {patient}, scan {scan}")
    recon_paths = []
    recon_names = []
    if PASSTHROUGH_COUNT is None:
        for model_version in MODEL_VERSIONS:
            recon_path = FILES.get_recon_filepath(model_version, patient, scan, SCAN_TYPE, gated=False)
            recon_paths.append(recon_path)
            recon_names.append(model_version)
    else:
        for i in range(PASSTHROUGH_COUNT):
            recon_path = FILES.get_recon_filepath(MODEL_VERSIONS[0], patient, scan, SCAN_TYPE, gated=False, passthrough_num=i)
            recon_paths.append(recon_path)
            recon_names.append(f"Passthrough {i+1}")

    gt_path = FILES.get_recon_filepath('fdk', patient, scan, SCAN_TYPE, gated=True)
    gt_paths_dict[(patient, scan)] = gt_path

    recon_paths_dict[(patient, scan)] = recon_paths
    recon_names_dict[(patient, scan)] = recon_names

## Analysis Loop
The main loop iterates through each scan. For each scan, it loads all 50 reconstruction passes, calculates the mean (prediction) and standard deviation (uncertainty), and then performs several analyses to evaluate the quality of the uncertainty estimates.

In [None]:
# Now go through each patient and scan, load the reconstructions
# calclate the mean and std (pixel-wise) and the error from GT
for (patient, scan), recon_paths in recon_paths_dict.items():
    print(f"\nProcessing reconstructions for patient {patient}, scan {scan}")
    
    # Load ground truth
    gt_path = gt_paths_dict[(patient, scan)]
    gt_recon = torch.load(gt_path).cpu().numpy()
    
    # Initialize lists to hold reconstructions
    reconstructions = []
    
    for recon_path in recon_paths:
        recon = torch.load(recon_path).cpu().numpy()
        reconstructions.append(recon)

    print(f"Loaded {len(reconstructions)} reconstructions for patient {patient}, scan {scan}")
    
    # Convert to numpy array for easier manipulation
    reconstructions = np.array(reconstructions)
    
    # --- Basic Calculations ---
    # Calculate mean and std across the first axis (across models or passthroughs)
    mean_recon = np.mean(reconstructions, axis=0)
    std_recon = np.std(reconstructions, axis=0)
    
    # Calculate error from ground truth
    error = mean_recon - gt_recon
    
    # To avoid memory issues, let's work with flattened arrays from here on
    # and select a random subset of the data for some plots to keep them readable.
    # We'll also only consider voxels inside a radius of 225px from the center (in the last 2 dims)
    # (note scans are 200x512x512)
    # Efficiently create a circular mask in the last two dimensions
    z, y, x = gt_recon.shape
    yy, xx = np.ogrid[:y, :x]
    center = np.array([y // 2, x // 2])
    radius = 225
    dist_from_center = np.sqrt((yy - center[0])**2 + (xx - center[1])**2)
    slice_mask = dist_from_center <= radius
    mask = np.broadcast_to(slice_mask, gt_recon.shape)
    
    flat_std = std_recon[mask].flatten()
    flat_error = error[mask].flatten()
    flat_abs_error = np.abs(flat_error)
    flat_squared_error = flat_error**2

    # Clean up memory
    del reconstructions, mean_recon, std_recon, error, gt_recon
    gc.collect()

    # --- Analysis 1: Standardized Error Histogram (Your Original Analysis) ---
    # This shows the distribution of errors normalized by the raw, uncalibrated uncertainty.
    # As you noted, the coverage is poor, indicating the raw std dev is too small.
    standardized_error = flat_error / (flat_std + 1e-8)
    plt.figure(figsize=(10, 6))
    plt.hist(standardized_error, bins=100, color='gray', alpha=0.7, range=(-5, 5))
    plt.title(f'UNCALIBRATED Standardized Error Histogram for Patient {patient}, Scan {scan}')
    plt.xlabel('Standardized Error (Error / Predicted Std Dev)')
    plt.ylabel('Frequency')
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.show()

    percent_within_1 = np.mean(np.abs(standardized_error) < 1) * 100
    percent_within_2 = np.mean(np.abs(standardized_error) < 2) * 100
    percent_within_3 = np.mean(np.abs(standardized_error) < 3) * 100
    print(f"Patient {patient}, Scan {scan} - UNCALIBRATED Coverage:")
    print(f"  Percent with abs standardized error < 1: {percent_within_1:.2f}% (Expected for Gaussian: ~68%)")
    print(f"  Percent with abs standardized error < 2: {percent_within_2:.2f}% (Expected for Gaussian: ~95%)")
    print(f"  Percent with abs standardized error < 3: {percent_within_3:.2f}% (Expected for Gaussian: ~99.7%)")

    # --- Analysis 2 & 3: Plots for Hyperparameter Tuning (Relative Uncertainty) ---
    # These plots from the previous step are excellent for comparing different dropout rates.
    # A model with a stronger correlation (tighter diagonal on the 2D hist) and a faster-
    # dropping sparsification curve is likely better, regardless of the absolute scale.
    plt.figure(figsize=(10, 8))
    plt.hist2d(flat_std, flat_abs_error, bins=50, cmap='inferno')
    plt.colorbar(label='Voxel Count')
    plt.title(f'Predicted Uncertainty vs. Actual Error for Patient {patient}, Scan {scan}')
    plt.xlabel('Predicted Standard Deviation (Uncertainty)')
    plt.ylabel('Absolute Error')
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.show()

    # --- Analysis 3: Calibration Plot (Reliability Diagram) ---
    # This plot assesses how well the *magnitude* of the predicted uncertainty corresponds
    # to the *magnitude* of the actual error. Voxels are binned by their predicted uncertainty.
    # For each bin, we plot the average predicted uncertainty against the actual error (RMSE).
    # A perfectly calibrated model would fall on the y=x line.
    num_bins = 20
    bin_limits = np.linspace(np.min(flat_std), np.max(flat_std), num_bins + 1)
    rmse_per_bin = np.zeros(num_bins)
    mean_std_per_bin = np.zeros(num_bins)
    
    for i in range(num_bins):
        lower_bound = bin_limits[i]
        upper_bound = bin_limits[i+1]
        mask_bin = (flat_std >= lower_bound) & (flat_std < upper_bound)
        if np.sum(mask_bin) > 0:
            rmse_per_bin[i] = np.sqrt(np.mean(flat_squared_error[mask_bin]))
            mean_std_per_bin[i] = np.mean(flat_std[mask_bin])
    
    # Filter out empty bins
    valid_bins = mean_std_per_bin > 0
    
    plt.figure(figsize=(8, 8))
    plt.plot(mean_std_per_bin[valid_bins], rmse_per_bin[valid_bins], 'o-', label='Model Calibration', color='royalblue')
    # Plot the ideal y=x line for perfect calibration
    lims = [min(np.min(mean_std_per_bin[valid_bins]), np.min(rmse_per_bin[valid_bins])), max(np.max(mean_std_per_bin[valid_bins]), np.max(rmse_per_bin[valid_bins]))]
    plt.plot(lims, lims, 'k--', label='Perfect Calibration (y=x)')
    plt.title(f'Calibration Plot for Patient {patient}, Scan {scan}')
    plt.xlabel('Average Predicted Standard Deviation (per bin)')
    plt.ylabel('Root Mean Squared Error (per bin)')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.axis('equal')
    plt.show()
    
    sorted_indices = np.argsort(flat_std)
    sorted_squared_error = flat_squared_error[sorted_indices]
    fractions_removed = np.linspace(0, 0.5, 51)
    mse_remaining = np.zeros_like(fractions_removed)
    for i, frac in enumerate(fractions_removed):
        num_to_keep = int((1 - frac) * len(sorted_squared_error))
        mse_remaining[i] = np.mean(sorted_squared_error[:num_to_keep]) if num_to_keep > 0 else 0
    plt.figure(figsize=(10, 6))
    plt.plot(fractions_removed * 100, mse_remaining, 'o-', color='crimson')
    plt.title(f'Sparsification Plot for Patient {patient}, Scan {scan}')
    plt.xlabel('Percent of Most Uncertain Voxels Removed (%)')
    plt.ylabel('Mean Squared Error (on remaining voxels)')
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.show()
    
    # --- Analysis 4: Post-Hoc Uncertainty Calibration ---
    # Here, we learn a mapping from the raw predicted uncertainty to the actual observed error.
    # This fixes the scaling issue and gives us a meaningful absolute uncertainty.
    
    # Method 1: Linear Calibration
    # Fits a line: y = a*x + b, mapping raw std dev to absolute error.
    p_linear = np.polyfit(flat_std, flat_abs_error, 1)
    calibrated_std_linear = np.polyval(p_linear, flat_std)
    # Ensure non-negative uncertainty
    calibrated_std_linear[calibrated_std_linear < 0] = 0
    
    # Method 2: Isotonic Regression
    # A more powerful, non-parametric approach. We map variance to squared error
    # as this relationship is more likely to be monotonically increasing.
    ir = IsotonicRegression(out_of_bounds='clip')
    # Use a subset for fitting to speed up the process, as the relationship is stable.
    # This is a common practice when dealing with millions of points.
    fit_indices = np.random.choice(len(flat_std), size=min(100000, len(flat_std)), replace=False)
    # The model maps predicted variance (std^2) to observed squared error (error^2)
    calibrated_var_isotonic = ir.fit_transform(flat_std[fit_indices]**2, flat_squared_error[fit_indices])
    # Apply the mapping to all data points
    calibrated_var_isotonic_full = ir.predict(flat_std**2)
    calibrated_std_isotonic = np.sqrt(calibrated_var_isotonic_full)

    print("\n--- Calibration Results ---")

    # --- Visualize the Calibrated Results ---
    
    # Plot new Calibration Plots for both methods to show the improvement.
    # The points should now lie much closer to the y=x line.
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
    fig.suptitle(f'Calibration Plots for Patient {patient}, Scan {scan}')
    
    # Linear Calibration Plot
    num_bins = 20
    bin_limits_lin = np.linspace(np.min(calibrated_std_linear), np.max(calibrated_std_linear), num_bins + 1)
    rmse_per_bin_lin = np.zeros(num_bins)
    mean_std_per_bin_lin = np.zeros(num_bins)
    for i in range(num_bins):
        mask_bin = (calibrated_std_linear >= bin_limits_lin[i]) & (calibrated_std_linear < bin_limits_lin[i+1])
        if np.sum(mask_bin) > 0:
            rmse_per_bin_lin[i] = np.sqrt(np.mean(flat_squared_error[mask_bin]))
            mean_std_per_bin_lin[i] = np.mean(calibrated_std_linear[mask_bin])
    valid_bins_lin = mean_std_per_bin_lin > 0
    ax1.plot(mean_std_per_bin_lin[valid_bins_lin], rmse_per_bin_lin[valid_bins_lin], 'o-', label='Linear Calibration', color='green')
    lims1 = [0, max(np.max(mean_std_per_bin_lin[valid_bins_lin]), np.max(rmse_per_bin_lin[valid_bins_lin]))]
    ax1.plot(lims1, lims1, 'k--', label='Perfect Calibration')
    ax1.set_title('After Linear Calibration')
    ax1.set_xlabel('Average Predicted Standard Deviation')
    ax1.set_ylabel('Root Mean Squared Error')
    ax1.legend()
    ax1.grid(True, linestyle='--', alpha=0.6)
    ax1.axis('equal')
    
    # Isotonic Calibration Plot
    bin_limits_iso = np.linspace(np.min(calibrated_std_isotonic), np.max(calibrated_std_isotonic), num_bins + 1)
    rmse_per_bin_iso = np.zeros(num_bins)
    mean_std_per_bin_iso = np.zeros(num_bins)
    for i in range(num_bins):
        mask_bin = (calibrated_std_isotonic >= bin_limits_iso[i]) & (calibrated_std_isotonic < bin_limits_iso[i+1])
        if np.sum(mask_bin) > 0:
            rmse_per_bin_iso[i] = np.sqrt(np.mean(flat_squared_error[mask_bin]))
            mean_std_per_bin_iso[i] = np.mean(calibrated_std_isotonic[mask_bin])
    valid_bins_iso = mean_std_per_bin_iso > 0
    ax2.plot(mean_std_per_bin_iso[valid_bins_iso], rmse_per_bin_iso[valid_bins_iso], 'o-', label='Isotonic Calibration', color='purple')
    lims2 = [0, max(np.max(mean_std_per_bin_iso[valid_bins_iso]), np.max(rmse_per_bin_iso[valid_bins_iso]))]
    ax2.plot(lims2, lims2, 'k--', label='Perfect Calibration')
    ax2.set_title('After Isotonic Regression')
    ax2.set_xlabel('Average Predicted Standard Deviation')
    ax2.set_ylabel('Root Mean Squared Error')
    ax2.legend()
    ax2.grid(True, linestyle='--', alpha=0.6)
    ax2.axis('equal')
    
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

    # --- Recalculate Coverage Statistics with Calibrated Uncertainty ---
    # Now that the uncertainty is scaled correctly, the coverage should be much closer to ideal.
    calibrated_standardized_error_linear = flat_error / (calibrated_std_linear + 1e-8)
    calibrated_standardized_error_isotonic = flat_error / (calibrated_std_isotonic + 1e-8)
    
    percent_within_1_linear = np.mean(np.abs(calibrated_standardized_error_linear) < 1) * 100
    percent_within_1_isotonic = np.mean(np.abs(calibrated_standardized_error_isotonic) < 1) * 100

    print(f"Patient {patient}, Scan {scan} - CALIBRATED Coverage:")
    print(f"  [Linear] Percent with abs standardized error < 1: {percent_within_1_linear:.2f}%")
    print(f"  [Isotonic] Percent with abs standardized error < 1: {percent_within_1_isotonic:.2f}% (Ideal is ~68%)")
    
    # Clean up memory before next loop iteration
    del flat_std, flat_error, flat_abs_error, flat_squared_error, standardized_error
    del calibrated_std_linear, calibrated_std_isotonic, calibrated_var_isotonic_full
    gc.collect()