diff --git a/mindtorch/_apis/npu.py b/mindtorch/_apis/npu.py index fe5ce7064..d360d198a 100644 --- a/mindtorch/_apis/npu.py +++ b/mindtorch/_apis/npu.py @@ -217,13 +217,13 @@ def dense(input, weight, bias=None): Tensor: The result of the dense operation. """ if ON_ORANGE_PI: + dtype = input.dtype input = cast(input, mindspore.float16) weight = cast(weight, mindspore.float16) - if bias is None: - return pyboost.dense_op(input, weight) - - bias = cast(bias, mindspore.float16) - return add(pyboost.dense_op(input, weight), bias) + out = cast(pyboost.dense_op(input, weight), dtype) + if bias is not None: + out = add(out, bias) + return out if use_pyboost(): return pyboost.dense_op(input, weight, bias) @@ -820,9 +820,13 @@ def argmin(input, axis, keepdims): def bmm(input, other): + if ON_ORANGE_PI: + dtype = input.dtype + out = pyboost.bmm_ext_op(cast(input, mindspore.float16), cast(other, mindspore.float16)) + return cast(out, dtype) if use_pyboost(): return pyboost.bmm_ext_op(input, other) - return legacy.batch_mat_mul(input, other) + return legacy.batch_mat_mul(input, other, False, False) def topk(input, k, dim, largest, sorted): if use_pyboost(): @@ -1198,14 +1202,47 @@ def roll(input, shifts, axis): return legacy.roll(input, shifts, axis) def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): - if use_pyboost(): + if use_pyboost() and not ON_ORANGE_PI: return pyboost.conv1d_ext_op(input, weight, bias, stride, padding, dilation, groups) - return legacy.conv1d(input, weight, bias, pad, stride, dilation) + return conv1d_legacy(input, weight, bias, stride, padding, dilation, groups) + +def conv1d_legacy(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + pad_mode = 'pad' + pad = padding + if isinstance(padding, tuple): + pad = (0, 0, padding[0], padding[0]) + elif isinstance(padding, int): + pad = (0, 0) + (padding,) * 2 + if not isinstance(padding, (int, tuple)): + pad_mode = padding + pad = (0,) * 4 + + input = expand_dims(input, 2) + weight = expand_dims(weight, 2) + + output = legacy.conv2_d( + input, weight, + weight.shape[0], + (1, weight.shape[-1]), + 1,#mode=1, + pad_mode, #pad_mode=pad_mode, + pad, #pad=pad, + (1, stride) if isinstance(stride, int) else (1, *stride), #stride=tuple(stride), + (1, dilation) if isinstance(dilation, int) else (1, *dilation), #dilation=dilation, + groups, #group=groups, + "NCHW", #data_format="NCHW" + ) + + if bias is not None: + output = legacy.bias_add(output, bias, "NCHW") + + output = squeeze(output, 2) + return output def conv1d_padding(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): - if use_pyboost(): + if use_pyboost() and not ON_ORANGE_PI: return pyboost.conv1d_padding_op(input, weight, bias, stride, padding, dilation, groups) - return legacy.conv1d(input, weight, bias, pad, stride, dilation) + return conv1d_legacy(input, weight, bias, stride, padding, dilation, groups) def square(input): if use_pyboost(): @@ -1233,14 +1270,14 @@ def split_with_size(input, size, dim=0): return legacy.split_with_size(input, size, dim) def softplus(input, beta=1, threshold=20): - if use_pyboost(): + if use_pyboost() and not ON_ORANGE_PI: return pyboost.softplus_ext_op(input, beta, threshold) - return legacy.softplus(input, beta, threshold) + return legacy.softplus(input) def remainder_tensor_scalar(input, other): - if use_pyboost(): + if use_pyboost() and not ON_ORANGE_PI: return pyboost.remainder_tensor_scalar_op(input, other) - out = input - floor_div(input, other) * other + out = sub(input, mul(floor_div(input, other), other), 1) return out def baddbmm(input, batch1, batch2, alpha=1, beta=1): @@ -1253,28 +1290,107 @@ def floor(input): return pyboost.floor_op(input) return legacy.floor(input) + +def _deconv_output_length(pad_mode, filter_size, stride_size, dilation_size, padding): + """Calculate the width and height of output.""" + length = 0 + filter_size = filter_size + (filter_size - 1) * (dilation_size - 1) + if pad_mode == 'valid': + if filter_size - stride_size > 0: + length = filter_size - stride_size + elif pad_mode == 'pad': + length = - padding + filter_size - stride_size + + return length + def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): - if use_pyboost(): + if use_pyboost() and not ON_ORANGE_PI: return pyboost.conv_transpose2d_op(input, weight, bias, stride, padding, output_padding, groups, dilation) - return legacy.conv_transpose2d(input, weight, bias, stride, padding, output_padding, groups, dilation) + pad_mode = 'pad' + pad = padding + if isinstance(padding, tuple): + pad = (padding[0], padding[0], padding[1], padding[1]) + elif isinstance(padding, int): + pad = (padding,) * 4 + if not isinstance(padding, (int, tuple)): + pad_mode = padding + pad = (0,) * 4 + + if isinstance(dilation, int): + dilation = (dilation, dilation) + + in_channel, out_channels = weight.shape[0], weight.shape[1] * groups + kernel_size = weight.shape[2:] + n, _, h, w = input.shape + h_add = _deconv_output_length(pad_mode, kernel_size[0], stride[0], dilation[0], pad[0] + pad[1]) + w_add = _deconv_output_length(pad_mode, kernel_size[1], stride[1], dilation[1], pad[2] + pad[3]) + + out = legacy.conv2_d_transpose( + input, weight, + (n, out_channels, h * stride[0] + h_add, w * stride[1] + w_add), + out_channels, + kernel_size, + pad_mode, + pad, + None, + 1, + stride, + dilation, + groups, + 'NCHW' + ) + if bias is not None: + out = legacy.bias_add(out, bias, 'NCHW') + return out + + def relu(input): if use_pyboost(): return pyboost.relu_op(input) return legacy.re_lu(input) +def _check_maxpool_padding(padding, nd): + """Calculate maxpool padding before call primitive""" + if isinstance(padding, int): + return (0,) * (3 - nd) + (padding,) * nd + if isinstance(padding, (tuple, list)): + if len(padding) == 1: + return (0,) * (3 - nd) + tuple(padding * nd) + if len(padding) != nd: + raise ValueError(f"For {cls_name}, the length of padding must equal to {nd}, but got {len(padding)}.") + return (0,) * (3 - nd) + tuple(padding) + return padding + def max_pool2d(input, kernel_size, stride=1, padding=0, dilation=1, ceil_mode=False, return_indices=False): # out, indices = legacy.max_pool_with_argmax_v2(input, kernel_size, stride, padding, dilation, ceil_mode) - - out, indices = legacy.max_pool_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) + if not ON_ORANGE_PI: + out, indices = legacy.max_pool_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) + if return_indices: + return out, indices + return out + + if isinstance(kernel_size, tuple): + kernel_size = (1,) + kernel_size + elif isinstance(kernel_size, int): + kernel_size = (1, kernel_size, kernel_size) + if isinstance(stride, tuple): + stride = (1,) + stride + elif isinstance(stride, int): + stride = (1, stride, stride) + padding = _check_maxpool_padding(padding, 2) + + input = expand_dims(input, 2) + out, indices = legacy.max_pool3_d_with_argmax(input, kernel_size, stride, padding, + dilation, ceil_mode, 'NCDHW', mindspore.int64) if return_indices: - return out, indices - return out + return squeeze(out, 2), squeeze(indices, 2) + return squeeze(out, 2) def upsample_bilinear2d(input, size=None, scale_factor=None, align_corners=False): - if use_pyboost(): + if use_pyboost() and not ON_ORANGE_PI: return pyboost.upsample_bilinear2d_op(input, size, scale_factor, align_corners) - return legacy.resize_bilinear_v2(input, size, scale_factor, align_corners) + return legacy.resize_bilinear_v2(input, size, align_corners, not align_corners) def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5): if use_pyboost(): diff --git a/mindtorch/_op_prim/ascend/legacy.py b/mindtorch/_op_prim/ascend/legacy.py index a63d92379..4c96f0144 100644 --- a/mindtorch/_op_prim/ascend/legacy.py +++ b/mindtorch/_op_prim/ascend/legacy.py @@ -2294,8 +2294,8 @@ def max_pool_with_argmax(*args): def max_pool_with_argmax_v2(*args): - op = _get_cache_prim(MaxPoolWithArgmaxV2)(*args[-5:]).set_device('Ascend') - return op(*args[:-5]) + op = _get_cache_prim(MaxPoolWithArgmaxV2)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) def max_unpool2_d(*args):