# T2 Mapping - Spiral TSE

T2 mapping using a spiral TSE sequence. For each echo in the TSE echo train an image is obtained and used for T2 mapping. 
A long repetition time (TR) is used to ensure full signal recovery before the next k-space line is acquired. 

### Imports

In [None]:
import tempfile
from pathlib import Path

import matplotlib.pyplot as plt
import MRzeroCore as mr0
import numpy as np
import torch
from cmap import Colormap
from einops import rearrange
from mrpro.algorithms.reconstruction import DirectReconstruction
from mrpro.data import KData
from mrpro.data.traj_calculators import KTrajectoryIsmrmrd
from mrpro.operators import DictionaryMatchOp
from mrpro.operators.models import MonoExponentialDecay

from mrseq.scripts.t2_t1rho_tse_spiral import main as create_seq
from mrseq.utils import combine_ismrmrd_files
from mrseq.utils import sys_defaults

### Settings
We are going to use a numerical phantom with a matrix size of 64 x 64 x 4. The repetition time is set to 20 seconds to ensure also tissue with long T1 such as CSF is fully relaxed. 

In [None]:
image_matrix_size = [64, 64, 4]
repetition_time = 20

tmp = tempfile.TemporaryDirectory()
fname_mrd = Path(tmp.name) / 't2_tse_spiral.mrd'

### Create the digital phantom

We will use a 3D BrainWeb phantom which has to be downloaded first.

In [None]:
!wget https://github.com/MRsources/MRzero-Core/raw/main/documentation/playground_mr0/subject05.npz

In [None]:
phantom = mr0.VoxelGridPhantom.brainweb('subject05.npz')
phantom = phantom.interpolate(image_matrix_size[0], image_matrix_size[1], 32)
phantom = phantom.slices([7, 11, 15, 19])
phantom.B1[:] = 1.0
phantom.T2[phantom.T2 > 0.7] = 0.7

### Create the spiral TSE sequence

To create the spiral TSE sequence, we use the previously imported [t2_t1rho_tse_spiral script](../src/mrseq/scripts/t2_t1rho_tse_spiral.py).


In [None]:
sequence, fname_seq = create_seq(
    system=sys_defaults,
    spin_lock_times=(0,),
    test_report=False,
    timing_check=False,
    fov_xy=float(phantom.size.numpy()[0]),
    fov_z=float(phantom.size.numpy()[2]),
    tr=repetition_time,
    n_slice_encoding=image_matrix_size[2],
    n_readout=image_matrix_size[0],
    n_spiral_arms=image_matrix_size[1],
)

### Simulate the sequence
Now, we pass the sequence and the phantom to the MRzero simulation and save the simulated signal as an (ISMR)MRD file.

In [None]:
mr0_sequence = mr0.Sequence.import_file(str(fname_seq.with_suffix('.seq')))
signal, ktraj_adc = mr0.util.simulate(mr0_sequence, phantom, accuracy=1e0)
mr0.sig_to_mrd(fname_mrd, signal, sequence)
combine_ismrmrd_files(fname_mrd, Path(f'{fname_seq}_header.h5'))

### Reconstruct the images at different echo times

We use [MRpro](https://github.com/PTB-MR/MRpro) for the image reconstruction.

In [None]:
kdata = KData.from_file(str(fname_mrd).replace('.mrd', '_with_traj.mrd'), trajectory=KTrajectoryIsmrmrd())
recon = DirectReconstruction(kdata, csm=None)
idata = recon(kdata)

Let's visualize the first six spiral arms of the trajectory for the first three echoes.

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
for echo in range(3):
    for spiral_arm in range(6):
        kx = kdata.traj.kx[echo, 0, 0, spiral_arm, :].numpy()
        ky = kdata.traj.ky[echo, 0, 0, spiral_arm, :].numpy()
        ax[echo].plot(kx, ky, 'o-', markersize=2)
        ax[echo].set_aspect('equal', adjustable='box')

We can see that the spiral amrs are rotated by the golden angle and that each echo has a different trajectory.
This ensures that any undersampling artifacts are incoherent along the echo dimension.

We can now plot the images at different echo times.

In [None]:
idat = idata.data.abs().numpy().squeeze()
fig, ax = plt.subplots(idat.shape[1], idat.shape[0], figsize=(4 * idat.shape[0], 4 * idat.shape[1]))
for echo in range(idat.shape[0]):
    for zidx in range(idat.shape[1]):
        ax[zidx, echo].imshow(idat[echo, zidx, :, :], cmap='gray')
        if zidx == 0:
            ax[zidx, echo].set_title(f'TE = {int(idata.header.te[echo] * 1000)} ms')
        ax[zidx, echo].set_xticks([])
        ax[zidx, echo].set_yticks([])

### Estimate the T2 maps
We use a dictionary matching approach to estimate the T2 maps. Afterward, we compare them to the input and ensure they match.

In [None]:
dictionary = DictionaryMatchOp(MonoExponentialDecay(decay_time=idata.header.te), index_of_scaling_parameter=0)
dictionary.append(torch.tensor(1.0), torch.linspace(0.01, 0.8, 1000)[None, :])
m0_match, t2_match = dictionary(idata.data[:, 0])

t2_input = np.roll(
    rearrange(phantom.T2.numpy().squeeze()[::-1, ::-1, ::-1], 'x y z -> z y x'), shift=(1, 1, 1), axis=(0, 1, 2)
)
obj_mask = np.zeros_like(t2_input)
obj_mask[t2_input > 0] = 1
t2_measured = t2_match.numpy().squeeze() * obj_mask

fig, ax = plt.subplots(idat.shape[1], 3, figsize=(15, 3 * idat.shape[1]))
for cax in ax.flatten():
    cax.set_xticks([])
    cax.set_yticks([])

for zidx in range(idat.shape[1]):
    im = ax[zidx, 0].imshow(t2_input[zidx], vmin=0, vmax=0.7, cmap=Colormap('navia').to_mpl())
    fig.colorbar(im, ax=ax[zidx, 0], label='Input T2 (s)')

    im = ax[zidx, 1].imshow(t2_measured[zidx], vmin=0, vmax=0.7, cmap=Colormap('navia').to_mpl())
    fig.colorbar(im, ax=ax[zidx, 1], label='Measured T2 (s)')

    im = ax[zidx, 2].imshow(t2_measured[zidx] - t2_input[zidx], vmin=-0.07, vmax=0.07, cmap='bwr')
    fig.colorbar(im, ax=ax[zidx, 2], label='Difference T2 (s)')

relative_error = np.sum(np.abs(t2_input - t2_measured)) / np.sum(np.abs(t2_input))
print(f'Relative error {relative_error}')
assert relative_error < 0.015