In [3]:
from glob import glob
import nibabel as nb
from nilearn.regions import RegionExtractor
from skimage.feature.peak import peak_local_max
from nilearn.image import smooth_img, new_img_like
import numpy as np
import os
from scipy.misc import imsave
import matplotlib.pyplot as plt
import tensorflow as tf

def np_to_tfrecords(X, Y, file_path_prefix, verbose=True):
    """
    Converts a Numpy array (or two Numpy arrays) into a tfrecord file.
    For supervised learning, feed training inputs to X and training labels to Y.
    For unsupervised learning, only feed training inputs to X, and feed None to Y.
    The length of the first dimensions of X and Y should be the number of samples.
    
    Parameters
    ----------
    X : numpy.ndarray of rank 2
        Numpy array for training inputs. Its dtype should be float32, float64, or int64.
        If X has a higher rank, it should be rshape before fed to this function.
    Y : numpy.ndarray of rank 2 or None
        Numpy array for training labels. Its dtype should be float32, float64, or int64.
        None if there is no label array.
    file_path_prefix : str
        The path and name of the resulting tfrecord file to be generated, without '.tfrecords'
    verbose : bool
        If true, progress is reported.
    
    Raises
    ------
    ValueError
        If input type is not float (64 or 32) or int.
    
    """
    def _dtype_feature(ndarray):
        """match appropriate tf.train.Feature class with dtype of ndarray. """
        assert isinstance(ndarray, np.ndarray)
        dtype_ = ndarray.dtype
        if dtype_ == np.float64 or dtype_ == np.float32:
            return lambda array: tf.train.Feature(float_list=tf.train.FloatList(value=array))
        elif dtype_ == np.int64:
            return lambda array: tf.train.Feature(int64_list=tf.train.Int64List(value=array))
        else:  
            raise ValueError("The input should be numpy ndarray. \
                               Instaed got {}".format(ndarray.dtype))
            
    assert isinstance(X, np.ndarray)
    assert len(X.shape) == 2  # If X has a higher rank, 
                               # it should be rshape before fed to this function.
    assert isinstance(Y, np.ndarray) or Y is None
    
    # load appropriate tf.train.Feature class depending on dtype
    dtype_feature_x = _dtype_feature(X)
    if Y is not None:
        assert X.shape[0] == Y.shape[0]
        assert len(Y.shape) == 2
        dtype_feature_y = _dtype_feature(Y)            
    
    # Generate tfrecord writer
    result_tf_file = file_path_prefix + '.tfrecords'
    writer = tf.python_io.TFRecordWriter(result_tf_file)
    if verbose:
        print("Serializing {:d} examples into {}".format(X.shape[0], result_tf_file))
        
    # iterate over each sample,
    # and serialize it as ProtoBuf.
    for idx in range(X.shape[0]):
        x = X[idx]
        if Y is not None:
            y = Y[idx]
        
        d_feature = {}
        d_feature['X'] = dtype_feature_x(x)
        if Y is not None:
            d_feature['Y'] = dtype_feature_y(y)
            
        features = tf.train.Features(feature=d_feature)
        example = tf.train.Example(features=features)
        serialized = example.SerializeToString()
        writer.write(serialized)
    
    if verbose:
        print("Writing {} done!".format(result_tf_file))

        
class Map(object):
    
    def __init__(self, filename, min_distance=10, max_peaks=20):
        
        self.filename = filename
        self.img = nb.load(filename)
        self.peak_sets = []
        self.peaks = peak_local_max(self.img.get_data(), exclude_border=False,
                                    min_distance=min_distance,
                                    num_peaks=max_peaks)

    def distort_peaks(self, num_copies=1, max_peaks=None, jitter=None,
                      include_orig=True):
        
        self.peak_sets = []
        
        if include_orig:
            self.peak_sets.append(self.peaks)
        
        for i in range(num_copies):
            
            peaks = self.peaks.copy()
            
            if jitter is not None:
                peaks += np.random.randint(0, high=jitter+1, size=peaks.shape)
                peaks = np.clip(peaks, [0, 0, 0], [90, 108, 90])

            if max_peaks is not None and len(peaks) > max_peaks:
                np.random.shuffle(peaks)
                peaks = peaks[:max_peaks, :]
            
            self.peak_sets.append(peaks)
        
    def to_tfrecords(self, output_dir, max_peaks=5, max_slices=3):
        ''' Write peak sets and slices from the current map out as TFRecords. '''

        basename = os.path.basename(self.filename).split('.')[0]

        for i, ps in enumerate(self.peak_sets):

            # Determine number of slices
            avail_slices = np.unique(ps[:, 2])
            np.random.shuffle(avail_slices)
            n_slices = min(max_slices, len(avail_slices))

            for j in range(n_slices):
                sl_ind = avail_slices[j]
                X_coords = ps[ps[:,2]==j]
                if max_peaks is not None and len(X_coords) > max_peaks:
                    np.random.shuffle(X_coords)
                    X_coords = X_coords[:max_peaks]
                
                # Make image slices
                X_img = np.zeros(self.img.shape[:2])
                X_img[tuple(X_coords[:,:2].T)] = 1
                Y_img = self.img.get_data()[:, :, sl_ind]
                
                # Write to TFRecords
                map_suff = '_var-%d_zslice-%d' % ((i+1), sl_ind)
                path = os.path.join(output_dir, basename + map_suff)
                np_to_tfrecords(X_img, Y_img, path, verbose=False)

                
def volumes_to_tfrecords(images, output_path, min_distance=10,
                       max_peaks_per_volume=100, max_peaks_per_slice=5,
                       max_slices=3, num_copies=4, jitter=3,
                       include_orig=True):
    ''' Reads in a set of nifti volumes and outputs TFRecords objects for 2D
    slices, optionally with some distortion/duplication.
    Args:
        images (str or list): Either a string giving the path to a set of nifti images,
            or a list of already-loaded NiftiImages.
        output_path (str): Directory to write tfrecords to.
        min_distance (int): Minimum distance between peaks (in voxels)
        max_peaks_per_volume (int): Maximum number of peaks to extract for each volume
        max_peaks_per_slice (int): Maximum number of peaks to extract for each slice
        max_slices (int): Maximum number of slices to extract from each volume
        num_copies (int): Number of distorted copies of each slice to create
        jitter (int): Amount of X/Y jittering of each peak (in voxels) to inject
        include_orig (bool): if True, includes the original, unaltered slice in training set 
    '''
    if isinstance(images, str):
        images = glob(images)
    maps = [Map(img, min_distance, max_peaks_per_volume) for img in images]
    for m in maps:
        if num_copies:
            m.distort_peaks(num_copies, max_peaks_per_volume, jitter=jitter,
                            include_orig=include_orig)
        m.to_tfrecords(output_path, max_peaks_per_slice, max_slices)

In [4]:
# E.g...
images = glob('/Dropbox/files/HCP/original/*.nii.gz')
volumes_to_tfrecords(images[:10], '/mnt/c/Users/tyark/Downloads/tmp')