# Accelerated Parameter Mapping of Multiple-Echo Gradient-Echo Data Using Model-Based Iterative Reconstruction

*M. Zimmermann, Z. Abbas, K. Dzieciol and N. J. Shah, "Accelerated Parameter Mapping of Multiple-Echo Gradient-Echo Data Using Model-Based Iterative Reconstruction," in IEEE Transactions on Medical Imaging, vol. 37, no. 2, pp. 626-637, Feb. 2018, doi: 10.1109/TMI.2017.2771504.*

# Load data

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

import os
import sys

sys.path.insert(0, "../src")

import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch

from juart.conopt.functional.fourier import nonuniform_fourier_transform_adjoint
from juart.conopt.tfs.fourier import nonuniform_transfer_function
from juart.recon.sense import SENSE
from juart.recon.mirage import MIRAGE
from juart.recon.monitoring import ConvergenceMonitor
from juart.vis import MultiPlotter

torch.set_num_threads(16)
torch.set_grad_enabled(False)

In [None]:
fname = "7T1566"
nUS = 64
iS = 80

In [None]:
full_session_dir = "/home/projects/qrage/sessions/%s" % fname
h5_preproc_fname = "preproc/mz_me_mpnrage3d_grappa_pytorch.h5"
h5_image_fname = "images/mz_me_mpnrage3d_grappa_%s_%s.h5" % (iS, nUS)

In [None]:
full_h5_preproc_fname = os.path.join(full_session_dir, h5_preproc_fname)
full_h5_h5_image_fname = os.path.join(full_session_dir, h5_image_fname)

In [None]:
print(full_h5_preproc_fname)
print(full_h5_h5_image_fname)

In [None]:
with h5py.File(
    full_h5_preproc_fname,
    "r",
    libver="latest",
    swmr=True,
) as h5_preproc_file:
    nC, nX, nY, nZ, nS = h5_preproc_file["C"].shape[:5]
    nC, spokes, baseresolution, nZ, nS, nTI, nTE = h5_preproc_file["d"].shape

    shape = (nX, nY, nZ, 1, nTI, nTE)

    print("Reconstructing slice %s of %s." % (iS, nS))

    nK = baseresolution * nUS

    iTI = slice(1, 2)
    iTE = slice(0, 1)
    nTI = 1
    nTE = 1

    shape = (nX, nY, nZ, 1, nTI, nTE)

    # Read data
    C = h5_preproc_file["C"][:, :, :, :, iS : iS + 1, :, :]
    k = h5_preproc_file["k"][:, :nUS, :, :, :, iTI, iTE]
    d = h5_preproc_file["d"][:, :nUS, :, :, iS : iS + 1, iTI, iTE] / 1e-4

    # nC = 1

    k = k.reshape((2, nK, 1, nTI, nTE))
    d = d.reshape((nC, nK, 1, nTI, nTE))

    C = torch.tensor(C, dtype=torch.complex64)
    k = torch.tensor(k, dtype=torch.float32)
    d = torch.tensor(d, dtype=torch.complex64)

In [None]:
MultiPlotter(
    np.abs(C[..., 0, 0].numpy()),
    (1, 8),
    axis=0,
    fig=plt.figure(figsize=(10, 3)),
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
MultiPlotter(
    np.angle(C[..., 0, 0].numpy()),
    (1, 8),
    axis=0,
    fig=plt.figure(figsize=(10, 3)),
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
AHd = nonuniform_fourier_transform_adjoint(
    k,
    d,
    (nX, nY),
    modeord=0,
    isign=1,
)
AHd = torch.sum(torch.conj(C) * AHd, dim=0)

In [None]:
H = nonuniform_transfer_function(k, (1, nX, nY, nZ, 1, nTI, nTE), oversampling=(2, 2))

In [None]:
MultiPlotter(
    np.abs(H[0, :, :, 0, 0, :, :].numpy()),
    (nTE, nTI),
    axis=(3, 2),
    fig=plt.figure(figsize=(10, 6)),
    cmap="viridis",
    vmin=0,
    vmax=1,
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
MultiPlotter(
    np.abs(AHd[..., 0, 0, :, :].numpy()),
    (nTE, nTI),
    axis=(3, 2),
    fig=plt.figure(figsize=(10, 6)),
    cmap="gray",
    vmin=0,
    vmax=25,
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
MultiPlotter(
    np.angle(AHd[..., 0, 0, :, :].numpy()),
    (nTE, nTI),
    axis=(3, 2),
    fig=plt.figure(figsize=(10, 6)),
    cmap="gray",
    vmin=-np.pi,
    vmax=np.pi,
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
cg_solver = SENSE(
    C,
    AHd,
    H,
    maxiter=30,
)

In [None]:
cg_image = cg_solver.solve().view(torch.complex64).reshape(shape)

In [None]:
MultiPlotter(
    torch.abs(cg_image[..., 0, 0, :, :]).cpu().numpy(),
    (nTI, nTE),
    axis=(3, 2),
    fig=plt.figure(figsize=(10, 6)),
    cmap="gray",
    vmin=0,
    vmax=1,
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
support = torch.abs(C).max(dim=0).values > 0

cm = ConvergenceMonitor(support + torch.zeros(shape), support, logfile=None)

solver = MIRAGE(
    C,
    AHd,
    H,
    lambda_wavelet=1e-3,
    lambda_hankel=None,
    lambda_casorati=None,
    weight_wavelet=0.5,
    weight_hankel=0.5,
    weight_casorati=0.5,
    cg_maxiter=5,
    admm_maxiter=250,
    # callback=cm.callback,
)

In [None]:
# Run MIRAGE reconstruction
solver.solve()

In [None]:
z_image = solver.solver.results["v"].view(torch.complex64).reshape(shape)

In [None]:
MultiPlotter(
    np.abs((
        0.1 * AHd[..., 0, 0, :, :],
        cg_image[..., 0, 0, :, :],
        z_image[..., 0, 0, :, :],
    )),
    (nTI * nTE, 3),
    axis=(3, 4, 0),
    fig=plt.figure(figsize=(10, 6)),
    cmap="gray",
    vmin=0,
    vmax=1,
    cbar_size="2.5%",
    cbar_pad=0.1,
)