## Notebook Purpose

Recreate deep decoder experiments run in `ConvDecoder_vs_DIP_vs_DD_multicoil.ipynb`, hereon referred to as the original notebook, which was extremely messy and unnecessarily complicated.

In [1]:
import os, sys
import h5py
import numpy as np
import torch
import matplotlib.pyplot as plt
import time

from utils.transform import np_to_tt, np_to_var, apply_mask, ifft_2d, fft_2d, \
                            reshape_complex_channels_to_sep_dimn, \
                            reshape_complex_channels_to_be_adj, \
                            split_complex_vals, combine_complex_channels, \
                            crop_center, root_sum_of_squares, recon_ksp_to_img
from utils.helpers import num_params, load_h5, get_masks
from include.decoder_conv import convdecoder
from include.mri_helpers import get_scale_factor
from include.fit import fit

from pytorch_msssim import ms_ssim
from utils.evaluate import calc_metrics

In [2]:
if torch.cuda.is_available():
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    dtype = torch.cuda.FloatTensor
    torch.cuda.set_device(1)
#     print("num GPUs",torch.cuda.device_count())
else:
    dtype = torch.FloatTensor

### Set up ConvDecoder

##### TODO's
- make separate function that returns net_input given the appropriate scale_factor, i.e. split up mri_helpers.get_scale_factor() into two different functions

In [3]:
def init_convdecoder(slice_ksp, mask):
    ''' parameters: 
                slice_ksp: original, unmasked k-space measurements
                mask: mask used to downsample original k-space
        return:
                net: initialized convdecoder
                net_input: random, scaled input seed 
                ksp_masked: masked measurements to fit
                img_masked: masked image, i.e. ifft(ksp_masked) '''

    in_size = [8,4]
    out_size = slice_ksp.shape[1:] # shape of (x,y) image slice, e.g. (640, 368)
    out_depth = slice_ksp.shape[0]*2 # 2*n_c, i.e. 2*15=30 if multi-coil
    num_layers = 8
    strides = [1]*(num_layers-1)
    num_channels = 160
    kernel_size = 3

    net = convdecoder(in_size, out_size, out_depth, num_layers, \
                      strides, num_channels).type(dtype)
    print('# parameters of ConvDecoder:',num_params(net))
    
    # fix the scaling b/w original image and random output image = net(input tensor w values ~U[0,1]) 
    # e.g. scale_factor = 168813
    # note: can be done using the under-sampled kspace, but we use the full kspace
    scale_factor, net_input = get_scale_factor(net,
                                       num_channels,
                                       in_size,
                                       slice_ksp)
    slice_ksp = slice_ksp * scale_factor # original fit_untrained() f'n returns this

    # mask the kspace
    ksp_masked = apply_mask(np_to_tt(slice_ksp), mask=mask)
    ksp_masked = np_to_var(ksp_masked.data.cpu().numpy()).type(dtype)

    # perform ifft of masked kspace
    img_masked = ifft_2d(ksp_masked[0]).cpu().numpy()
    img_masked = reshape_complex_channels_to_be_adj(img_masked)
    img_masked = np_to_var(img_masked).type(dtype)
        
    return net, net_input, ksp_masked, img_masked, slice_ksp

### Fit network via `fit(...)`

##### Returns
- net: the best network. network output is in image space but not computed
- mse_wrt_ksp = mse(ksp_masked, fft(out) * mask)
- mse_wrt_img = mse(img_masked, out)

##### args:
- `ksp_masked`: masked k-space of single slice
- `img_masked`: ifft(ksp_masked)
- `img_ls`: least-squares recon of original (unmasked) k-space. This is used only to compute ssim, psnr, and norm_ratio across number of iterations. Too see how this is created, refer to original ipynb for definining `lsimg`

Note: Original code has opt_input argument (default False) which would hence return a new version of net_input

##### TODO's (in fit.py)
- make apply_f call less confusing. compare forwardm to utils.transform.apply_mask()
- understand why we backprop on loss_ksp and not loss_img
- what is difference b/w "image loss" and "image loss orig"?

##### Findings:
- When evaluating on one image, 1000 iterations was sufficient. Seemingly no benefit running for 10000 images
- Runtime per iteration is ~ 0.125s --> ~125s per 1000 iterations. Most expensive steps:
    - backprop loss_ksp: ~0.085s
    - compute net output: ~0.025s
    - all the rest combined: ~0.015s
- Cut down runtime by ~15% by removing unnecessary data type conversions
- Can cut down runtime by ~20% using HalfTensor. In order to implement this, must do the following:
    - Uncomment casting lines below
    - Add apply_f(...).half() in loss_ksp calc of fit.py
    - After all this, gradients go to zero because their values are small and cannot be represented in fp16. As such, follow the blog post under "Mixed-Precision Training Iteration" at https://developer.nvidia.com/blog/mixed-precision-training-deep-neural-networks/. Note: cannot use autocast, as it currently isn't supported in the main version of pytorch
    - Once that works, need to verify no loss in output image quality

In [4]:
## UNCOMMENT for half precision and INJECT before fitting network
# net = net.half()
# ksp_masked = ksp_masked.half()
# img_masked = img_masked.half()
# mask = mask.half()
# dtype=torch.cuda.HalfTensor

In [5]:
# plt.plot(mse_wrt_ksp, label='ksp')
# plt.plot(mse_wrt_img, label='img')
# plt.ylim(0, 0.05)

# TODO: fix method for computing gt

currently we see that mse_wrt_ksp and mse_wrt_img decrease as NUM_ITER increases. However, ssim and psnr computed with img_gt actually get worse as NUM_ITER increases. Must not be computing img_gt correctly

current method: perform ifft of original k-space ksp_orig, combine complex values, combine multi-channel via rss (same method for recon of ksp_dc and ksp_est)

### Perform data consistency step

Compute network output, convert to k-space and perform data-consistency step, then convert back to image space

What is actually happening in this dc step?
- 41/368 of mask coefficients are set to true 
- 41 columns, e.g. 41 * 640 = 787200 of values in ksp are overwritten

##### TODO:
- check and reduce redundant computations, i.e. here we subsample k-space again

In [6]:
def data_consistency(img_out, slice_ksp, mask1d):
    ''' perform data-consistency step so no 
        parameters:
                img_out: network output image, shape torch.Size([30, x, y])
                slice_ksp: original k-space measurements 
        returns:
                img_dc: data-consistent output image
                img_est: output image without data consistency '''
    
    img_out = reshape_complex_channels_to_sep_dimn(img_out)

    # now get F*G(\hat{C}), i.e. estimated recon in k-space
    ksp_est = fft_2d(img_out) # ([15, 640, 368, 2])
    ksp_orig = np_to_tt(split_complex_vals(slice_ksp)) # ([15, 640, 368, 2]); slice_ksp (15,640,368) complex

    # replace estimated coeffs in k-space by original coeffs if it has been sampled
    mask1d = torch.from_numpy(np.array(mask1d, dtype=np.uint8)) # shape: torch.Size([368]) w 41 non-zero elements
    ksp_dc = ksp_est.clone().detach().cpu()
    ksp_dc[:,:,mask1d==1,:] = ksp_orig[:,:,mask1d==1,:]

    img_dc = recon_ksp_to_img(ksp_dc)
    img_est = recon_ksp_to_img(ksp_est.detach().cpu())
    
    return img_dc, img_est

# TODO for new ipynb
- call load_h5() with full path instead of file_id

In [7]:
#filename = '/bmrNAS/people/dvv/multicoil_test_v2/file{}_v2.h5'.format(file_id)

In [8]:
# file_id_list = '1000411' #'1000781'
file_id_list = ['1000186']#, '1000361', '1001524', '1000799', '1001152', '1001132']#, '1001826', '1000522']
  
img_dc_list, img_est_list, img_gt_list, metrics_dc = [], [], [], []

# NUM_ITER = 1000  
NUM_ITER_LIST = [3]#, 1000]

for idx, file_id in enumerate(file_id_list):  
    
    # load full mri measurements
    f, slice_ksp = load_h5(file_id)
    if f['kspace'].shape[3] == 320:
        continue
    print('file_id: {}'.format(file_id))

    # load mask, M
    mask, mask2d, mask1d = get_masks(f, slice_ksp)
    
    # make torch versions for data consistency step in fit()
    mask1d_ = torch.from_numpy(np.array(mask1d, dtype=np.uint8)) 
    ksp_orig = np_to_tt(split_complex_vals(slice_ksp)) # ([15, 640, 368, 2]); slice_ksp (15,640,368) complex
    
    for NUM_ITER in NUM_ITER_LIST:
    
        net, net_input, ksp_masked, img_masked, slice_ksp = \
                init_convdecoder(slice_ksp, mask)
        
        net, mse_wrt_ksp, mse_wrt_img = fit(
            ksp_masked=ksp_masked, img_masked=img_masked,
            net=net, net_input=net_input, mask2d=mask2d,
            mask1d=mask1d_, ksp_orig=ksp_orig,
            img_ls=None, num_iter=NUM_ITER, dtype=dtype)

        img_out = net(net_input.type(dtype))[0] # estimate image \hat{x} = G(\hat{C})

        img_dc, img_est = data_consistency(img_out, slice_ksp, mask1d)
        img_gt = recon_ksp_to_img(slice_ksp) # must do this after slice_ksp is scaled

        # save images, metrics
        img_dc_list.append(img_dc)
        img_est_list.append(img_est)
        img_gt_list.append(img_gt) # could do this once per loop
#     metrics_dc.append(calc_metrics(img_dc, img_gt))

file_id 1000186 w ksp shape (num_slices, num_coils, x, y): (34, 15, 640, 372)
file_id: 1000186
# parameters of ConvDecoder: 1850560
torch.Size([1, 15, 640, 372, 2]) torch.Size([1, 30, 640, 372])


In [9]:
sys.exit()

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
KEY_WORD = 'last_slice' #'iter{}'.format(NUM_ITER)

for i in np.arange(0, 2*len(file_id_list), 2):
    
    fig = plt.figure(figsize=(16, 8))

    ax1 = fig.add_subplot(141)
    ax1.imshow(img_gt_list[i], cmap='gray')
    ax1.set_title('ground-truth?')
    ax1.axis('off')

    ax2 = fig.add_subplot(142)
    ax2.imshow(img_dc_list[i], cmap='gray')
    ax2.set_title('conv_decoder, iter {}'.format(np.array(NUM_ITER_LIST).min()))
    ax2.axis('off')

    ax3 = fig.add_subplot(143)
    ax3.imshow(img_dc_list[i+1], cmap='gray')
    ax3.set_title('conv_decoder')
    ax3.axis('off')
    
    ax4 = fig.add_subplot(144)
    ax4.imshow(img_est_list[i+1], cmap='gray')
    ax4.set_title('conv_decoder w/o dc post')
    ax4.axis('off')
    
    plt.savefig('png_out/sample{}_{}.png'.format(i//2, KEY_WORD))
    plt.show()

# TODO 16 September    
- get baseline performance over e.g. 10 images
- implement data consistency in last layer; compare performance to the same 10 images
- review papers sent by akshay



- data consistency in the loss
    - need to implement dc step into torch variables (done in numpy above)
    - then i can call this similar to how apply_f()=forwardm() is done in current `fit.py`
- next step: how to do layer-wise data consistency?

In [None]:
img_gt_shifted = img_gt * (img_dc.mean() / img_gt.mean())

# est is output image without data consistency step
plt.hist(img_est.flatten(), bins=100, alpha=0.5, label='est')
plt.hist(img_dc.flatten(), bins=100, alpha=0.5, label='dc')
plt.hist(img_gt.flatten(), bins=100, alpha=0.5, label='gt')
plt.legend()
plt.show()