## Notebook Purpose

Make the data consistency step differentiable in pytorch so it can be applied to the loss function during network fitting

In [1]:
import os, sys
import h5py
import numpy as np
import torch
import matplotlib.pyplot as plt
import time

from utils.transform import np_to_tt, np_to_var, apply_mask, ifft_2d, fft_2d, \
                            reshape_complex_channels_to_sep_dimn, \
                            reshape_complex_channels_to_be_adj, \
                            split_complex_vals, combine_complex_channels, \
                            crop_center, root_sum_of_squares
from utils.helpers import num_params, get_masks, load_h5
from include.decoder_conv import convdecoder
from include.mri_helpers import get_scale_factor
from include.fit import fit

In [3]:
file_id = '1000267'
f, slice_ksp = load_h5(file_id)


# duplicated from other ipynb to get variables for testing purposes
in_ksp = np_to_tt(reshape_complex_channels_to_be_adj(split_complex_vals(slice_ksp)))
_, _, mask1d = get_masks(f, slice_ksp)
img_out = in_ksp # same shape as net output for testing purposes -- actually ksp meas tho

# pre-processing for new inputs into fit() -- added in main script
ksp_orig = np_to_tt(split_complex_vals(slice_ksp)) # ([15, 640, 368, 2]); slice_ksp (15,640,368) complex
mask1d = torch.from_numpy(np.array(mask1d, dtype=np.uint8)) # shape: torch.Size([368]) w 41 non-zero elements

   ### make everything below differentiable, create a new function for use in `fit.py`

##### TODO: 
- recon_ksp_to_img(). everything else might be good
- at each step, want to test original vs. new version and verify shape of torch tensors is the same

##### input shapes
- in_ksp (tt): [30,x,y]
- ksp_orig (tt): [15,x,y,2]
- mask1d (tt): [368]

In [46]:
## DO THIS INSIDE LOOP, once recon_ksp_to_img is differentiable

img_out = reshape_complex_channels_to_sep_dimn(in_ksp) #[15,x,y,2]
ksp_est = fft_2d(img_out) #[15,x,y,2]
ksp_est[:,:,mask1d==1,:] = ksp_orig[:,:,mask1d==1,:]
ksp_dc = ksp_est #[15,x,y,2]

img_dc = ifft_2d(ksp_dc) #[15,x,y,2]

# don't want to recon img
# img_dc = reshape_complex_channels_to_be_adj_tt(img_dc) #[30,x,y]
# img_dc = combine_complex_channels_tt(img_dc) #[15,x,y]
# img_dc = root_sum_of_squares_tt(img_dc)
# img_dc = crop_center(img_dc, 320, 320)

#img_dc = recon_ksp_to_img(ksp_dc)

In [49]:
ksp_orig[None, :].shape

torch.Size([1, 15, 640, 400, 2])

In [None]:
# want output [1,15,x,y,2]

In [4]:
def recon_ksp_to_img(ksp, dim=320):

    #arr = ifft_2d(ksp).cpu().numpy()
    #arr = reshape_complex_channels_to_be_adj(arr)
    arr = combine_complex_channels(arr) # e.g. shape (30,x,y) --> (15,x,y)
    arr = root_sum_of_squares(arr) # e.g. (15,x,y) --> (x,y)
    arr = crop_center(arr, dim, dim) # e.g. (x,y) --> (dim,dim)
    
    return arr

In [41]:
def reshape_complex_channels_to_be_adj_tt(arr):
    ''' (15,x,y,2) --> (30,x,y) '''

    arr_out = torch.empty(2*arr.shape[0], arr.shape[1], arr.shape[2])
    for idx, a in enumerate(arr):
        arr_out[2*idx], arr_out[2*idx+1] = a[:,:,0], a[:,:,1]

    return arr_out

def combine_complex_channels_tt(arr):
    ''' (30,x,y) --> (15,x,y) via combining real/complex vals into single magnitude '''

    num_coils = int(arr.shape[0] / 2)
    arr_out = torch.empty(num_coils, arr.shape[1], arr.shape[2])
    for idx in range(num_coils):
        arr_out[idx] = torch.sqrt(torch.square(arr[2*idx]) + torch.square(arr[2*idx+1]))
    return arr_out

def root_sum_of_squares_tt(arr):
    ''' given 3D torch tensor e.g. 2D slices from multiple coils
        combine each slice into a single 2D tensor via rss '''
    return torch.sqrt(torch.sum(torch.square(arr), axis=0))