# Accelerated Parameter Mapping of Multiple-Echo Gradient-Echo Data Using Model-Based Iterative Reconstruction

*M. Zimmermann, Z. Abbas, K. Dzieciol and N. J. Shah, "Accelerated Parameter Mapping of Multiple-Echo Gradient-Echo Data Using Model-Based Iterative Reconstruction," in IEEE Transactions on Medical Imaging, vol. 37, no. 2, pp. 626-637, Feb. 2018, doi: 10.1109/TMI.2017.2771504.*

# Load data

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

import sys

sys.path.insert(0, "../src")

import matplotlib.pyplot as plt
import numpy as np
import torch
import zarr
from s3fs import S3FileSystem

from juart.conopt.functional.fourier import nonuniform_fourier_transform_adjoint
from juart.conopt.tfs.fourier import nonuniform_transfer_function
from juart.recon.mirage import MIRAGE
from juart.recon.monitoring import ConvergenceMonitor
from juart.vis import MultiPlotter

torch.set_num_threads(16)
torch.set_grad_enabled(False)

In [None]:
# session = "7T1026"
session = "7T1029"
nUS = 8
iS = 80

In [None]:
# zarr_preproc_fname = (
#     "qrage/sessions/%s/preproc.zarr/mz_me_mpnrage3d_grappa.zarr" % session
# )
zarr_preproc_fname = "qrage/sessions/%s/preproc.zarr/preproc.zarr"
zarr_image_fname = "qrage/recons/%s/test.zarr/images.zarr"

In [None]:
print(zarr_preproc_fname)
print(zarr_image_fname)

In [None]:
# Check how to pass S3 credentials
# https://s3fs.readthedocs.io/en/latest/#credentials

fs = S3FileSystem(
    anon=False,
    endpoint_url="https://s3.fz-juelich.de",
    asynchronous=True,
)
store = zarr.storage.FsspecStore(
    fs,
)

In [None]:
zarr_preproc_file = zarr.open_group(store, path=zarr_preproc_fname % session, mode="r")

nC, nX, nY, nZ, nS = zarr_preproc_file["C"].shape[:5]
nC, spokes, baseresolution, nZ, nS, nTI, nTE = zarr_preproc_file["d"].shape

nTI, nTE = 1, 1

shape = (nX, nY, nZ, 1, nTI, nTE)

print("Reconstructing slice %s of %s." % (iS, nS))

nK = baseresolution * nUS

# Read data
C = zarr_preproc_file["C"][:, :, :, :, iS : iS + 1, :, :]
k = zarr_preproc_file["k"][:, :nUS, :, :, :, :nTI, :nTE]
d = zarr_preproc_file["d"][:, :nUS, :, :, iS : iS + 1, :nTI, :nTE] / 1e-4

k = k.reshape((2, nK, 1, nTI, nTE))
d = d.reshape((8, 1, 1, nK, 1, nTI, nTE))

C = torch.tensor(C, dtype=torch.complex64)
k = torch.tensor(k, dtype=torch.float32)
d = torch.tensor(d, dtype=torch.complex64)

In [None]:
MultiPlotter(
    torch.abs(C[..., 0, 0]).cpu().numpy(),
    (1, 8),
    axis=0,
    fig=plt.figure(figsize=(10, 3)),
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
MultiPlotter(
    torch.abs(C[..., 0, 0]).cpu().numpy(),
    (1, 8),
    axis=0,
    fig=plt.figure(figsize=(10, 3)),
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
MultiPlotter(
    torch.angle(C[..., 0, 0]).cpu().numpy(),
    (1, 8),
    axis=0,
    fig=plt.figure(figsize=(10, 3)),
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
AHd = nonuniform_fourier_transform_adjoint(
    k, d, (nX, nY, nZ), (nC, nX, nY, nZ, 1, nTI, nTE)
)
AHd = torch.sum(torch.conj(C) * AHd, dim=0)

In [None]:
H = nonuniform_transfer_function(
    k, (nX, nY, nZ, 1, nTI, nTE, nK), oversampling=(2, 2, 1)
)

In [None]:
MultiPlotter(
    torch.abs(H[0, :, :, 0, 0, :, :]).cpu().numpy(),
    (nTE, nTI),
    axis=(3, 2),
    fig=plt.figure(figsize=(10, 6)),
    cmap="viridis",
    vmin=0,
    vmax=1,
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
MultiPlotter(
    torch.abs(AHd[..., 0, 0, :, :]).cpu().numpy(),
    (nTE, nTI),
    axis=(3, 2),
    fig=plt.figure(figsize=(10, 6)),
    cmap="gray",
    vmin=0,
    vmax=5,
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
MultiPlotter(
    torch.angle(AHd[..., 0, 0, :, :]).cpu().numpy(),
    (nTE, nTI),
    axis=(3, 2),
    fig=plt.figure(figsize=(10, 6)),
    cmap="gray",
    vmin=-np.pi,
    vmax=np.pi,
    cbar_size="2.5%",
    cbar_pad=0.1,
)

In [None]:
support = torch.abs(C).max(dim=0).values > 0

cm = ConvergenceMonitor(support + torch.zeros(shape), support, logfile=None)

solver = MIRAGE(
    C[..., 0, 0],
    AHd,
    H,
    lambda_wavelet=1e-3,
    lambda_hankel=None,
    lambda_casorati=None,
    weight_wavelet=0.5,
    weight_hankel=0.5,
    weight_casorati=0.5,
    cg_maxiter=5,
    admm_maxiter=50,
    callback=cm.callback,
)

In [None]:
# Run MIRAGE reconstruction
solver.solve()

In [None]:
z_image = solver.solver.results["v"].view(torch.complex64).reshape(shape)

In [None]:
MultiPlotter(
    torch.abs(z_image[..., 0, 0, :, :]).cpu().numpy(),
    (nTE, nTI),
    axis=(3, 2),
    fig=plt.figure(figsize=(10, 6)),
    cmap="gray",
    vmin=0,
    vmax=2,
    cbar_size="2.5%",
    cbar_pad=0.1,
)