Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion mindnlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
import mindspore
from mindspore import context
from mindspore._c_expression import MSContext # pylint: disable=no-name-in-module, import-error

try:
from mindspore._c_expression import disable_multi_thread
except:
disable_multi_thread = None
# for different ascend devices
if platform.system().lower() == 'linux':
SOC = MSContext.get_instance().get_ascend_soc_version()
Expand All @@ -41,6 +44,9 @@
if SOC in ('ascend910', 'ascend310b'):
context.set_context(ascend_config={"precision_mode": "allow_mix_precision"})

if SOC == 'ascend310b' and disable_multi_thread is not None:
disable_multi_thread()

# set mindnlp.core to torch
from .utils.torch_proxy import initialize_torch_proxy, setup_metadata_patch
from .utils.safetensors_patch import setup_safetensors_patch
Expand Down
16 changes: 10 additions & 6 deletions mindnlp/core/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ def batch_norm(input, running_mean, running_var, weight=None, bias=None, trainin
if bias is None:
bias = ops.zeros(input.shape[1])

if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return mint.nn.functional.batch_norm(
input,
running_mean,
Expand All @@ -787,7 +787,7 @@ def batch_norm(input, running_mean, running_var, weight=None, bias=None, trainin

has_conv1d = hasattr(mint.nn.functional, 'conv1d')
def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if use_pyboost() and has_conv1d:
if use_pyboost() and has_conv1d and not ON_ORANGE_PI:
return mint.nn.functional.conv1d(input, weight, bias, stride, padding, dilation, groups)
pad_mode = 'pad'
pad = padding
Expand Down Expand Up @@ -819,7 +819,7 @@ def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):


def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return mint.nn.functional.conv2d(input, weight, bias, stride, padding, dilation, groups)

pad_mode = 'pad'
Expand All @@ -829,7 +829,7 @@ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
return ops.conv2d(input, weight, bias=bias, stride=stride, pad_mode=pad_mode, padding=padding, dilation=dilation, groups=groups)

def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return mint.nn.functional.conv3d(input, weight, bias, stride, padding, dilation, groups)

pad_mode = 'pad'
Expand Down Expand Up @@ -1014,7 +1014,7 @@ def max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode


def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return mint.nn.functional.group_norm(input, num_groups, weight, bias, eps)

input_shape = input.shape
Expand All @@ -1027,7 +1027,11 @@ def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
affine_param_shape[1] = C
affine_param_shape = tuple(affine_param_shape)
if weight is not None and bias is not None:
out = bias.view(affine_param_shape).addcmul(out, weight.view(affine_param_shape), 1)
if not ON_ORANGE_PI:
out = bias.view(affine_param_shape).addcmul(out, weight.view(affine_param_shape), 1)
else:
out = core.addcmul(bias.view(affine_param_shape), out, weight.view(affine_param_shape), value=1)

elif weight is not None:
out = out.mul(weight.view(affine_param_shape))
elif bias is not None:
Expand Down
2 changes: 1 addition & 1 deletion mindnlp/core/ops/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def linspace(start, end, steps, *, dtype=None, **kwargs):
start = start.item() if isinstance(start, mindspore.Tensor) else start
end = end.item() if isinstance(end, mindspore.Tensor) else end
steps = steps.item() if isinstance(steps, mindspore.Tensor) else steps
if use_pyboost() and has_linspace:
if use_pyboost() and has_linspace and not ON_ORANGE_PI:
return mindspore.mint.linspace(start, end, steps, dtype=dtype)
return ops.linspace(start, end, steps).to(dtype)

Expand Down
4 changes: 2 additions & 2 deletions mindnlp/core/ops/inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
inplace_clamp_tensor_op, inplace_copy_op, inplace_index_add_op, inplace_erfinv_op

from mindnlp import core
from ..configs import use_pyboost
from ..configs import use_pyboost, ON_ORANGE_PI
from ._inner import assign

generator_step_ = 12
Expand Down Expand Up @@ -230,7 +230,7 @@ def inplace_clamp(self, min=None, max=None):
return self

def inplace_erfinv(self):
if self.device.type == 'npu':
if self.device.type == 'npu' and not ON_ORANGE_PI:
inplace_erfinv_op(self)
else:
self.data = core.erfinv(self)
Expand Down
37 changes: 36 additions & 1 deletion mindnlp/core/ops/pointwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def arccosh(input):


def add(input, other, *, alpha=1, out=None):
if use_pyboost() and has_add:
if use_pyboost() and has_add and not ON_ORANGE_PI:
return call_ms_func(mindspore.mint.add, input, other, alpha=alpha, out=out)
if alpha != 1:
other = mul(alpha, other)
Expand Down Expand Up @@ -335,9 +335,44 @@ def erfc(input, *, out=None):

def erfinv(input, *, out=None):
if use_pyboost() and has_erfinv:
if ON_ORANGE_PI:
return erfinv_torch(input)
return call_ms_func(mindspore.mint.erfinv, input, out=out)
return call_ms_func(ops.erfinv, input, out=out)

def erfinv_torch(x):
"""
使用有理函数近似实现erfinv,适用于PyTorch张量
"""
# # 检查输入范围
# if core.any((x < -1) | (x > 1)):
# raise ValueError("erfinv(x) is only defined for x in [-1, 1]")

# 处理边界情况
sign = core.where(x > 0, 1.0, -1.0)
x = core.abs(x)

# Cody的有理函数近似
mask = x <= 0.7
x_sq = x * x

# 对于x <= 0.7的情况
p1 = 0.426170613044 + x_sq * (-0.304570194263 + x_sq * 0.152645863430)
q1 = 1.0 + x_sq * (-0.733058978416 + x_sq * 0.546875000000)
result1 = x * (p1 / q1)

# 对于x > 0.7的情况
t = core.sqrt(-core.log((1.0 - x)/2.0))
p2 = -0.322232431088 + t * (-1.00002368368 + t * (-0.342242088547 +
t * (-0.0204231210245 + t * (-0.0000453642210148))))
q2 = 0.460398842078 + t * (0.588581570495 + t * (0.531103462366 +
t * (0.103537752850 + t * 0.0038560700634)))
result2 = p2 / q2

# 合并结果
result = core.where(mask, result1, result2)

return sign * result

# exp
has_exp = hasattr(mindspore.mint, "exp")
Expand Down
Loading