Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions mindtorch/_apis/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ def inplace_copy(self, value):
Args:
value (Tensor): The tensor from which to copy the data.
"""
if use_pyboost:
if use_pyboost():
return pyboost.inplace_copy_op(self, value)
else:
self.assign_value(value)
legacy.assign(self, value)
return self

def slice(input, dim, start, end, step):
Expand All @@ -85,7 +85,15 @@ def slice(input, dim, start, end, step):
if use_pyboost():
return pyboost.slice_ext_op(input, dim, start, end, step)
else:
return legacy.slice(input, dim, start, end, step)
ndim = input.ndim
begins = [0] * ndim
ends = [i for i in input.shape]
strides = [1] * ndim
begins[dim] = start
ends[dim] = end
strides[dim] = step
return legacy.strided_slice(input, tuple(begins), tuple(ends), tuple(strides), 0, 0, 0, 0, 0)


def embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq):
"""
Expand Down Expand Up @@ -829,7 +837,7 @@ def bmm(input, other):
return legacy.batch_mat_mul(input, other, False, False)

def topk(input, k, dim, largest, sorted):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.topk_ext_op(input, k, dim, largest, sorted)

if not largest:
Expand Down Expand Up @@ -1296,9 +1304,9 @@ def remainder_tensor_scalar(input, other):
return out

def baddbmm(input, batch1, batch2, alpha=1, beta=1):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.baddbmm_op(input, batch1, batch2, alpha, beta)
return legacy.baddbmm(input, batch1, batch2, alpha, beta)
return add(mul(input, beta), mul(bmm(batch1, batch2), alpha))

def floor(input):
if use_pyboost():
Expand Down Expand Up @@ -1844,4 +1852,16 @@ def cumprod(input, dim, dtype):
out = legacy.cum_prod(input, dim, False, False)
if dtype is not None:
out = cast(out, dtype)
return out
return out

def scatter_nd_update(input, indices, updates):
return legacy.scatter_nd_update(input, indices, updates, True)

def assign(input, value):
return inplace_copy(input, value)

def strided_slice(input, begin, end, strides, begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0):
return legacy.strided_slice(input, tuple(begin), tuple(end), tuple(strides), begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask)

def tensor_scatter_update(input, indices, updates):
return legacy.tensor_scatter_update(input, indices, updates)
6 changes: 3 additions & 3 deletions mindtorch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class StubTensor: pass
from ._bind import get_device_in_context, device_, get_default_dtype
from ._utils import _rebuild_tensor_v2
from ._C.size import Size
from .configs import DEVICE_TARGET, cpu_use_numpy
from .configs import DEVICE_TARGET, cpu_use_numpy, ON_ORANGE_PI

device_map = {
'cpu': 'CPU',
Expand Down Expand Up @@ -282,7 +282,7 @@ def __setitem__(self, slices, value):
if value.device != self.device:
value._device = self.device

if self.device.type == 'npu':
if self.device.type == 'npu' and not ON_ORANGE_PI:
if value.device != self.device:
value._device = self.device
out = ops.tensor_setitem(self, slices, value)
Expand All @@ -301,7 +301,7 @@ def __iadd__(self, other):
return self.copy_(ops.add(self, other))

def __radd__(self, other):
return Tensor.__add__(other, self)
return ops.add(other, self)

def __div__(self, other):
# if 0 in self.shape:
Expand Down
2 changes: 1 addition & 1 deletion mindtorch/nn/modules/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1804,7 +1804,7 @@ def forward(self, input: Tensor) -> Tensor:
"""
Runs the forward pass.
"""
return F.softmax(input, self.dim, _stacklevel=5)
return F.softmax(input, self.dim)

def extra_repr(self) -> str:
"""
Expand Down
3 changes: 2 additions & 1 deletion mindtorch/ops/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,7 @@ def _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims):
raise TypeError(f"Index {index} contain unsupported elements")
self_viewed, dim, remain_indexes, self_viewed_shape = _process_dim_in_multi_dim_index(
self_viewed, self, index, dim, indexed_dims, i, remain_indexes, self_viewed_shape)

return self_viewed, remain_indexes


Expand Down Expand Up @@ -1162,7 +1163,7 @@ def strided_slice_update(x, begin, end, strides, updates,
# for i in range(ndim-1, -1, -1):
# if (shrink_axis_mask >> i) & 1:
# x_updated = mindtorch.squeeze(x_updated, dim=i)

return x_updated


Expand Down