diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py index 00757723b..27c62eac8 100644 --- a/mindnlp/core/_tensor.py +++ b/mindnlp/core/_tensor.py @@ -6,7 +6,7 @@ from mindspore.common.tensor import _TensorMeta from mindspore._c_expression.typing import Type try: - from mindspore.common._stub_tensor import StubTensor + from mindspore.common._stub_tensor import StubTensor, _stub_method except: class StubTensor: pass @@ -17,7 +17,7 @@ class StubTensor: pass from . import ops, _dtype from ._dtype import dtype2np -from ._bind import get_default_device, device_ +from ._bind import get_default_device, device_, get_default_dtype from .configs import use_pyboost, ON_A1 from .storage import UntypedStorage from ._utils import _rebuild_tensor_v2 @@ -98,6 +98,16 @@ def is_tensor(x): return isinstance(x, Tensor) def enable_mindspore_patch(): + old_init = Tensor.__init__ + def __init__(self, *args, **kwargs): + if len(args) > 1 and all([isinstance(arg, int) for arg in args]): + tensor = Tensor_(shape=args, dtype=get_default_dtype()) + old_init(self, tensor, internal=True) + else: + old_init(self, *args, **kwargs) + + Tensor.__init__ = __init__ + def __reduce_ex__(self, protocol): if isinstance(self, StubTensor): data = Tensor_(self.stub_sync()) @@ -280,6 +290,8 @@ def __setitem__(self, slices, value): # s = list(s) # new_slices += (s,) # slices = new_slices + if not isinstance(value, Tensor): + value = tensor(value, dtype=self.dtype) return origin_setitem(self, slices, value) Tensor.__setitem__ = __setitem__ @@ -469,6 +481,36 @@ def pin_memory(self, *args, **kwargs): Tensor.pin_memory = pin_memory StubTensor.pin_memory = pin_memory + def __deepcopy__(self, memodict): + new_obj = Tensor(self) + return new_obj + + Tensor.__deepcopy__ = __deepcopy__ + StubTensor.__deepcopy__ = __deepcopy__ + + def asnumpy(self): + return Tensor_.asnumpy(self) + + Tensor.asnumpy = asnumpy + StubTensor.asnumpy = _stub_method(asnumpy) + + def backward(self, *args, **kwargs): + pass + + Tensor.backward = backward + StubTensor.backward = backward + + def __repr__(self): + Tensor_.data_sync(self, True) + return Tensor_.__repr__(self) + + Tensor.__repr__ = __repr__ + StubTensor.__repr__ = _stub_method(__repr__) + + + def detach_(self): + return ops.stop_gradient(self) + def _rebuild_from_type_v2(func, new_type, args, state): ret = func(*args) return ret \ No newline at end of file diff --git a/mindnlp/core/backends/__init__.py b/mindnlp/core/backends/__init__.py index 0cf8794fa..706464362 100644 --- a/mindnlp/core/backends/__init__.py +++ b/mindnlp/core/backends/__init__.py @@ -1 +1 @@ -from . import cuda, mps +from . import cuda, mps, cudnn diff --git a/mindnlp/core/backends/cudnn/__init__.py b/mindnlp/core/backends/cudnn/__init__.py new file mode 100644 index 000000000..ca131a876 --- /dev/null +++ b/mindnlp/core/backends/cudnn/__init__.py @@ -0,0 +1,15 @@ +from contextlib import contextmanager + +@contextmanager +def flags( + enabled=False, + benchmark=False, + benchmark_limit=10, + deterministic=False, + allow_tf32=True, + fp32_precision="none", +): + try: + yield + finally: + pass \ No newline at end of file diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py index dda7fc233..fa66b06af 100644 --- a/mindnlp/core/nn/functional.py +++ b/mindnlp/core/nn/functional.py @@ -6,6 +6,16 @@ import mindspore from mindspore import ops, mint from mindspore.ops._primitive_cache import _get_cache_prim +from mindspore.ops.auto_generate import (reflection_pad_1d_op, reflection_pad_2d_op, add_layernorm_v2_op, + reflection_pad_3d_op, # pylint: disable=W0611 + replication_pad_1d_op, replication_pad_2d_op, replication_pad_3d_op, + constant_pad_nd_op, dropout_ext_op, reverse_v2_impl, avg_pool2d_op, + upsample_nearest1d_op, upsample_nearest2d_op, upsample_nearest3d_op, + upsample_linear1d_op, upsample_bilinear2d_op, upsample_bicubic2d_op, + upsample_trilinear3d_impl, fill_scalar_op, floor_op, nllloss_2d_op, + masked_fill_op, masked_select, ones, flatten_ext, conv_transpose2d) + + from mindnlp import core from ..configs import DEVICE_TARGET, ON_ORANGE_PI, use_pyboost, ON_A1 @@ -243,7 +253,11 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, sca return mint.nn.functional.embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq) return ops.gather(weight, input, 0) -def rms_norm(input, normalized_shape, weight, eps=1e-5): +def rms_norm(input, normalized_shape, weight, eps=None): + if eps is None: + eps = core.finfo(input.dtype).eps + if weight is None: + weight = core.ones(normalized_shape) return ops.rms_norm(input, weight, eps)[0] def fast_gelu(x): @@ -463,7 +477,161 @@ def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): return _layer_norm(input, weight, bias)[0] def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): - return ops.interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor) + if mode in ("nearest", "area", "nearest-exact"): + if align_corners is not None: + raise ValueError( + "align_corners option can only be set with the " + "interpolating modes: linear | bilinear | bicubic | trilinear" + ) + else: + if align_corners is None: + align_corners = False + + dim = input.dim() - 2 # Number of spatial dimensions. + + # Process size and scale_factor. Validate that exactly one is set. + # Validate its length if it is a list, or expand it if it is a scalar. + # After this block, exactly one of output_size and scale_factors will + # be non-None, and it will be a list (or tuple). + if size is not None and scale_factor is not None: + raise ValueError("only one of size or scale_factor should be defined") + elif size is not None: + assert scale_factor is None + scale_factors = None + if isinstance(size, (list, tuple)): + if len(size) != dim: + raise ValueError( + "Input and output must have the same number of spatial dimensions, but got " + f"input with spatial dimensions of {list(input.shape[2:])} and output size of {size}. " + "Please provide input tensor in (N, C, d1, d2, ...,dK) format and " + "output size in (o1, o2, ...,oK) format." + ) + output_size = size + else: + output_size = [size for _ in range(dim)] + elif scale_factor is not None: + assert size is None + output_size = None + if isinstance(scale_factor, (list, tuple)): + if len(scale_factor) != dim: + raise ValueError( + "Input and scale_factor must have the same number of spatial dimensions, but " + f"got input with spatial dimensions of {list(input.shape[2:])} and " + f"scale_factor of shape {scale_factor}. " + "Please provide input tensor in (N, C, d1, d2, ...,dK) format and " + "scale_factor in (s1, s2, ...,sK) format." + ) + scale_factors = scale_factor + else: + scale_factors = [scale_factor for _ in range(dim)] + else: + raise ValueError("either size or scale_factor should be defined") + + if ( + recompute_scale_factor is not None + and recompute_scale_factor + and size is not None + ): + raise ValueError( + "recompute_scale_factor is not meaningful with an explicit size." + ) + + # "area" mode always requires an explicit size rather than scale factor. + # Re-use the recompute_scale_factor code path. + if mode in ["area", "bilinear"] and output_size is None: + recompute_scale_factor = True + + if recompute_scale_factor is not None and recompute_scale_factor: + # We compute output_size here, then un-set scale_factors. + # The C++ code will recompute it based on the (integer) output size. + assert scale_factors is not None + # make scale_factor a tensor in tracing so constant doesn't get baked in + output_size = [ + ( + math.floor( + float(input.size(i + 2) * scale_factors[i]) + ) + ) + for i in range(dim) + ] + scale_factors = None + + if antialias and not (mode in ("bilinear", "bicubic") and input.ndim == 4): + raise ValueError( + "Anti-alias option is restricted to bilinear and bicubic modes and requires a 4-D tensor as input" + ) + + if input.dim() == 3 and mode == "nearest": + return upsample_nearest1d_op(input, output_size, scale_factors) + if input.dim() == 4 and mode == "nearest": + return upsample_nearest2d_op(input, output_size, scale_factors) + if input.dim() == 5 and mode == "nearest": + return upsample_nearest3d_op(input, output_size, scale_factors) + + if input.dim() == 3 and mode == "nearest-exact": + return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors) + if input.dim() == 4 and mode == "nearest-exact": + return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors) + if input.dim() == 5 and mode == "nearest-exact": + return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors) + + if input.dim() == 3 and mode == "area": + assert output_size is not None + return adaptive_avg_pool1d(input, output_size) + if input.dim() == 4 and mode == "area": + assert output_size is not None + return adaptive_avg_pool2d(input, output_size) + if input.dim() == 5 and mode == "area": + assert output_size is not None + return adaptive_avg_pool3d(input, output_size) + + if input.dim() == 3 and mode == "linear": + assert align_corners is not None + return upsample_linear1d_op( + input, output_size, scale_factors, align_corners + ) + if input.dim() == 4 and mode == "bilinear": + assert align_corners is not None + if antialias: + return torch._C._nn._upsample_bilinear2d_aa( + input, output_size, align_corners, scale_factors + ) + return upsample_bilinear2d_op( + input, output_size, scale_factors, align_corners + ) + if input.dim() == 5 and mode == "trilinear": + assert align_corners is not None + return upsample_trilinear3d_impl( + input, output_size, scale_factors, align_corners + ) + if input.dim() == 4 and mode == "bicubic": + assert align_corners is not None + if antialias: + return torch._C._nn._upsample_bicubic2d_aa( + input, output_size, align_corners, scale_factors + ) + return upsample_bicubic2d_op( + input, output_size, scale_factors, align_corners + ) + + if input.dim() == 3 and mode == "bilinear": + raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input") + if input.dim() == 3 and mode == "trilinear": + raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input") + if input.dim() == 4 and mode == "linear": + raise NotImplementedError("Got 4D input, but linear mode needs 3D input") + if input.dim() == 4 and mode == "trilinear": + raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input") + if input.dim() == 5 and mode == "linear": + raise NotImplementedError("Got 5D input, but linear mode needs 3D input") + if input.dim() == 5 and mode == "bilinear": + raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input") + + raise NotImplementedError( + "Input Error: Only 3D, 4D and 5D input Tensors supported" + f" (got {input.dim()}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact" + f" (got {mode})" + ) def normalize(input, p=2.0, dim=1, eps=1e-6): r""" @@ -599,8 +767,24 @@ def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): raise ValueError("Requires mindspore >= 2.3.0 by default, or set into pyboost mode by calling torch.config.set_byboost(True).") def conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): - return mint.nn.functional.conv_transpose1d(input, weight, bias, stride, padding, output_padding, groups, dilation) - + x_2d = input.unsqueeze(2) # (batch, in_channels, 1, L_in) + + # 2. 增加卷积核的高度维度 + weight_2d = weight.unsqueeze(2) # (in_channels, out_channels, 1, kernel_size) + + # 3. 二维转置卷积 + output_2d = conv_transpose2d( + x_2d, + weight_2d, + bias, + stride=(1,) + stride, + padding=(0,) + padding, + output_padding=(0,) + output_padding, + dilation=(1,) + dilation + ) # 输出形状: (batch, out_channels, 1, L_out) + + # 4. 移除高度维度恢复一维 + return output_2d.squeeze(2) def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): return mint.nn.functional.conv_transpose2d(input, weight, bias, stride, padding, output_padding, groups, dilation) @@ -1221,7 +1405,9 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1): return ops.fold(input, output_size, kernel_size, dilation, padding, stride) def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False): - ctc_loss_op = _get_cache_prim(nn_ops.CTCLossV2)(blank=blank, reduction="none", zero_infinity=zero_infinity) + ctc_loss_op = _get_cache_prim(ops.CTCLossV2)(blank=blank, reduction="none", zero_infinity=zero_infinity) + if targets.ndim == 1: + targets = targets.unsqueeze(-1) loss, _ = ctc_loss_op(log_probs, targets, input_lengths, target_lengths) if zero_infinity: loss = ops.where(ops.isinf(loss), 0., loss) diff --git a/mindnlp/core/nn/modules/__init__.py b/mindnlp/core/nn/modules/__init__.py index 6a37fb9a4..862358c19 100644 --- a/mindnlp/core/nn/modules/__init__.py +++ b/mindnlp/core/nn/modules/__init__.py @@ -3,7 +3,7 @@ from .container import ModuleList, ParameterList, Sequential, ParameterDict, ModuleDict from .linear import Linear, Identity from .sparse import Embedding -from .normalization import LayerNorm, GroupNorm +from .normalization import LayerNorm, GroupNorm, RMSNorm from .dropout import Dropout, Dropout2d from .activation import * from .conv import Conv3d, Conv2d, Conv1d, ConvTranspose2d, ConvTranspose1d diff --git a/mindnlp/core/nn/modules/batchnorm.py b/mindnlp/core/nn/modules/batchnorm.py index 17b88d945..92fa50e8b 100644 --- a/mindnlp/core/nn/modules/batchnorm.py +++ b/mindnlp/core/nn/modules/batchnorm.py @@ -45,8 +45,14 @@ def __init__( self.register_buffer('running_var', ops.ones(num_features,)) self.running_mean: Optional[Tensor] self.running_var: Optional[Tensor] - self.register_buffer('num_batches_tracked', - Tensor(0, dtype=core.int64)) + self.register_buffer( + "num_batches_tracked", + core.tensor( + 0, + dtype=core.long, + **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, + ), + ) self.num_batches_tracked: Optional[Tensor] else: self.register_buffer("running_mean", None) diff --git a/mindnlp/core/nn/modules/normalization.py b/mindnlp/core/nn/modules/normalization.py index 3fcc64418..ed3928f9d 100644 --- a/mindnlp/core/nn/modules/normalization.py +++ b/mindnlp/core/nn/modules/normalization.py @@ -1,8 +1,9 @@ """normalization""" +from typing import Optional import numbers from ..parameter import Parameter from .module import Module -from ..functional import group_norm, layer_norm +from .. import functional as F from .. import init from ... import ops @@ -90,7 +91,7 @@ def reset_parameters(self) -> None: init.zeros_(self.bias) def forward(self, input): - return layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) + return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) def extra_repr(self): return '{normalized_shape}, eps={eps}, ' \ @@ -155,7 +156,7 @@ def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, dtype=None): self.reset_parameters() def forward(self, input): - return group_norm(input, self.num_groups, self.weight, self.bias, self.eps) + return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) def reset_parameters(self) -> None: @@ -166,3 +167,94 @@ def reset_parameters(self) -> None: def extra_repr(self): return '{num_groups}, {num_channels}, eps={eps}, ' \ 'affine={affine}'.format(**self.__dict__) + +class RMSNorm(Module): + r"""Applies Root Mean Square Layer Normalization over a mini-batch of inputs. + + This layer implements the operation as described in + the paper `Root Mean Square Layer Normalization `__ + + .. math:: + y_i = \frac{x_i}{\mathrm{RMS}(x)} * \gamma_i, \quad + \text{where} \quad \text{RMS}(x) = \sqrt{\epsilon + \frac{1}{n} \sum_{i=1}^{n} x_i^2} + + The RMS is taken over the last ``D`` dimensions, where ``D`` + is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape` + is ``(3, 5)`` (a 2-dimensional shape), the RMS is computed over + the last 2 dimensions of the input. + + Args: + normalized_shape (int or list or torch.Size): input shape from an expected input + of size + + .. math:: + [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] + \times \ldots \times \text{normalized\_shape}[-1]] + + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability. Default: :func:`torch.finfo(x.dtype).eps` + elementwise_affine: a boolean value that when set to ``True``, this module + has learnable per-element affine parameters initialized to ones (for weights). Default: ``True``. + + Shape: + - Input: :math:`(N, *)` + - Output: :math:`(N, *)` (same shape as input) + + Examples:: + + >>> rms_norm = nn.RMSNorm([2, 3]) + >>> input = torch.randn(2, 2, 3) + >>> rms_norm(input) + + """ + __constants__ = ["normalized_shape", "eps", "elementwise_affine"] + normalized_shape: tuple[int, ...] + eps: Optional[float] + elementwise_affine: bool + + def __init__( + self, + normalized_shape, + eps: Optional[float] = None, + elementwise_affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = Parameter( + ops.empty(self.normalized_shape, **factory_kwargs) + ) + else: + self.register_parameter("weight", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + if self.elementwise_affine: + init.ones_(self.weight) + + def forward(self, x): + """ + Runs forward pass. + """ + return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) + + def extra_repr(self) -> str: + """ + Extra information about the module. + """ + return ( + "{normalized_shape}, eps={eps}, " + "elementwise_affine={elementwise_affine}".format(**self.__dict__) + ) diff --git a/mindnlp/core/nn/parameter.py b/mindnlp/core/nn/parameter.py index 438098fcd..1b11b3b0e 100644 --- a/mindnlp/core/nn/parameter.py +++ b/mindnlp/core/nn/parameter.py @@ -8,6 +8,7 @@ class Parameter(Tensor): grad = None requires_grad = False + _grad_fn = None def __init__(self, input_data=None, requires_grad=True, **kwargs): super().__init__(input_data) @@ -17,8 +18,6 @@ def __init__(self, input_data=None, requires_grad=True, **kwargs): self.param_info.parameter_shape = self._shape self.param_info.requires_grad = requires_grad self._requires_grad = requires_grad - if self._requires_grad: - self.retain_grad() def __deepcopy__(self, memodict): new_obj = Parameter(self) diff --git a/mindnlp/core/ops/array.py b/mindnlp/core/ops/array.py index 4acdf95da..2e981fc43 100644 --- a/mindnlp/core/ops/array.py +++ b/mindnlp/core/ops/array.py @@ -235,6 +235,24 @@ def split(tensor, split_size_or_sections, dim=0): return mindspore.mint.split(tensor, split_size_or_sections, dim) return ops.split(tensor, split_size_or_sections, dim) +def split_with_sizes(input, split_sizes, dim=0): + assert input.dim() != 0, "split expects at least a 1-dimensional tensor" + dim_size = input.size(dim) + num_splits = len(split_sizes) + start_idx = 0 + + splits = [] + for i in range(num_splits): + length = split_sizes[i] + assert length >= 0, f"split_with_sizes expects split_sizes have only non-negative entries, but got split_sizes={split_sizes}" + splits.append( + narrow(input, dim, start_idx, length) + ) + start_idx += length + + return splits + + # squeeze has_squeeze = hasattr(mindspore.mint, 'squeeze') def squeeze(input, *dim, **kwargs): @@ -769,6 +787,7 @@ def strided_slice_update(input, begin, end, strides, update, begin_mask=0, end_m 'scatter_nd_update', 'scatter_update', 'split', + 'split_with_sizes', 'squeeze', 'stack', 'swapaxes', diff --git a/mindnlp/core/ops/creation.py b/mindnlp/core/ops/creation.py index f32346b91..17ec23bab 100644 --- a/mindnlp/core/ops/creation.py +++ b/mindnlp/core/ops/creation.py @@ -52,8 +52,8 @@ def zeros(*size, dtype=None, device=None, requires_grad=False, **kwargs): if isinstance(size[0], (tuple, list)): size = size[0] if use_pyboost() and has_zeros: - if device == 'cpu': - return mindspore.Tensor(np.zeros(size), dtype=dtype) + # if device == 'cpu': + # return mindspore.Tensor(np.zeros(size), dtype=dtype) return mindspore.mint.zeros(size, dtype=dtype) size = tuple(size) return _zeros(size, dtype) @@ -126,7 +126,7 @@ def range(start=0, end=None, step=1, dtype=None): # linspace has_linspace = hasattr(mindspore.mint, 'linspace') -def linspace(start, end, steps, *, dtype=None): +def linspace(start, end, steps, *, dtype=None, **kwargs): if dtype is None: dtype = mindspore.float32 if use_pyboost() and has_linspace: diff --git a/mindnlp/core/ops/other.py b/mindnlp/core/ops/other.py index 72e7b0549..aee4c6eb3 100644 --- a/mindnlp/core/ops/other.py +++ b/mindnlp/core/ops/other.py @@ -7,6 +7,7 @@ from mindspore.common.initializer import initializer from mindspore.ops._primitive_cache import _get_cache_prim +from mindnlp import core from ..configs import use_pyboost, ON_ORANGE_PI, ON_A1 from .reduction import any from .comparison import eq @@ -868,7 +869,7 @@ def masked_fill(input, mask, value): if has_masked_fill: return mindspore.mint.masked_fill(input, mask, value) masked_fill_ = _get_cache_prim(ops.MaskedFill)() - return masked_fill_(input, mask, mindspore.tensor(value, dtype=input.dtype)) + return masked_fill_(input, mask, core.tensor(value, dtype=input.dtype)) class finfo: diff --git a/mindnlp/core/ops/random.py b/mindnlp/core/ops/random.py index 4b09277df..9452b9d32 100644 --- a/mindnlp/core/ops/random.py +++ b/mindnlp/core/ops/random.py @@ -54,7 +54,7 @@ def multinomial(input, num_samples, replacement=False, *, generator=None): has_normal = hasattr(mindspore.mint, 'normal') def normal(mean=0.0, std=1.0, size=None, *, generator=None, out=None): if use_pyboost() and has_normal: - return call_ms_func(mindspore.mint.normal, mean, std, size, generator, out=out) + return call_ms_func(mindspore.mint.normal, float(mean), float(std), size, generator, out=out) if size is None: if isinstance(mean, mindspore.Tensor): size = mean.shape diff --git a/mindnlp/core/ops/tensor.py b/mindnlp/core/ops/tensor.py index f204a22e7..0b44f0150 100644 --- a/mindnlp/core/ops/tensor.py +++ b/mindnlp/core/ops/tensor.py @@ -2,6 +2,8 @@ import mindspore from mindspore._c_expression import typing # pylint: disable=no-name-in-module, import-error +from mindnlp import core + def is_floating_point(input): return isinstance(input.dtype, typing.Float) @@ -12,6 +14,6 @@ def numel(input): return input.numel() def as_tensor(data, dtype=None, **kwarg): - return mindspore.Tensor(data, dtype) + return core.tensor(data, dtype=dtype) __all__ = ['as_tensor', 'is_floating_point', 'is_tensor', 'numel'] \ No newline at end of file