# Accelerated Parameter Mapping of Multiple-Echo Gradient-Echo Data Using Model-Based Iterative Reconstruction

*M. Zimmermann, Z. Abbas, K. Dzieciol and N. J. Shah, "Accelerated Parameter Mapping of Multiple-Echo Gradient-Echo Data Using Model-Based Iterative Reconstruction," in IEEE Transactions on Medical Imaging, vol. 37, no. 2, pp. 626-637, Feb. 2018, doi: 10.1109/TMI.2017.2771504.*

# Load data

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

import sys

sys.path.insert(0, "../src")

import matplotlib.pyplot as plt
import numpy as np
import torch
import zarr
from s3fs import S3FileSystem

from juart.conopt.functional.fourier import nonuniform_fourier_transform_adjoint
from juart.conopt.tfs.fourier import nonuniform_transfer_function
from juart.dl.data.inference import DatasetInference
from juart.recon.monitoring import ConvergenceMonitor
from juart.recon.regsense import REGSENSE
from juart.recon.sense import SENSE
from juart.vis import MultiPlotter

torch.set_num_threads(16)
torch.set_grad_enabled(False)

In [None]:
session = "7T1029"
num_spokes = 8
slices = (80, 81)
# nTI, nTE = 19, 9

In [None]:
dataset = DatasetInference(
    "qrage/sessions/%s/preproc.zarr/preproc.zarr",
    ["7T1029"],
    slices,
    num_spokes,
    endpoint_url="https://s3.fz-juelich.de",
    backend="s3",
)

In [None]:
data = dataset[0]

In [None]:
C = data["sensitivity_maps"][:, :, :, None]
k = data["kspace_trajectory"][:, :, None, :, :]
AHd = data["images_regridded"][None, :, :, :, None, :, :]

In [None]:
nX, nY, nZ, nS, nTI, nTE = AHd.shape[1:]

In [None]:
MultiPlotter(
    torch.abs(C).cpu().numpy(),
    (1, 8),
    axis=0,
    fig=plt.figure(figsize=(10, 3)),
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
MultiPlotter(
    torch.angle(C).cpu().numpy(),
    (1, 8),
    axis=0,
    fig=plt.figure(figsize=(10, 3)),
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
H = nonuniform_transfer_function(
    k,
    AHd.shape,
    oversampling=(2, 2),
)

In [None]:
MultiPlotter(
    torch.abs(H[0, :, :, 0, 0, :, :]).cpu().numpy(),
    (nTE, nTI),
    axis=(3, 2),
    fig=plt.figure(figsize=(10, 6)),
    cmap="viridis",
    vmin=0,
    vmax=1,
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
MultiPlotter(
    torch.abs(AHd[0, :, :, 0, 0, :, :]).cpu().numpy(),
    (nTE, nTI),
    axis=(3, 2),
    fig=plt.figure(figsize=(10, 6)),
    cmap="gray",
    vmin=0,
    vmax=2,
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
MultiPlotter(
    torch.angle(AHd[0, :, :, 0, 0, :, :]).cpu().numpy(),
    (nTE, nTI),
    axis=(3, 2),
    fig=plt.figure(figsize=(10, 6)),
    cmap="gray",
    vmin=-np.pi,
    vmax=np.pi,
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
cg_solver = SENSE(
    C,
    AHd,
    H,
)

In [None]:
cg_image = cg_solver.solve().view(torch.complex64).reshape(AHd.shape)

In [None]:
MultiPlotter(
    torch.abs(cg_image[0, :, :, 0, 0, :, :]).cpu().numpy(),
    (nTI, nTE),
    axis=(3, 2),
    fig=plt.figure(figsize=(10, 6)),
    cmap="gray",
    vmin=0,
    vmax=2,
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
# support = torch.abs(C).max(dim=0).values > 0

# cm = ConvergenceMonitor(support + torch.zeros(shape), support, logfile=None)

solver = REGSENSE(
    C,
    AHd,
    H,
    lambda_wavelet=1e-3,
    lambda_hankel=1e-1,
    lambda_casorati=1e-2,
    weight_wavelet=0.5,
    weight_hankel=0.5,
    weight_casorati=0.5,
    cg_maxiter=5,
    admm_maxiter=50,
    # callback=cm.callback,
)

In [None]:
# Run MIRAGE reconstruction
solver.solve()

In [None]:
z_image = solver.solver.results["v"].view(torch.complex64).reshape(AHd.shape)

In [None]:
MultiPlotter(
    torch.abs(z_image[0, :, :, 0, 0, :, :]).cpu().numpy(),
    (nTE, nTI),
    axis=(3, 2),
    fig=plt.figure(figsize=(10, 6)),
    cmap="gray",
    vmin=0,
    vmax=2,
    cbar_size="2.5%",
    cbar_pad=0.1,
)