In [24]:
import os, sys
import numpy as np
import pydicom
from os import listdir
from os.path import isfile, join
import matplotlib.pyplot as plt
import h5py
import sigpy
from sigpy.mri.samp import poisson
import torch
from torch.fft import ifftn

sys.path.append('/home/vanveen/ConvDecoder/')
from utils.data_io import load_h5, load_output, save_output, \
                            expmt_already_generated
from utils.transform import np_to_tt, split_complex_vals, recon_ksp_to_img
from utils.helpers import num_params, get_masks
from include.decoder_conv import init_convdecoder
from include.mri_helpers import get_scale_factor, get_masked_measurements, \
                                data_consistency
from include.fit import fit
from utils.evaluate import calc_metrics

if torch.cuda.is_available():
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    dtype = torch.cuda.FloatTensor
    torch.cuda.set_device(1)
    
from utils.transform import np_to_tt, np_to_var, apply_mask, ifft_2d, fft_2d, \
                        reshape_complex_channels_to_sep_dimn, \
                        reshape_complex_channels_to_be_adj, \
                        split_complex_vals, recon_ksp_to_img, \
                        fftshift, ifftshift

In [2]:
path = '/bmrNAS/people/arjun/data/qdess_knee_2020/files_recon_calib-16/'
files = [f for f in listdir(path) if isfile(join(path, f))]
files.sort()

print(len(files))

164


##### data format
- 'kspace': Nx x Ny x Nz x # echos x # coils
- 'maps': Nx x Ny x Nz x # coils x # maps
- 'target': Nx x Ny x Nz x # echos x # maps

take kspace, run on one echo. what to do w num coils? recon all, then rss at end

### load data, make mask

In [3]:
dtype = torch.cuda.FloatTensor

In [68]:
fn = files[0]
f = h5py.File(path + fn, 'r')

ksp = torch.from_numpy(f['kspace'][()])
targ = torch.from_numpy(f['target'][()])

# get echo1, reshape to be (nc, kx, ky, kz)
ksp_vol = ksp[:,:,:,0,:].permute(3,0,1,2)
ksp_vol.shape

torch.Size([16, 512, 512, 160])

In [69]:
# mask = poisson(img_shape=(512, 160), accel=4)
# mask = abs(mask)
# np.save('mask_3d.npy', mask)
mask = torch.from_numpy(np.load('mask_3d.npy').astype('float32'))
mask.shape

# change dimensions of mask to multiply with volume
mask_ = mask[np.newaxis, np.newaxis, :, :]
print(mask.shape, mask_.shape)

torch.Size([512, 160]) torch.Size([1, 1, 512, 160])


### apply mask to entire volume --> create `ksp_masked`, `img_masked`

currently a modified version of the function call `ksp_masked, img_masked = get_masked_measurements(vol_ksp, mask_)`

In [70]:
# apply mask
ksp_masked = ksp_vol * mask_

img_masked = recon_3d_ksp_to_img(ksp_masked)
print(ksp_masked.shape, img_masked.shape)

torch.Size([16, 512, 512, 160]) torch.Size([16, 512, 512, 160])


### get central slice in kx of volumes
because dd+ requires a 2d recon, and we're undersampling in ky, kz

In [72]:
idx_kx = vol_ksp.shape[1] // 2

ksp_masked_ = ksp_masked[:, idx_kx, :, :]
img_masked_ = img_masked[:, idx_kx, :, :]
ksp_slice = ksp_vol[:, idx_kx, :, :]

print(ksp_masked_.shape, img_masked_.shape, ksp_slice.shape)

torch.Size([16, 512, 160]) torch.Size([16, 512, 160]) torch.Size([16, 512, 160])


# TODO: 
- reshape the ksp_masked, img_masked variables according to what fit() requires
- first re-write fit() so data types/shape makes sense, are tensors
    - do this step-by-step w fastmri dataset since that already works? or get it working w qdess, then merge into fastmri processing?
    
### do all array processing in torch!

### quick conversion to run fit() w original data format [delete or re-format later]

In [81]:
from torch.autograd import Variable
# original code [1, 15, 640, 372, 2], [1, 30, 640, 372], mask=(640, 372)
# ksp_masked want (1, 16, 512, 160, 2), img_masked want (1, 32, 512, 160)

ksp_masked__ = np_to_var(ksp_masked_).type(dtype)

img_masked__ = torch.cat([torch.real(img_masked_), torch.imag(img_masked_)])
# img_masked__ = np_to_var(img_masked__).type(dtype)
img_masked__ = Variable(img_masked__[None, :]).type(dtype)

print(ksp_masked__.shape, img_masked__.shape)

torch.Size([1, 16, 512, 160, 2]) torch.Size([1, 32, 512, 160])


### initialize network

network has same num_params as original network w lone difference of 32 = 2 * n_c output channels instead of 30. hence as written now, network is agnostic to number of pixels in a slice, e.g. 512x512 would have same num_params as 512x160 -- is this right?

In [73]:
# slice_ksp (nc, x, y) in original. now slice_ksp (nc, y, z)
# mask is mask2d is (x,y) in original. now (y,z)
net, net_input, slice_ksp = init_convdecoder(slice_ksp, mask)

# from utils.helpers import num_params
# params = [p.shape for p in net.parameters()]
# params

### later todo's

In [165]:
# img_gt = recon_ksp_to_img(slice_ksp, dim=???)

# only need this if doing dc step
# ksp_orig = np_to_tt(split_complex_vals(slice_ksp))[None, :].type(dtype)
# ksp_orig.shape

torch.Size([1, 16, 512, 160, 2])

### run network

In [83]:
# with out = net(net_input) of size [1, 30, 640, 372]
# because mask is in the (x,y) plane. here mask is in the (y,z) plane

print(ksp_masked__.shape, img_masked__.shape, mask.shape)

net, mse_wrt_ksp, mse_wrt_img = fit(
        ksp_masked=ksp_masked__, img_masked=img_masked__,
        net=net, net_input=net_input, mask2d=np.array(mask), num_iter=10)

torch.Size([1, 16, 512, 160, 2]) torch.Size([1, 32, 512, 160]) torch.Size([512, 160])
torch.Size([1, 32, 512, 160]) (512, 160)


TypeError: 'module' object is not callable