diff --git a/mindnlp/core/_C/__init__.py b/mindnlp/core/_C/__init__.py index 079ddaece..5ffd108e1 100644 --- a/mindnlp/core/_C/__init__.py +++ b/mindnlp/core/_C/__init__.py @@ -1,6 +1,7 @@ from mindspore import default_generator, Generator as msGenerator from . import _nn +from ..types import device as device_ def _jit_set_profiling_executor(mode): pass @@ -31,3 +32,10 @@ def _debug_set_autodiff_subgraph_inlining(mode): class Generator(msGenerator): def __init__(self, device='cpu'): super().__init__() + self._device = device_(device) if isinstance(device, str) else device + + @property + def device(self): + return self._device + +class Tag: pass diff --git a/mindnlp/core/__init__.py b/mindnlp/core/__init__.py index 3b9ae5178..2702835a8 100644 --- a/mindnlp/core/__init__.py +++ b/mindnlp/core/__init__.py @@ -14,7 +14,7 @@ # ============================================================================ """core module""" import os -import platform +import math from typing import ( Any as _Any, Callable as _Callable, @@ -26,8 +26,11 @@ Union as _Union, ) +import mindspore from mindspore.runtime import Stream +from mindspore.common.api import _pynative_executor +pi = math.pi strided = None contiguous_format = None preserve_format = None @@ -105,4 +108,13 @@ def get_autocast_gpu_dtype(): def is_autocast_enabled(): return True +def use_deterministic_algorithms(mode, *, warn_only=False): + mindspore.set_context(deterministic='ON' if mode else 'OFF') + +def is_grad_enabled(): + return _pynative_executor.enable_grad() + +def set_grad_enabled(enable_grad): + return _pynative_executor.set_enable_grad(enable_grad) + __version__ = 'test_version_no_value' \ No newline at end of file diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py index 096a0032f..68b6639ff 100644 --- a/mindnlp/core/_tensor.py +++ b/mindnlp/core/_tensor.py @@ -137,6 +137,9 @@ def to_(self, *args, **kwargs): else: dtype_to = kwargs.get("dtype", None) if dtype_to is not None: + if ON_A1 and dtype_to == _dtype.bfloat16: + warnings.warn('910A do not support bfloat16, use float16 instead.') + return mindspore.ops.cast(self, _dtype.float16) return mindspore.ops.cast(self, dtype_to) return self @@ -408,7 +411,7 @@ def view(self, *args): Tensor.view = view StubTensor.view = view - def cpu(self): + def cpu(self, *args, **kwargs): return self Tensor.cpu = cpu @@ -627,6 +630,15 @@ def __contains__(self, item): Tensor.as_strided = ops.as_strided StubTensor.as_strided = ops.as_strided + Tensor.split = ops.split + StubTensor.split = ops.split + + Tensor.flip = ops.flip + StubTensor.flip = ops.flip + + Tensor.unflatten = ops.unflatten + StubTensor.unflatten = ops.unflatten + def _rebuild_from_type_v2(func, new_type, args, state): ret = func(*args) return ret \ No newline at end of file diff --git a/mindnlp/core/backends/cuda/__init__.py b/mindnlp/core/backends/cuda/__init__.py index e69de29bb..841444400 100644 --- a/mindnlp/core/backends/cuda/__init__.py +++ b/mindnlp/core/backends/cuda/__init__.py @@ -0,0 +1,29 @@ +class cuBLASModule: + # def __getattr__(self, name): + # if name == "allow_tf32": + # return torch._C._get_cublas_allow_tf32() + # elif name == "allow_fp16_reduced_precision_reduction": + # return torch._C._get_cublas_allow_fp16_reduced_precision_reduction() + # elif name == "allow_bf16_reduced_precision_reduction": + # return torch._C._get_cublas_allow_bf16_reduced_precision_reduction() + # elif name == "allow_fp16_accumulation": + # return torch._C._get_cublas_allow_fp16_accumulation() + # elif name == "fp32_precision": + # return torch._C._get_fp32_precision_getter("cuda", "matmul") + # raise AttributeError("Unknown attribute " + name) + + # def __setattr__(self, name, value): + # if name == "allow_tf32": + # return torch._C._set_cublas_allow_tf32(value) + # elif name == "allow_fp16_reduced_precision_reduction": + # return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(value) + # elif name == "allow_bf16_reduced_precision_reduction": + # return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value) + # elif name == "allow_fp16_accumulation": + # return torch._C._set_cublas_allow_fp16_accumulation(value) + # elif name == "fp32_precision": + # return torch._C._set_fp32_precision_setter("cuda", "matmul", value) + # raise AttributeError("Unknown attribute " + name) + pass + +matmul = cuBLASModule() \ No newline at end of file diff --git a/mindnlp/core/library.py b/mindnlp/core/library.py index 3d9b81195..152b39001 100644 --- a/mindnlp/core/library.py +++ b/mindnlp/core/library.py @@ -1,4 +1,140 @@ +from typing import Any, Callable, Literal, Optional, overload, Union +from collections.abc import Iterable, Sequence + +from mindnlp.core import _C + +device_types_t = Optional[Union[str, Sequence[str]]] + def register_fake(*args, **kwargs): def register(func): return func - return register \ No newline at end of file + return register + +def custom_op( + name: str, + fn: Optional[Callable] = None, + /, + *, + mutates_args: Union[str, Iterable[str]], + device_types: device_types_t = None, + schema: Optional[str] = None, + tags: Optional[Sequence[_C.Tag]] = None, +) -> Union[Callable[[Callable[..., object]], "CustomOpDef"], "CustomOpDef"]: + """Wraps a function into custom operator. + + Reasons why you may want to create a custom op include: + - Wrapping a third-party library or custom kernel to work with PyTorch + subsystems like Autograd. + - Preventing torch.compile/export/FX tracing from peeking inside your function. + + This API is used as a decorator around a function (please see examples). + The provided function must have type hints; these are needed to interface + with PyTorch's various subsystems. + + Args: + name (str): A name for the custom op that looks like "{namespace}::{name}", + e.g. "mylib::my_linear". The name is used as the op's stable identifier + in PyTorch subsystems (e.g. torch.export, FX graphs). + To avoid name collisions, please use your project name as the namespace; + e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace. + mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates. + This MUST be accurate, otherwise, the behavior is undefined. If "unknown", + it pessimistically assumes that all inputs to the operator are being mutated. + device_types (None | str | Sequence[str]): The device type(s) the function + is valid for. If no device type is provided, then the function + is used as the default implementation for all device types. + Examples: "cpu", "cuda". + When registering a device-specific implementation for an operator that accepts no Tensors, + we require the operator to have a "device: torch.device argument". + schema (None | str): A schema string for the operator. If None + (recommended) we'll infer a schema for the operator from its type + annotations. We recommend letting us infer a schema unless you + have a specific reason not to. + Example: "(Tensor x, int y) -> (Tensor, Tensor)". + + .. note:: + We recommend not passing in a ``schema`` arg and instead letting us infer + it from the type annotations. It is error-prone to write your own schema. + You may wish to provide your own schema if our interpretation of + the type annotation is not what you want. + For more info on how to write a schema string, see + `here `_ + + Examples:: + >>> import torch + >>> from torch import Tensor + >>> from torch.library import custom_op + >>> import numpy as np + >>> + >>> @custom_op("mylib::numpy_sin", mutates_args=()) + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> x = torch.randn(3) + >>> y = numpy_sin(x) + >>> assert torch.allclose(y, x.sin()) + >>> + >>> # Example of a custom op that only works for one device type. + >>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu") + >>> def numpy_sin_cpu(x: Tensor) -> Tensor: + >>> x_np = x.numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np) + >>> + >>> x = torch.randn(3) + >>> y = numpy_sin_cpu(x) + >>> assert torch.allclose(y, x.sin()) + >>> + >>> # Example of a custom op that mutates an input + >>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu") + >>> def numpy_sin_inplace(x: Tensor) -> None: + >>> x_np = x.numpy() + >>> np.sin(x_np, out=x_np) + >>> + >>> x = torch.randn(3) + >>> expected = x.sin() + >>> numpy_sin_inplace(x) + >>> assert torch.allclose(x, expected) + >>> + >>> # Example of a factory function + >>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu") + >>> def bar(device: torch.device) -> Tensor: + >>> return torch.ones(3) + >>> + >>> bar("cpu") + + """ + + def inner(fn: Callable[..., object]): + import torch + + if schema is None: + # schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args) + schema_str = None + else: + schema_str = schema + + namespace, opname = name.split("::") + # result = CustomOpDef(namespace, opname, schema_str, fn, tags) + # if schema is not None: + # # Check that schema's alias annotations match those of `mutates_args`. + # expected = set() + # for arg in result._opoverload._schema.arguments: + # if arg.alias_info is not None and arg.alias_info.is_write: + # expected.add(arg.name) + # if expected != set(mutates_args): + # raise ValueError( + # f"Attempted to create a custom op with `mutates_args={mutates_args}` " + # f"and `schema={schema}. The schema suggests that the op mutates {expected}" + # f"which is different from what was provided to us in `mutates_args`. " + # f"Please make these consistent." + # ) + # result.register_kernel(device_types)(fn) + # return result + return None + + if fn is None: + return inner + return inner(fn) diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py index ece112046..9dfc8c8c9 100644 --- a/mindnlp/core/nn/functional.py +++ b/mindnlp/core/nn/functional.py @@ -15,6 +15,8 @@ upsample_trilinear3d_impl, fill_scalar_op, floor_op, nllloss_2d_op, masked_fill_op, masked_select, ones, flatten_ext, conv_transpose2d) +from mindspore.ops.auto_generate.pyboost_inner_prim import nllloss_impl + from mindnlp import core @@ -304,14 +306,16 @@ def custom_circular_pad(x, pad): return x -def pad(input, pad, mode='constant', value=0.0): +def pad(input, pad, mode='constant', value=None): if sum(pad) == 0: return input if isinstance(pad, tuple): pad = tuple(p if isinstance(p, int) else p.item() for p in pad) if use_pyboost() and not ON_A1: return mint.nn.functional.pad(input, pad, mode, value) - if mode == 'reflect': + if mode in ['reflect', 'replicate']: + if mode == 'reflect' and input.ndim > 4: + return reflection_pad_3d_op(input, pad) return ops.pad(input, pad, mode) if mode == 'circular': return custom_circular_pad(input, pad) @@ -329,105 +333,126 @@ def pad(input, pad, mode='constant', value=0.0): return ops.pad(input, new_pad, mode, value).to(mindspore.bool_) return ops.pad(input, new_pad, mode, value) -def nll_loss(input, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0): - return _inner_nll_loss(input, target, weight, ignore_index, reduction, label_smoothing) - -def cross_entropy(input, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0): - input = input.to(core.float32) - class_dim = 0 if input.ndim == 1 else 1 - if target.dtype in [core.float32, core.float16]: - return _cross_entropy(input, target, class_dim, weight, reduction, label_smoothing) - return nll_loss(log_softmax(input, class_dim), target, weight, ignore_index, reduction, label_smoothing) +def nll_loss(input, target, weight=None, ignore_index=-100, reduction='mean'): + return _nllloss_nd(input, target, weight, ignore_index, reduction) -def _cross_entropy(inputs, target, target_dim, weight=None, reduction='mean', label_smoothing=0.0): - """cross entropy inner function""" - class_dim = 0 if inputs.ndim == 1 else 1 - n_classes = inputs.shape[class_dim] - inputs = log_softmax(inputs, class_dim) - if label_smoothing > 0.0: - target = target * (1 - label_smoothing) + label_smoothing / n_classes - +def _nllloss_nd(input, target, weight=None, ingore_index=-100, reduction='mean'): + """nllloss_nd inner function""" + input_dim = input.ndim + class_dim = 0 if input_dim == 1 else 1 + n_classes = input.shape[class_dim] if weight is None: - weight = core.ones_like(inputs) - elif inputs.ndim != 1: - broadcast_shape = [1 for _ in range(inputs.ndim)] - broadcast_shape[1] = weight.shape[0] - weight = weight.reshape(broadcast_shape) - - if reduction == 'mean': - return -(inputs * target * weight).sum() / (inputs.size / n_classes) - if reduction == 'sum': - return -(inputs * target * weight).sum() - return -(inputs * target * weight).sum(class_dim) - - -def _inner_nll_loss(inputs, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0): - ndim = inputs.ndim - if ndim == 2: - ret = _nll_loss(inputs, target, -1, weight, ignore_index, reduction, label_smoothing) - elif ndim == 4: - ret = _nll_loss(inputs, target, 1, weight, ignore_index, reduction, label_smoothing) - elif ndim == 1: - ret = _nll_loss(inputs, target, 0, weight, ignore_index, reduction, label_smoothing) + weight = ones(n_classes, input.dtype) + if input_dim < 1: + raise ValueError(f"input dim should be less than 1, but got {input_dim}") + if input_dim != 1 and input.shape[0] != target.shape[0]: + raise ValueError(f"input bacth_size should be equal to target batch_size, but got {input.shape[0]} and " + f"{target.shape[0]}") + if input_dim == 1 or input_dim == 2: + return nllloss_impl(input, target, weight, reduction, ingore_index)[0] + if input_dim == 4: + return nllloss_2d_op(input, target, weight, reduction, ingore_index)[0] + # input_dim==3 or input_dim>4 + n = input.shape[0] + c = input.shape[1] + out_size = (n,) + input.shape[2:] + if input.numel() > 0: + input = input.view((n, c, 1, -1)) else: - n = inputs.shape[0] - c = inputs.shape[1] - out_size = (n,) + inputs.shape[2:] - inputs = inputs.view((n, c, 1, -1)) + input = input.view((n, c, 0, 0)) + if target.numel() > 0: target = target.view((n, 1, -1)) - if reduction != 'none': - ret = _nll_loss(inputs, target, 1, weight, ignore_index, reduction, label_smoothing) - else: - ret = _nll_loss(inputs, target, 1, weight, ignore_index, label_smoothing=label_smoothing) - ret = ret.view(out_size) - return ret - - -def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, reduction='none', label_smoothing=0.0): - """nll loss inner function""" - if target.ndim == inputs.ndim - 1: - target = target.unsqueeze(target_dim) - if ignore_index is not None: - non_pad_mask = core.eq(target, ignore_index) - target = target.masked_fill(non_pad_mask, 0) - else: - non_pad_mask = target - if weight is not None: - loss_weights = core.gather(weight, 0, target) - orig_shape = inputs.shape - if inputs.ndim != 2: - inputs = inputs.view(orig_shape[:2] + (-1,)) - weight = weight.view(weight.shape + (1,)) - weighted_inputs = inputs * weight - weighted_inputs = weighted_inputs.view(orig_shape) - loss = core.neg(core.gather(weighted_inputs, target_dim, target)) - smooth_loss = core.neg(weighted_inputs.sum(dim=target_dim, keepdim=True)) else: - loss = core.neg(core.gather(inputs, target_dim, target)) - smooth_loss = core.neg(inputs.sum(dim=target_dim, keepdim=True)) - loss_weights = core.ones_like(loss) - - if ignore_index is not None: - loss = loss.masked_fill(non_pad_mask, 0.) - loss_weights = loss_weights.masked_fill(non_pad_mask, 0.) - smooth_loss = smooth_loss.masked_fill(non_pad_mask, 0.) - - loss = loss.squeeze(target_dim) - smooth_loss = smooth_loss.squeeze(target_dim) + target = target.view((n, 0, 0)) + if reduction != 'none': + return nllloss_2d_op(input, target, weight, reduction, ingore_index)[0] + ret = nllloss_2d_op(input, target, weight, reduction, ingore_index)[0] + return ret.view(out_size) - if reduction == 'sum': - loss = loss.sum() - smooth_loss = smooth_loss.sum() - if reduction == 'mean': - loss = loss.sum() / loss_weights.sum() - smooth_loss = smooth_loss.sum() / loss_weights.sum() - - eps_i = label_smoothing / inputs.shape[target_dim] - if label_smoothing != 0: - loss = (1. - label_smoothing) * loss + eps_i * smooth_loss +def cross_entropy(input, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0): + if label_smoothing < 0.0 or label_smoothing > 1.0: + raise ValueError(f"For cross_entropy, label_smoothing must in [0, 1]") + if input.ndim == 0 or input.shape[0] == 0: + raise ValueError(f"For cross_entropy, input don't support 0-dim and shape[0].") + class_dim = 0 if input.ndim == 1 else 1 + n_classes = input.shape[class_dim] + input = log_softmax(input, class_dim, dtype=input.dtype) + # for probabilities + target_dtype = target.dtype + if target_dtype in [mindspore.float32, mindspore.float16, mindspore.bfloat16]: + return _cross_entropy_for_probabilities(input, target, weight, reduction, label_smoothing, class_dim, + n_classes) + # for class indices + return _cross_entropy_for_class_indices(input, target, weight, ignore_index, reduction, label_smoothing, + class_dim, n_classes) + +def _cross_entropy_for_probabilities(input, target, weight, reduction, label_smoothing, class_dim, n_classes): + """cross_entropy inner function for class probabilities""" + if input.shape != target.shape: + raise ValueError("For cross_entropy that target is probabilities, input shape should equal to target shape.") + if label_smoothing > 0.0: + target = target * (1 - label_smoothing) + label_smoothing / n_classes + loss = input * target + if weight is not None: + weight_ = weight + ori_shape = loss.shape + if input.ndim > 2: + loss = loss.view(ori_shape[:2] + (-1,)) + weight_ = weight_.view(1, -1, 1) + loss = loss * weight_ + loss = loss.view(ori_shape) + if reduction == "mean": + return -mint.div(loss.sum(), (input.size / n_classes)) + if reduction == "sum": + return -loss.sum() + if reduction == "none": + return -loss.sum(class_dim) + raise ValueError(f"redution value {reduction} not valid.") + + +def _cross_entropy_for_class_indices(input, target, weight, ingore_index, reduction, label_smoothing, class_dim, + n_classes): + """cross_entropy inner function for class indices""" + nllloss = _nllloss_nd(input, target, weight, ingore_index, reduction) + if label_smoothing > 0.0: + if weight is not None: + weight_ = weight + input_ = input + ori_shape = input.shape + if input.ndim > 2: + input_ = input.view(ori_shape[:2] + (-1,)) + weight_ = weight_.view(1, -1, 1) + loss = input_ * weight_ + loss = loss.view(ori_shape) + smooth_loss = -loss.sum(class_dim) + else: + smooth_loss = -input.sum(class_dim) + ignore_mask = ops.eq(target, ingore_index) + smooth_loss = masked_fill_op(smooth_loss, ignore_mask, 0) + if reduction == "mean": + true_mask = ~ignore_mask + if weight is not None: + weight_sum = mint.gather(weight, 0, mint.masked_select(masked_select(target, true_mask))).sum() + if weight_sum == 0: + ret = smooth_loss.sum() + else: + ret = smooth_loss.sum() / weight_sum + else: + weight_sum = true_mask.sum() + if weight_sum == 0: + ret = smooth_loss.sum() + else: + ret = smooth_loss.sum() / weight_sum + elif reduction == "sum": + ret = smooth_loss.sum() + elif reduction == "none": + ret = smooth_loss + else: + raise ValueError(f"redution value {reduction} not valid.") + return (1 - label_smoothing) * nllloss + ret * (label_smoothing / n_classes) + return nllloss - return loss def mse_loss(input, target, reduction='mean'): return ops.mse_loss(input, target, reduction) @@ -462,12 +487,12 @@ def softmax(input, dim=-1, *, dtype=None): return softmax_(input) def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): + if use_pyboost(): + return mint.nn.functional.layer_norm(input, normalized_shape, weight, bias, eps) if weight is None: weight = ops.ones(normalized_shape, dtype=input.dtype) if bias is None: bias = ops.zeros(normalized_shape, dtype=input.dtype) - if use_pyboost(): - return mint.nn.functional.layer_norm(input, normalized_shape, weight, bias, eps) if weight is not None: begin_axis = input.ndim - weight.ndim else: @@ -779,15 +804,19 @@ def conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_paddi stride=(1,) + stride, padding=(0,) + padding, output_padding=(0,) + output_padding, + groups=groups, dilation=(1,) + dilation ) # 输出形状: (batch, out_channels, 1, L_out) - # 4. 移除高度维度恢复一维 return output_2d.squeeze(2) def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): return mint.nn.functional.conv_transpose2d(input, weight, bias, stride, padding, output_padding, groups, dilation) +def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + return mint.nn.functional.conv_transpose3d(input, weight, bias, stride, padding, output_padding, groups, dilation) + + def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): if use_pyboost(): return mint.nn.functional.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode=ceil_mode, return_indices=return_indices) diff --git a/mindnlp/core/nn/modules/module.py b/mindnlp/core/nn/modules/module.py index 7b90980d3..628f13a33 100644 --- a/mindnlp/core/nn/modules/module.py +++ b/mindnlp/core/nn/modules/module.py @@ -1185,6 +1185,10 @@ def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: def npu(self: T, device: Optional[Union[int, device]] = None) -> T: return self._apply(lambda t: t.npu(device)) + def cpu(self: T, device: Optional[Union[int, device]] = None) -> T: + return self._apply(lambda t: t.cpu(device)) + + def _load_from_state_dict( self, state_dict, diff --git a/mindnlp/core/ops/other.py b/mindnlp/core/ops/other.py index 98b8febc7..7ff0bddd7 100644 --- a/mindnlp/core/ops/other.py +++ b/mindnlp/core/ops/other.py @@ -148,6 +148,8 @@ def clone(input): # cummin # cumprod +def cumprod(input, dim, *, dtype=None, out=None): + return ops.cumprod(input, dim, dtype=dtype) # cumsum has_cumsum = hasattr(mindspore.mint, "cumsum") @@ -664,6 +666,8 @@ def flatten(input, start_dim=0, end_dim=-1): def flip(input, dims): + if not isinstance(dims, (list, tuple)): + dims = (dims,) if use_pyboost() and has_flip: return mindspore.mint.flip(input, dims) return ops.flip(input, dims) @@ -849,7 +853,7 @@ def triu(input, diagonal=0, *, out=None): # unflatten def unflatten(x, dim, sizes): - new_shape = x.shape[:dim] + sizes + new_shape = x.shape[:dim] + sizes + x.shape[dim+1:] return ops.reshape(x, new_shape) @@ -1023,6 +1027,7 @@ def unfold(input, dimension, size, step): "clone", "contains", "cumsum", + "cumprod", "diag", "dim_list_to_bitset", "einsum", diff --git a/mindnlp/core/ops/pointwise.py b/mindnlp/core/ops/pointwise.py index 530df82d4..501826639 100644 --- a/mindnlp/core/ops/pointwise.py +++ b/mindnlp/core/ops/pointwise.py @@ -1,9 +1,11 @@ """pointwise op""" import mindspore from mindspore import ops -from ..configs import use_pyboost +from ..configs import use_pyboost, ON_A1 from ._inner import call_ms_func +from mindnlp import core + # abs has_abs = hasattr(mindspore.mint, 'abs') def abs(input, *, out=None): @@ -431,9 +433,38 @@ def mvlgamma(input, p): # nan_to_num has_nan_to_num = hasattr(mindspore.mint, 'nan_to_num') def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): - if use_pyboost() and has_nan_to_num: + if use_pyboost() and has_nan_to_num and not ON_A1: return call_ms_func(mindspore.mint.nan_to_num, input, nan, posinf, neginf, out=out) - return call_ms_func(ops.nan_to_num, input, nan, posinf, neginf, out=out) + + # 创建输入张量的副本 + output = input.clone() + print(output.shape) + # 获取数据类型信息 + if output.is_floating_point(): + dtype = output.dtype + # 获取默认替换值 + f_info = core.finfo(dtype) + default_posinf = f_info.max if posinf is None else posinf + default_neginf = f_info.min if neginf is None else neginf + else: + # 对于整数类型,使用给定值或默认值 + default_posinf = core.iinfo(dtype).max if posinf is None else posinf + default_neginf = core.iinfo(dtype).min if neginf is None else neginf + + # 替换 NaN + if core.isnan(output).any(): + output = core.where(core.isnan(output), core.tensor(nan, dtype=output.dtype, device=output.device), output) + + # 替换正无穷大 + if core.isinf(output).any() and (posinf is not None or output.is_floating_point()): + output = core.where((output == float('inf')) & core.isinf(output), core.tensor(default_posinf, dtype=output.dtype, device=output.device), output) + + # 替换负无穷大 + if core.isinf(output).any() and (neginf is not None or output.is_floating_point()): + output = core.where((output == float('-inf')) & core.isinf(output), + core.tensor(default_neginf, dtype=output.dtype, device=output.device), output) + + return output # neg has_neg = hasattr(mindspore.mint, 'neg')