Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mindnlp/core/_C/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _parse_to(*args, **kwargs):
elif isinstance(args[0], core.device): # device only
device = args[0]
dtype = None
elif isinstance(args[0], str):
elif isinstance(args[0], (str, int)):
device = device_(args[0])
dtype = None
else:
Expand Down
1 change: 1 addition & 0 deletions mindnlp/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from .serialization import load, save
from ._bind import get_default_dtype, set_default_dtype, get_default_device
from .amp import autocast, GradScaler
from .func import vmap

from . import profiler, cuda, optim, amp, compiler, jit, version, __future__, overrides, \
return_types, linalg, fx, backends, testing, nn, fft, _jit_internal, utils
Expand Down
7 changes: 6 additions & 1 deletion mindnlp/core/_dynamo/_trace_wrapped_higher_order_op.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
class TransformGetItemToIndex: pass
class TransformGetItemToIndex:
def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
pass
Empty file.
51 changes: 51 additions & 0 deletions mindnlp/core/_functorch/apis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Callable
from mindnlp import core

def vmap(
func: Callable,
in_dims = 0,
out_dims = 0,
randomness: str = "error",
*,
chunk_size=None,
) -> Callable:
def batched_func(*args):
# 统一处理in_dims格式
if not isinstance(in_dims, tuple):
in_dims_tuple = (in_dims,) * len(args)
else:
in_dims_tuple = in_dims

# 验证输入维度一致性
batch_sizes = set()
for i, (arg, dim) in enumerate(zip(args, in_dims_tuple)):
if dim is not None:
batch_sizes.add(arg.shape[dim])

if len(batch_sizes) > 1:
raise ValueError(f"不一致的批处理大小: {batch_sizes}")
batch_size = next(iter(batch_sizes)) if batch_sizes else 1

# 收集单个样本的结果
results = []
for b in range(batch_size):
# 为当前批次构造输入
single_args = []
for arg, dim in zip(args, in_dims_tuple):
if dim is None:
single_args.append(arg)
else:
# 切片获取当前批次的样本
slices = [slice(None)] * arg.ndim
slices[dim] = b
single_args.append(arg[tuple(slices)])

# 调用原始函数
result = func(*single_args)
results.append(result)

# 堆叠结果并调整维度
stacked = core.stack(results, dim=out_dims)
return stacked

return batched_func
5 changes: 5 additions & 0 deletions mindnlp/core/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,8 @@ def new_ones(self, *size, dtype=None, device=None, requires_grad=False, layout=N
if isinstance(s, Tensor):
s = s.item()
new_size += (s,)
if new_size == new_size:
new_size = (new_size,)
return ops.ones(*new_size, dtype=dtype if dtype is not None else self.dtype)

Tensor.new_ones = new_ones
Expand Down Expand Up @@ -792,6 +794,9 @@ def tobytes(self):
Tensor.copy_ = ops.inplace_copy
StubTensor.copy_ = ops.inplace_copy

Tensor.index_add_ = ops.inplace_index_add
StubTensor.index_add_ = ops.inplace_index_add

def _rebuild_from_type_v2(func, new_type, args, state):
ret = func(*args)
return ret
1 change: 1 addition & 0 deletions mindnlp/core/func/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .._functorch.apis import vmap
3 changes: 2 additions & 1 deletion mindnlp/core/nn/attention/flex_attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
BlockMask = None
from mindnlp import core
BlockMask = core.Tensor
flex_attention = None
create_block_mask = None
_DEFAULT_SPARSE_BLOCK_SIZE = None
10 changes: 7 additions & 3 deletions mindnlp/core/ops/inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mindspore.ops.auto_generate.gen_ops_prim import inplace_normal_op, inplace_scatter_value_op, inplace_scatter_src_reduce_op, \
inplace_scatter_src_op, inplace_fill_tensor_op, inplace_fill_scalar_op, inplace_zero_op, inplace_uniform_op, \
inplace_masked_fill_scalar_op, inplace_masked_fill_tensor_op, inplace_random_op, inplace_clamp_scalar_op, \
inplace_clamp_tensor_op, inplace_copy_op
inplace_clamp_tensor_op, inplace_copy_op, inplace_index_add_op

from mindnlp import core
from ..configs import use_pyboost
Expand Down Expand Up @@ -96,8 +96,12 @@ def inplace_index_copy(input, dim, index, tensor):
return input

def inplace_index_add(input, dim, index, source):
_inplace = _get_cache_prim(ops.InplaceIndexAdd)(dim)
return _inplace(input, index, source)
if input.device == 'npu':
inplace_index_add_op(input, dim, index, source)
else:
_inplace = _get_cache_prim(ops.IndexAdd)(dim)
input.data = _inplace(input, index.int(), source)
return input

has_squeeze = hasattr(mindspore.mint, 'squeeze')
def inplace_squeeze(input, *dim, **kwargs):
Expand Down
6 changes: 5 additions & 1 deletion mindnlp/core/ops/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,10 @@ def diag(input, diagonal=0):
# diagonal

# diff

def diff(input, n=1, dim=-1, prepend=None, append=None):
if use_pyboost():
return mindspore.mint.diff(input, n, dim, prepend, append)
return ops.diff(input, n, dim, prepend, append)

# einsum

Expand Down Expand Up @@ -1081,6 +1084,7 @@ def cosine_similarity(*args, **kwargs):
"cumsum",
"cumprod",
"diag",
"diff",
"dim_list_to_bitset",
"einsum",
"einsum_label_to_index",
Expand Down
3 changes: 2 additions & 1 deletion tests/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def run_tests():
"and not torchscript " \
"and not torch_fx " \
"and not test_wrong_device_map " \
"and not test_layerwise_casting"
"and not test_layerwise_casting " \
"and not test_flex_attention"

pytest_args.extend(["--ignore-glob=test_modeling_flax_*.py"])
pytest_args.extend(['-k', skip_ut])
Expand Down
Loading