Skip to content

Commit

Permalink
ENH: Replaces cropping with resampling
Browse files Browse the repository at this point in the history
  • Loading branch information
ellisdg committed Mar 31, 2017
1 parent 56521a4 commit 07b293d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 59 deletions.
83 changes: 45 additions & 38 deletions DataGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
import pickle
from random import shuffle

import SimpleITK as sitk
import numpy as np
from nilearn.image import resample_img, reorder_img
import nibabel as nib


# TODO: Rescale images to integer
# TODO: include normalization script from raw BRATS data
# TODO: find the smallest shape image that contains all of the original data
# TODO: set background to zero after resampling

def pickle_dump(item, out_file):
with open(out_file, "wb") as opened_file:
pickle.dump(item, opened_file)
Expand All @@ -17,22 +23,22 @@ def pickle_load(in_file):
return pickle.load(opened_file)


def get_training_and_testing_generators(data_dir, batch_size=1, nb_channels=3, truth_channel=3,
background_channel=4, z_crop=15, validation_split=0.8, overwrite=False,
saved_folders_file="training_and_testing_folders.pkl"):
def get_training_and_testing_generators(data_dir, input_shape, batch_size=1, nb_channels=3, validation_split=0.8,
overwrite=False, saved_folders_file="training_and_testing_folders.pkl"):
if overwrite or not os.path.exists(saved_folders_file):
subject_folders = get_subject_folders(data_dir=data_dir)
training_list, testing_list = split_list(subject_folders, split=validation_split, shuffle_list=True)
pickle_dump((training_list, testing_list), saved_folders_file)
else:
training_list, testing_list = pickle_load(saved_folders_file)
training_generator = data_generator(training_list, batch_size=batch_size, nb_channels=nb_channels,
truth_channel=truth_channel, background_channel=background_channel,
z_crop=z_crop)
input_shape=input_shape)
testing_generator = data_generator(testing_list, batch_size=batch_size, nb_channels=nb_channels,
truth_channel=truth_channel, background_channel=background_channel,
z_crop=z_crop)
return training_generator, testing_generator, len(training_list), len(testing_list)
input_shape=input_shape)
# Set the number of training and testing samples per epoch correctly
nb_training_samples = len(training_list)/batch_size * batch_size
nb_testing_samples = len(testing_list)/batch_size * batch_size
return training_generator, testing_generator, nb_training_samples, nb_testing_samples


def split_list(input_list, split=0.8, shuffle_list=True):
Expand All @@ -48,48 +54,49 @@ def get_subject_folders(data_dir):
return glob.glob(os.path.join(data_dir, "*", "*"))


def data_generator(subject_folders, batch_size=1, nb_channels=3, truth_channel=3, background_channel=4, z_crop=15):
def data_generator(subject_folders, input_shape, batch_size=1, nb_channels=3):
nb_subjects = len(subject_folders)
while True:
shuffle(subject_folders)
# TODO: Edge case?
for i in range(nb_subjects/batch_size):
nb_batches = nb_subjects/batch_size
# TODO: Edge case? Currently this is handled by flooring the number of training/testing samples
for i in range(nb_batches):
batch_folders = subject_folders[i*batch_size:(i+1)*batch_size]
batch = read_batch(batch_folders, background_channel, z_crop)
x_train, y_train = get_training_data(batch, nb_channels, truth_channel)
batch = read_batch(batch_folders, input_shape)
x_train, y_train = get_training_data(batch, nb_channels, truth_channel=3)
del batch, batch_folders
yield x_train, y_train


def read_batch(folders, background_channel, z_crop):
def read_batch(folders, input_shape):
batch = []
for folder in folders:
batch.append(crop_data(read_subject_folder(folder), background_channel=background_channel, z_crop=z_crop))
batch.append(read_subject_folder(folder, input_shape))
return np.asarray(batch)


def read_subject_folder(folder):
flair_image = sitk.ReadImage(os.path.join(folder, "Flair.nii.gz"))
t1_image = sitk.ReadImage(os.path.join(folder, "T1.nii.gz"))
t1c_image = sitk.ReadImage(os.path.join(folder, "T1c.nii.gz"))
truth_image = sitk.ReadImage(os.path.join(folder, "truth.nii.gz"))
background_image = sitk.ReadImage(os.path.join(folder, "background.nii.gz"))
return np.array([sitk.GetArrayFromImage(t1_image),
sitk.GetArrayFromImage(t1c_image),
sitk.GetArrayFromImage(flair_image),
sitk.GetArrayFromImage(truth_image),
sitk.GetArrayFromImage(background_image)])


def crop_data(data, background_channel=4, z_crop=15):
if np.all(data[background_channel, :z_crop] == 1):
return data[:, z_crop:]
elif np.all(data[background_channel, data.shape[1] - z_crop:] == 1):
return data[:, :data.shape[1] - z_crop]
else:
upper = z_crop/2
lower = z_crop - upper
return data[:, lower:data.shape[1] - upper]
def read_subject_folder(folder, image_size):
flair_image = read_image(os.path.join(folder, "Flair.nii.gz"), image_size=image_size)
t1_image = read_image(os.path.join(folder, "T1.nii.gz"), image_size=image_size)
t1c_image = read_image(os.path.join(folder, "T1c.nii.gz"), image_size=image_size)
truth_image = read_image(os.path.join(folder, "truth.nii.gz"), image_size=image_size,
interpolation="nearest")
return np.asarray([t1_image.get_data(), t1c_image.get_data(), flair_image.get_data(), truth_image.get_data()])


def read_image(in_file, image_size, interpolation='continuous'):
image = nib.load(in_file)
return resize(image, new_shape=image_size, interpolation=interpolation)


def resize(image, new_shape, interpolation="continuous"):
ras_image = reorder_img(image, resample=interpolation)
input_shape = np.asarray(image.shape, dtype=np.float16)
output_shape = np.asarray(new_shape)
new_spacing = input_shape/output_shape
new_affine = np.copy(ras_image.affine)
new_affine[:3, :3] = ras_image.affine[:3, :3] * np.diag(new_spacing)
return resample_img(ras_image, target_affine=new_affine, target_shape=output_shape, interpolation=interpolation)


def get_truth(batch, truth_channel=3):
Expand Down
24 changes: 3 additions & 21 deletions UnetTraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
from DataGenerator import get_training_and_testing_generators, pickle_dump

pool_size = (2, 2, 2)
image_shape = (144, 240, 240)
image_shape = (240, 240, 144)
n_channels = 3
input_shape = tuple([n_channels] + list(image_shape))
n_labels = 1 # not including background
batch_size = 1
z_crop = 155 - image_shape[0]
n_epochs = 50
data_dir = "/home/neuro-user/PycharmProjects/BRATS/sample_data"
truth_channel = 3
Expand Down Expand Up @@ -123,23 +122,6 @@ def counts_to_weights(array):
return weights_list


def deleteme(array):
length = len(array)
if length > 2:
out_list = []
array_sum = array.sum()
for item in array:
background_count = array_sum - item
out_list.append({0:1, 1:float(background_count)/item})
return out_list
else:
out_dict = dict()
weights = []
for i, item in enumerate(weights):
out_dict[i] = item
return out_dict


def get_training_weights_from_data(y_train):
weights = []
for label in range(y_train.shape[1]):
Expand Down Expand Up @@ -190,8 +172,8 @@ def main(overwrite=False):

def train_model(model, model_file):
training_generator, testing_generator, nb_training_samples, nb_testing_samples = get_training_and_testing_generators(
data_dir=data_dir, batch_size=batch_size, nb_channels=n_channels, truth_channel=truth_channel, z_crop=z_crop,
background_channel=background_channel, validation_split=validation_split)
data_dir=data_dir, batch_size=batch_size, nb_channels=n_channels, input_shape=input_shape,
validation_split=validation_split)

model.fit_generator(generator=training_generator,
samples_per_epoch=nb_training_samples,
Expand Down

0 comments on commit 07b293d

Please sign in to comment.