Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions mindtorch/_apis/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def add(input, other, alpha=1.0): # pylint: disable=unused-argument
Returns:
Tensor: The result of the addition.
"""
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.add_ext_op(input, other, alpha)
if alpha == 1.0:
return legacy.add(input, other)
Expand Down Expand Up @@ -724,7 +724,7 @@ def less_equal(input, other):

def select(condition, input, other):
if ON_ORANGE_PI:
return add(mul(condition, input), mul(bitwise_not(condition), other))
return legacy.add(mul(condition, input), mul(bitwise_not(condition), other))
if use_pyboost():
return pyboost.select_op(condition, input, other)
return legacy.select(condition, input, other)
Expand Down Expand Up @@ -975,8 +975,23 @@ def inplace_zero(input):
return input

def mse_loss(input, target, reduction):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.mse_loss_ext_op(input, target, reduction)
x = square(input - target)
average_flag = True
reduce_flag = True
if reduction == 'sum':
average_flag = False
if reduction == 'none':
reduce_flag = False

if reduce_flag and average_flag:
x = mean(x, tuple(range(x.ndim)), False, None)

if reduce_flag and not average_flag:
x = sum(x, tuple(range(x.ndim)), False, None)

return x

def abs(input):
if use_pyboost():
Expand Down Expand Up @@ -1126,7 +1141,7 @@ def pow_scalar_tensor(input, scalar):
return legacy.pow(input, scalar)

def adaptive_avg_pool2d(input, output_size):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.adaptive_avg_pool2d_ext_op(input, output_size)
return legacy.adaptive_avg_pool2_d(input, output_size)

Expand Down Expand Up @@ -1362,6 +1377,21 @@ def _check_maxpool_padding(padding, nd):
return (0,) * (3 - nd) + tuple(padding)
return padding

def _cal_dilation(dilation, nd):
"""check the dilation"""
if isinstance(dilation, int):
return dilation
if isinstance(dilation, tuple):
if len(dilation) == 1:
return dilation[0]
if len(dilation) == nd:
return (3 - nd) * (1,) + dilation
if nd == 1:
raise ValueError(f"the length of 'dilation' must be 1, but got {len(dilation)}.")
raise ValueError(f"the length of 'dilation' must be 1 or {nd}, but got {len(dilation)}.")
raise ValueError(f"the 'dilation' must be int or tuple, but got {type(dilation)}.")


def max_pool2d(input, kernel_size, stride=1, padding=0, dilation=1, ceil_mode=False, return_indices=False):
# out, indices = legacy.max_pool_with_argmax_v2(input, kernel_size, stride, padding, dilation, ceil_mode)
if not ON_ORANGE_PI:
Expand All @@ -1379,6 +1409,7 @@ def max_pool2d(input, kernel_size, stride=1, padding=0, dilation=1, ceil_mode=Fa
elif isinstance(stride, int):
stride = (1, stride, stride)
padding = _check_maxpool_padding(padding, 2)
dilation = _cal_dilation(dilation, 2)

input = expand_dims(input, 2)
out, indices = legacy.max_pool3_d_with_argmax(input, kernel_size, stride, padding,
Expand Down Expand Up @@ -1550,9 +1581,9 @@ def outer(input, other):
return legacy.outer(input, other)

def addcmul(input, tensor1, tensor2, value=1.0):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.addcmul_op(input, tensor1, tensor2, value)
return legacy.addcmul(input, tensor1, tensor2, value)
return legacy.add(mul(mul(tensor1, tensor2), value), input)

def prelu(input, weight):
if use_pyboost():
Expand Down
2 changes: 2 additions & 0 deletions mindtorch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ def __setitem__(self, slices, value):
def __add__(self, other):
# if 0 in self.shape:
# return self
if self.dtype == mindtorch.bool:
return ops.bitwise_or(self, other)
return ops.add(self, other)

def __iadd__(self, other):
Expand Down
4 changes: 2 additions & 2 deletions mindtorch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mindtorch._C import default_generator
from mindtorch.nn.modules.utils import _pair

from ..configs import ON_A2, ON_A1, FLASH_ATTN_MASK_VALID
from ..configs import ON_A2, ON_A1, ON_ORANGE_PI, FLASH_ATTN_MASK_VALID

generator_step_ = 12

Expand Down Expand Up @@ -991,7 +991,7 @@ def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
weight = mindtorch.ones([input.shape[1]], dtype=input.dtype, device=input.device)
if bias is None:
bias = mindtorch.zeros([input.shape[1]], dtype=input.dtype, device=input.device)
if input.device.type == 'npu':
if input.device.type == 'npu' and not ON_ORANGE_PI:
return execute('group_norm', input, num_groups, weight, bias, eps)[0]

input_shape = input.shape
Expand Down