In [3]:
import os
import numpy as np
import napari
import time, os

npz_path = "scan20_splits.npz"
out_dir  = "cg_baseline"
os.makedirs(out_dir, exist_ok=True)

In [4]:
def fft3c(img):   # (..., X, Y, Z) -> k-space
    return np.fft.fftn(img, axes=(-3, -2, -1), norm="ortho")

def ifft3c(ksp):  # (..., X, Y, Z) -> image
    return np.fft.ifftn(ksp, axes=(-3, -2, -1), norm="ortho")

In [5]:
def A_omega(x, sens, mask_omega_yz):
    X = x.shape[0]
    m3d = np.broadcast_to(mask_omega_yz[None, :, :], (X, *mask_omega_yz.shape))  # (X,Y,Z)
    xC  = x[None, ...] * sens                      # (C,X,Y,Z)
    y   = fft3c(xC)                             # (C,X,Y,Z)
    return y * m3d[None, ...]                      # (C,X,Y,Z)

def AH_omega(y, sens, mask_omega_yz):
    X = y.shape[1]
    m3d = np.broadcast_to(mask_omega_yz[None, :, :], (X, *mask_omega_yz.shape))  # (X,Y,Z)
    y    = y * m3d[None, ...]                     # (C,X,Y,Z)
    imgC = ifft3c(y)                           # (C,X,Y,Z)
    x    = np.sum(np.conj(sens) * imgC, axis=0)   # (X,Y,Z)
    return x

In [6]:
def cg_solve(normal_op, b, lam, maxiter=40, tol=1e-6):
    """
    Solve (N + lam*I) x = b with complex CG in NumPy.
    normal_op(x) applies A^H Ω A to x (no λ inside).
    """
    x = np.zeros_like(b)
    r = b - (normal_op(x) + lam * x)
    p = r.copy()
    rs_old = np.vdot(r, r).real

    for _ in range(maxiter):
        Ap    = normal_op(p) + lam * p
        denom = np.vdot(p, Ap).real
        alpha = rs_old / max(denom, 1e-12)
        x     = x + alpha * p
        r     = r - alpha * Ap
        rs_new = np.vdot(r, r).real
        if np.sqrt(rs_new) < tol * np.sqrt(rs_old + 1e-20):
            break
        beta = rs_new / max(rs_old, 1e-20)
        p    = r + beta * p
        rs_old = rs_new
    return x

In [7]:
data = np.load(npz_path)
ksp   = data["ksp"]                  # (P,C,X,Y,Z) complex64
sens  = data["sens"]                 # (C,X,Y,Z)   complex64
omega = data["omega"].astype(bool)   # (P,Y,Z)     bool

P, C, X, Y, Z = ksp.shape
print(f"Loaded: P={P}, C={C}, X={X}, Y={Y}, Z={Z}")

# Regularization & CG params
lambda_cg = 1e-6
cg_iters  = 40
cg_tol    = 1e-6

Loaded: P=20, C=24, X=160, Y=128, Z=72


In [8]:
omse_list, oce_list = [], []
for p in range(P):
    y_meas = ksp[p]        # (C,X,Y,Z)
    mΩ_yz  = omega[p]      # (Y,Z) bool

    # b = A^H Ω y
    b = AH_omega(y_meas, sens, mΩ_yz)  # (X,Y,Z)

    # normal op: N(x) = A^H Ω A x
    def normal_op_np(x):
        return AH_omega(A_omega(x, sens, mΩ_yz), sens, mΩ_yz)

    # Solve (N + λI) x = b
    x_rec = cg_solve(normal_op_np, b, lam=lambda_cg, maxiter=cg_iters, tol=cg_tol)  # (X,Y,Z)

    # Metrics (Ω-MSE on acquired, Ω^c energy on unacquired)
    m3d     = np.broadcast_to(mΩ_yz[None, :, :], (X, Y, Z))         # (X,Y,Z)
    one_m3d = np.logical_not(m3d)

    y_hat   = A_omega(x_rec, sens, mΩ_yz)                         # (C,X,Y,Z)
    diff    = y_hat - (y_meas * m3d[None, ...])

    sq_diff = (diff.real**2 + diff.imag**2)
    sq_yhat = (y_hat.real**2 + y_hat.imag**2)

    num_acq = int(m3d.sum()) * C
    num_un  = int(one_m3d.sum()) * C

    omse = (sq_diff * m3d[None, ...]).sum().item() / max(num_acq, 1)
    oce  = (sq_yhat * one_m3d[None, ...]).sum().item() / max(num_un, 1)

    omse_list.append(omse); oce_list.append(oce)
    print(f"[phase {p+1:02d}] Ω_MSE={omse:.3e}  Ωc_E={oce:.3e}")

    # Save |x| with fftshift ONLY IN IMAGE DOMAIN (center all three axes)
    mag = np.abs(x_rec).astype(np.float32)                 # (X,Y,Z)
    mag = np.fft.fftshift(mag, axes=(0, 1, 2))
    np.save(os.path.join(out_dir, f"phase_{p+1:02d}_cg_sos.npy"), x_rec)

[phase 01] Ω_MSE=1.885e-09  Ωc_E=0.000e+00
[phase 02] Ω_MSE=1.766e-09  Ωc_E=0.000e+00
[phase 03] Ω_MSE=1.912e-09  Ωc_E=0.000e+00
[phase 04] Ω_MSE=2.075e-09  Ωc_E=0.000e+00
[phase 05] Ω_MSE=1.534e-09  Ωc_E=0.000e+00
[phase 06] Ω_MSE=2.081e-09  Ωc_E=0.000e+00
[phase 07] Ω_MSE=1.712e-09  Ωc_E=0.000e+00
[phase 08] Ω_MSE=1.791e-09  Ωc_E=0.000e+00
[phase 09] Ω_MSE=1.654e-09  Ωc_E=0.000e+00
[phase 10] Ω_MSE=1.918e-09  Ωc_E=0.000e+00
[phase 11] Ω_MSE=1.754e-09  Ωc_E=0.000e+00
[phase 12] Ω_MSE=1.882e-09  Ωc_E=0.000e+00
[phase 13] Ω_MSE=1.652e-09  Ωc_E=0.000e+00
[phase 14] Ω_MSE=1.870e-09  Ωc_E=0.000e+00
[phase 15] Ω_MSE=1.688e-09  Ωc_E=0.000e+00
[phase 16] Ω_MSE=1.894e-09  Ωc_E=0.000e+00
[phase 17] Ω_MSE=1.712e-09  Ωc_E=0.000e+00
[phase 18] Ω_MSE=1.892e-09  Ωc_E=0.000e+00
[phase 19] Ω_MSE=1.722e-09  Ωc_E=0.000e+00
[phase 20] Ω_MSE=1.700e-09  Ωc_E=0.000e+00


In [9]:
om = np.array(omse_list); oc = np.array(oce_list)
print(f"\nλ_cg = {lambda_cg:.1e}, iters={cg_iters}, tol={cg_tol:.1e}")
print(f"Ω_MSE mean±sd = {om.mean():.3e} ± {om.std():.3e}")
print(f"Ωc_E  mean±sd = {oc.mean():.3e} ± {oc.std():.3e}")
print(f"Saved per-phase recon to: {out_dir}/phase_XX_cg_sos.npy")


λ_cg = 1.0e-06, iters=40, tol=1.0e-06
Ω_MSE mean±sd = 1.805e-09 ± 1.375e-10
Ωc_E  mean±sd = 0.000e+00 ± 0.000e+00
Saved per-phase recon to: cg_baseline/phase_XX_cg_sos.npy


In [None]:
recon = np.load('cg_baseline_numpy_3d/phase_17_cg_sos.npy')
v1 = napari.Viewer()
v1.add_image(np.abs(recon), name='phase_07_cg_sos')
napari.run()