From e5992a80fdef3fabd2ec47922c437a243faa5ce5 Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Wed, 19 Nov 2025 16:53:06 +0800 Subject: [PATCH] refactor cpu setitem/getitem --- mindtorch/_apis/cpu.py | 747 ++++++++++++++++++++++++++++++++++-- mindtorch/_apis/meta.py | 25 +- mindtorch/_apis/npu_910a.py | 27 +- mindtorch/_apis/npu_910b.py | 5 +- mindtorch/ops/creation.py | 4 +- 5 files changed, 765 insertions(+), 43 deletions(-) diff --git a/mindtorch/_apis/cpu.py b/mindtorch/_apis/cpu.py index 14ef3daab..2dd909c8d 100644 --- a/mindtorch/_apis/cpu.py +++ b/mindtorch/_apis/cpu.py @@ -21,18 +21,11 @@ def inplace_normal(input, mean, std, generator_): return input def select_ext_view(input, dim, index): - return legacy.select_view(input, index, dim) + return pyboost.select_ext_view_op(input, dim, index) def inplace_copy(input, value): - if value.shape != input.shape: - value = legacy.fill_v2(input.shape, value) - # inplace_copy(input, value) - # t2t_overwrite(input, value) - # legacy.assign(input, value) - if hasattr(input, '_base'): - input._base.assign_value(value) - input.assign_value(value) - return input + return pyboost.inplace_copy_op(input, value) + def fill_scalar(size, fill_value, dtype): if dtype is None: @@ -82,6 +75,7 @@ def inplace_zero(input): inplace_copy(input, legacy.zeros_like(input)) return input +py_abs = abs def abs(input): return legacy.abs(input) @@ -91,6 +85,7 @@ def identity(input): def clone(input): return cast(legacy.mul(input, 1), input.dtype) +py_max = max def max(input): return legacy.reduce_max(input, (), False) @@ -153,21 +148,14 @@ def tile(input, dims): py_slice = slice def slice(self, dim, start, end, step): - ndim = self.ndim - begins = [0] * ndim - ends = [i for i in self.shape] - strides = [1] * ndim - begins[dim] = start - ends[dim] = end - strides[dim] = step - return legacy.strided_slice(self, tuple(begins), tuple(ends), tuple(strides), 0, 0, 0, 0, 0) + return pyboost.slice_ext_view_op(self, dim, start, end, step) def pad_v3(input, new_pad, mode, value=None, contiguous=True): return legacy.pad_v3(input, new_pad, value, mode, contiguous) def cumsum(self, dim, dtype): if self.shape[dim] == 0: - return mindtorch.tensor([], dtype=self.dtype, device=self.device) + return mindspore.tensor([], dtype=self.dtype) return legacy.cum_sum(self, dim, False, False) def reduce_any(input, axis, keepdims): @@ -187,7 +175,6 @@ def numpy_to_tensor_overwrite(np_array, tensor): return tensor def t2t_overwrite(input, other): - other._device = input.device ctypes.memmove(input.data_ptr(), other.data_ptr(), input.nbytes) return input @@ -399,6 +386,7 @@ def masked_fill(input, mask, value): value = float(value) return legacy.masked_fill(input, mask, value) +py_sum = sum def sum(input, dim, keepdim, dtype): if dim is None: dim = () @@ -512,7 +500,7 @@ def split_with_size(tensor, split_sizes, dim=0): end = start + chunk_size slice_obj = [py_slice(None)] * tensor.dim() slice_obj[dim] = py_slice(start, end) - chunks.append(tensor[tuple(slice_obj)]) + chunks.append(getitem(tensor, tuple(slice_obj))) start = end return tuple(chunks) @@ -594,9 +582,8 @@ def dropout(input, p, training=True): return legacy.dropout(input, 1-p, 0, 0) def split_tensor(input, split_size_or_sections, dim): - if isinstance(split_size_or_sections, int): - num = input.shape[dim] // split_size_or_sections - return legacy.split(input, dim, num) + num = input.shape[dim] // split_size_or_sections + return legacy.split(input, dim, num) def bmm(input_x, input_y): return legacy.batch_mat_mul(input_x, input_y, False, False) @@ -1300,7 +1287,7 @@ def pad(input, pad, mode='constant', value=None): input = narrow(input, dim, 0, input.shape[dim] + pad_v) pad_v = 0 new_pad += (pad_v,) - if sum(new_pad) == 0: + if py_sum(new_pad) == 0: return input if mode == 'circular': return custom_circular_pad(input, pad) @@ -1318,4 +1305,712 @@ def pad(input, pad, mode='constant', value=None): elif input.dtype in [mindtorch.int32, mindtorch.int64]: value = int(value) - return pad_v3(input, new_pad, mode, value) \ No newline at end of file + return pad_v3(input, new_pad, mode, value) + +tensor_1d = mindspore.Tensor([0], dtype=mindtorch.int64) +empty_tensor_1d = mindspore.Tensor(shape=(0,), dtype=mindtorch.int64) +empty_tensor_9d = mindspore.Tensor(shape=(0,)*9, dtype=mindtorch.int64) + +def _do_select(self, dim: int, index: int, dim_index: int, self_shape: list): + """call select view operator""" + if not self_shape: + raise TypeError("Invalid index of a 0-dim tensor.") + dim_size = self_shape[dim] + if index >= dim_size or index < -dim_size: + raise IndexError(f"Index {index} is out of bounds for dimension {dim_index} with size {dim_size}") + index = index + dim_size if index < 0 else index + return select_ext_view(self, dim, index) + + +def _do_slice(self, dim: int, index: py_slice, self_shape: list): + """call slice view operator""" + def _get_index(index, default): + if index is None: + return default + if mindtorch.is_tensor(index): + index = int(index) + return index + + if not self_shape: + raise TypeError("Invalid index of a 0-dim tensor.") + step = _get_index(index.step, 1) + if step <= 0: + raise ValueError("slice step must be positive") + start = _get_index(index.start, 0) + end = _get_index(index.stop, self_shape[dim]) + if start == 0 and end == self_shape[dim] and step == 1: + return self + return slice(self, dim, start, end, step) + +def _wrap_index_to_tuple(index): + """Wrap index to tuple""" + if isinstance(index, tuple): + return index + if isinstance(index, list): + if len(index) < 32 and any(isinstance(i, (mindtorch.Tensor, list, tuple, py_slice, type(None), type(...))) for i in index): + return tuple(index) + return (index,) + + +def _count_indexed_dims(indexes): + """Count indexed dims""" + count = 0 + for index in indexes: + if isinstance(index, mindtorch.Tensor): + if index.dtype == mindtorch.bool: + count += index.ndim + else: + count += 1 + elif not isinstance(index, (type(None), type(...), bool)): + count += 1 + return count + +def _record_tensor_index(index, remain_indexes, dim): + """Record indexes remained to be used by aclnnIndex/aclnnIndexPut""" + if len(remain_indexes) > dim: + remain_indexes[dim] = index + return remain_indexes + + while dim > len(remain_indexes): + # use empty_tensor with dim_num 9 to indicate unused dim + remain_indexes.append(py_slice(None, None, None)) + + remain_indexes.append(index) + return remain_indexes + +def _process_dim_in_multi_dim_index(prev_result, orig_tensor, index, dim, indexed_dims, dim_index, remain_indexes, + prev_shape): + """Process dim in multi dim index""" + if isinstance(index, bool): + result = expand_dims(prev_result, dim) + index_for_bool = tensor_1d if index else empty_tensor_1d + _record_tensor_index(index_for_bool, remain_indexes, dim) + prev_shape.insert(dim, 1) + dim += 1 + return result, dim, remain_indexes, prev_shape + if isinstance(index, int): + result = _do_select(prev_result, dim, index, dim_index, prev_shape) + del prev_shape[dim] + return result, dim, remain_indexes, prev_shape + if isinstance(index, py_slice): + result = _do_slice(prev_result, dim, index, prev_shape) + # current dim in prev_shape will not be used later, ignore it + dim += 1 + return result, dim, remain_indexes, prev_shape + if isinstance(index, type(...)): + dim += (orig_tensor.ndim - indexed_dims) + return prev_result, dim, remain_indexes, prev_shape + if index is None: + result = expand_dims(prev_result, dim) + prev_shape.insert(dim, 1) + dim += 1 + return result, dim, remain_indexes, prev_shape + if isinstance(index, mindtorch.Tensor): + result = prev_result + if index.ndim == 0 and index.dtype in (mindtorch.int, mindtorch.long, mindtorch.short, mindtorch.bool): + if index.dtype in (mindtorch.int, mindtorch.long, mindtorch.short): + result = _do_select(prev_result, dim, index.item(), dim_index, prev_shape) + del prev_shape[dim] + return result, dim, remain_indexes, prev_shape + # process index with Tensor bool type + result = expand_dims(prev_result, dim) + index_for_bool = tensor_1d if index else empty_tensor_1d + _record_tensor_index(index_for_bool, remain_indexes, dim) + prev_shape.insert(dim, 1) + dim += 1 + return result, dim, remain_indexes, prev_shape + _record_tensor_index(index, remain_indexes, dim) + dim += 1 + return result, dim, remain_indexes, prev_shape + raise IndexError(f"Invalid tensor index type {index}") + + +def _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims): + """Process indexes in tuple""" + self_viewed = self + self_viewed_shape = list(self.shape) + dim = 0 + # if ON_ORANGE_PI: + # if all([isinstance(index, slice) for index in indexes]): + # return getitem(self_viewed, tuple(indexes)), remain_indexes + for i, index in enumerate(indexes): + if isinstance(index, (list, tuple, np.ndarray)): + index_np = np.array(index) if isinstance(index, (list, tuple)) else index + if index_np.dtype in (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, + np.float16, np.float32, np.float64): + index = mindspore.tensor(index_np, dtype=mindtorch.int64) + elif index_np.dtype == np.bool_: + index = mindspore.tensor(index_np, dtype=mindtorch.int64) + else: + raise TypeError(f"Index {index} contain unsupported elements") + self_viewed, dim, remain_indexes, self_viewed_shape = _process_dim_in_multi_dim_index( + self_viewed, self, index, dim, indexed_dims, i, remain_indexes, self_viewed_shape) + return self_viewed, remain_indexes + + +def getitem(self, index): + """Handle tensor getitem""" + if isinstance(index, bool): + self_viewed = expand_dims(self, 0) + index_for_bool = tensor_1d if index else empty_tensor_1d + return index(self_viewed, [index_for_bool]) + if isinstance(index, int): + return _do_select(self, 0, index, 0, list(self.shape)) + if isinstance(index, py_slice): + result = _do_slice(self, 0, index, list(self.shape)) + return result + if index is None: + return expand_dims(self, 0) + if isinstance(index, type(...)): + return self + indexes = _wrap_index_to_tuple(index) + indexed_dims = _count_indexed_dims(indexes) + if self.ndim < indexed_dims: + raise IndexError(f"too many indices for tensor with dimension size {self.ndim}") + remain_indexes = [] + self_viewed, remain_indexes = _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims) + if not remain_indexes: + return self_viewed + + out = legacy_getitem(self_viewed, tuple(remain_indexes) if len(remain_indexes) > 1 else remain_indexes[0]) + return out + + +_SLICE_ERROR = ( + 'only integers, slices (`:`), ellipsis (`...`), ' + 'newaxis (`None`) and integer or boolean arrays are valid indices' +) + + +def _as_index(idx, need_scalar=True): + """Helper function to parse idx as an index. + """ + if isinstance(idx, numbers.Integral): + return idx, True + + if not isinstance(idx, mindtorch.Tensor): + idx = mindspore.tensor(idx, dtype=mindtorch.int64) + + if idx.dtype == mindtorch.bool: + if idx.ndim > 1: + raise NotImplementedError('Need rank 1 for bool index %s' % idx) + idx = non_zero_ext(idx) + idx = idx.reshape(-1) + + if need_scalar and idx.ndim not in (None, 0): + raise IndexError(_SLICE_ERROR + ', got {!r}'.format(idx)) + + if idx.ndim == 0: + return idx.item(), True + + return idx, False + +def moveaxis(a, source, destination): + """Raises ValueError if source, destination not in (-ndim(a), ndim(a)).""" + if not source and not destination: + return a + + if isinstance(source, int): + source = (source,) + if isinstance(destination, int): + destination = (destination,) + if len(source) != len(destination): + raise ValueError('The lengths of source and destination must equal') + + a_rank = a.ndim + + def _correct_axis(axis, rank): + if axis < 0: + return axis + rank + return axis + + source = tuple(_correct_axis(axis, a_rank) for axis in source) + destination = tuple(_correct_axis(axis, a_rank) for axis in destination) + + if a.ndim is not None: + perm = [i for i in range(a_rank) if i not in source] + for dest, src in sorted(zip(destination, source)): + assert dest <= len(perm) + perm.insert(dest, src) + else: + r = range(0, a_rank, 1) + + def _remove_indices(a, b): + """Remove indices (`b`) from `a`.""" + items = unstack_view( + sort(stack(b), -1, False, False), 0 + ) + + i = 0 + result = [] + + for item in items: + result.append(a[i:item]) + i = item + 1 + + result.append(a[i:]) + + return concat(result, 0) + + minus_sources = _remove_indices(r, source) + minus_dest = _remove_indices(r, destination) + + perm = scatter_nd(expand_dims(minus_dest, 1), minus_sources, [a_rank]) + perm = tensor_scatter_update(perm, expand_dims(destination, 1), source) + a = mindtorch.permute(a, tuple(perm)) + + return a + +def cumprod(x, axis=0, exclusive=False, reverse=False): + x = np.array(x) + if reverse: + x = np.flip(x, axis=axis) + + if exclusive: + shifted_x = np.ones_like(x) + if axis == 0: + shifted_x[1:] = x[:-1] + else: + shifted_x[:, 1:] = x[:, :-1] + result = np.cumprod(shifted_x, axis=axis) + else: + result = np.cumprod(x, axis=axis) + + if reverse: + result = np.flip(result, axis=axis) + + return result + +def broadcast_shapes(*shapes): + reversed_shapes = [list(reversed(shape)) for shape in shapes] + + max_dim = py_max(len(shape) for shape in reversed_shapes) + + result_shape = [1] * max_dim + + for i in range(max_dim): + current_dim_size = 1 + for shape in reversed_shapes: + if i < len(shape): + if shape[i] == 1: + continue + if current_dim_size == 1: + current_dim_size = shape[i] + elif current_dim_size != shape[i]: + raise ValueError(f"Shapes {shapes} are not broadcastable.") + result_shape[i] = current_dim_size + + return tuple(reversed(result_shape)) + +def broadcast_tensors(*tensors): + target_shape = broadcast_shapes(*[t.shape for t in tensors]) + broadcasted_tensors = [broadcast_to(t, target_shape) for t in tensors] + return broadcasted_tensors + + +def _slice_helper(tensor, slice_spec, do_update=False, updates=None): + """Helper function for __getitem__ and _with_index_update_helper. + """ + begin, end, strides = [], [], [] + new_axis_mask, shrink_axis_mask = 0, 0 + begin_mask, end_mask = 0, 0 + ellipsis_mask = 0 + advanced_indices = [] + shrink_indices = [] + for index, s in enumerate(slice_spec): + if isinstance(s, py_slice): + if s.start is not None: + begin.append(s.start) + else: + begin.append(0) + begin_mask |= (1 << index) + if s.stop is not None: + stop = s.stop + if stop == -1: + stop = tensor.shape[index] - 1 + end.append(stop) + else: + end.append(0) + end_mask |= (1 << index) + if s.step is not None: + strides.append(s.step) + else: + strides.append(1) + elif s is Ellipsis: + begin.append(0) + end.append(0) + strides.append(1) + ellipsis_mask |= (1 << index) + elif s is None: + begin.append(0) + end.append(0) + strides.append(1) + new_axis_mask |= (1 << index) + else: + s, is_scalar = _as_index(s, False) + if is_scalar: + begin.append(s) + end.append(s + 1) + strides.append(1) + shrink_axis_mask |= (1 << index) + shrink_indices.append(index) + else: + begin.append(0) + end.append(0) + strides.append(1) + begin_mask |= (1 << index) + end_mask |= (1 << index) + advanced_indices.append((index, s, ellipsis_mask != 0)) + + if do_update and not advanced_indices: + if 0 in updates.shape: + return tensor + return strided_slice_update( + tensor, + begin, + end, + strides, + updates, + begin_mask=begin_mask, + end_mask=end_mask, + shrink_axis_mask=shrink_axis_mask, + new_axis_mask=new_axis_mask, + ellipsis_mask=ellipsis_mask, + ) + else: + if updates is not None: + original_tensor = tensor + if new_axis_mask != 0: + tensor = strided_slice_manual( + tensor, + begin, + end, + strides, + begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask + ) + else: + tensor = strided_slice( + tensor, + begin, + end, + strides, + begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask + ) + + if not advanced_indices: + return tensor + advanced_indices_map = {} + for index, data, had_ellipsis in advanced_indices: + if had_ellipsis: + num_shrink = len([x for x in shrink_indices if x > index]) + dim = index - len(slice_spec) + num_shrink + else: + num_shrink = len([x for x in shrink_indices if x < index]) + dim = index - num_shrink + advanced_indices_map[dim] = data + dims = sorted(advanced_indices_map.keys()) + dims_contiguous = True + if len(dims) > 1: + if dims[0] < 0 and dims[-1] >= 0: # not all same sign + dims_contiguous = False + else: + for i in range(len(dims) - 1): + if dims[i] + 1 != dims[i + 1]: + dims_contiguous = False + break + indices = [advanced_indices_map[x] for x in dims] + indices = broadcast_tensors(*indices) + stacked_indices = stack(indices, -1) + # Skip the contiguous-dims optimization for update because there is no + # tf.*scatter* op that supports the `axis` argument. + if not dims_contiguous or updates is not None: + if range(len(dims)) != dims: + tensor = moveaxis(tensor, dims, range(len(dims))) + tensor_shape_prefix = mindspore.tensor(tensor.shape[: len(dims)]) + stacked_indices = select( + less(stacked_indices, 0), + add(stacked_indices, tensor_shape_prefix), + stacked_indices, + ) + if updates is None: + return gather_nd(tensor, stacked_indices) + else: + # We only need to move-axis `updates` in the contiguous case becausce + # only in this case the result dimensions of advanced indexing are in + # the middle of `updates`. In the non-contiguous case, those dimensions + # are always at the front. + if dims_contiguous and updates.ndim > 1: + batch_size = stacked_indices.ndim - 1 + batch_start = dims[0] + if batch_start < 0: + batch_start += len(dims) - batch_size + + def range_(start, length): + return range(start, start + length) + + updates = moveaxis( + updates, range_(batch_start, batch_size), range(batch_size) + ) + updates = updates.broadcast_to(stacked_indices.shape[:-1] + tensor.shape[stacked_indices.shape[-1]:]) + tensor = tensor_scatter_update(tensor, stacked_indices, updates) + if range(len(dims)) != dims: + tensor = moveaxis(tensor, range(len(dims)), dims) + return strided_slice_update( + original_tensor, + begin, + end, + strides, + tensor, + begin_mask=begin_mask, + end_mask=end_mask, + shrink_axis_mask=shrink_axis_mask, + new_axis_mask=new_axis_mask, + ellipsis_mask=ellipsis_mask, + ) + + # Note that gather_nd does not support gathering from inside the array. + # To avoid shuffling data back and forth, we transform the indices and + # do a gather instead. + rank = tensor.ndim + dims = [(x + rank if x < 0 else x) for x in dims] + shape_tensor = tensor.shape + dim_sizes = np.take_along_axis(np.array(shape_tensor), np.array(dims), axis=0) + if len(dims) == 1: + stacked_indices = indices[0] + stacked_indices = stacked_indices.to(mindtorch.int32) + stacked_indices = select( + less(stacked_indices, 0), add(stacked_indices, mindspore.tensor(dim_sizes, dtype=stacked_indices.dtype)), stacked_indices + ) + axis = dims[0] + if len(dims) > 1: + index_scaling = cumprod(dim_sizes, reverse=True, exclusive=True) + + def _tensordot(a, b): + # TODO(b/168657656): This function should be replaced by + # tensordot(axis=1) once MatMul has int32 XLA kernel. + b = broadcast_to(b, a.shape) + return sum(mul(a,b), -1, False, None) + + stacked_indices = _tensordot(stacked_indices, mindspore.tensor(index_scaling)) + flat_shape = shape_tensor[:axis] + (-1,) + shape_tensor[axis + len(dims) :] + tensor = reshape(tensor, flat_shape) + + return gather(tensor, stacked_indices, axis, 0) + +def _as_spec_tuple(slice_spec): + """Convert slice_spec to tuple.""" + if isinstance(slice_spec, (list, tuple)): + is_index = True + for s in slice_spec: + if s is None or s is Ellipsis or isinstance(s, (list, tuple, slice)): + is_index = False + break + if not is_index: + return tuple(slice_spec) + return (slice_spec,) + +def legacy_getitem(self, slice_spec): + if ( + isinstance(slice_spec, bool) + or ( + isinstance(slice_spec, mindtorch.Tensor) + and slice_spec.dtype == mindtorch.bool + ) + ): + if self.shape == slice_spec.shape: + return masked_select(self, slice_spec) + slice_spec = non_zero_ext(slice_spec) + + if not isinstance(slice_spec, tuple): + slice_spec = _as_spec_tuple(slice_spec) + + result_t = _slice_helper(self, slice_spec) + return result_t + +def setitem(a, slice_spec, updates): + """Implementation of ndarray._with_index_*.""" + if isinstance(updates, numbers.Number): + updates = mindspore.tensor(updates) + if 0 in updates.shape: + return a + if ( + isinstance(slice_spec, bool) + or ( + isinstance(slice_spec, mindtorch.Tensor) + and slice_spec.dtype == mindtorch.bool + ) + ): + if slice_spec.shape == a.shape and (isinstance(updates, numbers.Number) or updates.ndim == 0): + inplace_copy(a, masked_fill(a, slice_spec, updates)) + return a + slice_spec = non_zero_ext(slice_spec) + + if not isinstance(slice_spec, tuple): + slice_spec = _as_spec_tuple(slice_spec) + + a_dtype = a.dtype + result_t = _slice_helper(a, slice_spec, True, updates) + return cast(result_t, a_dtype) + + +def strided_slice_manual(x, begin, end, strides, begin_mask=0, end_mask=0, + ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0): + + x_shape = x.shape + ndim = len(x_shape) + + full_begin, full_end, full_strides = [], [], [] + dim = 0 # 当前 x 的维度 + i = 0 # 当前 begin/end 索引 + + while dim < ndim: + # ellipsis_mask + if i < len(begin) and ((ellipsis_mask >> i) & 1): + remaining_dims = ndim - dim - (len(begin) - i - 1) + shrink_axis_mask = shrink_axis_mask << remaining_dims - 1 + for _ in range(remaining_dims): + full_begin.append(0) + full_end.append(x_shape[dim]) + full_strides.append(1) + dim += 1 + i += 1 + continue + + # new_axis_mask + elif i < len(begin) and ((new_axis_mask >> i) & 1): + full_begin.append(0) + full_end.append(1) + full_strides.append(1) + i += 1 + continue + + else: + # 自动补齐 begin/end/strides + b = begin[i] if i < len(begin) else 0 + e = end[i] if i < len(end) else x_shape[dim] + s = strides[i] if i < len(strides) else 1 + if b < 0: + b += x_shape[dim] + if e == 0: + e += x_shape[dim] + if e < 0: + e += x_shape[dim] + + # begin_mask / end_mask + if i < len(begin) and ((begin_mask >> i) & 1): + b = 0 if s > 0 else x_shape[dim]-1 + if i < len(end) and ((end_mask >> i) & 1): + e = x_shape[dim] if s > 0 else -1 + + full_begin.append(b) + full_end.append(e) + full_strides.append(s) + + dim += 1 + i += 1 + + # Step 2: generate indices for scatter update + ranges = [arange(b, e, s) for b, e, s in zip(full_begin, full_end, full_strides)] + mesh = meshgrid(*ranges, indexing='ij') + indices = stack(mesh, dim=-1) + indices = reshape(indices, [-1, ndim]) + + x_updated = gather_nd(x, indices) + + # # Step 5: optionally squeeze shrinked axes + for i in range(ndim-1, -1, -1): + if (shrink_axis_mask >> i) & 1: + x_updated = squeeze(x_updated, dim=i) + + return x_updated + +def strided_slice_update(x, begin, end, strides, updates, + begin_mask=0, end_mask=0, + ellipsis_mask=0, new_axis_mask=0, + shrink_axis_mask=0): + x_shape = x.shape + ndim = len(x_shape) + + full_begin, full_end, full_strides = [], [], [] + dim = 0 # 当前 x 的维度 + i = 0 # 当前 begin/end 索引 + + while dim < ndim: + # ellipsis_mask + if i < len(begin) and ((ellipsis_mask >> i) & 1): + remaining_dims = ndim - dim - (len(begin) - i - 1) + shrink_axis_mask = shrink_axis_mask << remaining_dims - 1 + for _ in range(remaining_dims): + full_begin.append(0) + full_end.append(x_shape[dim]) + full_strides.append(1) + dim += 1 + i += 1 + continue + + # new_axis_mask + elif i < len(begin) and ((new_axis_mask >> i) & 1): + full_begin.append(0) + full_end.append(1) + full_strides.append(1) + i += 1 + continue + + else: + # 自动补齐 begin/end/strides + b = begin[i] if i < len(begin) else 0 + e = end[i] if i < len(end) else x_shape[dim] + s = strides[i] if i < len(strides) else 1 + if b < 0: + b %= x_shape[dim] + if e == 0: + e += x_shape[dim] + if e < 0: + e %= x_shape[dim] + # begin_mask / end_mask + if i < len(begin) and ((begin_mask >> i) & 1): + b = 0 if s > 0 else x_shape[dim]-1 + if i < len(end) and ((end_mask >> i) & 1): + e = x_shape[dim] if s > 0 else -1 + + full_begin.append(b) + full_end.append(e) + full_strides.append(s) + + dim += 1 + i += 1 + + # Step 2: 计算目标切片 shape(考虑 shrink_axis_mask) + target_shape = [] + + for d, (b, e, s) in enumerate(zip(full_begin, full_end, full_strides)): + if (shrink_axis_mask >> d) & 1: + continue + length = py_max(0, (py_abs(e - b) + py_abs(s) - 1) // py_abs(s)) + target_shape.append(length) + + # Step 3: broadcast updates if scalar + updates = broadcast_to(updates, target_shape) + + # Step 2: generate indices for scatter update + ranges = [arange(b, e, s, mindspore.int64) for b, e, s in zip(full_begin, full_end, full_strides)] + mesh = meshgrid(ranges, 'ij') + indices = stack(mesh, -1) + indices = reshape(indices, [-1, ndim]) + + # Step 3: flatten updates + updates_flat = reshape(updates, [-1]) + # if updates.shape[0] == 1 and updates.shape[0] != indices.shape[0]: + # updates = updates.broadcast_to((indices.shape[0],)) + # Step 4: apply scatter update + if x.dtype == mindtorch.bool: + x_updated = cast(scatter_nd_update(cast(x, mindspore.int32), indices, cast(updates_flat, mindspore.int32)), mindspore.bool_) + else: + x_updated = scatter_nd_update(x, indices, updates_flat) + + assign(x, x_updated) + # # Step 5: optionally squeeze shrinked axes + # for i in range(ndim-1, -1, -1): + # if (shrink_axis_mask >> i) & 1: + # x_updated = mindtorch.squeeze(x_updated, dim=i) + return x_updated \ No newline at end of file diff --git a/mindtorch/_apis/meta.py b/mindtorch/_apis/meta.py index d7b815cc5..3a1693f36 100644 --- a/mindtorch/_apis/meta.py +++ b/mindtorch/_apis/meta.py @@ -185,7 +185,14 @@ def log(input): __all__.append('log') def mul(input, other): - out = Tensor_(init='none', shape=input.shape, dtype=input.dtype) + if isinstance(input, mindtorch.Tensor): + shape = input.shape + dtype = input.dtype + else: + shape = other.shape + dtype = other.dtype + + out = Tensor_(init='none', shape=shape, dtype=dtype) return mindtorch.Tensor(out) __all__.append('mul') @@ -393,4 +400,18 @@ def select(condition, input, other): return input def logical_not(input): - return input \ No newline at end of file + return input + +def pad(input, pad, mode='constant', value=None): + size = input.shape + if len(pad) == 2: + new_size = size[:-1] + (size[-1] + sum(pad),) + elif len(pad) == 4: + new_size = size[:-2] + (size[-2] + pad[2] + pad[3], size[-1] + pad[0] + pad[1]) + elif len(pad) == 6: + new_size = size[:-3] + (size[-3] + pad[4] + pad[5], size[-2] + pad[2] + pad[3], size[-1] + pad[0] + pad[1]) + else: + raise ValueError('pad size must be 2, 4 or 6') + + out = Tensor_(init='none', shape=new_size, dtype=input.dtype) + return mindtorch.Tensor(out) \ No newline at end of file diff --git a/mindtorch/_apis/npu_910a.py b/mindtorch/_apis/npu_910a.py index a07b9fc00..6cbf301fb 100644 --- a/mindtorch/_apis/npu_910a.py +++ b/mindtorch/_apis/npu_910a.py @@ -433,7 +433,7 @@ def eq(input, other): return pyboost.equal_op(input, other) return legacy.equal(input, other) - +py_sum = sum def sum(input, dim, keepdim, dtype): """ Returns the sum of elements over a specified dimension. @@ -765,6 +765,11 @@ def arange(start, end, step, dtype): out = cast(out, dtype) return out +def full_like(input, fill_value, dtype=None): + if fill_value == -math.inf: + fill_value = mindtorch.finfo(input.dtype).min + return pyboost.full_like_op(input, fill_value, dtype) + def fill_scalar(input, value, dtype): if ENABLE_PYBOOST: return pyboost.fill_scalar_op(input, value, dtype) @@ -1090,7 +1095,7 @@ def sin(input): return pyboost.sin_op(input) return legacy.sin(input) -def batch_norm(input, weight, bias, running_mean=None, runnning_var=None, training=False, momentum=0.1, epsilon=1e-5): +def batch_norm(input, weight, bias, running_mean=None, running_var=None, training=False, momentum=0.1, epsilon=1e-5): if running_mean is None: running_mean = ones(input.shape[1], dtype=input.dtype) if running_var is None: @@ -1100,8 +1105,8 @@ def batch_norm(input, weight, bias, running_mean=None, runnning_var=None, traini if bias is None: bias = zeros(input.shape[1], dtype=input.dtype) if ENABLE_PYBOOST: - return pyboost.batch_norm_ext_op(input, weight, bias, running_mean, runnning_var, training, momentum, epsilon) - return legacy.batch_norm(input, weight, bias, running_mean, runnning_var, training, epsilon, momentum, 'NCHW') + return pyboost.batch_norm_ext_op(input, weight, bias, running_mean, running_var, training, momentum, epsilon) + return legacy.batch_norm(input, weight, bias, running_mean, running_var, training, epsilon, momentum, 'NCHW') def silu(input): if ENABLE_PYBOOST: @@ -2239,7 +2244,7 @@ def pad(input, pad, mode='constant', value=None): input = narrow(input, dim, 0, input.shape[dim] + pad_v) pad_v = 0 new_pad += (pad_v,) - if sum(new_pad) == 0: + if py_sum(new_pad) == 0: return input if mode == 'circular': return custom_circular_pad(input, pad) @@ -2250,11 +2255,11 @@ def pad(input, pad, mode='constant', value=None): if mode == "replicate": mode = "edge" return pad_v3(input, new_pad, mode) - if input.dtype.is_floating_point: - value = float(value) - elif input.dtype == mindtorch.bool: - value = bool(value) - elif input.dtype in [mindtorch.int32, mindtorch.int64]: - value = int(value) + # if input.dtype.is_floating_point: + # value = float(value) + # elif input.dtype == mindtorch.bool: + # value = bool(value) + # elif input.dtype in [mindtorch.int32, mindtorch.int64]: + # value = int(value) return pad_v3(input, new_pad, mode, value) \ No newline at end of file diff --git a/mindtorch/_apis/npu_910b.py b/mindtorch/_apis/npu_910b.py index c0a6405b0..8670438ce 100644 --- a/mindtorch/_apis/npu_910b.py +++ b/mindtorch/_apis/npu_910b.py @@ -2272,4 +2272,7 @@ def pad(input, pad, mode='constant', value=None): out = _replication_pad(input, pad) else: raise ValueError(f"Pad filling mode must be 'constant' 'circular' 'reflect' or 'replicate'.") - return out \ No newline at end of file + return out + +def full_like(input, fill_value, dtype=None): + return pyboost.full_like_op(input, fill_value, dtype) diff --git a/mindtorch/ops/creation.py b/mindtorch/ops/creation.py index 4cbe5d4d3..bc50b3874 100644 --- a/mindtorch/ops/creation.py +++ b/mindtorch/ops/creation.py @@ -200,9 +200,7 @@ def full(size, fill_value, *, out=None, dtype=None, layout=None, device=None, re # full_like def full_like(input, fill_value, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=None): - if dtype is None: - dtype = input.dtype - return full(input.shape, fill_value, dtype=dtype, layout=layout, device=input.device) + return execute('full_like', input, fill_value, dtype=dtype) # quantize_per_tensor