<a href="https://colab.research.google.com/github/midas-tum/esmrmb_lmr_2022/blob/main/Workbook_MRI_reconstruction_solutions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

In this exercise, you will get introduced to the world of MRI reconstruction. We start with examining the raw k-space data and coil-sensitivity maps and build the multi-coil forward and adjoint operator. Additionally, we solve a linear and a regularized reconstruction problem, which allows us to deeply understand where we can later connect to machine learning.

First, we install the dependencies and download the data.

Data were acquired on a 3T Siemens Magnetom Vida at the Institute of Biomedical Imaging, Graz University of Technology, Austria. Data should be only used for educational purpose.

In [None]:
# install dependencies
!pip install PyWavelets git+https://github.com/khammernik/medutils.git

In [None]:
# download data
!wget -O brain_cartesian_2D.h5 https://www.dropbox.com/s/hclfv3re91qb1v3/brain_cartesian_2D.h5?dl=1


In [None]:
# Download ESPIRiT code for coil sensitivity map estimation
!git clone https://github.com/mikgroup/espirit-python.git
!cp /content/espirit-python/espirit.py .

# Magnetic Resonance Image (MRI) Reconstruction

The goal is to recover the clean image $x$, which is obtained by undersampled k-space data $y$ and corrupted by additive Gaussian white noise $n$,
$$ y = Ax + n. $$
The rawdata $y$ was aquired for multiple receive coils. The linear operator $A$ denotes the mapping from image space to k-space.

![MRI Inverse Problem](https://github.com/midas-tum/esmrmb_lmr_2022/blob/main/images/mri_inverse_problem.png?raw=true)



## Data Loading

In the first step, we examine the avaiable data regarding their shape and their datatype. Note, that we are dealing with complex-valued data here.

In [None]:
import h5py
import numpy as np
import medutils
np.random.seed(1001)

ds = h5py.File('./brain_cartesian_2D.h5', 'r')
kspace = ds['kspace'][()]
ds.close()

print(f'K-Space:')
print(f'dtype={kspace.dtype}')
print(f'(nCoils, nFE, nPE)={kspace.shape}')
nCoils, nFE, nPE = kspace.shape

We observe that we have 16 coils, the number of frequency encoding (readout) points `nFE` equals 640, and the number of phase encoding steps `nPE` is 330. We will come back to this later.

## Data Visualization
For data visualization, you are free to use any plotting library such as `matplotlib` or use the provided `medutils` package. The `medutils` package has some useful function for visualization:
- `kshow` Process the data in log-space
- `imshow` Display the magnitude of the image
- `plot_array` Re-arrange the images from a 3D array next to each other.
- `ksave` Save k-space
- `imsave` Save images

We will first visualize the `kspace`. The data was acquired with 16 coils. The vertical direction is the frequency encoding (FE) direction, and the horizontal direction is the phase encoding (PE) direction.

In [None]:
medutils.visualization.kshow(medutils.visualization.plot_array(kspace, M=2, N=8), title='K-space', figsize=(40,20))

## Transforming k-space to image space
Let us now start to transform the `kspace` to images. Therefore, we require the centered 2d inverse Fourier transform. Application of the `ifft2c` to the k-space results in individual coil images.

**Task 1: Write the function `ifft2c(kspace)`**

In [None]:
def ifft2c(kspace):
  #TODO implement the centered inverse FFT.
  return np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(kspace), norm='ortho'))

coil_img = ifft2c(kspace)
medutils.visualization.imshow(medutils.visualization.plot_array(coil_img, M=2, N=8), title='Coil Images', figsize=(40,20))

You might notice several things. First, you see that only a fraction of the image is bright. This is due to the effect that the coils are sensitive only in a certain spatial region. Second, you might notice that there are a lot of black areas all over the image, especially in column direction. This extended field of view in read-out direction, also termed frequency-encoding direction is actually for free, i.e. does not cost any additional acquisition time, and is acquired per default on MRI scanners. Assuming the base resolution is 320, the number of frequency encoding steps is (at least) doubled. This frequency oversampling results in an increased field-of-view in this direction. After the image is transformed to image domain, only the central part needs to be visualized. Thus, for display, we will from now on only consider the central part.

## Root-Sum-of-Squares Reconstruction

Now, we calculate the Root-Sum-of-Squares reconstruction $x_{rss}$
$$ x_{rss} = \sqrt{ \sum_{c=1}^{nCoils} \vert x_c \vert ^ 2 }, $$

where $x_c$ are the individual coil images. Note that using the root-sum-of-squares reconstruction, the phase information of the complex-valued data gets lost. 

We now visualize only the central part of the reconstructed image of size `[nFE//2, nPE]`. 

**Task 2: Implement the RSS reconstruction `rss(coil_img)`**

In [None]:
def rss(coil_img):
  #TODO implement the rss reconstruction
  return np.sqrt(np.sum(np.abs(coil_img)**2, 0))

x_rss = rss(coil_img)
medutils.visualization.imshow(medutils.visualization.center_crop(x_rss, (nFE//2, nPE)), figsize=(10,10), title='RSS reconstruction')

# Sensitivity Map Estimation

The coil sensitivity maps (`smaps`) are smooth maps that show us in which parts the individual coil elements are sensitive. We will need these information for our multi-coil MRI forward and adjoint operators. We use the [python implementation](https://github.com/mikgroup/espirit-python) for ESPIRiT [1,2] to estimate these coil sensitivity maps.

[1] Uecker et al. [ESPIRiT—an eigenvalue approach to autocalibrating parallel MRI: Where SENSE meets GRAPPA](https://onlinelibrary.wiley.com/doi/10.1002/mrm.24751). Magn Reson Med 71(3):990-1001, 2014.

[2] https://github.com/mikgroup/espirit-python

In [None]:
import espirit
kspace_espirit = np.transpose(kspace, (1, 2, 0))[:,:,np.newaxis]
smaps_espirit = espirit.espirit(kspace_espirit, 8, 24, 0.05, 0)

smaps = smaps_espirit[:,:,0,:,0]
smaps = np.transpose(smaps, (2, 0, 1))

Let us visualize the coil sensitivity maps for our k-space.

In [None]:
medutils.visualization.imshow(medutils.visualization.plot_array(smaps), title='Sensitivity Maps  (compressed)', figsize=(40,20))

# Multi-Coil Operators and Sensitivity-Weighted coil combination

Now, we have all ingredients to combine the image! Do you remember how the forward and adjoint MRI multi-coil operators are defined? These are required in the next steps.

**Task 4: Implement the multi-coil forward operator $A$ in `mriForwardOp(image, smaps, mask)` and adjoint operator $A^*$ in `mriAdjointOp(kspace, smaps, mask)`. Please check the lecture slides for details on the implementation.**

*Hint: Start with the implementation of the adjoint operator and make use of the previously written function `ifft2c`. Then, define a function `fft2c` which is the 2D centered Fourier transform before you continue with the forward operator.*

## Suggested Readings:

Pruessmann et al. [SENSE: Sensitivity encoding for fast MRI](https://onlinelibrary.wiley.com/doi/abs/10.1002/%28SICI%291522-2594%28199911%2942%3A5%3C952%3A%3AAID-MRM16%3E3.0.CO%3B2-S) Magnetic Resonance in Medicine, 43(5):952-962, 1999.


In [None]:
def mriAdjointOp(kspace, smaps, mask):
  # TODO implement
  return np.sum(ifft2c(kspace * mask)*np.conj(smaps), axis=0)

def fft2c(image):
  # TODO implement
  return np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(image), norm='ortho'))

def mriForwardOp(image, smaps, mask):
  # TODO implement
  return fft2c(smaps * image) * mask

Now, you should check if the adjoint operator is working as expected. The result should be a coil-combined image. Right now, there is no undersampling mask involved, i.e., it is set to all ones.

In [None]:
img_cc = mriAdjointOp(kspace, smaps, np.ones_like(kspace))
medutils.visualization.imshow(medutils.visualization.center_crop(img_cc, (nFE//2, nPE)), title='Combined image (sensitivity-weighted)', figsize=(8,8))

**Task 5: Adjointness check**

Now, also check if the operators are adjoint using the following equation:
$$ \langle Au, v\rangle = \langle u, A^Hv\rangle,$$
where $u$ and $v$ are complex random numbers. The variable $u$ should have the same size as the image $x$ and $v$ should have the same size as the k-space $y$. Note, that the sampling mask and sensitivity maps are kept constant. To create random numbers, use `np.random.randn`. Print the results for the left-hand side and right-hand side of the equation.

*Hint: To get the correct result, you might use the `conj` when computing the complex-valued dot product. Also, the forward and inverse Fourier transform have to be scaled the same way (`norm='ortho'`) to get correct results.*

In [None]:
# TODO implement the adjointness check
u = np.random.randn(*img_cc.shape) + 1j * np.random.randn(*img_cc.shape)
v = np.random.randn(*kspace.shape) + 1j * np.random.randn(*kspace.shape)


lhs = np.sum(np.conj(mriForwardOp(u, smaps, np.ones_like(kspace))) * v)
rhs = np.sum(np.conj(u) * mriAdjointOp(v, smaps, np.ones_like(kspace)))

print(f'{lhs} == {rhs}')

# Undersampling
Now we will get to the most exciting part of this exercise - undersampling the k-space! We will generate undersampling masks for acceleration $R\in\lbrace 2,3\rbrace$. 

**Task 5: Your task is to play around with different undersampling patterns. Generate a sampling pattern in the function `generate_mask(R, nPE, nFE, mode)` where `R` is the acceleration factor and `mode` is an integer corresponding to following patterns:**
1. Choose randomly an integer {0,1} with propability `p=[1-1/R, 1/R]`
2. Only set a dense block of `nRef=20` lines in the center of k-space.
3. Combine 1.+2.
4. Only set every `R`-th line
5. Combine 2.+4. 

To create the mask, simply generate a 1D line of size `nPE`. For each item, compute and print the effective acceleration `Reff`, which is determined by `nPE` divided by the number of sampled points. 

The code to get the full mask of size `[nFE, nPE]` is given below.

To continue the tasks on iterative reconstruction, please use `mode=3` and `R=3` as well as `mode=5` and `R=3`.

In [None]:
# TODO generate undersampling masks
def generate_mask(R, nPE, nFE, mode):
  nRef = 20
  if mode == 1:
    mask = np.random.choice([1, 0],(nPE),p=[1/R, 1-1/R])
  elif mode == 2:
    mask = np.zeros(nPE)
    mask[nPE//2-nRef//2:nPE//2+nRef//2] = 1
  elif mode == 3:
    mask = np.random.choice([1, 0],(nPE),p=[1/R, 1-1/R])
    mask[nPE//2-nRef//2:nPE//2+nRef//2] = 1
  elif mode == 4:
    mask = np.zeros(nPE)
    mask[::R] = 1
  elif mode == 5:
    mask = np.zeros(nPE)
    mask[::R] = 1
    mask[nPE//2-nRef//2:nPE//2+nRef//2] = 1
  else:
    raise ValueError(f'Mode {mode} not defined')

  Reff = nPE/np.sum(mask)
  print(f'Reff={Reff}')

  mask = mask.reshape(1, nPE).repeat(nFE, axis=0)

  return mask

Now, generate the mask and visualize it (we visualize only a fraction in frequency encoding direction).

In [None]:
np.random.seed(1001)
mask = generate_mask(R=3, nPE=nPE, nFE=nFE, mode=5)

medutils.visualization.imshow(mask[:40,:], 'Undersampling mask', figsize=(20,20))

**Task 6: Zero-Filling solution**

Now you are ready to estimate the zero filling solution by applying the adjoint operator to the data, by using the estimated undersampling mask `mask`. Play around with above mask configurations. How do the images change?

In [None]:
#TODO Apply the adjoint operator to the data and use the newly created undersampling mask.
img_cc_us = mriAdjointOp(kspace, smaps, mask)
img_cc_us = medutils.visualization.center_crop(img_cc_us, (nFE//2, nPE))
medutils.visualization.imshow(img_cc_us, 'Undersampled image (zero filling)', figsize=(10,10))

# Linear and Regularized Reconstruction
Now, we are ready to implement linear and regularized reconstruction. We additionally need the gradient operator, implementing forward / backward differences in `D` and `DT`, and the multi-coil MRI forward and adjoint operators, `A` and `AH`, respectively.

In [None]:
def nabla(x):
    dx = np.pad(x[:,1:], [[0, 0],[0, 1]], mode='edge')
    dy = np.pad(x[1:], [[0, 1],[0, 0]], mode='edge')
    return np.concatenate([dx[None,...] - x, dy[None,...] - x], 0)

def nablaT(x):
    assert x.shape[0] == 2
    dx = np.pad(x[0,:,:-1], [[0, 0],[1, 1]], mode='constant')
    dy = np.pad(x[1,:-1], [[1, 1],[0, 0]], mode='constant')
    return dx[:,:-1] - dx[:,1:] + dy[:-1] - dy[1:]

D = lambda x: nabla(np.real(x)) + 1j * nabla(np.imag(x))
DT= lambda x: nablaT(np.real(x)) + 1j * nablaT(np.imag(x))

In [None]:
A = lambda x: mriForwardOp(x, smaps, mask)
AH = lambda x: mriAdjointOp(x, smaps, mask)

## Solving the linear reconstruction problem
Consider the following minimization problem:

$$ \min_x  E(x,y) = \min_x \frac{1}{2} \Vert Ax - y \Vert_2^2 .$$

While in image denoising we are still able to compute a closed-form solution for this problem, this is not feasible for the task of MRI reconstruction anymore. We instead use first-order optimization methods and solve this by Gradient Descent:
$$ x^{k+1} = x^{k} - \alpha \nabla_x E(x,y) $$
$$ x^{k+1} = x^{k} - \alpha A^H (Ax^k - y) $$

**Task 7: Implement Gradient Descent in `opt_linear` to solve the linear reconstruction problem and run the optimization for `max_iter=50` iterations and a step size of `alpha=1`**

In [None]:
def opt_linear(y, max_iter, alpha):
    x = np.zeros_like(AH(y))

    #TODO implement gradient descent to solve the linear reconstruction problem
    for _ in range(max_iter):
        x = x - alpha * AH(A(x) - y)
    return x

In [None]:
alpha=1
img_linear = opt_linear(kspace, max_iter=50, alpha=alpha)
img_linear = medutils.visualization.center_crop(img_linear, (nFE//2, nPE))
medutils.visualization.imshow(img_linear, f'Linear reconstruction alpha={alpha}', figsize=(10,10))

## L2-H1 Regularization

Now, we regularize the least-squares problem with a regularizer of form $\mathcal{R}(x)=\frac{1}{2} \Vert \nabla x \Vert_2^2$.
Consider now the following minimization problem

$$ \min_x  D(x,y) + \lambda R(x) = \min_x \frac{1}{2}\Vert Ax - y \Vert_2^2 + \frac{\lambda}{2}\Vert \nabla x \Vert_2^2.$$

We solve this by Gradient Descent:
$$ x^{k+1} = x^{k} - \alpha \left( \nabla_x D(x,y) + \nabla_x R(x) \right) $$
$$ x^{k+1} = x^{k} - \alpha \left( A^H (Ax^k - y) + \lambda \nabla^T \nabla x^k \right) $$

**Task 8: Implement gradient descent to solve the L2-H1 regularized problem and run the optimization for `max_iter=200` iterations, a step size of `alpha=1.0` and a regularization parameter of `lambd=0.01`.**

*Note that we do not have the best setting for the parameters here and the difference to the linear reconstruction might be only minimal. You can play around with the hyper-parameters. This example is to show you the properties of L2-H1 regularization and that it is actually hard to find a good set of hyper-parameters (step size, regularization parameters, iterations).*

In [None]:
def opt_reg_l2(y, max_iter, alpha, lambd):
    x = np.zeros_like(AH(y))
    # TODO implement gradient descent for L2-H1 regularization
    for _ in range(max_iter):
        x = x - alpha * AH(A(x) - y) - alpha * lambd * DT(D(x))
    return x

In [None]:
alpha = 1.0
lambd = 0.01
img_reg_l2 = opt_reg_l2(kspace, max_iter=200, alpha=alpha, lambd=lambd)
img_reg_l2 = medutils.visualization.center_crop(img_reg_l2, (nFE//2, nPE))
medutils.visualization.imshow(img_reg_l2, f'L2H1 reconstruction alpha={alpha} lambda={lambd}', figsize=(10,10))

## Sparse MRI: Wavelet Thresholding
Medical images per se are not sparse, however, they might have a sparse representation in some transform domain. One example here is the Wavelet transform, resulting in a multi-level feature representation. We provide the `plot_wavedec` function to find out how the sparse images look like at different scales and orientations.

We perform an optimization first wrt. data consistency term. This is followed by a Wavelet transform, and the *detailed* Wavelet coefficients are surpressed by using soft-thresholding, i.e.,

$$
\text{thresh}(x) = \frac{x}{\vert x \vert}\max(\vert x \vert - \alpha\lambda , 0)
$$

**Task 9: Implement the soft-thresholding in `soft_thresh(x, tau)`**

*Hint: Note, that the absolut value could get zero, and a small epsilon might be adorable to surpress this.*

### Suggested Readings

Lustig et al. [Compressed Sensing MRI](https://ieeexplore.ieee.org/document/4472246), IEEE Signal Processing Magazine 25(2):72-82, 2008.

Lustig et al. [Sparse MRI: The application of compressed sensing for rapid MR imaging](https://onlinelibrary.wiley.com/doi/full/10.1002/mrm.21391). Magnetic Resonance in Medicine 58(6):1182-1195, 2007.

In [None]:
def soft_thresh(x, tau):
    #TODO: implement soft-thresholding
    return x / np.maximum(np.abs(x), 1e-9) * np.maximum(np.abs(x) - tau, 0)

In [None]:
import pywt

def plot_wavedec(img, wavelet='db4', level=2):
    img = medutils.visualization.center_crop(img_cc, (nFE//2, nPE))
    coeffs = pywt.wavedecn(img, wavelet=wavelet, level=level)
    # normalize coeffs
    coeffs[0] /= np.max(np.abs(coeffs[0]))
    for level in range(1, len(coeffs)):
        for key in coeffs[level].keys():
            coeffs[level][key] /= np.max(np.abs(coeffs[level][key]))
    arr, coeff_slices = pywt.coeffs_to_array(coeffs)
    medutils.visualization.imshow(arr, figsize=(10,10))

def opt_reg_wavelet(y, max_iter, alpha, lambd, wavelet='db4', level=3):
    x = np.zeros_like(AH(y))
    wavelet_object = pywt.Wavelet(wavelet)
    threshold = alpha * lambd

    for _ in range(max_iter):
        x = x - alpha * (AH(A(x) - y))
        coeffs = pywt.wavedecn(x, wavelet_object, level=level)
        array, coeff_slices = pywt.coeffs_to_array(coeffs)
        denoised_array=soft_thresh(array, threshold)
        denoised_coeffs = pywt.array_to_coeffs(denoised_array, coeff_slices, output_format='wavedecn')
        denoised_coeffs[0] = coeffs[0]
        x = pywt.waverecn(denoised_coeffs, wavelet_object)
        
    return x

Next, we define a wavelet, the number of levels for decomposition and plot the decomposition.

In [None]:
# Plot Wavelet Decomposition
wavelet='bior2.8'
level=3
plot_wavedec(img_cc, wavelet, level)

Finally, we run the optimization for `lambd=1e-6` and `alpha=1` and 200 iterations.

In [None]:
lambd=1e-6
alpha=1.0
img_reg_wavelet = opt_reg_wavelet(kspace, max_iter=200, alpha=alpha, lambd=lambd, wavelet=wavelet, level=level)
img_reg_wavelet = medutils.visualization.center_crop(img_reg_wavelet, (nFE//2, nPE))
medutils.visualization.imshow(img_reg_wavelet, f'Wavelet reconstruction alpha={alpha} lambda={lambd}', figsize=(10,10))

# Connections to Machine Learning

You might wonder, why there is not any exercise on machine learning? The tasks in this workbook form the cornerstone for successfull Machine Learning MRI reconstruction 😁. However, we would like to introduce you to the world of machine learning for MRI reconstruction in separate IPython notebooks 🤓:
- [Image denoising (magnitude)](https://github.com/midas-tum/merlin/blob/master/notebooks/tutorial_denoising_real.ipynb)
- [Image denoising (2-channel real-valued)](https://github.com/midas-tum/merlin/blob/master/notebooks/tutorial_denoising_2chreal.ipynb)
- [Image denoising (complex-valued)](https://github.com/midas-tum/merlin/blob/master/notebooks/tutorial_denoising_complex.ipynb)
- [Image reconstruction (complex-valued)](https://github.com/midas-tum/merlin/blob/master/notebooks/tutorial_reconstruction_complex.ipynb)
- [Complex-valued activation functions](https://github.com/midas-tum/merlin/blob/master/notebooks/tutorial_complex_activations.ipynb)
- [Complex layers](https://github.com/midas-tum/merlin/blob/master/notebooks/tutorial_complex_layers.ipynb)
