Skip to content

Commit

Permalink
Fix logic for determining if a value can be cast to a specific dtype.
Browse files Browse the repository at this point in the history
Note:

* `np.can_cast` does not operate on non-scalar arrays.
* `np.can_cast` interprets strings as dtype like specifications.

Therefore, it's required to handle these cases differently.

Additionally:

1. Moved the logic for determining if a array is compatible with a dtype or a shape to `array.py`.
2. Fixed a lint error in `building_blocks` module.

PiperOrigin-RevId: 619218713
  • Loading branch information
michaelreneer authored and tensorflow-copybara committed Mar 26, 2024
1 parent 6a6263f commit 7bcaa5c
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 82 deletions.
1 change: 0 additions & 1 deletion tensorflow_federated/python/core/impl/compiler/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ py_library(
"//tensorflow_federated/proto/v0:computation_py_pb2",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/impl/types:array_shape",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/impl/types:type_analysis",
Expand Down
83 changes: 80 additions & 3 deletions tensorflow_federated/python/core/impl/compiler/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ def to_proto(
"""Returns a `Array` for the `array`."""

if dtype_hint is not None:
if not np.can_cast(array, dtype_hint):
raise ValueError(f'Expected that {array} can be cast to {dtype_hint}.')
if not is_compatible_dtype(array, dtype_hint):
raise ValueError(
f"Expected '{array}' to be compatible with '{dtype_hint}'."
)

if isinstance(array, (np.ndarray, np.generic)):
if dtype_hint is not None:
Expand All @@ -128,7 +130,7 @@ def to_proto(
array = array.encode()
value = [array]
else:
raise NotImplementedError(f'Unexpected value found: {array}.')
raise NotImplementedError(f'Unexpected array found: {array}.')

dtype_pb = dtype_utils.to_proto(dtype)
shape_pb = array_shape.to_proto(shape)
Expand Down Expand Up @@ -240,3 +242,78 @@ def to_proto(
)
else:
raise NotImplementedError(f'Unexpected dtype found: {dtype}.')


def _can_cast(array: Array, dtype: type[np.generic]) -> bool:
"""Returns `True` if `array` can be cast to the `dtype`."""
if isinstance(array, np.ndarray):
# `np.can_cast` does not operate on non-scalar arrays. See
# https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html for
# more information.
return all(np.can_cast(x, dtype) for x in array.flatten())
elif isinstance(array, (np.generic, bool, int, float, complex)):
return np.can_cast(array, dtype)
elif isinstance(array, (str, bytes)):
# `np.can_cast` interprets strings as dtype like specifications.
return dtype is np.str_ or dtype is np.bytes_
else:
return False


def is_compatible_dtype(value: Array, dtype: type[np.generic]) -> bool:
"""Returns `True` if `value` is compatible with `dtype`, otherwise `False`.
This functions checks that the `value` has the same scalar kind as `dtype` and
has a compatible size.
See https://numpy.org/doc/stable/reference/arrays.scalars.html for more
information.
Args:
value: The value to check.
dtype: The scalar `np.generic` to check against.
"""
if isinstance(value, (np.ndarray, np.generic)):
value_dtype = value.dtype
else:
value_dtype = type(value)

# Check dtype kind and skip checking dtype size because `np.bool_` does not
# have a size and values with a dtype `np.str_` and `np.bytes_` have a
# variable length.
if np.issubdtype(value_dtype, np.bool_):
return dtype is np.bool_
elif np.issubdtype(value_dtype, np.character):
return dtype is np.str_ or dtype is np.bytes_

# Check dtype kind.
if np.issubdtype(value_dtype, np.integer):
if not np.issubdtype(dtype, np.integer):
return False
elif np.issubdtype(value_dtype, np.floating):
if not np.issubdtype(dtype, np.floating):
return False
elif np.issubdtype(value_dtype, np.complexfloating):
if not np.issubdtype(dtype, np.complexfloating):
return False
else:
return False

# Check dtype size.
if not _can_cast(value, dtype):
return False

return True


def is_compatible_shape(value: Array, shape: array_shape.ArrayShape) -> bool:
"""Returns `True` if `value` is compatible with `shape`, otherwise `False`.
Args:
value: The value to check.
shape: The `tff.types.ArrayShape` to check against.
"""
if isinstance(value, np.ndarray):
return array_shape.is_compatible_with(value.shape, shape)
else:
return array_shape.is_shape_scalar(shape)
122 changes: 121 additions & 1 deletion tensorflow_federated/python/core/impl/compiler/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,36 @@ def test_to_proto_returns_value_with_no_dtype_hint(
int32_list=array_pb2.Array.IntList(value=[1, 2, 3, 4, 5, 6]),
),
),
(
'scalar_different_dtype',
1,
np.int64,
array_pb2.Array(
dtype=data_type_pb2.DataType.DT_INT64,
shape=array_pb2.ArrayShape(dim=[]),
int64_list=array_pb2.Array.Int64List(value=[1]),
),
),
(
'generic_different_dtype',
np.int32(1),
np.int64,
array_pb2.Array(
dtype=data_type_pb2.DataType.DT_INT64,
shape=array_pb2.ArrayShape(dim=[]),
int64_list=array_pb2.Array.Int64List(value=[1]),
),
),
(
'array_different_dtype',
np.array([[1, 2, 3], [4, 5, 6]], np.int32),
np.int64,
array_pb2.Array(
dtype=data_type_pb2.DataType.DT_INT64,
shape=array_pb2.ArrayShape(dim=[2, 3]),
int64_list=array_pb2.Array.Int64List(value=[1, 2, 3, 4, 5, 6]),
),
),
)
def test_to_proto_returns_value_with_dtype_hint(
self, value, dtype, expected_value
Expand All @@ -478,7 +508,9 @@ def test_to_proto_returns_value_with_dtype_hint(
self.assertEqual(actual_value, expected_value)

@parameterized.named_parameters(
('int', np.iinfo(np.int64).max, np.int32),
('scalar', np.iinfo(np.int64).max, np.int32),
('generic', np.int64(np.iinfo(np.int64).max), np.int32),
('array', np.array([np.iinfo(np.int64).max] * 3, np.int64), np.int32),
)
def test_to_proto_raises_value_error_with_incompatible_dtype_hint(
self, value, dtype_hint
Expand All @@ -502,6 +534,94 @@ def test_to_proto_raises_not_implemented_error(self, value):
with self.assertRaises(NotImplementedError):
array.to_proto(value)

@parameterized.named_parameters(
('scalar', 1, np.int64),
('str', 'abc', np.str_),
('generic', np.int32(1), np.int64),
('array', np.array([[1, 2, 3], [4, 5, 6]], np.int32), np.int64),
)
def test_can_cast_returns_true(self, value, dtype):
result = array._can_cast(value, dtype)
self.assertTrue(result)

@parameterized.named_parameters(
('scalar', np.iinfo(np.int64).max, np.int32),
('str', 'abc', np.int32),
('generic', np.int64(np.iinfo(np.int64).max), np.int32),
('array', np.array([np.iinfo(np.int64).max] * 3, np.int64), np.int32),
)
def test_can_cast_returns_false(self, value, dtype):
result = array._can_cast(value, dtype)
self.assertFalse(result)

@parameterized.named_parameters(
('bool', True, np.bool_),
('int', 1, np.int32),
('float', 1.0, np.float32),
('complex', (1.0 + 1.0j), np.complex64),
('str', 'abc', np.str_),
('bytes', b'abc', np.bytes_),
('generic', np.int32(1), np.int32),
('generic_smaller_size', np.int32(1), np.int16),
('generic_larger_size', np.int32(1), np.int64),
('array', np.array([[1, 2, 3], [4, 5, 6]], np.int32), np.int32),
(
'array_smaller_size',
np.array([[1, 2, 3], [4, 5, 6]], np.int32),
np.int16,
),
(
'array_larger_size',
np.array([[1, 2, 3], [4, 5, 6]], np.int32),
np.int32,
),
)
def test_is_compatible_dtype_returns_true(self, value, dtype):
result = array.is_compatible_dtype(value, dtype)
self.assertTrue(result)

@parameterized.named_parameters(
('scalar_and_incompatible_dtype_kind', 1, np.float32),
('scalar_and_incompatible_dtype_size', np.iinfo(np.int64).max, np.int32),
('generic_and_incompatible_dtype_kind', np.int32(1), np.float32),
(
'generic_and_incompatible_dtype_size',
np.int64(np.iinfo(np.int64).max),
np.int32,
),
(
'array_and_incompatible_dtype_kind',
np.array([1, 2, 3], np.int32),
np.float32,
),
(
'array_and_incompatible_dtype_size',
np.array([np.iinfo(np.int64).max] * 3, np.int64),
np.float32,
),
)
def test_is_compatible_dtype_returns_false(self, value, dtype):
result = array.is_compatible_dtype(value, dtype)
self.assertFalse(result)

@parameterized.named_parameters(
('scalar', 1, []),
('generic', np.int32(1), []),
('array', np.array([[1, 2, 3], [4, 5, 6]], np.int32), [2, 3]),
)
def test_is_compatible_shape_returns_true(self, value, shape):
result = array.is_compatible_shape(value, shape)
self.assertTrue(result)

@parameterized.named_parameters(
('scalar', 1, [3]),
('generic', np.int32(1), [3]),
('array', np.array([[1, 2, 3], [4, 5, 6]], np.int32), [3]),
)
def test_is_compatible_shape_returns_false(self, value, shape):
result = array.is_compatible_shape(value, shape)
self.assertFalse(result)


if __name__ == '__main__':
absltest.main()
72 changes: 4 additions & 68 deletions tensorflow_federated/python/core/impl/compiler/building_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.impl.compiler import array
from tensorflow_federated.python.core.impl.compiler import intrinsic_defs
from tensorflow_federated.python.core.impl.types import array_shape
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.impl.types import type_analysis
Expand Down Expand Up @@ -1110,69 +1109,6 @@ def __repr__(self) -> str:
return "Placement('{}')".format(self.uri)


def _is_compatible_dtype(value: array.Array, dtype: type[np.generic]) -> bool:
"""Returns `True` if `value` is compatible with `dtype`, otherwise `False`.
This functions checks that the `value` has the same scalar kind as `dtype` and
has a compatible size.
See https://numpy.org/doc/stable/reference/arrays.scalars.html for more
information.
Args:
value: The value to check.
dtype: The scalar `np.generic` to check against.
"""
if isinstance(value, (np.ndarray, np.generic)):
value_dtype = value.dtype
else:
value_dtype = type(value)

# Check dtype kind and skip checking dtype size because `np.bool_` does not
# have a size and values with a dtype `np.str_` and `np.bytes_` have a
# variable length.
if np.issubdtype(value_dtype, np.bool_):
return dtype is np.bool_
elif np.issubdtype(value_dtype, np.str_):
return dtype is np.str_
elif np.issubdtype(value_dtype, np.bytes_):
return dtype is np.str_

# Check dtype kind.
if np.issubdtype(value_dtype, np.integer):
if not np.issubdtype(dtype, np.integer):
return False
elif np.issubdtype(value_dtype, np.floating):
if not np.issubdtype(dtype, np.floating):
return False
elif np.issubdtype(value_dtype, np.complexfloating):
if not np.issubdtype(dtype, np.complexfloating):
return False
else:
return False

# Check dtype size.
if not np.can_cast(value, dtype):
return False

return True


def _is_compatible_shape(
value: array.Array, shape: array_shape.ArrayShape
) -> bool:
"""Returns `True` if `value` is compatible with `shape`, otherwise `False`.
Args:
value: The value to check.
shape: The `tff.types.ArrayShape` to check against.
"""
if isinstance(value, np.ndarray):
return array_shape.is_compatible_with(value.shape, shape)
else:
return array_shape.is_shape_scalar(shape)


class Literal(ComputationBuildingBlock):
"""A representation of a literal in TFF's internal language."""

Expand All @@ -1196,13 +1132,13 @@ def __init__(
elif isinstance(value, str):
value = value.encode()

if not _is_compatible_dtype(value, type_signature.dtype.type):
if not array.is_compatible_dtype(value, type_signature.dtype.type):
raise ValueError(
f"Expected '{value}' to be compatible with"
f" '{type_signature.dtype.type}'."
)

if not _is_compatible_shape(value, type_signature.shape):
if not array.is_compatible_shape(value, type_signature.shape):
raise ValueError(
f"Expected '{value}' to be compatible with '{type_signature.shape}'."
)
Expand All @@ -1229,9 +1165,9 @@ def from_proto(cls, computation_proto: pb.Computation) -> 'Literal':
return cls(value, type_signature)

def _proto(self) -> pb.Computation:
type_pb = type_serialization.serialize_type(self._type_signature)
type_pb = type_serialization.serialize_type(self.type_signature)
value_pb = array.to_proto(
self._value, dtype_hint=self._type_signature.dtype.type
self._value, dtype_hint=self.type_signature.dtype.type
)
literal_pb = pb.Literal(value=value_pb)
return pb.Computation(type=type_pb, literal=literal_pb)
Expand Down
Loading

0 comments on commit 7bcaa5c

Please sign in to comment.