# Accelerated Parameter Mapping of Multiple-Echo Gradient-Echo Data Using Model-Based Iterative Reconstruction

*M. Zimmermann, Z. Abbas, K. Dzieciol and N. J. Shah, "Accelerated Parameter Mapping of Multiple-Echo Gradient-Echo Data Using Model-Based Iterative Reconstruction," in IEEE Transactions on Medical Imaging, vol. 37, no. 2, pp. 626-637, Feb. 2018, doi: 10.1109/TMI.2017.2771504.*

# Load data

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

import os
import sys

sys.path.insert(0, "/home/jovyan/jail/src")

import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch

from jail.conopt.aux.fourier import nonuniform_fourier_transform_adjoint
from jail.conopt.tfs.fourier import nonuniform_transfer_function
from jail.recon.mirage import MIRAGE
from jail.recon.monitoring import ConvergenceMonitor
from jail.vis import MultiPlotter

torch.set_num_threads(16)
torch.set_grad_enabled(False)

In [None]:
fname = "7T1541_pulseq_hypsecn_overdrive"
nUS = 8
iS = 80
device = ["cpu", "cuda:3"][0]

In [None]:
full_session_dir = "/home/jovyan/qrage/sessions/%s" % fname
h5_preproc_fname = "preproc/mz_me_mpnrage3d_grappa.h5"
h5_image_fname = "images/mz_me_mpnrage3d_grappa_%s_%s.h5" % (iS, nUS)

In [None]:
full_h5_preproc_fname = os.path.join(full_session_dir, h5_preproc_fname)
full_h5_h5_image_fname = os.path.join(full_session_dir, h5_image_fname)

In [None]:
print(full_h5_preproc_fname)
print(full_h5_h5_image_fname)

In [None]:
with h5py.File(
    full_h5_preproc_fname,
    "r",
    libver="latest",
    swmr=True,
) as h5_preproc_file:
    nC, nX, nY, nZ, nS = h5_preproc_file["C"].shape[:5]
    nC, spokes, baseresolution, nZ, nS, nTI, nTE = h5_preproc_file["d"].shape

    shape = (nX, nY, nZ, 1, nTI, nTE)

    print("Reconstructing slice %s of %s." % (iS, nS))

    nK = baseresolution * nUS

    # Read data
    C = h5_preproc_file["C"][:, :, :, :, iS : iS + 1, :, :]
    k = h5_preproc_file["k"][:, :nUS, :, :, :, :, :]
    d = h5_preproc_file["d"][:, :nUS, :, :, iS : iS + 1, :, :] / 1e-4

    k = k.reshape((2, nK, 1, nTI, nTE))
    d = d.reshape((8, 1, 1, nK, 1, nTI, nTE))

    C = torch.tensor(C, dtype=torch.complex64, device=device)
    k = torch.tensor(k, dtype=torch.float32, device=device)
    d = torch.tensor(d, dtype=torch.complex64, device=device)

In [None]:
MultiPlotter(
    np.abs(C[..., 0, 0].cpu().numpy()),
    (1, 8),
    axis=0,
    fig=plt.figure(figsize=(10, 3)),
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
MultiPlotter(
    np.angle(C[..., 0, 0].cpu().numpy()),
    (1, 8),
    axis=0,
    fig=plt.figure(figsize=(10, 3)),
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
AHd = nonuniform_fourier_transform_adjoint(
    k, d, (nX, nY, nZ), (nC, nX, nY, nZ, 1, nTI, nTE)
)
AHd = torch.sum(torch.conj(C) * AHd, dim=0)

In [None]:
H = nonuniform_transfer_function(
    k, (nX, nY, nZ, 1, nTI, nTE, nK), oversampling=(2, 2, 1)
)

In [None]:
MultiPlotter(
    np.abs(H[0, :, :, 0, 0, :, :].cpu().numpy()),
    (nTE, nTI),
    axis=(3, 2),
    fig=plt.figure(figsize=(10, 6)),
    cmap="viridis",
    vmin=0,
    vmax=1,
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
MultiPlotter(
    np.abs(AHd[..., 0, 0, :, :].cpu().numpy()),
    (nTE, nTI),
    axis=(3, 2),
    fig=plt.figure(figsize=(10, 6)),
    cmap="gray",
    vmin=0,
    vmax=5,
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
MultiPlotter(
    np.angle(AHd[..., 0, 0, :, :].cpu().numpy()),
    (nTE, nTI),
    axis=(3, 2),
    fig=plt.figure(figsize=(10, 6)),
    cmap="gray",
    vmin=-np.pi,
    vmax=np.pi,
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
support = torch.abs(C).max(dim=0).values > 0

cm = ConvergenceMonitor(
    support + torch.zeros(shape, device=device), support, logfile=None
)

solver = MIRAGE(
    C[..., 0, 0],
    AHd,
    H,
    (nX, nY, nZ, 1, nTI, nTE),
    lambda_wavelet=1e-3,
    # lambda_hankel=None,
    # lambda_casorati=None,
    lambda_hankel=1e-1,
    lambda_casorati=1e-2,
    weight_wavelet=0.5,
    weight_hankel=0.5,
    weight_casorati=0.5,
    inner_iter=5,
    outer_iter=250,
    callback=cm.callback,
    device=device,
)

In [None]:
# Run MIRAGE reconstruction
solver.solve()

In [None]:
z_image = solver.solver.results["v"].view(torch.complex64).reshape(shape_reduced)

In [None]:
MultiPlotter(
    np.abs(z_image[..., 0, 0, :, :].cpu().numpy()),
    (nTE, nTI),
    axis=(3, 2),
    fig=plt.figure(figsize=(10, 6)),
    cmap="gray",
    vmin=0,
    vmax=2,
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
plt.savefig("/home/jovyan/reconstruction_idea.png", dpi=1200, transparent=True)