In [None]:
# ---------- MLflow setup ----------
EXPERIMENT_NAME = "deblur3d_microCT"
RUN_NAME        = "unet3d_residual_base24L4"   # change per run
TRACKING_URI    = None  # e.g. "http://your-mlflow:5000" or None for local ./mlruns

if TRACKING_URI:
    mlflow.set_tracking_uri(TRACKING_URI)
mlflow.set_experiment(EXPERIMENT_NAME)
# ---------- Log static params ----------


model_params = {
    "model": type(net).__name__,
    "in_ch": 1, "base": 24, "levels": 4,          # keep in sync with your model init
    "optimizer": "AdamW", "lr": 2e-4, "weight_decay": 1e-4, "betas": (0.9, 0.99),
    "scheduler": "CosineAnnealingLR", "epochs": 50,
    "amp": bool(use_amp), "device": str(device),
}
data_params = {
    "patch_size": tuple(loader_train.dataset.patch_size) if hasattr(loader_train.dataset, "patch_size") else None,
    "batch_size": loader_train.batch_size,
    "num_workers": loader_train.num_workers,
    "train_volumes": len(getattr(loader_train.dataset, "vols", getattr(loader_train.dataset, "paths", []))),
    "val_exists": loader_val is not None,
}
loss_params = {
    "loss": type(criterion).__name__,
    **{k:getattr(criterion,k) for k in ("w_l1","w_ssim","w_freq","idw") if hasattr(criterion,k)},
}

# ---------- Training loop with MLflow ----------
with mlflow.start_run(run_name=RUN_NAME) as run:
    run_id = run.info.run_id
    print("MLflow run_id:", run_id)

    # log params once
    mlflow.log_params(model_params)
    mlflow.log_params(data_params)
    mlflow.log_params(loss_params)

    # environment snapshot
    env = {
        "python": platform.python_version(),
        "torch": torch.__version__,
        "cuda_available": torch.cuda.is_available(),
        "cuda_device": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
    }
    mlflow.log_dict(env, "env.json")
    req_path = _freeze_requirements()
    if req_path:
        mlflow.log_artifact(req_path, artifact_path="env")

    best_psnr = -1.0
    best_path = "deblur3d_unet.pt"

    def to_ch(x): return x.unsqueeze(1)  # (B,1,D,H,W)

    for epoch in range(1, 51):
        net.train()
        t0 = time.time()
        tr_loss = 0.0
        nvox = 0

        for sharp, blurred in loader_train:
            sharp   = to_ch(sharp).to(device, non_blocking=True)
            blurred = to_ch(blurred).to(device, non_blocking=True)

            with torch.cuda.amp.autocast(enabled=use_amp):
                pred = net(blurred)
                loss = criterion(pred, sharp, blurred)

            scaler.scale(loss).backward()
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
            scaler.step(opt); scaler.update()
            opt.zero_grad(set_to_none=True)

            bs = sharp.size(0)
            tr_loss += loss.item() * bs
            nvox    += bs

        sched.step()
        epoch_time = time.time() - t0
        train_loss_epoch = tr_loss / max(nvox, 1)

        # validation (PSNR)
        net.eval(); psnr_sum, nvox = 0.0, 0
        with torch.no_grad():
            for sharp, blurred in (loader_val or []):  # handle None
                sharp   = to_ch(sharp).to(device, non_blocking=True)
                blurred = to_ch(blurred).to(device, non_blocking=True)
                pred = net(blurred)
                mse  = F.mse_loss(pred, sharp, reduction='none').mean(dim=(1,2,3,4))
                psnr = 10 * torch.log10(1.0 / (mse + 1e-12))
                psnr_sum += psnr.sum().item()
                nvox += sharp.size(0)
        psnr_epoch = (psnr_sum / nvox) if nvox > 0 else float("nan")

        # --- log metrics ---
        mlflow.log_metric("train/loss", train_loss_epoch, step=epoch)
        mlflow.log_metric("time/epoch_sec", epoch_time, step=epoch)
        if not (psnr_epoch != psnr_epoch):  # NaN check
            mlflow.log_metric("val/PSNR_dB", psnr_epoch, step=epoch)

        print(f"Epoch {epoch:03d} | train {train_loss_epoch:.4f} | PSNR {psnr_epoch:.2f} dB | {epoch_time:.1f}s")

        # --- checkpoint & preview on improvement ---
        improved = (not (psnr_epoch != psnr_epoch)) and (psnr_epoch > best_psnr)
        if improved:
            best_psnr = psnr_epoch
            torch.save({"epoch": epoch, "state_dict": net.state_dict()}, best_path)
            mlflow.log_artifact(best_path, artifact_path="checkpoints")

            # small preview artifact
            try:
                # take a tiny batch from val or train
                sample_batch = next(iter(loader_val if loader_val is not None else loader_train))
                s = to_ch(sample_batch[0]).to(device)
                b = to_ch(sample_batch[1]).to(device)
                with torch.no_grad():
                    p = net(b)
                fig_path = _preview_triplet(b, p, s, save_path="preview_epoch.png")
                mlflow.log_artifact(fig_path, artifact_path="figures")
            except Exception as e:
                print("preview logging failed:", e)

    # (optional) log the final model weights artifact one more time
    if os.path.exists(best_path):
        mlflow.log_artifact(best_path, artifact_path="checkpoints_final")
