Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions mindnlp/core/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,10 @@ def __setitem__(self, slices, value):
value = tensor(value, dtype=self.dtype)
else:
value = value.to(self.dtype)

if 1 in value.shape and self[slices].ndim != value.ndim:
value = value.squeeze()

return origin_setitem(self, slices, value)

Tensor.__setitem__ = __setitem__
Expand Down Expand Up @@ -658,6 +660,21 @@ def __contains__(self, item):
Tensor.exponential_ = ops.inplace_exponential
StubTensor.exponential_ = ops.inplace_exponential

Tensor.log_ = ops.inplace_log
StubTensor.log_ = ops.inplace_log

Tensor.mul_ = ops.inplace_mul
StubTensor.mul_ = ops.inplace_mul

Tensor.neg_ = ops.inplace_neg
StubTensor.neg_ = ops.inplace_neg

Tensor.exp_ = ops.inplace_exp
StubTensor.exp_ = ops.inplace_exp

Tensor.sub_ = ops.inplace_sub
StubTensor.sub_ = ops.inplace_sub

def _rebuild_from_type_v2(func, new_type, args, state):
ret = func(*args)
return ret
33 changes: 31 additions & 2 deletions mindnlp/core/ops/inplace.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numbers
import mindspore
from mindspore import ops
from mindspore.ops._primitive_cache import _get_cache_prim
from mindspore.common.generator import default_generator
from mindspore.ops.auto_generate.gen_ops_prim import inplace_normal_op, inplace_scatter_value_op, inplace_scatter_src_reduce_op
from mindspore.ops.auto_generate.gen_ops_prim import inplace_normal_op, inplace_scatter_value_op, inplace_scatter_src_reduce_op, \
inplace_scatter_src_op

from mindnlp import core
from ..configs import use_pyboost
Expand Down Expand Up @@ -85,6 +87,8 @@ def inplace_add(input, other, alpha):
return input

def inplace_scatter(input, dim, index, src):
if not isinstance(src, numbers.Number):
return inplace_scatter_src_op(input, dim, index, src)
return inplace_scatter_value_op(input, dim, index, src)

def inplace_index_copy(input, dim, index, tensor):
Expand Down Expand Up @@ -157,6 +161,26 @@ def inplace_exponential(tensor, lambd=1.0):

return tensor

def inplace_log(self):
self.data = core.log(self)
return self

def inplace_mul(self, other):
self.data = core.mul(self, other)
return self

def inplace_neg(self):
self.data = core.neg(self)
return self

def inplace_exp(self):
self.data = core.exp(self)
return self

def inplace_sub(self, other):
self.data = core.sub(self, other)
return self

__all__ = [
'inplace_copy',
'inplace_zero',
Expand All @@ -173,5 +197,10 @@ def inplace_exponential(tensor, lambd=1.0):
'inplace_triu',
'inplace_round',
'inplace_scatter_reduce',
'inplace_exponential'
'inplace_exponential',
'inplace_log',
'inplace_mul',
'inplace_neg',
'inplace_exp',
'inplace_sub'
]
20 changes: 17 additions & 3 deletions mindnlp/core/ops/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,10 @@ def cumsum(input, dim=None, dtype=None, out=None, **kwargs):

# diag
has_diag = hasattr(mindspore.mint, "diag")


def diag(input, diagonal=0):
if use_pyboost() and has_diag:
return mindspore.mint.diag(input, diagonal)
return ops.diag(input)
return mindspore.numpy.diag(input, diagonal)


# diag_embed
Expand Down Expand Up @@ -806,6 +804,8 @@ def searchsorted(
sorter=None,
):
if use_pyboost() and has_searchsorted:
if not isinstance(values, core.Tensor):
values = core.tensor(values)
return call_ms_func(
mindspore.mint.searchsorted,
sorted_sequence,
Expand Down Expand Up @@ -1030,12 +1030,26 @@ def unfold(input, dimension, size, step):
return output


def cartesian_prod(*tensors):
"""
手动实现 torch.cartesian_prod
:param tensors: 一个或多个一维张量
:return: 笛卡尔积结果的二维张量 (每行一个组合)
"""
# 生成网格坐标
grids = core.meshgrid(*tensors, indexing='ij')

# 展平每个网格张量并堆叠
return core.stack([g.reshape(-1) for g in grids], dim=1)


__all__ = [
"bincount",
"broadcast_shapes",
"broadcast_tensors",
"broadcast_to",
"bucketize",
"cartesian_prod",
"cdist",
"clone",
"contains",
Expand Down
Loading
Loading