<a href="https://colab.research.google.com/github/kmjohnson3/Intro-to-MRI/blob/master/AdvancedNoteBooks/Constrained_Reconstruction_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MRI Constrained Reconstruction Excercise

This Jupyter notebook provides some hands on experience with compressed sensing like contrained reconstruction. Each code cell can be run by clicking on the upper left corner. You can also run all by using the "Runtime" menu on the top menu bar. When you modify one of the reconstruction paramaters, think about what you expect the change to be.

# Objectives
*   Reconstruct images using constrained reconstructions
*   Understand the tradeoffs and failure mode of constrained reconstructions
*   Explore different regularization methods

In python you need to load libraries to use them. This first cell imports a couple of key libraries to reconstruct images.

In [None]:
# This is comment, Python will ignore this line

# Import libraries (load libraries which provide some functions)
%matplotlib inline
import numpy as np # array library
import math
import cmath
import pickle
import pywt #wavelets

# For interactive plotting
from ipywidgets import interact, interactive, FloatSlider, IntSlider, FloatLogSlider, Dropdown
from IPython.display import clear_output, display, HTML

# for plotting modified style for better visualization
import matplotlib.pyplot as plt 
import matplotlib as mpl
mpl.rcParams['lines.linewidth'] = 4
mpl.rcParams['axes.titlesize'] = 24
mpl.rcParams['axes.labelsize'] = 20
mpl.rcParams['xtick.labelsize'] = 16
mpl.rcParams['ytick.labelsize'] = 16
mpl.rcParams['legend.fontsize'] = 16

# Download Raw Data
We are going to download raw data from scans collected on the scanner. This is just a nice single slice, single channel datset

In [None]:
# Get some data - Data is as an HDF5 array. This data is single channel 
import os
if not os.path.exists("multicoil8n.h5"):
  !wget https://www.dropbox.com/s/aihpudtdonm7dxd/multicoil8n.h5

import h5py as h5
with h5.File('multicoil8n.h5', 'r') as hf:
  kdata = np.array(hf['kspace'])
  smaps = np.array(hf['smaps'])

# Linear Operators 

## Subsampling of the data
*   Subselecting a set of lines by changing $\Delta k_y $ controlled by accY
*   Subselecting a set of lines by changing $k_{max} $ in x and y controlled by kmaxX, kmaxY
*   Subselecting a set of point by randomly removing a set of points, controlled by random_undersampling_fraction (1=remove none)

## Forward and adjoint model
*  Forward: Fourier tranform $E$
*  Adjoint: Inverse Fourier transform $E^H$

## Gradient descent
For a linear system solving $Loss(x)=||Ex-d||^2_2$ the gradient is 

>$\frac{\partial{Loss}}{\partial{x}} = E^H(Ex-d) $

We descend the loss by taking steps in the (-) direction of the gradient: 

>$x_{n+1} = x_{n} - \alpha E^H(Ex-d)$

where $\alpha$ is the step size which for this problem is set to $1$





In [None]:
def undersample_data(kspace, acc_y, kmax_x, kmax_y, random_undersampling_fraction):
  # Subsampling
  kspace_us = np.zeros_like(kspace)
  mask = np.zeros_like(kspace)

  # Regular - change delta k 
  mask[:,:,::acc_y] = 1

  # Sub sections - change kmax
  ky,kx = np.meshgrid( np.linspace(-1,1,kspace.shape[-2]),np.linspace(-1,1,kspace.shape[-1]))
  mask *=  np.abs(kx/kmax_x) < 1
  mask *=  np.abs(ky/kmax_y) < 1

  # This ensures calls to the code are repeatable
  np.random.seed(0)
  
  # Random undersampling in 1D
  if random_undersampling_fraction < 1:
    # Get the k-space radius ( we will fully sample the center)
    abs_kx = np.abs(np.linspace(-1,1,kspace_us.shape[-2]))
    abs_ky = np.abs(np.linspace(-1,1,kspace_us.shape[-1]))
    kx,ky = np.meshgrid(abs_ky, abs_kx)
    kr = np.sqrt( kx**2 + ky**2)

    # Randomly remove a fraction of the lines
    mask_pe = np.random.uniform(size=kspace_us.shape)
    mask_pe = mask_pe < random_undersampling_fraction

    # Fully sample 5% of the data
    mask_pe = np.maximum(mask_pe, kr < 0.05)

    # Combine with standard undersampling
    mask *= mask_pe
  
  kspace_us = kspace * mask

  return kspace_us, mask 

def adjoint_fourier_transform(kspace, smaps):
  # This is to do a FFT shift operator so that the center of k-space is at the center of the image
  x,y = np.meshgrid( range(kspace.shape[-2]), range(kspace.shape[-1]))
  chop = (-1)**( x+y)

  # Fourier Transform with Window Function
  coil_images = chop*np.fft.ifftn(kspace*chop, axes=(-2,-1))

  # Coil Sum
  image = np.sum( coil_images*np.conj(smaps), axis=0)

  return image

def fourier_transform(image, smaps):
  # This is to do a FFT shift operator so that the center of k-space is at the center of the image
  x,y = np.meshgrid( range(image.shape[-2]), range(image.shape[-1]))
  chop = (-1)**( x+y)

  # Fourier Transform with Window Function
  kspace = chop*np.fft.fftn(image*smaps*chop, axes=(-2,-1))

  return kspace

def gradient_descent(image, mask, kspace, smaps, alpha=1.0):
  
  # Generate k-space data from Image (Ex)
  kspace_generated = fourier_transform(image, smaps)

  # Take the difference from the true data (Ex-d)
  diff = kspace_generated - kspace

  # We can only take the difference with k-space samples we actually have
  diff_masked = mask*diff
  error = np.sum(np.abs(diff_masked)**2)

  # Calculate the gradient by fourier transforming back
  gradient = adjoint_fourier_transform(diff_masked, smaps)

  # Update the image
  image = image - alpha*gradient

  return image, error

# Alternating minimization using Proximal operators
We will define some proximal operators which we use to minimize functions. These proximal operators are functions that minimize:

>$g(x) + \frac{1}{2}||x-x_0||^2_2 $

where $g(x)$ is the function to minimize. This keeps the solution from wandering from the initial solution. For example, minimizing $||x||_1$ without a contraint would be minimized when $x$ is zero. 

We will use an alternating minimization problem with two steps:

* Take a gradient descent step of the linear part
* Take a proximal gradient descent step

Below defines some functions to do these operations

In [None]:
def l1_prox(input, lamda):
    # Proximal gradient of L1 (sum of absolute values) is thresholding
    abs_input = np.abs(input)
    
    sign = input / (abs_input + 1e-9*np.max(abs_input))

    mag = abs_input - lamda
    mag = (abs(mag) + mag) / 2

    return mag * sign

def l2_prox(input, lamda):
    output = input / ( 1 + lamda)
    return output

def iterative_sense(Mask, kspace, smaps, iterations=10):
    # Initial guess of zeros
    image = np.zeros_like(kspace)

    for iter in range(iterations):
      image = gradient_descent(image, Mask, kspace, smaps)

    return image

def iterative_sense(Mask, kspace, smaps, iterations=10, lamda=1e-3, transform=None, proximal=None):
    
    # Initial guess of zeros
    image = np.zeros(kspace.shape[-2:], dtype=kspace.dtype)

    error_kspace = []
    error_constraint = []

    for iter in range(iterations):
      image, error = gradient_descent(image, Mask, kspace, smaps)
      error_kspace.append(error)

      if transform is not None:
        image = transform.forward(image)
      
      image = proximal(image, lamda)
      error_constraint.append(np.sum(np.abs(image)))

      if transform is not None:
        image = transform.backward(image)
      
    return image, error_kspace, error_constraint

# Tranforms

## Wavelets
This uses multi-scale wavelets to compress the image

## Edge
This is an aproximation of total variation. It used an undecimated wavelet tranform to make the construction similar to that in wavelets.


In [None]:
class wavelet_transform:
  def __init__(self, levels=4):
    self.levels = levels

  def forward(self, x):
    coef= pywt.wavedec2(x, 'db4', level=self.levels)

    # Convert to array
    arr, coeff_slices, coeff_shapes = pywt.ravel_coeffs(coef)
    self.coeff_slices = coeff_slices
    self.coeff_shapes = coeff_shapes

    return arr

  def backward(self,x):

    # Convert array to coef
    coef = pywt.unravel_coeffs(x, self.coeff_slices, self.coeff_shapes, output_format='wavedec2')

    arr = pywt.waverec2(coef, 'db4')
    return arr

class edge_transform:
  def __init__(self, levels=1):
    self.levels = levels

  def forward(self, x):
    coef= pywt.swt2(x, 'db1', level=self.levels, trim_approx=True)

    # Convert to array
    arr, coeff_slices, coeff_shapes = pywt.ravel_coeffs(coef)
    self.coeff_slices = coeff_slices
    self.coeff_shapes = coeff_shapes

    return arr

  def backward(self,x):

    # Convert array to coef
    coef = pywt.unravel_coeffs(x, self.coeff_slices, self.coeff_shapes, output_format='swt2')

    arr = pywt.iswt2(coef, 'db1')
    return arr




# Simulation and plot

In [None]:
def sample_and_plot(acc_y, kmax_x, kmax_y, random_undersampling_fraction, iterations, lamda, norm, transform_name):

  # Grab the scan data
  kspace = kdata.copy()
  kspace /= np.max(np.abs(kspace))

  # Normalize the sensitivity maps (allows a step size of 1)
  maps = smaps.copy() / np.max(np.abs(smaps))

  # Subsample the data
  kspace_us, mask = undersample_data(kspace, acc_y=acc_y, kmax_y=kmax_y, kmax_x=kmax_x, random_undersampling_fraction=random_undersampling_fraction)

  # Pick a transform
  if transform_name == 'Wavelet':
    transform = wavelet_transform()
  elif transform_name == 'Edge':
    transform = edge_transform()
  else:
    transform = None

  # Pick a norm for the the regularization
  if norm==1:
    prox = l1_prox
  else:
    prox = l2_prox

  # Actual reconstruction
  image, error_kspace, error_constraint = iterative_sense(mask, kspace_us, maps, iterations, lamda, transform=transform, proximal=prox)

  # Show the subsampled data
  fig = plt.figure(figsize=(15,10), constrained_layout=True)
  gs = fig.add_gridspec(3, 6)
  axs = []
  
  axs.append( fig.add_subplot(gs[1,5]) )
  plt.imshow(np.log(1e-7+np.abs(kspace_us[0])),cmap='gray')
  plt.grid(False)
  plt.title(f'Subsampled Kspace Scan')
  plt.ylabel(r'$K_x$ [index]')
  plt.xlabel(r'$K_y$ [index]')
  plt.xticks([], [])
  plt.yticks([], [])

  axs.append( fig.add_subplot(gs[:,0:4]) ) 
  plt.imshow(np.rot90(np.abs(image),-1),cmap='gray')
  plt.grid(False)
  plt.title(f'Reconstructed Image')
  plt.ylabel(r'$x$ [index]')
  plt.xlabel(r'$y$ [index]')
  plt.clim(0, 8*np.mean(np.abs(image)))
  plt.xticks([], [])
  plt.yticks([], [])

  
  axs.append( fig.add_subplot(gs[2,5]) ) 
  plt.semilogy(error_kspace)
  plt.semilogy(error_constraint)
  plt.ylabel(r'error')
  plt.xlabel(r'iteration')
  plt.legend(('Error Kspace','Error Constraint'))
  
  plt.show()


w = interactive(sample_and_plot, 
                iterations=IntSlider(min=1, max=200, step=1, value=1, description='Iterations', continuous_update=False),
                acc_y=IntSlider(min=1, max=4, step=1, value=1, description='Stride in Y', continuous_update=False),
                kmax_x=FloatSlider(min=0.1, max=1, step=0.1, value=1, description='Kmax X', continuous_update=False),
                kmax_y=FloatSlider(min=0.1, max=1, step=0.1, value=1, description='Kmax Y', continuous_update=False),
                lamda=FloatLogSlider(value=-3, base=10, min=-10, max=0, description='Lamda', continuous_update=False),
                random_undersampling_fraction=FloatSlider(min=0.1, max=1, step=0.1, value=1, description='Rand sample', continuous_update=False),
                norm=IntSlider(min=1, max=2, step=1, value=1, description='Norm', continuous_update=False),
                transform_name=Dropdown(options=['None', 'Wavelet', 'Edge'], value='None', description='Spatial Transform:'))

                     
display(w)

# Code to compare sampling

In [None]:
# Grab the scan data
kspace = kdata.copy()
kspace /= np.max(np.abs(kspace))

# Normalize the sensitivity maps (allows a step size of 1)
maps = smaps.copy() / np.max(np.abs(smaps))

for pattern in range(2):

  # Subsample the data
  if pattern == 0:
    kspace_us, mask = undersample_data(kspace, acc_y=4, kmax_y=1, kmax_x=1, random_undersampling_fraction=1)
  else:
    kspace_us, mask = undersample_data(kspace, acc_y=1, kmax_y=1, kmax_x=1, random_undersampling_fraction=0.25)

  # Actual reconstructions
  image_std, error_kspace, error_constraint = iterative_sense(mask, kspace_us, maps, iterations=1, lamda=0, transform=edge_transform(), proximal=l1_prox)
  image_sense, error_kspace, error_constraint = iterative_sense(mask, kspace_us, maps, iterations=200, lamda=0, transform=edge_transform(), proximal=l1_prox)
  image_cs, error_kspace, error_constraint = iterative_sense(mask, kspace_us, maps, iterations=200, lamda=1e-6, transform=edge_transform(), proximal=l1_prox)

  plt.figure(figsize=(20,10))
  plt.subplot(131)
  plt.imshow(np.rot90(np.abs(image_std[32:-32,32:-32]),-1),cmap='gray')
  plt.xticks([], [])
  plt.yticks([], [])
  plt.title('Standard')

  plt.subplot(132)
  plt.imshow(np.rot90(np.abs(image_sense[32:-32,32:-32]),-1),cmap='gray')
  plt.xticks([], [])
  plt.yticks([], [])
  plt.title('SENSE')

  plt.subplot(133)
  plt.imshow(np.rot90(np.abs(image_cs[32:-32,32:-32]),-1),cmap='gray')
  plt.xticks([], [])
  plt.yticks([], [])
  plt.title('L1 - SENSE')



# Code to compare lamda values

In [None]:
# Grab the scan data
kspace = kdata.copy()
kspace /= np.max(np.abs(kspace))

# Normalize the sensitivity maps (allows a step size of 1)
maps = smaps.copy() / np.max(np.abs(smaps))

kspace_us, mask = undersample_data(kspace, acc_y=1, kmax_y=1, kmax_x=1, random_undersampling_fraction=0.25)

images = []
error_k = []
error_c = []
for lp in np.linspace(-10,-5,20):

  l = 10**lp

  print(f'Working on {l}')

  # Actual reconstructions
  image_cs, error_kspace, error_constraint = iterative_sense(mask, kspace_us, maps, iterations=200, lamda=l, transform=edge_transform(), proximal=l1_prox)

  error_k.append(error_kspace[-1])
  error_c.append(error_constraint[-1])
  images.append(image_cs)
  

plt.figure()
plt.plot(error_k, error_c)
plt.xlabel('Error Kspace')
plt.ylabel('Error Constraint')
plt.show()



In [None]:
plt.figure(figsize=(20,10))
plt.subplot(131)
plt.imshow(np.rot90(np.abs(images[0][32:-32,32:-32]),-1),cmap='gray')
plt.xticks([], [])
plt.yticks([], [])
plt.title('Low')

plt.subplot(132)
plt.imshow(np.rot90(np.abs(images[-4][32:-32,32:-32]),-1),cmap='gray')
plt.xticks([], [])
plt.yticks([], [])
plt.title('Medium')

plt.subplot(133)
plt.imshow(np.rot90(np.abs(images[-1][32:-32,32:-32]),-1),cmap='gray')
plt.xticks([], [])
plt.yticks([], [])
plt.title('High')


# Code to compare transforms



In [None]:
# Grab the scan data
kspace = kdata.copy()
kspace /= np.max(np.abs(kspace))

# Normalize the sensitivity maps (allows a step size of 1)
maps = smaps.copy() / np.max(np.abs(smaps))

# Subsample the data
kspace_us, mask = undersample_data(kspace, acc_y=1, kmax_y=1, kmax_x=1, random_undersampling_fraction=0.25)

images = []
for transform in [None, edge_transform(), wavelet_transform()]:
  print(transform)
  # Actual reconstructions
  image_cs, error_kspace, error_constraint = iterative_sense(mask, kspace_us, maps, iterations=200, lamda=1e-6, transform=transform, proximal=l1_prox)

  images.append(image_cs)


plt.figure(figsize=(20,10))
plt.subplot(131)
plt.imshow(np.rot90(np.abs(images[0][32:-32,32:-32]),-1),cmap='gray')
plt.xticks([], [])
plt.yticks([], [])
plt.title('None')

plt.subplot(132)
plt.imshow(np.rot90(np.abs(images[1][32:-32,32:-32]),-1),cmap='gray')
plt.xticks([], [])
plt.yticks([], [])
plt.title('Egde')

plt.subplot(133)
plt.imshow(np.rot90(np.abs(images[2][32:-32,32:-32]),-1),cmap='gray')
plt.xticks([], [])
plt.yticks([], [])
plt.title('Wavelet')
