Skip to content

Commit

Permalink
Switch DenseToSparseBatchTest to use TF combinations
Browse files Browse the repository at this point in the history
  • Loading branch information
feihugis committed Dec 3, 2019
1 parent 876279a commit 1067f41
Showing 1 changed file with 7 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,21 @@
from __future__ import division
from __future__ import print_function

from absl.testing import parameterized
import numpy as np

from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test


@test_util.run_all_in_graph_and_eager_modes
class DenseToSparseBatchTest(test_base.DatasetTestBase):
class DenseToSparseBatchTest(test_base.DatasetTestBase, parameterized.TestCase):

@combinations.generate(test_base.default_test_combinations())
def testDenseToSparseBatchDataset(self):
components = np.random.randint(12, size=(100,)).astype(np.int32)
dataset = dataset_ops.Dataset.from_tensor_slices(
Expand All @@ -53,6 +54,7 @@ def testDenseToSparseBatchDataset(self):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())

@combinations.generate(test_base.default_test_combinations())
def testDenseToSparseBatchDatasetWithUnknownShape(self):
components = np.random.randint(5, size=(40,)).astype(np.int32)
dataset = dataset_ops.Dataset.from_tensor_slices(
Expand Down Expand Up @@ -80,12 +82,14 @@ def testDenseToSparseBatchDatasetWithUnknownShape(self):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())

@combinations.generate(test_base.default_test_combinations())
def testDenseToSparseBatchDatasetWithInvalidShape(self):
input_tensor = array_ops.constant([[1]])
with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
dataset_ops.Dataset.from_tensors(input_tensor).apply(
batching.dense_to_sparse_batch(4, [-2]))

@combinations.generate(test_base.default_test_combinations())
def testDenseToSparseBatchDatasetShapeErrors(self):

def dataset_fn(input_tensor):
Expand Down

0 comments on commit 1067f41

Please sign in to comment.