Skip to content

Commit

Permalink
ENH: Saves data to file
Browse files Browse the repository at this point in the history
  • Loading branch information
ellisdg committed Apr 3, 2017
1 parent ac68ea6 commit 7193e19
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 93 deletions.
107 changes: 20 additions & 87 deletions DataGenerator.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,22 @@
import os
import glob
import pickle
from random import shuffle

import numpy as np
from nilearn.image import resample_img, reorder_img
import nibabel as nib

from utils.utils import pickle_dump

# TODO: Rescale images to integer
# TODO: include normalization script from raw BRATS data
# TODO: normalize data by subtracting mean and then dividing by standard deviation
# TODO: find the smallest shape image that contains all of the original data
# TODO: crop data to the smallest shape image that contains all of the original data
# TODO: set background to zero after resampling

def pickle_dump(item, out_file):
with open(out_file, "wb") as opened_file:
pickle.dump(item, opened_file)


def pickle_load(in_file):
with open(in_file, "rb") as opened_file:
return pickle.load(opened_file)


def get_training_and_testing_generators(data_dir, input_shape, batch_size=1, nb_channels=3, validation_split=0.8,
overwrite=False, saved_folders_file="training_and_testing_folders.pkl"):
if overwrite or not os.path.exists(saved_folders_file):
subject_folders = get_subject_folders(data_dir=data_dir)
training_list, testing_list = split_list(subject_folders, split=validation_split, shuffle_list=True)
pickle_dump((training_list, testing_list), saved_folders_file)
else:
training_list, testing_list = pickle_load(saved_folders_file)
training_generator = data_generator(training_list, batch_size=batch_size, nb_channels=nb_channels,
image_shape=input_shape)
testing_generator = data_generator(testing_list, batch_size=batch_size, nb_channels=nb_channels,
image_shape=input_shape)
def get_training_and_testing_generators(data_file, batch_size, data_split=0.8):
nb_samples = data_file.root.data.shape[0]
sample_list = range(nb_samples)
training_list, testing_list = split_list(sample_list, split=data_split)
pickle_dump(training_list, "training_list.pkl")
pickle_dump(testing_list, "testing_list.pkl")
training_generator = data_generator(data_file, training_list, batch_size=batch_size)
testing_generator = data_generator(data_file, testing_list, batch_size=batch_size)
# Set the number of training and testing samples per epoch correctly
nb_training_samples = len(training_list)/batch_size * batch_size
nb_testing_samples = len(testing_list)/batch_size * batch_size
Expand All @@ -51,65 +32,17 @@ def split_list(input_list, split=0.8, shuffle_list=True):
return training, testing


def get_subject_folders(data_dir):
return glob.glob(os.path.join(data_dir, "*", "*"))


def data_generator(subject_folders, image_shape, batch_size=1, nb_channels=3):
nb_subjects = len(subject_folders)
def data_generator(data_file, index_list, batch_size=1, binary=True):
nb_subjects = len(index_list)
while True:
shuffle(subject_folders)
shuffle(index_list)
nb_batches = nb_subjects/batch_size
# TODO: Edge case? Currently this is handled by flooring the number of training/testing samples
for i in range(nb_batches):
batch_folders = subject_folders[i*batch_size:(i+1)*batch_size]
batch = read_batch(batch_folders, image_shape)
x_train, y_train = get_training_data(batch, nb_channels, truth_channel=3)
del batch, batch_folders
yield x_train, y_train


def read_batch(folders, input_shape):
batch = []
for folder in folders:
batch.append(read_subject_folder(folder, input_shape))
return np.asarray(batch)


def read_subject_folder(folder, image_shape):
flair_image = read_image(os.path.join(folder, "Flair.nii.gz"), image_shape=image_shape)
t1_image = read_image(os.path.join(folder, "T1.nii.gz"), image_shape=image_shape)
t1c_image = read_image(os.path.join(folder, "T1c.nii.gz"), image_shape=image_shape)
truth_image = read_image(os.path.join(folder, "truth.nii.gz"), image_shape=image_shape,
interpolation="nearest")
return np.asarray([t1_image.get_data(), t1c_image.get_data(), flair_image.get_data(), truth_image.get_data()])


def read_image(in_file, image_shape, interpolation='continuous'):
print("Reading: {0}".format(in_file))
image = nib.load(in_file)
return resize(image, new_shape=image_shape, interpolation=interpolation)


def resize(image, new_shape, interpolation="continuous"):
input_shape = np.asarray(image.shape, dtype=np.float16)
ras_image = reorder_img(image, resample=interpolation)
output_shape = np.asarray(new_shape)
new_spacing = input_shape/output_shape
new_affine = np.copy(ras_image.affine)
new_affine[:3, :3] = ras_image.affine[:3, :3] * np.diag(new_spacing)
return resample_img(ras_image, target_affine=new_affine, target_shape=output_shape, interpolation=interpolation)


def get_truth(batch, truth_channel=3):
truth = np.array(batch)[:, truth_channel]
batch_list = []
for sample_number in range(truth.shape[0]):
array = np.zeros(truth[sample_number].shape)
array[truth[sample_number] > 0] = 1
batch_list.append([array])
return np.array(batch_list)


def get_training_data(batch, nb_channels, truth_channel):
return batch[:, :nb_channels], get_truth(batch, truth_channel)
x = data_file.root.data[i*batch_size:(i+1)*batch_size]
y = data_file.root.truth[i*batch_size:(i+1)*batch_size]
if binary:
y[y > 0] = 1
else:
raise NotImplementedError("Multi-class labels are not yet implemented")
yield x, y
17 changes: 11 additions & 6 deletions UnetTraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@
from functools import partial

import numpy as np
import tables
from keras import backend as K
from keras.layers import (Conv3D, MaxPooling3D, Activation, UpSampling3D, merge, Input, Reshape)
from keras.models import Model, load_model
from keras.optimizers import SGD
from keras.callbacks import ModelCheckpoint, CSVLogger, Callback, LearningRateScheduler

from DataGenerator import get_training_and_testing_generators, pickle_dump
from normalize import write_data_to_file

pool_size = (2, 2, 2)
image_shape = (240, 240, 144)
n_channels = 3
input_shape = tuple([n_channels] + list(image_shape))
nb_channels = 3
input_shape = tuple([nb_channels] + list(image_shape))
n_labels = 1 # not including background
batch_size = 1
n_epochs = 50
Expand All @@ -25,6 +27,7 @@
initial_learning_rate = 0.1
learning_rate_drop = 0.5
validation_split = 0.8
hdf5_file = "/home/neuro-user/PycharmProjects/BRATS/data.hdf5"


# learning rate schedule
Expand Down Expand Up @@ -161,19 +164,21 @@ def get_callbacks(model_file):


def main(overwrite=False):
write_data_to_file(data_dir, hdf5_file, image_shape=image_shape, nb_channels=nb_channels)
hdf5_file_opened = tables.open_file(hdf5_file, "r")
model_file = os.path.abspath("3d_unet_model.h5")
if not overwrite and os.path.exists(model_file):
print("Loading pre-trained model")
model = load_model(model_file, custom_objects={'dice_coef_loss': dice_coef_loss, 'dice_coef': dice_coef})
else:
model = unet_model()
train_model(model, model_file)
train_model(model, model_file, hdf5_file_opened)
hdf5_file_opened.close()


def train_model(model, model_file):
def train_model(model, model_file, data_file):
training_generator, testing_generator, nb_training_samples, nb_testing_samples = get_training_and_testing_generators(
data_dir=data_dir, batch_size=batch_size, nb_channels=n_channels, input_shape=image_shape,
validation_split=validation_split)
data_file, batch_size=batch_size, data_split=validation_split)

model.fit_generator(generator=training_generator,
samples_per_epoch=nb_training_samples,
Expand Down
103 changes: 103 additions & 0 deletions normalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import glob
import os

import tables
import numpy as np
from nilearn.image import resample_img, reorder_img
import nibabel as nib


def normalize_data(data, mean, std):
data -= mean[:, np.newaxis, np.newaxis, np.newaxis]
data /= std[:, np.newaxis, np.newaxis, np.newaxis]
return data


def normalize_data_storage(data_storage):
means = list()
stds = list()
for index in range(data_storage.shape[0]):
data = data_storage[index]
means.append(data.mean(axis=(1, 2, 3)))
stds.append(data.std(axis=(1, 2, 3)))
mean = np.asarray(means).mean(axis=0)
std = np.asarray(means).std(axis=0)
for index in range(data_storage.shape[0]):
data_storage[index] = normalize_data(data_storage[index], mean, std)
return data_storage


def create_data_file(out_file, nb_channels, nb_samples, image_shape):
hdf5_file = tables.open_file(out_file, mode='w')
filters = tables.Filters(complevel=5, complib='blosc')
data_shape = tuple([0, nb_channels] + list(image_shape))
truth_shape = tuple([0, 1] + list(image_shape))
data_storage = hdf5_file.createEArray(hdf5_file.root, 'data',
tables.Float32Atom(),
shape=data_shape,
filters=filters,
expectedrows=nb_samples)
truth_storage = hdf5_file.createEArray(hdf5_file.root, 'truth',
tables.UInt8Atom(),
shape=truth_shape,
filters=filters,
expectedrows=nb_samples)
return hdf5_file, data_storage, truth_storage


def write_folders_to_file(subject_folders, data_storage, truth_storage, image_shape, truth_dtype=np.uint8):
for subject_folder in subject_folders:
subject_data = read_subject_folder(subject_folder, image_shape)
data_storage.append(subject_data[:3][np.newaxis])
truth_storage.append(np.asarray(subject_data[3][np.newaxis][np.newaxis], dtype=truth_dtype))
return data_storage, truth_storage


def write_data_to_file(data_folder, out_file, image_shape, truth_dtype=np.uint8, nb_channels=3):
subject_folders = get_subject_folders(data_folder)
nb_samples = len(subject_folders)
hdf5_file, data_storage, truth_storage = create_data_file(out_file, nb_channels=nb_channels, nb_samples=nb_samples,
image_shape=image_shape)
write_folders_to_file(subject_folders, data_storage, truth_storage, image_shape, truth_dtype=truth_dtype)
normalize_data_storage(data_storage)
hdf5_file.close()
return out_file


def get_subject_folders(data_dir):
return glob.glob(os.path.join(data_dir, "*", "*"))


def read_subject_folder(folder, image_shape):
flair_image = read_image(os.path.join(folder, "Flair.nii.gz"), image_shape=image_shape)
t1_image = read_image(os.path.join(folder, "T1.nii.gz"), image_shape=image_shape)
t1c_image = read_image(os.path.join(folder, "T1c.nii.gz"), image_shape=image_shape)
truth_image = read_image(os.path.join(folder, "truth.nii.gz"), image_shape=image_shape,
interpolation="nearest")
return np.asarray([t1_image.get_data(), t1c_image.get_data(), flair_image.get_data(), truth_image.get_data()])


def read_image(in_file, image_shape, interpolation='continuous'):
print("Reading: {0}".format(in_file))
image = nib.load(in_file)
return resize(image, new_shape=image_shape, interpolation=interpolation)


def resize(image, new_shape, interpolation="continuous"):
input_shape = np.asarray(image.shape, dtype=np.float16)
ras_image = reorder_img(image, resample=interpolation)
output_shape = np.asarray(new_shape)
new_spacing = input_shape/output_shape
new_affine = np.copy(ras_image.affine)
new_affine[:3, :3] = ras_image.affine[:3, :3] * np.diag(new_spacing)
return resample_img(ras_image, target_affine=new_affine, target_shape=output_shape, interpolation=interpolation)


def get_truth(batch, truth_channel=3):
truth = np.array(batch)[:, truth_channel]
batch_list = []
for sample_number in range(truth.shape[0]):
array = np.zeros(truth[sample_number].shape)
array[truth[sample_number] > 0] = 1
batch_list.append([array])
return np.array(batch_list)
Empty file added utils/__init__.py
Empty file.
11 changes: 11 additions & 0 deletions utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pickle


def pickle_dump(item, out_file):
with open(out_file, "wb") as opened_file:
pickle.dump(item, opened_file)


def pickle_load(in_file):
with open(in_file, "rb") as opened_file:
return pickle.load(opened_file)

0 comments on commit 7193e19

Please sign in to comment.