Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mindnlp/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from .amp import autocast, GradScaler

from . import profiler, cuda, optim, amp, compiler, jit, version, __future__, overrides, \
return_types, linalg, fx, backends, testing, nn
return_types, linalg, fx, backends, testing, nn, fft

from ._lowrank import svd_lowrank
from .random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state
Expand Down
43 changes: 43 additions & 0 deletions mindnlp/core/_tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import numpy as np
from functools import partial
import mindspore
from mindspore import Tensor
from mindspore.common.tensor import _TensorMeta
Expand Down Expand Up @@ -226,8 +227,42 @@ def __getitem__(self, slices):
Tensor.__getitem__ = __getitem__
StubTensor.__getitem__ = __getitem__

def _convert_numpy_slices(self, key):
"""递归转换 key 中的 NumPy 整数为内置 int"""
# 处理元组:遍历所有元素并递归转换
if isinstance(key, tuple):
return tuple(self._convert_numpy_slices(k) for k in key)

# 处理 slice 对象:转换 start/stop/step
elif isinstance(key, slice):
start = key.start
stop = key.stop
step = key.step

# 转换 NumPy 整数为 Python int
if isinstance(start, np.integer):
start = int(start)
if isinstance(stop, np.integer):
stop = int(stop)
if isinstance(step, np.integer):
step = int(step)

return slice(start, stop, step)

# 转换单个 NumPy 索引值
elif isinstance(key, np.integer):
return int(key)

# 其他类型(如 int、None)直接返回
else:
return key

Tensor._convert_numpy_slices = _convert_numpy_slices
StubTensor._convert_numpy_slices = _convert_numpy_slices

origin_setitem = Tensor.__setitem__
def __setitem__(self, slices, value):
slices = self._convert_numpy_slices(slices)
if isinstance(value, float):
if value == float('inf'):
value = ops.finfo(self.dtype).max
Expand Down Expand Up @@ -399,6 +434,14 @@ def __rmul__(self, other):
Tensor.__rmul__ = __rmul__
StubTensor.__rmul__ = __rmul__

Tensor.norm = ops.norm
StubTensor.norm = ops.norm

def clamp_min(self, value):
return ops.clamp(self, value)
Tensor.clamp_min = clamp_min
StubTensor.clamp_min = clamp_min

def _rebuild_from_type_v2(func, new_type, args, state):
ret = func(*args)
return ret
38 changes: 38 additions & 0 deletions mindnlp/core/fft/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""fft"""
from mindspore import ops
from mindspore.ops._primitive_cache import _get_cache_prim
from ..configs import use_pyboost
from ..ops import narrow
from ..nn import functional as F

def rfft(input, n=None, dim=-1, norm="backward"):
if use_pyboost():
return ops.rfft(input, n, dim, norm)
if input.shape[dim] < n:
pad_inf = (0, n - input.shape[dim])
pad_dims = (0, 0) * (input.ndim - (dim + 1)) + pad_inf
input = F.pad(input, pad_dims)
else:
input = narrow(input, dim, 0, n)
_rfft = _get_cache_prim(ops.FFTWithSize)(input.ndim, False, True, norm)
return _rfft(input)

def irfft(input, n=None, dim=-1, norm="backward"):
if use_pyboost():
return ops.irfft(input, n, dim, norm)
if input.shape[dim] < n:
pad_inf = (0, n - input.shape[dim])
pad_dims = (0, 0) * (input.ndim - (dim + 1)) + pad_inf
input = pad(input, pad_dims)
else:
input = narrow(input, dim, 0, n)
_irfft = _get_cache_prim(ops.FFTWithSize)(input.ndim, True, True, norm)
return _irfft(input)

def fftn(input, s=None, dim=None, norm=None):
return ops.fftn(input, s, dim, norm)

def fft(input, s=None, dim=-1, norm=None):
return ops.fft(input, s, dim, norm)

__all__ = ['fft', 'fftn', 'irfft', 'rfft']
38 changes: 37 additions & 1 deletion mindnlp/core/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,15 +246,51 @@ def apply_rotary_pos_emb(query, key, cos, sin, position_ids, cos_format=0):
query, key, cos, sin, position_ids, cos_format
)

def custom_circular_pad(x, pad):
"""手动实现 torch.nn.functional.pad 的 circular 模式。

参数:
x: 输入张量,形状为 (B, C, D1, D2, ...)
pad: 填充参数,格式为 (left_N, right_N, left_{N-1}, right_{N-1}, ..., left_1, right_1)
表示从最后维度开始向前定义填充大小

返回:
循环填充后的张量
"""
ndim = x.dim()
n_pad_dims = len(pad) // 2
assert n_pad_dims <= ndim, "填充参数超过了张量的维度"

# 按从最后维度向前处理填充
for dim in range(ndim-1, ndim-1-n_pad_dims, -1):
# 当前维度的左右填充量
idx = 2 * (ndim - 1 - dim) # 在pad元组中的起始位置
left_pad = pad[idx]
right_pad = pad[idx + 1]

if left_pad == 0 and right_pad == 0:
continue # 跳过该维度

size = x.shape[dim] # 当前维度的原始长度
new_size = left_pad + size + right_pad

# 生成循环索引: (index - left_pad) mod size
index = (core.arange(new_size) - left_pad) % size
x = core.index_select(x, dim, index)

return x

def pad(input, pad, mode='constant', value=0.0):
if sum(pad) == 0:
return input
if isinstance(pad, tuple):
pad = tuple(p if isinstance(p, int) else p.item() for p in pad)
if use_pyboost() and not ON_A1:
return mint.nn.functional.pad(input, pad, mode, value)
if mode in ['reflect', 'circular']:
if mode == 'reflect':
return ops.pad(input, pad, mode)
if mode == 'circular':
return custom_circular_pad(input, pad)
new_pad = ()
for idx, pad_v in enumerate(pad):
if pad_v < 0:
Expand Down
2 changes: 0 additions & 2 deletions mindnlp/core/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .reduction import *
from .other import *
from .tensor import *
# from .fft_op import *
# from .spectral import *
from ._inner import *
from .optim import *
Expand All @@ -27,7 +26,6 @@ def load_library(lib_path):
__all__.extend(blas.__all__)
__all__.extend(comparison.__all__)
__all__.extend(creation.__all__)
# __all__.extend(fft_op.__all__)
__all__.extend(pointwise.__all__)
__all__.extend(random.__all__)
__all__.extend(reduction.__all__)
Expand Down
2 changes: 1 addition & 1 deletion mindnlp/core/ops/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def empty(*size, dtype=None, device=None, requires_grad=False, pin_memory=False,
if device is None:
device= get_default_device()

if isinstance(size[0], (tuple, list)):
if len(size) > 0 and isinstance(size[0], (tuple, list)):
size = size[0]

if dtype is None:
Expand Down
Loading