## Notebook Purpose

Demo a run of ConvDecoder

In [1]:
import os, sys
import h5py
import numpy as np
import torch
import matplotlib.pyplot as plt
import time

from utils.transform import np_to_tt, np_to_var, apply_mask, ifft_2d, fft_2d, \
                            reshape_complex_channels_to_sep_dimn, \
                            reshape_complex_channels_to_be_adj, \
                            split_complex_vals, combine_complex_channels, \
                            crop_center, root_sum_of_squares, recon_ksp_to_img
from utils.helpers import num_params, load_h5, get_masks
from include.decoder_conv import convdecoder
from include.mri_helpers import get_scale_factor
from include.fit import fit

from pytorch_msssim import ms_ssim
from utils.evaluate import calc_metrics

In [2]:
if torch.cuda.is_available():
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    dtype = torch.cuda.FloatTensor
    torch.cuda.set_device(0)
else:
    dtype = torch.FloatTensor

In [3]:
def get_masked_measurements(slice_ksp, mask):
    ''' parameters: 
                slice_ksp: original, unmasked k-space measurements
                mask: mask used to downsample original k-space
        return:
                ksp_masked: masked measurements to fit
                img_masked: masked image, i.e. ifft(ksp_masked) '''

    # mask the kspace
    ksp_masked = apply_mask(np_to_tt(slice_ksp), mask=mask)
    ksp_masked = np_to_var(ksp_masked.data.cpu().numpy()).type(dtype)

    # perform ifft of masked kspace
    img_masked = ifft_2d(ksp_masked[0]).cpu().numpy()
    img_masked = reshape_complex_channels_to_be_adj(img_masked)
    img_masked = np_to_var(img_masked).type(dtype)
    
    return ksp_masked, img_masked

In [4]:
def data_consistency(img_out, slice_ksp, mask1d):
    ''' perform data-consistency step 
        parameters:
                img_out: network output image, shape torch.Size([30, x, y])
                slice_ksp: original k-space measurements 
        returns:
                img_dc: data-consistent output image
                img_est: output image without data consistency '''
    
    img_out = reshape_complex_channels_to_sep_dimn(img_out)

    # now get F*G(\hat{C}), i.e. estimated recon in k-space
    ksp_est = fft_2d(img_out) # ([15, 640, 368, 2])
    ksp_orig = np_to_tt(split_complex_vals(slice_ksp)) # ([15, 640, 368, 2]); slice_ksp (15,640,368) complex

    # replace estimated coeffs in k-space by original coeffs if it has been sampled
    mask1d = torch.from_numpy(np.array(mask1d, dtype=np.uint8)) # shape: torch.Size([368]) w 41 non-zero elements
    ksp_dc = ksp_est.clone().detach().cpu()
    ksp_dc[:,:,mask1d==1,:] = ksp_orig[:,:,mask1d==1,:]

    img_dc = recon_ksp_to_img(ksp_dc)
    img_est = recon_ksp_to_img(ksp_est.detach().cpu())
    
    return img_dc, img_est

# TODO for new ipynb
- call load_h5() with full path instead of file_id

### Load measurements y and mask M

In [5]:
file_id = '1000411'
filename = '/bmrNAS/people/dvv/multicoil_test_v2/file{}_v2.h5'.format(file_id)

f, slice_ksp = load_h5(file_id)

mask, mask2d, mask1d = get_masks(f, slice_ksp)

file_id 1000411 w ksp shape (num_slices, num_coils, x, y): (40, 15, 640, 368)


### Initialize ConvDecoder

In [6]:
in_size = [8,4]
out_size = slice_ksp.shape[1:] # shape of (x,y) image slice, e.g. (640, 368)
out_depth = slice_ksp.shape[0]*2 # 2*n_c, i.e. 2*15=30 if multi-coil
num_layers = 8
strides = [1]*(num_layers-1)
num_channels = 160
kernel_size = 3

net = convdecoder(in_size, out_size, out_depth, num_layers, \
                  strides, num_channels).type(dtype)
print('# parameters of ConvDecoder:',num_params(net))

# parameters of ConvDecoder: 1850560


### Miscellaneous pre-processing / data conversion

In [7]:
# generate network input
# fix scaling b/w original image and network output
scale_factor, net_input = get_scale_factor(net,
                                   num_channels,
                                   in_size,
                                   slice_ksp)
slice_ksp = slice_ksp * scale_factor

# make torch versions for data consistency step in fit()
mask1d_ = torch.from_numpy(np.array(mask1d, dtype=np.uint8)) 
ksp_orig = np_to_tt(split_complex_vals(slice_ksp)) # ([15, 640, 368, 2]); slice_ksp (15,640,368) complex

In [8]:
NUM_ITER = 1 # default 1000  

ksp_masked, img_masked = get_masked_measurements(slice_ksp, mask)

print(ksp_masked, img_masked, net_input)

net, mse_wrt_ksp, mse_wrt_img = fit(
    ksp_masked=ksp_masked, img_masked=img_masked,
    net=net, net_input=net_input, mask2d=mask2d,
    mask1d=mask1d_, ksp_orig=ksp_orig,
    img_ls=None, num_iter=NUM_ITER, dtype=dtype)

# estimate image \hat{x} = G(\hat{C})
img_out = net(net_input.type(dtype))[0] 

# data consistency step
img_dc, img_est = data_consistency(img_out, slice_ksp, mask1d)

# create ground-truth from full k-space
print('bug alert! ground-truth likely not computed properly')
img_gt = recon_ksp_to_img(slice_ksp) # must do this after slice_ksp is scaled

calc_metrics(img_dc, img_gt)

RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR (createCuDNNHandle at /pytorch/aten/src/ATen/cudnn/Handle.cpp:9)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x46 (0x7f0d0397f536 in /home/vanveen/heck/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x10a0c28 (0x7f0d04e7bc28 in /home/vanveen/heck/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
frame #2: at::native::getCudnnHandle() + 0xe54 (0x7f0d04e7d404 in /home/vanveen/heck/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0xf19f4c (0x7f0d04cf4f4c in /home/vanveen/heck/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0xf1afe1 (0x7f0d04cf5fe1 in /home/vanveen/heck/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
frame #5: <unknown function> + 0xf1f01b (0x7f0d04cfa01b in /home/vanveen/heck/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
frame #6: at::native::cudnn_convolution_backward_input(c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool) + 0xb2 (0x7f0d04cfa572 in /home/vanveen/heck/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xf86090 (0x7f0d04d61090 in /home/vanveen/heck/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
frame #8: <unknown function> + 0xfca928 (0x7f0d04da5928 in /home/vanveen/heck/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
frame #9: at::native::cudnn_convolution_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, std::array<bool, 2ul>) + 0x4fa (0x7f0d04cfbc0a in /home/vanveen/heck/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
frame #10: <unknown function> + 0xf863bb (0x7f0d04d613bb in /home/vanveen/heck/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
frame #11: <unknown function> + 0xfca984 (0x7f0d04da5984 in /home/vanveen/heck/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
frame #12: <unknown function> + 0x2c80736 (0x7f0d3e511736 in /home/vanveen/heck/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #13: <unknown function> + 0x2ccff44 (0x7f0d3e560f44 in /home/vanveen/heck/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #14: torch::autograd::generated::CudnnConvolutionBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0x378 (0x7f0d3e129908 in /home/vanveen/heck/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #15: <unknown function> + 0x2d89705 (0x7f0d3e61a705 in /home/vanveen/heck/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #16: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&) + 0x16f3 (0x7f0d3e617a03 in /home/vanveen/heck/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #17: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&, bool) + 0x3d2 (0x7f0d3e6187e2 in /home/vanveen/heck/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #18: torch::autograd::Engine::thread_init(int) + 0x39 (0x7f0d3e610e59 in /home/vanveen/heck/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #19: torch::autograd::python::PythonEngine::thread_init(int) + 0x38 (0x7f0d4af54488 in /home/vanveen/heck/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #20: <unknown function> + 0xd6cb4 (0x7f0de808dcb4 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #21: <unknown function> + 0x9609 (0x7f0deaad4609 in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #22: clone + 0x43 (0x7f0deac10103 in /lib/x86_64-linux-gnu/libc.so.6)


In [None]:
sys.exit()

In [None]:
for i in np.arange(0, 2*len(file_id_list), 2):
    
    fig = plt.figure(figsize=(16, 8))

    ax1 = fig.add_subplot(141)
    ax1.imshow(img_gt_list[i], cmap='gray')
    ax1.set_title('ground-truth?')
    ax1.axis('off')

    ax2 = fig.add_subplot(142)
    ax2.imshow(img_dc_list[i], cmap='gray')
    ax2.set_title('conv_decoder, iter {}'.format(np.array(NUM_ITER_LIST).min()))
    ax2.axis('off')

    ax3 = fig.add_subplot(143)
    ax3.imshow(img_dc_list[i+1], cmap='gray')
    ax3.set_title('conv_decoder')
    ax3.axis('off')
    
    ax4 = fig.add_subplot(144)
    ax4.imshow(img_est_list[i+1], cmap='gray')
    ax4.set_title('conv_decoder w/o dc post')
    ax4.axis('off')
    
    plt.show()