In [3]:
# --- CONFIG ---
vol_path   = r"U:\users\taki\vizualization\SlicesY-B17.123120 reb unten.tif"   # <-- set this
ckpt_path  = "deblur3d_unet.pt"                        # model checkpoint
base, levels = 24, 4                                   # must match training
tile     = (64, 256, 256)                              # adjust to fit GPU RAM
overlap  = (32, 128, 128)                              # smooth seams
spacing  = (1.0, 1.0, 1.0)                             # z,y,x voxel size for napari

import os, torch, torch.nn as nn, torch.nn.functional as F
from deblur3d.models import UNet3D_Residual, ControlledUNet3D
from deblur3d.infer import deblur_volume_tiled
from deblur3d.data import read_volume_float01

# --- Load volume via your data module ---
vol = read_volume_float01(vol_path)  # returns (D,H,W) float32 in [0,1]
print("Input volume:", vol.shape, vol.dtype, f"min/max {vol.min():.3f}/{vol.max():.3f}")

# --- Load trained base net and wrap with controller ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert os.path.exists(ckpt_path), f"Checkpoint not found: {ckpt_path}"
net = UNet3D_Residual(in_ch=1, base=base, levels=levels).to(device)
obj = torch.load(ckpt_path, map_location=device)
net.load_state_dict(obj.get("state_dict", obj))
net.eval()
ctrl = ControlledUNet3D(net).to(device).eval()

# --- Define three control presets ---
presets = [
    ("Deblurred α=1.0", dict(strength=1.0)),                                   # baseline
    ("Deblurred α=2.0",    dict(strength=2.0)),                                   # gentler
    ("Deblurred α=5.0", dict(strength=5.0)),
]

# Small wrapper so deblur_volume_tiled can call it like a net
class NetWithControl(nn.Module):
    def __init__(self, ctrl: ControlledUNet3D, **ctrl_kwargs):
        super().__init__()
        self.ctrl = ctrl
        self.ctrl_kwargs = ctrl_kwargs
    @torch.no_grad()
    def forward(self, x):
        return self.ctrl(x, **self.ctrl_kwargs)

# --- Run tiled inference for each preset ---
preds = []
for name, kw in presets:
    net_ctrl = NetWithControl(ctrl, **kw).eval()
    out = deblur_volume_tiled(net_ctrl, vol, tile=tile, overlap=overlap, device=device.type, use_amp=False)
    preds.append((name, out))

# --- Visualize in napari ---
import napari
v = napari.Viewer(ndisplay=2)
L_in = v.add_image(vol,  name="Input", colormap="gray", scale=spacing)
for name, arr in preds:
    layer = v.add_image(arr, name=name, colormap="gray", scale=spacing, opacity=0.8)
    layer.contrast_limits = L_in.contrast_limits
napari.run()

Input volume: (183, 979, 1546) float32 min/max 0.000/1.000


