Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mindtorch/_apis/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,4 +1255,7 @@ def cumprod(input, dim, dtype):
return out

def lerp(input, end, weight):
return legacy.lerp(input, end, weight)
return legacy.lerp(input, end, weight)

def smooth_l1_loss(input, target, beta=1.0, reduction='none'):
return legacy.smooth_l1_loss(input, target, beta, reduction)
26 changes: 19 additions & 7 deletions mindtorch/_apis/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ def index(input, index):
def scatter(input, dim, index, src):
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.scatter_op(input, dim, index, src)
return legacy.tensor_scatter_elements(input, index, src, dim, "none")
return legacy.tensor_scatter_elements(input, index, cast(src, input.dtype), dim, "none")

def tril(input, diagonal=0):
if use_pyboost():
Expand Down Expand Up @@ -858,7 +858,8 @@ def isinf(input):
def sort(input, dim, descending, stable):
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.sort_ext_op(input, dim, descending, stable)
return legacy.sort(input, dim, descending)
out = legacy.sort(input, dim, descending)
return out[0], cast(out[1], mindspore.int64)

def prod(input, axis, keepdims, dtype):
if use_pyboost():
Expand Down Expand Up @@ -1612,9 +1613,15 @@ def inplace_add(input, other, alpha):
return legacy.inplace_add(input, other)

def logsumexp(input, dim, keepdim):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.logsumexp_op(input, dim, keepdim)
return legacy.logsumexp(input, dim, keepdim)
input_max = legacy.reduce_max(input, dim, True)
input_exp = exp(sub(input, input_max))
input_sumexp = sum(input_exp, dim, keepdim, None)
input_logsumexp = log(input_sumexp)
if not keepdim:
input_max = squeeze(input_max, dim)
return add(input_logsumexp, input_max)

def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity):
loss, log_alpha = legacy.ctc_loss_v2(log_probs, targets, input_lengths, target_lengths, blank, 'none', zero_infinity)
Expand Down Expand Up @@ -1922,9 +1929,11 @@ def linalg_qr(input_x, mode):

def bernoulli(input, generator):
seed, offset = generator._step(12)
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.bernoulli_ext_op(input, seed, offset)
return legacy.bernoulli(input, seed, offset)
uniform = rand_like(input, generator, input.dtype)
result = cast(less(uniform, input), input.dtype)
return result

def multinomial(input, num_samples, replacement, generator):
seed, offset = generator._step(12) # pylint: disable=protected-access
Expand Down Expand Up @@ -1998,4 +2007,7 @@ def replication_pad_1d(input, padding):
return pyboost.reflection_pad_1d_op(input, padding)

def hardtanh(input, min_val, max_val):
return pyboost.hardtanh_op(input, min_val, max_val)
return pyboost.hardtanh_op(input, min_val, max_val)

def smooth_l1_loss(input, target, beta=1.0, reduction='none'):
return pyboost.smooth_l1_loss_impl(input, target, beta, reduction)
4 changes: 1 addition & 3 deletions mindtorch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,9 +556,7 @@ def l1_loss(input, target, reduction='mean'):
return execute('l1_loss', input, target, reduction)

def smooth_l1_loss(input, target, beta=1.0, reduction='none'):
input = input.to(mindtorch.float32)
target = target.to(mindtorch.float32)
return ops.smooth_l1_loss(input, target, beta, reduction)
return execute('smooth_l1_loss', input, target, beta, reduction)

def kl_div(input, target, reduction='mean', log_target=False):
if reduction == 'batchmean':
Expand Down
4 changes: 2 additions & 2 deletions mindtorch/ops/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def _record_tensor_index(index, remain_indexes, dim, device):

while dim > len(remain_indexes):
# use empty_tensor with dim_num 9 to indicate unused dim
if device.type == 'npu':
if device.type == 'npu' and not ON_ORANGE_PI:
remain_indexes.append(empty_tensor_9d)
else:
remain_indexes.append(slice(None, None, None))
Expand Down Expand Up @@ -650,7 +650,7 @@ def tensor_getitem(self, index):
if not remain_indexes:
return self_viewed

if self.device.type == 'npu':
if self.device.type == 'npu' and not ON_ORANGE_PI:
return execute('index', self_viewed, remain_indexes)

return getitem(self_viewed, tuple(remain_indexes) if len(remain_indexes) > 1 else remain_indexes[0])
Expand Down