From 1b92c91f0203aa32deb41403a79aaebd6bd3b5d1 Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Mon, 4 Aug 2025 22:27:24 +0800 Subject: [PATCH] use gather_nd to implement gather for Orange Pi --- mindnlp/core/ops/array.py | 30 +++++++++++++++++++++++++++++- mindnlp/core/ops/comparison.py | 4 ++-- mindnlp/core/ops/pointwise.py | 5 ++++- mindnlp/core/ops/reduction.py | 6 +++--- 4 files changed, 38 insertions(+), 7 deletions(-) diff --git a/mindnlp/core/ops/array.py b/mindnlp/core/ops/array.py index f01a9ab90..9c5c978c3 100644 --- a/mindnlp/core/ops/array.py +++ b/mindnlp/core/ops/array.py @@ -77,7 +77,8 @@ def gather(input, dim, index): index = core.where(index < input.shape[dim], index, index - input.shape[dim]) if not ON_ORANGE_PI: return ops.gather_elements(input, dim, index) - return tf_gather(input, index, dim, batch_dims=dim) + + return torch_gather(input, index, dim) def gather_nd(input, indices): return ops.gather_nd(input, indices) @@ -85,6 +86,33 @@ def gather_nd(input, indices): def tf_gather(input, indices, axis, batch_dims=0): return ops.gather(input, indices, axis, batch_dims) +def torch_gather(x, indices, axis=1): + # 这个实现模拟了 torch.gather 的行为 + if axis < 0: + axis = len(x.shape) + axis + + # 创建索引数组,其他维度保持原样 + all_indices = [] + for dim in range(len(x.shape)): + if dim == axis: + # 使用提供的索引 + all_indices.append(indices.to(mindspore.int32)) + else: + # 创建该维度的原始索引 + shape = [1] * len(x.shape) + shape[dim] = x.shape[dim] + dim_indices = core.arange(x.shape[dim], dtype=mindspore.int32) + dim_indices = core.reshape(dim_indices, shape) + # 广播到 indices 的形状 + dim_indices = core.broadcast_to(dim_indices, indices.shape) + all_indices.append(dim_indices) + + # 组合所有维度的索引 + multi_indices = core.stack(all_indices, axis=-1) + + # 使用 tf.gather_nd 收集元素 + return gather_nd(x, multi_indices) + # hsplit diff --git a/mindnlp/core/ops/comparison.py b/mindnlp/core/ops/comparison.py index 18000f740..f402223b0 100644 --- a/mindnlp/core/ops/comparison.py +++ b/mindnlp/core/ops/comparison.py @@ -14,7 +14,7 @@ def allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): rtol = rtol.item() if isinstance(rtol, mindspore.Tensor) else rtol atol = atol.item() if isinstance(atol, mindspore.Tensor) else atol - if use_pyboost() and has_allclose: + if use_pyboost() and has_allclose and not ON_ORANGE_PI: return mindspore.mint.allclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan) return np.allclose(input.numpy(), other.numpy(), rtol, atol, equal_nan) @@ -37,7 +37,7 @@ def eq(input, other, *, out=None): # equal has_equal = hasattr(mindspore.mint, 'equal') def equal(input, other): - if use_pyboost() and has_equal: + if use_pyboost() and has_equal and not ON_ORANGE_PI: return mindspore.mint.equal(input, other) if input.shape != other.shape: return False diff --git a/mindnlp/core/ops/pointwise.py b/mindnlp/core/ops/pointwise.py index 98465b043..afce9a17e 100644 --- a/mindnlp/core/ops/pointwise.py +++ b/mindnlp/core/ops/pointwise.py @@ -587,7 +587,10 @@ def mul(input, other, *, out=None): else: if input.dtype == mindspore.bool_: if isinstance(other, bool): - out = ops.bitwise_and(input, other) + if ON_ORANGE_PI: + out = ops.bitwise_and(input.int(), other).bool() + else: + out = ops.bitwise_and(input, other) else: out = ops.mul(input.int(), other) else: diff --git a/mindnlp/core/ops/reduction.py b/mindnlp/core/ops/reduction.py index 0ad1b7064..3a142ffc7 100644 --- a/mindnlp/core/ops/reduction.py +++ b/mindnlp/core/ops/reduction.py @@ -4,7 +4,7 @@ import mindspore from mindspore import ops from mindspore.ops._primitive_cache import _get_cache_prim -from ..configs import use_pyboost, DEVICE_TARGET +from ..configs import use_pyboost, DEVICE_TARGET, ON_ORANGE_PI from ._inner import call_ms_func from mindnlp import core @@ -137,7 +137,7 @@ def nanmedian(input, dim=-1, keepdim=False): # norm has_norm = hasattr(mindspore.mint, 'norm') def norm(input, p='fro', dim=None, keepdim=False, out=None, dtype=None): - if use_pyboost() and has_norm: + if use_pyboost() and has_norm and not ON_ORANGE_PI: return call_ms_func(mindspore.mint.norm, input, p, dim, keepdim, out=out, dtype=dtype) if p == 'fro': p = None @@ -339,7 +339,7 @@ def std(input, dim=None, *, correction=1, keepdim=False, **kwargs): axis = kwargs.get('axis', None) if axis is not None: dim = axis - if use_pyboost() and has_std: + if use_pyboost() and has_std and not ON_ORANGE_PI: return mindspore.mint.std(input, dim=dim, correction=correction, keepdim=keepdim) if DEVICE_TARGET == 'GPU': unbiased = bool(correction)