Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion mindnlp/core/ops/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,42 @@ def gather(input, dim, index):
index = core.where(index < input.shape[dim], index, index - input.shape[dim])
if not ON_ORANGE_PI:
return ops.gather_elements(input, dim, index)
return tf_gather(input, index, dim, batch_dims=dim)

return torch_gather(input, index, dim)

def gather_nd(input, indices):
return ops.gather_nd(input, indices)

def tf_gather(input, indices, axis, batch_dims=0):
return ops.gather(input, indices, axis, batch_dims)

def torch_gather(x, indices, axis=1):
# 这个实现模拟了 torch.gather 的行为
if axis < 0:
axis = len(x.shape) + axis

# 创建索引数组,其他维度保持原样
all_indices = []
for dim in range(len(x.shape)):
if dim == axis:
# 使用提供的索引
all_indices.append(indices.to(mindspore.int32))
else:
# 创建该维度的原始索引
shape = [1] * len(x.shape)
shape[dim] = x.shape[dim]
dim_indices = core.arange(x.shape[dim], dtype=mindspore.int32)
dim_indices = core.reshape(dim_indices, shape)
# 广播到 indices 的形状
dim_indices = core.broadcast_to(dim_indices, indices.shape)
all_indices.append(dim_indices)

# 组合所有维度的索引
multi_indices = core.stack(all_indices, axis=-1)

# 使用 tf.gather_nd 收集元素
return gather_nd(x, multi_indices)

# hsplit


Expand Down
4 changes: 2 additions & 2 deletions mindnlp/core/ops/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
rtol = rtol.item() if isinstance(rtol, mindspore.Tensor) else rtol
atol = atol.item() if isinstance(atol, mindspore.Tensor) else atol
if use_pyboost() and has_allclose:
if use_pyboost() and has_allclose and not ON_ORANGE_PI:
return mindspore.mint.allclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan)
return np.allclose(input.numpy(), other.numpy(), rtol, atol, equal_nan)

Expand All @@ -37,7 +37,7 @@ def eq(input, other, *, out=None):
# equal
has_equal = hasattr(mindspore.mint, 'equal')
def equal(input, other):
if use_pyboost() and has_equal:
if use_pyboost() and has_equal and not ON_ORANGE_PI:
return mindspore.mint.equal(input, other)
if input.shape != other.shape:
return False
Expand Down
5 changes: 4 additions & 1 deletion mindnlp/core/ops/pointwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,10 @@ def mul(input, other, *, out=None):
else:
if input.dtype == mindspore.bool_:
if isinstance(other, bool):
out = ops.bitwise_and(input, other)
if ON_ORANGE_PI:
out = ops.bitwise_and(input.int(), other).bool()
else:
out = ops.bitwise_and(input, other)
else:
out = ops.mul(input.int(), other)
else:
Expand Down
6 changes: 3 additions & 3 deletions mindnlp/core/ops/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import mindspore
from mindspore import ops
from mindspore.ops._primitive_cache import _get_cache_prim
from ..configs import use_pyboost, DEVICE_TARGET
from ..configs import use_pyboost, DEVICE_TARGET, ON_ORANGE_PI

from ._inner import call_ms_func
from mindnlp import core
Expand Down Expand Up @@ -137,7 +137,7 @@ def nanmedian(input, dim=-1, keepdim=False):
# norm
has_norm = hasattr(mindspore.mint, 'norm')
def norm(input, p='fro', dim=None, keepdim=False, out=None, dtype=None):
if use_pyboost() and has_norm:
if use_pyboost() and has_norm and not ON_ORANGE_PI:
return call_ms_func(mindspore.mint.norm, input, p, dim, keepdim, out=out, dtype=dtype)
if p == 'fro':
p = None
Expand Down Expand Up @@ -339,7 +339,7 @@ def std(input, dim=None, *, correction=1, keepdim=False, **kwargs):
axis = kwargs.get('axis', None)
if axis is not None:
dim = axis
if use_pyboost() and has_std:
if use_pyboost() and has_std and not ON_ORANGE_PI:
return mindspore.mint.std(input, dim=dim, correction=correction, keepdim=keepdim)
if DEVICE_TARGET == 'GPU':
unbiased = bool(correction)
Expand Down
Loading