In [4]:
import os, sys
import numpy as np
from os import listdir
from os.path import isfile, join
import matplotlib.pyplot as plt
import h5py
import sigpy
from sigpy.mri.samp import poisson
import torch

sys.path.append('/home/vanveen/ConvDecoder/')
from utils.data_io import load_h5_qdess, num_params
from include.decoder_conv import init_convdecoder
from include.fit import fit
from utils.evaluate import calc_metrics
from utils.transform import fft_2d, ifft_2d, root_sum_squares, \
                            reshape_complex_vals_to_adj_channels, \
                            reshape_adj_channels_to_complex_vals

from torch.autograd import Variable
import copy
import numpy as np
import time

dtype = torch.cuda.FloatTensor

if torch.cuda.is_available():
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    dtype = torch.cuda.FloatTensor
    torch.cuda.set_device(0)

In [20]:
ACCEL_LIST = [4] # 4, 6, 8]
NUM_ITER = 1

def run_expmt(file_id_list):

    for file_id in file_id_list:

#         t0 = time.time()
#         ksp = load_h5_qdess(file_id)
#         t1 = time.time()
        
#         # load, concat both echo slices
#         idx_kx = ksp.shape[0] // 2 # want central slice in kx (axial) b/c we undersample in (ky,kz)
#         ksp_echo1 = ksp[:,:,:,0,:].permute(3,0,1,2)[:, idx_kx, :, :]
#         ksp_echo2 = ksp[:,:,:,1,:].permute(3,0,1,2)[:, idx_kx, :, :]
#         ksp_orig = torch.cat((ksp_echo1, ksp_echo2), 0)
        
#         np.save('delete_me_ksp_orig_mtr_005.npy', np.array(ksp_orig))

        ksp_orig = torch.from_numpy(np.load('delete_me_ksp_orig_mtr_005.npy'))

        for ACCEL in ACCEL_LIST:

            path_out = '/bmrNAS/people/dvv/out_qdess/accel_{}x/echo_joint/new_layers/'.format(ACCEL)
            if os.path.exists('{}{}_e1-joint-recon_dc.npy'.format(path_out, file_id)):
                continue

            # original masks created w central region 32x32 forced to 1's
            mask = torch.from_numpy(np.load('/home/vanveen/ConvDecoder/ipynb/masks/mask_poisson_disc_{}x.npy'.format(ACCEL)))

            # initialize network
            net, net_input, ksp_orig_ = init_convdecoder(ksp_orig, mask)

            # apply mask after rescaling k-space. want complex tensors dim (nc, ky, kz)
            ksp_masked = ksp_orig_ * mask
            img_masked = ifft_2d(ksp_masked)

            # fit network, get net output

            net, mse_wrt_ksp, mse_wrt_img = fit(
                ksp_masked=ksp_masked, img_masked=img_masked,
                net=net, net_input=net_input, mask2d=mask, num_iter=NUM_ITER)
            img_out = net(net_input.type(dtype)) # real tensor dim (2*nc, kx, ky)
            img_out = reshape_adj_channels_to_complex_vals(img_out[0]) # complex tensor dim (nc, kx, ky)

            # perform dc step
            ksp_est = fft_2d(img_out)
            ksp_dc = torch.where(mask, ksp_masked, ksp_est)
        
#         t2 = time.time()
#         print(t1-t0)
#         print(t2-t1)

In [41]:
def fit(ksp_masked, img_masked, net, net_input, mask2d,
        mask1d=None, ksp_orig=None, DC_STEP=False, alpha=0.5,
        num_iter=5000, lr=0.01, img_ls=None, dtype=torch.cuda.FloatTensor,
        c_wmse=None, LOSS_IN_KSP=True):
    
    # initialize variables
    if img_ls is not None or net_input is None:
        raise NotImplementedError('incorporate original code here')
    if alpha < 0 or alpha >= 1:
        raise ValueError('alpha must be non-negative and strictly less than 1')
    net_input = net_input.type(dtype)
    best_net = copy.deepcopy(net)
    best_mse = 10000.0
    mse_wrt_ksp, mse_wrt_img = np.zeros(num_iter), np.zeros(num_iter)

    p = [x for x in net.parameters()]
    optimizer = torch.optim.Adam(p, lr=lr,weight_decay=0)
    mse = torch.nn.MSELoss()

    # convert complex [nc,x,y] --> real [2*nc,x,y] to match w net output
    ksp_masked = reshape_complex_vals_to_adj_channels(ksp_masked).cuda()
    img_masked = reshape_complex_vals_to_adj_channels(img_masked)[None,:].cuda()
    mask2d = mask2d.cuda()

    for i in range(num_iter):
        def closure(): # execute this for each iteration (gradient step)

            optimizer.zero_grad()

            out = net(net_input) # out is in img space
            print(out.shape)
            
            if LOSS_IN_KSP:
                out_ksp_masked = forwardm(out, mask2d).cuda() # convert img to ksp, apply mask
                loss_ksp = mse(out_ksp_masked, ksp_masked)
                loss_ksp.backward(retain_graph=False)
            else:
                out_img_masked = forwardm_img(out, mask2d) # img-->ksp, apply mask, convert to img
                loss_img = mse(out_img_masked, img_masked)
                loss_img.backward(retain_graph=False)

#             mse_wrt_ksp[i] = loss_ksp.data.cpu().numpy() # store loss over each iteration
            if LOSS_IN_KSP:
                return loss_ksp
            else:
                return loss_img

        loss = optimizer.step(closure)

        # at each iteration, check if loss improves by 1%. if so, a new best net
        loss_val = loss.data
        if best_mse > 1.005*loss_val:
            best_mse = loss_val
            best_net = copy.deepcopy(net)

    return best_net, mse_wrt_ksp, mse_wrt_img

def forwardm_img(img, mask):
    ''' convert img --> ksp (must be complex for fft), apply mask
        convert back to img. input dim [2*nc,x,y], output dim [1,2*nc,x,y] '''

    img = reshape_adj_channels_to_complex_vals(img[0])
    ksp = fft_2d(img).cuda()
    ksp_masked_ = ksp * mask
    img_masked_ = ifft_2d(ksp_masked_)
    
    return reshape_complex_vals_to_adj_channels(img_masked_)[None, :]

def forwardm(img, mask):
    ''' convert img --> ksp (must be complex for fft), apply mask
        input, output should have dim [2*nc,x,y] '''

    img = reshape_adj_channels_to_complex_vals(img[0])
    ksp = fft_2d(img).cuda()
    ksp_masked_ = ksp * mask

    return reshape_complex_vals_to_adj_channels(ksp_masked_)

# TODO

- get baseline recon in im-space using l2. should be the same as k-space. verify this.
- next, add tv reg

In [42]:
file_id_list = ['005']

run_expmt(file_id_list)

torch.Size([1, 32, 512, 160])
