Skip to content

Commit

Permalink
Fix errors with augmenting data
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Oct 25, 2018
1 parent f24f9cc commit e20ae3f
Showing 1 changed file with 11 additions and 16 deletions.
27 changes: 11 additions & 16 deletions factnn/data/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@ def image_augmenter(images):
return images


def common_step(batch_images, positions, time_slice, total_slices, labels=None, proton_data=None,
type_training=None, augment=True, swap=True, shape=None):
def common_step(batch_images, positions, labels=None, proton_images=None, augment=True, swap=True, shape=None):
if augment:
batch_images = image_augmenter(batch_images)
if type_training == "Separation":
proton_images = proton_data[positions, time_slice:time_slice + total_slices, ::]
if proton_images:
if augment:
proton_images = image_augmenter(proton_images)
batch_images = batch_images.reshape(shape)
Expand Down Expand Up @@ -92,13 +90,12 @@ def get_random_hdf5_chunk(start, stop, size, time_slice, total_slices, gamma, pr
proton_data = images_two["Image"]
training_data = images_one["Image"]
batch_images = training_data[start_pos:int(start_pos + size), time_slice:time_slice + total_slices, ::]
return common_step(batch_images, positions, time_slice, total_slices, labels=labels,
proton_data=proton_data, type_training=type_training, augment=augment, swap=swap, shape=shape)
proton_images = proton_data[start_pos:int(start_pos + size), time_slice:time_slice + total_slices, ::]
return common_step(batch_images, positions, labels=labels, proton_images=proton_images, augment=augment, swap=swap, shape=shape)
else:
training_data = images_one["Image"]
batch_images = training_data[start_pos:int(start_pos + size), time_slice:time_slice + total_slices, ::]
return common_step(batch_images, positions, time_slice, total_slices, labels=labels,
type_training=type_training, augment=augment, swap=swap, shape=shape)
return common_step(batch_images, positions, labels=labels, augment=augment, swap=swap, shape=shape)


def get_completely_random_hdf5(start, stop, size, time_slice, total_slices, gamma, proton_input=None, labels=None,
Expand Down Expand Up @@ -133,13 +130,12 @@ def get_completely_random_hdf5(start, stop, size, time_slice, total_slices, gamm
proton_data = images_two["Image"]
training_data = images_one["Image"]
batch_images = training_data[positions, time_slice:time_slice + total_slices, ::]
return common_step(batch_images, positions, time_slice, total_slices, labels=labels,
proton_data=proton_data, type_training=type_training, augment=augment, swap=swap, shape=shape)
proton_images = proton_data[positions, time_slice:time_slice + total_slices, ::]
return common_step(batch_images, positions, labels=labels, proton_images=proton_images, augment=augment, swap=swap, shape=shape)
else:
training_data = images_one["Image"]
batch_images = training_data[positions, time_slice:time_slice + total_slices, ::]
return common_step(batch_images, positions, time_slice, total_slices, labels=labels,
type_training=type_training, augment=augment, swap=swap, shape=shape)
return common_step(batch_images, positions, labels=labels, augment=augment, swap=swap, shape=shape)


def get_random_from_list(indicies, size, time_slice, total_slices, gamma, proton_input=None, labels=None,
Expand Down Expand Up @@ -174,13 +170,12 @@ def get_random_from_list(indicies, size, time_slice, total_slices, gamma, proton
proton_data = images_two["Image"]
training_data = images_one["Image"]
batch_images = training_data[positions, time_slice:time_slice + total_slices, ::]
return common_step(batch_images, positions, time_slice, total_slices, labels=labels,
proton_data=proton_data, type_training=type_training, augment=augment, swap=swap, shape=shape)
proton_images = proton_data[positions, time_slice:time_slice + total_slices, ::]
return common_step(batch_images, positions, labels=labels, proton_images=proton_images, augment=augment, swap=swap, shape=shape)
else:
training_data = images_one["Image"]
batch_images = training_data[positions, time_slice:time_slice + total_slices, ::]
return common_step(batch_images, positions, time_slice, total_slices, labels=labels,
type_training=type_training, augment=augment, swap=swap, shape=shape)
return common_step(batch_images, positions, labels=labels, augment=augment, swap=swap, shape=shape)


def get_random_from_paths(paths, size, time_slice, total_slices, preprocessor, labels=None,
Expand Down

0 comments on commit e20ae3f

Please sign in to comment.