In [19]:
import numpy as np
import torch
import math
import torch.nn.functional as F
import torch.nn as nn
import random
import time
import os

In [2]:
# Step 1 — Data I/O + prep (device-ready)
path_npz = "scan20_splits.npz"
device   = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_step1(npz_path, device):
    data = np.load(npz_path, allow_pickle=False)
    # raw numpy
    ksp_np  = data["ksp"]    # (P,C,X,Y,Z) complex64
    sens_np = data["sens"]   # (C,X,Y,Z)   complex64
    th_np   = data["theta"]  # (P,Y,Z)
    la_np   = data["lam"]    # (P,Y,Z)
    va_np   = data["val"]    # (P,Y,Z)
    om_np   = data["omega"]  # (P,Y,Z)
    seed    = int(data["split_seed"])

    P, C, X, Y, Z = ksp_np.shape

    # to torch (on device)
    ksp_t  = torch.from_numpy(ksp_np).to(device)          # complex64
    sens_t = torch.from_numpy(sens_np).to(device)         # complex64

    # masks → bool on device
    theta_t = torch.from_numpy(th_np.astype(np.bool_)).to(device)
    lam_t   = torch.from_numpy(la_np.astype(np.bool_)).to(device)
    val_t   = torch.from_numpy(va_np.astype(np.bool_)).to(device)
    omega_t = torch.from_numpy(om_np.astype(np.bool_)).to(device)

    # broadcast masks: (P,1,1,Y,Z)
    def bmask(m): return m.view(P, 1, 1, Y, Z)
    th_b = bmask(theta_t)
    la_b = bmask(lam_t)
    val_b= bmask(val_t)
    om_b = bmask(omega_t)

    # temporal cues (sin/cos per phase) → (P,2,Y,Z) float32
    phases = torch.arange(P, device=device, dtype=torch.float32)
    ang    = 2.0 * math.pi * phases / float(P)            # (P,)
    s      = torch.sin(ang).view(P,1,1,1)                 # (P,1,1,1)
    c      = torch.cos(ang).view(P,1,1,1)                 # (P,1,1,1)
    ones_yz= torch.ones((1,1,Y,Z), device=device, dtype=torch.float32)
    sin_map= (s * ones_yz).squeeze(1)                     # (P,1,Y,Z)
    cos_map= (c * ones_yz).squeeze(1)                     # (P,1,Y,Z)
    cues_2d= torch.cat([sin_map, cos_map], dim=1).contiguous()  # (P,2,Y,Z)

    # quick summary
    s2 = (sens_t.conj()*sens_t).real.sum(dim=0).mean().item()
    print(f"Loaded keys: {list(data.keys())}")
    print(f"P={P}, C={C}, X={X}, Y={Y}, Z={Z}")
    print(f"ksp dtype: {ksp_t.dtype}  sens dtype: {sens_t.dtype}")
    print(f"θ/λ/val/Ω shapes: {tuple(theta_t.shape)} {tuple(lam_t.shape)} {tuple(val_t.shape)} {tuple(omega_t.shape)}")
    print(f"cues_2d shape: {tuple(cues_2d.shape)}  (P,2,Y,Z)")
    print(f"mean|Σ|S|^2 - 1| ≈ {abs(s2 - 1.0):.6e}")
    print(f"Broadcast masks: th_b/la_b/val_b/om_b → {(P,1,1,Y,Z)}")

    bmasks = {"theta": th_b, "lam": la_b, "val": val_b, "omega": om_b}
    return ksp_t, sens_t, bmasks, cues_2d, seed

# run
ksp_t, sens_t, bmasks, cues_2d, split_seed = load_step1(path_npz, device)

Loaded keys: ['ksp', 'sens', 'theta', 'lam', 'val', 'omega', 'split_seed']
P=20, C=24, X=160, Y=128, Z=72
ksp dtype: torch.complex64  sens dtype: torch.complex64
θ/λ/val/Ω shapes: (20, 128, 72) (20, 128, 72) (20, 128, 72) (20, 128, 72)
cues_2d shape: (20, 256, 72)  (P,2,Y,Z)
mean|Σ|S|^2 - 1| ≈ 0.000000e+00
Broadcast masks: th_b/la_b/val_b/om_b → (20, 1, 1, 128, 72)


In [3]:
# Step 2 — FFT & SENSE operators (3D, orthonormal)
# --- centered 3D FFTs (no explicit shifts; data/masks assumed consistent) ---
def fft3c(x: torch.Tensor) -> torch.Tensor:
    """3D FFT with 'ortho' norm. x: (..., X, Y, Z) complex."""
    return torch.fft.fftn(x, dim=(-3, -2, -1), norm="ortho")

def ifft3c(x: torch.Tensor) -> torch.Tensor:
    """3D iFFT with 'ortho' norm. x: (..., X, Y, Z) complex."""
    return torch.fft.ifftn(x, dim=(-3, -2, -1), norm="ortho")

# --- SENSE Encoding (A) and Adjoint (AH) ---
def A(x_img: torch.Tensor, mask_yz: torch.Tensor, sens: torch.Tensor) -> torch.Tensor:
    """
    Forward encoding: x -> masked multi-coil k-space.
      x_img: (X,Y,Z) complex
      mask_yz: (1,1,Y,Z) bool/float (broadcast over C,X)
      sens: (C,X,Y,Z) complex
    Returns: (C,X,Y,Z) complex
    """
    coil_img = sens * x_img.unsqueeze(0)            # (C,X,Y,Z)
    ksp      = fft3c(coil_img)                      # (C,X,Y,Z)
    return ksp * mask_yz                            # broadcast over C,X

def AH(y_ksp: torch.Tensor, mask_yz: torch.Tensor, sens: torch.Tensor) -> torch.Tensor:
    """
    Adjoint: masked multi-coil k-space -> combined image.
      y_ksp: (C,X,Y,Z) complex
      mask_yz: (1,1,Y,Z) bool/float
      sens: (C,X,Y,Z) complex
    Returns: (X,Y,Z) complex
    """
    y_mask = y_ksp * mask_yz                        # (C,X,Y,Z)
    coil_img = ifft3c(y_mask)                       # (C,X,Y,Z)
    x_img = (coil_img * sens.conj()).sum(dim=0)     # (X,Y,Z)
    return x_img

# --- Convenience: zero-filled SENSE adjoint for a single phase ---
def zf_image(ksp_p: torch.Tensor, mask_yz_p: torch.Tensor, sens: torch.Tensor) -> torch.Tensor:
    """
    Zero-filled reconstruction (adjoint of A) for one phase.
      ksp_p:  (C,X,Y,Z) complex
      mask_yz_p: (1,1,Y,Z) bool/float
      sens:   (C,X,Y,Z) complex
    Returns: (X,Y,Z) complex
    """
    return AH(ksp_p, mask_yz_p, sens)

# # --- Quick sanity check (uses tensors from Step 1) ---
# with torch.no_grad():
#     P, C, X, Y, Z = ksp_t.shape
#     p = 0
#     y_p   = ksp_t[p]                 # (C,X,Y,Z)
#     m_th  = bmasks["theta"][p]       # (1,1,Y,Z)
#     x_zf  = zf_image(y_p, m_th, sens_t)      # (X,Y,Z)
#     y_hat = A(x_zf, m_th, sens_t)            # (C,X,Y,Z)

#     print(f"[Check] x_zf: {tuple(x_zf.shape)}, y_hat: {tuple(y_hat.shape)}  dtypes={x_zf.dtype}/{y_hat.dtype}")

In [4]:
def kspace_metrics(y_hat: torch.Tensor,
                   y: torch.Tensor,
                   mask_acq: torch.Tensor) -> tuple[float, float]:
    """
    Compute k-space metrics:
      - MSE over acquired samples (Ω_MSE)
      - Energy over unacquired samples (Ω^c_E)

    Args:
        y_hat    : (C, X, Y, Z) complex — predicted k-space
        y        : (C, X, Y, Z) complex — reference/true k-space
        mask_acq : (Y, Z) or (1,1,Y,Z), bool or float — acquisition mask Ω

    Returns:
        (omse, oce) as floats
          omse = mean squared error over Ω
          oce  = mean squared magnitude over Ω^c
    """
    # Ensure mask is float and broadcastable to (C,X,Y,Z)
    if mask_acq.ndim == 2:
        m = mask_acq[None, None, ...]
    elif mask_acq.ndim == 4:
        m = mask_acq
    else:
        raise ValueError("mask_acq must be (Y,Z) or (1,1,Y,Z)")
    m = m.to(dtype=torch.float32, device=y_hat.device)

    # Squared magnitudes (real scalars)
    diff    = y_hat - y
    sq_diff = diff.real.pow(2) + diff.imag.pow(2)     # (C,X,Y,Z)
    sq_pred = y_hat.real.pow(2) + y_hat.imag.pow(2)   # (C,X,Y,Z)

    # Complement mask
    one_m = 1.0 - m

    # Normalize by element counts across C and X
    C, X = y_hat.shape[0], y_hat.shape[1]
    num_acq = (m.sum() * (C * X)).item()
    num_un  = (one_m.sum() * (C * X)).item()

    omse = ((sq_diff * m).sum().item()   / max(num_acq, 1.0)) if num_acq > 0 else float('nan')
    oce  = ((sq_pred * one_m).sum().item() / max(num_un,  1.0)) if num_un  > 0 else 0.0
    return omse, oce

In [5]:
class ConvBNAct(nn.Module):
    def __init__(self, in_ch, out_ch, k=3):
        super().__init__()
        p = k//2
        self.conv = nn.Conv2d(in_ch, out_ch, k, padding=p)
        self.gn   = nn.GroupNorm(num_groups=8, num_channels=out_ch)
        self.act  = nn.LeakyReLU(0.1, inplace=True)
    def forward(self, x):
        return self.act(self.gn(self.conv(x)))

class UNetBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.c1 = ConvBNAct(in_ch,  out_ch)
        self.c2 = ConvBNAct(out_ch, out_ch)
    def forward(self, x):
        x = self.c1(x)
        x = self.c2(x)
        return x

class Pseudo3DUNet2p5D(nn.Module):
    """
    Input:  (B, 2*k, H, W) — k neighboring slices, real/imag stacked as channels
    Output: (B, 2,   H, W) — residual for center slice (real, imag)
    """
    def __init__(self, in_ch=10, base=32, out_ch=2, residual_scale=0.1):
        super().__init__()
        self.enc1 = UNetBlock(in_ch,   base)
        self.enc2 = UNetBlock(base,    base*2)
        self.enc3 = UNetBlock(base*2,  base*4)

        self.down1 = nn.Conv2d(base,   base,   3, stride=2, padding=1)
        self.down2 = nn.Conv2d(base*2, base*2, 3, stride=2, padding=1)

        self.dec2 = UNetBlock(base*4 + base*2, base*2)
        self.dec1 = UNetBlock(base*2 + base,   base)

        self.out  = nn.Conv2d(base, out_ch, 1)
        self.res_scale = residual_scale

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)                  # (B, B, H, W)
        x  = self.down1(e1)                # (B, B, H/2, W/2)
        e2 = self.enc2(x)                  # (B, 2B, H/2, W/2)
        x  = self.down2(e2)                # (B, 2B, H/4, W/4)
        e3 = self.enc3(x)                  # (B, 4B, H/4, W/4)

        # Decoder
        x  = F.interpolate(e3, scale_factor=2, mode='bilinear', align_corners=False)
        x  = torch.cat([x, e2], dim=1)
        x  = self.dec2(x)

        x  = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x  = torch.cat([x, e1], dim=1)
        x  = self.dec1(x)

        out = self.out(x)
        return out * self.res_scale

In [6]:
# @torch.no_grad()
def _neighbor_indices_X(X: int, K: int, x0: int):
    """Clamp-replicate neighbor indices along X for center x0."""
    r = K // 2
    idxs = [min(max(x0 + d, 0), X - 1) for d in range(-r, r + 1)]
    return idxs  # length K

def _make_25d_stack(x_img: torch.Tensor, K: int) -> torch.Tensor:
    """
    Build 2.5D stack over (Y,Z) with batch=X:
      x_img: (X,Y,Z) complex
      return: (X, 2*K, Y, Z) float  [real/imag per neighbor slice]
    """
    assert x_img.ndim == 3 and torch.is_complex(x_img)
    X, Y, Z = x_img.shape
    r = K // 2

    # Collect per-center stacks
    stacks = []
    for x0 in range(X):
        idxs = _neighbor_indices_X(X, K, x0)      # length K
        nb   = x_img[idxs, ...]                   # (K, Y, Z) complex
        ch   = torch.stack([nb.real, nb.imag], dim=1)  # (K, 2, Y, Z)
        ch   = ch.reshape(K * 2, Y, Z)                 # (2K, Y, Z)
        stacks.append(ch)
    inp = torch.stack(stacks, dim=0)              # (X, 2K, Y, Z)
    return inp

def apply_prior_25d_cued(prior_25d, x_img: torch.Tensor, cues_yz: torch.Tensor, K: int) -> torch.Tensor:
    """
    Apply temporal-cued 2.5D prior on (Y,Z) planes, batched over X.
      prior_25d: Pseudo3DUNet2p5D(in_ch=2*K + 2)
      x_img: (X,Y,Z) complex
      cues_yz: (2, Y, Z) float  [sin, cos] for this phase
      K: odd number of through-plane neighbors (e.g., 3, 5)
    Returns: residual r (X,Y,Z) complex
    """
    X, Y, Z = x_img.shape
    stack_25d = _make_25d_stack(x_img, K)                  # (X, 2K, Y, Z), float
    cues = cues_yz.unsqueeze(0).expand(X, -1, -1, -1)      # (X, 2,  Y, Z)

    inp = torch.cat([stack_25d, cues], dim=1)              # (X, 2K+2, Y, Z)
    out = prior_25d(inp)                                   # (X, 2,    Y, Z) float
    r   = torch.complex(out[:, 0], out[:, 1])              # (X, Y, Z) complex
    return r

In [7]:
# # --- Fix cues shape: (P, 2*Y, Z) -> (P, 2, Y, Z) ---
# P, Cx, Z = cues_2d.shape            # currently (P, 2*Y, Z)
# Y = Cx // 2
# assert Cx == 2 * Y, "Expected second dim to be 2*Y"

# sin_c = cues_2d[:, :Y, :].contiguous()
# cos_c = cues_2d[:, Y:, :].contiguous()
# cues_2d_fixed = torch.stack([sin_c, cos_c], dim=1)   # (P, 2, Y, Z)

# # device/dtype alignment
# cues_2d_fixed = cues_2d_fixed.to(ksp_t.device).float()

# # --- Re-run the prior sanity check (same K and tmp prior) ---
# p = 7
# K = 3
# prior_tmp = Pseudo3DUNet2p5D(in_ch=2*K + 2, base=16, out_ch=2, residual_scale=0.1).eval().to(ksp_t.device)

# with torch.no_grad():
#     x_zf  = AH(ksp_t[p], bmasks['theta'][p], sens_t)            # (X,Y,Z) complex
#     cues_p = cues_2d_fixed[p]                        # (2,Y,Z) float
#     r = apply_prior_25d_cued(prior_tmp, x_zf, cues_p, K)  # (X,Y,Z) complex
#     print(f"[Sanity] x_zf={tuple(x_zf.shape)}  r={tuple(r.shape)}  dtype={r.dtype}")

In [8]:
# ---- minimal complex-safe CG for (A^H Ω A + λI) x = rhs ----
# @torch.no_grad()
def cg_solve(rhs, th_mask, sens, lam, x0=None, iters=8, tol=1e-6):
    """
    rhs: (X,Y,Z) complex
    th_mask: (1,1,Y,Z) bool/0-1
    sens: (C,X,Y,Z) complex
    lam: scalar float
    """
    def M(x):
        Ax  = A(x, th_mask, sens)
        AHx = AH(Ax, th_mask, sens)
        return AHx + lam * x

    x = rhs.clone() if x0 is None else x0.clone()
    r = rhs - M(x)
    p = r.clone()
    rs_old = (r.conj() * r).real.sum()

    for _ in range(iters):
        Ap = M(p)
        denom = (p.conj() * Ap).real.sum().clamp_min(1e-20)
        alpha = (rs_old / denom)
        x = x + alpha * p
        r = r - alpha * Ap
        rs_new = (r.conj() * r).real.sum()
        if torch.sqrt(rs_new) < tol * torch.sqrt((rhs.conj()*rhs).real.sum().clamp_min(1e-20)):
            break
        beta = rs_new / rs_old
        p = r + beta * p
        rs_old = rs_new
    return x

# ---- single-phase unrolled forward with cues ----
# @torch.no_grad()
# def unroll_one_phase_with_cues(prior_25d, p, K=3, lam=3e-2, unroll=5, iters=8, tol=1e-6, device="cuda"):
#     """
#     Returns: x (X,Y,Z) complex, and metrics (Ω_MSE, Ωc_E)
#     """
#     # data for this phase
#     y_full = ksp_t[p]                     # (C,X,Y,Z) complex
#     m_th   = bmasks['theta'][p]                      # (1,1,Y,Z)
#     y      = y_full * m_th                # acquired-only
#     Ah_y   = AH(y, m_th, sens_t)          # (X,Y,Z)

#     # ZF init
#     x = Ah_y.clone()

#     # cues for this phase: (2,Y,Z) float
#     cues_p = cues_2d_fixed[p]             # from your fixed cues tensor (P,2,Y,Z)

#     # unroll
#     for _ in range(unroll):
#         r   = apply_prior_25d_cued(prior_25d, x, cues_p, K)   # (X,Y,Z) complex residual
#         x_t = x + r                                           # proximal target
#         rhs = Ah_y + lam * x_t
#         x   = cg_solve(rhs, m_th, sens_t, lam, x0=x, iters=iters, tol=tol)

#     # k-space prediction and metrics
#     y_hat = A(x, m_th, sens_t)      # predicted k-space
#     omse, oce = kspace_metrics(y_hat, y_full, m_th)

#     return x, omse, oce

# # ---- quick demo on one phase (e.g., p=7) ----
# p_demo = 7
# lam_demo = 3e-2
# x_rec, omse, oce = unroll_one_phase_with_cues(
#     prior_25d=prior_tmp, p=p_demo, K=3, lam=lam_demo, unroll=3, iters=8, tol=1e-6,
#     device=ksp_t.device.type
# )
# print(f"[phase {p_demo+1:02d}] Ω_MSE={omse:.3e}  Ωc_E={oce:.3e}  | x={tuple(x_rec.shape)}")

In [9]:
def kspace_l1_l2_loss(y_pred, y_true, alpha=0.5, eps=1e-6):
    diff = y_pred - y_true
    mag  = torch.abs(diff)
    l1   = mag.mean()
    l2   = (mag**2).mean()
    return alpha*l1 + (1-alpha)*l2

def solve_cg(rhs, mask, sens, lam, iters, tol, x0=None):
    out = cg_solve(rhs, mask, sens, lam, iters=iters, tol=tol, x0=x0)
    return out[0] if isinstance(out, tuple) else out

In [10]:
# After loading data and defining (P, C, X, Y, Z) and device
# cues_2d should be (P, 2, Y, Z). If it's (P, 2*Y, Z), reshape it.

P, C, X, Y, Z = ksp_t.shape

def normalize_cues_shape(cues_2d, Y, Z, device):
    # ensure torch float on the right device
    if isinstance(cues_2d, np.ndarray):
        cues_2d = torch.from_numpy(cues_2d)
    cues_2d = cues_2d.to(device=device, dtype=torch.float32)

    # fix shape: (P, 2*Y, Z) -> (P, 2, Y, Z)
    if cues_2d.ndim == 3 and cues_2d.shape[1] == 2*Y and cues_2d.shape[2] == Z:
        P = cues_2d.shape[0]
        cues_2d = cues_2d.view(P, 2, Y, Z).contiguous()

    # sanity check
    assert cues_2d.ndim == 4 and cues_2d.shape[1:] == (2, Y, Z), \
        f"cues_2d must be (P,2,Y,Z), got {tuple(cues_2d.shape)}"
    return cues_2d

# use it
cues_2d = normalize_cues_shape(cues_2d, Y, Z, ksp_t.device)
print("cues_2d fixed shape:", tuple(cues_2d.shape))  # -> (P, 2, Y, Z)

cues_2d fixed shape: (20, 2, 128, 72)


In [None]:
def train_one_epoch_cued(
    prior_25d, opt,
    ksp, sens, bmasks, cues_2d,
    A, AH, solve_cg, apply_prior_25d_cued, kspace_l1_l2_loss,
    lam, K,
    unroll=5, cg_iters=8, cg_tol=1e-3, alpha=0.5, device="cuda"
):
    prior_25d.train()
    P = ksp.shape[0]
    order = torch.randperm(P, device=ksp.device)

    running_loss = 0.0
    steps = 0
    gn_sum = 0.0

    for p in order.tolist():
        m_th  = bmasks["theta"][p]          # (1,1,Y,Z) bool/float
        m_L   = bmasks["lam"][p]            # Λ mask for loss
        cues  = cues_2d[p]                  # (2,Y,Z) float
        y_th  = ksp[p] * m_th

        # --- CG-SENSE init (Θ) ---
        Ah_y  = AH(y_th, m_th, sens)                 # (X,Y,Z)
        x     = solve_cg(Ah_y, m_th, sens, lam, cg_iters, cg_tol, x0=None)
        x_dc = x.clone()

        # --- Unrolled proximal gradient with cues ---
        for _ in range(unroll):
            r    = apply_prior_25d_cued(prior_25d, x_dc, cues, K)   # (X,Y,Z)
            x_t  = x + 0.05 * r
            rhs  = Ah_y + lam * x_t
            x_dc    = solve_cg(rhs, m_th, sens, lam, cg_iters, cg_tol, x0=x_dc)

        # --- Λ loss in k-space ---
        y_L    = ksp[p] * m_L
        y_hatL = A(x, m_L, sens)
        loss   = kspace_l1_l2_loss(y_hatL, y_L, alpha)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

        running_loss += float(loss.item()); steps += 1

        # grad norm (for logging)
        gn = 0.0
        for pmt in prior_25d.parameters():
            if pmt.grad is not None:
                gn += float(pmt.grad.detach().pow(2).sum().item())
        gn_sum += gn**0.5

    return {"train_L": running_loss / max(steps, 1), "grad": gn_sum/max(1,steps)}


In [None]:
@torch.no_grad()
def validate_all_phases_cued(
    prior_25d,
    ksp, sens, bmasks, cues_2d,
    A, AH, solve_cg, apply_prior_25d_cued, kspace_l1_l2_loss, kspace_metrics,
    lam, K,
    unroll=5, cg_iters=25, cg_tol=1e-6, alpha=0.5
):
    prior_25d.eval()
    P = ksp.shape[0]
    losses, omse_list, oce_list = [], [], []

    for p in range(P):
        m_val = bmasks["val"][p]            # (1,1,Y,Z)
        cues  = cues_2d[p]                  # (2,Y,Z)

        # CG-SENSE anchor with validation mask
        y_val = ksp[p] * m_val
        Ah_y  = AH(y_val, m_val, sens)
        x     = solve_cg(Ah_y, m_val, sens, lam, cg_iters, cg_tol, x0=None)
        x_dc = x.clone()

        # Unroll with validation mask
        for _ in range(unroll):
            r    = apply_prior_25d_cued(prior_25d, x_dc, cues, K)   # (X,Y,Z)
            # x_t  = x + r
            x_t  = x + 0.05 * r
            rhs  = Ah_y + lam * x_t
            x_dc    = solve_cg(rhs, m_val, sens, lam, cg_iters, cg_tol, x0=x_dc)

        # Λ loss on validation mask
        y_hatV = A(x, m_val, sens)
        losses.append(float(kspace_l1_l2_loss(y_hatV, y_val, alpha).item()))

        # k-space metrics on full prediction & val mask
        y_hat_full = A(x, torch.ones_like(m_val), sens)  # (C,X,Y,Z)
        y_true     = ksp[p]                              # (C,X,Y,Z)
        omse, oce  = kspace_metrics(y_hat_full, y_true, m_val)  # uses your helper
        omse_list.append(omse)
        oce_list.append(oce)

    val_loss = float(np.mean(losses)) if losses else float('nan')
    return val_loss, float(np.mean(omse_list)), float(np.mean(oce_list))

In [13]:
def make_optimizer_and_warmcos_scheduler(
    params,
    total_epochs: int,
    warmup_epochs: int = 3,
    lr_start: float = 5e-4,
    lr_min: float = 5e-6,
    warmup_start_factor: float = 0.2,
):
    """
    Linear warmup for `warmup_epochs`, then cosine anneal to `lr_min`.
    Returns: (optimizer, scheduler)
    """
    opt = torch.optim.Adam(params, lr=lr_start)

    sched_warm = torch.optim.lr_scheduler.LinearLR(
        opt, start_factor=warmup_start_factor, total_iters=warmup_epochs
    )
    sched_cos = torch.optim.lr_scheduler.CosineAnnealingLR(
        opt, T_max=max(1, total_epochs - warmup_epochs), eta_min=lr_min
    )
    scheduler = torch.optim.lr_scheduler.SequentialLR(
        opt, schedulers=[sched_warm, sched_cos], milestones=[warmup_epochs]
    )
    return opt, scheduler

In [14]:
def make_lambda_schedule(
    T: int,
    lam_max: float = 3e-2,
    lam_min: float = 0.0,
    warmup_epochs: int = 0,
):
    """
    Returns a function lam_sched(epoch) that gives λ for that epoch.
    Linear warmup to lam_max for `warmup_epochs`, then cosine decay to lam_min.
    """
    T_cos = max(1, T - warmup_epochs)

    def lam_sched(epoch: int) -> float:
        if epoch < warmup_epochs:
            # linear ramp 0 → lam_max
            return lam_max * (epoch + 1) / max(1, warmup_epochs)
        # cosine phase
        t = epoch - warmup_epochs
        cos_term = 0.5 * (1 + math.cos(math.pi * t / T_cos))
        return lam_min + (lam_max - lam_min) * cos_term

    return lam_sched

In [15]:
# --- hyperparams ---
epochs       = 60
K            = 3                 # neighbors on each side → in_ch = 2*K + 2
unroll_tr    = 5
unroll_val   = 5
cg_tr_iters  = 8
cg_val_iters = 25
cg_tol_tr    = 1e-6
cg_tol_val   = 1e-6
alpha        = 0.5               # L1/L2 mix in k-space loss
lam          = 3e-2              # data-consistency weight (fixed for now)

# --- model & optim ---
prior = Pseudo3DUNet2p5D(in_ch=2*K + 2, base=32, out_ch=2, residual_scale=0.1).to(ksp_t.device)

opt, scheduler = make_optimizer_and_warmcos_scheduler(
    prior.parameters(), total_epochs=epochs, warmup_epochs=3, lr_start=5e-4, lr_min=5e-6
)

lam_sched = make_lambda_schedule(T=epochs, lam_max=lam, lam_min=0.0, warmup_epochs=0)

best_val = float('inf')

def save_ckpt(path, epoch, val_loss, model, opt):
    torch.save({
        "epoch": epoch,
        "val_loss": float(val_loss),
        "model_state": model.state_dict(),
        "opt_state": opt.state_dict(),
    }, path)

for ep in range(epochs):
    t0 = time.time()
    lam = lam_sched(ep)

    # ---- one epoch of training ----
    tr = train_one_epoch_cued(
        prior, opt,
        ksp_t, sens_t, bmasks, cues_2d,
        A, AH, solve_cg, apply_prior_25d_cued, kspace_l1_l2_loss,
        lam=lam, K=K,
        unroll=unroll_tr, cg_iters=cg_tr_iters, cg_tol=cg_tol_tr,
        alpha=alpha, device=ksp_t.device.type
    )

    # ---- validation ----
    vL, v_omse, v_oce = validate_all_phases_cued(
        prior,
        ksp_t, sens_t, bmasks, cues_2d,
        A, AH, solve_cg, apply_prior_25d_cued, kspace_l1_l2_loss, kspace_metrics,
        lam=lam, K=K,
        unroll=unroll_val, cg_iters=cg_val_iters, cg_tol=cg_tol_val,
        alpha=alpha
    )

    scheduler.step()

    # ---- ckpts ----
    save_ckpt("latest_cued_prior.pt", ep+1, vL, prior, opt)
    if vL < best_val - 1e-9:
        best_val = vL
        save_ckpt("best_cued_prior.pt", ep+1, vL, prior, opt)

    dt = time.time() - t0
    print(f"[Epoch {ep:3d}] train_Λ={tr['train_L']:.3e}  ‖∇‖={tr['grad']:.3e}  "
          f"val={vL:.3e}  (Ω_MSE={v_omse:.3e}, Ωc_E={v_oce:.3e})  "
          f"λ={lam:.2e} LR={opt.param_groups[0]['lr']:.1e}  time={dt:.1f}s")

[Epoch   0] train_Λ=4.639e-05  ‖∇‖=3.747e-03  val=3.564e-06  (Ω_MSE=4.556e-06, Ωc_E=1.151e-04)  λ=3.00e-02 LR=2.3e-04  time=105.5s
[Epoch   1] train_Λ=1.069e-05  ‖∇‖=9.434e-04  val=1.791e-06  (Ω_MSE=1.147e-06, Ωc_E=2.976e-05)  λ=3.00e-02 LR=3.7e-04  time=105.0s




[Epoch   2] train_Λ=6.385e-06  ‖∇‖=7.658e-04  val=9.131e-07  (Ω_MSE=2.642e-07, Ωc_E=6.863e-06)  λ=2.99e-02 LR=5.0e-04  time=104.9s
[Epoch   3] train_Λ=4.064e-06  ‖∇‖=6.272e-04  val=6.969e-07  (Ω_MSE=1.308e-07, Ωc_E=3.407e-06)  λ=2.98e-02 LR=5.0e-04  time=104.9s
[Epoch   4] train_Λ=2.883e-06  ‖∇‖=4.957e-04  val=4.645e-07  (Ω_MSE=4.225e-08, Ωc_E=1.288e-06)  λ=2.97e-02 LR=5.0e-04  time=105.1s
[Epoch   5] train_Λ=2.125e-06  ‖∇‖=3.960e-04  val=3.620e-07  (Ω_MSE=1.576e-08, Ωc_E=6.132e-07)  λ=2.95e-02 LR=5.0e-04  time=105.0s
[Epoch   6] train_Λ=1.801e-06  ‖∇‖=3.754e-04  val=3.896e-07  (Ω_MSE=2.408e-08, Ωc_E=6.921e-07)  λ=2.93e-02 LR=4.9e-04  time=104.9s
[Epoch   7] train_Λ=1.826e-06  ‖∇‖=4.187e-04  val=3.364e-07  (Ω_MSE=1.226e-08, Ωc_E=4.282e-07)  λ=2.90e-02 LR=4.9e-04  time=104.8s
[Epoch   8] train_Λ=1.783e-06  ‖∇‖=4.613e-04  val=5.159e-07  (Ω_MSE=6.456e-08, Ωc_E=1.510e-06)  λ=2.87e-02 LR=4.9e-04  time=105.1s
[Epoch   9] train_Λ=1.700e-06  ‖∇‖=4.073e-04  val=4.150e-07  (Ω_MSE=3.387e-08, Ωc_E

In [34]:
@torch.no_grad()
def infer_and_save_all_phases_cued(
    prior, ksp_t, sens_t, bmasks, cues_2d,
    K=3, lam=1e-2, unroll=5, cg_iters=25, cg_tol=1e-6,
    out_dir="infer_cued"
):
    """
    Saves: out_dir/phase_XX_recon_img.npy  (complex image volume, shape: (X,Y,Z))
    """
    os.makedirs(out_dir, exist_ok=True)
    prior.eval()
    P = ksp_t.shape[0]

    for p in range(P):
        y      = ksp_t[p]               # (C,X,Y,Z) complex
        S      = sens_t                 # (C,X,Y,Z) complex
        m_yz   = bmasks["omega"][p]     # (1,1,Y,Z)
        cues_p = cues_2d[p]             # (2,Y,Z) float

        # CG-SENSE init
        rhs = AH(y, m_yz, S)
        x   = solve_cg(rhs, m_yz, S, lam, iters=cg_iters, tol=cg_tol, x0=rhs)

        # Unroll
        for _ in range(unroll):
            r   = apply_prior_25d_cued(prior, x, cues_p, K)   # (X,Y,Z) complex
            x_t = x + r
            b   = rhs + lam * x_t
            x   = solve_cg(b, m_yz, S, lam, iters=cg_iters, tol=cg_tol, x0=None)

        # Metrics over Ω for logging only
        y_hat = A(x, m_yz, S)
        omse, oce = kspace_metrics(y_hat, y, m_yz)
        print(f"[phase {p+1:02d}] Ω_MSE={omse:.3e}  Ωc_E={oce:.3e}")

        # Save complex image volume
        out_path = os.path.join(out_dir, f"phase_{p+1:02d}_recon_img.npy")
        np.save(out_path, x.detach().cpu().numpy())
        print(f"Saved → {out_path}")

    print(f"Done. Per-phase recons saved in: {out_dir}")

In [35]:
# ---- config (match training) ----
K        = 3
IN_CH    = 2*K + 2          # 2K (re/im neighbors) + 2 (sin,cos)
UNROLL   = 5
CG_ITERS = 25
CG_TOL   = 1e-6
LAM_DC   = 1e-6             # data-consistency λ
OUT_DIR  = "infer_cued"
DEVICE   = ksp_t.device

# ---- model: build + load best checkpoint ----
prior = Pseudo3DUNet2p5D(in_ch=IN_CH, base=32, out_ch=2, residual_scale=0.1).to(DEVICE).eval()
ckpt   = torch.load("best_cued_prior.pt", map_location=DEVICE)
state  = ckpt.get("model_state", ckpt.get("state_dict", ckpt))
prior.load_state_dict(state, strict=False)

os.makedirs(OUT_DIR, exist_ok=True)
P = ksp_t.shape[0]

infer_and_save_all_phases_cued(
    prior, ksp_t, sens_t, bmasks, cues_2d,
    K=K, lam=LAM_DC, unroll=UNROLL,
    cg_iters=CG_ITERS, cg_tol=CG_TOL,
    out_dir=OUT_DIR
)

[phase 01] Ω_MSE=3.097e-09  Ωc_E=0.000e+00
Saved → infer_cued/phase_01_recon_img.npy
[phase 02] Ω_MSE=3.426e-09  Ωc_E=0.000e+00
Saved → infer_cued/phase_02_recon_img.npy
[phase 03] Ω_MSE=3.141e-09  Ωc_E=0.000e+00
Saved → infer_cued/phase_03_recon_img.npy
[phase 04] Ω_MSE=3.981e-09  Ωc_E=0.000e+00
Saved → infer_cued/phase_04_recon_img.npy
[phase 05] Ω_MSE=2.554e-09  Ωc_E=0.000e+00
Saved → infer_cued/phase_05_recon_img.npy
[phase 06] Ω_MSE=3.652e-09  Ωc_E=0.000e+00
Saved → infer_cued/phase_06_recon_img.npy
[phase 07] Ω_MSE=2.868e-09  Ωc_E=0.000e+00
Saved → infer_cued/phase_07_recon_img.npy
[phase 08] Ω_MSE=3.153e-09  Ωc_E=0.000e+00
Saved → infer_cued/phase_08_recon_img.npy
[phase 09] Ω_MSE=2.785e-09  Ωc_E=0.000e+00
Saved → infer_cued/phase_09_recon_img.npy
[phase 10] Ω_MSE=3.753e-09  Ωc_E=0.000e+00
Saved → infer_cued/phase_10_recon_img.npy
[phase 11] Ω_MSE=2.750e-09  Ωc_E=0.000e+00
Saved → infer_cued/phase_11_recon_img.npy
[phase 12] Ω_MSE=3.598e-09  Ωc_E=0.000e+00
Saved → infer_cued/pha