Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mindnlp/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
strided = None
contiguous_format = None
preserve_format = None
legacy_contiguous_format = None

inf = float("inf")
nan = float("nan")
Expand Down
12 changes: 12 additions & 0 deletions mindnlp/core/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,18 @@ def bfloat16(self):
Tensor.bfloat16 = bfloat16
StubTensor.bfloat16 = bfloat16

def sort(self, dim=-1, descending=False):
return ops.sort(self, dim=dim, descending=descending)

Tensor.sort = sort
StubTensor.sort = sort

Tensor.cumsum = ops.cumsum
StubTensor.cumsum = ops.cumsum

Tensor.scatter_ = ops.inplace_scatter
StubTensor.scatter_ = ops.inplace_scatter

def _rebuild_from_type_v2(func, new_type, args, state):
ret = func(*args)
return ret
5 changes: 4 additions & 1 deletion mindnlp/core/ops/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def range(start=0, end=None, step=1, dtype=None):
def linspace(start, end, steps, *, dtype=None, **kwargs):
if dtype is None:
dtype = mindspore.float32
start = start.item() if isinstance(start, mindspore.Tensor) else start
end = end.item() if isinstance(end, mindspore.Tensor) else end
steps = steps.item() if isinstance(steps, mindspore.Tensor) else steps
if use_pyboost() and has_linspace:
return mindspore.mint.linspace(start, end, steps, dtype=dtype)
return ops.linspace(start, end, steps).to(dtype)
Expand All @@ -139,7 +142,7 @@ def logspace(start, end, steps, base=10.0, *, dtype=None):

# eye
has_eye = hasattr(mindspore.mint, 'eye')
def eye(n, m=None, *, dtype=None):
def eye(n, m=None, *, dtype=None, **kwargs):
if use_pyboost() and has_eye:
return mindspore.mint.eye(n, m, dtype)
return ops.eye(n, m, dtype)
Expand Down
6 changes: 2 additions & 4 deletions mindnlp/core/ops/inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from mindspore import ops
from mindspore.ops._primitive_cache import _get_cache_prim
from mindspore.common.generator import default_generator
from mindspore.ops.auto_generate.gen_ops_prim import inplace_normal_op
from mindspore.ops.auto_generate.gen_ops_prim import inplace_normal_op, inplace_scatter_value_op

from mindnlp import core
from ..configs import use_pyboost
Expand Down Expand Up @@ -85,9 +85,7 @@ def inplace_add(input, other, alpha):
return input

def inplace_scatter(input, dim, index, src):
if not isinstance(src, core.Tensor):
return execute('inplace_scatter_value', input, dim, index, src)
return execute('inplace_scatter', input, dim, index, src)
return inplace_scatter_value_op(input, dim, index, src)

def inplace_index_copy(input, dim, index, tensor):
selected = input.index_select(dim, index)
Expand Down
5 changes: 5 additions & 0 deletions mindnlp/core/ops/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ def clone(input):


def cumsum(input, dim, dtype=None, out=None):
input_dtype = input.dtype
if input_dtype == mindspore.int64:
input = input.to(mindspore.int32)
if (
use_pyboost() and has_cumsum and not ON_ORANGE_PI
): # since cann8.0 community remove aclnn cumsum
Expand All @@ -161,6 +164,8 @@ def cumsum(input, dim, dtype=None, out=None):
output = ops.cumsum(input, dim, dtype)
if out is not None:
out.assign_value(output)
return out
output = output.to(input_dtype)
return output


Expand Down
4 changes: 4 additions & 0 deletions mindnlp/core/special/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from mindspore import ops

def logit(input, eps=None, *, out=None):
return ops.logit(input, eps)
Loading