## Notebook Purpose

Make the data consistency step differentiable in pytorch so it can be applied to the loss function during network fitting

In [1]:
import os, sys
import h5py
import numpy as np
import torch
import matplotlib.pyplot as plt

from utils.transform import np_to_tt, np_to_var, apply_mask, ifft_2d, fft_2d, \
                            reshape_complex_channels_to_sep_dimn, \
                            reshape_complex_channels_to_be_adj, \
                            split_complex_vals, combine_complex_channels, \
                            crop_center, root_sum_of_squares
from utils.helpers import num_params, get_masks, load_h5
from include.decoder_conv import convdecoder
from include.mri_helpers import get_scale_factor
from include.fit import fit

In [2]:
file_id = '1000267'
f, slice_ksp = load_h5(file_id)


# duplicated from other ipynb to get variables for testing purposes
in_ksp = np_to_tt(reshape_complex_channels_to_be_adj(split_complex_vals(slice_ksp)))
_, _, mask1d = get_masks(f, slice_ksp)
img_out = in_ksp # same shape as net output for testing purposes -- actually ksp meas tho

# pre-processing for new inputs into fit() -- added in main script
ksp_orig = np_to_tt(split_complex_vals(slice_ksp)) # ([15, 640, 368, 2]); slice_ksp (15,640,368) complex
mask1d = torch.from_numpy(np.array(mask1d, dtype=np.uint8)) # shape: torch.Size([368]) w 41 non-zero elements

### make the following command differentiable, create a new function for use in `fit.py`

`ksp_est[:,:,mask1d==1,:] = ksp_orig[:,:,mask1d==1,:]`

##### SOLVED: implementation woes  in `data_consistency_iter()`
- given ksp, ksp_orig, mask1d: 
    - (1) make a clone of ksp, called ksp_dc. can't use ksp directly, as this will break the differentiable chain
    - (2) replace values of ksp_dc with those from ksp_orig according to indices determined by mask1d
    - (3) interpolate b/w ksp and ksp_dc

old notes:
- if index_select doesn't work, think of other ways we can perform dc step using interpolation
    - possible torch.nn.functional.grid_sample? see examples:
        - https://pytorch.org/docs/stable/nn.functional.html
        - https://discuss.pytorch.org/t/differentiable-indexing/17647/4
        - https://www.programcreek.com/python/example/104458/torch.nn.functional.grid_sample


##### outstanding questions
- sampling may not produce a meaningful gradient --> perhaps use interpolation instead?
    - interpolation example: https://discuss.pytorch.org/t/indexing-a-variable-with-a-variable/2111/4

# OLD

### basic interpolation examples -- works

In [12]:
ksp = img_out
ksp_dc.shape

from torch.autograd import Variable
zz = Variable(torch.randn(ksp.shape), requires_grad=True)

# linearly combine network output and ksp_orig (zz)
alpha = 0.5 
out = alpha*ksp + (1-alpha)*zz

### using index_select -- fails as written

https://pytorch.org/docs/stable/generated/torch.index_select.html

in data_consistency_iter():

#mask1d_idx = Variable(torch.LongTensor([1]*len(mask1d))).cuda()

#return torch.index_select(ksp, 3, mask1d_idx)

In [None]:
# must convert mask1d into indices at which it contains 1's
# e.g. mask1d=[0, 1, 1, 0] --> indices=torch.tensor([1, 2])
# then, call torch.index_select(arr_ksp, 2, indices)

mask1d_idx = Variable(torch.where(mask1d)[0])
# mask1d_idx = Variable(torch.LongTensor([1]*len(mask1d)))

arr_select = torch.index_select(arr_ksp, 2, mask1d_idx)
print(arr_ksp.shape, arr_select.shape, len(mask1d_idx))

###
# template example
x = Variable(torch.randn(3,3), requires_grad=True)
print(x)
idx = Variable(torch.LongTensor([0,2]))
print(x.index_select(0, idx))

### unused: functions converted from numpy to torch

In [None]:
def reshape_complex_channels_to_be_adj_tt(arr):
    ''' (15,x,y,2) --> (30,x,y) '''

    arr_out = torch.empty(2*arr.shape[0], arr.shape[1], arr.shape[2])
    for idx, a in enumerate(arr):
        arr_out[2*idx], arr_out[2*idx+1] = a[:,:,0], a[:,:,1]

    return arr_out

def combine_complex_channels_tt(arr):
    ''' (30,x,y) --> (15,x,y) via combining real/complex vals into single magnitude '''

    num_coils = int(arr.shape[0] / 2)
    arr_out = torch.empty(num_coils, arr.shape[1], arr.shape[2])
    for idx in range(num_coils):
        arr_out[idx] = torch.sqrt(torch.square(arr[2*idx]) + torch.square(arr[2*idx+1]))
    return arr_out

def root_sum_of_squares_tt(arr):
    ''' given 3D torch tensor e.g. 2D slices from multiple coils
        combine each slice into a single 2D tensor via rss '''
    return torch.sqrt(torch.sum(torch.square(arr), axis=0))