Skip to content

Commit

Permalink
Refined unit tests and fixed possible issue with numpy array referenc…
Browse files Browse the repository at this point in the history
…e sharing in tensorflow.
  • Loading branch information
ghcollin committed Feb 15, 2018
1 parent d832044 commit 5b0ced3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
4 changes: 3 additions & 1 deletion tftables.py
Expand Up @@ -388,7 +388,9 @@ def feed(self):
for gen, placeholders in generators:
# Get the next batch
try:
batch = next(gen)
# Unfortunately Tensorflow seems to keep references to these arrays around somewhere,
# so a copy is required to prevent data corruption.
batch = next(gen).copy()
except StopIteration:
return
# Populate the feed_dict with the elements of this batch.
Expand Down
16 changes: 9 additions & 7 deletions tftables_test.py
Expand Up @@ -42,23 +42,25 @@ def get_batches(array, size, trim_remainder=False):


def assert_array_equal(self, a, b):
self.assertTrue(np.all(a == b),
self.assertTrue(np.array_equal(a, b),
msg="LHS: \n" + str(a) + "\n RHS: \n" + str(b))


def assert_items_equal(self, a, b, key, epsilon=0):
a = [item for sublist in a for item in sublist]
b = [item for sublist in b for item in sublist]
self.assertEqual(len(a), len(b))
a_sorted, b_sorted = (a, b) if key is None else (sorted(a, key=key), sorted(b, key=key))
#a_sorted, b_sorted = (a, b) if key is None else (sorted(a, key=key), sorted(b, key=key))

unique_a, counts_a = np.unique(a, return_counts=True)
unique_b, counts_b = np.unique(b, return_counts=True)

assert_array_equal(self, unique_a, unique_b)
self.assertAllEqual(unique_a, unique_b)

epsilon *= np.prod(a[0].shape)
delta = counts_a - counts_b
self.assertLessEqual(np.max(np.abs(delta)), 1, msg="More than one extra copy of an element.\n" + str(delta)
+ "\n" + str(np.unique(delta, return_counts=True)))
non_zero = np.abs(delta) > 0
n_non_zero = np.sum(non_zero)
self.assertLessEqual(n_non_zero, epsilon, msg="Num. zero deltas=" + str(n_non_zero) + " epsilon=" + str(epsilon)
Expand Down Expand Up @@ -115,16 +117,16 @@ def set_up(path, array, batchsize, get_tensors):
batch = reader.get_batch(path, block_size=blocksize, ordered=False)
batches = get_batches(array, batchsize)*cycles*N_threads
loader = reader.get_fifoloader(N, get_tensors(batch), threads=N_threads)
return reader, loader, batches
return reader, loader, batches, batch

array_batchsize = 10
array_reader, array_loader, array_batches = set_up(self.test_array_path, self.test_array,
array_reader, array_loader, array_batches, array_batch_pl = set_up(self.test_array_path, self.test_array,
array_batchsize, lambda x: [x])
array_data = array_loader.dequeue()
array_result = []

table_batchsize = 5
table_reader, table_loader, table_batches = set_up(self.test_table_path, self.test_table_ary,
table_reader, table_loader, table_batches, table_batch_pl = set_up(self.test_table_path, self.test_table_ary,
table_batchsize, lambda x: [x['col_A'], x['col_B']])
table_A_data, table_B_data = table_loader.dequeue()
table_result = []
Expand All @@ -136,7 +138,7 @@ def set_up(path, array, batchsize, get_tensors):
table_loader.start(sess)

for i in tqdm.tqdm(range(len(array_batches))):
array_result.append(sess.run(array_data))
array_result.append(sess.run(array_data).copy())
self.assertEqual(len(array_result[-1]), array_batchsize)

assert_items_equal(self, array_batches, array_result,
Expand Down

0 comments on commit 5b0ced3

Please sign in to comment.