Skip to content

Commit

Permalink
Switch ParseExampleDatasetTest to use TF combinations
Browse files Browse the repository at this point in the history
  • Loading branch information
feihugis committed Dec 3, 2019
1 parent 6ff609a commit 28248c5
Showing 1 changed file with 27 additions and 6 deletions.
Expand Up @@ -20,6 +20,7 @@

import copy

from absl.testing import parameterized
import numpy as np

from tensorflow.core.example import example_pb2
Expand All @@ -28,11 +29,11 @@
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test
Expand All @@ -50,8 +51,8 @@
sequence_example = example_pb2.SequenceExample


@test_util.run_all_in_graph_and_eager_modes
class ParseExampleDatasetTest(test_base.DatasetTestBase):
class ParseExampleDatasetTest(test_base.DatasetTestBase,
parameterized.TestCase):

def _compare_output_to_expected(self, dict_tensors, expected_tensors):
self.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys()))
Expand Down Expand Up @@ -107,6 +108,7 @@ def _test(self,
self.assertEqual(
dataset_ops.get_legacy_output_shapes(dataset)[k].as_list()[1], None)

@combinations.generate(test_base.default_test_combinations())
def testEmptySerializedWithAllDefaults(self):
sparse_name = "st_a"
a_name = "a"
Expand Down Expand Up @@ -145,7 +147,8 @@ def testEmptySerializedWithAllDefaults(self):
expected_values=expected_output,
create_iterator_twice=True)

@test_util.run_deprecated_v1
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
def testEmptySerializedWithoutDefaultsShouldFail(self):
input_features = {
"st_a":
Expand Down Expand Up @@ -179,7 +182,8 @@ def testEmptySerializedWithoutDefaultsShouldFail(self):
expected_err=(errors_impl.InvalidArgumentError,
"Feature: c \\(data type: float\\) is required"))

@test_util.run_deprecated_v1
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
def testDenseNotMatchingShapeShouldFail(self):
original = [
example(features=features({
Expand All @@ -197,6 +201,7 @@ def testDenseNotMatchingShapeShouldFail(self):
expected_err=(errors_impl.InvalidArgumentError,
"Key: a, Index: 1. Number of float values"))

@combinations.generate(test_base.default_test_combinations())
def testDenseDefaultNoShapeShouldFail(self):
original = [example(features=features({"a": float_feature([1, 1, 3]),})),]

Expand All @@ -207,6 +212,7 @@ def testDenseDefaultNoShapeShouldFail(self):
{"a": parsing_ops.FixedLenFeature(None, dtypes.float32)},
expected_err=(ValueError, "Missing shape for feature a"))

@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingSparse(self):
original = [
example(features=features({
Expand Down Expand Up @@ -248,6 +254,7 @@ def testSerializedContainingSparse(self):
expected_values=expected_output,
create_iterator_twice=True)

@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingSparseFeature(self):
original = [
example(features=features({
Expand Down Expand Up @@ -284,6 +291,7 @@ def testSerializedContainingSparseFeature(self):
expected_values=expected_output,
create_iterator_twice=True)

@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingSparseFeatureReuse(self):
original = [
example(features=features({
Expand Down Expand Up @@ -325,6 +333,7 @@ def testSerializedContainingSparseFeatureReuse(self):
expected_values=expected_output,
create_iterator_twice=True)

@combinations.generate(test_base.default_test_combinations())
def testSerializedContaining3DSparseFeature(self):
original = [
example(features=features({
Expand Down Expand Up @@ -370,6 +379,7 @@ def testSerializedContaining3DSparseFeature(self):
expected_values=expected_output,
create_iterator_twice=True)

@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingDense(self):
aname = "a"
bname = "b*has+a:tricky_name"
Expand Down Expand Up @@ -407,6 +417,7 @@ def testSerializedContainingDense(self):

# This test is identical as the previous one except
# for the creation of 'serialized'.
@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingDenseWithConcat(self):
aname = "a"
bname = "b*has+a:tricky_name"
Expand Down Expand Up @@ -452,6 +463,7 @@ def testSerializedContainingDenseWithConcat(self):
expected_values=expected_output,
create_iterator_twice=True)

@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingDenseScalar(self):
original = [
example(features=features({
Expand All @@ -476,6 +488,7 @@ def testSerializedContainingDenseScalar(self):
expected_values=expected_output,
create_iterator_twice=True)

@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingDenseWithDefaults(self):
original = [
example(features=features({
Expand Down Expand Up @@ -514,6 +527,7 @@ def testSerializedContainingDenseWithDefaults(self):
expected_values=expected_output,
create_iterator_twice=True)

@combinations.generate(test_base.default_test_combinations())
def testSerializedSparseAndSparseFeatureAndDenseWithNoDefault(self):
expected_st_a = sparse_tensor.SparseTensorValue( # indices, values, shape
np.empty((0, 2), dtype=np.int64), # indices
Expand Down Expand Up @@ -569,6 +583,7 @@ def testSerializedSparseAndSparseFeatureAndDenseWithNoDefault(self):
expected_values=expected_output,
create_iterator_twice=True)

@combinations.generate(test_base.default_test_combinations())
def testerializedContainingSparseAndSparseFeatureWithReuse(self):
expected_idx = sparse_tensor.SparseTensorValue( # indices, values, shape
np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64),
Expand Down Expand Up @@ -667,11 +682,13 @@ def _testSerializedContainingVarLenDenseLargerBatch(self, batch_size):
expected_values=expected_output,
create_iterator_twice=True)

@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingVarLenDenseLargerBatch(self):
np.random.seed(3456)
for batch_size in (1, 10, 20, 100, 256):
self._testSerializedContainingVarLenDenseLargerBatch(batch_size)

@combinations.generate(test_base.default_test_combinations())
def testSerializedShapeMismatch(self):
aname = "a"
bname = "b"
Expand Down Expand Up @@ -724,7 +741,8 @@ def testSerializedShapeMismatch(self):
expected_err=(ValueError,
"Cannot reshape a tensor with 0 elements to shape"))

@test_util.run_deprecated_v1
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
def testSerializedContainingVarLenDense(self):
aname = "a"
bname = "b"
Expand Down Expand Up @@ -877,6 +895,7 @@ def testSerializedContainingVarLenDense(self):
"Unsupported: FixedLenSequenceFeature requires "
"allow_missing to be True."))

@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingRaggedFeatureWithNoPartitions(self):
original = [
example(
Expand Down Expand Up @@ -922,6 +941,7 @@ def testSerializedContainingRaggedFeatureWithNoPartitions(self):
expected_values=expected_output,
create_iterator_twice=True)

@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingRaggedFeatureWithOnePartition(self):
original = [
example(
Expand Down Expand Up @@ -1040,6 +1060,7 @@ def testSerializedContainingRaggedFeatureWithOnePartition(self):
expected_values=expected_output,
create_iterator_twice=True)

@combinations.generate(test_base.default_test_combinations())
def testSerializedContainingRaggedFeatureWithMultiplePartitions(self):
original = [
# rt shape: [(batch), 2, None, None]
Expand Down

0 comments on commit 28248c5

Please sign in to comment.