Skip to content

Commit

Permalink
Fix call sites invoking to_type and TensorType using a `tensorflo…
Browse files Browse the repository at this point in the history
…w` type.

PiperOrigin-RevId: 597676178
  • Loading branch information
eglanz authored and tensorflow-copybara committed Jan 11, 2024
1 parent 27f14f6 commit 1d01650
Showing 1 changed file with 4 additions and 24 deletions.
28 changes: 4 additions & 24 deletions tensorflow_federated/python/core/impl/types/computation_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,14 +302,12 @@ def _clear_intern_pool() -> None:
]


def _is_dtype_like(
obj: object,
) -> TypeGuard[Union[_DtypeLike, tf.dtypes.DType]]:
def _is_dtype_like(obj: object) -> TypeGuard[_DtypeLike]:
"""Returns `True` if `obj` is dtype like, otherwise `False`."""
if isinstance(obj, type) and issubclass(obj, np.generic):
return True
else:
return isinstance(obj, (tf.dtypes.DType, np.dtype))
return isinstance(obj, np.dtype)


_ALLOWED_NP_DTYPES = [
Expand Down Expand Up @@ -379,27 +377,18 @@ class TensorType(Type, metaclass=_Intern):
@classmethod
def _hashable_from_init_args(
cls,
dtype: Union[_DtypeLike, tf.dtypes.DType],
dtype: _DtypeLike,
shape: array_shape._ArrayShapeLike = (),
) -> Hashable:
"""Returns hashable `TensorType.__init__` args."""
# TODO: b/305743962 - This is only required to convert a `tf.dtypes.DType`
# to a `np.dtype`. It should be when `tf.dtypes.DType` can not be passed
# into the constructor of the `tff.TensorType`.
if isinstance(dtype, tf.dtypes.DType):
if dtype.base_dtype == tf.string:
dtype = np.str_
else:
dtype = dtype.base_dtype.as_numpy_dtype

dtype = _to_dtype(dtype)
if shape is not None:
shape = tuple(shape)
return (dtype, shape)

def __init__(
self,
dtype: Union[_DtypeLike, tf.dtypes.DType],
dtype: _DtypeLike,
shape: array_shape._ArrayShapeLike = (),
):
"""Constructs a new instance from the given `dtype` and `shape`.
Expand All @@ -411,15 +400,6 @@ def __init__(
Raises:
TypeError: if arguments are of the wrong types.
"""
# TODO: b/305743962 - This is only required to convert a `tf.dtypes.DType`
# to a `np.dtype`. It should be when `tf.dtypes.DType` can not be passed
# into the constructor of the `tff.TensorType`.
if isinstance(dtype, tf.dtypes.DType):
if dtype.base_dtype == tf.string:
dtype = np.str_
else:
dtype = dtype.base_dtype.as_numpy_dtype

self._dtype = _to_dtype(dtype)
if shape is not None:
shape = tuple(shape)
Expand Down

0 comments on commit 1d01650

Please sign in to comment.