## Notebook Purpose

Make the data consistency step differentiable in pytorch so it can be applied to the loss function during network fitting

In [1]:
import os, sys
import h5py
import numpy as np
import torch
import matplotlib.pyplot as plt
import time

from utils.transform import np_to_tt, np_to_var, apply_mask, ifft_2d, fft_2d, \
                            reshape_complex_channels_to_sep_dimn, \
                            reshape_complex_channels_to_be_adj, \
                            split_complex_vals, combine_complex_channels, \
                            crop_center, root_sum_of_squares
from utils.helpers import num_params
from include.decoder_conv import convdecoder
from include.mri_helpers import get_scale_factor
from include.fit import fit

from pytorch_msssim import ms_ssim
from common.evaluate import vifp_mscale, ssim, psnr

In [16]:
file_id = '1000411'
filename = '/bmrNAS/people/dvv/multicoil_test_v2/file{}_v2.h5'.format(file_id)
f = h5py.File(filename, 'r') 
slice_idx = f['kspace'].shape[0] // 2
slice_ksp = f['kspace'][slice_idx]



# duplicated from other ipynb to get variables for testing purposes
in_ksp = np_to_tt(reshape_complex_channels_to_be_adj(split_complex_vals(slice_ksp)))
_, _, mask1d = get_masks(f)
img_out = in_ksp # same shape as net output for testing purposes -- actually ksp meas tho

# pre-processing for new inputs into fit() -- added in main script
ksp_orig = np_to_tt(split_complex_vals(slice_ksp)) # ([15, 640, 368, 2]); slice_ksp (15,640,368) complex
mask1d = torch.from_numpy(np.array(mask1d, dtype=np.uint8)) # shape: torch.Size([368]) w 41 non-zero elements

   ### make everything below differentiable, create a new function for use in `fit.py`

##### TODO: 
- recon_ksp_to_img(). everything else might be good
- at each step, want to test original vs. new version and verify shape of torch tensors is the same

##### input shapes
- in_ksp (tt): [30,x,y]
- ksp_orig (tt): [15,x,y,2]
- mask1d (tt): [368]

In [None]:
## DO THIS INSIDE LOOP, once recon_ksp_to_img is differentiable

img_out = reshape_complex_channels_to_sep_dimn(in_ksp)
ksp_est = fft_2d(img_out)
ksp_est[:,:,mask1d==1,:] = ksp_orig[:,:,mask1d==1,:]
ksp_dc = ksp_est

img_dc = recon_ksp_to_img(ksp_dc)

In [None]:
def recon_ksp_to_img(ksp, dim=320):

    arr = ifft_2d(ksp).cpu().numpy()
    arr = reshape_complex_channels_to_be_adj(arr)
    arr = combine_complex_channels(arr) # e.g. shape (30,x,y) --> (15,x,y)
    arr = root_sum_of_squares(arr) # e.g. (15,x,y) --> (x,y)
    arr = crop_center(arr, dim, dim) # e.g. (x,y) --> (dim,dim)
    
    return arr

### Below: duplicate functions. Delete later

In [15]:
def get_masks(file_h5):
    ''' given h5 file, return three different versions of masks:
            mask: used for masking k-space as network input
            mask2d: 2D mask used to fit network 
            mask1d: 1D mask used for data consistency step '''
    try:
        mask1d = np.array([1 if e else 0 for e in file_h5["mask"]]) # load 1D binary mask
    except:
        print('Implement method for generating a mask')
        sys.exit()

    # zero out mask in outer regions e.g. mask and data have last dimn 368, but actual data is size 320
    # TODO: if actual data is size 320, then why do we have dimn 368?
    idxs_zero = (mask1d.shape[-1] - 320) // 2 # e.g. zero first/last (368-320)/2=24 indices
    mask1d[:idxs_zero], mask1d[-idxs_zero:] = 0, 0

    # create 2d mask. zero pad if dimensions don't line up - is this necessary?
    mask2d = np.repeat(mask1d[None,:], slice_ksp.shape[1], axis=0)#.astype(int)
    mask2d = np.pad(mask2d, ((0,),((slice_ksp.shape[-1]-mask2d.shape[-1])//2,)), mode='constant')

    # convert shape e.g. (368,) --> (1, 1, 368, 1)
    mask = np_to_tt(np.array([[mask2d[0][np.newaxis].T]])).type(torch.FloatTensor)
    #print('under-sampling factor:', round(len(mask1d) / sum(mask1d), 2))
    
    return mask, mask2d, mask1d