diff --git a/mindnlp/core/_dtype.py b/mindnlp/core/_dtype.py index 0f526e731..73aa3c973 100644 --- a/mindnlp/core/_dtype.py +++ b/mindnlp/core/_dtype.py @@ -22,15 +22,15 @@ dtype = Type @property -def is_floating_point(self): +def _is_floating_point(self): return isinstance(self, (typing.Float, typing.BFloat)) @property -def is_complex(self): +def _is_complex(self): return isinstance(self, typing.Complex) -Type.is_floating_point = is_floating_point -Type.is_complex = is_complex +Type.is_floating_point = _is_floating_point +Type.is_complex = _is_complex Type.__str__ = Type.__repr__ diff --git a/mindnlp/core/ops/tensor.py b/mindnlp/core/ops/tensor.py index 0b44f0150..eefcf3f16 100644 --- a/mindnlp/core/ops/tensor.py +++ b/mindnlp/core/ops/tensor.py @@ -16,4 +16,7 @@ def numel(input): def as_tensor(data, dtype=None, **kwarg): return core.tensor(data, dtype=dtype) -__all__ = ['as_tensor', 'is_floating_point', 'is_tensor', 'numel'] \ No newline at end of file +def is_complex(input): + return input.dtype.is_complex + +__all__ = ['as_tensor', 'is_floating_point', 'is_tensor', 'numel', 'is_complex'] \ No newline at end of file