Skip to content

Commit

Permalink
Switch PrefetchWithSlackTest to use TF combinations
Browse files Browse the repository at this point in the history
  • Loading branch information
feihugis committed Dec 3, 2019
1 parent bfb3142 commit efd4441
Showing 1 changed file with 8 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test


@test_util.run_all_in_graph_and_eager_modes
class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):

@test_util.run_v1_only("b/121264236")
# TODO(b/121264236)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
def testPrefetchWithSlackOption(self):
"""Determines slack_period based on num devices attached to iterator."""
dataset = dataset_ops.Dataset.range(10)
Expand All @@ -60,6 +61,7 @@ def testPrefetchWithSlackOption(self):
self.evaluate(elem_on_1)
self.evaluate(elem_on_2)

@combinations.generate(test_base.default_test_combinations())
def testPrefetchWithSlackOptionWithoutIterator(self):
"""Defaults to slack period of 1 without iterator."""
dataset = dataset_ops.Dataset.range(10)
Expand All @@ -72,6 +74,7 @@ def testPrefetchWithSlackOptionWithoutIterator(self):
dataset.options()._static_optimization_configs())
self.assertDatasetProduces(dataset, range(10))

@combinations.generate(test_base.default_test_combinations())
def testWithPassthroughDataset(self):
"""Should still work with a passthrough dataset after prefetch()."""
dataset = dataset_ops.Dataset.range(10)
Expand All @@ -82,6 +85,7 @@ def testWithPassthroughDataset(self):
dataset = dataset.with_options(options)
self.assertDatasetProduces(dataset, range(1, 11))

@combinations.generate(test_base.default_test_combinations())
def testErrorWithoutPrefetch(self):
"""The rewrite fails if there is no prefetch() in the pipeline."""
dataset = dataset_ops.Dataset.range(10)
Expand All @@ -92,6 +96,7 @@ def testErrorWithoutPrefetch(self):
get_next = self.getNext(dataset)
self.evaluate(get_next())

@combinations.generate(test_base.default_test_combinations())
def testErrorWithInvalidDataset(self):
"""With a nested dataset op after prefetch, the rewrite should fail."""
dataset = dataset_ops.Dataset.range(10)
Expand Down

0 comments on commit efd4441

Please sign in to comment.