From a4aa5a08502b3bfad25d06145ddcbcd56e608f09 Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Sun, 28 Sep 2025 08:02:03 +0000 Subject: [PATCH] fix e-g class on GPU --- mindtorch/_apis/gpu.py | 16 +++++++++++++++- mindtorch/_dynamo/__init__.py | 2 +- mindtorch/fft/__init__.py | 25 +++++++++++++++++++++++-- mindtorch/ops/_inner.py | 14 +++++++++----- mindtorch/ops/array.py | 10 +++++----- mindtorch/ops/other.py | 3 ++- 6 files changed, 55 insertions(+), 15 deletions(-) diff --git a/mindtorch/_apis/gpu.py b/mindtorch/_apis/gpu.py index 92c3e6764..f940d011b 100644 --- a/mindtorch/_apis/gpu.py +++ b/mindtorch/_apis/gpu.py @@ -1133,6 +1133,8 @@ def bernoulli(input, generator): return legacy.bernoulli(input, seed, offset) def arange(start, end, step, dtype): + end = type(start)(end) + step = type(start)(step) if dtype is not None: return cast(legacy.range(start, end, step, 1000000), dtype) return legacy.range(start, end, step, 1000000) @@ -1228,4 +1230,16 @@ def as_strided(self, size, stride, storage_offset=None): else: input_indices = mindspore.tensor(index.astype(np.int32)) out = gather(reshape(self, (-1,)), input_indices, 0, 0) - return out \ No newline at end of file + return out + +def fft(input, n=None, dim=-1, norm="backward"): + if norm is None: + norm="backward" + if input.shape[dim] < n: + pad_inf = (0, n - input.shape[dim]) + pad_dims = (0, 0) * (input.ndim - (dim + 1)) + pad_inf + input = pad_v3(input, pad_dims, 'constant', 0, True) + else: + input = narrow(input, dim, 0, n) + return legacy.fft_with_size(input, input.ndim, False, False, norm, True, ()) + diff --git a/mindtorch/_dynamo/__init__.py b/mindtorch/_dynamo/__init__.py index 62e799a95..5feb456ad 100644 --- a/mindtorch/_dynamo/__init__.py +++ b/mindtorch/_dynamo/__init__.py @@ -1,7 +1,7 @@ from .decorators import ( allow_in_graph, # assume_constant_result, - # disable, + disable, # disallow_in_graph, # dont_skip_tracing, # forbid_in_graph, diff --git a/mindtorch/fft/__init__.py b/mindtorch/fft/__init__.py index 8df9b0e15..7810159da 100644 --- a/mindtorch/fft/__init__.py +++ b/mindtorch/fft/__init__.py @@ -1,4 +1,5 @@ """fft""" +import mindtorch from ..executor import execute @@ -29,10 +30,30 @@ def irfft(input, n=None, dim=-1, norm="backward"): # return _irfft(input) def fftn(input, s=None, dim=None, norm=None): - return execute('fftn', input, s, dim, norm) + if input.device.type == 'npu': + return execute('fftn', input, s, dim, norm) + if dim is None: + dim = tuple(range(input.dim())) + if s is None: + s = [input.size(d) for d in dim] + + # 确保s和dim是序列且长度相同 + if not isinstance(s, (list, tuple)): + s = (s,) + if not isinstance(dim, (list, tuple)): + dim = (dim,) + if len(s) != len(dim): + raise ValueError("参数 's' 和 'dim' 必须具有相同的长度。") + + output = input.to(mindtorch.complex64) if input.is_floating_point() else input.clone() + + # 逐个维度进行FFT + for d, n in zip(dim, s): + output = fft(output, s=n, dim=d, norm=norm) + return output def fft(input, s=None, dim=-1, norm=None): - return ops.fft(input, s, dim, norm) + return execute('fft', input, s, dim, norm) def fftshift(x, dim=None): return ops.fftshift(x, dim) diff --git a/mindtorch/ops/_inner.py b/mindtorch/ops/_inner.py index 1cf80e84d..4fa5a1d6a 100644 --- a/mindtorch/ops/_inner.py +++ b/mindtorch/ops/_inner.py @@ -1,4 +1,5 @@ """inner ops""" +import mindtorch from mindtorch.executor import execute def cast(input, dtype): @@ -17,11 +18,14 @@ def all_finite(inputs): return execute('all_finite', inputs) def custom_masked_scatter_vec(input, mask, source): - output = input.clone() - if mask.sum() == 0: - return output - output[mask] = source.flatten() # 关键的一行:向量化赋值 - return output + indices = mindtorch.nonzero(mask) + + # 如果 src 是 1D,按顺序取值 + updates = source.reshape(-1)[:indices.shape[0]] + + # 更新 tensor + out = mindtorch.scatter_nd_update(input, indices, updates) + return out def masked_scatter(input, mask, source): if input.device.type == 'cuda': diff --git a/mindtorch/ops/array.py b/mindtorch/ops/array.py index 4cc10d553..1173485d2 100644 --- a/mindtorch/ops/array.py +++ b/mindtorch/ops/array.py @@ -36,7 +36,7 @@ def infer_dtype(dtypes): def cat(tensors, dim=0, **kwargs): dim = kwargs.pop('axis', dim) dtype = infer_dtype([t.dtype for t in tensors]) - tensors = [t.to(dtype) for t in tensors if 0 not in t.shape] + tensors = [t.to(dtype) for t in tensors if t.shape != (0,)] return execute("concat", tensors, dim) @@ -1108,12 +1108,11 @@ def strided_slice_update(x, begin, end, strides, updates, e = end[i] if i < len(end) else x_shape[dim] s = strides[i] if i < len(strides) else 1 if b < 0: - b += x_shape[dim] + b %= x_shape[dim] if e == 0: e += x_shape[dim] if e < 0: - e += x_shape[dim] - + e %= x_shape[dim] # begin_mask / end_mask if i < len(begin) and ((begin_mask >> i) & 1): b = 0 if s > 0 else x_shape[dim]-1 @@ -1229,5 +1228,6 @@ def setitem_np(input, slice, value): 'setitem', 'getitem_np', 'setitem_np', - 'split_with_sizes' + 'split_with_sizes', + 'scatter_nd_update' ] diff --git a/mindtorch/ops/other.py b/mindtorch/ops/other.py index b0c20fd2d..dfcd722de 100644 --- a/mindtorch/ops/other.py +++ b/mindtorch/ops/other.py @@ -843,7 +843,8 @@ def efficient_repeat_interleave(input_tensor, repeats, dim=None): current_pos = 0 for i in range(dim_size): repeat_count = repeats_tensor[i].item() - index[current_pos:current_pos + repeat_count] = i + if repeat_count > 0: + index[current_pos:current_pos + repeat_count] = i current_pos += repeat_count output = input_tensor.index_select(dim, index)