Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions mindtorch/_apis/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,10 +909,11 @@ def narrow(input, dim, start, length):
return legacy.slice(input, begin, size)

def std(input, dim, correction, keepdim):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.std_op(input, dim, correction, keepdim)
return legacy.reduce_std(input, dim, keepdim)

if dim is None:
dim = ()
return legacy.reduce_std(input, dim, bool(correction), keepdim)[0]

def log(input):
if use_pyboost():
Expand Down Expand Up @@ -1083,9 +1084,9 @@ def sin(input):
return legacy.sin(input)

def batch_norm(input, weight, bias, running_mean=None, runnning_var=None, training=False, momentum=0.1, epsilon=1e-5):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.batch_norm_ext_op(input, weight, bias, running_mean, runnning_var, training, momentum, epsilon)
return legacy.batch_norm(input, weight, bias, running_mean, runnning_var, training, momentum, epsilon, 'NHWC')
return legacy.batch_norm(input, weight, bias, running_mean, runnning_var, training, epsilon, momentum, 'NCHW')

def silu(input):
if use_pyboost():
Expand Down Expand Up @@ -1448,9 +1449,11 @@ def reciprocal(input):
return legacy.reciprocal(input)

def index_add_ext(input, dim, index, source, alpha):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.index_add_ext_op(input, dim, index, source, alpha)
return legacy.index_add(input, dim, index, source, alpha)
if alpha != 1:
source = mul(alpha, source)
return legacy.index_add(input, cast(index, mindspore.int32), source, dim, True, True)

def polar(abs, angle):
if use_pyboost():
Expand Down
6 changes: 1 addition & 5 deletions mindtorch/nn/modules/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,11 +1347,7 @@ def forward(
check_other=False,
)

is_fastpath_enabled = mindtorch.backends.mha.get_fastpath_enabled()

if not is_fastpath_enabled:
why_not_fast_path = "mindtorch.backends.mha.get_fastpath_enabled() was not True"
elif not is_batched:
if not is_batched:
why_not_fast_path = (
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
)
Expand Down
4 changes: 2 additions & 2 deletions mindtorch/ops/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from mindtorch._C import default_generator
from mindtorch.executor import execute
from .._bind import get_default_dtype, get_device_in_context
from ..configs import ON_A1
from ..configs import ON_A1, ON_ORANGE_PI

generator_step_ = 12

Expand All @@ -26,7 +26,7 @@ def multinomial(input, num_samples, replacement=False, *, generator=None, out=No
num_samples = num_samples.item()
if generator is None:
generator = default_generator
if input.device.type == 'npu':
if input.device.type == 'npu' and not ON_ORANGE_PI:
output = execute("multinomial", input, num_samples, replacement, generator)

else:
Expand Down