In [1]:
# --- core ---
import os, time, math
from pathlib import Path
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from tqdm.auto import tqdm


# --- your project modules ---
from deblur3d.losses import DeblurLoss, ssim3d
from deblur3d.models import UNet3D_Residual
from deblur3d.data import MultiPageTiffDataset
from deblur3d.transforms import GaussianIsoBlurCPUTransform


# --- mlflow ---
import mlflow
mlflow.set_tracking_uri("http://127.0.0.1:5000")   # MLflow server URI
mlflow.set_experiment("deblur3d_unet")            # experiment name

# --------- config ---------
index_path      = r"T:\users\taki\Dataset_L\index_with_split.xlsx"  # or .xlsx/.csv
patch_size      = (64, 256, 256)
batch_size      = 8
num_workers     = 4
seed            = 42
epochs          = 10
lr              = 1e-3
weight_decay    = 1e-4
betas           = (0.9, 0.99)
amp_enabled     = torch.cuda.is_available()
save_dir        = Path("./checkpoints"); save_dir.mkdir(parents=True, exist_ok=True)
run_name        = f"unet3d_residual_ps{patch_size}_bs{batch_size}"

# Blur transform (CPU)
blur_tf = GaussianIsoBlurCPUTransform(
    fwhm_range=(4, 6),
    radius_mult=3,
    add_noise=True,
    poisson_gain_range=(200, 600),
    read_noise_std_range=(0.004, 0.008),
)

# Loss
criterion = DeblurLoss(
    w_l1=0.7, w_ssim=0.1, w_freq=0.1, id_weight=0.1,
    use_relative_freq=True, freq_alpha=1.0
)

assert torch.cuda.device_count() >= 2, "Need at least 2 GPUs for DataParallel"
device_ids = [0, 1]          # or [0,1] explicitly
main_device = torch.device(f"cuda:{device_ids[0]}")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Train / Val datasets from the index file with 'split' column
ds_train = MultiPageTiffDataset(
    manifest_path=index_path,
    split="train",
    patch_size=patch_size,
    blur_transform=blur_tf,
    balance=None,              # or "slice_count"
    samples_per_epoch=None,    # set an int to cap per-epoch samples, else len(vols)
    seed=seed,
)

ds_val = MultiPageTiffDataset(
    manifest_path=index_path,
    split="val",
    patch_size=patch_size,
    blur_transform=blur_tf,    # keep or remove blur for val (often you want identity blur)
    balance=None,
    samples_per_epoch=None,
    seed=seed+1,
)

loader_train = DataLoader(
    ds_train,
    batch_size=min(batch_size, len(ds_train)),
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=(num_workers > 0),
)

loader_val = DataLoader(
    ds_val,
    batch_size=min(batch_size, len(ds_val)),
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=(num_workers > 0),
)

print(f"Train vols: {len(ds_train)} | Val vols: {len(ds_val)}")

Train vols: 418 | Val vols: 49


In [3]:
net_single = UNet3D_Residual(in_ch=1, base=16, levels=4).to(main_device)
net = torch.nn.DataParallel(net_single, device_ids=device_ids)
opt = torch.optim.AdamW(net.parameters(), lr=lr, weight_decay=weight_decay, betas=betas)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
scaler = GradScaler(enabled=amp_enabled)

def to_ch(x):  # (B,1,D,H,W)
    return x.unsqueeze(1)

In [4]:
best_psnr = -1.0
best_path = save_dir / "deblur3d_unet_best.pt"

with mlflow.start_run(run_name=run_name):
    mlflow.log_params({
        "model": "UNet3D_Residual",
        "patch_size": str(patch_size),
        "batch_size": batch_size,
        "num_workers": num_workers,
        "optimizer": "AdamW",
        "lr": lr,
        "weight_decay": weight_decay,
        "betas": str(betas),
        "epochs": epochs,
        "amp": amp_enabled,
        "index_path": index_path,
    })

    for epoch in range(1, epochs + 1):
        t0 = time.perf_counter()
        net.train()
        train_loss_sum = 0.0
        train_count = 0

        pbar = tqdm(loader_train, desc=f"Epoch {epoch:03d} [train]", leave=False)
        for sharp, blurred in pbar:
            sharp   = to_ch(sharp).to(main_device, non_blocking=True)
            blurred = to_ch(blurred).to(main_device, non_blocking=True)

            with autocast(enabled=amp_enabled):
                pred = net(blurred)
                loss = criterion(pred, sharp, blurred)

            scaler.scale(loss).backward()
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
            scaler.step(opt)
            scaler.update()
            opt.zero_grad(set_to_none=True)

            bsz = sharp.size(0)
            train_loss_sum += loss.item() * bsz
            train_count += bsz

            # live progress
            pbar.set_postfix(
                loss=f"{train_loss_sum/max(train_count,1):.4f}",
                lr=f"{sched.get_last_lr()[0]:.2e}"
            )

        sched.step()
        train_loss = train_loss_sum / max(train_count, 1)

        # ---- validation ----
        net.eval()
        psnr_sum = 0.0
        ssim_sum = 0.0
        val_count = 0

        pbar_val = tqdm(loader_val, desc=f"Epoch {epoch:03d} [val]  ", leave=False)
        with torch.no_grad():
            for sharp, blurred in pbar_val:
                sharp   = to_ch(sharp).to(main_device, non_blocking=True)
                blurred = to_ch(blurred).to(main_device, non_blocking=True)
                with autocast(enabled=amp_enabled):
                    pred = net(blurred)
                    mse  = F.mse_loss(pred, sharp, reduction='none').mean(dim=(1,2,3,4))
                    psnr = 10 * torch.log10(1.0 / (mse + 1e-12))
                    ssim_vals = ssim3d(pred, sharp).detach()

                psnr_sum += psnr.sum().item()
                ssim_sum += ssim_vals.sum().item()
                val_count += sharp.size(0)

                # live progress
                cur_psnr = psnr.mean().item()
                cur_ssim = ssim_vals.mean().item()
                pbar_val.set_postfix(psnr=f"{cur_psnr:.2f}", ssim=f"{cur_ssim:.3f}")

        val_psnr = psnr_sum / max(val_count, 1)
        val_ssim = ssim_sum / max(val_count, 1)
        epoch_time = time.perf_counter() - t0

        # console summary line
        print(f"Epoch {epoch:03d} | train_loss {train_loss:.4f} | "
              f"val_psnr {val_psnr:.2f} dB | val_ssim {val_ssim:.4f} | "
              f"time {epoch_time:.1f}s")

        # MLflow
        mlflow.log_metrics({
            "train_loss": train_loss,
            "val_psnr": val_psnr,
            "val_ssim": val_ssim,
            "epoch_time_s": epoch_time,
            "lr": sched.get_last_lr()[0],
        }, step=epoch)

        # save best
        is_best = val_psnr > best_psnr
        if is_best:
            best_psnr = val_psnr
            torch.save({"epoch": epoch, "state_dict": getattr(net, "module", net).state_dict()}, best_path)
            mlflow.log_artifact(str(best_path))
            mlflow.log_metric("best_val_psnr", best_psnr, step=epoch)
            print(f"  ↳ saved best: {best_path} (PSNR {best_psnr:.2f} dB)")


                                                                                                                       

Epoch 001 | train_loss 0.1277 | val_psnr 24.23 dB | val_ssim 0.0759 | time 362.8s
  ↳ saved best: checkpoints\deblur3d_unet_best.pt (PSNR 24.23 dB)


                                                                                                                       

Epoch 002 | train_loss 0.1128 | val_psnr 25.53 dB | val_ssim 0.0696 | time 347.5s
  ↳ saved best: checkpoints\deblur3d_unet_best.pt (PSNR 25.53 dB)


                                                                                                                       

Epoch 003 | train_loss 0.1066 | val_psnr 25.46 dB | val_ssim 0.0843 | time 347.9s


                                                                                                                       

Epoch 004 | train_loss 0.1074 | val_psnr 24.75 dB | val_ssim 0.0867 | time 362.9s


                                                                                                                       

Epoch 005 | train_loss 0.1027 | val_psnr 26.31 dB | val_ssim 0.0895 | time 366.1s
  ↳ saved best: checkpoints\deblur3d_unet_best.pt (PSNR 26.31 dB)


                                                                                                                       

Epoch 006 | train_loss 0.1014 | val_psnr 24.97 dB | val_ssim 0.0932 | time 389.3s


                                                                                                                       

Epoch 007 | train_loss 0.0994 | val_psnr 25.46 dB | val_ssim 0.0914 | time 366.0s


                                                                                                                       

Epoch 008 | train_loss 0.1085 | val_psnr 25.50 dB | val_ssim 0.0953 | time 346.9s


                                                                                                                       

Epoch 009 | train_loss 0.1014 | val_psnr 25.35 dB | val_ssim 0.0954 | time 352.0s


                                                                                                                       

Epoch 010 | train_loss 0.1029 | val_psnr 25.67 dB | val_ssim 0.0909 | time 348.1s
🏃 View run unet3d_residual_ps(64, 256, 256)_bs8 at: http://127.0.0.1:5000/#/experiments/168148162419857529/runs/859753a69c4c4d7e831074c288449f11
🧪 View experiment at: http://127.0.0.1:5000/#/experiments/168148162419857529


