In [1]:
#!/usr/bin/env python
# scripts/compare_infer.py
import os, time, numpy as np, torch, torch.nn as nn

from deblur3d.data.io   import read_volume_float01
from deblur3d.models    import UNet3D_Residual, ControlledUNet3D
from deblur3d.infer     import deblur_volume_tiled
from deblur3d.infer.baselines3d import run_baselines  # selectable baselines

# ---------------- CONFIG ----------------
vol_path   = r"U:\users\taki\vizualization\CaSO4.tif"
ckpt_path  = r"checkpoints\deblur3d_unet_best.pt"
base, levels = 16, 4
tile     = (64, 256, 256)
overlap  = (32, 128, 128)
spacing  = (1.0, 1.0, 1.0)

# Baselines to run: subset of {"USM","LoG","Wiener","RL"}
baselines_to_run = ["USM"]
FWHM_vox   = 7.0
USM_amount = 1.0
LoG_lambda = 2.0
Wiener_K   = 0.015
RL_iters   = 10

# Controlled CNN presets (optional)
cnn_control_presets = [
    ("CNN α=2", dict(strength=2.0)),
    ("CNN gentle", dict(strength=0.7, hp_sigma=1.5, hp_gain=0.8, lp_gain=0.2)),
    ("CNN hp", dict(strength=1, hp_sigma=1.5, hp_gain=0.8, lp_gain=0.2)),
    ("CNN hp_gain", dict(strength=1, hp_sigma=1.5, hp_gain=2.0, lp_gain=0.2)),
]

use_napari = True  # False to skip viz
# ----------------------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert device.type == "cuda", "CUDA not available."

def run_timed_cuda(name, fn, *args, **kwargs):
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    out = fn(*args, **kwargs)
    torch.cuda.synchronize()
    dt = time.perf_counter() - t0
    print(f"[{name}] {dt:.3f}s")
    return out, dt

# ---- Load input volume (float32 [0,1]) ----
vol = read_volume_float01(vol_path)
print("Input:", vol.shape, vol.dtype, f"min/max {vol.min():.3f}/{vol.max():.3f}")

# ---- Load network + controller ----
assert os.path.exists(ckpt_path), f"Checkpoint not found: {ckpt_path}"
net = UNet3D_Residual(in_ch=1, base=base, levels=levels).to(device).eval()
state = torch.load(ckpt_path, map_location=device)
net.load_state_dict(state.get("state_dict", state))
ctrl = ControlledUNet3D(net).to(device).eval()

class NetWithControl(nn.Module):
    def __init__(self, ctrl: ControlledUNet3D, **ctrl_kwargs):
        super().__init__()
        self.ctrl = ctrl
        self.kw = ctrl_kwargs
    @torch.no_grad()
    def forward(self, x):
        return self.ctrl(x, **self.kw)

times = {}
results = {}

# ---- CNN (base) ----
cnn_base, t_cnn_base = run_timed_cuda(
    "CNN (base) tiled",
    deblur_volume_tiled, net, vol,
    tile=tile, overlap=overlap, device=device.type
)
times["CNN (base)"] = t_cnn_base
results["CNN (base)"] = cnn_base

# ---- CNN (controlled variants) ----
for name, kwargs in cnn_control_presets:
    net_ctrl = NetWithControl(ctrl, **kwargs).to(device).eval()
    out, dt = run_timed_cuda(
        f"{name} tiled",
        deblur_volume_tiled, net_ctrl, vol,
        tile=tile, overlap=overlap, device=device.type
    )
    times[name] = dt
    results[name] = out

# ---- Selected baselines (GPU) ----
res_b, times_b = run_baselines(
    vol,
    run=baselines_to_run,
    fwhm_vox=FWHM_vox,
    usm_amount=USM_amount,
    log_lambda=LoG_lambda,
    wiener_K=Wiener_K,
    rl_iters=RL_iters,
    device=device,
)
results.update({f"{k} (GPU)": v for k, v in res_b.items()})
times.update({f"{k} (GPU)": t for k, t in times_b.items()})

# ---- Timing summary ----
try:
    import pandas as pd
    df = (pd.DataFrame.from_dict(times, orient="index", columns=["seconds"])
                .sort_values("seconds"))
    print("\nTiming (s):")
    with pd.option_context('display.max_rows', None):
        print(df.to_string(formatters={"seconds": "{:.3f}".format}))
except Exception:
    print("Times (s):", {k: f"{v:.3f}" for k, v in times.items()})

# ---- Optional visualization ----
if use_napari:
    import napari
    v = napari.Viewer(ndisplay=2)
    L_in = v.add_image(vol, name="Input", colormap="gray", scale=spacing)
    for name, arr in results.items():
        L = v.add_image(arr, name=name, colormap="gray", scale=spacing, opacity=0.85)
        L.contrast_limits = L_in.contrast_limits
    napari.run()


Input: (401, 400, 400) float32 min/max 0.000/1.000
[CNN (base) tiled] 10.009s
[CNN α=2 tiled] 6.305s
[CNN gentle tiled] 6.407s
[CNN hp tiled] 6.445s
[CNN hp_gain tiled] 6.429s

Timing (s):
            seconds
USM (GPU)     0.389
CNN α=2       6.305
CNN gentle    6.407
CNN hp_gain   6.429
CNN hp        6.445
CNN (base)   10.009




In [2]:
from deblur3d.metrics.no_ref3d import evaluate_methods_no_gt, add_auto_cnr_columns

# outputs: dict of name -> np.float32 volume (D,H,W) in [0,1]
# e.g., outputs = {"Input": vol_in, "CNN (base)": vol_cnn, "USM (GPU)": vol_usm, ...}
df = evaluate_methods_no_gt(
    outputs=outputs,
    vol_input=outputs["Input"],
    crop=(128,256,256),
    hp_sigma_noise=1.0,
    flat_pct=0.30,
    min_vox=32768,
    hf_r0=0.6,
    times=times_dict_or_None,   # optional
)
df = add_auto_cnr_columns(
    df,
    outputs=outputs,
    vol_input=outputs["Input"],
    vx_size=1.0,                # your voxel size (e.g., mm or arbitrary units)
    crop=(128,256,256),
)
display(df)
