# MRI reconstruction from multicoil data

In [1]:
## EDITS: commented out unnecessary imports

# from __future__ import print_function
import matplotlib.pyplot as plt

import os
import sigpy.mri as mr

# import sigpy as sp
# from os import listdir
# from os.path import isfile, join

import warnings
warnings.filterwarnings('ignore')

from include import *

# from PIL import Image
# import PIL
import h5py
from common.evaluate import *
from pytorch_msssim import ms_ssim # only to evaluate/compare across convdecoder, dd, dip
# import pickle
from common.subsample import MaskFunc # only if we need to generate our own mask

from DIP_UNET_models.skip import * # only to evaluate/compare across convdecoder, dd, dip

import numpy as np
import torch
import torch.optim
from torch.autograd import Variable

from include import transforms as transform

GPU = True
if GPU == True:
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    dtype = torch.cuda.FloatTensor
#     os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    gpu = 0
    torch.cuda.set_device(gpu)
    print("num GPUs",torch.cuda.device_count())
else:
    dtype = torch.FloatTensor

num GPUs 4


In [2]:
sys.exit() # cell for reference

net, ni, slice_ksp_cd = fit_untrained(net, num_channels, mask, 
                                      in_size, slice_ksp, slice_ksp_torchtensor)
# net: convdecoder
# num_channels = 160
# mask: binary torch tensor of size [1, 1, ~368, 1]
# in_size: hard-coded above as [8,4]
# slice_ksp: mri measurements in npy as dtype=complex64, e.g. shape (15, 640, 368)
# slice_ksp_torchtensor: mri measurements in torch tensor w 2 channels, e.g. shape [15, 640, 368, 2]

SystemExit: 

In [9]:
def fit_untrained(parnet, num_channels, mask, in_size, slice_ksp, 
                  slice_ksp_torchtensor):
    
    # fix the scaling b/w original image and random output image = net(input tensor w values ~U[0,1]) 
    # note: this can be done using the under-sampled kspace as well, but we do it using the full kspace
    scale_out = 1
    scaling_factor, ni = get_scale_factor(parnet,
                                       num_channels,
                                       in_size,
                                       slice_ksp,
                                       scale_out=scale_out)
    # e.g. scaling_factor = 168813
    # ni: network input, i.e. tensor w values sampled uniformly on [0,1]
    slice_ksp_torchtensor = slice_ksp_torchtensor * scaling_factor
    slice_ksp = slice_ksp * scaling_factor
    
    ### mask the kspace
    # note: apply_mask() only genereates new mask if mask arg is blank. else, returns same mask
    if slice_ksp_torchtensor.shape[2] != mask.shape[2]: # added to avoid dim error
        mask_ = mask[:,:,:slice_ksp_torchtensor.shape[2],:]
        masked_kspace, mask = transform.apply_mask(slice_ksp_torchtensor, mask = mask_)
    else:    
        masked_kspace, mask = transform.apply_mask(slice_ksp_torchtensor, mask = mask)
    # convert measurement to torch tensor
    unders_measurement = np_to_var(masked_kspace.data.cpu().numpy()).type(dtype)
    # do ifft of input measurements
    sampled_image2 = transform.ifft2(masked_kspace)
    
    # unused
#     ### ??? compute reconstructed image i.e. ground-truth
#     # ksp2measurement(): convert complex npy array into torch tensor w 2 channels
#     measurement = ksp2measurement(slice_ksp).type(dtype)
#     # lsreconstruction(): take ifft of measurement and return combined magnitude of real/imag parts
#     lsimg = lsreconstruction(measurement)
    
    ### fit the network to the under-sampled measurement
    out = []
    for img in sampled_image2:
        out += [ img[:,:,0].numpy() , img[:,:,1].numpy() ]
    lsest = torch.tensor(np.array([out]))
    
    ssim_list, psnr_list, norm_ratio, mse_wrt_noisy, mse_wrt_truth, net_input, net = fit(net=parnet,
                                                                img_noisy_var=unders_measurement,
                                                                img_clean_var=Variable(lsest).type(dtype),
                                                                num_channels=[num_channels]*(num_layers-1),
                                                                net_input = ni,
                                                                mask = mask2d,
                                                                #lsimg = lsimg,
                                                                find_best=True,         
                                                                LR=0.008,
                                                                num_iter=200, #20000      
                                                                scale_out=scale_out # default 1
                                                                          )
    
    return net, net_input, slice_ksp

In [59]:
def data_consistency(parnet, parni, mask1d, slice_ksp):
    ''' actually compute network output 
        replace the predicted coefficients in k-space by the original coefficient if it has been sampled '''
    
    # estimate image \hat{x} = G(\hat{C})
    img = parnet(parni.type(dtype)) # shape: torch.Size([1, 30, 640, 368])
    num_slices = int(img.shape[1]/2) # 15*2=30, i.e. real/complex separate
    # combine real/complex channels into one complex channel
    fimg = Variable(torch.zeros((img.shape[0], num_slices, img.shape[2], img.shape[3], 2))).type(dtype)
    for i in range(num_slices):
        fimg[0,i,:,:,0] = img[0,2*i,:,:]
        fimg[0,i,:,:,1] = img[0,2*i+1,:,:]
    
    # take fourier transform to convert into k-space; split real/complex parts, transpose dimensions
    Fimg = transform.fft2(fimg) # shape: torch.Size([1, 15, 640, 368, 2])
    meas = ksp2measurement(slice_ksp) # shape: torch.Size([1, 15, 640, 368, 2]); slice_ksp has shape (15,640,368) complex
    
    # replace the predicted coeffs in k-space by original coeffs if it has been sampled
    mask = torch.from_numpy(np.array(mask1d, dtype=np.uint8)) # shape: torch.Size([368]) w 41 non-zero elements
    ksp_dc = Fimg.clone().detach().cpu()
    # applying mask, ksp_dc goes from having 7065599 nonzeros to 7065565 nonzeros, i.e. <41 difference
    # question: shouldn't this mask zero out entire columns, not just single elements?
    ksp_dc[:,:,:,mask==1,:] = meas[:,:,:,mask==1,:] # after data consistency block
    
    # now we have M*F*G(\hat{C})

    # take ifft of measurements
    img_dc = transform.ifft2(ksp_dc)[0]
    
    # combine 30 channels w real/complex separate to 15 complex imgs
    out = []
    for img in img_dc.detach().cpu():
        out += [ img[:,:,0].numpy() , img[:,:,1].numpy() ]
    par_out_chs = np.array(out)
    #par_out_chs = parnet( parni.type(dtype),scale_out=scale_out ).data.cpu().numpy()[0]
    par_out_imgs = channels2imgs(par_out_chs)

    # combine 15 complex imgs via standard rss, then crop center for 320x320 grayscale output
    prec = crop_center2(root_sum_of_squares2(par_out_imgs),320,320)
    
    return prec

In [60]:
rec_convD = data_consistency(net, ni, mask1d, slice_ksp_cd)

# Loading MRI measurement

In [5]:
# get full stack of slices
# note: if from val set, contains kspace meas f['kspace'] and rss recon f['reconstruction_rss']
filename = '/bmrNAS/people/dvv/multicoil_test_v2/file1000781_v2.h5'
f = h5py.File(filename, 'r') 
print("Kspace shape (number slices, number coils, x, y): ", f['kspace'].shape)

# isolate kspace slice
slicenu = f["kspace"].shape[0]//2
slice_ksp = f['kspace'][slicenu]
slice_ksp_torchtensor = transform.to_tensor(slice_ksp)      # Convert from numpy array to pytorch tensor

# note - f['reconstruction_rss'] only exists in val set, not test set
# kspace dtype is complex64, hence cannot display without conversion

# fig = plt.figure(figsize=(6,6))
# ax = fig.add_subplot(111)
# ax.imshow(f["reconstruction_rss"][slicenu],"gray")
# ax.set(title="ground truth")
# ax.axis("off")
# plt.show()

Kspace shape (number slices, number coils, x, y):  (37, 15, 640, 368)


<KeysViewHDF5 ['ismrmrd_header', 'kspace', 'mask']>

### Load or create mask, M

- Load if .h5 file already has one, otherwise create
    - Format of loaded mask is 1d binary vector of size ~368
- Convert mask to 0's and 1's, zero pad, convert to 2D, create torch transform

In [6]:
# try: # if the file already has a mask
temp = np.array([1 if e else 0 for e in f["mask"]])
temp = temp[np.newaxis].T
temp = np.array([[temp]])
mask = transform.to_tensor(temp).type(dtype).detach().cpu()
# except: # if we need to create a mask
#     desired_factor = 4 # desired under-sampling factor
#     undersampling_factor = 0
#     tolerance = 0.03
#     while undersampling_factor < desired_factor - tolerance or undersampling_factor > desired_factor + tolerance:
#         mask_func = MaskFunc(center_fractions=[0.07], accelerations=[desired_factor])  # Create the mask function object
#         masked_kspace, mask = transform.apply_mask(slice_ksp_torchtensor, mask_func=mask_func)   # Apply the mask to k-space
#         mask1d = var_to_np(mask)[0,:,0]
#         undersampling_factor = len(mask1d) / sum(mask1d)

mask1d = var_to_np(mask)[0,:,0]

# The provided mask and data have last dim of 368, but the actual data is smaller.
# To prevent the network learning outside the data region, we force the mask to 0 there.
mask1d[:mask1d.shape[-1]//2-160] = 0 
mask1d[mask1d.shape[-1]//2+160:] =0
mask2d = np.repeat(mask1d[None,:], slice_ksp.shape[1], axis=0).astype(int) # Turning 1D Mask into 2D that matches data dimensions
mask2d = np.pad(mask2d,((0,),((slice_ksp.shape[-1]-mask2d.shape[-1])//2,)),mode='constant') # Zero padding to make sure dimensions match up
mask = transform.to_tensor( np.array( [[mask2d[0][np.newaxis].T]] ) ).type(dtype).detach().cpu()
print("under-sampling factor:",round(len(mask1d)/sum(mask1d),2))

under-sampling factor: 8.98


### Setup and fit ConvDecoder

In [7]:
arch_name = "ConvDecoder"

out_depth = slice_ksp.shape[0]*2 # 2*n_c, i.e. 2*15=30 if multi-coil
out_size = slice_ksp.shape[1:] # shape of (x,y) image slice, e.g. (640, 368)

num_channels = 160 #256
num_layers = 8
strides = [1]*(num_layers-1)
in_size = [8,4]
kernel_size = 3

net = convdecoder(in_size, out_size, out_depth, num_layers, strides, num_channels, act_fun = nn.ReLU(),
                     skips=False, need_sigmoid=False, bias=False, need_last = True,
                     kernel_size=kernel_size, upsample_mode="nearest").type(dtype)

print("# parameters of {}:".format(arch_name),num_param(net))
#print(net)

[(15, 8), (28, 15), (53, 28), (98, 53), (183, 102), (343, 193), (640, 368)]
# parameters of ConvDecoder: 1850560


In [10]:
net, ni, slice_ksp_cd = fit_untrained(net, num_channels, mask, in_size, slice_ksp, slice_ksp_torchtensor)
# net: convdecoder
# num_channels = 160
# mask: binary torch tensor of size [1, 1, ~368, 1]
# in_size: hard-coded above as [8,4]
# slice_ksp: mri measurements in npy as dtype=complex64, e.g. shape (15, 640, 368)
# slice_ksp_torchtensor: mri measurements in torch tensor w 2 channels, e.g. shape [15, 640, 368, 2]

optimize with adam 0.008
Iteration 00100    Train loss 0.028007  Actual loss 0.032156 Actual loss orig 0.032156 

In [11]:
slice_ksp_cd.shape

(15, 640, 368)

In [12]:
rec_convD = data_consistency(net, ni, mask1d, slice_ksp_cd)

In [13]:
rec_convD.shape

(320, 320)

In [None]:
sys.exit()

### Setup and fit Deep Decoder (DD)

In [None]:
### delete cashe
del(net,ni)
torch.cuda.empty_cache()

In [None]:
arch_name = "DD"
###
num_channels = 368
num_layers = 10
in_size = [16,16]

net = skipdecoder(out_size,in_size,output_depth,
                   num_layers,num_channels,skips=False,need_last=True,
                   need_sigmoid=False,upsample_mode="bilinear").type(dtype)
print("#prameters of {}:".format(arch_name),num_param(net))
#print(net)

In [None]:
net,ni,slice_ksp_cd = fit_untrained(net, num_channels, mask, in_size, slice_ksp, slice_ksp_torchtensor)

In [None]:
rec_DD = data_consistency(net, ni, mask1d, slice_ksp_cd)

### Setup and fit Deep Image Prior (DIP) (encoder-decoder style architecture)

In [None]:
### delete cashe
del(net,ni)
torch.cuda.empty_cache()

In [None]:
arch_name = "DIP"
### 
in_size = slice_ksp.shape[-2:]
pad = "zero" #'reflection' # 'zero'
num_channels = 256
net = skip(in_size,num_channels, output_depth, 
           num_channels_down = [num_channels] * 8,
           num_channels_up =   [num_channels] * 8,
           num_channels_skip =    [num_channels*0] * 6 + [4,4],  
           filter_size_up = 3, filter_size_down = 5, 
           upsample_mode='nearest', filter_skip_size=1,
           need_sigmoid=False, need_bias=True, pad=pad, act_fun='ReLU').type(dtype)
print("#prameters of {}:".format(arch_name),num_param(net))
#print(net)

In [None]:
net,ni,slice_ksp_cd = fit_untrained(net, num_channels, mask, in_size, slice_ksp, slice_ksp_torchtensor)

In [None]:
rec_DIP = data_consistency(net, ni, mask1d, slice_ksp_cd)

### Evaluation

In [None]:
def scores(im1,im2):
    im1 = (im1-im1.mean()) / im1.std()
    im1 *= im2.std()
    im1 += im2.mean()
    
    vif_ = vifp_mscale(im1,im2,sigma_nsq=im1.mean())
    
    ssim_ = ssim(np.array([im1]), np.array([im2]))
    psnr_ = psnr(np.array([im1]),np.array([im2]))

    dt = torch.FloatTensor
    im11 = torch.from_numpy(np.array([[im1]])).type(dt)
    im22 = torch.from_numpy(np.array([[im2]])).type(dt)
    ms_ssim_ = ms_ssim(im11, im22,data_range=im22.max()).data.cpu().numpy()[np.newaxis][0]
    return vif_, ms_ssim_, ssim_, psnr_

In [None]:
gt = f["reconstruction_rss"][slicenu]

In [None]:
vif_cd, ms_ssim_cd, ssim_cd, psnr_cd  = scores(gt, rec_convD)
vif_dd, ms_ssim_dd, ssim_dd, psnr_dd  = scores(gt, rec_DD)
vif_dip, ms_ssim_dip, ssim_dip, psnr_dip  = scores(gt, rec_DIP)

### Visualization

In [None]:
fig = plt.figure(figsize = (16,14)) # create a 5 x 5 figure 
    
ax1 = fig.add_subplot(221)
ax1.imshow(gt,cmap='gray')
ax1.set_title('Ground Truth')
ax1.axis('off')

ax2 = fig.add_subplot(222)
ax2.imshow(rec_convD,cmap='gray')
ax2.set_title( "ConvDecoder") 
ax2.axis('off') 

ax3 = fig.add_subplot(223)
ax3.imshow(rec_DD,cmap='gray')
ax3.set_title( "Deep Decoder" ) 
ax3.axis('off')

ax4 = fig.add_subplot(224)
ax4.imshow(rec_DIP,cmap='gray')
ax4.set_title( "Deep Image Prior" ) 
ax4.axis('off')

print("ConvDecoder       --> VIF: %.2f, MS-SSIM: %.2f, SSIM: %.2f, PSNR: %.2f " % (vif_cd,ms_ssim_cd,ssim_cd,psnr_cd))
print("Deep Decoder      --> VIF: %.2f, MS-SSIM: %.2f, SSIM: %.2f, PSNR: %.2f " % (vif_dd,ms_ssim_dd,ssim_dd,psnr_dd))
print("Deep Image Prior  --> VIF: %.2f, MS-SSIM: %.2f, SSIM: %.2f, PSNR: %.2f " % (vif_dip,ms_ssim_dip,ssim_dip,psnr_dip))

plt.show()