Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mindnlp/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from ._bind import get_default_dtype, set_default_dtype, get_default_device
from .amp import autocast, GradScaler
from .func import vmap
from .configs import set_pyboost

from . import profiler, cuda, optim, amp, compiler, jit, version, __future__, overrides, \
return_types, linalg, fx, backends, testing, nn, fft, _jit_internal, utils
Expand Down
2 changes: 1 addition & 1 deletion mindnlp/core/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
mode=1,
pad_mode=pad_mode,
pad=pad,
stride=stride,
stride=tuple(stride),
dilation=dilation,
group=groups)
output = conv3d_op(input, weight)
Expand Down
2 changes: 1 addition & 1 deletion mindnlp/core/ops/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def sort(input, *, dim=-1, descending=False, stable=False):
# topk
has_topk = hasattr(mindspore.mint, 'topk')
def topk(input, k, dim=-1, largest=True, sorted=True):
if use_pyboost() and has_topk:
if use_pyboost() and has_topk and not ON_ORANGE_PI:
out = mindspore.mint.topk(input, int(k), dim, largest, sorted)
else:
out = ops.topk(input, k, dim, largest, sorted)
Expand Down
4 changes: 2 additions & 2 deletions mindnlp/core/ops/pointwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,9 @@ def erfc(input, *, out=None):


def erfinv(input, *, out=None):
if ON_ORANGE_PI:
return erfinv_torch(input)
if use_pyboost() and has_erfinv:
if ON_ORANGE_PI:
return erfinv_torch(input)
return call_ms_func(mindspore.mint.erfinv, input, out=out)
return call_ms_func(ops.erfinv, input, out=out)

Expand Down
Loading