Skip to content

Commit

Permalink
Fixing issues in PR tensorflow#37400
Browse files Browse the repository at this point in the history
  • Loading branch information
lithuak committed Aug 2, 2020
1 parent 3bda04d commit cc9faf5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
16 changes: 9 additions & 7 deletions tensorflow/python/data/kernel_tests/iterator_test.py
Expand Up @@ -72,7 +72,7 @@ def testCapturingStateInOneShotRaisesException(self):
dataset = (
dataset_ops.Dataset.from_tensor_slices([0.0, 1.0, 2.0])
.map(lambda x: x + var))
with self.assertRaisesRegex(
with self.assertRaisesRegexp(
ValueError, r"`Dataset.make_one_shot_iterator\(\)` does not support "
"datasets that capture stateful objects.+myvar"):
dataset_ops.make_one_shot_iterator(dataset)
Expand Down Expand Up @@ -213,17 +213,17 @@ def testOneShotIteratorInitializerFails(self):
next_element = iterator.get_next()

with self.cached_session() as sess:
with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
with self.assertRaisesRegexp(errors.InvalidArgumentError, ""):
sess.run(next_element)

# Test that subsequent attempts to use the iterator also fail.
with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
with self.assertRaisesRegexp(errors.InvalidArgumentError, ""):
sess.run(next_element)

with self.cached_session() as sess:

def consumer_thread():
with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
with self.assertRaisesRegexp(errors.InvalidArgumentError, ""):
sess.run(next_element)

num_threads = 8
Expand Down Expand Up @@ -293,8 +293,8 @@ def testNotInitializedError(self):
get_next = iterator.get_next()

with self.cached_session() as sess:
with self.assertRaisesRegex(errors.FailedPreconditionError,
"iterator has not been initialized"):
with self.assertRaisesRegexp(errors.FailedPreconditionError,
"iterator has not been initialized"):
sess.run(get_next)

@combinations.generate(test_base.graph_only_combinations())
Expand Down Expand Up @@ -946,7 +946,9 @@ def finalize_fn(n):

@def_function.function
def fn():
dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn)
output_spec = tensor_spec.TensorSpec((), dtypes.int64)
dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn,
output_spec)
iterator = iter(dataset)
next(iterator)

Expand Down
5 changes: 4 additions & 1 deletion tensorflow/python/data/util/structure.py
Expand Up @@ -91,9 +91,11 @@ def normalize_element(element, element_signature=None):
if element_signature is None:
components = nest.flatten(element)
flattened_signature = [None] * len(components)
pack_as = element
else:
flattened_signature = nest.flatten(element_signature)
components = nest.flatten_up_to(element_signature, element)
pack_as = element_signature
with ops.name_scope("normalize_element"):
# Imported here to avoid circular dependency.
from tensorflow.python.data.ops import dataset_ops # pylint: disable=g-import-not-at-top
Expand Down Expand Up @@ -125,7 +127,8 @@ def normalize_element(element, element_signature=None):
else:
normalized_components.append(
ops.convert_to_tensor(t, name="component_%d" % i, dtype=dtype))
return nest.pack_sequence_as(element_signature, normalized_components)
return nest.pack_sequence_as(pack_as, normalized_components)



def convert_legacy_structure(output_types, output_shapes, output_classes):
Expand Down

0 comments on commit cc9faf5

Please sign in to comment.