npz file for each MRI 

In [None]:
import os
import shutil
import glob
import re

# h5py can read hdf5 dataset
import h5py

# delete bad data files
from send2trash import send2trash

# fastmri has some k-space undersampling functions we can use
# git clone https://github.com/facebookresearch/fastMRI.git
# go to the fastmri directory
# pip install -e .
import fastmri

# We will use this functions to generate masks
from fastmri.data.subsample import RandomMaskFunc, EquispacedMaskFunc

# sigpy is apparently a good MRI viewing tool
# pip install sigpy
import sigpy as sp
import sigpy.plot as pl

import numpy as np

import tensorflow as tf
from keras import backend as K

import matplotlib.pyplot as plt
import matplotlib

%matplotlib notebook

In [4]:
# define paths

# # master
DATASETS = [
    'singlecoil_train',
#     'singlecoil_val',
#     'singlecoil_test_v2'
]

# purely for testing / demo-ing master
data_save_path = os.path.join('/central/groups/BEBi_205_Spring_2021/vliu', 'dataset_objects')

AXES = {
        'singlecoil_train' : (1, 2),
        'singlecoil_val' : (1, 2),
        'singlecoil_test_v2' : (1, 2),
       }


# # single dataset, for debugging / demo purposes. cwd is home directory
# DATASET = 'singlecoil_val' #singlecoil_val, singlecoil_test_v2
# data_path = os.path.join('/central/groups/BEBi_205_Spring_2021/vliu', DATASET)
# mri_paths = glob.glob(os.path.join(data_path, '*99.h5'))
# data_save_path = os.path.join('/central/groups/BEBi_205_Spring_2021/vliu/dataset_objects', DATASET) 
print ('bye')

bye


In [3]:
# this block gets Dataset object with imaginary and real separated
def _get_kspace_and_reconstruction_rss(filename, DATASET):
    """
    @params filename: full path to .h5 mri file
    @return kspace data of that particular file
    """
    try:
        with h5py.File(filename, 'r') as hr:
            return hr['kspace'][:], hr['reconstruction_rss'][:]
    except:
        print(f'Error could not open {filename}')

def _get_kspace_undersampled(kspace, center_fractions = [0.04], accelerations = [4]):
    """
    @params kspace: from _get_kspace_and_reconstruction_rss(filename, DATASET)
    @params center_fractions: for undersampling, 
        Ncenter_fraction columns in center corresponding to low-frequencies
    @params accelerations: how much mri acquisition is sped up
    @return undersampled k-space
    """
    mask_func = RandomMaskFunc(
        center_fractions = center_fractions, 
        accelerations = accelerations
    )
    mask = np.array(mask_func(kspace.shape))
    return kspace * mask



def _get_mri_im_separated(
    reconstruction_rss,
    kspace_undersampled, 
    DATASET
):
    """
    separates imaginary from real values
    # @params kspace: from _get_kspace_and_reconstruction_rss(filename)
    @params reconstruction_rss: reconstructed MR image of fully sampled kspace, provided
    @params kspace_undersampled: mask-undersampled k-space from _get_kspace_undersampled
    @params DATASET: i.e. 'singlecoil_challenge' or 'multicoil_challenge'
    @return (undersampled mri image, fully sampled mri image (i.e. label for GAN))
    """
    undersampled_im = sp.ifft(kspace_undersampled, axes = AXES[DATASET])
    
    #crop to make sure images are all the size
    undersampled_crop = sp.resize(
        undersampled_im,
        [1, 32, 256, 256]
    )
    
    undersampled_crop_real = tf.math.real(undersampled_crop)
    undersampled_crop_imag = tf.math.imag(undersampled_crop)
    
    undersampled_crop = np.stack(
        (undersampled_crop_real, undersampled_crop_imag),
        axis = 4,
    )
    
    
    fullysampled_crop = sp.resize(
        reconstruction_rss,
        [1, 32, 256, 256]
    )
    
    return (
        undersampled_crop,
        fullysampled_crop,
    )
    



def get_datum_from_single_file_separated(filename, DATASET):
    """
    user-facing function for tf Dataset object
    @params filename: full path to .h5 mri file
    @params DATASET: i.e. 'singlecoil_challenge' or 'multicoil_challenge'
    @return (undersampled mri image, fully sampled mri image (i.e. label for GAN))
    """
    kspace, reconstruction_rss = _get_kspace_and_reconstruction_rss(filename, DATASET)
    kspace_undersampled = _get_kspace_undersampled(kspace)
    return _get_mri_im_separated(
        reconstruction_rss,
        kspace_undersampled,
        DATASET,
    )




def save_data(filenames, DATASET):  
    """
    user-facing function to save each MRI as
    undersampled / ground truth pair as npz
    @params filenames: list of full paths to .h5 mri files
    @params DATASET: i.e. 'singlecoil_train' or 'multicoil_train'
    @return True if saved successfully
    """
    print_marker = 1
    
    # regex for getting mri_file names
    pattern = f'{DATASET}/(.*)'
    regex = re.compile(pattern)
    
    for i, mri_path in enumerate(filenames):
        try:
            # undersampled_crop has real and imag components
            undersampled_crop, fullysampled_crop = get_datum_from_single_file_separated(
                mri_path, DATASET
            )
            
            # reshape for channels and imaginary / real numbers
            undersampled_crop = undersampled_crop.reshape(-1, 32, 256, 256, 2)
            fullysampled_crop = fullysampled_crop.reshape(-1, 32, 256, 256, 1)
            
            
            # save
            mri_filename = regex.findall(mri_path)[0]
            data_save_file = os.path.join(data_save_path, f'{DATASET}/{mri_filename}.npz')
#             if os.path.isfile(data_save_file):
#                 os.remove(data_save_file)
            np.savez(data_save_file, undersampled_crop, fullysampled_crop)
            
            # print progress
            print_marker += 1
            if print_marker % 50 == 0:
                print(f'undersampled and saved {print_marker} {DATASET} files thus far')
            
            
        except:
            print(f'could not process file {mri_path}')
#             send2trash(mri_path)
            print(f'sent file {mri_path} to trash')
    
    print (f'check out files at {data_save_path}/{DATASET}/ directory')

    return True

#





def master():
    '''
    main program to execute save_data for all datasets
    including train, val, and test
    '''
    for DATASET in DATASETS:
        print (f'> > > processing {DATASET} images')
        # define paths
        data_path = os.path.join(
            '/central/groups/BEBi_205_Spring_2021/vliu', 
            DATASET
        )
        mri_paths = glob.glob(os.path.join(data_path, '*.h5'))

        save_data(mri_paths, DATASET) 
        

In [4]:
K.clear_session()
master()

> > > processing singlecoil_train images
undersampled and saved 50 singlecoil_train files thus far
undersampled and saved 100 singlecoil_train files thus far
undersampled and saved 150 singlecoil_train files thus far
undersampled and saved 200 singlecoil_train files thus far
undersampled and saved 250 singlecoil_train files thus far
undersampled and saved 300 singlecoil_train files thus far
undersampled and saved 350 singlecoil_train files thus far
undersampled and saved 400 singlecoil_train files thus far
undersampled and saved 450 singlecoil_train files thus far
undersampled and saved 500 singlecoil_train files thus far
undersampled and saved 550 singlecoil_train files thus far
undersampled and saved 600 singlecoil_train files thus far
undersampled and saved 650 singlecoil_train files thus far
undersampled and saved 700 singlecoil_train files thus far
undersampled and saved 750 singlecoil_train files thus far
undersampled and saved 800 singlecoil_train files thus far
undersampled and

In [50]:
npzfile = np.load(f'{data_save_path}.npz')
under_sampled_separated = npzfile['arr_0']
fully_sampled_separated = npzfile['arr_1']

# ds_separated = tf.data.Dataset.from_tensor_slices((under_sampled_separated, fully_sampled_separated))
# ds_separated = ds_separated.shuffle(1000, seed = 123, reshuffle_each_iteration = True)

# # check Dataset object was created properly
# ds_separated = ds_separated.shuffle(1000)
# for undersampled_im, fullysampled_im in ds_separated.take(10):
#     undersampled_im = tf.reshape(undersampled_im, (-1, 32, 256, 256, 2))
#     fullysampled_im = tf.reshape(fullysampled_im, (-1, 32, 256, 256, 1))
#     print(f'undersampled size {undersampled_im.shape} fullysampled size {fullysampled_im.shape}')
    

# for undersampled_im, fullysampled_im in ds_separated.take(1):
#     pl.ImagePlot(undersampled_im[:, :, :, 0]) #this is just real numbers. Don't use this in comparison
#     pl.ImagePlot(fullysampled_im[:, :, :, 0])

In [58]:
under_sampled_separated.shape

(1, 32, 256, 256, 2)