diff --git a/mindnlp/__init__.py b/mindnlp/__init__.py index 4fd98f2ee..ea086e3c7 100644 --- a/mindnlp/__init__.py +++ b/mindnlp/__init__.py @@ -31,7 +31,10 @@ import mindspore from mindspore import context from mindspore._c_expression import MSContext # pylint: disable=no-name-in-module, import-error - +try: + from mindspore._c_expression import disable_multi_thread +except: + disable_multi_thread = None # for different ascend devices if platform.system().lower() == 'linux': SOC = MSContext.get_instance().get_ascend_soc_version() @@ -41,6 +44,9 @@ if SOC in ('ascend910', 'ascend310b'): context.set_context(ascend_config={"precision_mode": "allow_mix_precision"}) + if SOC == 'ascend310b' and disable_multi_thread is not None: + disable_multi_thread() + # set mindnlp.core to torch from .utils.torch_proxy import initialize_torch_proxy, setup_metadata_patch from .utils.safetensors_patch import setup_safetensors_patch diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py index 4409a7fa1..feadb4240 100644 --- a/mindnlp/core/nn/functional.py +++ b/mindnlp/core/nn/functional.py @@ -763,7 +763,7 @@ def batch_norm(input, running_mean, running_var, weight=None, bias=None, trainin if bias is None: bias = ops.zeros(input.shape[1]) - if use_pyboost(): + if use_pyboost() and not ON_ORANGE_PI: return mint.nn.functional.batch_norm( input, running_mean, @@ -787,7 +787,7 @@ def batch_norm(input, running_mean, running_var, weight=None, bias=None, trainin has_conv1d = hasattr(mint.nn.functional, 'conv1d') def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): - if use_pyboost() and has_conv1d: + if use_pyboost() and has_conv1d and not ON_ORANGE_PI: return mint.nn.functional.conv1d(input, weight, bias, stride, padding, dilation, groups) pad_mode = 'pad' pad = padding @@ -819,7 +819,7 @@ def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): - if use_pyboost(): + if use_pyboost() and not ON_ORANGE_PI: return mint.nn.functional.conv2d(input, weight, bias, stride, padding, dilation, groups) pad_mode = 'pad' @@ -829,7 +829,7 @@ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): return ops.conv2d(input, weight, bias=bias, stride=stride, pad_mode=pad_mode, padding=padding, dilation=dilation, groups=groups) def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): - if use_pyboost(): + if use_pyboost() and not ON_ORANGE_PI: return mint.nn.functional.conv3d(input, weight, bias, stride, padding, dilation, groups) pad_mode = 'pad' @@ -1014,7 +1014,7 @@ def max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5): - if use_pyboost(): + if use_pyboost() and not ON_ORANGE_PI: return mint.nn.functional.group_norm(input, num_groups, weight, bias, eps) input_shape = input.shape @@ -1027,7 +1027,11 @@ def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5): affine_param_shape[1] = C affine_param_shape = tuple(affine_param_shape) if weight is not None and bias is not None: - out = bias.view(affine_param_shape).addcmul(out, weight.view(affine_param_shape), 1) + if not ON_ORANGE_PI: + out = bias.view(affine_param_shape).addcmul(out, weight.view(affine_param_shape), 1) + else: + out = core.addcmul(bias.view(affine_param_shape), out, weight.view(affine_param_shape), value=1) + elif weight is not None: out = out.mul(weight.view(affine_param_shape)) elif bias is not None: diff --git a/mindnlp/core/ops/creation.py b/mindnlp/core/ops/creation.py index 748d4559d..feeac2919 100644 --- a/mindnlp/core/ops/creation.py +++ b/mindnlp/core/ops/creation.py @@ -150,7 +150,7 @@ def linspace(start, end, steps, *, dtype=None, **kwargs): start = start.item() if isinstance(start, mindspore.Tensor) else start end = end.item() if isinstance(end, mindspore.Tensor) else end steps = steps.item() if isinstance(steps, mindspore.Tensor) else steps - if use_pyboost() and has_linspace: + if use_pyboost() and has_linspace and not ON_ORANGE_PI: return mindspore.mint.linspace(start, end, steps, dtype=dtype) return ops.linspace(start, end, steps).to(dtype) diff --git a/mindnlp/core/ops/inplace.py b/mindnlp/core/ops/inplace.py index 98651c991..d8c73626a 100644 --- a/mindnlp/core/ops/inplace.py +++ b/mindnlp/core/ops/inplace.py @@ -11,7 +11,7 @@ inplace_clamp_tensor_op, inplace_copy_op, inplace_index_add_op, inplace_erfinv_op from mindnlp import core -from ..configs import use_pyboost +from ..configs import use_pyboost, ON_ORANGE_PI from ._inner import assign generator_step_ = 12 @@ -230,7 +230,7 @@ def inplace_clamp(self, min=None, max=None): return self def inplace_erfinv(self): - if self.device.type == 'npu': + if self.device.type == 'npu' and not ON_ORANGE_PI: inplace_erfinv_op(self) else: self.data = core.erfinv(self) diff --git a/mindnlp/core/ops/pointwise.py b/mindnlp/core/ops/pointwise.py index afce9a17e..fb5df41f2 100644 --- a/mindnlp/core/ops/pointwise.py +++ b/mindnlp/core/ops/pointwise.py @@ -60,7 +60,7 @@ def arccosh(input): def add(input, other, *, alpha=1, out=None): - if use_pyboost() and has_add: + if use_pyboost() and has_add and not ON_ORANGE_PI: return call_ms_func(mindspore.mint.add, input, other, alpha=alpha, out=out) if alpha != 1: other = mul(alpha, other) @@ -335,9 +335,44 @@ def erfc(input, *, out=None): def erfinv(input, *, out=None): if use_pyboost() and has_erfinv: + if ON_ORANGE_PI: + return erfinv_torch(input) return call_ms_func(mindspore.mint.erfinv, input, out=out) return call_ms_func(ops.erfinv, input, out=out) +def erfinv_torch(x): + """ + 使用有理函数近似实现erfinv,适用于PyTorch张量 + """ + # # 检查输入范围 + # if core.any((x < -1) | (x > 1)): + # raise ValueError("erfinv(x) is only defined for x in [-1, 1]") + + # 处理边界情况 + sign = core.where(x > 0, 1.0, -1.0) + x = core.abs(x) + + # Cody的有理函数近似 + mask = x <= 0.7 + x_sq = x * x + + # 对于x <= 0.7的情况 + p1 = 0.426170613044 + x_sq * (-0.304570194263 + x_sq * 0.152645863430) + q1 = 1.0 + x_sq * (-0.733058978416 + x_sq * 0.546875000000) + result1 = x * (p1 / q1) + + # 对于x > 0.7的情况 + t = core.sqrt(-core.log((1.0 - x)/2.0)) + p2 = -0.322232431088 + t * (-1.00002368368 + t * (-0.342242088547 + + t * (-0.0204231210245 + t * (-0.0000453642210148)))) + q2 = 0.460398842078 + t * (0.588581570495 + t * (0.531103462366 + + t * (0.103537752850 + t * 0.0038560700634))) + result2 = p2 / q2 + + # 合并结果 + result = core.where(mask, result1, result2) + + return sign * result # exp has_exp = hasattr(mindspore.mint, "exp")