Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion mindtorch/_apis/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,8 @@ def bernoulli(input, generator):
return legacy.bernoulli(input, seed, offset)

def arange(start, end, step, dtype):
end = type(start)(end)
step = type(start)(step)
if dtype is not None:
return cast(legacy.range(start, end, step, 1000000), dtype)
return legacy.range(start, end, step, 1000000)
Expand Down Expand Up @@ -1228,4 +1230,16 @@ def as_strided(self, size, stride, storage_offset=None):
else:
input_indices = mindspore.tensor(index.astype(np.int32))
out = gather(reshape(self, (-1,)), input_indices, 0, 0)
return out
return out

def fft(input, n=None, dim=-1, norm="backward"):
if norm is None:
norm="backward"
if input.shape[dim] < n:
pad_inf = (0, n - input.shape[dim])
pad_dims = (0, 0) * (input.ndim - (dim + 1)) + pad_inf
input = pad_v3(input, pad_dims, 'constant', 0, True)
else:
input = narrow(input, dim, 0, n)
return legacy.fft_with_size(input, input.ndim, False, False, norm, True, ())

2 changes: 1 addition & 1 deletion mindtorch/_dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .decorators import (
allow_in_graph,
# assume_constant_result,
# disable,
disable,
# disallow_in_graph,
# dont_skip_tracing,
# forbid_in_graph,
Expand Down
25 changes: 23 additions & 2 deletions mindtorch/fft/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""fft"""
import mindtorch
from ..executor import execute


Expand Down Expand Up @@ -29,10 +30,30 @@ def irfft(input, n=None, dim=-1, norm="backward"):
# return _irfft(input)

def fftn(input, s=None, dim=None, norm=None):
return execute('fftn', input, s, dim, norm)
if input.device.type == 'npu':
return execute('fftn', input, s, dim, norm)
if dim is None:
dim = tuple(range(input.dim()))
if s is None:
s = [input.size(d) for d in dim]

# 确保s和dim是序列且长度相同
if not isinstance(s, (list, tuple)):
s = (s,)
if not isinstance(dim, (list, tuple)):
dim = (dim,)
if len(s) != len(dim):
raise ValueError("参数 's' 和 'dim' 必须具有相同的长度。")

output = input.to(mindtorch.complex64) if input.is_floating_point() else input.clone()

# 逐个维度进行FFT
for d, n in zip(dim, s):
output = fft(output, s=n, dim=d, norm=norm)
return output

def fft(input, s=None, dim=-1, norm=None):
return ops.fft(input, s, dim, norm)
return execute('fft', input, s, dim, norm)

def fftshift(x, dim=None):
return ops.fftshift(x, dim)
Expand Down
14 changes: 9 additions & 5 deletions mindtorch/ops/_inner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""inner ops"""
import mindtorch
from mindtorch.executor import execute

def cast(input, dtype):
Expand All @@ -17,11 +18,14 @@ def all_finite(inputs):
return execute('all_finite', inputs)

def custom_masked_scatter_vec(input, mask, source):
output = input.clone()
if mask.sum() == 0:
return output
output[mask] = source.flatten() # 关键的一行:向量化赋值
return output
indices = mindtorch.nonzero(mask)

# 如果 src 是 1D,按顺序取值
updates = source.reshape(-1)[:indices.shape[0]]

# 更新 tensor
out = mindtorch.scatter_nd_update(input, indices, updates)
return out

def masked_scatter(input, mask, source):
if input.device.type == 'cuda':
Expand Down
10 changes: 5 additions & 5 deletions mindtorch/ops/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def infer_dtype(dtypes):
def cat(tensors, dim=0, **kwargs):
dim = kwargs.pop('axis', dim)
dtype = infer_dtype([t.dtype for t in tensors])
tensors = [t.to(dtype) for t in tensors if 0 not in t.shape]
tensors = [t.to(dtype) for t in tensors if t.shape != (0,)]
return execute("concat", tensors, dim)


Expand Down Expand Up @@ -1108,12 +1108,11 @@ def strided_slice_update(x, begin, end, strides, updates,
e = end[i] if i < len(end) else x_shape[dim]
s = strides[i] if i < len(strides) else 1
if b < 0:
b += x_shape[dim]
b %= x_shape[dim]
if e == 0:
e += x_shape[dim]
if e < 0:
e += x_shape[dim]

e %= x_shape[dim]
# begin_mask / end_mask
if i < len(begin) and ((begin_mask >> i) & 1):
b = 0 if s > 0 else x_shape[dim]-1
Expand Down Expand Up @@ -1229,5 +1228,6 @@ def setitem_np(input, slice, value):
'setitem',
'getitem_np',
'setitem_np',
'split_with_sizes'
'split_with_sizes',
'scatter_nd_update'
]
3 changes: 2 additions & 1 deletion mindtorch/ops/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,8 @@ def efficient_repeat_interleave(input_tensor, repeats, dim=None):
current_pos = 0
for i in range(dim_size):
repeat_count = repeats_tensor[i].item()
index[current_pos:current_pos + repeat_count] = i
if repeat_count > 0:
index[current_pos:current_pos + repeat_count] = i
current_pos += repeat_count

output = input_tensor.index_select(dim, index)
Expand Down