In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import time, math
from pathlib import Path

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cpu


In [None]:
npz_path =  "scan20_splits.npz" 

In [5]:
data = np.load(npz_path, allow_pickle=False)
print("Saved keys:", list(data.keys()))

Saved keys: ['ksp', 'sens', 'theta', 'lam', 'val', 'omega', 'split_seed']


In [6]:
def to_torch(x, dtype=None, dev=device):
    t = torch.from_numpy(x)
    if dtype is not None:
        t = t.to(dtype)
    return t.to(dev, non_blocking=True)

In [7]:
ksp    = to_torch(data["ksp"],   dtype=torch.complex64)   # (P,C,X,Y,Z)
sens   = to_torch(data["sens"],  dtype=torch.complex64)   # (C,X,Y,Z)
theta  = to_torch(data["theta"], dtype=torch.bool)        # (P,Y,Z)
lam  = to_torch(data["lam"],   dtype=torch.bool)        # (P,Y,Z)
val  = to_torch(data["val"], dtype=torch.bool)        # (P,Y,Z)
omega = to_torch(data["omega"], dtype=torch.bool)        # (P,Y,Z)

In [8]:
P, C, X, Y, Z = ksp.shape
print(f"ksp:   {tuple(ksp.shape)}  {ksp.dtype}")
print(f"sens:  {tuple(sens.shape)} {sens.dtype}")
print(f"theta: {tuple(theta.shape)} {theta.dtype}")
print(f"lam:   {tuple(lam.shape)} {lam.dtype}")
print(f"val: {tuple(val.shape)} {val.dtype}")
print(f"omega: {tuple(omega.shape)} {omega.dtype}")

ksp:   (20, 34, 128, 120, 80)  torch.complex64
sens:  (34, 128, 120, 80) torch.complex64
theta: (20, 120, 80) torch.bool
lam:   (20, 120, 80) torch.bool
val: (20, 120, 80) torch.bool
omega: (20, 120, 80) torch.bool


In [9]:
def precompute_broadcast_masks(theta, lam_mask, gamma, omega=None):
    """
    Inputs: (P,Y,Z) bool
    Returns: dict of (P,1,1,Y,Z) bool (contiguous, on same device)
    Also includes complements *_c for convenience.
    """
    def b(m): 
        return m[:, None, None, ...].contiguous()

    masks = {
        "theta": b(theta),
        "lam":   b(lam_mask),
        "gamma": b(gamma),
    }
    if omega is not None:
        masks["omega"] = b(omega)

    # assertion
    P, _, _, Y, Z = masks["theta"].shape
    assert masks["lam"].shape   == (P,1,1,Y,Z)
    assert masks["gamma"].shape == (P,1,1,Y,Z)
    if "omega" in masks:
        assert masks["omega"].shape == (P,1,1,Y,Z)

    return masks

bmasks = precompute_broadcast_masks(theta, lam, val, omega)

In [10]:
bmasks.keys()
print(bmasks['theta'].shape)
print(bmasks['lam'].shape)
print(bmasks['gamma'].shape)

torch.Size([20, 1, 1, 120, 80])
torch.Size([20, 1, 1, 120, 80])
torch.Size([20, 1, 1, 120, 80])


In [None]:
class ConvBNAct(nn.Module):
    def __init__(self, in_ch, out_ch, k=3):
        super().__init__()
        p = k//2
        self.conv = nn.Conv2d(in_ch, out_ch, k, padding=p)
        self.gn   = nn.GroupNorm(num_groups=8, num_channels=out_ch)
        self.act  = nn.LeakyReLU(0.1, inplace=True)
    def forward(self, x):
        return self.act(self.gn(self.conv(x)))

class UNetBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.c1 = ConvBNAct(in_ch,  out_ch)
        self.c2 = ConvBNAct(out_ch, out_ch)
    def forward(self, x):
        x = self.c1(x)
        x = self.c2(x)
        return x

class Pseudo3DUNet2p5D(nn.Module):
    """
    Input:  (B, 2*k, H, W) — k neighboring slices, real/imag stacked as channels
    Output: (B, 2,   H, W) — residual for center slice (real, imag)
    """
    def __init__(self, in_ch=10, base=32, out_ch=2, residual_scale=0.1):
        super().__init__()
        self.enc1 = UNetBlock(in_ch,   base)
        self.enc2 = UNetBlock(base,    base*2)
        self.enc3 = UNetBlock(base*2,  base*4)

        self.down1 = nn.Conv2d(base,   base,   3, stride=2, padding=1)
        self.down2 = nn.Conv2d(base*2, base*2, 3, stride=2, padding=1)

        self.dec2 = UNetBlock(base*4 + base*2, base*2)
        self.dec1 = UNetBlock(base*2 + base,   base)

        self.out  = nn.Conv2d(base, out_ch, 1)
        self.res_scale = residual_scale

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)                  # (B, B, H, W)
        x  = self.down1(e1)                # (B, B, H/2, W/2)
        e2 = self.enc2(x)                  # (B, 2B, H/2, W/2)
        x  = self.down2(e2)                # (B, 2B, H/4, W/4)
        e3 = self.enc3(x)                  # (B, 4B, H/4, W/4)

        # Decoder
        x  = F.interpolate(e3, scale_factor=2, mode='bilinear', align_corners=False)
        x  = torch.cat([x, e2], dim=1)
        x  = self.dec2(x)

        x  = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x  = torch.cat([x, e1], dim=1)
        x  = self.dec1(x)

        out = self.out(x)
        return out * self.res_scale

In [13]:
def A(x: torch.Tensor,
      mask: torch.Tensor,
      sens: torch.Tensor) -> torch.Tensor:
    """
    Forward model: image → undersampled k-space.
    x    : [X, Y, Z] complex64
    mask : [1, 1, Y, Z]
    sens : [C, X, Y, Z]
    returns y : [C, X, Y, Z] complex64
    """
    # coil images: [C, X, Y, Z]
    x_coils = sens * x.unsqueeze(0)
    # full 3D FFT on (X,Y,Z)
    y_full = torch.fft.fftn(x_coils, dim=(1,2,3), norm="ortho")
    # apply sampling mask (broadcasts over coil & X)
    return y_full * mask


def AH(y: torch.Tensor,
       mask: torch.Tensor,
       sens: torch.Tensor) -> torch.Tensor:
    """
    Adjoint model: undersampled k-space → image.
    y    : [C, X, Y, Z] complex64
    mask : [1, 1, Y, Z]
    sens : [C, X, Y, Z]
    returns x : [X, Y, Z] complex64
    """
    # zero out unmeasured lines
    y_masked = y * mask
    # inverse 3D FFT
    x_coils = torch.fft.ifftn(y_masked, dim=(1,2,3), norm="ortho")
    # coil-combine: sum conj(sens) * coil images
    return torch.sum(torch.conj(sens) * x_coils, dim=0)


def normal_op(x: torch.Tensor,
              mask: torch.Tensor,
              sens: torch.Tensor,
              lam: float) -> torch.Tensor:
    """
    (A^H A + λ I) x
    x    : [X, Y, Z] complex64
    mask : [1, 1, Y, Z]
    sens : [C, X, Y, Z]
    lam  : float
    """
    return AH(A(x, mask, sens), mask, sens) + lam * x

In [14]:
def cg_sense(
    rhs, mask, sens, lam,
    iters=8, tol=1e-3, x0=None,
    A=A, AH=AH
):
    """
    CG-SENSE with optional warm-start x0 and early-stop on relative residual.
    Returns (x, info_dict).
    """
    def normal_op(x):
        return AH(A(x, mask, sens), mask, sens) + lam * x

    x = rhs.clone() if x0 is None else x0.clone()
    r = rhs - normal_op(x)
    p = r.clone()
    rsold = torch.vdot(r.flatten(), r.flatten()).real

    relres_traj = []
    for i in range(1, iters + 1):
        Ap = normal_op(p)
        alpha = rsold / torch.vdot(p.flatten(), Ap.flatten()).real
        x = x + alpha * p
        r = r - alpha * Ap
        rsnew = torch.vdot(r.flatten(), r.flatten()).real
        relres = (rsnew.sqrt() / (rhs.norm() + 1e-12)).item()
        relres_traj.append(relres)
        if relres < tol:
            break
        beta = rsnew / rsold
        p = r + beta * p
        rsold = rsnew

    return x, {"iters": i, "relres": relres_traj}

In [15]:
def apply_prior_25d(prior25d: nn.Module, x3d: torch.Tensor, k: int = 5) -> torch.Tensor:
    """
    x3d: complex tensor (X, Y, Z)
    returns: complex residual (X, Y, Z)
    """
    assert x3d.is_complex()
    X, Y, Z = x3d.shape
    r = k // 2

    # pad along Z with replication
    x5  = x3d.unsqueeze(0).unsqueeze(0)
    x5p = F.pad(x5, pad=(r, r, 0, 0, 0, 0), mode='replicate')  # (X, Y, Z+2r)

    # build (Z, 2*k, X, Y) batch for the UNet
    slabs = []
    for z in range(Z):
        slab = x5p[..., z:z+k].squeeze(0).squeeze(0)                 # (X, Y, k) complex
        real = slab.real.permute(2, 0, 1)        # (k, X, Y)
        imag = slab.imag.permute(2, 0, 1)        # (k, X, Y)
        inp  = torch.cat([real, imag], dim=0)    # (2k, X, Y)
        slabs.append(inp.contiguous())
    inp_batch = torch.stack(slabs, dim=0)        # (Z, 2k, X, Y)

    # run prior in batch over slices
    out = prior25d(inp_batch)                    # (Z, 2, X, Y)

    # back to complex (X, Y, Z)
    out_c = torch.complex(out[:,0], out[:,1]).permute(1,2,0).contiguous()
    return out_c

In [16]:
def kspace_l1_l2_loss(y_pred, y_true, alpha=0.5, eps=1e-6):
    diff = y_pred - y_true
    mag  = torch.abs(diff)
    l1   = mag.mean()
    l2   = (mag**2).mean()
    return alpha*l1 + (1-alpha)*l2

In [None]:
def solve_cg(rhs, mask, sens, lam, iters, tol, x0=None):
    out = cg_sense(rhs, mask, sens, lam, iters=iters, tol=tol, x0=x0)
    return out[0] if isinstance(out, tuple) else out

@torch.no_grad()
def _tensor_float(x):
    return float(x.detach().cpu().item())

In [18]:
def make_lambda_schedule(lam_start=3e-2, lam_end=5e-3, T=30):
    """
    Cosine decay from lam_start → lam_end over T epochs.
    After T epochs it stays at lam_end.
    """
    def lam_at_epoch(ep):
        t = min(max(ep-1, 0), T-1) / max(T-1, 1)
        w = 0.5*(1 + math.cos(math.pi * t))  # 1→0
        return lam_end + (lam_start - lam_end) * w
    return lam_at_epoch


In [19]:
def save_ckpt(path, epoch, val_loss, model, opt):
    torch.save(
        {"epoch": epoch, "val_loss": val_loss,
         "model_state": model.state_dict(),
         "optim_state": opt.state_dict()},
        path
    )

In [None]:
def train_one_epoch(
    prior, opt,
    ksp, sens, bmasks,
    A, AH, solve_cg, apply_prior_2d, kspace_l1_l2_loss,
    lam_param,                  # learnable λ parameter (nn.Parameter)
    alpha=0.5,
    unroll=5, cg_iters=8, cg_tol=1e-3,
    k_neighbors=3, device="cuda"
):
    """
    Changes:
      - ZF anchor: x_anchor = AH(y_theta)
      - Learnable λ via softplus
      - γ removed; residual scaled by λ (per request)
    """
    prior.train()
    P = ksp.shape[0]
    phase_order = torch.randperm(P, device=device)

    running_loss = 0.0
    running_gn   = 0.0
    steps = 0

    lam_eps = 1e-8  # keep λ strictly positive

    for p in phase_order.tolist():
        mask_t = bmasks["theta"][p]   # (1,1,Y,Z)
        mask_l = bmasks["lam"][p]     # (1,1,Y,Z)

        # Θ measurements
        y_t   = ksp[p] * mask_t

        # ZF anchor / adjoint
        x_anchor = AH(y_t, mask_t, sens)

        # learnable λ (shared for DC and residual scale)
        lam = F.softplus(lam_param) + lam_eps

        # unroll (CG in-graph)
        x = x_anchor.clone()
        for _ in range(unroll):
            r     = apply_prior_2d(prior, x)
            rhs   = x_anchor + lam * r                
            x     = solve_cg(rhs, mask_t, sens, lam, cg_iters, cg_tol, x0=x)

        # Λ loss
        y_l   = ksp[p] * mask_l
        y_hat = A(x, mask_l, sens)
        loss  = kspace_l1_l2_loss(y_hat, y_l, alpha)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(prior.parameters(), 5e-5)
        opt.step()

        running_loss += float(loss.item())

        # simple average grad-norm over prior params for logging
        try:
            gsum, cnt = 0.0, 0
            for q in prior.parameters():
                if q.grad is not None:
                    gsum += q.grad.detach().norm().item()
                    cnt  += 1
            if cnt > 0:
                running_gn += gsum / cnt
        except Exception:
            pass

        steps += 1

    return {
        "train_L": running_loss / max(steps, 1),
        "grad":    running_gn   / max(steps, 1),
        "lam":     _tensor_float(lam),
    }

In [None]:
@torch.no_grad()
def validate_all_phases(
    prior,
    ksp, sens, bmasks,
    A, AH, solve_cg, apply_prior_25d, kspace_l1_l2_loss,
    lam_param,                 # learnable λ (nn.Parameter)
    alpha=0.5,
    unroll=5, cg_iters=8, cg_tol=1e-3,
    k_neighbors=3, device="cuda"
):
    prior.eval()

    P = ksp.shape[0]
    val_losses = []
    mΩ_list, eΩc_list = [], []

    lam_eps = 1e-8
    lam = F.softplus(lam_param) + lam_eps  # shared scale

    ones_mask = None  # cache a full-ones mask for A(x, ·, sens)

    for p in range(P):
        mask_t = bmasks["theta"][p]    # (1,1,Y,Z)
        mask_l = bmasks["lam"][p]      # (1,1,Y,Z)
        mask_Ω = bmasks["omega"][p]    # (1,1,Y,Z)

        # Θ measurements and ZF anchor
        y_t     = ksp[p] * mask_t
        x_anchor = AH(y_t, mask_t, sens)

        # unroll with λ-scaled residual (γ removed)
        x = x_anchor.clone()
        for _ in range(unroll):
            r     = apply_prior_25d(prior, x)
            rhs   = x_anchor + lam * r
            x     = solve_cg(rhs, mask_t, sens, lam, cg_iters, cg_tol, x0=x)

        # Λ-loss on Λ mask
        y_l   = ksp[p] * mask_l
        y_hat = A(x, mask_l, sens)
        loss  = kspace_l1_l2_loss(y_hat, y_l, alpha)
        val_losses.append(float(loss.item()))

        # --- minimal Ω metrics (broadcast-safe) ---
        if ones_mask is None:
            ones_mask = torch.ones_like(mask_t, dtype=mask_t.dtype, device=mask_t.device)

        y_full  = A(x, ones_mask, sens)                # (C,X,Y,Z)
        diff2   = (y_full - ksp[p]).abs().pow(2)

        mΩ      = mask_Ω.to(diff2.dtype)               # (1,1,Y,Z), broadcasts over (C,X)
        mΩc     = (1.0 - mΩ)

        C, X = diff2.shape[0], diff2.shape[1]
        denom_Ω  = (mΩ.sum()  * C * X).clamp_min(1.0)
        denom_Ωc = (mΩc.sum() * C * X).clamp_min(1.0)

        mse_acq   = (diff2 * mΩ ).sum() / denom_Ω
        energy_un = (y_full.abs().pow(2) * mΩc).sum() / denom_Ωc

        mΩ_list.append(float(mse_acq.item()))
        eΩc_list.append(float(energy_un.item()))

    val_loss = sum(val_losses) / max(len(val_losses), 1)
    mΩ_mean  = sum(mΩ_list)   / max(len(mΩ_list), 1)
    eΩc_mean = sum(eΩc_list)  / max(len(eΩc_list), 1)

    return val_loss, mΩ_mean, eΩc_mean, float(lam.detach().cpu().item())

In [None]:
# --- config ---
epochs        = 100
unroll_tr     = 5
unroll_val    = 5
cg_tr_iters   = 8
cg_val_iters  = 25
cg_tol_tr     = 1e-3
cg_tol_val    = 1e-6
alpha         = 0.5
k_neighbors   = None

prior = Pseudo3DUNet2p5D(in_ch=2*k_neighbors).to(device)

# --- learnable λ: param in softplus space so λ=softplus(param) > 0 ---
def inv_softplus(y):
    # inverse of softplus for positive y
    return torch.log(torch.expm1(torch.as_tensor(y, device=device)))

lam_init = 3e-2
lam_param = torch.nn.Parameter(inv_softplus(lam_init))  # so softplus(lam_param) ≈ 3e-2

# --- optimizer + schedulers (two param groups: prior + λ) ---
start_lr   = 5e-4
lam_lr     = 1e-4   # usually smaller lr for λ is stabler
opt = torch.optim.Adam([
    {"params": prior.parameters(), "lr": start_lr, "weight_decay": 0.0},
    {"params": [lam_param],        "lr": lam_lr,   "weight_decay": 0.0},
])

warmup_epochs = 3
total_epochs  = epochs
scheduler = torch.optim.lr_scheduler.SequentialLR(
    opt,
    schedulers=[
        torch.optim.lr_scheduler.LinearLR(opt, start_factor=0.2, total_iters=warmup_epochs),
        torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_epochs - warmup_epochs, eta_min=5e-6),
    ],
    milestones=[warmup_epochs],
)

best_val = float('inf')

for ep in range(epochs):
    t0 = time.time()

    train_stats = train_one_epoch(
        prior, opt,
        ksp, sens, bmasks,
        A, AH, solve_cg, apply_prior_25d, kspace_l1_l2_loss,
        lam_param=lam_param, alpha=alpha,
        unroll=unroll_tr, cg_iters=cg_tr_iters, cg_tol=cg_tol_tr,
        k_neighbors=k_neighbors, device=device
    )

    val_loss, val_mΩ, val_eΩc, lam_scalar = validate_all_phases(
        prior, ksp, sens, bmasks,
        A, AH, solve_cg, apply_prior_25d, kspace_l1_l2_loss,
        lam_param=lam_param, alpha=alpha,
        unroll=unroll_val, cg_iters=cg_val_iters, cg_tol=cg_tol_val,
        k_neighbors=k_neighbors, device=device
    )

    scheduler.step()
    save_ckpt("latest_prior.pt", ep+1, val_loss, prior, opt)  # (optional) extend to also save lam_param

    if val_loss < best_val - 1e-9:
        best_val = val_loss
        save_ckpt("best_prior.pt", ep+1, val_loss, prior, opt)  # (optional) extend to also save lam_param

    dt = time.time() - t0
    print(
        f"[Epoch {ep:3d}] train_Λ={train_stats['train_L']:.3e}  ‖∇‖={train_stats['grad']:.3e}  "
        f"val={val_loss:.3e}  (Ω_MSE={val_mΩ:.3e}, Ωc_E={val_eΩc:.3e})  "
        f"λ={lam_scalar:.2e}  LR={opt.param_groups[0]['lr']:.1e}  time={dt:.1f}s"
    )

In [None]:
# --- lean all-phases inference with per-phase gamma sweep ---
import os, csv, time
import numpy as np
import torch

# ==== config ====
npz_path     = "scan20_splits.npz"   # your saved data
out_dir      = "inference_all_phases" # outputs here
device       = torch.device("cuda" if torch.cuda.is_available() else "cpu")

lam_fixed    = 5e-2
gamma_grid   = [2e-2, 3e-2, 5e-2]     # per-phase sweep
unroll       = 5
cg_iters     = 60
cg_tol       = 1e-8
k_neighbors  = 3

energy_cap   = 1.10e-07               # Ωᶜ energy cap (mean power on unacquired)
fallback_eta = 1.0                    # weight if nothing meets the cap

os.makedirs(out_dir, exist_ok=True)

# ==== helpers ====
def to_torch(x):
    return torch.from_numpy(x).to(device)

@torch.no_grad()
def sos_from_image_and_sens(x_img, sens):
    """
    x_img: (X,Y,Z) complex
    sens : (C,X,Y,Z) complex
    returns SoS magnitude volume (X,Y,Z)
    """
    coil_imgs = x_img.unsqueeze(0) * sens  # (C,X,Y,Z)
    sos = torch.sqrt((coil_imgs.abs()**2).sum(dim=0) + 1e-12)
    return sos

@torch.no_grad()
def kspace_metrics(x, ksp, mask_omega, A, sens):
    """
    x: (X,Y,Z) complex, ksp: (C,X,Y,Z) complex, mask_omega: (1,1,Y,Z) float/bool
    Ω_MSE over acquired lines, Ωᶜ_E mean energy on unacquired.
    """
    ones = torch.ones_like(mask_omega)
    y_hat = A(x, ones, sens)          # (C,X,Y,Z)
    diff  = y_hat - ksp
    m_acq = mask_omega.bool()
    m_un  = ~m_acq
    # broadcasting across (C,X)
    mse_acq   = ((diff.abs()**2)[..., m_acq.squeeze(0).squeeze(0)].mean()).item()
    energy_un = ((y_hat.abs()**2)[..., m_un.squeeze(0).squeeze(0)].mean()).item()
    return mse_acq, energy_un

@torch.no_grad()
def unroll_once(x, x_th, Ah_y, mask_omega, sens, lam, gamma):
    r     = apply_prior_25d(prior, x, k=k_neighbors)  # residual in image space
    x_til = x_th + gamma * r
    rhs   = Ah_y + lam * x_til
    x_new = solve_cg(rhs, mask_omega, sens, lam, cg_iters, cg_tol, x0=x)
    return x_new

@torch.no_grad()
def infer_one_phase(ksp_p, sens, mask_omega, lam, gammas):
    """
    Returns dict with chosen recon, cg_sos, recon_sos, chosen gamma and metrics per gamma.
    """
    # CG-SENSE anchor on Ω (acquired lines)
    y_acq = ksp_p * mask_omega
    Ah_y  = AH(y_acq, mask_omega, sens)
    x_th  = solve_cg(Ah_y, mask_omega, sens, lam, cg_iters, cg_tol, x0=None)

    results = []
    for g in gammas:
        x = x_th.clone()
        for _ in range(unroll):
            x = unroll_once(x, x_th, Ah_y, mask_omega, sens, lam, g)
        mse_acq, energy_un = kspace_metrics(x, ksp_p, mask_omega, A, sens)
        results.append(dict(gamma=g, x=x, mse=mse_acq, e_un=energy_un))

    # pick best under cap; else fallback to score = mse + eta*e_un
    under = [r for r in results if r["e_un"] <= energy_cap]
    if len(under) > 0:
        best = min(under, key=lambda r: r["mse"])
    else:
        best = min(results, key=lambda r: r["mse"] + fallback_eta * r["e_un"])

    # SoS volumes
    cg_sos    = sos_from_image_and_sens(x_th, sens)
    recon_sos = sos_from_image_and_sens(best["x"], sens)

    return dict(best=best, cg_sos=cg_sos, recon_sos=recon_sos, all=results)

# ==== load data ====
data = np.load(npz_path, allow_pickle=True)
ksp_np   = data["ksp"]    # expect (P,C,X,Y,Z) complex64
sens_np  = data["sens"]   # (C,X,Y,Z) complex64
omega_np = data["omega"]  # (P,Y,Z) float/bool

# adapt shapes if ksp is (P,X,Y,Z,C)
if ksp_np.shape[1] != sens_np.shape[0]:
    # move coils axis to dim=1
    # from (P,X,Y,Z,C) -> (P,C,X,Y,Z)
    ksp_np = np.moveaxis(ksp_np, -1, 1)

P, C, X, Y, Z = ksp_np.shape
print(f"Loaded: P={P}, C={C}, X={X}, Y={Y}, Z={Z}")

# to torch
ksp_t  = to_torch(ksp_np)        # (P,C,X,Y,Z) complex
sens_t = to_torch(sens_np)       # (C,X,Y,Z) complex

rows = []
t0 = time.time()

prior = Pseudo3DUNet2p5D(in_ch=2*k_neighbors, base=32, residual_scale=0.1).to(device)

ckpt_path = "best_prior_100_epoch.pt"
ckpt = torch.load(ckpt_path, map_location=device)
prior.load_state_dict(ckpt["model_state"])

with torch.no_grad():
    prior.eval()
    for p in range(1):
        ksp_p  = ksp_t[p]  # (C,X,Y,Z)
        mask_w = torch.from_numpy(omega_np[p]).to(device=device, dtype=torch.float32)
        mask_w = mask_w.unsqueeze(0).unsqueeze(0)  # (1,1,Y,Z)

        out = infer_one_phase(ksp_p, sens_t, mask_w, lam_fixed, gamma_grid)

        # save SoS volumes
        cg_sos_np    = out["cg_sos"].cpu().numpy()
        recon_sos_np = out["recon_sos"].cpu().numpy()
        np.save(os.path.join(out_dir, f"phase_{p+1:02d}_cg_sos.npy"),    cg_sos_np)
        np.save(os.path.join(out_dir, f"phase_{p+1:02d}_recon_sos.npy"), recon_sos_np)

        b = out["best"]
        rows.append({
            "phase":   p+1,
            "lambda":  lam_fixed,
            "gamma":   float(b["gamma"]),
            "Omega_MSE": float(b["mse"]),
            "OmegaC_E":  float(b["e_un"]),
        })
        print(f"phase {p+1:02d} | γ={b['gamma']:.2e}  Ω_MSE={b['mse']:.3e}  Ωc_E={b['e_un']:.3e}")

# write CSV
csv_path = os.path.join(out_dir, "per_phase_metrics.csv")
with open(csv_path, "w", newline="") as f:
    w = csv.DictWriter(f, fieldnames=["phase","lambda","gamma","Omega_MSE","OmegaC_E"])
    w.writeheader()
    w.writerows(rows)

print(f"\nSaved SoS volumes + CSV to: {out_dir}")
print(f"Elapsed: {time.time()-t0:.1f}s")

Loaded: P=20, C=24, X=160, Y=128, Z=72
phase 01 | γ=2.00e-02  Ω_MSE=5.003e-09  Ωc_E=1.048e-07

Saved SoS volumes + CSV to: inference_all_phases
Elapsed: 591.9s
