In [1]:
# --- DATA LOADERS (train/val) ---
import torch
from torch.utils.data import DataLoader, random_split
from deblur3d.data import TiffDataset
from deblur3d.transforms import GaussianIsoBlurCPUTransform

# ==== config ====
manifest_path   = r"C:\Users\taki\DeepDeBlur3D\manifest.parquet"
patch_size      = (64, 256, 256)   # (D,H,W)
batch_size      = 4
num_workers     = 0                # safe: blur runs on CPU
pin_memory      = True
seed            = 42
val_frac        = 0.15             # used if no split column
samples_per_ep  = None             # e.g. 2000 to cap per-epoch length
balance_mode    = "volume"    # "slice_count" or "volume"
# optional quick filtering / size limiting
filter_query    = "size_8bit_GB < 0.5"             # e.g. "n_slices >= 128 and size_8bit_GB < 6"
random_subset   = 0.5             # e.g. 0.2 (20%) or 500 (count)

# CPU-only isotropic Gaussian blur (augmentation)
blur_tf = GaussianIsoBlurCPUTransform(
    fwhm_range=(6, 12), radius_mult=3,
    add_noise=True, poisson_gain_range=(400, 900),
    read_noise_std_range=(0.004, 0.012),
)

g = torch.Generator().manual_seed(seed)


def safe_collate(batch):
    sharps, blurs = zip(*batch)  # list of tensors
    sharps = [s.contiguous().clone().float() for s in sharps]
    blurs  = [b.contiguous().clone().float() for b in blurs]
    return torch.stack(sharps, 0), torch.stack(blurs, 0)


def _make_loader(ds, shuffle, batch_size=batch_size):
    if ds is None: return None
    bs = min(batch_size, len(ds)) if len(ds) > 0 else 1
    return DataLoader(
        ds, batch_size=bs, shuffle=shuffle,
        num_workers=num_workers, pin_memory=pin_memory,
        persistent_workers=False, collate_fn=safe_collate
    )


# Try split-based datasets first (expects 'split' column == 'train'/'val')
try:
    ds_train = TiffDataset(
        manifest_path,
        split="train",
        patch_size=patch_size,
        blur_transform=blur_tf,
        balance=balance_mode,
        samples_per_epoch=samples_per_ep,
        filter_query=filter_query,
        random_subset=random_subset,
        seed=seed,
    )
    ds_val = TiffDataset(
        manifest_path,
        split="val",
        patch_size=patch_size,
        blur_transform=GaussianIsoBlurCPUTransform(fwhm_range=(0, 0), add_noise=False),  # no blur on val
        balance="volume",
        samples_per_epoch=None,
        filter_query=filter_query,
        random_subset=random_subset,
        seed=seed,
    )
    use_split = True
except Exception:
    # Fallback: no split column → single dataset then random split by volumes
    ds_full = TiffDataset(
        manifest_path,
        split=None,
        patch_size=patch_size,
        blur_transform=blur_tf,
        balance=balance_mode,
        samples_per_epoch=None,           # split first, then optionally cap train below
        filter_query=filter_query,
        random_subset=random_subset,
        seed=seed,
    )
    n = len(ds_full)
    n_val = max(1, int(round(val_frac * n)))
    n_tr  = max(1, n - n_val)
    ds_train, ds_val = random_split(ds_full, [n_tr, n_val], generator=g)
    # optional: cap train length after split for faster epochs
    if samples_per_ep is not None and isinstance(ds_train, torch.utils.data.Subset):
        ds_train.dataset.samples_per_epoch = samples_per_ep
    use_split = False

loader_train = _make_loader(ds_train, shuffle=True)
loader_val   = _make_loader(ds_val,   shuffle=False)

print(f"Train vols: {len(ds_train) if use_split else len(ds_train.dataset) if isinstance(ds_train, torch.utils.data.Subset) else len(ds_train)}")
print(f"Val vols:   {len(ds_val) if use_split else len(ds_val.dataset) if isinstance(ds_val, torch.utils.data.Subset) else len(ds_val)}")
print(f"batch_size={loader_train.batch_size}, workers={num_workers}, pin_memory={pin_memory}")

Train vols: 369
Val vols:   66
batch_size=4, workers=0, pin_memory=True


In [2]:
# ==== BOOTSTRAP: model + loss + training helpers ====
import os, time, platform
import torch, torch.nn.functional as F
from torch.cuda.amp import GradScaler
import mlflow

from deblur3d.models import UNet3D_Residual
from deblur3d.losses import DeblurLoss

# --- device / amp ---
device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = torch.cuda.is_available()
scaler  = GradScaler(enabled=use_amp)

# --- model / opt / sched / loss ---
net = UNet3D_Residual(in_ch=1, base=16, levels=8).to(device)  # <- keep in sync with model_params
opt = torch.optim.AdamW(net.parameters(), lr=1e-3, weight_decay=1e-4, betas=(0.9, 0.99))
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=50)
criterion = DeblurLoss(w_l1=0.8, w_ssim=0.2, w_freq=0.05, id_weight=0.1)

# --- tiny helpers used in your loop ---
def _freeze_requirements(path="env/requirements.txt"):
    """Writes `pip freeze` to file; returns path or None if fails."""
    try:
        os.makedirs(os.path.dirname(path), exist_ok=True)
        import subprocess, sys
        req = subprocess.check_output([sys.executable, "-m", "pip", "freeze"], text=True)
        with open(path, "w", encoding="utf-8") as f:
            f.write(req)
        return path
    except Exception:
        return None

def _preview_triplet(b, p, s, save_path="preview_epoch.png"):
    """Save a 3-panel mid-slice preview: input (blurred), pred, sharp."""
    import matplotlib.pyplot as plt
    with torch.no_grad():
        b = b.detach().float(); p = p.detach().float(); s = s.detach().float()
        # (B,1,D,H,W) -> pick first sample, central Z
        b0, p0, s0 = b[0,0], p[0,0], s[0,0]
        z = b0.shape[0] // 2
        fig = plt.figure(figsize=(10, 3.2), dpi=120)
        ax = fig.add_subplot(1,3,1); ax.imshow(b0[z].cpu(), cmap="gray"); ax.set_title("Input"); ax.axis("off")
        ax = fig.add_subplot(1,3,2); ax.imshow(p0[z].cpu(), cmap="gray"); ax.set_title("Pred");  ax.axis("off")
        ax = fig.add_subplot(1,3,3); ax.imshow(s0[z].cpu(), cmap="gray"); ax.set_title("Sharp"); ax.axis("off")
        plt.tight_layout()
        fig.savefig(save_path, bbox_inches="tight")
        plt.close(fig)
    return save_path

# --- MLflow tracking URI default (local folder) if not set elsewhere ---
if 'TRACKING_URI' in globals() and TRACKING_URI:
    mlflow.set_tracking_uri(TRACKING_URI)
else:
    mlruns_path = os.path.abspath("mlruns").replace("\\", "/")
    mlflow.set_tracking_uri("file:///" + mlruns_path)


In [3]:
# ---------- MLflow setup ----------
EXPERIMENT_NAME = "deblur3d_microCT"
RUN_NAME        = "unet3d_residual_base24L4"   # change per run
TRACKING_URI    =   "http://127.0.0.1:5000/"

if TRACKING_URI:
    mlflow.set_tracking_uri(TRACKING_URI)
mlflow.set_experiment(EXPERIMENT_NAME)
# ---------- Log static params ----------


model_params = {
    "model": type(net).__name__,
    "in_ch": 1, "base": 16, "levels": 8,   # ← match the net above
    "optimizer": "AdamW", "lr": 1e-3, "weight_decay": 1e-4, "betas": (0.9, 0.99),
    "scheduler": "CosineAnnealingLR", "epochs": 20,
    "amp": bool(use_amp), "device": str(device),
}
data_params = {
    "patch_size": tuple(loader_train.dataset.patch_size) if hasattr(loader_train.dataset, "patch_size") else None,
    "batch_size": loader_train.batch_size,
    "num_workers": loader_train.num_workers,
    "train_volumes": len(getattr(loader_train.dataset, "vols", getattr(loader_train.dataset, "paths", []))),
    "val_exists": loader_val is not None,
}
loss_params = {
    "loss": type(criterion).__name__,
    **{k:getattr(criterion,k) for k in ("w_l1","w_ssim","w_freq","idw") if hasattr(criterion,k)},
}

# ---------- Training loop with MLflow ----------
with mlflow.start_run(run_name=RUN_NAME) as run:
    run_id = run.info.run_id
    print("MLflow run_id:", run_id)

    # log params once
    mlflow.log_params(model_params)
    mlflow.log_params(data_params)
    mlflow.log_params(loss_params)

    # environment snapshot
    env = {
        "python": platform.python_version(),
        "torch": torch.__version__,
        "cuda_available": torch.cuda.is_available(),
        "cuda_device": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
    }
    mlflow.log_dict(env, "env.json")
    req_path = _freeze_requirements()
    if req_path:
        mlflow.log_artifact(req_path, artifact_path="env")

    best_psnr = -1.0
    best_path = "deblur3d_unet.pt"

    def to_ch(x): return x.unsqueeze(1)  # (B,1,D,H,W)

    for epoch in range(1, 51):
        net.train()
        t0 = time.time()
        tr_loss = 0.0
        nvox = 0

        for sharp, blurred in loader_train:
            sharp   = to_ch(sharp).to(device, non_blocking=True)
            blurred = to_ch(blurred).to(device, non_blocking=True)

            with torch.cuda.amp.autocast(enabled=use_amp):
                pred = net(blurred)
                loss = criterion(pred, sharp, blurred)

            scaler.scale(loss).backward()
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
            scaler.step(opt); scaler.update()
            opt.zero_grad(set_to_none=True)

            bs = sharp.size(0)
            tr_loss += loss.item() * bs
            nvox    += bs

        sched.step()
        epoch_time = time.time() - t0
        train_loss_epoch = tr_loss / max(nvox, 1)

        # validation (PSNR)
        net.eval(); psnr_sum, nvox = 0.0, 0
        with torch.no_grad():
            for sharp, blurred in (loader_val or []):  # handle None
                sharp   = to_ch(sharp).to(device, non_blocking=True)
                blurred = to_ch(blurred).to(device, non_blocking=True)
                pred = net(blurred)
                mse  = F.mse_loss(pred, sharp, reduction='none').mean(dim=(1,2,3,4))
                psnr = 10 * torch.log10(1.0 / (mse + 1e-12))
                psnr_sum += psnr.sum().item()
                nvox += sharp.size(0)
        psnr_epoch = (psnr_sum / nvox) if nvox > 0 else float("nan")

        # --- log metrics ---
        mlflow.log_metric("train/loss", train_loss_epoch, step=epoch)
        mlflow.log_metric("time/epoch_sec", epoch_time, step=epoch)
        if not (psnr_epoch != psnr_epoch):  # NaN check
            mlflow.log_metric("val/PSNR_dB", psnr_epoch, step=epoch)

        print(f"Epoch {epoch:03d} | train {train_loss_epoch:.4f} | PSNR {psnr_epoch:.2f} dB | {epoch_time:.1f}s")

        # --- checkpoint & preview on improvement ---
        improved = (not (psnr_epoch != psnr_epoch)) and (psnr_epoch > best_psnr)
        if improved:
            best_psnr = psnr_epoch
            torch.save({"epoch": epoch, "state_dict": net.state_dict()}, best_path)
            mlflow.log_artifact(best_path, artifact_path="checkpoints")

            # small preview artifact
            try:
                # take a tiny batch from val or train
                sample_batch = next(iter(loader_val if loader_val is not None else loader_train))
                s = to_ch(sample_batch[0]).to(device)
                b = to_ch(sample_batch[1]).to(device)
                with torch.no_grad():
                    p = net(b)
                fig_path = _preview_triplet(b, p, s, save_path="preview_epoch.png")
                mlflow.log_artifact(fig_path, artifact_path="figures")
            except Exception as e:
                print("preview logging failed:", e)

    # (optional) log the final model weights artifact one more time
    if os.path.exists(best_path):
        mlflow.log_artifact(best_path, artifact_path="checkpoints_final")


MLflow run_id: 19593107f6b24f5c8cc15d5bcaa6d952
🏃 View run unet3d_residual_base24L4 at: http://127.0.0.1:5000/#/experiments/139863174919178366/runs/19593107f6b24f5c8cc15d5bcaa6d952
🧪 View experiment at: http://127.0.0.1:5000/#/experiments/139863174919178366


RuntimeError: Calculated padded input size per channel: (1 x 4 x 4). Kernel size: (2 x 2 x 2). Kernel size can't be greater than actual input size