Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions keras/src/backend/common/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,6 @@ def standardize_shape(shape):
# `tf.TensorShape` may contain `Dimension` objects.
# We need to convert the items in it to either int or `None`
shape = shape.as_list()
shape = tuple(shape)

if config.backend() == "jax":
# Replace `_DimExpr` (dimension expression) with None
Expand All @@ -609,25 +608,37 @@ def standardize_shape(shape):
None if jax_export.is_symbolic_dim(d) else d for d in shape
)

if config.backend() == "torch":
# `shape` might be `torch.Size`. We need to convert the items in it to
# either int or `None`
shape = tuple(map(lambda x: int(x) if x is not None else None, shape))

for e in shape:
if e is None:
# Handle dimensions that are not ints and not None, verify they're >= 0.
standardized_shape = []
for d in shape:
if d is None:
standardized_shape.append(d)
continue
if not is_int_dtype(type(e)):

# Reject these even if they can be cast to int successfully.
if isinstance(d, (str, float)):
raise ValueError(
f"Cannot convert '{shape}' to a shape. "
f"Found invalid entry '{e}' of type '{type(e)}'. "
f"Found invalid dimension '{d}' of type '{type(d)}'. "
)
if e < 0:

try:
# Cast numpy scalars, tf constant tensors, etc.
d = int(d)
except Exception as e:
raise ValueError(
f"Cannot convert '{shape}' to a shape. "
f"Found invalid dimension '{d}' of type '{type(d)}'. "
) from e
if d < 0:
raise ValueError(
f"Cannot convert '{shape}' to a shape. "
"Negative dimensions are not allowed."
)
return shape
standardized_shape.append(d)

# This also turns subclasses of `tuple` (e.g. `torch.Size`) to plain tuple.
return tuple(standardized_shape)


def shape_equal(a_shape, b_shape):
Expand Down
188 changes: 58 additions & 130 deletions keras/src/backend/common/variables_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,32 +310,69 @@ def test_name_validation(self):
)

def test_standardize_shape_with_none(self):
"""Tests standardizing shape with None."""
with self.assertRaisesRegex(
ValueError, "Undefined shapes are not supported."
):
standardize_shape(None)

def test_standardize_shape_with_non_iterable(self):
"""Tests shape standardization with non-iterables."""
with self.assertRaisesRegex(
ValueError, "Cannot convert '42' to a shape."
):
standardize_shape(42)

def test_standardize_shape_with_valid_input(self):
"""Tests standardizing shape with valid input."""
shape = (3, 4, 5)
standardized_shape = standardize_shape(shape)
self.assertEqual(standardized_shape, (3, 4, 5))

def test_standardize_shape_with_valid_input_with_none(self):
shape = (3, None, 5)
standardized_shape = standardize_shape(shape)
self.assertEqual(standardized_shape, (3, None, 5))

def test_standardize_shape_with_valid_not_tuple_input(self):
shape = [3, 4, 5]
standardized_shape = standardize_shape(shape)
self.assertEqual(standardized_shape, (3, 4, 5))

def test_standardize_shape_with_negative_entry(self):
"""Tests standardizing shape with negative entries."""
def test_standardize_shape_with_numpy(self):
shape = [3, np.int32(4), np.int64(5)]
standardized_shape = standardize_shape(shape)
self.assertEqual(standardized_shape, (3, 4, 5))
for d in standardized_shape:
self.assertIsInstance(d, int)

def test_standardize_shape_with_string(self):
shape_with_string = (3, 4, "5")
with self.assertRaisesRegex(
ValueError,
"Cannot convert .* to a shape. Found invalid dimension '5'.",
):
standardize_shape(shape_with_string)

def test_standardize_shape_with_float(self):
shape_with_float = (3, 4, 5.0)
with self.assertRaisesRegex(
ValueError,
"Cannot convert .* to a shape. Found invalid dimension '5.0'.",
):
standardize_shape(shape_with_float)

def test_standardize_shape_with_object(self):
shape_with_object = (3, 4, object())
with self.assertRaisesRegex(
ValueError,
"Cannot convert '\\(3, 4, -5\\)' to a shape. Negative dimensions",
"Cannot convert .* to a shape. Found invalid dimension .*object",
):
standardize_shape([3, 4, -5])
standardize_shape(shape_with_object)

def test_standardize_shape_with_negative_dimension(self):
with self.assertRaisesRegex(
ValueError,
"Cannot convert .* to a shape. Negative dimensions",
):
standardize_shape((3, 4, -5))

def test_shape_equal_length_mismatch(self):
"""Test mismatch in lengths of shapes."""
Expand Down Expand Up @@ -1138,138 +1175,29 @@ def test_xor(self, dtypes):
reason="Tests for standardize_shape with Torch backend",
)
class TestStandardizeShapeWithTorch(test_case.TestCase):
"""Tests for standardize_shape with Torch backend."""

def test_standardize_shape_with_torch_size_containing_negative_value(self):
"""Tests shape with a negative value."""
shape_with_negative_value = (3, 4, -5)
with self.assertRaisesRegex(
ValueError,
"Cannot convert '\\(3, 4, -5\\)' to a shape. Negative dimensions",
):
_ = standardize_shape(shape_with_negative_value)

def test_standardize_shape_with_torch_size_valid(self):
"""Tests a valid shape."""
shape_valid = (3, 4, 5)
standardized_shape = standardize_shape(shape_valid)
self.assertEqual(standardized_shape, (3, 4, 5))

def test_standardize_shape_with_torch_size_multidimensional(self):
"""Tests shape of a multi-dimensional tensor."""
def test_standardize_shape_with_torch_size(self):
import torch

tensor = torch.randn(3, 4, 5)
shape = tensor.size()
standardized_shape = standardize_shape(shape)
self.assertEqual(standardized_shape, (3, 4, 5))

def test_standardize_shape_with_torch_size_single_dimension(self):
"""Tests shape of a single-dimensional tensor."""
import torch

tensor = torch.randn(10)
shape = tensor.size()
standardized_shape = standardize_shape(shape)
self.assertEqual(standardized_shape, (10,))

def test_standardize_shape_with_torch_size_with_valid_1_dimension(self):
"""Tests a valid shape."""
shape_valid = [3]
standardized_shape = standardize_shape(shape_valid)
self.assertEqual(standardized_shape, (3,))

def test_standardize_shape_with_torch_size_with_valid_2_dimension(self):
"""Tests a valid shape."""
shape_valid = [3, 4]
standardized_shape = standardize_shape(shape_valid)
self.assertEqual(standardized_shape, (3, 4))

def test_standardize_shape_with_torch_size_with_valid_3_dimension(self):
"""Tests a valid shape."""
shape_valid = [3, 4, 5]
standardized_shape = standardize_shape(shape_valid)
self.assertEqual(standardized_shape, (3, 4, 5))

def test_standardize_shape_with_torch_size_with_negative_value(self):
"""Tests shape with a negative value appended."""
import torch

tensor = torch.randn(3, 4, 5)
shape = tuple(tensor.size())
shape_with_negative = shape + (-1,)
with self.assertRaisesRegex(
ValueError,
"Cannot convert .* to a shape. Negative dimensions are not",
):
_ = standardize_shape(shape_with_negative)

def test_standardize_shape_with_non_integer_entry(self):
"""Tests shape with a non-integer value."""
with self.assertRaisesRegex(
# different error message for torch
ValueError,
r"invalid literal for int\(\) with base 10: 'a'",
):
standardize_shape([3, 4, "a"])

def test_standardize_shape_with_negative_entry(self):
"""Tests shape with a negative value."""
with self.assertRaisesRegex(
ValueError,
"Cannot convert '\\(3, 4, -5\\)' to a shape. Negative dimensions",
):
standardize_shape([3, 4, -5])

def test_standardize_shape_with_valid_not_tuple(self):
"""Tests a valid shape."""
shape_valid = [3, 4, 5]
standardized_shape = standardize_shape(shape_valid)
self.assertEqual(standardized_shape, (3, 4, 5))
self.assertIs(type(standardized_shape), tuple)
for d in standardized_shape:
self.assertIsInstance(d, int)


@pytest.mark.skipif(
backend.backend() == "torch",
reason="Tests for standardize_shape with others backend",
backend.backend() != "tensorflow",
reason="Tests for standardize_shape with TensorFlow backend",
)
class TestStandardizeShapeWithOutTorch(test_case.TestCase):
"""Tests for standardize_shape with others backend."""
class TestStandardizeShapeWithTensorflow(test_case.TestCase):
def test_standardize_shape_with_tensor_size(self):
import tensorflow as tf

def test_standardize_shape_with_out_torch_negative_value(self):
"""Tests shape with a negative value."""
shape_with_negative_value = (3, 4, -5)
with self.assertRaisesRegex(
ValueError,
"Cannot convert '\\(3, 4, -5\\)' to a shape. Negative dimensions",
):
_ = standardize_shape(shape_with_negative_value)

def test_standardize_shape_with_out_torch_string(self):
"""Tests shape with a string value."""
shape_with_string = (3, 4, "5")
with self.assertRaisesRegex(
ValueError,
"Cannot convert .* to a shape. Found invalid entry '5'.",
):
_ = standardize_shape(shape_with_string)

def test_standardize_shape_with_out_torch_float(self):
"""Tests shape with a float value."""
shape_with_float = (3, 4, 5.0)
with self.assertRaisesRegex(
ValueError,
"Cannot convert .* to a shape. Found invalid entry '5.0'.",
):
_ = standardize_shape(shape_with_float)

def test_standardize_shape_with_out_torch_valid(self):
"""Tests a valid shape."""
shape_valid = (3, 4, 5)
standardized_shape = standardize_shape(shape_valid)
self.assertEqual(standardized_shape, (3, 4, 5))

def test_standardize_shape_with_out_torch_valid_not_tuple(self):
"""Tests a valid shape."""
shape_valid = [3, 4, 5]
standardized_shape = standardize_shape(shape_valid)
shape = (3, tf.constant(4, dtype=tf.int64), 5)
standardized_shape = standardize_shape(shape)
self.assertEqual(standardized_shape, (3, 4, 5))
self.assertIs(type(standardized_shape), tuple)
for d in standardized_shape:
self.assertIsInstance(d, int)
Loading