In [1]:
# --- CONFIG ---
vol_path   = r"U:\users\taki\vizualization\test.tif"   # <-- set this
ckpt_path  = "deblur3d_unet.pt"
base, levels = 24, 4
tile     = (64, 256, 256)
overlap  = (32, 128, 128)
spacing  = (1.0, 1.0, 1.0)

# Baseline params (for 0–1 normalized data)
FWHM_vox = 9.0
sigma    = FWHM_vox / 2.3548
USM_amount = 2
LoG_lambda = 2
Wiener_K   = 0.015
RL_iters   = 10

# ---- Controlled CNN presets (edit these) ----
# strength: residual scaling (1.0 = as trained; <1 gentler; >1 stronger)
# hp_sigma: high-pass Gaussian sigma (vox), hp_gain: multiply HP residual
# lp_gain : mix-in of the low-pass (denoise) branch
cnn_control_presets = [                                  # raw network (no control modulation)
    ("CNN α=2",           dict(strength=2)),
]

# --- imports ---
import os, time, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
from tqdm.auto import tqdm
from deblur3d.data   import read_volume_float01
from deblur3d.models import UNet3D_Residual, ControlledUNet3D
from deblur3d.infer  import deblur_volume_tiled

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert device.type == "cuda", "CUDA not available."

# ---------- reusable kernels (GPU) ----------
@torch.no_grad()
def _gauss1d(sigma: float, device, dtype=torch.float32, radius_mult: float = 3.0):
    import math
    sigma = max(1e-6, float(sigma))
    r = max(1, int(math.ceil(radius_mult * sigma)))
    x = torch.arange(-r, r + 1, device=device, dtype=dtype)
    k = torch.exp(-(x * x) / (2.0 * sigma * sigma))
    return (k / (k.sum() + 1e-12)), r

@torch.no_grad()
def gaussian_blur3d_tensor(x: torch.Tensor, sigma: float, pad_mode="reflect"):
    if sigma <= 0: return x
    k1d, r = _gauss1d(sigma, x.device, x.dtype)
    kz = k1d.view(1,1,-1,1,1); ky = k1d.view(1,1,1,-1,1); kx = k1d.view(1,1,1,1,-1)
    C = x.shape[1]
    y = F.conv3d(F.pad(x, (0,0,0,0,r,r), mode=pad_mode), kz, groups=C)
    y = F.conv3d(F.pad(y, (0,0,r,r,0,0), mode=pad_mode), ky, groups=C)
    y = F.conv3d(F.pad(y, (r,r,0,0,0,0), mode=pad_mode), kx, groups=C)
    return y

@torch.no_grad()
def laplacian3d_tensor(x: torch.Tensor, pad_mode="reflect"):
    w = torch.zeros((1,1,3,3,3), device=x.device, dtype=x.dtype)
    w[0,0,1,1,1] = 6.0
    w[0,0,1,1,0] = w[0,0,1,1,2] = -1.0
    w[0,0,1,0,1] = w[0,0,1,2,1] = -1.0
    w[0,0,0,1,1] = w[0,0,2,1,1] = -1.0
    xpad = F.pad(x, (1,1,1,1,1,1), mode=pad_mode)
    return F.conv3d(xpad, w)

# ---------- GPU baselines ----------
@torch.no_grad()
def usm3d_gpu(vol_t: torch.Tensor, sigma, amount):
    x = vol_t.unsqueeze(0).unsqueeze(0)
    base = gaussian_blur3d_tensor(x, sigma)
    y = (x + amount * (x - base)).clamp(0,1)
    return y.squeeze().detach().cpu().numpy().astype(np.float32)

@torch.no_grad()
def log_sharpen3d_gpu(vol_t: torch.Tensor, sigma, lam):
    x = vol_t.unsqueeze(0).unsqueeze(0)
    g  = gaussian_blur3d_tensor(x, sigma)
    L  = laplacian3d_tensor(g)
    y  = (x - lam * L).clamp(0,1)
    return y.squeeze().detach().cpu().numpy().astype(np.float32)

@torch.no_grad()
def wiener_gaussian3d_gpu(vol_t: torch.Tensor, sigma, K=0.01):
    x = vol_t
    D,H,W = x.shape
    X = torch.fft.fftn(x)
    fz = torch.fft.fftfreq(D, d=1.0, device=x.device).view(D,1,1)
    fy = torch.fft.fftfreq(H, d=1.0, device=x.device).view(1,H,1)
    fx = torch.fft.fftfreq(W, d=1.0, device=x.device).view(1,1,W)
    two_pi2 = (2.0 * np.pi) ** 2
    Htf = torch.exp(-0.5 * two_pi2 * (sigma**2) * (fz*fz + fy*fy + fx*fx))
    Y = X * Htf / (Htf*Htf + K)
    y = torch.fft.ifftn(Y).real.clamp(0,1)
    return y.detach().cpu().numpy().astype(np.float32)

@torch.no_grad()
def richardson_lucy3d_gpu(vol_t: torch.Tensor, sigma, n_iter=15):
    x = vol_t.unsqueeze(0).unsqueeze(0)
    psf1d, r = _gauss1d(sigma, x.device, x.dtype)
    kz = psf1d.view(1,1,-1,1,1); ky = psf1d.view(1,1,1,-1,1); kx = psf1d.view(1,1,1,1,-1)
    def psfZ(z): return F.conv3d(F.pad(z, (0,0,0,0,r,r), mode="replicate"), kz)
    def psfY(z): return F.conv3d(F.pad(z, (0,0,r,r,0,0), mode="replicate"), ky)
    def psfX(z): return F.conv3d(F.pad(z, (r,r,0,0,0,0), mode="replicate"), kx)
    def psf_conv(z):  return psfX(psfY(psfZ(z)))
    y = x.clamp_min(1e-6)
    it = tqdm(range(int(n_iter)), desc="RL (GPU) iters", leave=False)
    for _ in it:
        est = psf_conv(y).clamp_min(1e-6)
        ratio = x / est
        y = (y * psf_conv(ratio)).clamp(0,1)
    return y.squeeze().detach().cpu().numpy().astype(np.float32)

# --- helper: accurate CUDA timing ---
def run_timed_cuda(name, fn, *args, **kwargs):
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    with tqdm(total=1, desc=name, leave=False) as pbar:
        out = fn(*args, **kwargs)
        torch.cuda.synchronize()
        dt = time.perf_counter() - t0
        pbar.update(1)
    return out, dt

# --- load input volume & move to GPU once ---
vol = read_volume_float01(vol_path)
print("Input:", vol.shape, vol.dtype, f"min/max {vol.min():.3f}/{vol.max():.3f}")
vol_t = torch.from_numpy(vol).to(device, dtype=torch.float32)

# --- load trained net + controller ---
assert os.path.exists(ckpt_path), f"Checkpoint not found: {ckpt_path}"
net = UNet3D_Residual(in_ch=1, base=base, levels=levels).to(device).eval()
state = torch.load(ckpt_path, map_location=device)
net.load_state_dict(state.get("state_dict", state))
ctrl = ControlledUNet3D(net).to(device).eval()   # controller wraps the trained net

# small module to bind control kwargs so deblur_volume_tiled can call it like a net
class NetWithControl(nn.Module):
    def __init__(self, ctrl: ControlledUNet3D, **ctrl_kwargs):
        super().__init__()
        self.ctrl = ctrl
        self.kw = ctrl_kwargs
    @torch.no_grad()
    def forward(self, x):
        return self.ctrl(x, **self.kw)

times = {}
cnn_results = []  # list of (name, np_array)

# --- CNN base (no control) ---
cnn_base, t_cnn_base = run_timed_cuda("CNN (base) tiled", deblur_volume_tiled, net, vol, tile=tile, overlap=overlap, device=device.type)
times["CNN (base)"] = t_cnn_base
cnn_results.append(("CNN (base)", cnn_base))

# --- CNN controlled variants ---
for name, kwargs in cnn_control_presets:
    if name == "CNN (base)" and not kwargs:
        continue  # already ran above
    net_ctrl = NetWithControl(ctrl, **kwargs).to(device).eval()
    out, dt = run_timed_cuda(f"{name} tiled", deblur_volume_tiled, net_ctrl, vol, tile=tile, overlap=overlap, device=device.type)
    times[name] = dt
    cnn_results.append((name, out))

# --- baselines on GPU ---
usm,  t_usm  = run_timed_cuda(f"USM (GPU) σ={sigma:.2f}, a={USM_amount}", usm3d_gpu, vol_t, sigma, USM_amount)
logb, t_log  = run_timed_cuda(f"LoG (GPU) σ={sigma:.2f}, λ={LoG_lambda}",  log_sharpen3d_gpu, vol_t, sigma, LoG_lambda)
wien, t_win  = run_timed_cuda(f"Wiener (GPU) σ={sigma:.2f}, K={Wiener_K}", wiener_gaussian3d_gpu, vol_t, sigma, Wiener_K)
rl,   t_rl   = run_timed_cuda(f"RL (GPU) σ={sigma:.2f}, iters={RL_iters}", richardson_lucy3d_gpu, vol_t, sigma, RL_iters)

times.update({
    "USM (GPU)": t_usm,
    "LoG (GPU)": t_log,
    "Wiener (GPU)": t_win,
    "RL (GPU)": t_rl,
})

# --- timing summary ---
try:
    import pandas as pd
    from IPython.display import display
    df_times = (
        pd.DataFrame.from_dict(times, orient="index", columns=["seconds"])
        .sort_values("seconds")
    )
    display(df_times.style.format({"seconds": "{:.3f}"}))
except Exception:
    print("Times (s):", {k: f"{v:.3f}" for k, v in times.items()})

# --- visualize in napari ---
import napari
v = napari.Viewer(ndisplay=2)
L_in  = v.add_image(vol,  name="Input",      colormap="gray", scale=spacing)

# add all CNN variants first
for name, arr in cnn_results:
    L = v.add_image(arr, name=name, colormap="gray", scale=spacing, opacity=0.85)
    L.contrast_limits = L_in.contrast_limits

# add baselines
L_usm = v.add_image(usm,  name=f"USM σ={sigma:.2f} a={USM_amount}",  colormap="gray", scale=spacing, opacity=0.85)
L_log = v.add_image(logb, name=f"LoG σ={sigma:.2f} λ={LoG_lambda}",  colormap="gray", scale=spacing, opacity=0.85)
L_win = v.add_image(wien, name=f"Wiener σ={sigma:.2f} K={Wiener_K}", colormap="gray", scale=spacing, opacity=0.85)
L_rl  = v.add_image(rl,   name=f"RL σ={sigma:.2f} iters={RL_iters}", colormap="gray", scale=spacing, opacity=0.85)

for L in (L_usm, L_log, L_win, L_rl):
    L.contrast_limits = L_in.contrast_limits

napari.run()


  from .autonotebook import tqdm as notebook_tqdm


Input: (81, 712, 688) float32 min/max 0.000/1.000


RL (GPU) σ=3.82, iters=10:   0%|                                                                 | 0/1 [00:00<?, ?it/s]
RL (GPU) iters:   0%|                                                                           | 0/10 [00:00<?, ?it/s][A
                                                                                                                       [A

Unnamed: 0,seconds
USM (GPU),0.221
LoG (GPU),0.226
Wiener (GPU),0.477
RL (GPU),1.014
CNN α=2,3.612
CNN (base),5.389


In [2]:
# ===== Micro-CT metrics with input-anchored noise =====
import torch, numpy as np, pandas as pd
import torch.nn.functional as F


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def _to_t(x): return torch.from_numpy(x).to(device=device, dtype=torch.float32)

def _center_crop_np(x, maxD=128, maxH=256, maxW=256):
    D,H,W = x.shape
    d = min(D, maxD); h = min(H, maxH); w = min(W, maxW)
    zs = (D - d)//2; ys = (H - h)//2; xs = (W - w)//2
    return x[zs:zs+d, ys:ys+h, xs:xs+w].copy()

@torch.no_grad()
def _tenengrad_3d(x: torch.Tensor):
    k = torch.tensor([1, 2, 1], dtype=x.dtype, device=x.device)
    d = torch.tensor([1, 0,-1], dtype=x.dtype, device=x.device)
    kz = d.view(1,1,3,1,1); ky = d.view(1,1,1,3,1); kx = d.view(1,1,1,1,3)
    x4 = x.unsqueeze(0).unsqueeze(0)
    gz = F.conv3d(F.pad(x4,(0,0,0,0,1,1),'replicate'), kz)
    gy = F.conv3d(F.pad(x4,(0,0,1,1,0,0),'replicate'), ky)
    gx = F.conv3d(F.pad(x4,(1,1,0,0,0,0),'replicate'), kx)
    return (gx*gx + gy*gy + gz*gz).mean().item()

@torch.no_grad()
def _lap_var_3d(x: torch.Tensor):
    w = torch.zeros((1,1,3,3,3), device=x.device, dtype=x.dtype)
    w[0,0,1,1,1] = 6.0
    w[0,0,1,1,0] = w[0,0,1,1,2] = -1.0
    w[0,0,1,0,1] = w[0,0,1,2,1] = -1.0
    w[0,0,0,1,1] = w[0,0,2,1,1] = -1.0
    y = F.conv3d(F.pad(x.unsqueeze(0).unsqueeze(0),(1,1,1,1,1,1),'replicate'), w).squeeze()
    return y.var().item()

@torch.no_grad()
def _gauss_separable(x4, sigma):
    # x4: (1,1,D,H,W)
    import math
    s = max(1e-6, float(sigma))
    r = max(1, int(math.ceil(3*s)))
    z = torch.arange(-r, r+1, device=x4.device, dtype=x4.dtype)
    g = torch.exp(-(z*z)/(2*s*s)); g = g/g.sum()
    kz = g.view(1,1,-1,1,1); ky = g.view(1,1,1,-1,1); kx = g.view(1,1,1,1,-1)
    y = F.conv3d(F.pad(x4,(0,0,0,0,r,r),'reflect'), kz)
    y = F.conv3d(F.pad(y,(0,0,r,r,0,0),'reflect'), ky)
    y = F.conv3d(F.pad(y,(r,r,0,0,0,0),'reflect'), kx)
    return y

@torch.no_grad()
def _build_flat_mask_from_input(x_in: torch.Tensor, flat_pct=0.30, min_vox=32768):
    """Select low-texture voxels from the *input*; same mask reused for all methods."""
    t = x_in.unsqueeze(0).unsqueeze(0)
    # light pre-blur to suppress texture edges influencing the mask
    t_s = _gauss_separable(t, sigma=0.7)
    k = torch.tensor([1,0,-1], dtype=t.dtype, device=t.device).view(1,1,3,1,1)
    gx = F.conv3d(F.pad(t_s,(0,0,0,0,1,1),'replicate'), k).abs()
    gy = F.conv3d(F.pad(t_s,(0,0,1,1,0,0),'replicate'), k.transpose(2,3)).abs()
    gz = F.conv3d(F.pad(t_s,(1,1,0,0,0,0),'replicate'),
                  torch.tensor([1,0,-1], dtype=t.dtype, device=t.device).view(1,1,1,1,3)).abs()
    grad = (gx + gy + gz).squeeze()
    q = torch.quantile(grad, flat_pct, interpolation="nearest")
    m = (grad <= q)
    if m.sum().item() < min_vox:  # ensure enough samples
        ksel = min(min_vox, grad.numel())
        _, idx = torch.topk((-grad).flatten(), k=ksel)  # smallest gradients
        m = torch.zeros_like(grad, dtype=torch.bool).flatten()
        m[idx] = True
        m = m.view_as(grad)
    return m  # bool (D,H,W)

@torch.no_grad()
def _noise_mad_hp_masked(x: torch.Tensor, mask: torch.Tensor, sigma=1.0):
    """MAD of high-pass residual over a fixed mask."""
    x4 = x.unsqueeze(0).unsqueeze(0)
    low = _gauss_separable(x4, sigma=sigma).squeeze()
    hp = (x - low)
    r = hp[mask]
    if r.numel() == 0:
        return float('nan')
    med = r.median()
    mad = (r - med).abs().median() * 1.4826
    # avoid printing 0.0000 due to float underflow in display
    return float(mad.item() + 1e-12)

@torch.no_grad()
def _hf_energy_ratio(x: torch.Tensor, r0=0.6):
    D,H,W = x.shape
    X = torch.fft.fftn(x); P = (X.abs()**2)
    fz = torch.fft.fftfreq(D, d=1.0, device=x.device).view(D,1,1)
    fy = torch.fft.fftfreq(H, d=1.0, device=x.device).view(1,H,1)
    fx = torch.fft.fftfreq(W, d=1.0, device=x.device).view(1,1,W)
    fny = 0.5
    r = torch.sqrt((fz/fny)**2 + (fy/fny)**2 + (fx/fny)**2)
    mask = (r >= r0)
    return (P[mask].sum() / (P.sum() + 1e-12)).item()

def evaluate_methods_no_gt(
    outputs: dict,
    vol_input: np.ndarray,
    crop=(128,256,256),
    hp_sigma_noise=1.0,
    flat_pct=0.30,
    min_vox=32768,
    hf_r0=0.6,
):
    """
    outputs: dict name -> np.ndarray (D,H,W) in [0,1], should include 'Input'
    vol_input: original input volume (np.ndarray) used to anchor the flat mask
    crop: center crop size for metrics (reduces FFT memory)
    hp_sigma_noise: sigma (vox) for LP in high-pass residual for Noise_MAD
    flat_pct: fraction of lowest-gradient voxels (on input) to define 'flat' mask
    min_vox: minimum voxels guaranteed in the flat mask
    hf_r0: normalized radius threshold for HF energy ratio (0..1 of Nyquist)
    """
    # center crop input and move to device
    xin_c = _center_crop_np(vol_input, *crop)
    x_in  = _to_t(xin_c)

    # build a fixed 'flat' mask from INPUT (reused across methods)
    flat_mask = _build_flat_mask_from_input(x_in, flat_pct=flat_pct, min_vox=min_vox)

    # input's own noise for NRF
    noise_in = _noise_mad_hp_masked(x_in, flat_mask, sigma=hp_sigma_noise)

    rows = []
    hf_col = f"HF_ratio@r>{hf_r0}"
    for name, arr in outputs.items():
        arr_c = _center_crop_np(arr, *crop)
        y = _to_t(arr_c)

        row = {"method": name}
        row["Tenengrad"] = _tenengrad_3d(y)
        row["Var(Lap)"]  = _lap_var_3d(y)
        # HF energy ratio at radius hf_r0
        row[hf_col]      = _hf_energy_ratio(y, r0=hf_r0)

        # input-anchored noise + NRF
        noise = _noise_mad_hp_masked(y, flat_mask, sigma=hp_sigma_noise)
        row["Noise_MAD"] = noise
        row["NRF"]       = float(noise / (noise_in + 1e-12))  # <1 = denoised vs input

        rows.append(row)

    import pandas as pd
    df = pd.DataFrame(rows).set_index("method")

    # join runtimes (map common baseline keys; CNN variants usually already match)
    if "times" in globals() and isinstance(times, dict):
        alias = {
            "USM (GPU)": "USM",
            "LoG (GPU)": "LoG",
            "Wiener (GPU)": "Wiener",
            "RL (GPU)": "RL",
            "CNN (base)": "CNN (base)",
        }
        t = pd.Series({alias.get(k, k): v for k, v in times.items()}, name="seconds")
        df = df.join(t, how="left")

    # order columns
    cols = [c for c in ["Tenengrad","Var(Lap)", hf_col, "Noise_MAD","NRF","seconds"] if c in df.columns]
    df = df[cols]

    # sort: sharpness ↑ then NRF ↓ (if present)
    sort_keys, asc = [], []
    if "Tenengrad" in df.columns: sort_keys.append("Tenengrad"); asc.append(False)
    if "NRF" in df.columns:       sort_keys.append("NRF");       asc.append(True)
    if sort_keys:
        df = df.sort_values(by=sort_keys, ascending=asc)

    return df

# ==== Automatic CNR metric (no manual ROIs) ====

@torch.no_grad()
def _otsu_threshold_from_masked(x: torch.Tensor, mask: torch.Tensor, bins=256):
    """Otsu threshold on x[mask] in [0,1]. Returns float threshold (CPU scalar)."""
    vals = x[mask].clamp(0,1).detach().float().cpu().numpy()
    if vals.size < 1024:  # too few voxels -> fallback to mid
        return float(vals.mean())
    hist, edges = np.histogram(vals, bins=bins, range=(0.0,1.0))
    hist = hist.astype(np.float64); w = hist.sum()
    if w <= 0: return 0.5
    p = hist / w
    omega = np.cumsum(p)
    mu = np.cumsum(p * (edges[:-1] + edges[1:]) * 0.5)
    mu_t = mu[-1]
    sigma_b = (mu_t * omega - mu)**2 / (omega * (1.0 - omega) + 1e-12)
    k = np.nanargmax(sigma_b)
    # threshold at bin boundary
    return float((edges[k] + edges[k+1]) * 0.5)

@torch.no_grad()
def _robust_sigma(v: torch.Tensor):
    """Robust σ via MAD (Gaussian equiv). v is 1D tensor."""
    if v.numel() == 0: return torch.tensor(float('nan'), device=v.device)
    med = v.median()
    mad = (v - med).abs().median()
    return mad * 1.4826

@torch.no_grad()
def _auto_cnr_on_fixed_partition(out_y: torch.Tensor,
                                 x_in: torch.Tensor,
                                 flat_mask: torch.Tensor,
                                 thr: float,
                                 robust=True,
                                 min_class_vox=4096):
    """
    Compute CNR on two classes defined ON INPUT (x_in <= thr and > thr), restricted to flat_mask.
    Returns (cnr, mu0, mu1, sig0, sig1, n0, n1).
    """
    m0 = flat_mask & (x_in <= thr)
    m1 = flat_mask & (x_in >  thr)
    # ensure both classes have some voxels; fallback to quantile split on input if needed
    if m0.sum().item() < min_class_vox or m1.sum().item() < min_class_vox:
        q0, q1 = torch.quantile(x_in[flat_mask], torch.tensor([0.3, 0.7], device=x_in.device))
        m0 = flat_mask & (x_in <= q0)
        m1 = flat_mask & (x_in >= q1)

    v0 = out_y[m0]; v1 = out_y[m1]
    if robust:
        s0 = _robust_sigma(v0); s1 = _robust_sigma(v1)
    else:
        s0 = v0.std(unbiased=False); s1 = v1.std(unbiased=False)
    mu0 = v0.mean(); mu1 = v1.mean()
    cnr = (mu1 - mu0).abs() / torch.sqrt(s0*s0 + s1*s1 + 1e-12)
    return float(cnr.item()), float(mu0.item()), float(mu1.item()), float(s0.item()), float(s1.item()), int(v0.numel()), int(v1.numel())

def add_auto_cnr_columns(df: pd.DataFrame,
                         outputs: dict,
                         vol_input: np.ndarray,
                         vx_size: float,
                         crop=(128,256,256),
                         hp_sigma_noise=1.0,  # reuse your noise LP sigma
                         flat_pct=0.30,
                         min_vox=32768,
                         robust=True):
    """
    Appends 'aCNR' and 'aCNR/(2.4*vx)' columns using:
      - flat mask from input (low-gradient voxels)
      - Otsu threshold computed ON INPUT within that mask
      - same spatial voxels used for every method
    """
    # center crop input and move to device
    xin_c = _center_crop_np(vol_input, *crop)
    x_in  = _to_t(xin_c)

    # fixed flat mask from input
    flat_mask = _build_flat_mask_from_input(x_in, flat_pct=flat_pct, min_vox=min_vox)

    # threshold from input (masked Otsu)
    thr = _otsu_threshold_from_masked(x_in, flat_mask, bins=256)

    # compute aCNR per method
    acnr, acnr_norm = {}, {}
    for name, arr in outputs.items():
        arr_c = _center_crop_np(arr, *crop)
        y = _to_t(arr_c)
        cnr, *_ = _auto_cnr_on_fixed_partition(y, x_in, flat_mask, thr, robust=robust)
        acnr[name]      = cnr
        acnr_norm[name] = cnr / (2.4 * float(vx_size) + 1e-12)

    df["aCNR"] = pd.Series(acnr)
    df[f"aCNR/(2.4*vx)"] = pd.Series(acnr_norm)
    return df



In [3]:
# vx_size: voxel edge length in your chosen units (e.g., micrometers); if anisotropic, pass the in-plane or mean.
vx_size = 1.0  # <-- set (e.g., µm per voxel). If your spacing is (z,y,x), you can use np.mean(spacing).

outputs = {"Input": vol, "USM": usm, "LoG": logb, "Wiener": wien, "RL": rl}
for name, arr in cnn_results:  # e.g. ("CNN (base)", ...), ("CNN α=2", ...), ...
    outputs[name] = arr
    
df_metrics = evaluate_methods_no_gt(
    outputs, vol_input=vol,
    crop=(128,256,256),
    hp_sigma_noise=1.0,  # noise LP sigma (vox)
    flat_pct=0.30,       # lowest-gradient fraction for mask
    min_vox=32768,       # ensure stable MAD
    hf_r0=0.6            # HF radius threshold (fraction of Nyquist)
)

df_metrics = add_auto_cnr_columns(
    df_metrics, outputs, vol_input=vol,
    vx_size=vx_size,
    crop=(128,256,256),
    hp_sigma_noise=1.0,
    flat_pct=0.30, min_vox=32768,
    robust=True
)

# Pretty print (adds new cols if present)
fmt_extra = {"aCNR":"{:.3f}", "aCNR/(2.4*vx)":"{:.4f}"}
fmt_all = {
    "Tenengrad":"{:.4e}", "Var(Lap)":"{:.4e}",
    "HF_ratio@r>0.6":"{:.4f}", "Noise_MAD":"{:.5f}", "NRF":"{:.3f}",
    "seconds":"{:.3f}", **fmt_extra
}
display(df_metrics.style.format({k:v for k,v in fmt_all.items() if k in df_metrics.columns}))


Unnamed: 0_level_0,Tenengrad,Var(Lap),HF_ratio@r>0.6,Noise_MAD,NRF,seconds,aCNR,aCNR/(2.4*vx)
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
USM,0.035922,0.037292,0.0044,0.04058,2.971,0.221,1.98,0.8248
CNN α=2,0.019414,0.0099185,0.0039,0.01543,1.13,3.612,2.294,0.9558
Wiener,0.01482,0.0015777,0.0003,0.0094,0.688,0.477,1.98,0.8251
CNN (base),0.013935,0.0086867,0.004,0.01205,0.882,5.389,2.528,1.0532
RL,0.0097358,0.004388,0.0008,0.01423,1.042,1.014,2.316,0.965
Input,0.0056434,0.0044227,0.0008,0.01366,1.0,,2.431,1.013
LoG,0.005342,0.0044068,0.0008,0.0136,0.996,0.226,2.4,1.0002
