diff --git a/mindtorch/_apis/npu.py b/mindtorch/_apis/npu.py index d360d198a..30e242e4c 100644 --- a/mindtorch/_apis/npu.py +++ b/mindtorch/_apis/npu.py @@ -117,7 +117,7 @@ def add(input, other, alpha=1.0): # pylint: disable=unused-argument Returns: Tensor: The result of the addition. """ - if use_pyboost(): + if use_pyboost() and not ON_ORANGE_PI: return pyboost.add_ext_op(input, other, alpha) if alpha == 1.0: return legacy.add(input, other) @@ -724,7 +724,7 @@ def less_equal(input, other): def select(condition, input, other): if ON_ORANGE_PI: - return add(mul(condition, input), mul(bitwise_not(condition), other)) + return legacy.add(mul(condition, input), mul(bitwise_not(condition), other)) if use_pyboost(): return pyboost.select_op(condition, input, other) return legacy.select(condition, input, other) @@ -975,8 +975,23 @@ def inplace_zero(input): return input def mse_loss(input, target, reduction): - if use_pyboost(): + if use_pyboost() and not ON_ORANGE_PI: return pyboost.mse_loss_ext_op(input, target, reduction) + x = square(input - target) + average_flag = True + reduce_flag = True + if reduction == 'sum': + average_flag = False + if reduction == 'none': + reduce_flag = False + + if reduce_flag and average_flag: + x = mean(x, tuple(range(x.ndim)), False, None) + + if reduce_flag and not average_flag: + x = sum(x, tuple(range(x.ndim)), False, None) + + return x def abs(input): if use_pyboost(): @@ -1126,7 +1141,7 @@ def pow_scalar_tensor(input, scalar): return legacy.pow(input, scalar) def adaptive_avg_pool2d(input, output_size): - if use_pyboost(): + if use_pyboost() and not ON_ORANGE_PI: return pyboost.adaptive_avg_pool2d_ext_op(input, output_size) return legacy.adaptive_avg_pool2_d(input, output_size) @@ -1362,6 +1377,21 @@ def _check_maxpool_padding(padding, nd): return (0,) * (3 - nd) + tuple(padding) return padding +def _cal_dilation(dilation, nd): + """check the dilation""" + if isinstance(dilation, int): + return dilation + if isinstance(dilation, tuple): + if len(dilation) == 1: + return dilation[0] + if len(dilation) == nd: + return (3 - nd) * (1,) + dilation + if nd == 1: + raise ValueError(f"the length of 'dilation' must be 1, but got {len(dilation)}.") + raise ValueError(f"the length of 'dilation' must be 1 or {nd}, but got {len(dilation)}.") + raise ValueError(f"the 'dilation' must be int or tuple, but got {type(dilation)}.") + + def max_pool2d(input, kernel_size, stride=1, padding=0, dilation=1, ceil_mode=False, return_indices=False): # out, indices = legacy.max_pool_with_argmax_v2(input, kernel_size, stride, padding, dilation, ceil_mode) if not ON_ORANGE_PI: @@ -1379,6 +1409,7 @@ def max_pool2d(input, kernel_size, stride=1, padding=0, dilation=1, ceil_mode=Fa elif isinstance(stride, int): stride = (1, stride, stride) padding = _check_maxpool_padding(padding, 2) + dilation = _cal_dilation(dilation, 2) input = expand_dims(input, 2) out, indices = legacy.max_pool3_d_with_argmax(input, kernel_size, stride, padding, @@ -1550,9 +1581,9 @@ def outer(input, other): return legacy.outer(input, other) def addcmul(input, tensor1, tensor2, value=1.0): - if use_pyboost(): + if use_pyboost() and not ON_ORANGE_PI: return pyboost.addcmul_op(input, tensor1, tensor2, value) - return legacy.addcmul(input, tensor1, tensor2, value) + return legacy.add(mul(mul(tensor1, tensor2), value), input) def prelu(input, weight): if use_pyboost(): diff --git a/mindtorch/_tensor.py b/mindtorch/_tensor.py index b64692fbd..6e55977f1 100644 --- a/mindtorch/_tensor.py +++ b/mindtorch/_tensor.py @@ -287,6 +287,8 @@ def __setitem__(self, slices, value): def __add__(self, other): # if 0 in self.shape: # return self + if self.dtype == mindtorch.bool: + return ops.bitwise_or(self, other) return ops.add(self, other) def __iadd__(self, other): diff --git a/mindtorch/nn/functional.py b/mindtorch/nn/functional.py index 59cd180fb..361fafd49 100644 --- a/mindtorch/nn/functional.py +++ b/mindtorch/nn/functional.py @@ -9,7 +9,7 @@ from mindtorch._C import default_generator from mindtorch.nn.modules.utils import _pair -from ..configs import ON_A2, ON_A1, FLASH_ATTN_MASK_VALID +from ..configs import ON_A2, ON_A1, ON_ORANGE_PI, FLASH_ATTN_MASK_VALID generator_step_ = 12 @@ -991,7 +991,7 @@ def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5): weight = mindtorch.ones([input.shape[1]], dtype=input.dtype, device=input.device) if bias is None: bias = mindtorch.zeros([input.shape[1]], dtype=input.dtype, device=input.device) - if input.device.type == 'npu': + if input.device.type == 'npu' and not ON_ORANGE_PI: return execute('group_norm', input, num_groups, weight, bias, eps)[0] input_shape = input.shape