## Notebook Purpose

Make the data consistency step differentiable in pytorch so it can be applied to the loss function during network fitting

In [1]:
import os, sys
import h5py
import numpy as np
import torch
import matplotlib.pyplot as plt
import time

from utils.transform import np_to_tt, np_to_var, apply_mask, ifft_2d, fft_2d, \
                            reshape_complex_channels_to_sep_dimn, \
                            reshape_complex_channels_to_be_adj, \
                            split_complex_vals, combine_complex_channels, \
                            crop_center, root_sum_of_squares
from utils.helpers import num_params
from include.decoder_conv import convdecoder
from include.mri_helpers import get_scale_factor
from include.fit import fit

from pytorch_msssim import ms_ssim
from common.evaluate import vifp_mscale, ssim, psnr

In [4]:
file_id = '1000411'
filename = '/bmrNAS/people/dvv/multicoil_test_v2/file{}_v2.h5'.format(file_id)
f = h5py.File(filename, 'r') 
slice_idx = f['kspace'].shape[0] // 2
slice_ksp = f['kspace'][slice_idx]

in_ksp = np_to_tt(reshape_complex_channels_to_be_adj(split_complex_vals(slice_ksp)))

### make everything below differentiable

### TODO: recon_ksp_to_img(). everything else might be good

In [None]:
## DO THIS OUTSIDE OF LOOP

ksp_orig = np_to_tt(split_complex_vals(slice_ksp)) # ([15, 640, 368, 2]); slice_ksp (15,640,368) complex

mask1d = torch.from_numpy(np.array(mask1d, dtype=np.uint8)) # shape: torch.Size([368]) w 41 non-zero elements

In [None]:
## DO THIS INSIDE LOOP, once recon_ksp_to_img is differentiable

img_out = in_ksp
img_out = reshape_complex_channels_to_sep_dimn(in_ksp)
ksp_est = fft_2d(img_out) # ([15, 640, 368, 2])
ksp_est[:,:,mask1d==1,:] = ksp_orig[:,:,mask1d==1,:]
ksp_dc = ksp_est

img_dc = recon_ksp_to_img(ksp_dc)

In [None]:
def recon_ksp_to_img(ksp, dim=320):

    arr = ifft_2d(ksp).cpu().numpy()
    arr = reshape_complex_channels_to_be_adj(arr)
    arr = combine_complex_channels(arr) # e.g. shape (30,x,y) --> (15,x,y)
    arr = root_sum_of_squares(arr) # e.g. (15,x,y) --> (x,y)
    arr = crop_center(arr, dim, dim) # e.g. (x,y) --> (dim,dim)
    
    return arr