Skip to content

Commit

Permalink
Made ordered access safer. Added additional convenience methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
ghcollin committed Mar 7, 2017
1 parent 42b60d2 commit 4b29b3d
Show file tree
Hide file tree
Showing 2 changed files with 221 additions and 49 deletions.
172 changes: 157 additions & 15 deletions tftables.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
def open_file(filename, batch_size, **kw_args):
"""
Open a HDF5 file for streaming with multitables.
Batches will be retrieved with size batch_size.
Additional keyword arguments will be passed to the multitables.Streamer object.
Batches will be retrieved with size ``batch_size``.
Additional keyword arguments will be passed to the ``multitables.Streamer`` object.
:param filename: Filename for the HDF5 file to be read.
:param batch_size: The size of the batches to be fetched by this reader.
Expand All @@ -24,6 +24,102 @@ def open_file(filename, batch_size, **kw_args):
return TableReader(filename, batch_size, **kw_args)


def load_dataset(filename, dataset_path, batch_size, queue_size=8,
input_transform=None,
ordered=False,
cyclic=True,
processes=None,
threads=None):
"""
Convenience function to quickly and easily load a dataset using best guess defaults.
If a table is loaded, then the ``input_transformation`` argument is required.
Returns an instance of ``FIFOQueueLoader`` that loads this dataset into a fifo queue.
This function takes a single argument, which is either a tensorflow placeholder for the
requested array or a dictionary of tensorflow placeholders for the columns in the
requested table. The output of this function should be either a single tensorflow tensor,
a tuple of tensorflow tensors, or a list of tensorflow tensors. A subsequent call to
``loader.dequeue()`` will return tensors in the same order as ``input_transform``.
For example, if an array is stored in uint8 format, but we want to cast
it to float32 format to do work on the GPU, the ``input_transform`` would be:
::
def input_transform(ary_batch):
return tf.cast(ary_batch, tf.float32)
If, instead we were loading a table with column names ``label`` and ``data`` we
need to transform this into a list. We might use something like the following
to also do the one hot transform.
::
def input_transform(tbl_batch):
labels = tbl_batch['labels']
data = tbl_batch['data']
truth = tf.to_float(tf.one_hot(labels, num_labels, 1, 0))
data_float = tf.to_float(data)
return truth, data_float
Then the subsequent call to ``loader.dequeue()`` returns these int the same order:
::
truth_batch, data_batch = loader.dequeue()
By default, this function does not preserve on-disk ordering, and gives cyclic access.
:param filename: The filename to the HDF5 file.
:param dataset_path: The internal HDF5 path to the dataset.
:param batch_size: The size of the batches to be loaded into tensorflow.
:param queue_size: The size of the tensorflow FIFO queue.
:param input_transform: A function that transforms the batch before being loaded into the queue.
:param processes: Number of concurrent processes that multitables should use to read data from disk.
:param threads: Number of threads to use to preprocess data and load the FIFO queue.
:return: a loader for the dataset
"""
if processes is None:
processes = (queue_size + 1) // 2
if threads is None:
threads = 1 if ordered else processes

reader = TableReader(filename, batch_size)

batch = reader.get_batch(dataset_path, ordered=ordered, cyclic=cyclic, n_procs=processes)

if input_transform is not None:
# Transform the input based on user specified function.
processed_batch = input_transform(batch)
elif isinstance(batch, dict):
# If the user tries to load a table, but no function is given, then we cannot go further.
# Table's return dictionaries and there is no good default on how to handle this.
raise ValueError("Table datasets must have an input transformation.")
else:
# User loaded an array, no processing requested or required.
processed_batch = batch

if isinstance(processed_batch, list):
# If the user gave a list, we're good
pass
elif isinstance(processed_batch, tuple):
# If the user gave a tuple, turn it into a list
processed_batch = list(processed_batch)
else:
# If the user returned a single value, also turn it into a list
processed_batch = [processed_batch]

loader = FIFOQueueLoader(reader, queue_size, processed_batch, threads=threads)
# The user never gets a reference to the reader, so we request the loader to close the
# reader for us when it is stopped.
loader.close_reader = True

return loader



class TableReader:
def __init__(self, filename, batch_size, **kw_args):
"""
Expand All @@ -37,6 +133,7 @@ def __init__(self, filename, batch_size, **kw_args):
self.vars = []
self.batch_size = batch_size
self.queues = []
self.order_lock = None

@staticmethod
def __match_slices(slice1, len1, slice2):
Expand Down Expand Up @@ -150,6 +247,9 @@ def get_batch(self, path, **kw_args):
kw_args['cyclic'] = True
if 'ordered' not in kw_args:
kw_args['ordered'] = True
if kw_args['ordered']:
if self.order_lock is None:
self.order_lock = threading.Lock()
queue = self.streamer.get_queue(path=path, **kw_args)
block_size = queue.block_size
# get an example for finding data types and row sizes.
Expand Down Expand Up @@ -233,6 +333,19 @@ def read_batch():

return result

@contextlib.contextmanager
def __feed_lock(self):
"""
If ordered access was requested for any variables, then the feed method should
be locked to prevent accidental data races.
:return:
"""
if self.order_lock is not None:
with self.order_lock:
yield
else:
yield

@staticmethod
def __feed_batch(feed_dict, batch, placeholders):
"""
Expand All @@ -257,19 +370,20 @@ def feed(self):
:return: A generator which yields tensorflow feed_dicts
"""
# The reader generator is initialised here to allow safe multi-threaded access to the reader.
generators = [(reader(), placeholders) for reader, placeholders in self.vars]
while True:
feed_dict = {}
for gen, placeholders in generators:
# Get the next batch
try:
batch = next(gen)
except StopIteration:
return
# Populate the feed_dict with the elements of this batch.
TableReader.__feed_batch(feed_dict, batch, placeholders)
yield feed_dict
with self.__feed_lock():
# The reader generator is initialised here to allow safe multi-threaded access to the reader.
generators = [(reader(), placeholders) for reader, placeholders in self.vars]
while True:
feed_dict = {}
for gen, placeholders in generators:
# Get the next batch
try:
batch = next(gen)
except StopIteration:
return
# Populate the feed_dict with the elements of this batch.
TableReader.__feed_batch(feed_dict, batch, placeholders)
yield feed_dict

def close(self):
"""
Expand Down Expand Up @@ -314,6 +428,7 @@ def __init__(self, reader, size, inputs, threads=1):
self.n_threads = threads
self.threads = []
self.monitor_thread = None
self.close_reader = False

def __read_thread(self, sess):
"""
Expand Down Expand Up @@ -373,7 +488,34 @@ def stop(self, sess):
self.coord.request_stop()
sess.run(self.q_close_now_op)
self.coord.join([self.monitor_thread])
if self.close_reader:
self.reader.close()

@staticmethod
def catch_termination():
"""
In non-cyclic access, once the end of the dataset is reached, an exception
is called to halt all access to the queue.
This context manager catches this exception for silent handling
of the termination condition.
:return:
"""
return contextlib.suppress(tf.errors.OutOfRangeError)

@contextlib.contextmanager
def begin(self, tf_session, catch_termination=True):
"""
Convenience context manager for starting and stopping the loader.
:param tf_session: The current Tensorflow session.
:param catch_termination: Catch the termination of the loop for non-cyclic access.
:return:
"""
self.start(tf_session)
try:
if catch_termination:
with self.catch_termination():
yield
else:
yield
finally:
self.stop(tf_session)
98 changes: 64 additions & 34 deletions tftables_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,22 @@
test_table_col_B_shape = (7,49)


def lcm(a,b):
import fractions
return abs(a * b) // fractions.gcd(a, b) if a and b else 0


class TestTableRow(tables.IsDescription):
col_A = tables.UInt32Col(shape=test_table_col_A_shape)
col_B = tables.Float64Col(shape=test_table_col_B_shape)

test_mock_data_shape = (100, 100)


class TestMockDataRow(tables.IsDescription):
label = tables.UInt32Col()
data = tables.Float64Col(shape=test_mock_data_shape)


def lcm(a,b):
import fractions
return abs(a * b) // fractions.gcd(a, b) if a and b else 0


def get_batches(array, size, trim_remainder=False):
result = [ array[i:i+size] for i in range(0, len(array), size)]
Expand Down Expand Up @@ -82,6 +89,14 @@ def setUp(self):
self.test_uint64_array_path = '/test_uint64'
uint64_array = test_file.create_array(test_file.root, self.test_uint64_array_path[1:], self.test_uint64_array)

self.test_mock_data_ary = np.array([ (
np.random.rand(*test_mock_data_shape),
np.random.randint(10, size=1)[0] ) for _ in range(1000) ],
dtype=tables.dtype_from_descr(TestMockDataRow))
self.test_mock_data_path = '/mock_data'
mock = test_file.create_table(test_file.root, self.test_mock_data_path[1:], TestMockDataRow)
mock.append(self.test_mock_data_ary)

test_file.close()

def tearDown(self):
Expand Down Expand Up @@ -145,7 +160,7 @@ def set_up(path, array, batchsize, get_tensors):
array_reader.close()
table_reader.close()

def test_noncylic(self):
def test_shared_reader(self):
batch_size = 8
reader = tftables.open_file(self.test_filename, batch_size)

Expand All @@ -156,7 +171,7 @@ def test_noncylic(self):
table_batches = get_batches(self.test_table_ary, batch_size, trim_remainder=True)
total_batches = min(len(array_batches), len(table_batches))

loader = reader.get_fifoloader(10, [array_batch, table_batch['col_A'], table_batch['col_B']])
loader = reader.get_fifoloader(10, [array_batch, table_batch['col_A'], table_batch['col_B']], threads=4)

deq = loader.dequeue()
array_result = []
Expand Down Expand Up @@ -191,35 +206,50 @@ def test_uint64(self):
batch = reader.get_batch("/test_uint64")
reader.close()

def test_quick_start_A(self):
my_network = lambda x: x
N = 100

# Open the HDF5 file. The batch_size defined the length
# (in the outer dimension) of the elements (batches) returned
# by the reader.
reader = tftables.open_file(filename=self.test_filename,
batch_size=20)

# For simple arrays, the get_batch method returns a
# placeholder for one batch taken from the array.
array_batch_placeholder = reader.get_batch(self.test_array_path)
# We can then do a transform on the raw data.
array_float = tf.to_float(array_batch_placeholder)
def test_quick_start_A(self):
my_network = lambda x, y: x
num_iterations = 100
num_labels = 10

with tf.device('/cpu:0'):
# This function preprocesses the batched before they
# are loaded into the internal queue.
# You can cast data, or do one-hot transforms.
# If the dataset is a table, this function is required.
def input_transform(tbl_batch):
labels = tbl_batch['label']
data = tbl_batch['data']

truth = tf.to_float(tf.one_hot(labels, num_labels, 1, 0))
data_float = tf.to_float(data)

return truth, data_float

# Open the HDF5 file and create a loader for a dataset.
# The batch_size defines the length (in the outer dimension)
# of the elements (batches) returned by the reader.
# Takes a function as input that pre-processes the data.
loader = tftables.load_dataset(filename=self.test_filename,
dataset_path=self.test_mock_data_path,
input_transform=input_transform,
batch_size=20)

# To get the data, we dequeue it from the loader.
# Tensorflow tensors are returned in the same order as input_transformation
truth_batch, data_batch = loader.dequeue()

# The placeholder can then be used in your network
result = my_network(array_float)
result = my_network(truth_batch, data_batch)

with tf.Session() as sess:
# The feed method provides a generator that returns
# feed_dict's containing batches from your HDF5 file.
for i, feed_dict in enumerate(reader.feed()):
sess.run(result, feed_dict=feed_dict)
if i >= N:
break

# Finally, the reader should be closed.
reader.close()

# This context manager starts and stops the internal threads and
# processes used to read the data from disk and store it in the queue.
with loader.begin(sess):
for _ in range(num_iterations):
sess.run(result)


def test_quick_start_B(self):
my_network = lambda x: x
Expand All @@ -241,9 +271,9 @@ def test_quick_start_B(self):
truth_batch = tf.to_float(labels_batch)

# This class creates a Tensorflow FIFOQueue and populates it with data from the reader.
loader = tftables.FIFOQueueLoader(reader, size=2,
# The inputs are placeholders (or graphs derived thereof) from the reader.
inputs=[col_A_pl, col_B_pl, truth_batch])
loader = reader.get_fifoloader(queue_size=2,
# The inputs are placeholders (or graphs derived thereof) from the reader.
inputs=[col_A_pl, col_B_pl, truth_batch])
# Batches are taken out of the queue using a dequeue operation.
dequeue_op = loader.dequeue()

Expand Down

0 comments on commit 4b29b3d

Please sign in to comment.