<a href="https://colab.research.google.com/github/kmjohnson3/sigpy/blob/master/examples/CoLab_ImageReconstruction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This code explores a image reconstruction problem using several types of regularization. Specifically, we aim to reconstruct data aquired in the Fourier domain with incomplete sampling.  It uses Jupyter notebooks, which is a web based way to code, predominantly used in Python projects. Each cell can be executed by clicking on the cell and hitting the play buttin in the upper left hand corner. The first cell loads in a set of libraries which we will use. It can be slow due to the need for installing a cupy. **Please set your runtime to GPU (Runtime->Change runtime type->Hardware accelerator )**

In [0]:
# Utilities
import numpy as np # array library
import math  # math used to define pi
import h5py  #hdf5 interface, used to load data
from functools import partial #this lets us fix some of the variables in a fucntion call
import pywt #python based wavelets
from scipy.signal import convolve # convolution

# Plotting
import matplotlib.pyplot as plt
from IPython import display

#Cupy is CUDA software (slow to install)
!pip install cupy

# This gets SigPy from github and installs
!git clone https://github.com/mikgroup/sigpy.git 
!pip install ./sigpy/  
import sigpy as sp
import sigpy.mri as mri


We are going to download some very simple data, which has:
*   Maps : Sensitivity maps [18 x 256 x 256]
*   Image : Raw image [ 256 x 256 ]

Data is in the hdf5 format with just those two entries. 


In [0]:
# Download from a web location
!wget http://medphysics.wisc.edu/~kmjohnso/simple_MRI_data.h5
  
# Data is in hdf5 which we read in here
filename = 'simple_MRI_data.h5'
f = h5py.File(filename, 'r')
Image = np.array(f['Image'])
Maps = np.array(f['Maps'])

#Convert to float32 for faster processing
Image = Image.astype(np.complex64)
Maps = Maps.astype(np.complex64)


Lets make up a k-space trajectory, here I'm making a spiral or radial trajectory, that's pretty unrealistic. 

In [0]:
Narms = 20
Kmax = 128
Npts = 512
Rot = 1*math.pi*2*2
VD = 1

# Create a K-space Samppling (Spiral)
kr_range = np.linspace(0,1,Npts)
theta_range = np.linspace(0,2*math.pi,Narms+1)
theta_range = theta_range[:-1]
kr,theta = np.meshgrid(kr_range,theta_range);

#Spiral formulation
k_complex = Kmax*np.power(kr,VD)*np.exp(1j*kr*Rot)*np.exp(1j*theta)

# Parse complex into kx,ky
kx = np.real(k_complex)
ky = np.imag(k_complex)

plt.figure()
plt.plot(kx.transpose(),ky.transpose(),'-*')
plt.show()

# Get in format for CuPy
coord = np.stack((kx,ky),-1)
coord.shape
dcf = np.abs(k_complex)

The below is used to transform the data from image domain to k-space.

In [0]:
# Device 0 is GPU:0
device = sp.Device(0)

# In SigPy we need to push the data to a device (GPU)
coord = sp.to_device(coord, device=device)
image_gpu = sp.to_device(Image, device=device)
dcf_gpu = sp.to_device(dcf, device=device)
smaps_gpu = sp.to_device(Maps, device=device)

print('DCF shape ', dcf_gpu.shape)
print('Coord shape ', coord.shape)
print('Image shape', image_gpu.shape)
print('Smaps shape', smaps_gpu.shape)

with device:
    # Create A which is a lineary operator
    A = sp.mri.linop.Sense(smaps_gpu, coord=coord, coil_batch_size=None)
    
    print('A is a operator of shape:')
    print(A)
    print('A is a stack of NUFFT and multiply:')
    print(A.linops)
    
    # This now runs the operator setup above to get simulated k-space
    k_space = A.apply(image_gpu)
   
    # Lets add noise since this is simulated
    nl = 5e-4 * device.xp.max(device.xp.abs(k_space))
    k_space += nl * device.xp.random.standard_normal(k_space.shape)
    k_space += nl * 1j * device.xp.random.standard_normal(k_space.shape)

In [0]:
# Transfer back to CPU so we can plot
k_space_cpu = sp.to_device(k_space, device=sp.cpu_device)

# Simple plot using Matplotlib
plt.figure()
plt.plot(np.abs(k_space_cpu[0,:,:]).transpose())
plt.show()

In [0]:
with device:

    # Create a SENSE operator
    sense = sp.mri.app.SenseRecon(k_space, mps=smaps_gpu, weights=dcf_gpu, coord=coord, device=device, max_iter=60,
                                      coil_batch_size=None)
    # Run SENSE operator 
    image_sense = sense.run()
    
    # Create SENSE + L1 Wavelet penalty
    lam = 0.1
    l1wavelet = sp.mri.app.L1WaveletRecon(k_space, mps=smaps_gpu, lamda=lam, weights=dcf_gpu, coord=coord,
                                              device=device, accelerate=True, coil_batch_size=None, max_iter=200)
    # Run L1 wavelet penalty 
    image_l1wavelet = l1wavelet.run()
    
    # Create SENSE + TotalVariation
    lam = 0.1
    tv = sp.mri.app.TotalVariationRecon(k_space, mps=smaps_gpu, lamda=lam, weights=dcf_gpu, coord=coord,
                                               device=device, accelerate=True, coil_batch_size=None, max_iter=200)
    # Run total variation
    image_tv = tv.run()
    
# Put back onto CPU for visualization
image_tv = sp.to_device(image_tv, sp.cpu_device)
image_sense = sp.to_device(image_sense, sp.cpu_device)
image_l1wavelet = sp.to_device(image_l1wavelet, sp.cpu_device)

In [0]:
# Show the image
plt.figure(figsize=(20,20))
plt.subplot(221)
plt.imshow(np.abs(image_tv.transpose()),cmap='gray')
plt.axis('off')
plt.title('Total Variation');

plt.subplot(222)
plt.imshow((np.abs(image_sense.transpose())),cmap='gray')
plt.axis('off')
plt.title('Sense');

plt.subplot(223)
plt.imshow((np.abs(image_l1wavelet.transpose())),cmap='gray')
plt.axis('off')
plt.title('L1 Wavelet');

plt.subplot(224)
plt.imshow((np.abs(Image.transpose())),cmap='gray')
plt.axis('off')
plt.title('Truth');

plt.show()

In [0]:
# How to get sensitivity maps
mps_ker_width=16
ksp_calib_width=32
lamda= 0.001

# Small bug JSense wahts complex weights
dcf_gpu_complex = dcf_gpu.astype(k_space.dtype)

# Setup Jsense operator
app = sp.mri.app.JsenseRecon(k_space,
                             coord=coord, weights=dcf_gpu_complex,
                             mps_ker_width=mps_ker_width,
                             ksp_calib_width=ksp_calib_width,
                             lamda=lamda,
                             device=0,
                             max_iter=60,
                             max_inner_iter=10)

# Run Jsense operator
mps = app.run()

In [0]:
# Show the images
plt.figure(figsize=(10,10))
for m in range(4):
  for n in range(4):
    plt.subplot(4,4,m*4+n+1)
    plt.imshow(np.abs(mps[m*4+n,:,:]))
plt.show()

In [0]:
with device:

    sense = sp.mri.app.SenseRecon(k_space, mps=mps, weights=dcf_gpu, coord=coord, device=device, max_iter=60,
                                      coil_batch_size=None)
    image_sense = sense.run()
    
    
    lam = 0.1
    l1wavelet = sp.mri.app.L1WaveletRecon(k_space, mps=mps, lamda=lam, weights=dcf_gpu, coord=coord,
                                              device=device, accelerate=True, coil_batch_size=None, max_iter=200)
    image_l1wavelet = l1wavelet.run()
    
    lam = 0.1
    tv = sp.mri.app.TotalVariationRecon(k_space, mps=mps, lamda=lam, weights=dcf_gpu, coord=coord,
                                               device=device, accelerate=True, coil_batch_size=None, max_iter=200)
    image_tv = tv.run()
    
    
image_tv = sp.to_device(image_tv, sp.cpu_device)
image_sense = sp.to_device(image_sense, sp.cpu_device)
image_l1wavelet = sp.to_device(image_l1wavelet, sp.cpu_device)

In [0]:
# Show the image
plt.figure(figsize=(20,20))
plt.subplot(221)
plt.imshow(np.abs(image_tv.transpose()),cmap='gray')
plt.axis('off')
plt.title('Total Variation');

plt.subplot(222)
plt.imshow((np.abs(image_sense.transpose())),cmap='gray')
plt.axis('off')
plt.title('Sense');

plt.subplot(223)
plt.imshow((np.abs(image_l1wavelet.transpose())),cmap='gray')
plt.axis('off')
plt.title('L1 Wavelet');

plt.subplot(224)
plt.imshow((np.abs(Image.transpose())),cmap='gray')
plt.axis('off')
plt.title('Truth');

plt.show()