Skip to content

Commit

Permalink
Fix issue with generating streaming files
Browse files Browse the repository at this point in the history
Previously the generators were calling the stream singe_preprocessor every time, so only the first element of any file was being read and returned, now it properly works through the whole file before opening a new one up
  • Loading branch information
jacobbieker committed Nov 1, 2018
1 parent 792e2a4 commit ced429a
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 11 deletions.
4 changes: 2 additions & 2 deletions examples/flow_sep_outside.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,13 @@
from examples.open_crab_sample_constants import NUM_EVENTS_GAMMA, NUM_EVENTS_PROTON

event_totals = 0.8*NUM_EVENTS_PROTON
train_num = event_totals * 0.8
train_num = 2000#(event_totals * 0.8)
val_num = event_totals * 0.2

separation_model.fit_generator(
generator=separation_train,
steps_per_epoch=int(np.floor(train_num / separation_train.batch_size)),
epochs=50,
epochs=500,
verbose=1,
validation_data=separation_validate,
callbacks=[early_stop, model_checkpoint, tensorboard],
Expand Down
9 changes: 4 additions & 5 deletions factnn/data/augment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
from sklearn.utils import shuffle
import h5py
from keras.utils import to_categorical


def image_augmenter(images, as_channels=False):
Expand Down Expand Up @@ -302,9 +301,9 @@ def get_random_from_paths(preprocessor, size, time_slice, total_slices,
for i in range(size):
# Call processor size times to get the correct number for the batch
if as_channels:
processed_data, data_format = next(preprocessor.single_processor(final_slices=final_slices, as_channels=as_channels, collapse_time=True))
processed_data, data_format = next(preprocessor)
else:
processed_data, data_format = next(preprocessor.single_processor())
processed_data, data_format = next(preprocessor)
training_data.append(processed_data)
# Use the type of data to determine what to keep
if type_training == "Separation":
Expand Down Expand Up @@ -341,9 +340,9 @@ def get_random_from_paths(preprocessor, size, time_slice, total_slices,
for i in range(size):
# Call processor size times to get the correct number for the batch
if as_channels:
processed_data, data_format = next(proton_preprocessor.single_processor(final_slices=final_slices, as_channels=as_channels, collapse_time=True))
processed_data, data_format = next(proton_preprocessor)
else:
processed_data, data_format = next(proton_preprocessor.single_processor())
processed_data, data_format = next(proton_preprocessor)
proton_data.append(processed_data)
proton_data = [item[data_format["Image"]] for item in proton_data]
proton_data = np.array(proton_data)
Expand Down
31 changes: 27 additions & 4 deletions factnn/data/base_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def __init__(self, config):

self.init()

self.set_gens = False
self.training_gen = None
self.proton_gen = None
self.val_training_gen = None
self.val_proton_gen = False
# Now self.input_shape will be defined, so set to the correct value of 1 at the end

def init(self):
Expand Down Expand Up @@ -155,27 +160,45 @@ def __next__(self):
else:
# Now streaming from files, training, test, and validation need to be preprocessors set up for it.
if self.mode == "train":
batch_images, batch_image_label = get_random_from_paths(preprocessor=self.train_preprocessor,
if not self.set_gens:
if self.as_channels:
self.training_gen = self.train_preprocessor.single_processor(final_slices=5, as_channels=self.as_channels, collapse_time=True)
self.proton_gen = self.proton_train_preprocessor.single_processor(final_slices=5, as_channels=self.as_channels, collapse_time=True)
self.set_gens = True
else:
self.training_gen = self.train_preprocessor.single_processor()
self.proton_gen = self.train_preprocessor.single_processor()
self.set_gens = True
batch_images, batch_image_label = get_random_from_paths(preprocessor=self.training_gen,
size=self.batch_size,
time_slice=self.start_slice,
total_slices=self.number_slices,
augment=self.augment,
shape=self.input_shape,
type_training=self.type_gen,
proton_preprocessor=self.proton_train_preprocessor,
proton_preprocessor=self.proton_gen,
as_channels=self.as_channels)
return batch_images, batch_image_label

elif self.mode == "validate":
batch_images, batch_image_label = get_random_from_paths(preprocessor=self.validate_preprocessor,
if not self.set_gens:
if self.as_channels:
self.val_training_gen = self.validate_preprocessor.single_processor(final_slices=5, as_channels=self.as_channels, collapse_time=True)
self.val_proton_gen = self.proton_validate_preprocessor.single_processor(final_slices=5, as_channels=self.as_channels, collapse_time=True)
self.set_gens = True
else:
self.val_training_gen = self.validate_preprocessor.single_processor()
self.val_proton_gen = self.proton_validate_preprocessor.single_processor()
self.set_gens = True
batch_images, batch_image_label = get_random_from_paths(preprocessor=self.val_training_gen,
size=self.batch_size,
time_slice=self.start_slice,
total_slices=self.number_slices,
augment=False,
shape=self.input_shape,
swap=False,
type_training=self.type_gen,
proton_preprocessor=self.proton_validate_preprocessor,
proton_preprocessor=self.val_proton_gen,
as_channels=self.as_channels)
return batch_images, batch_image_label
elif self.mode == "test":
Expand Down

0 comments on commit ced429a

Please sign in to comment.