In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook

import numpy as np

import torch
import torch.nn
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
from matplotlib.widgets import Slider

from meld.recon import UnrolledNetwork
from meld.util import getAbs, getPhase
from meld.model import pytorch_proximal

import h5py
import mri
import model
import dataloader
import lib_complex as cp
import os

In [2]:
# Setup device
device_no = 2
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
torch.cuda.set_device(device_no)
device = torch.device("cuda:"+str(device_no) if torch.cuda.is_available() else "cpu")

np_dtype = np.float32

## Crop Volumes into slabs

In [3]:
total_size = 320
slab_size = 50

# path to file containing imgs, k-space, and map data
datapath = "/mikQNAP/dataset_train_knees_3d.h5"
# datapath = "/mikQNAP/dataset_test_knees_3d.h5"

cropped_maps = []
cropped_imgs = []

with h5py.File(datapath, 'r') as F:
    print('keys:', list(F.keys()))
    length = len(F['imgs'])
    print('number of data points:', length)
    for ii in range(length):        
        print(ii)
#         print(F['maps'][ii,...].shape)
        maps = np.transpose(np.array(F['maps'][ii,...], dtype=np.complex), axes=(1,0,2,3))
        imgs = np.transpose(np.array(F['imgs'][ii,...], dtype=np.complex), axes=(0,1,2))
#         imgs = np.array(F['maps'][ii,...], dtype=np.complex)
#         maps = np.swapaxes(np.array(F['maps'][ii,...], dtype=np.complex), 1, 2)
#         imgs = np.swapaxes(np.array(F['imgs'][ii,...], dtype=np.complex), 0, 1)
        print(maps.shape, imgs.shape)
        for jj in range(np.int(np.floor(total_size/slab_size))):
            start_idx = jj*slab_size
            end_idx = (jj+1)*slab_size
            maps_tmp = maps[:,start_idx:end_idx,...]
            imgs_tmp = imgs[start_idx:end_idx,...]
            cropped_imgs.append(imgs_tmp)
            cropped_maps.append(maps_tmp)

keys: ['imgs', 'ksp', 'maps']
number of data points: 16
0
(8, 320, 256, 320) (320, 256, 320)
1
(8, 320, 256, 320) (320, 256, 320)
2
(8, 320, 256, 320) (320, 256, 320)
3
(8, 320, 256, 320) (320, 256, 320)
4
(8, 320, 256, 320) (320, 256, 320)
5
(8, 320, 256, 320) (320, 256, 320)
6
(8, 320, 256, 320) (320, 256, 320)
7
(8, 320, 256, 320) (320, 256, 320)
8
(8, 320, 256, 320) (320, 256, 320)
9
(8, 320, 256, 320) (320, 256, 320)
10
(8, 320, 256, 320) (320, 256, 320)
11
(8, 320, 256, 320) (320, 256, 320)
12
(8, 320, 256, 320) (320, 256, 320)
13
(8, 320, 256, 320) (320, 256, 320)
14
(8, 320, 256, 320) (320, 256, 320)
15
(8, 320, 256, 320) (320, 256, 320)


In [5]:
imgs.shape

(320, 256, 320)

In [6]:
print(len(cropped_imgs))
new_imgs = np.array(cropped_imgs)
new_maps = np.array(cropped_maps)

96


In [7]:
new_maps_file = h5py.File("/tmp/kellman/3d_dataset_maps_50.h5", "w")
new_maps_file['maps'] = new_maps
new_maps_file.flush()
new_maps_file.close()

In [8]:
new_imgs_file = h5py.File("/tmp/kellman/3d_dataset_imgs_50.h5", "w")
new_imgs_file['imgs'] = new_imgs
new_imgs_file.flush()
new_imgs_file.close()

## Loading sampling masks
From /mikRAID/frank/data/cube_knees/fully_sampled <br>
the code below splits the masks into different h5 files for training/testing purposes

In [9]:
masks_dir = '/mikRAID/frank/data/cube_knees/train_mask_slices/'
masks = []
total = new_imgs.shape[0]
c = 0
for mask in os.listdir(masks_dir):
    if mask.split('.')[1] == 'npy':
        if c <= total-1:
            c += 1
            continue
        masks.append(np.load(masks_dir + mask))
        c += 1
        if c > total - 1 + total:
            break

masks_array = np.array(masks)
masks_array = np.fft.fftshift(masks_array)
# print(np.fft.fftshift(masks_array).shape)

In [10]:
print(masks_array.shape)
new_masks_file = h5py.File("/tmp/kellman/3d_dataset_masks_50.h5", "w")
new_masks_file['masks'] = masks_array
new_masks_file.flush()
new_masks_file.close()

(96, 256, 320)


## Cropping masks + ksp

In [27]:
reduced_maps_arr = np.array(reduced_maps)
reduced_ksp_arr = np.array(reduced_ksp)

In [28]:
print(reduced_maps_arr.shape)
print(reduced_ksp_arr.shape)

(16, 8, 64, 256, 320)
(16, 8, 64, 256, 320)


In [32]:
# cropped_maps_filename = "/mikQNAP/kellman/mri_data_3d_64/maps_train_64.h5"
!mkdir /tmp/kellman/
cropped_maps_filename = "/tmp/kellman/maps_train_64.h5"
new_maps_file = h5py.File(cropped_maps_filename, "a")
new_maps_file['maps'] = reduced_maps_arr
new_maps_file.flush()
new_maps_file.close()

## Coil compression 

In [35]:
cc_maps = []
cc_ksp = []

with h5py.File(cropped_maps_filename) as F:
    print(list(F.keys()))
    length = len(F['maps'])
    for i in range(length):
        print(i)
        maps = np.array(F['maps'][i,...], dtype=np.complex)
        ksp = reduced_ksp_arr[i,...]
        
        # note: these are hardcoded values that should be changed 
        ksp_center = ksp[:, 20:44, 116:140, 148:172]
        
        cent_k = ksp_center.reshape(8, -1)
        U, S, V = np.linalg.svd(cent_k, full_matrices=False)
        
        # notes: these are hardcoded values that should be changed 
        maps_out = (np.conj(U.T) @ maps.reshape(8,-1)).reshape(8, 64, 256, 320)
        ksp_out = (np.conj(U.T) @ ksp.reshape(8, -1)).reshape(8, 64, 256, 320)
        cc_maps.append(maps_out)
        cc_ksp.append(ksp_out)
        

['maps']
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15


In [37]:
reduced_maps_array = np.array(cc_maps)

In [38]:
print(reduced_maps_array.shape)

(16, 8, 64, 256, 320)


In [39]:
cc_maps_filename = "/tmp/kellman/maps_test_reduced_64.h5"
reduced_maps_file = h5py.File(cc_maps_filename, "a")
reduced_maps_file['maps'] = reduced_maps_array
reduced_maps_file.flush()
reduced_maps_file.close()

## generating ground truth

In [40]:
reduced_ksp_arr = np.array(cc_ksp)

In [41]:
print(reduced_ksp_arr.shape)

(16, 8, 64, 256, 320)


In [42]:
gt_arr = np.array([np.sum(np.fft.ifftshift(np.fft.ifftn(cc_ksp[i], axes=(1,2,3)),axes=(1,2,3)) * np.conj(cc_maps[i]), axis=0) for i in range(len(cc_ksp))])

In [43]:
gt_filename = "imgs_train_64"
new_gt_file = h5py.File(gt_filename, "a")
new_gt_file['imgs'] = gt_arr
new_gt_file.flush()
new_gt_file.close()

In [None]:
def crop_map_size(maps, size):    
    x_dim = maps.shape[1]
    exclude = int((x_dim - size) / 2)    
    maps_fft = np.fft.fftshift(np.fft.fft(maps, axis=1), axes=(1))
    maps_fft_cropped = np.fft.ifftshift(maps_fft[:, exclude:x_dim-exclude, ...], axes=(1))
    maps_comp = np.fft.ifft(maps_fft_cropped, axis=1)
    return maps_comp

def crop_ksp_size(ksp, size):
    x_dim = ksp.shape[1]
    exclude = int((x_dim - size) / 2)
    
    ksp_cropped = ksp[:, exclude:x_dim-exclude, ...]
#     ksp_cropped_ifft = np.fft.fftshift(np.fft.ifftn(ksp_cropped, axes=(1,2,3)), axes=(1,2,3))
    
#     img = np.sum(ksp_cropped_ifft * np.conj(maps), axis=0)
    return ksp_cropped

# this is actually unused delete later?
def reduce_size(maps, meas, num_vals):
#     x_dim = maps.shape[1]
#     exclude = int((x_dim - size) / 2)
    
    maps_np = cp.r2c(maps.cpu().numpy())
    meas_np = cp.r2c(meas.cpu().numpy())
    
    maps_shape = (num_vals, *maps_np.shape[1:])
    meas_shape = (num_vals, *meas_np.shape[1:])
    
    # these values are hardcoded in, change later!!!!!
    meas_np_center = np.fft.fftshift(meas_np)[:, 20:44, 116:140, 148:172]
    
    print(meas_np_center.shape)
    cent_k = meas_np_center.reshape(8, -1)
    U, S, V = np.linalg.svd(cent_k, full_matrices=False)
    U1 = U[:, :num_vals]
    
    meas_reshape = meas_np.reshape(8, -1)
    maps_reshape = maps_np.reshape(8, -1)
    
    
    meas_reduced_reshape = np.conj(U1.T) @ meas_reshape
    maps_reduced_reshape = np.conj(U1.T) @ maps_reshape
    
    return maps_reduced_reshape.reshape(*maps_shape), meas_reduced_reshape.reshape(*meas_shape)