diff --git a/mindnlp/core/_C/_nn.py b/mindnlp/core/_C/_nn.py index a868c046e..3bf10046c 100644 --- a/mindnlp/core/_C/_nn.py +++ b/mindnlp/core/_C/_nn.py @@ -21,7 +21,7 @@ def _parse_to(*args, **kwargs): elif isinstance(args[0], core.device): # device only device = args[0] dtype = None - elif isinstance(args[0], str): + elif isinstance(args[0], (str, int)): device = device_(args[0]) dtype = None else: diff --git a/mindnlp/core/__init__.py b/mindnlp/core/__init__.py index eafcd7230..a75b3e39a 100644 --- a/mindnlp/core/__init__.py +++ b/mindnlp/core/__init__.py @@ -52,6 +52,7 @@ from .serialization import load, save from ._bind import get_default_dtype, set_default_dtype, get_default_device from .amp import autocast, GradScaler +from .func import vmap from . import profiler, cuda, optim, amp, compiler, jit, version, __future__, overrides, \ return_types, linalg, fx, backends, testing, nn, fft, _jit_internal, utils diff --git a/mindnlp/core/_dynamo/_trace_wrapped_higher_order_op.py b/mindnlp/core/_dynamo/_trace_wrapped_higher_order_op.py index 08a078b7c..6be149291 100644 --- a/mindnlp/core/_dynamo/_trace_wrapped_higher_order_op.py +++ b/mindnlp/core/_dynamo/_trace_wrapped_higher_order_op.py @@ -1 +1,6 @@ -class TransformGetItemToIndex: pass +class TransformGetItemToIndex: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass \ No newline at end of file diff --git a/mindnlp/core/_functorch/__init__.py b/mindnlp/core/_functorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/_functorch/apis.py b/mindnlp/core/_functorch/apis.py new file mode 100644 index 000000000..fa46cd67f --- /dev/null +++ b/mindnlp/core/_functorch/apis.py @@ -0,0 +1,51 @@ +from typing import Callable +from mindnlp import core + +def vmap( + func: Callable, + in_dims = 0, + out_dims = 0, + randomness: str = "error", + *, + chunk_size=None, +) -> Callable: + def batched_func(*args): + # 统一处理in_dims格式 + if not isinstance(in_dims, tuple): + in_dims_tuple = (in_dims,) * len(args) + else: + in_dims_tuple = in_dims + + # 验证输入维度一致性 + batch_sizes = set() + for i, (arg, dim) in enumerate(zip(args, in_dims_tuple)): + if dim is not None: + batch_sizes.add(arg.shape[dim]) + + if len(batch_sizes) > 1: + raise ValueError(f"不一致的批处理大小: {batch_sizes}") + batch_size = next(iter(batch_sizes)) if batch_sizes else 1 + + # 收集单个样本的结果 + results = [] + for b in range(batch_size): + # 为当前批次构造输入 + single_args = [] + for arg, dim in zip(args, in_dims_tuple): + if dim is None: + single_args.append(arg) + else: + # 切片获取当前批次的样本 + slices = [slice(None)] * arg.ndim + slices[dim] = b + single_args.append(arg[tuple(slices)]) + + # 调用原始函数 + result = func(*single_args) + results.append(result) + + # 堆叠结果并调整维度 + stacked = core.stack(results, dim=out_dims) + return stacked + + return batched_func \ No newline at end of file diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py index 02ffdbbf1..155d75041 100644 --- a/mindnlp/core/_tensor.py +++ b/mindnlp/core/_tensor.py @@ -607,6 +607,8 @@ def new_ones(self, *size, dtype=None, device=None, requires_grad=False, layout=N if isinstance(s, Tensor): s = s.item() new_size += (s,) + if new_size == new_size: + new_size = (new_size,) return ops.ones(*new_size, dtype=dtype if dtype is not None else self.dtype) Tensor.new_ones = new_ones @@ -792,6 +794,9 @@ def tobytes(self): Tensor.copy_ = ops.inplace_copy StubTensor.copy_ = ops.inplace_copy + Tensor.index_add_ = ops.inplace_index_add + StubTensor.index_add_ = ops.inplace_index_add + def _rebuild_from_type_v2(func, new_type, args, state): ret = func(*args) return ret \ No newline at end of file diff --git a/mindnlp/core/func/__init__.py b/mindnlp/core/func/__init__.py new file mode 100644 index 000000000..7c1fe9b52 --- /dev/null +++ b/mindnlp/core/func/__init__.py @@ -0,0 +1 @@ +from .._functorch.apis import vmap \ No newline at end of file diff --git a/mindnlp/core/nn/attention/flex_attention.py b/mindnlp/core/nn/attention/flex_attention.py index c2db2c1e4..26d6a905a 100644 --- a/mindnlp/core/nn/attention/flex_attention.py +++ b/mindnlp/core/nn/attention/flex_attention.py @@ -1,4 +1,5 @@ -BlockMask = None +from mindnlp import core +BlockMask = core.Tensor flex_attention = None create_block_mask = None _DEFAULT_SPARSE_BLOCK_SIZE = None \ No newline at end of file diff --git a/mindnlp/core/ops/inplace.py b/mindnlp/core/ops/inplace.py index fa53de7e1..02f15d0a1 100644 --- a/mindnlp/core/ops/inplace.py +++ b/mindnlp/core/ops/inplace.py @@ -7,7 +7,7 @@ from mindspore.ops.auto_generate.gen_ops_prim import inplace_normal_op, inplace_scatter_value_op, inplace_scatter_src_reduce_op, \ inplace_scatter_src_op, inplace_fill_tensor_op, inplace_fill_scalar_op, inplace_zero_op, inplace_uniform_op, \ inplace_masked_fill_scalar_op, inplace_masked_fill_tensor_op, inplace_random_op, inplace_clamp_scalar_op, \ - inplace_clamp_tensor_op, inplace_copy_op + inplace_clamp_tensor_op, inplace_copy_op, inplace_index_add_op from mindnlp import core from ..configs import use_pyboost @@ -96,8 +96,12 @@ def inplace_index_copy(input, dim, index, tensor): return input def inplace_index_add(input, dim, index, source): - _inplace = _get_cache_prim(ops.InplaceIndexAdd)(dim) - return _inplace(input, index, source) + if input.device == 'npu': + inplace_index_add_op(input, dim, index, source) + else: + _inplace = _get_cache_prim(ops.IndexAdd)(dim) + input.data = _inplace(input, index.int(), source) + return input has_squeeze = hasattr(mindspore.mint, 'squeeze') def inplace_squeeze(input, *dim, **kwargs): diff --git a/mindnlp/core/ops/other.py b/mindnlp/core/ops/other.py index 57189f5bb..e7b280401 100644 --- a/mindnlp/core/ops/other.py +++ b/mindnlp/core/ops/other.py @@ -202,7 +202,10 @@ def diag(input, diagonal=0): # diagonal # diff - +def diff(input, n=1, dim=-1, prepend=None, append=None): + if use_pyboost(): + return mindspore.mint.diff(input, n, dim, prepend, append) + return ops.diff(input, n, dim, prepend, append) # einsum @@ -1081,6 +1084,7 @@ def cosine_similarity(*args, **kwargs): "cumsum", "cumprod", "diag", + "diff", "dim_list_to_bitset", "einsum", "einsum_label_to_index", diff --git a/tests/run_test.py b/tests/run_test.py index f4b6e6d93..6891cf5da 100644 --- a/tests/run_test.py +++ b/tests/run_test.py @@ -33,7 +33,8 @@ def run_tests(): "and not torchscript " \ "and not torch_fx " \ "and not test_wrong_device_map " \ - "and not test_layerwise_casting" + "and not test_layerwise_casting " \ + "and not test_flex_attention" pytest_args.extend(["--ignore-glob=test_modeling_flax_*.py"]) pytest_args.extend(['-k', skip_ut])