diff --git a/mindnlp/core/__init__.py b/mindnlp/core/__init__.py index 8ecdbdf5e..1b4dfe125 100644 --- a/mindnlp/core/__init__.py +++ b/mindnlp/core/__init__.py @@ -47,7 +47,7 @@ from .amp import autocast, GradScaler from . import profiler, cuda, optim, amp, compiler, jit, version, __future__, overrides, \ - return_types, linalg, fx, backends, testing, nn + return_types, linalg, fx, backends, testing, nn, fft from ._lowrank import svd_lowrank from .random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py index a88d65fb1..174f1e3fc 100644 --- a/mindnlp/core/_tensor.py +++ b/mindnlp/core/_tensor.py @@ -1,5 +1,6 @@ import math import numpy as np +from functools import partial import mindspore from mindspore import Tensor from mindspore.common.tensor import _TensorMeta @@ -226,8 +227,42 @@ def __getitem__(self, slices): Tensor.__getitem__ = __getitem__ StubTensor.__getitem__ = __getitem__ + def _convert_numpy_slices(self, key): + """递归转换 key 中的 NumPy 整数为内置 int""" + # 处理元组:遍历所有元素并递归转换 + if isinstance(key, tuple): + return tuple(self._convert_numpy_slices(k) for k in key) + + # 处理 slice 对象:转换 start/stop/step + elif isinstance(key, slice): + start = key.start + stop = key.stop + step = key.step + + # 转换 NumPy 整数为 Python int + if isinstance(start, np.integer): + start = int(start) + if isinstance(stop, np.integer): + stop = int(stop) + if isinstance(step, np.integer): + step = int(step) + + return slice(start, stop, step) + + # 转换单个 NumPy 索引值 + elif isinstance(key, np.integer): + return int(key) + + # 其他类型(如 int、None)直接返回 + else: + return key + + Tensor._convert_numpy_slices = _convert_numpy_slices + StubTensor._convert_numpy_slices = _convert_numpy_slices + origin_setitem = Tensor.__setitem__ def __setitem__(self, slices, value): + slices = self._convert_numpy_slices(slices) if isinstance(value, float): if value == float('inf'): value = ops.finfo(self.dtype).max @@ -399,6 +434,14 @@ def __rmul__(self, other): Tensor.__rmul__ = __rmul__ StubTensor.__rmul__ = __rmul__ + Tensor.norm = ops.norm + StubTensor.norm = ops.norm + + def clamp_min(self, value): + return ops.clamp(self, value) + Tensor.clamp_min = clamp_min + StubTensor.clamp_min = clamp_min + def _rebuild_from_type_v2(func, new_type, args, state): ret = func(*args) return ret \ No newline at end of file diff --git a/mindnlp/core/fft/__init__.py b/mindnlp/core/fft/__init__.py new file mode 100644 index 000000000..c208ed94d --- /dev/null +++ b/mindnlp/core/fft/__init__.py @@ -0,0 +1,38 @@ +"""fft""" +from mindspore import ops +from mindspore.ops._primitive_cache import _get_cache_prim +from ..configs import use_pyboost +from ..ops import narrow +from ..nn import functional as F + +def rfft(input, n=None, dim=-1, norm="backward"): + if use_pyboost(): + return ops.rfft(input, n, dim, norm) + if input.shape[dim] < n: + pad_inf = (0, n - input.shape[dim]) + pad_dims = (0, 0) * (input.ndim - (dim + 1)) + pad_inf + input = F.pad(input, pad_dims) + else: + input = narrow(input, dim, 0, n) + _rfft = _get_cache_prim(ops.FFTWithSize)(input.ndim, False, True, norm) + return _rfft(input) + +def irfft(input, n=None, dim=-1, norm="backward"): + if use_pyboost(): + return ops.irfft(input, n, dim, norm) + if input.shape[dim] < n: + pad_inf = (0, n - input.shape[dim]) + pad_dims = (0, 0) * (input.ndim - (dim + 1)) + pad_inf + input = pad(input, pad_dims) + else: + input = narrow(input, dim, 0, n) + _irfft = _get_cache_prim(ops.FFTWithSize)(input.ndim, True, True, norm) + return _irfft(input) + +def fftn(input, s=None, dim=None, norm=None): + return ops.fftn(input, s, dim, norm) + +def fft(input, s=None, dim=-1, norm=None): + return ops.fft(input, s, dim, norm) + +__all__ = ['fft', 'fftn', 'irfft', 'rfft'] \ No newline at end of file diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py index be9bd2c60..60a5bef53 100644 --- a/mindnlp/core/nn/functional.py +++ b/mindnlp/core/nn/functional.py @@ -246,6 +246,40 @@ def apply_rotary_pos_emb(query, key, cos, sin, position_ids, cos_format=0): query, key, cos, sin, position_ids, cos_format ) +def custom_circular_pad(x, pad): + """手动实现 torch.nn.functional.pad 的 circular 模式。 + + 参数: + x: 输入张量,形状为 (B, C, D1, D2, ...) + pad: 填充参数,格式为 (left_N, right_N, left_{N-1}, right_{N-1}, ..., left_1, right_1) + 表示从最后维度开始向前定义填充大小 + + 返回: + 循环填充后的张量 + """ + ndim = x.dim() + n_pad_dims = len(pad) // 2 + assert n_pad_dims <= ndim, "填充参数超过了张量的维度" + + # 按从最后维度向前处理填充 + for dim in range(ndim-1, ndim-1-n_pad_dims, -1): + # 当前维度的左右填充量 + idx = 2 * (ndim - 1 - dim) # 在pad元组中的起始位置 + left_pad = pad[idx] + right_pad = pad[idx + 1] + + if left_pad == 0 and right_pad == 0: + continue # 跳过该维度 + + size = x.shape[dim] # 当前维度的原始长度 + new_size = left_pad + size + right_pad + + # 生成循环索引: (index - left_pad) mod size + index = (core.arange(new_size) - left_pad) % size + x = core.index_select(x, dim, index) + + return x + def pad(input, pad, mode='constant', value=0.0): if sum(pad) == 0: return input @@ -253,8 +287,10 @@ def pad(input, pad, mode='constant', value=0.0): pad = tuple(p if isinstance(p, int) else p.item() for p in pad) if use_pyboost() and not ON_A1: return mint.nn.functional.pad(input, pad, mode, value) - if mode in ['reflect', 'circular']: + if mode == 'reflect': return ops.pad(input, pad, mode) + if mode == 'circular': + return custom_circular_pad(input, pad) new_pad = () for idx, pad_v in enumerate(pad): if pad_v < 0: diff --git a/mindnlp/core/ops/__init__.py b/mindnlp/core/ops/__init__.py index e4b029cbd..63ed67c5b 100644 --- a/mindnlp/core/ops/__init__.py +++ b/mindnlp/core/ops/__init__.py @@ -10,7 +10,6 @@ from .reduction import * from .other import * from .tensor import * -# from .fft_op import * # from .spectral import * from ._inner import * from .optim import * @@ -27,7 +26,6 @@ def load_library(lib_path): __all__.extend(blas.__all__) __all__.extend(comparison.__all__) __all__.extend(creation.__all__) -# __all__.extend(fft_op.__all__) __all__.extend(pointwise.__all__) __all__.extend(random.__all__) __all__.extend(reduction.__all__) diff --git a/mindnlp/core/ops/creation.py b/mindnlp/core/ops/creation.py index 5bb550a32..502a907a7 100644 --- a/mindnlp/core/ops/creation.py +++ b/mindnlp/core/ops/creation.py @@ -142,7 +142,7 @@ def empty(*size, dtype=None, device=None, requires_grad=False, pin_memory=False, if device is None: device= get_default_device() - if isinstance(size[0], (tuple, list)): + if len(size) > 0 and isinstance(size[0], (tuple, list)): size = size[0] if dtype is None: