# B0 and B1 Mapping using WASABI

Simultaneous mapping of water shift (B0) and B1 (WASABI) using using an off-resonant preparation pulse. This induces Rabi oscillations which can be seen as sinc-like oscillations in the frequency-offset dimension. B0 can then be estimated by its symmetry axis and B1 by its oscillation frequency.

### Imports

In [None]:
import tempfile
from pathlib import Path

import matplotlib.pyplot as plt
import MRzeroCore as mr0
import numpy as np
import torch
from einops import rearrange
from mrpro.algorithms.reconstruction import DirectReconstruction
from mrpro.data import KData
from mrpro.data import SpatialDimension
from mrpro.data.traj_calculators import KTrajectoryCartesian
from mrpro.operators import DictionaryMatchOp
from mrpro.operators.models import WASABI

from mrseq.scripts.b0_b1_wasabi import main as create_seq
from mrseq.utils import sys_defaults

### Settings
We are going to use a numerical phantom with a matrix size of 128 x 128. 

In [None]:
image_matrix_size = [128, 128]

frequency_offsets = np.linspace(-240, 240, 31)
norm_offset = -35e3

tmp = tempfile.TemporaryDirectory()
fname_mrd = Path(tmp.name) / 'b0_b1_wasabi.mrd'

### Create the digital phantom

We use the standard Brainweb phantom from [MRzero](https://github.com/MRsources/MRzero-Core).

In [None]:
phantom = mr0.util.load_phantom(image_matrix_size)

### Create the WASABI sequence

To create the WASABI sequence, we use the previously imported [b0_b1_wasabi script](../src/mrseq/scripts/b0_b1_wasabi.py).


In [None]:
sequence, fname_seq = create_seq(
    system=sys_defaults,
    test_report=False,
    timing_check=False,
    fov_xy=float(phantom.size.numpy()[0]),
    n_readout=image_matrix_size[0],
    n_phase_encoding=image_matrix_size[0],
    frequency_offsets=frequency_offsets,
    norm_offset=norm_offset,
)

### Simulate the sequence
Now, we pass the sequence and the phantom to the MRzero simulation and save the simulated signal as an (ISMR)MRD file.

In [None]:
mr0_sequence = mr0.Sequence.import_file(str(fname_seq.with_suffix('.seq')))
signal, ktraj_adc = mr0.util.simulate(mr0_sequence, phantom, accuracy=1e-1)
mr0.sig_to_mrd(fname_mrd, signal, sequence)

### Reconstruct the images at different inversion times

We use [MRpro](https://github.com/PTB-MR/MRpro) for the image reconstruction.

In [None]:
kdata = KData.from_file(fname_mrd, trajectory=KTrajectoryCartesian())
kdata.header.encoding_matrix = SpatialDimension(z=1, y=image_matrix_size[1], x=image_matrix_size[0] * 2)
kdata.header.recon_matrix = SpatialDimension(z=1, y=image_matrix_size[1], x=image_matrix_size[0])
recon = DirectReconstruction(kdata, csm=None)
idata = recon(kdata)

We can now plot the images at different offset frequencies.

In [None]:
idat = idata.rss().abs().numpy().squeeze()
offsets = np.concatenate(([norm_offset], frequency_offsets))
fig, ax = plt.subplots(4, idat.shape[0] // 4, figsize=(3 * idat.shape[0] // 4, 3 * 4))
ax = ax.flatten()
for i in range(idat.shape[0]):
    ax[i].imshow(idat[i, :, :], cmap='gray')
    ax[i].set_title(f'Offset = {int(offsets[i])} Hz')
    ax[i].set_xticks([])
    ax[i].set_yticks([])

### Estimate B0 and B1
We use a dictionary matching approach to estimate the T1 maps. Afterward, we compare them to the input and ensure they match.

In [None]:
dictionary = DictionaryMatchOp(WASABI(offsets=torch.as_tensor(offsets, dtype=torch.float32)))
dictionary.append(
    torch.linspace(-100.0, 100.0, 100),
    torch.linspace(0.1, 1.5, 100)[None, :],
    torch.linspace(-100.0, 100.0, 100)[None, None, :],
    torch.linspace(-100.0, 100.0, 100)[None, None, None, :],
)
b0_match, rb1_match, c_match, a_match = dictionary(idata.data)

b0_input = np.roll(rearrange(phantom.B0.numpy().squeeze()[::-1, ::-1], 'x y -> y x'), shift=(1, 1), axis=(0, 1))
obj_mask = np.zeros_like(b0_input)
obj_mask[b0_input > 0] = 1
b0_measured = b0_match.numpy().squeeze() * obj_mask

rb1_input = np.abs(
    np.roll(rearrange(phantom.B1.numpy().squeeze()[::-1, ::-1], 'x y -> y x'), shift=(1, 1), axis=(0, 1))
)
rb1_measured = rb1_match.abs().numpy().squeeze() * obj_mask

fig, ax = plt.subplots(1, 4)
ax[0].imshow(b0_match.squeeze().numpy() * obj_mask)
ax[1].imshow(rb1_match.abs().squeeze().numpy() * obj_mask)
ax[2].imshow(c_match.squeeze().numpy() * obj_mask)
ax[3].imshow(a_match.squeeze().numpy() * obj_mask)

fig, ax = plt.subplots(2, 3, figsize=(15, 6))
for cax in ax.flatten():
    cax.set_xticks([])
    cax.set_yticks([])

im = ax[0, 0].imshow(b0_input, vmin=-40, vmax=40, cmap='bwr')
fig.colorbar(im, ax=ax[0, 0], label='Input B0 (Hz)')

im = ax[0, 1].imshow(b0_measured, vmin=-1, vmax=1, cmap='bwr')
fig.colorbar(im, ax=ax[0, 1], label='Measured B0 (Hz)')

im = ax[0, 2].imshow(b0_measured - b0_input, vmin=-1.8, vmax=1.8, cmap='bwr')
fig.colorbar(im, ax=ax[0, 2], label='Difference B0 (Hz)')

relative_error = np.sum(np.abs(b0_input - b0_measured)) / np.sum(np.abs(b0_input))
print(f'Relative error {relative_error}')
# assert relative_error < 0.08

im = ax[1, 0].imshow(rb1_input, vmin=0, vmax=1.2, cmap='magma')
fig.colorbar(im, ax=ax[1, 0], label='Input relative B1 ')

im = ax[1, 1].imshow(rb1_measured, vmin=-1, vmax=1, cmap='magma')
fig.colorbar(im, ax=ax[1, 1], label='Measured relative B1')

im = ax[1, 2].imshow(rb1_measured - rb1_input, vmin=-1.8, vmax=1.8, cmap='bwr')
fig.colorbar(im, ax=ax[1, 2], label='Difference relative B1')

relative_error = np.sum(np.abs(rb1_input - rb1_measured)) / np.sum(np.abs(rb1_input))
print(f'Relative error {relative_error}')
# assert relative_error < 0.08