# Estimation of coil sensitivity maps

Cartesian FLASH sequence to estimate coil sensitivity maps.

### Imports

In [None]:
import tempfile
from pathlib import Path

import matplotlib.pyplot as plt
import MRzeroCore as mr0
import numpy as np
import torch
from einops import rearrange
from mrpro.algorithms.reconstruction import DirectReconstruction
from mrpro.data import CsmData
from mrpro.data import KData
from mrpro.data import SpatialDimension
from mrpro.data.traj_calculators import KTrajectoryCartesian
from mrpro.phantoms.coils import birdcage_2d

from mrseq.scripts.cartesian_flash import main as create_seq
from mrseq.utils import sys_defaults

### Settings
We are going to use a numerical phantom with a matrix size of 120 x 120.

In [None]:
image_matrix_size = [120, 120]

tmp = tempfile.TemporaryDirectory()
fname_mrd = Path(tmp.name) / 'csm_flash.mrd'

### Create the digital phantom

We use the standard Brainweb phantom from [MRzero](https://github.com/MRsources/MRzero-Core).

In [None]:
im_dims = SpatialDimension(z=1, y=image_matrix_size[1], x=image_matrix_size[0])
coil_maps = birdcage_2d(6, image_dimensions=im_dims, relative_radius=0.8)

phantom = mr0.util.load_phantom(image_matrix_size)
coil_sens = torch.flip(rearrange(coil_maps[0, ...], 'coils z y x -> coils x y z'), dims=(1, 2))
phantom.coil_sens = torch.roll(coil_sens, shifts=(-1, -1), dims=(1, 2))

### Create the Cartesian FLASH sequence

To create the Cartesian FLASH sequence, we use the previously imported [cartesian_flash script](../src/mrseq/scripts/cartesian_flash.py).


In [None]:
sequence, fname_seq = create_seq(
    system=sys_defaults,
    fov_xy=float(phantom.size.numpy()[0]),
    fov_z=float(phantom.size.numpy()[2]),
    n_readout=image_matrix_size[0],
    n_phase_encoding=image_matrix_size[1],
    n_slice_encoding=1,
    test_report=False,
    timing_check=False,
)

### Simulate the sequence
Now, we pass the sequence and the phantom to the MRzero simulation and save the simulated signal as an (ISMR)MRD file.

In [None]:
mr0_sequence = mr0.Sequence.import_file(str(fname_seq.with_suffix('.seq')))
signal, ktraj_adc = mr0.util.simulate(mr0_sequence, phantom, accuracy=1e-1)
mr0.sig_to_mrd(fname_mrd, signal, sequence)

### Reconstruct the images of the different coils

We use [MRpro](https://github.com/PTB-MR/MRpro) for the image reconstruction.

In [None]:
kdata = KData.from_file(fname_mrd, trajectory=KTrajectoryCartesian())
kdata.header.encoding_matrix = SpatialDimension(z=1, y=image_matrix_size[1], x=image_matrix_size[0] * 2)
kdata.header.recon_matrix = SpatialDimension(z=1, y=image_matrix_size[1], x=image_matrix_size[0])

recon = DirectReconstruction(kdata, csm=None)
idata = recon(kdata)

We can now plot the different coil images.

In [None]:
idat = idata.data.abs().numpy().squeeze()
fig, ax = plt.subplots(2, 3, figsize=(9, 6))
ax = ax.flatten()
for i in range(idat.shape[0]):
    ax[i].imshow(idat[i, :, :], cmap='gray')
    ax[i].set_xticks([])
    ax[i].set_yticks([])

### Estimate the Coil Sensitivity Maps

In [None]:
csm = CsmData.from_idata_inati(idata, smoothing_width=9)

pd_input = np.roll(rearrange(phantom.PD.numpy().squeeze()[::-1, ::-1], 'x y -> y x'), shift=(1, 1), axis=(0, 1))
obj_mask = np.zeros_like(pd_input)
obj_mask[pd_input > 0] = 1

csm_measured = csm.data.squeeze().abs().numpy() * obj_mask[None]
csm_measured /= csm_measured.max()
csm_input = coil_maps.squeeze().abs().numpy() * obj_mask[None]
csm_input /= csm_input.max()

fig, ax = plt.subplots(3, 6, figsize=(20, 6))
for cax in ax.flatten():
    cax.set_xticks([])
    cax.set_yticks([])

for idx in range(csm.shape[1]):
    im = ax[0, idx].imshow(csm_input[idx], vmin=0, vmax=1, cmap='grey')
    fig.colorbar(im, ax=ax[0, idx], label='Input CSM (a.u.)')

    im = ax[1, idx].imshow(csm_measured[idx], vmin=0, vmax=1, cmap='grey')
    fig.colorbar(im, ax=ax[1, idx], label='Estimated CSM (a.u.)')

    im = ax[2, idx].imshow(csm_measured[idx] - csm_input[idx], vmin=-0.1, vmax=0.1, cmap='bwr')
    fig.colorbar(im, ax=ax[2, idx], label='Difference CSM (a.u.)')

relative_error = np.sum(np.abs(csm_input - csm_measured)) / np.sum(np.abs(csm_input))
print(f'Relative error {relative_error}')
assert relative_error < 0.02