Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 138 additions & 22 deletions mindtorch/_apis/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,13 @@ def dense(input, weight, bias=None):
Tensor: The result of the dense operation.
"""
if ON_ORANGE_PI:
dtype = input.dtype
input = cast(input, mindspore.float16)
weight = cast(weight, mindspore.float16)
if bias is None:
return pyboost.dense_op(input, weight)

bias = cast(bias, mindspore.float16)
return add(pyboost.dense_op(input, weight), bias)
out = cast(pyboost.dense_op(input, weight), dtype)
if bias is not None:
out = add(out, bias)
return out

if use_pyboost():
return pyboost.dense_op(input, weight, bias)
Expand Down Expand Up @@ -820,9 +820,13 @@ def argmin(input, axis, keepdims):


def bmm(input, other):
if ON_ORANGE_PI:
dtype = input.dtype
out = pyboost.bmm_ext_op(cast(input, mindspore.float16), cast(other, mindspore.float16))
return cast(out, dtype)
if use_pyboost():
return pyboost.bmm_ext_op(input, other)
return legacy.batch_mat_mul(input, other)
return legacy.batch_mat_mul(input, other, False, False)

def topk(input, k, dim, largest, sorted):
if use_pyboost():
Expand Down Expand Up @@ -1198,14 +1202,47 @@ def roll(input, shifts, axis):
return legacy.roll(input, shifts, axis)

def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.conv1d_ext_op(input, weight, bias, stride, padding, dilation, groups)
return legacy.conv1d(input, weight, bias, pad, stride, dilation)
return conv1d_legacy(input, weight, bias, stride, padding, dilation, groups)

def conv1d_legacy(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
pad_mode = 'pad'
pad = padding
if isinstance(padding, tuple):
pad = (0, 0, padding[0], padding[0])
elif isinstance(padding, int):
pad = (0, 0) + (padding,) * 2
if not isinstance(padding, (int, tuple)):
pad_mode = padding
pad = (0,) * 4

input = expand_dims(input, 2)
weight = expand_dims(weight, 2)

output = legacy.conv2_d(
input, weight,
weight.shape[0],
(1, weight.shape[-1]),
1,#mode=1,
pad_mode, #pad_mode=pad_mode,
pad, #pad=pad,
(1, stride) if isinstance(stride, int) else (1, *stride), #stride=tuple(stride),
(1, dilation) if isinstance(dilation, int) else (1, *dilation), #dilation=dilation,
groups, #group=groups,
"NCHW", #data_format="NCHW"
)

if bias is not None:
output = legacy.bias_add(output, bias, "NCHW")

output = squeeze(output, 2)
return output

def conv1d_padding(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.conv1d_padding_op(input, weight, bias, stride, padding, dilation, groups)
return legacy.conv1d(input, weight, bias, pad, stride, dilation)
return conv1d_legacy(input, weight, bias, stride, padding, dilation, groups)

def square(input):
if use_pyboost():
Expand Down Expand Up @@ -1233,14 +1270,14 @@ def split_with_size(input, size, dim=0):
return legacy.split_with_size(input, size, dim)

def softplus(input, beta=1, threshold=20):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.softplus_ext_op(input, beta, threshold)
return legacy.softplus(input, beta, threshold)
return legacy.softplus(input)

def remainder_tensor_scalar(input, other):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.remainder_tensor_scalar_op(input, other)
out = input - floor_div(input, other) * other
out = sub(input, mul(floor_div(input, other), other), 1)
return out

def baddbmm(input, batch1, batch2, alpha=1, beta=1):
Expand All @@ -1253,28 +1290,107 @@ def floor(input):
return pyboost.floor_op(input)
return legacy.floor(input)


def _deconv_output_length(pad_mode, filter_size, stride_size, dilation_size, padding):
"""Calculate the width and height of output."""
length = 0
filter_size = filter_size + (filter_size - 1) * (dilation_size - 1)
if pad_mode == 'valid':
if filter_size - stride_size > 0:
length = filter_size - stride_size
elif pad_mode == 'pad':
length = - padding + filter_size - stride_size

return length

def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.conv_transpose2d_op(input, weight, bias, stride, padding, output_padding, groups, dilation)
return legacy.conv_transpose2d(input, weight, bias, stride, padding, output_padding, groups, dilation)
pad_mode = 'pad'
pad = padding
if isinstance(padding, tuple):
pad = (padding[0], padding[0], padding[1], padding[1])
elif isinstance(padding, int):
pad = (padding,) * 4
if not isinstance(padding, (int, tuple)):
pad_mode = padding
pad = (0,) * 4

if isinstance(dilation, int):
dilation = (dilation, dilation)

in_channel, out_channels = weight.shape[0], weight.shape[1] * groups
kernel_size = weight.shape[2:]
n, _, h, w = input.shape
h_add = _deconv_output_length(pad_mode, kernel_size[0], stride[0], dilation[0], pad[0] + pad[1])
w_add = _deconv_output_length(pad_mode, kernel_size[1], stride[1], dilation[1], pad[2] + pad[3])

out = legacy.conv2_d_transpose(
input, weight,
(n, out_channels, h * stride[0] + h_add, w * stride[1] + w_add),
out_channels,
kernel_size,
pad_mode,
pad,
None,
1,
stride,
dilation,
groups,
'NCHW'
)
if bias is not None:
out = legacy.bias_add(out, bias, 'NCHW')
return out



def relu(input):
if use_pyboost():
return pyboost.relu_op(input)
return legacy.re_lu(input)

def _check_maxpool_padding(padding, nd):
"""Calculate maxpool padding before call primitive"""
if isinstance(padding, int):
return (0,) * (3 - nd) + (padding,) * nd
if isinstance(padding, (tuple, list)):
if len(padding) == 1:
return (0,) * (3 - nd) + tuple(padding * nd)
if len(padding) != nd:
raise ValueError(f"For {cls_name}, the length of padding must equal to {nd}, but got {len(padding)}.")
return (0,) * (3 - nd) + tuple(padding)
return padding

def max_pool2d(input, kernel_size, stride=1, padding=0, dilation=1, ceil_mode=False, return_indices=False):
# out, indices = legacy.max_pool_with_argmax_v2(input, kernel_size, stride, padding, dilation, ceil_mode)

out, indices = legacy.max_pool_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode)
if not ON_ORANGE_PI:
out, indices = legacy.max_pool_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode)
if return_indices:
return out, indices
return out

if isinstance(kernel_size, tuple):
kernel_size = (1,) + kernel_size
elif isinstance(kernel_size, int):
kernel_size = (1, kernel_size, kernel_size)
if isinstance(stride, tuple):
stride = (1,) + stride
elif isinstance(stride, int):
stride = (1, stride, stride)
padding = _check_maxpool_padding(padding, 2)

input = expand_dims(input, 2)
out, indices = legacy.max_pool3_d_with_argmax(input, kernel_size, stride, padding,
dilation, ceil_mode, 'NCDHW', mindspore.int64)
if return_indices:
return out, indices
return out
return squeeze(out, 2), squeeze(indices, 2)
return squeeze(out, 2)

def upsample_bilinear2d(input, size=None, scale_factor=None, align_corners=False):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.upsample_bilinear2d_op(input, size, scale_factor, align_corners)
return legacy.resize_bilinear_v2(input, size, scale_factor, align_corners)
return legacy.resize_bilinear_v2(input, size, align_corners, not align_corners)

def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
if use_pyboost():
Expand Down
4 changes: 2 additions & 2 deletions mindtorch/_op_prim/ascend/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2294,8 +2294,8 @@ def max_pool_with_argmax(*args):


def max_pool_with_argmax_v2(*args):
op = _get_cache_prim(MaxPoolWithArgmaxV2)(*args[-5:]).set_device('Ascend')
return op(*args[:-5])
op = _get_cache_prim(MaxPoolWithArgmaxV2)(*args[-6:]).set_device('Ascend')
return op(*args[:-6])


def max_unpool2_d(*args):
Expand Down