# Preparing the train/test data

In [None]:
%matplotlib inline

from glob import glob
import os.path as op

import h5py
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
from scipy.misc import imresize  # PIL or Pillow must be installed.

In [None]:
def plot(arr, title=None, **kwds):
    """Plot 2D array."""
    plt.imshow(arr.T, cmap='gray', origin='lower', **kwds)
    if title is not None:
        plt.title(title)
    plt.show()

# Resizing to uniform shape

In [None]:
img = nib.load("../data/subj-001/anat/subj-001_gad-T1w.nii.gz")
data = img.get_data()

slice_ = data[:, :, 25]
print("shape", slice_.shape)
plot(slice_, 'original')

In [None]:
new_shape = (256, 256)

def resize_volume(arr, new_shape):
    """Resize the first two dimensions of a 3D array to `new_shape`."""
    res = np.zeros(new_shape + arr.shape[-1:], dtype=np.int16)
    for z in range(arr.shape[-1]):
        res[:, :, z] = imresize(arr[:, :, z], new_shape)
    return res


data_resized = resize_volume(data, new_shape)
print("shape", data_resized.shape)
plot(data_resized[:, :, 25], "resized (volume)")

# Load NIfTI from the web

In [None]:
def load_url(url, suffix='.nii.gz', **kwds):
    """From URL, return image data, affine, header, and extra. `kwds` are for
    `nibabel.load()`."""
    import tempfile
    import requests
    
    with tempfile.NamedTemporaryFile(suffix=suffix) as fp:
        r = requests.get(url)
        r.raise_for_status()
        fp.write(r.content)
        img = nib.load(fp.name, **kwds)
        return img.get_data(), img.affine, img.header, img.extra


url = ("https://dl.dropbox.com/sh/71jbelduefu41xs/AADysls57HwmJT0pdbbSVI4Na/"
       "case_001_2.nii.gz")
data, _, _, _ = load_url(url)
plot(data[:, :, 25])

# Save arrays to hdf5

In [None]:
def get_filenames():
    """Return list of tuples, where each tuple consists of
    (anat_filename, seg_filename)."""
    import os.path as op
    subjs = glob('../data/*')
    fnames = []
    for s in subjs:
        try:
            anat_file = glob(op.join(s, 'anat/*gad-T1w.nii.gz'))[0]
            seg_file = glob(op.join(s, 'seg/*seg-uint8_gad-T1w.nii.gz'))[0]
            fnames.append((anat_file, seg_file))
        except IndexError:
            continue
    return fnames

def _transform(arr):
    """Transform `arr` to range [0, 1]."""
    return (arr-arr.min())/(arr.max()-arr.min())

def _resize_volume(arr, new_shape):
    """Resize the first two dimensions of a 3D array to `new_shape`."""
    from scipy.misc import imresize  # Requires PIL or Pillow.
    res = np.zeros(new_shape + arr.shape[-1:], dtype=np.int16)
    for z in range(arr.shape[-1]):
        res[:, :, z] = imresize(arr[:, :, z], new_shape)
    return res

def _preprocess(arr, new_shape):
    """Return fully preprocessed array."""
    arr = _resize_volume(arr, new_shape)
    return _transform(arr)

def _gen_slices(arr):
    for i in range(arr.shape[-1]):
        yield arr[:, :, i]
        
def preprocess_all(fnames, **kwds):
    slices = []
    for f in fnames:
        arr = nib.load(f).get_data()
        arr = _preprocess(arr, **kwds)
        slices.extend(_gen_slices(arr))
    return np.array(slices, dtype=np.float32)

In [None]:
slice_shape = (256, 256)

fnames = get_filenames()
anat_fnames = [f[0] for f in fnames]
seg_fnames = [f[1] for f in fnames]

anat_data = preprocess_all(anat_fnames, new_shape=slice_shape)
seg_data = preprocess_all(seg_fnames, new_shape=slice_shape)

In [None]:
print("anat shape:", anat_data.shape)
print("seg shape:", seg_data.shape)

In [None]:
with h5py.File('test_comp.h5', mode='w') as fp:
    fp.create_dataset('/anat', data=anat_data, compression='gzip')
    fp.create_dataset('/seg', data=seg_data, compression='gzip')

In [None]:
fp = h5py.File('test.h5', mode='r')

In [2]:
from keras.models import Sequential
from keras.layers import Dense
from keras.utils.io_utils import HDF5Matrix

In [3]:
X_train = HDF5Matrix('test.h5', 'anat', start=0, end=150)
y_train = HDF5Matrix('test.h5', 'seg', start=0, end=150)