From 4b29b3d8ab79d7978a33146de59c0050be642f05 Mon Sep 17 00:00:00 2001 From: gabrielc Date: Tue, 7 Mar 2017 17:30:31 -0500 Subject: [PATCH] Made ordered access safer. Added additional convenience methods. --- tftables.py | 172 ++++++++++++++++++++++++++++++++++++++++++----- tftables_test.py | 98 +++++++++++++++++---------- 2 files changed, 221 insertions(+), 49 deletions(-) diff --git a/tftables.py b/tftables.py index 82b6d2c..32c35bf 100644 --- a/tftables.py +++ b/tftables.py @@ -13,8 +13,8 @@ def open_file(filename, batch_size, **kw_args): """ Open a HDF5 file for streaming with multitables. - Batches will be retrieved with size batch_size. - Additional keyword arguments will be passed to the multitables.Streamer object. + Batches will be retrieved with size ``batch_size``. + Additional keyword arguments will be passed to the ``multitables.Streamer`` object. :param filename: Filename for the HDF5 file to be read. :param batch_size: The size of the batches to be fetched by this reader. @@ -24,6 +24,102 @@ def open_file(filename, batch_size, **kw_args): return TableReader(filename, batch_size, **kw_args) +def load_dataset(filename, dataset_path, batch_size, queue_size=8, + input_transform=None, + ordered=False, + cyclic=True, + processes=None, + threads=None): + """ + Convenience function to quickly and easily load a dataset using best guess defaults. + If a table is loaded, then the ``input_transformation`` argument is required. + Returns an instance of ``FIFOQueueLoader`` that loads this dataset into a fifo queue. + + This function takes a single argument, which is either a tensorflow placeholder for the + requested array or a dictionary of tensorflow placeholders for the columns in the + requested table. The output of this function should be either a single tensorflow tensor, + a tuple of tensorflow tensors, or a list of tensorflow tensors. A subsequent call to + ``loader.dequeue()`` will return tensors in the same order as ``input_transform``. + + For example, if an array is stored in uint8 format, but we want to cast + it to float32 format to do work on the GPU, the ``input_transform`` would be: + + :: + + def input_transform(ary_batch): + return tf.cast(ary_batch, tf.float32) + + If, instead we were loading a table with column names ``label`` and ``data`` we + need to transform this into a list. We might use something like the following + to also do the one hot transform. + + :: + + def input_transform(tbl_batch): + labels = tbl_batch['labels'] + data = tbl_batch['data'] + + truth = tf.to_float(tf.one_hot(labels, num_labels, 1, 0)) + data_float = tf.to_float(data) + + return truth, data_float + + Then the subsequent call to ``loader.dequeue()`` returns these int the same order: + + :: + + truth_batch, data_batch = loader.dequeue() + + By default, this function does not preserve on-disk ordering, and gives cyclic access. + + :param filename: The filename to the HDF5 file. + :param dataset_path: The internal HDF5 path to the dataset. + :param batch_size: The size of the batches to be loaded into tensorflow. + :param queue_size: The size of the tensorflow FIFO queue. + :param input_transform: A function that transforms the batch before being loaded into the queue. + :param processes: Number of concurrent processes that multitables should use to read data from disk. + :param threads: Number of threads to use to preprocess data and load the FIFO queue. + :return: a loader for the dataset + """ + if processes is None: + processes = (queue_size + 1) // 2 + if threads is None: + threads = 1 if ordered else processes + + reader = TableReader(filename, batch_size) + + batch = reader.get_batch(dataset_path, ordered=ordered, cyclic=cyclic, n_procs=processes) + + if input_transform is not None: + # Transform the input based on user specified function. + processed_batch = input_transform(batch) + elif isinstance(batch, dict): + # If the user tries to load a table, but no function is given, then we cannot go further. + # Table's return dictionaries and there is no good default on how to handle this. + raise ValueError("Table datasets must have an input transformation.") + else: + # User loaded an array, no processing requested or required. + processed_batch = batch + + if isinstance(processed_batch, list): + # If the user gave a list, we're good + pass + elif isinstance(processed_batch, tuple): + # If the user gave a tuple, turn it into a list + processed_batch = list(processed_batch) + else: + # If the user returned a single value, also turn it into a list + processed_batch = [processed_batch] + + loader = FIFOQueueLoader(reader, queue_size, processed_batch, threads=threads) + # The user never gets a reference to the reader, so we request the loader to close the + # reader for us when it is stopped. + loader.close_reader = True + + return loader + + + class TableReader: def __init__(self, filename, batch_size, **kw_args): """ @@ -37,6 +133,7 @@ def __init__(self, filename, batch_size, **kw_args): self.vars = [] self.batch_size = batch_size self.queues = [] + self.order_lock = None @staticmethod def __match_slices(slice1, len1, slice2): @@ -150,6 +247,9 @@ def get_batch(self, path, **kw_args): kw_args['cyclic'] = True if 'ordered' not in kw_args: kw_args['ordered'] = True + if kw_args['ordered']: + if self.order_lock is None: + self.order_lock = threading.Lock() queue = self.streamer.get_queue(path=path, **kw_args) block_size = queue.block_size # get an example for finding data types and row sizes. @@ -233,6 +333,19 @@ def read_batch(): return result + @contextlib.contextmanager + def __feed_lock(self): + """ + If ordered access was requested for any variables, then the feed method should + be locked to prevent accidental data races. + :return: + """ + if self.order_lock is not None: + with self.order_lock: + yield + else: + yield + @staticmethod def __feed_batch(feed_dict, batch, placeholders): """ @@ -257,19 +370,20 @@ def feed(self): :return: A generator which yields tensorflow feed_dicts """ - # The reader generator is initialised here to allow safe multi-threaded access to the reader. - generators = [(reader(), placeholders) for reader, placeholders in self.vars] - while True: - feed_dict = {} - for gen, placeholders in generators: - # Get the next batch - try: - batch = next(gen) - except StopIteration: - return - # Populate the feed_dict with the elements of this batch. - TableReader.__feed_batch(feed_dict, batch, placeholders) - yield feed_dict + with self.__feed_lock(): + # The reader generator is initialised here to allow safe multi-threaded access to the reader. + generators = [(reader(), placeholders) for reader, placeholders in self.vars] + while True: + feed_dict = {} + for gen, placeholders in generators: + # Get the next batch + try: + batch = next(gen) + except StopIteration: + return + # Populate the feed_dict with the elements of this batch. + TableReader.__feed_batch(feed_dict, batch, placeholders) + yield feed_dict def close(self): """ @@ -314,6 +428,7 @@ def __init__(self, reader, size, inputs, threads=1): self.n_threads = threads self.threads = [] self.monitor_thread = None + self.close_reader = False def __read_thread(self, sess): """ @@ -373,7 +488,34 @@ def stop(self, sess): self.coord.request_stop() sess.run(self.q_close_now_op) self.coord.join([self.monitor_thread]) + if self.close_reader: + self.reader.close() @staticmethod def catch_termination(): + """ + In non-cyclic access, once the end of the dataset is reached, an exception + is called to halt all access to the queue. + This context manager catches this exception for silent handling + of the termination condition. + :return: + """ return contextlib.suppress(tf.errors.OutOfRangeError) + + @contextlib.contextmanager + def begin(self, tf_session, catch_termination=True): + """ + Convenience context manager for starting and stopping the loader. + :param tf_session: The current Tensorflow session. + :param catch_termination: Catch the termination of the loop for non-cyclic access. + :return: + """ + self.start(tf_session) + try: + if catch_termination: + with self.catch_termination(): + yield + else: + yield + finally: + self.stop(tf_session) diff --git a/tftables_test.py b/tftables_test.py index 874daa1..69d7b86 100644 --- a/tftables_test.py +++ b/tftables_test.py @@ -17,15 +17,22 @@ test_table_col_B_shape = (7,49) -def lcm(a,b): - import fractions - return abs(a * b) // fractions.gcd(a, b) if a and b else 0 - - class TestTableRow(tables.IsDescription): col_A = tables.UInt32Col(shape=test_table_col_A_shape) col_B = tables.Float64Col(shape=test_table_col_B_shape) +test_mock_data_shape = (100, 100) + + +class TestMockDataRow(tables.IsDescription): + label = tables.UInt32Col() + data = tables.Float64Col(shape=test_mock_data_shape) + + +def lcm(a,b): + import fractions + return abs(a * b) // fractions.gcd(a, b) if a and b else 0 + def get_batches(array, size, trim_remainder=False): result = [ array[i:i+size] for i in range(0, len(array), size)] @@ -82,6 +89,14 @@ def setUp(self): self.test_uint64_array_path = '/test_uint64' uint64_array = test_file.create_array(test_file.root, self.test_uint64_array_path[1:], self.test_uint64_array) + self.test_mock_data_ary = np.array([ ( + np.random.rand(*test_mock_data_shape), + np.random.randint(10, size=1)[0] ) for _ in range(1000) ], + dtype=tables.dtype_from_descr(TestMockDataRow)) + self.test_mock_data_path = '/mock_data' + mock = test_file.create_table(test_file.root, self.test_mock_data_path[1:], TestMockDataRow) + mock.append(self.test_mock_data_ary) + test_file.close() def tearDown(self): @@ -145,7 +160,7 @@ def set_up(path, array, batchsize, get_tensors): array_reader.close() table_reader.close() - def test_noncylic(self): + def test_shared_reader(self): batch_size = 8 reader = tftables.open_file(self.test_filename, batch_size) @@ -156,7 +171,7 @@ def test_noncylic(self): table_batches = get_batches(self.test_table_ary, batch_size, trim_remainder=True) total_batches = min(len(array_batches), len(table_batches)) - loader = reader.get_fifoloader(10, [array_batch, table_batch['col_A'], table_batch['col_B']]) + loader = reader.get_fifoloader(10, [array_batch, table_batch['col_A'], table_batch['col_B']], threads=4) deq = loader.dequeue() array_result = [] @@ -191,35 +206,50 @@ def test_uint64(self): batch = reader.get_batch("/test_uint64") reader.close() - def test_quick_start_A(self): - my_network = lambda x: x - N = 100 - - # Open the HDF5 file. The batch_size defined the length - # (in the outer dimension) of the elements (batches) returned - # by the reader. - reader = tftables.open_file(filename=self.test_filename, - batch_size=20) - # For simple arrays, the get_batch method returns a - # placeholder for one batch taken from the array. - array_batch_placeholder = reader.get_batch(self.test_array_path) - # We can then do a transform on the raw data. - array_float = tf.to_float(array_batch_placeholder) + def test_quick_start_A(self): + my_network = lambda x, y: x + num_iterations = 100 + num_labels = 10 + + with tf.device('/cpu:0'): + # This function preprocesses the batched before they + # are loaded into the internal queue. + # You can cast data, or do one-hot transforms. + # If the dataset is a table, this function is required. + def input_transform(tbl_batch): + labels = tbl_batch['label'] + data = tbl_batch['data'] + + truth = tf.to_float(tf.one_hot(labels, num_labels, 1, 0)) + data_float = tf.to_float(data) + + return truth, data_float + + # Open the HDF5 file and create a loader for a dataset. + # The batch_size defines the length (in the outer dimension) + # of the elements (batches) returned by the reader. + # Takes a function as input that pre-processes the data. + loader = tftables.load_dataset(filename=self.test_filename, + dataset_path=self.test_mock_data_path, + input_transform=input_transform, + batch_size=20) + + # To get the data, we dequeue it from the loader. + # Tensorflow tensors are returned in the same order as input_transformation + truth_batch, data_batch = loader.dequeue() # The placeholder can then be used in your network - result = my_network(array_float) + result = my_network(truth_batch, data_batch) with tf.Session() as sess: - # The feed method provides a generator that returns - # feed_dict's containing batches from your HDF5 file. - for i, feed_dict in enumerate(reader.feed()): - sess.run(result, feed_dict=feed_dict) - if i >= N: - break - - # Finally, the reader should be closed. - reader.close() + + # This context manager starts and stops the internal threads and + # processes used to read the data from disk and store it in the queue. + with loader.begin(sess): + for _ in range(num_iterations): + sess.run(result) + def test_quick_start_B(self): my_network = lambda x: x @@ -241,9 +271,9 @@ def test_quick_start_B(self): truth_batch = tf.to_float(labels_batch) # This class creates a Tensorflow FIFOQueue and populates it with data from the reader. - loader = tftables.FIFOQueueLoader(reader, size=2, - # The inputs are placeholders (or graphs derived thereof) from the reader. - inputs=[col_A_pl, col_B_pl, truth_batch]) + loader = reader.get_fifoloader(queue_size=2, + # The inputs are placeholders (or graphs derived thereof) from the reader. + inputs=[col_A_pl, col_B_pl, truth_batch]) # Batches are taken out of the queue using a dequeue operation. dequeue_op = loader.dequeue()