From 6b4eed71f17647ca88fd7929a8bc63184300e269 Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Wed, 6 Aug 2025 23:05:36 +0800 Subject: [PATCH 1/2] support gpt-oss --- mindnlp/core/_tensor.py | 13 +++++++++++++ mindnlp/core/npu/__init__.py | 3 +++ mindnlp/core/ops/pointwise.py | 8 +++++++- mindnlp/utils/safetensors_patch.py | 3 ++- 4 files changed, 25 insertions(+), 2 deletions(-) diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py index 6bb44dd55..8857647c2 100644 --- a/mindnlp/core/_tensor.py +++ b/mindnlp/core/_tensor.py @@ -20,6 +20,7 @@ class StubTensor: pass from ._utils import _rebuild_tensor_v2 from ._C.size import Size from .types import DEVICE_MAP +from .configs import DEVICE_TARGET DTYPE_ELEMENT_SIZE_MAP = { mindspore.float64: 8, @@ -824,6 +825,18 @@ def record_stream(self, stream): Tensor.gather = ops.gather StubTensor.gather = ops.gather + def is_cuda(self): + device_type = 'cuda' + if DEVICE_TARGET == 'Ascend': + device_type = 'npu' + return self.device.type == device_type + + Tensor.is_cuda = is_cuda + StubTensor.is_cuda = is_cuda + + Tensor.__rshift__ = ops.bitwise_right_shift + StubTensor.__rshift__ = ops.bitwise_right_shift + def _rebuild_from_type_v2(func, new_type, args, state): ret = func(*args) return ret \ No newline at end of file diff --git a/mindnlp/core/npu/__init__.py b/mindnlp/core/npu/__init__.py index 0f862261b..167042e83 100644 --- a/mindnlp/core/npu/__init__.py +++ b/mindnlp/core/npu/__init__.py @@ -64,3 +64,6 @@ def mem_get_info(index): def current_device(): return core.device('npu', 0) + +def get_device_capability(device=None): + return 10, 0 \ No newline at end of file diff --git a/mindnlp/core/ops/pointwise.py b/mindnlp/core/ops/pointwise.py index 322ef7e65..f053ce18f 100644 --- a/mindnlp/core/ops/pointwise.py +++ b/mindnlp/core/ops/pointwise.py @@ -494,7 +494,12 @@ def imag(input): # ldexp - +def ldexp(input, other, out=None): + output = ops.ldexp(input, other) + if out is not None: + out.data = output + return out + return output # lerp has_lerp = hasattr(mindspore.mint, "lerp") @@ -1005,6 +1010,7 @@ def relu(input): "igamma", "igammac", "imag", + "ldexp", "lerp", "lgamma", "log", diff --git a/mindnlp/utils/safetensors_patch.py b/mindnlp/utils/safetensors_patch.py index bdc0b38d4..22310e26b 100644 --- a/mindnlp/utils/safetensors_patch.py +++ b/mindnlp/utils/safetensors_patch.py @@ -196,7 +196,8 @@ def __exit__(self, *args): def metadata(self): meta = self.__metadata__ - meta['format'] = 'pt' + if meta is not None: + meta['format'] = 'pt' return meta def keys(self): From 1a2d3ce8830070b03e67aced9795a2207fa044ef Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Wed, 13 Aug 2025 12:32:14 +0800 Subject: [PATCH 2/2] support multiprocess inference on 910A --- mindnlp/__init__.py | 2 + mindnlp/accelerate/__init__.py | 18 - .../utils/__init__.py} | 0 mindnlp/accelerate/utils/modeling.py | 0 mindnlp/core/_bind.py | 12 + mindnlp/core/_prims/ascend.py | 159 + mindnlp/core/_prims/ascend/__init__.py | 7 - mindnlp/core/_prims/ascend/aclop.py | 87 - mindnlp/core/_prims/ascend/pyboost.py | 202 -- mindnlp/core/_prims/cpu.py | 196 ++ mindnlp/core/_prims/cpu/__init__.py | 249 -- mindnlp/core/_prims/meta.py | 264 ++ mindnlp/core/_prims/numpy.py | 470 +++ mindnlp/core/_tensor.py | 2770 ++++++++++++--- mindnlp/core/configs.py | 1 + mindnlp/core/dispatcher.py | 77 + .../core/distributed/c10d/process_group.py | 4 +- mindnlp/core/distributed/distributed_c10d.py | 2 + mindnlp/core/executor.py | 13 + mindnlp/core/linalg/__init__.py | 4 +- mindnlp/core/nn/functional.py | 225 +- mindnlp/core/nn/modules/module.py | 44 +- mindnlp/core/nn/modules/sparse.py | 8 +- mindnlp/core/nn/parameter.py | 12 +- mindnlp/core/npu/__init__.py | 44 +- mindnlp/core/ops/_inner.py | 17 +- mindnlp/core/ops/array.py | 743 ++-- mindnlp/core/ops/blas.py | 65 +- mindnlp/core/ops/comparison.py | 174 +- mindnlp/core/ops/complex.py | 2 + mindnlp/core/ops/creation.py | 388 +-- mindnlp/core/ops/inplace.py | 228 +- mindnlp/core/ops/optim.py | 36 +- mindnlp/core/ops/other.py | 1295 +++---- mindnlp/core/ops/pointwise.py | 845 ++--- mindnlp/core/ops/random.py | 463 ++- mindnlp/core/ops/reduction.py | 453 +-- mindnlp/core/types.py | 12 +- mindnlp/transformers/__init__.py | 26 +- mindnlp/transformers/modeling_utils.py | 213 ++ mindnlp/transformers/models/__init__.py | 0 mindnlp/utils/__init__.py | 8 +- mindnlp/utils/decorators.py | 8 +- mindnlp/utils/generic.py | 2 +- mindnlp/utils/import_utils.py | 3083 ++++++++++++++--- mindnlp/utils/safetensors_patch.py | 6 +- mindnlp/utils/torch_proxy.py | 3 + 47 files changed, 8371 insertions(+), 4569 deletions(-) rename mindnlp/{core/ops/fft_op.py => accelerate/utils/__init__.py} (100%) create mode 100644 mindnlp/accelerate/utils/modeling.py create mode 100644 mindnlp/core/_prims/ascend.py delete mode 100644 mindnlp/core/_prims/ascend/__init__.py delete mode 100644 mindnlp/core/_prims/ascend/aclop.py delete mode 100644 mindnlp/core/_prims/ascend/pyboost.py create mode 100644 mindnlp/core/_prims/cpu.py delete mode 100644 mindnlp/core/_prims/cpu/__init__.py create mode 100644 mindnlp/core/_prims/meta.py create mode 100644 mindnlp/core/_prims/numpy.py create mode 100644 mindnlp/core/dispatcher.py create mode 100644 mindnlp/core/executor.py create mode 100644 mindnlp/transformers/modeling_utils.py create mode 100644 mindnlp/transformers/models/__init__.py diff --git a/mindnlp/__init__.py b/mindnlp/__init__.py index ea086e3c7..174037b12 100644 --- a/mindnlp/__init__.py +++ b/mindnlp/__init__.py @@ -36,6 +36,8 @@ except: disable_multi_thread = None # for different ascend devices + context.set_context(device_target='CPU') + if platform.system().lower() == 'linux': SOC = MSContext.get_instance().get_ascend_soc_version() if ('910b' not in SOC and '310' not in SOC) or version.parse(mindspore.__version__) < version.parse('2.4.0'): diff --git a/mindnlp/accelerate/__init__.py b/mindnlp/accelerate/__init__.py index 51c7f5deb..e69de29bb 100644 --- a/mindnlp/accelerate/__init__.py +++ b/mindnlp/accelerate/__init__.py @@ -1,18 +0,0 @@ -import sys -import accelerate -from transformers.utils import _LazyModule - -_import_structure = { - "utils": [ - 'DistributedType', - - ] -} - -sys.modules[__name__] = _LazyModule( - 'accelerate', - accelerate.__file__, - _import_structure, - module_spec=__spec__, - extra_objects={"__version__": accelerate.__version__}, -) diff --git a/mindnlp/core/ops/fft_op.py b/mindnlp/accelerate/utils/__init__.py similarity index 100% rename from mindnlp/core/ops/fft_op.py rename to mindnlp/accelerate/utils/__init__.py diff --git a/mindnlp/accelerate/utils/modeling.py b/mindnlp/accelerate/utils/modeling.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/_bind.py b/mindnlp/core/_bind.py index 9a95467c6..60358145a 100644 --- a/mindnlp/core/_bind.py +++ b/mindnlp/core/_bind.py @@ -6,6 +6,8 @@ DEFAULT_DTYPE, DEFAULT_DEVICE = float32, device_('cpu') +DEVICE_IN_CONTEXT = None + AUTO_CAST_DTYE = { 'cuda': float16, 'cpu': bfloat16, @@ -41,6 +43,16 @@ def get_default_device(): """get default dtype""" return DEFAULT_DEVICE +def set_device_in_context(device): + global DEVICE_IN_CONTEXT + DEVICE_IN_CONTEXT = device + +def get_device_in_context(): + """get default dtype""" + if DEVICE_IN_CONTEXT is None: + return get_default_device() + return DEVICE_IN_CONTEXT + bits_map = { } diff --git a/mindnlp/core/_prims/ascend.py b/mindnlp/core/_prims/ascend.py new file mode 100644 index 000000000..e9911115d --- /dev/null +++ b/mindnlp/core/_prims/ascend.py @@ -0,0 +1,159 @@ +import numbers +from mindspore import ops +from mindspore.ops.auto_generate import gen_ops_prim +from mindspore.ops.auto_generate import pyboost_inner_prim +from mindspore._c_expression import _empty_instance + +from mindnlp import core +from mindnlp.core._C import default_generator + +op_list = list(filter(lambda s: s.endswith("_op"), dir(gen_ops_prim))) + +__all__ = [] + +for op_name in op_list: + func_name = op_name.replace('_op', '') + __all__.append(func_name) + globals()[func_name] = getattr(gen_ops_prim, op_name).__class__().set_device('Ascend') + +def empty(*args, **kwargs): + return _empty_instance(*args, **kwargs, device='Ascend') + +def reduce_any(input, dim, keepdim): + if dim is None: + dim = () + return pyboost_inner_prim.reduce_any_impl(input, dim, keepdim) + +__all__.append('reduce_any') + +def reduce_all(input, dim, keepdim): + if dim is None: + dim = () + return pyboost_inner_prim.reduce_all_impl(input, dim, keepdim) + +__all__.append('reduce_all') + +broadcast_to_op = ops.Primitive('BroadcastTo').set_device('Ascend') +def broadcast_to(*args): + return broadcast_to_op(*args) + +__all__.append('broadcast_to') + +cast_op = ops.Cast().set_device('Ascend') +def cast(*args): + return cast_op(*args) + +__all__.append('cast') + +zeros_op = ops.Zeros().set_device('Ascend') +def zeros(*args): + return zeros_op(*args) + +__all__.append('zeros') + +def softmax(*args): + return pyboost_inner_prim.softmax_impl(*args) + +__all__.append('softmax') + +def dropout_ext(input, p): + seed, offset = default_generator._step(12) # pylint: disable=protected-access + return gen_ops_prim.dropout_ext_op(input, p, seed, offset) + +def squeeze(*args): + return pyboost_inner_prim.squeeze_impl(*args) + +__all__.append('squeeze') + +ones_op = ops.Ones().set_device('Ascend') +def ones(*args): + return ones_op(*args) + +__all__.append('ones') + +def nllloss(*args): + return pyboost_inner_prim.nllloss_impl(*args) + +__all__.append('nllloss') + +def repeat_elements(*args): + return ops.repeat_elements(*args) + +__all__.append('repeat_elements') + +def concat(*args): + return pyboost_inner_prim.concat_impl(*args) + +__all__.append('concat') + +def multinomial_ext(input, num_samples, replacement, generator): + seed, offset = generator._step(12) # pylint: disable=protected-access + return gen_ops_prim.multinomial_ext_op(input, num_samples, replacement, seed, offset) + +def isclose(*args): + return pyboost_inner_prim.isclose_impl(*args) + +__all__.append('isclose') + +tile_op = ops.Primitive('Tile').set_device('Ascend') +def tile(*args): + return tile_op(*args) + +__all__.append('tile') + +def pad_v3(input_x, padding, mode='constant', value=None): + pad_op = ops.PadV3(mode=mode, paddings_contiguous=True).set_device('CPU') + if isinstance(value, (float, int)): + value = core.tensor(value, dtype=input_x.dtype) + return pad_op(input_x, padding, value) + +__all__.append('pad_v3') + +def inplace_uniform(input, from_, to_, generator_): + seed, offset = generator_._step(12) + return gen_ops_prim.inplace_uniform_op(input, from_, to_, seed, offset) + +def binary_cross_entropy_with_logits(*args): + return pyboost_inner_prim.binary_cross_entropy_with_logits_impl(*args) + +__all__.append('binary_cross_entropy_with_logits') + +def gather(input_params, input_indices, axis, batch_dims=0): + return ops.gather(input_params, input_indices, axis, batch_dims) + +__all__.append('gather') + +def randint(low, high, shape, dtype, generator): + seed, offset = generator._step(12) # pylint: disable=protected-access + return gen_ops_prim.randint_op(low, high, shape, seed, offset, dtype) + +def stack_ext(*args): + return pyboost_inner_prim.stack_ext_impl(*args) + +__all__.append('stack_ext') + +def argmax_with_value(*args): + return pyboost_inner_prim.argmax_with_value_impl(*args) + +__all__.append('argmax_with_value') + +right_shift_op = ops.RightShift().set_device('Ascend') +def right_shift(input, other): + if isinstance(other, numbers.Number): + other = core.Tensor(other, input.dtype) + return right_shift_op(input, other) + +tensor_mul = ops.Mul().set_device('Ascend') +tensor_pow = ops.Pow().set_device('Ascend') +def ldexp(input, other): + out = tensor_mul(input, tensor_pow(2.0, other)) + return out + +__all__.append('ldexp') + +def reverse_v2(input, dims): + if isinstance(dims, int): + dims = (dims,) + return pyboost_inner_prim.reverse_v2_impl(input, dims) + +__all__.append('reverse_v2') diff --git a/mindnlp/core/_prims/ascend/__init__.py b/mindnlp/core/_prims/ascend/__init__.py deleted file mode 100644 index 95e4a4e45..000000000 --- a/mindnlp/core/_prims/ascend/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from . import aclop, pyboost -from .aclop import * -from .pyboost import * - -__all__ = [] -__all__.extend(aclop.__all__) -__all__.extend(pyboost.__all__) diff --git a/mindnlp/core/_prims/ascend/aclop.py b/mindnlp/core/_prims/ascend/aclop.py deleted file mode 100644 index c59c0beaa..000000000 --- a/mindnlp/core/_prims/ascend/aclop.py +++ /dev/null @@ -1,87 +0,0 @@ -from mindspore.ops.auto_generate import gen_ops_prim -from mindspore.common.api import _pynative_executor -from mindspore.ops._primitive_cache import _get_cache_prim -from mindspore.ops.auto_generate.gen_ops_prim import Range, Cdist -from mindspore.ops import StopGradient, Primitive, ApplyAdadelta, Adam, ApplyAdamWithAmsgradV2, SGD, Imag - -pyboost_list = list(filter(lambda s: s.startswith("pyboost"), dir(gen_ops_prim))) -pyboost_op_list = [op.replace('pyboost_', '') + '_op' for op in pyboost_list] -aclop_list = list(filter(lambda s: s.endswith("_op") and not s in pyboost_op_list, dir(gen_ops_prim))) - -aclop_func = ''' -def {name}(*args): - return _pynative_executor.run_op_async({obj}, {obj}.name, args) -''' - -__all__ = [] - -for op_name in aclop_list: - func_name = op_name.replace('_op', '_npu') - __all__.append(func_name) - prim_op = func_name + '_prim' - globals()[prim_op] = getattr(gen_ops_prim, op_name).__class__().set_device('Ascend') - exec(aclop_func.format(name=func_name, obj=prim_op), globals()) - -imag_op = Imag().set_device('Ascend') -def imag_npu(*args): - return _pynative_executor.run_op_async(imag_op, range_op.name, args) - -__all__.append('imag_npu') - -range_op = Range().set_device('Ascend') -def range_npu(*args): - return _pynative_executor.run_op_async(range_op, range_op.name, args) - -__all__.append('range_npu') - -cdist_op = Cdist().set_device('Ascend') -def cdist_npu(*args): - return _pynative_executor.run_op_async(cdist_op, cdist_op.name, args) - -__all__.append('cdist_npu') - - -stop_gradient_op = StopGradient().set_device('Ascend') -def stop_gradient_npu(*args): - return _pynative_executor.run_op_async(stop_gradient_op, stop_gradient_op.name, args) - -__all__.append('stop_gradient_npu') - -diagonal_op = Primitive('Diagonal').set_device('Ascend') -def diagonal_npu(*args): - return _pynative_executor.run_op_async(diagonal_op, diagonal_op.name, args) - -__all__.append('diagonal_npu') - -adadelta_op = ApplyAdadelta().set_device('Ascend') -def raw_adadelta_npu(param, square_avg, acc_delta, lr, rho, eps, grad): - args = (param, square_avg, acc_delta, lr, rho, eps, grad) - return _pynative_executor.run_op_async(adadelta_op, adadelta_op.name, args) - -adam_op = Adam().set_device('Ascend') -def raw_adam_npu(param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): - # var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad - args = (param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) - return _pynative_executor.run_op_async(adam_op, adam_op.name, args) - -adam_amsgrad_op = ApplyAdamWithAmsgradV2().set_device('Ascend') -def raw_adam_amsgrad_npu(param, exp_avg, exp_avg_sq, max_exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): - # var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad - args = (param, exp_avg, exp_avg_sq, max_exp_avg_sq, - beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) - return _pynative_executor.run_op_async(adam_amsgrad_op, adam_amsgrad_op.name, args) - - -def raw_sgd_npu(param, grad, lr, dampening, weight_decay, nesterov, accum, momentum, stat): - sgd_op = _get_cache_prim(SGD)(dampening, weight_decay, nesterov).set_device('Ascend') - args = (param, grad, lr, accum, momentum, stat) - return _pynative_executor.run_op_async(sgd_op, sgd_op.name, args) - -__all__.extend( - [ - 'raw_adadelta_npu', - 'raw_adam_npu', - 'raw_adam_amsgrad_npu', - 'raw_sgd_npu' - ] -) diff --git a/mindnlp/core/_prims/ascend/pyboost.py b/mindnlp/core/_prims/ascend/pyboost.py deleted file mode 100644 index 31bddfc26..000000000 --- a/mindnlp/core/_prims/ascend/pyboost.py +++ /dev/null @@ -1,202 +0,0 @@ -from mindspore.ops import Primitive -from mindspore.ops.auto_generate import gen_ops_prim -from mindspore.ops.auto_generate.gen_ops_prim import * -from mindspore._c_expression import pyboost_cast, pyboost_zeros, pyboost_ones, pyboost_empty, \ - pyboost_reduce_max, pyboost_reduce_min, pyboost_reduce_all, pyboost_reduce_all -from mindspore.ops.operations.manually_defined.ops_def import Cast, Zeros, Ones -from mindspore.common.api import _pynative_executor - -pyboost_list = list(filter(lambda s: s.startswith("pyboost"), dir(gen_ops_prim))) - - -pyboost_func = ''' -def {name}(*args): - return {pyboost}({op}, args) -''' - -__all__ = [] - -for op_name in pyboost_list: - op = getattr(gen_ops_prim, op_name) - func_name = op_name.replace('pyboost_', '') + '_npu' - prim_op = func_name.replace('_npu', '_op') - if not hasattr(gen_ops_prim, prim_op): - continue - __all__.append(func_name) - globals()[prim_op] = getattr(gen_ops_prim, prim_op).__class__().set_device('Ascend') - exec(pyboost_func.format(name=func_name, pyboost=op_name, op=prim_op), globals()) - -cast_op = Cast().set_device('Ascend') -def cast_npu(*args): - return pyboost_cast(cast_op, args) - -__all__.append('cast_npu') - -def empty_npu(size, dtype): - return pyboost_empty([size, dtype, 'Ascend']) - -__all__.append('empty_npu') - -zeros_op = Zeros().set_device('Ascend') -def zeros_npu(*args): - return pyboost_zeros(zeros_op, args) - -__all__.append('zeros_npu') - -ones_op = Ones().set_device('Ascend') -def ones_npu(*args): - return pyboost_ones(ones_op, args) - -__all__.append('ones_npu') - - -squeeze_op = Squeeze().set_device('Ascend') -def squeeze_npu(*args): - return pyboost_squeeze(squeeze_op, args) - -__all__.append('squeeze_npu') - -stack_ext_op = StackExt().set_device('Ascend') -def stack_ext_npu(*args): - return pyboost_stack_ext(stack_ext_op, args) - -__all__.append('stack_ext_npu') - -tile_op = Primitive('Tile').set_device('Ascend') -def tile_npu(*args): - return pyboost_tile(tile_op, args) - -__all__.append('tile_npu') - -greater_equal_op = GreaterEqual().set_device('Ascend') -def greater_equal_npu(*args): - return pyboost_greater_equal(greater_equal_op, args) - -__all__.append('greater_equal_npu') - -isclose_op = IsClose().set_device('Ascend') -def isclose_npu(*args): - return pyboost_isclose(isclose_op, args) - -__all__.append('isclose_npu') - -reduce_max_op = ReduceMax().set_device('Ascend') -def reduce_max_npu(*args): - return pyboost_reduce_max(reduce_max_op, args) - -__all__.append('reduce_max_npu') - -reduce_min_op = ReduceMin().set_device('Ascend') -def reduce_min_npu(*args): - return pyboost_reduce_min(reduce_min_op, args) - -__all__.append('reduce_min_npu') - -reduce_all_op = ReduceAll().set_device('Ascend') -def reduce_all_npu(*args): - return pyboost_reduce_all(reduce_all_op, args) - -__all__.append('reduce_all_npu') - -reduce_any_op = ReduceAny().set_device('Ascend') -def reduce_any_npu(*args): - return pyboost_reduce_any(reduce_any_op, args) - -__all__.append('reduce_any_npu') - -unique_consecutive_op = UniqueConsecutive().set_device('Ascend') -def unique_consecutive_npu(*args): - return pyboost_unique_consecutive(unique_consecutive_op, args) - -__all__.append('unique_consecutive_npu') - -nan_to_num_op = NanToNum().set_device('Ascend') -def nan_to_num_npu(*args): - return pyboost_nan_to_num(nan_to_num_op, args) - -__all__.append('nan_to_num_npu') - - -softmax_op = Softmax().set_device('Ascend') -def softmax_npu(*args): - return pyboost_softmax(softmax_op, args) - -__all__.append('softmax_npu') - -broadcast_to_op = Primitive('BroadcastTo').set_device('Ascend') -def broadcast_to_npu(*args): - return pyboost_broadcast_to(broadcast_to_op, args) - -__all__.append('broadcast_to_npu') - -triu_op = Triu().set_device('Ascend') -def triu_npu(*args): - return pyboost_triu(triu_op, args) - -__all__.append('triu_npu') - -tril_ext_op = TrilExt().set_device('Ascend') -def tril_ext_npu(*args): - return pyboost_tril_ext(triu_op, args) - -__all__.append('tril_ext_npu') - -search_sorted_op = SearchSorted().set_device('Ascend') -def search_sorted_npu(*args): - return pyboost_searchsorted(search_sorted_op, args) - -__all__.append('search_sorted_npu') - -roll_op = Primitive('Roll').set_device('Ascend') -def roll_npu(*args): - return pyboost_roll(roll_op, args) - -__all__.append('roll_npu') - -meshgrid_op = Meshgrid().set_device('Ascend') -def meshgrid_npu(*args): - return pyboost_meshgrid(meshgrid_op, args) - -__all__.append('meshgrid_npu') - -reverse_v2_op = Primitive('ReverseV2').set_device('Ascend') -def reverse_v2_npu(*args): - return pyboost_reverse_v2(reverse_v2_op, args) - -__all__.append('reverse_v2_npu') - -hard_shrink_op = HShrink().set_device('Ascend') -def hard_shrink_npu(*args): - return pyboost_hshrink(hard_shrink_op, args) - -__all__.append('hard_shrink_npu') - -concat_op = Concat().set_device('Ascend') -def concat_npu(*args): - return pyboost_concat(concat_op, args) - -__all__.append('concat_npu') - -rms_norm_op = RmsNorm().set_device('Ascend') -def rms_norm_npu(*args): - return pyboost_rms_norm(rms_norm_op, args) - -__all__.append('rms_norm_npu') - -flash_attention_score_op = Primitive('FlashAttentionScore').set_device('Ascend') -def flash_attention_score_npu(*args): - return pyboost_flash_attention_score(flash_attention_score_op, args) - -__all__.append('flash_attention_score_npu') - -argmax_with_value_op = ArgMaxWithValue().set_device('Ascend') -def argmax_with_value_npu(*args): - return pyboost_argmax_with_value(argmax_with_value_op, args) - -__all__.append('argmax_with_value_npu') - -argmin_with_value_op = ArgMinWithValue().set_device('Ascend') -def argmin_with_value_npu(*args): - return pyboost_argmin_with_value(argmin_with_value_op, args) - -__all__.append('argmin_with_value_npu') diff --git a/mindnlp/core/_prims/cpu.py b/mindnlp/core/_prims/cpu.py new file mode 100644 index 000000000..247fb1663 --- /dev/null +++ b/mindnlp/core/_prims/cpu.py @@ -0,0 +1,196 @@ +import numbers +from mindspore.ops.auto_generate import gen_ops_prim +from mindspore.ops._primitive_cache import _get_cache_prim +from mindspore._c_expression import _empty_instance +from mindspore.ops.operations._grad_ops import StridedSliceGrad + +import mindspore +from mindspore import ops + +from mindnlp import core + +__all__ = [] +op_list = list(filter(lambda s: s.endswith("_op"), dir(gen_ops_prim))) + +for op_name in op_list: + func_name = op_name.replace('_op', '') + __all__.append(func_name) + globals()[func_name] = getattr(gen_ops_prim, op_name).__class__().set_device('CPU') + +def empty(*args, **kwargs): + return _empty_instance(*args, **kwargs, device='CPU') + +normal_op = ops.StandardNormal().set_device('CPU') +def normal(*args, **kwargs): + return normal_op(*args, **kwargs) + +__all__.append('normal') + +full_op = ops.FillV2().set_device('CPU') +def full(*args): + return full_op(*args) + +__all__.append('full') + +range_op = ops.Range().set_device('CPU') +def arange(start, end, step, dtype): + return cast(range_op(start, end, step), dtype) + +__all__.append('arange') + + +broadcast_to_op = ops.Primitive('BroadcastTo').set_device('CPU') +def broadcast_to(*args): + return broadcast_to_op(*args) + +__all__.append('broadcast_to') + +def concat(tensors, dim): + concat_op = ops.Concat(dim).set_device('CPU') + return concat_op(tensors) + +__all__.append('concat') + +zeros_op = ops.Zeros().set_device('CPU') +def zeros(*args): + return zeros_op(*args) + +__all__.append('zeros') + +ones_op = ops.Ones().set_device('CPU') +def ones(*args): + return ones_op(*args) + +__all__.append('ones') + +uniform_real_op = ops.UniformReal().set_device('CPU') +def uniform_real(*args): + return uniform_real_op(*args) + +__all__.append('uniform_real') + +def pad_v3(input_x, padding, mode='constant', value=None): + pad_op = ops.PadV3(mode=mode, paddings_contiguous=True).set_device('CPU') + if isinstance(value, (float, int)): + value = core.tensor(value, dtype=input_x.dtype) + return pad_op(input_x, padding, value) + +__all__.append('pad_v3') + +reduce_any_op = ops.ReduceAny().set_device('CPU') +reduce_any_keepdim_op = ops.ReduceAny(True).set_device('CPU') +def reduce_any(input, dim, keepdim): + if keepdim: + return reduce_any_keepdim_op(input, dim) + return reduce_any_op(input, dim) + +__all__.append('reduce_any') + +reduce_all_op = ops.ReduceAll().set_device('CPU') +reduce_all_keepdim_op = ops.ReduceAll(True).set_device('CPU') +def reduce_all(input, dim, keepdim): + if keepdim: + return reduce_all_keepdim_op(input, dim) + return reduce_all_op(input, dim) + +__all__.append('reduce_all') + +def isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): + is_close = _get_cache_prim(ops.IsClose)(rtol=rtol, atol=atol, equal_nan=equal_nan).set_device('CPU') + return is_close(input, other) + +__all__.append('isclose') + +tile_op = ops.Primitive('Tile').set_device('CPU') +def tile(*args): + return tile_op(*args) + +__all__.append('tile') + +def randint(low, high, shape, dtype, generator): + rand_op = ops.UniformInt().set_device('CPU') + output = rand_op(shape, mindspore.Tensor(low, mindspore.int32), mindspore.Tensor(high, mindspore.int32)) + return cast(output, dtype) + # return mindspore.Tensor(np.random.randint(low, high, shape)) + +cast_op = ops.Cast().set_device('CPU') +def cast(input, dtype): + return cast_op(input, dtype) + +__all__.append('cast') + +def tril_ext(input, diagonal): + tril_op = ops.Tril(diagonal).set_device('CPU') + return tril_op(input) + +def strided_slice(input, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask): + strided_slice_op = _get_cache_prim(ops.StridedSlice)(begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask).set_device('CPU') + return strided_slice_op(input, begin, end, strides) + +__all__.append('strided_slice') + +def strided_slice_grad(input, begin, end, strides, update, begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0): + strided_slice_grad = _get_cache_prim(StridedSliceGrad)(begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask).set_device('CPU') + return strided_slice_grad(update, input.shape, begin, end, strides) + +__all__.append('strided_slice_grad') + +def squeeze(input, dim): + squeeze_op = ops.Squeeze(dim).set_device('CPU') + return squeeze_op(input) + +__all__.append('squeeze') + +def sort_ext(input, dim, descending, stable): + sort_op = ops.Sort(dim, descending).set_device('CPU') + return sort_op(input) + +__all__.append('sort_ext') + +def stack(tensors, dim): + stack_op = ops.Stack(dim).set_device('CPU') + return stack_op(tensors) + +__all__.append('stack') + +def gather(input_params, input_indices, axis, batch_dims=0): + gather_op = _get_cache_prim(ops.Gather)(batch_dims).set_device('CPU') + return gather_op(input_params, input_indices, axis) + +__all__.append('gather') + +def softmax(input, dim): + softmax_op = ops.Softmax(dim).set_device('CPU') + return softmax_op(input) + +__all__.append('softmax') + +def topk(input, k, sorted=True): + topk_op = ops.TopK(sorted).set_device('CPU') + return topk_op(input, k) + +__all__.append('topk') + +dyn_shape_op = ops.TensorShape().set_device('CPU') +def dyn_shape(self): + return dyn_shape_op(self) + +__all__.append('dyn_shape') + +bitwise_and_op = ops.BitwiseAnd().set_device('CPU') +def bitwise_and_scalar(input, other): + return bitwise_and_op(input, other) + +bitwise_right_shift_op = ops.RightShift().set_device('CPU') +def bitwise_right_shift(input, other): + if isinstance(input, numbers.Number): + if not isinstance(input, int): + raise TypeError(f"For 'bitwise_left_shift', 'input' must be an integer, but got input:{type(input)}.") + input = cast(input, other.dtype) + elif isinstance(other, numbers.Number): + if not isinstance(other, int): + raise TypeError(f"For 'bitwise_left_shift', 'other' must be an integer, but got other:{type(other)}.") + other = cast(other, input.dtype) + return bitwise_right_shift_op(input, other) + +__all__.append('bitwise_right_shift') diff --git a/mindnlp/core/_prims/cpu/__init__.py b/mindnlp/core/_prims/cpu/__init__.py deleted file mode 100644 index 4192374af..000000000 --- a/mindnlp/core/_prims/cpu/__init__.py +++ /dev/null @@ -1,249 +0,0 @@ -from mindspore.common.api import _pynative_executor -from mindspore.ops.auto_generate import gen_ops_prim -from mindspore.ops.auto_generate.gen_ops_prim import * -from mindspore._c_expression import Tensor as MSTensor -from mindspore._c_expression import pyboost_cast, pyboost_empty, pyboost_zeros, pyboost_ones -from mindspore.ops.operations.manually_defined.ops_def import Cast, Zeros, Ones -from mindspore.ops._primitive_cache import _get_cache_prim -from mindspore.ops import StopGradient, Primitive, ApplyAdadelta, Adam, ApplyAdamWithAmsgradV2, SGD -from mindspore.ops import FillV2, UniformReal, Stack, StandardNormal, TensorScatterUpdate -from mindspore.ops.operations import identity, TensorShape -from mindspore.ops.operations._grad_ops import StridedSliceGrad - - -pyboost_list = list(filter(lambda s: s.startswith("pyboost"), dir(gen_ops_prim))) -pyboost_op_list = [op.replace('pyboost_', '') + '_op' for op in pyboost_list] -aclop_list = list(filter(lambda s: s.endswith("_op") and not s in pyboost_op_list, dir(gen_ops_prim))) - - -pyboost_func = ''' -def {name}(*args): - return {pyboost}({op}, args) -''' - -aclop_func = ''' -def {name}(*args): - return _pynative_executor.run_op_async({obj}, {obj}.name, args) -''' - -__all__ = [] - -for op_name in pyboost_list: - op = getattr(gen_ops_prim, op_name) - func_name = op_name.replace('pyboost_', '') + '_cpu' - prim_op = func_name.replace('_cpu', '_op') - if not hasattr(gen_ops_prim, prim_op): - continue - __all__.append(func_name) - globals()[prim_op] = getattr(gen_ops_prim, prim_op).__class__().set_device('CPU') - exec(pyboost_func.format(name=func_name, pyboost=op_name, op=prim_op), globals()) - - -for op_name in aclop_list: - func_name = op_name.replace('_op', '_cpu') - __all__.append(func_name) - prim_op = func_name + '_prim' - globals()[prim_op] = getattr(gen_ops_prim, op_name).__class__().set_device('CPU') - exec(aclop_func.format(name=func_name, obj=prim_op), globals()) - -cast_op = Cast().set_device('CPU') -def cast_cpu(*args): - return pyboost_cast(cast_op, args) - -__all__.append('cast_cpu') - -def empty_cpu(size, dtype): - return pyboost_empty([size, dtype, 'CPU']) - -__all__.append('empty_cpu') - -zeros_op = Zeros().set_device('CPU') -def zeros_cpu(*args): - return pyboost_zeros(zeros_op, args) - -__all__.append('zeros_cpu') - -ones_op = Ones().set_device('CPU') -def ones_cpu(*args): - return pyboost_ones(ones_op, args) - -__all__.append('ones_cpu') - - -squeeze_op = Squeeze().set_device('CPU') -def squeeze_cpu(*args): - return pyboost_squeeze(squeeze_op, args) - -__all__.append('squeeze_cpu') - -stack_ext_op = StackExt().set_device('CPU') -def stack_ext_cpu(*args): - return pyboost_stack_ext(stack_ext_op, args) - -__all__.append('stack_ext_cpu') - -tile_op = Primitive('Tile').set_device('CPU') -def tile_cpu(*args): - return pyboost_tile(tile_op, args) - -__all__.append('tile_cpu') - -greater_equal_op = GreaterEqual().set_device('CPU') -def greater_equal_cpu(*args): - return pyboost_greater_equal(greater_equal_op, args) - -__all__.append('greater_equal_cpu') - -isclose_op = IsClose().set_device('CPU') -def isclose_cpu(*args): - return pyboost_isclose(isclose_op, args) - -__all__.append('isclose_cpu') - -range_op = Range().set_device('CPU') -def range_cpu(*args): - return _pynative_executor.run_op_async(range_op, range_op.name, args) - -__all__.append('range_cpu') - -linspace_op = LinSpace().set_device('CPU') -def linspace_cpu(*args): - return _pynative_executor.run_op_async(linspace_op, linspace_op.name, args) - -__all__.append('linspace_cpu') - -full_op = FillV2().set_device('CPU') -def full_cpu(shape, value): - return _pynative_executor.run_op_async(full_op, full_op.name, [shape, MSTensor(value)]) - -__all__.append('full_cpu') - -stop_gradient_op = StopGradient().set_device('CPU') -def stop_gradient_cpu(*args): - return _pynative_executor.run_op_async(stop_gradient_op, stop_gradient_op.name, args) - -__all__.append('stop_gradient_cpu') - -identity_op = identity().set_device('CPU') -def identity_cpu(*args): - return _pynative_executor.run_op_async(identity_op, identity_op.name, args) - -__all__.append('identity_cpu') - - -tensor_shape_op = TensorShape().set_device('CPU') -def tensor_shape_cpu(*args): - return _pynative_executor.run_op_async(tensor_shape_op, tensor_shape_op.name, args) - -__all__.append('stop_gradient_cpu') - -adadelta_op = ApplyAdadelta().set_device('CPU') -def raw_adadelta_cpu(param, square_avg, acc_delta, lr, rho, eps, grad): - args = (param, square_avg, acc_delta, lr, rho, eps, grad) - return _pynative_executor.run_op_async(adadelta_op, adadelta_op.name, args) - -adam_op = Adam().set_device('CPU') -def raw_adam_cpu(param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): - # var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad - args = (param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) - return _pynative_executor.run_op_async(adam_op, adam_op.name, args) - -adam_amsgrad_op = ApplyAdamWithAmsgradV2().set_device('CPU') -def raw_adam_amsgrad_cpu(param, exp_avg, exp_avg_sq, max_exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): - # var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad - args = (param, exp_avg, exp_avg_sq, max_exp_avg_sq, - beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) - return _pynative_executor.run_op_async(adam_amsgrad_op, adam_amsgrad_op.name, args) - - -def raw_sgd_cpu(param, grad, lr, dampening, weight_decay, nesterov, accum, momentum, stat): - sgd_op = _get_cache_prim(SGD)(dampening, weight_decay, nesterov).set_device('CPU') - args = (param, grad, lr, accum, momentum, stat) - return _pynative_executor.run_op_async(sgd_op, sgd_op.name, args) - -__all__.extend( - [ - 'raw_adadelta_cpu', - 'raw_adam_cpu', - 'raw_adam_amsgrad_cpu', - 'raw_sgd_cpu' - ] -) - -uniform_real_op = UniformReal().set_device('CPU') -def uniform_real_cpu(*args): - return _pynative_executor.run_op_async(uniform_real_op, uniform_real_op.name, args) - -__all__.append('uniform_real_cpu') - -def stack_cpu(tensors, dim): - stack_op = _get_cache_prim(Stack)(dim).set_device('CPU') - return _pynative_executor.run_op_async(stack_op, stack_op.name, tensors) - -__all__.append('stack_cpu') - -argmax_with_value_op = ArgMaxWithValue().set_device('CPU') -def argmax_with_value_cpu(*args): - return pyboost_argmax_with_value(argmax_with_value_op, args) - -__all__.append('argmax_with_value_cpu') - -argmin_with_value_op = ArgMinWithValue().set_device('CPU') -def argmin_with_value_cpu(*args): - return pyboost_argmin_with_value(argmin_with_value_op, args) - -__all__.append('argmin_with_value_cpu') - -log_softmax_op = LogSoftmax().set_device('CPU') -def log_softmax_cpu(*args): - return pyboost_log_softmax(log_softmax_op, args) - -__all__.append('log_softmax_cpu') - -strided_slice_op = StridedSlice().set_device('CPU') -def strided_slice_cpu(*args): - return _pynative_executor.run_op_async(strided_slice_op, strided_slice_op.name, args) - -__all__.append('strided_slice_cpu') - -hard_shrink_op = HShrink().set_device('CPU') -def hard_shrink_cpu(*args): - return pyboost_hshrink(hard_shrink_op, args) - -__all__.append('hard_shrink_cpu') - -normal_op = StandardNormal().set_device('CPU') -def normal_cpu(*args): - return _pynative_executor.run_op_async(normal_op, normal_op.name, args) - -__all__.append('normal_cpu') - -reduce_any_op = ReduceAny().set_device('CPU') -def reduce_any_cpu(*args): - return pyboost_reduce_any(reduce_any_op, args) - -__all__.append('reduce_any_cpu') - -def strided_slice_grad_cpu(input, begin, end, strides, update, begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0): - strided_slice_grad = _get_cache_prim(StridedSliceGrad)(begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask).set_device('CPU') - return _pynative_executor.run_op_async(strided_slice_grad, strided_slice_grad.name, [update, input.shape, begin, end, strides]) - -__all__.append('strided_slice_grad_cpu') - -tensor_scatter_update_op = TensorScatterUpdate().set_device('CPU') -def tensor_scatter_update_cpu(*args): - return _pynative_executor.run_op_async(tensor_scatter_update_op, tensor_scatter_update_op.name, args) - -__all__.append('tensor_scatter_update_cpu') - -broadcast_to_op = Primitive('BroadcastTo').set_device('CPU') -def broadcast_to_cpu(*args): - return pyboost_broadcast_to(broadcast_to_op, args) - -__all__.append('broadcast_to_cpu') - -concat_op = Concat().set_device('CPU') -def concat_cpu(*args): - return pyboost_concat(concat_op, args) - -__all__.append('concat_cpu') diff --git a/mindnlp/core/_prims/meta.py b/mindnlp/core/_prims/meta.py new file mode 100644 index 000000000..1f97a8b5f --- /dev/null +++ b/mindnlp/core/_prims/meta.py @@ -0,0 +1,264 @@ +try: + from mindspore._c_expression import TensorPy as Tensor_ +except: + from mindspore._c_expression import Tensor as Tensor_ + +import math +import numpy as np +from mindnlp import core + +__all__ = [] + +def arange(start, end, step, dtype): + out = Tensor_(shape=(math.ceil((end - start) / step), ), dtype=dtype) + return core.Tensor(out) + +__all__.append('arange') + +import math +from typing import Tuple, Union + +def infer_broadcast_shape(input_shape: Tuple[int, ...], + target_shape: Tuple[Union[int, None], ...]) -> Tuple[int, ...]: + """ + 推断 torch.broadcast_to 的输出形状 + + 参数: + input_shape: 输入张量的形状元组 (例如 (3, 1)) + target_shape: 目标广播形状元组 (可包含None表示自动推断维度) + + 返回: + 广播后的输出形状元组 + + 异常: + ValueError: 当广播不兼容时 + """ + # 处理 None 值(自动维度推断) + final_target_shape = [] + for i, dim in enumerate(target_shape): + if dim is None: + # 查找可以推断的维度位置 + candidates = [j for j, d in enumerate(target_shape) if d is None] + if len(candidates) > 1: + raise ValueError(f"多个None维度 {candidates},无法明确推断") + final_target_shape.append(None) + elif dim < -1: + raise ValueError(f"维度大小不能为负数 (除-1外),发现 {dim}") + else: + final_target_shape.append(dim) + + # 计算需要推断的总元素数量 + def count_product(shape, exclude_none=True): + prod = 1 + for dim in shape: + if dim == 0: + return 0 # 任何维度为0结果即为0 + if dim is not None and not (exclude_none and dim == -1): + prod *= max(1, dim) # -1视为1用于计数 + return prod + + # 验证维度数量兼容性 + ndim_input = len(input_shape) + ndim_target = len(final_target_shape) + + if ndim_input > ndim_target: + raise ValueError( + f"输入维度({ndim_input})多于目标维度({ndim_target})," + f"无法广播: {input_shape} -> {final_target_shape}" + ) + + # 创建对齐后的形状(左侧填充1) + aligned_input_shape = (1,) * (ndim_target - ndim_input) + input_shape + inferred_target_shape = list(final_target_shape) + known_product = 1 + + # 第一遍:收集已知信息 + for i in range(ndim_target): + target_dim = inferred_target_shape[i] + input_dim = aligned_input_shape[i] + + if target_dim == -1: + # 标记需要推断的维度 + inferred_target_shape[i] = None + elif target_dim is not None: + # 验证维度兼容性 + if target_dim == 0: + if input_dim not in (0, 1): + raise ValueError( + f"维度 {i}: 目标维度为0时输入维度必须为0或1, " + f"但得到 {input_dim} -> {target_dim}" + ) + else: # 正数维度 + if input_dim != 1 and input_dim != target_dim: + raise ValueError( + f"维度 {i}: 大小 {input_dim} 无法广播到 {target_dim}" + ) + known_product *= target_dim + + # 第二遍:推断维度 + total_elements = math.prod([d for d in input_shape if d != 0]) + inferred_product = known_product + + # 统计需要推断的维度数量 + none_indices = [i for i, d in enumerate(inferred_target_shape) if d is None] + num_infer = len(none_indices) + + if num_infer > 0: + # 计算需要推断的总元素量 + required_total = total_elements + + # 当输入有0维时的特殊情况 + if 0 in input_shape: + if required_total != 0: + raise ValueError("含0维输入广播时无法推断非0维度") + # 所有推断维度必须为0 + for i in none_indices: + inferred_target_shape[i] = 0 + else: + if inferred_product == 0 and required_total > 0: + raise ValueError( + "无法将非0输入广播到含0维的目标形状: " + f"{input_shape} -> {inferred_target_shape}" + ) + + # 计算推断维度的乘积 + infer_product = required_total // inferred_product if inferred_product != 0 else 0 + + if infer_product * inferred_product != required_total: + raise ValueError( + f"元素总数不兼容: 输入有 {total_elements} 元素, " + f"但目标形状仅能容纳 {inferred_product * infer_product} 元素" + ) + + # 检查是否可以整数划分 + for i in none_indices: + # 仅当有1个-1时可以推断 + if num_infer == 1: + inferred_target_shape[i] = infer_product + else: + # 多维度无法自动推断 + raise ValueError( + f"多个维度({len(none_indices)})需要推断: {none_indices} " + "但未指定足够约束条件" + ) + + # 转换为确定形状元组 + result_shape = tuple( + d if d is not None else -1 # 保留-1表示未指定 + for d in inferred_target_shape + ) + + return result_shape + +def broadcast_to(input, shape): + out_shape = infer_broadcast_shape(input.shape, shape) + out = Tensor_(shape=out_shape, dtype=input.dtype) + return core.Tensor(out) + +__all__.append('broadcast_to') + +def zeros(size, dtype): + out = Tensor_(shape=size, dtype=dtype) + return core.Tensor(out) + +__all__.append('zeros') + +def ones(size, dtype): + out = Tensor_(shape=size, dtype=dtype) + return core.Tensor(out) + +__all__.append('ones') + +def inplace_uniform(input, *args): + return input + +__all__.append('inplace_uniform') + +def inplace_fill_scalar(input, value): + return input + +__all__.append('inplace_fill_scalar') + +def inplace_normal(input, *args): + return input + +__all__.append('inplace_normal') + +def getitem(input, slice): + out = input.asnumpy()[slice] + out = Tensor_(shape=out.shape, dtype=input.dtype) + return core.Tensor(out) + +__all__.append('getitem') + +def sub_ext(input, other, alpha): + return input + +__all__.append('sub_ext') + +def pad_v3(input, pad, mode, value): + out = np.pad(input.asnumpy(), pad, mode, constant_values=value) + out = Tensor_(shape=out.shape, dtype=input.dtype) + return core.Tensor(out) + +__all__.append('pad_v3') + +def abs(input): + return input + +__all__.append('abs') + +def cast(input, dtype): + out = Tensor_(shape=input.shape, dtype=dtype) + return core.Tensor(out) + +__all__.append('cast') + +def index_select(input, dim, index): + out = np.take(input.asnumpy(), index.asnumpy(), dim) + out = Tensor_(shape=out.shape, dtype=input.dtype) + return core.Tensor(out) + +__all__.append('index_select') + +def identity(input): + out = Tensor_(shape=input.shape, dtype=input.dtype) + return core.Tensor(out) + +__all__.append('identity') + +def contiguous(input): + return input + +__all__.append('contiguous') + +def inplace_copy(input, other): + return input + +__all__.append('inplace_copy') + +def div(input, other): + if isinstance(input, core.Tensor): + shape = input.shape + dtype = input.dtype + else: + shape = other.shape + dtype = other.dtype + out = Tensor_(shape=shape, dtype=dtype) + return core.Tensor(out) + +__all__.append('div') + +def pow_scalar_tensor(input, other): + out = Tensor_(shape=other.shape, dtype=other.dtype) + return core.Tensor(out) + +__all__.append('pow_scalar_tensor') + +def concat(tensors, dim): + shape = list(tensors[0].shape) + shape[dim] = sum([t.shape[dim] for t in tensors]) + out = Tensor_(shape=tuple(shape), dtype=tensors[0].dtype) + return core.Tensor(out) + +__all__.append('concat') diff --git a/mindnlp/core/_prims/numpy.py b/mindnlp/core/_prims/numpy.py new file mode 100644 index 000000000..12705ea98 --- /dev/null +++ b/mindnlp/core/_prims/numpy.py @@ -0,0 +1,470 @@ +import numbers +import numpy as np +from mindspore import ops +from mindnlp import core + +__all__ = [] + +def empty(size, dtype): + return core.Tensor.from_numpy(np.empty(size, core.dtype2np[dtype])) + +__all__.append('empty') + +def ones(size, dtype): + return core.Tensor.from_numpy(np.ones(size, core.dtype2np[dtype])) + +__all__.append('ones') + +def zeros(size, dtype): + return core.Tensor.from_numpy(np.zeros(size, core.dtype2np[dtype])) + +__all__.append('zeros') + +def arange(start, end, step, dtype): + return core.Tensor.from_numpy(np.arange(start, end, step, core.dtype2np[dtype])) + +__all__.append('arange') + +def div(input, other): + if not isinstance(input, numbers.Number): + input = input.numpy() + elif not isinstance(other, numbers.Number): + other = other.numpy() + out = np.divide(input, other) + if not isinstance(out, np.ndarray): + out = np.array(out) + return core.Tensor.from_numpy(out) + +__all__.append('div') + +def pow_scalar_tensor(input, other): + out = np.power(input, other.numpy()) + return core.Tensor.from_numpy(out) + +__all__.append('pow_scalar_tensor') + +def mul(input, other): + if not isinstance(input, numbers.Number): + input = input.asnumpy() + elif not isinstance(other, numbers.Number): + other = other.asnumpy() + out = np.multiply(input, other) + if not isinstance(out, np.ndarray): + out = np.array(out) + return core.Tensor.from_numpy(out) + +__all__.append('mul') + +def sub_ext(input, other, alpha): + if not isinstance(input, numbers.Number): + input = input.numpy() + elif not isinstance(other, numbers.Number): + other = other.numpy() + out = np.subtract(input, other * alpha) + return core.Tensor.from_numpy(out) + +__all__.append('sub_ext') + +def clamp_scalar(input, min, max): + out = np.clip(input.numpy(), min, max) + return core.Tensor.from_numpy(out) + +__all__.append('clamp_scalar') + +def add(input, other): + if not isinstance(input, numbers.Number): + input = input.numpy() + elif not isinstance(other, numbers.Number): + other = other.numpy() + out = np.add(input, other) + if not isinstance(out, np.ndarray): + out = np.array(out) + return core.Tensor.from_numpy(out) + +__all__.append('add') + +dyn_shape_op = ops.TensorShape().set_device('CPU') +def dyn_shape(self): + return dyn_shape_op(self) + +__all__.append('dyn_shape') + +def cast(input, dtype): + out = input.asnumpy().astype(core.dtype2np[dtype]) + return core.Tensor.from_numpy(out) + +__all__.append('cast') + +def getitem(input, slice): + out = input.asnumpy()[slice] + if not isinstance(out, np.ndarray): + out = np.array(out) + return core.Tensor.from_numpy(out) + +__all__.append('getitem') + +def setitem(input, slice, value): + out = input.asnumpy() + out[slice] = value + out = core.Tensor.from_numpy(out) + input.assign_value(out) + return input + +__all__.append('setitem') + +def contiguous(input): + return input + +__all__.append('contiguous') + +def reshape(input, shape): + out = np.reshape(input.asnumpy(), shape) + return core.Tensor.from_numpy(out) + +__all__.append('reshape') + +def bitwise_and_scalar(input, other): + out = np.bitwise_and(input.numpy(), other) + return core.Tensor.from_numpy(out) + +__all__.append('bitwise_and_scalar') + +def right_shift(input, other): + out = np.right_shift(input.numpy(), other) + return core.Tensor.from_numpy(out) + +__all__.append('right_shift') + +def transpose_ext_view(input, dim0, dim1): + out = np.swapaxes(input.numpy(), dim0, dim1) + return core.Tensor.from_numpy(out) + +__all__.append('transpose_ext_view') + +def expand_dims_view(input, dim): + out = np.expand_dims(input.numpy(), dim) + return core.Tensor.from_numpy(out) + +__all__.append('expand_dims_view') + +def equal(input, other): + if not isinstance(input, numbers.Number): + input = input.numpy() + elif not isinstance(other, numbers.Number): + other = other.numpy() + out = np.equal(input, other) + if not isinstance(out, np.ndarray): + out = np.array(out) + return core.Tensor.from_numpy(out) + +__all__.append('equal') + +def reduce_all(input, dim, keepdim): + out = np.all(input.numpy(), dim, keepdims=keepdim) + if not isinstance(out, np.ndarray): + out = np.array(out) + return core.Tensor.from_numpy(out) + +__all__.append('reduce_all') + +def reduce_any(input, dim, keepdim): + out = np.any(input.numpy(), dim, keepdims=keepdim) + if not isinstance(out, np.ndarray): + out = np.array(out) + return core.Tensor.from_numpy(out) + +__all__.append('reduce_any') + + +def sum_ext(input, dim, keepdim, dtype): + if dtype is not None: + dtype = core.dtype2np[dtype] + out = np.sum(input.numpy(), dim, dtype, keepdims=keepdim) + if not isinstance(out, np.ndarray): + out = np.array(out) + return core.Tensor.from_numpy(out) + +__all__.append('sum_ext') + +def full(size, fill_value): + out = np.full(size, fill_value) + return core.Tensor.from_numpy(out) + +__all__.append('full') + +def zeros_like(input): + out = np.zeros_like(input.numpy()) + + return core.Tensor.from_numpy(out) + +__all__.append('zeros_like') + +broadcast_to_op = ops.Primitive('BroadcastTo').set_device('CPU') +def broadcast_to(input, shape): + return broadcast_to_op(input, shape) + +__all__.append('broadcast_to') + +def uniform_real(size): + out = np.random.rand(*size).astype(np.float32) + return core.Tensor.from_numpy(out) + +__all__.append('uniform_real') + +def normal(shape): + out = np.random.normal(0., 1., shape).astype(np.float32) + return core.Tensor.from_numpy(out) + +__all__.append('normal') + +def pad_v3(input, pad, mode, value): + out = np.pad(input.asnumpy(), pad, mode, constant_values=value) + return core.Tensor.from_numpy(out) + +__all__.append('pad_v3') + +def concat(tensors, dim): + out = np.concatenate([t.numpy() for t in tensors], dim) + return core.Tensor.from_numpy(out) + +__all__.append('concat') + +def abs(input): + out = np.abs(input.numpy()) + return core.Tensor.from_numpy(out) + +__all__.append('abs') + +def mean_ext(input, dim, keepdim, dtype): + out = np.mean(input.numpy(), dim, keepdims=keepdim) + if not isinstance(out, np.ndarray): + out = np.array(out) + return core.Tensor.from_numpy(out) + +__all__.append('mean_ext') + +def matmul_ext(input, other): + out = np.matmul(input.numpy(), other.numpy()) + return core.Tensor.from_numpy(out) + +__all__.append('matmul_ext') + +def max(input): + out = np.max(input.numpy()) + + if not isinstance(out, np.ndarray): + out = np.array(out) + return core.Tensor.from_numpy(out) + +__all__.append('max') + +def randint(from_, to, shape, dtype, generator): + out = np.random.randint(from_, to, shape, dtype=core.dtype2np[dtype]) + + return core.Tensor.from_numpy(out) + +__all__.append('randint') + +def identity(input): + out = np.copy(input.asnumpy()) + + return core.Tensor.from_numpy(out) + +__all__.append('identity') + +# def non_zero() +def isclose(input, other, rtol, atol, equal_nan): + out = np.isclose(input.numpy(), other.numpy(), rtol, atol, equal_nan) + return core.Tensor.from_numpy(out) + +__all__.append('isclose') + +def non_zero(input): + out = np.nonzero(input.numpy()) + out = np.stack(out, 1) + return core.Tensor.from_numpy(out) + +__all__.append('non_zero') + +def tile(input, dims): + out = np.tile(input.numpy(), dims) + + return core.Tensor.from_numpy(out[0]) + +__all__.append('tile') + +def squeeze(input, dim): + out = np.squeeze(input.numpy(), dim) + return core.Tensor.from_numpy(out) + +__all__.append('squeeze') + +def index_select(input, dim, index): + out = np.take(input.asnumpy(), index.asnumpy(), dim) + return core.Tensor.from_numpy(out) + +__all__.append('index_select') + +def rand_ext(size, seed, offset, dtype): + out = np.random.randn(*size).astype(core.dtype2np[dtype]) + return core.Tensor.from_numpy(out[0]) + +__all__.append('rand_ext') + +def inplace_uniform(input, from_, to_, generator_): + out = np.random.uniform(from_, to_, input.shape).astype(core.dtype2np[input.dtype]) + input.assign_value(core.Tensor.from_numpy(out)) + return input + +__all__.append('inplace_uniform') + +def inplace_fill_scalar(input, value): + out = np.full_like(input.numpy(), value) + input.assign_value(core.Tensor.from_numpy(out)) + return input + +__all__.append('inplace_fill_scalar') + +def inplace_normal(input, mean, std, seed, offset): + out = np.random.normal(mean, std, input.shape).astype(core.dtype2np[input.dtype]) + input.assign_value(core.Tensor.from_numpy(out)) + return input + +__all__.append('inplace_normal') + +def inplace_random(input, from_val=0, to_val=None, seed=None, offset=None): + # 选择随机数生成器 + rng = np.random + arr = input.numpy() + if np.issubdtype(arr.dtype, np.floating): + # 浮点类型处理 + if to_val is None: + # 默认 [0, 1) 均匀分布 + rnd = rng.random(size=arr.shape).astype(arr.dtype) + else: + rnd = (from_val + (to_val - from_val) * rng.random(size=arr.shape)).astype(arr.dtype) + + elif np.issubdtype(arr.dtype, np.integer): + # 整数类型处理 + from_int = int(from_val) + + if to_val is None: + # 默认范围 [0, dtype.max] + max_val = np.iinfo(arr.dtype).max + rnd = rng.randint(0, max_val + 1, size=arr.shape).astype(arr.dtype) + else: + # 指定范围 [from_int, to_val) + to_int = int(to_val) + + # 验证参数有效性 + if from_int >= to_int: + raise ValueError(f"Empty range for integers: from={from_int} >= to={to_int}") + + # 处理整数边界问题 + dtype_min = np.iinfo(arr.dtype).min + dtype_max = np.iinfo(arr.dtype).max + from_int = np.clip(from_int, dtype_min, dtype_max) + to_int = np.clip(to_int, dtype_min + 1, dtype_max + 1) + + rnd = rng.randint(from_int, to_int, size=arr.shape).astype(arr.dtype) + + elif arr.dtype == bool: + # 布尔类型处理 (忽略 from_val/to_val) + rnd = rng.random(size=arr.shape) > 0.5 + + else: + raise TypeError(f"Unsupported data type: {arr.dtype}") + + input.assign_value(core.Tensor.from_numpy(rnd)) + return input + +__all__.append('inplace_random') + +def inplace_copy(input, other): + input.assign_value(other) + return input + +__all__.append('inplace_copy') + +def softmax(input, dim): + softmax_op = ops.Softmax(dim).set_device('CPU') + return softmax_op(input) + + +__all__.append('softmax') + +def topk(input, k, sorted=True): + topk_op = ops.TopK(sorted).set_device('CPU') + return topk_op(input, k) + +__all__.append('topk') + +def sort_ext(input, dim, descending, stable): + sort_op = ops.Sort(dim, descending).set_device('CPU') + return sort_op(input) + +__all__.append('sort_ext') + +def round(input): + out = np.round(input.numpy()) + if not isinstance(out, np.ndarray): + out = np.array(out) + return core.Tensor.from_numpy(out) + +__all__.append('round') + +def isin(elements, test_elements): + out = np.isin(elements, test_elements) + if not isinstance(out, np.ndarray): + out = np.array(out) + return core.Tensor.from_numpy(out) + +__all__.append('isin') + +def ldexp(input, other): + if not isinstance(other, numbers.Number): + other = other.numpy() + out = np.ldexp(input.numpy(), other) + return core.Tensor.from_numpy(out) + +__all__.append('ldexp') + +def less(input, other): + if not isinstance(input, numbers.Number): + input = input.numpy() + if not isinstance(other, numbers.Number): + other = other.numpy() + + out = input < other + return core.Tensor.from_numpy(out) + +__all__.append('less') + +def cumsum_ext(input, dim, dtype): + if dtype is not None: + dtype = core.dtype2np[dtype] + out = np.cumsum(input.numpy(), dim, dtype) + + return core.Tensor.from_numpy(out) + +__all__.append('cumsum_ext') + +def greater_equal(input, other): + if not isinstance(input, numbers.Number): + input = input.numpy() + if not isinstance(other, numbers.Number): + other = other.numpy() + + out = input >= other + if not isinstance(out, np.ndarray): + out = np.array(out) + + return core.Tensor.from_numpy(out) + +__all__.append('greater_equal') + +def masked_fill(input, mask, value): + out = np.where(mask.numpy(), value, input.numpy()) + return core.Tensor.from_numpy(out) + +__all__.append('masked_fill') diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py index 8857647c2..1d0a79bf6 100644 --- a/mindnlp/core/_tensor.py +++ b/mindnlp/core/_tensor.py @@ -1,9 +1,11 @@ import math +import ctypes import numpy as np import mindspore from mindspore import Tensor from mindspore.common.tensor import _TensorMeta from mindspore._c_expression.typing import Type +from mindspore._c_expression import typing try: from mindspore.common._stub_tensor import StubTensor, _stub_method except: @@ -15,12 +17,17 @@ class StubTensor: pass from mindspore._c_expression import Tensor as Tensor_ from . import ops, _dtype -from ._bind import get_default_device, device_, get_default_dtype +from ._bind import get_device_in_context, device_, get_default_dtype from .storage import UntypedStorage from ._utils import _rebuild_tensor_v2 from ._C.size import Size -from .types import DEVICE_MAP -from .configs import DEVICE_TARGET +from .configs import DEVICE_TARGET, CPU_USE_NUMPY_OP +from .dispatcher import device_map + +if DEVICE_TARGET == 'Ascend': + import acl +else: + acl = None DTYPE_ELEMENT_SIZE_MAP = { mindspore.float64: 8, @@ -70,50 +77,78 @@ def __init__(self, *args, **kwargs): super().__init__(*args, dtype=_dtype.bool, **kwargs) +def tensor_meta_str(self): + return "" + +_TensorMeta.__str__ = tensor_meta_str + +old_init = Tensor.__init__ +def __init__(self, *args, **kwargs): + if len(args) > 1 and all([isinstance(arg, int) for arg in args]): + tensor = Tensor_(shape=args, dtype=get_default_dtype()) + old_init(self, tensor, internal=True) + else: + old_init(self, *args, **kwargs) + +Tensor.__init__ = __init__ +origin_setitem = Tensor.__setitem__ + def tensor(data, *, dtype=None, device=None, requires_grad=False): if isinstance(data, Tensor): UserWarning("To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than core.tensor(sourceTensor).") - return Tensor(data) - - if isinstance(data, list): - new_data = [] - for d in data: - if isinstance(d, Tensor): - d = d.item() - new_data.append(d) - data = new_data + out = Tensor(data) + out._device = data.device + return out + + # if isinstance(data, list): + # new_data = [] + # for d in data: + # if isinstance(d, Tensor): + # d = d.item() + # new_data.append(d) + # data = new_data if device is None: - device = get_default_device() + device = get_device_in_context() + + if isinstance(device, (str, int)): + device = device_(device) if dtype is not None: tensor = Tensor(data, dtype=dtype) else: tensor = Tensor(data) - tensor = tensor.to(device) + tensor._device = device + if DEVICE_TARGET == 'Ascend' and device.type == 'cuda': + device.type = 'npu' + if device.type not in ['meta', 'cpu']: + tensor = tensor.to(device) tensor.requires_grad_(requires_grad) return tensor def is_tensor(x): return isinstance(x, Tensor) -def enable_mindspore_patch(): +class TensorPlaceHolder: - def tensor_meta_str(self): - return "" + def cpu(self): + return self.to(device_('cpu')) - _TensorMeta.__str__ = tensor_meta_str + def npu(self, device=None, non_blocking=False): + if device is None: + device = device_('npu', 0) + return self.to(device, non_blocking=non_blocking) - old_init = Tensor.__init__ - def __init__(self, *args, **kwargs): - if len(args) > 1 and all([isinstance(arg, int) for arg in args]): - tensor = Tensor_(shape=args, dtype=get_default_dtype()) - old_init(self, tensor, internal=True) - else: - old_init(self, *args, **kwargs) + def cuda(self, device=None, non_blocking=False): + if DEVICE_TARGET == 'Ascend': + return self.npu(device, non_blocking) + if device is None: + device = device_('gpu', 0) + return self.to(device, non_blocking=non_blocking) - Tensor.__init__ = __init__ + def requires_grad_(self, requires_grad=True): + self.requires_grad = requires_grad def __reduce_ex__(self, protocol): if isinstance(self, StubTensor): @@ -128,128 +163,28 @@ def __reduce_ex__(self, protocol): return ( _rebuild_from_type_v2, (_rebuild_tensor_v2, type(self), args, None)) - Tensor.__reduce_ex__ = __reduce_ex__ - StubTensor.__reduce_ex__ = __reduce_ex__ - - def to_(self, *args, **kwargs): - dtype_to = kwargs.get("dtype", None) - if len(args) == 1: - if isinstance(args[0], Type): - dtype_to = args[0] - elif isinstance(args[0], Tensor): - dtype_to = args[0].dtype - elif len(args) == 2: - _, dtype_to = args - else: - dtype_to = kwargs.get("dtype", None) - if dtype_to is not None: - return mindspore.ops.cast(self, dtype_to) - return self - - Tensor.to = to_ - StubTensor.to = to_ - - def size(self, dim=None): - if dim is None: - return self.shape - assert isinstance(dim, int), f'`dim` must be int but got {type(dim)}' - return self.shape[dim] - - Tensor.size = size - StubTensor.size = size - - @property - def shape(self): - if isinstance(self, StubTensor): - if self.stub is not None: - stub_shape = self.stub.get_shape() - else: - stub_shape = self.tensor.shape - return Size(stub_shape) - return Size(self._shape) - - Tensor.shape = shape - StubTensor.shape = shape - - @property - def is_meta(self): - return False - - Tensor.is_meta = is_meta - StubTensor.is_meta = is_meta - - def data_ptr(self): - ptr = self._data_ptr() - if ptr != 0: - return ptr - self + 1 - return self._data_ptr() - - Tensor.data_ptr = data_ptr - StubTensor.data_ptr = data_ptr - - Tensor.device = device_(DEVICE_MAP[mindspore.get_context('device_target')]) - StubTensor.device = device_(DEVICE_MAP[mindspore.get_context('device_target')]) - - def _expand(self, *size): - if len(size) == 1 and isinstance(size[0], (tuple, list)): - size = size[0] - new_size = () - for s in size: - if isinstance(s, Tensor): - s = s.item() - new_size += (s,) - return self.broadcast_to(new_size) - - Tensor.expand = _expand - StubTensor.expand = _expand - - Tensor.broadcast_to = ops.broadcast_to - StubTensor.broadcast_to = ops.broadcast_to - - def clone(self, *args, **kwargs): - return self.copy() - - Tensor.clone = clone - StubTensor.clone = clone - - def _repeat(self, *sizes): - if len(sizes) == 1 and isinstance(sizes[0], (list, tuple)): - sizes = sizes[0] - new_sizes = () - for s in sizes: - if not isinstance(s, int): - s = s.item() - new_sizes += (s,) - - return ops.tile(self, new_sizes) - - Tensor.repeat = _repeat - StubTensor.repeat = _repeat - - def __or__(self, other): - if isinstance(other, (int, bool, float, Tensor)): - return ops.bitwise_or(self, other) - raise TypeError("Unsupported operand type(s) for |: 'Tensor' and '{}'".format(type(other))) - - Tensor.__or__ = __or__ - StubTensor.__or__ = __or__ + def __hash__(self): + return hash(id(self)) - def __and__(self, other): - if isinstance(other, (int, bool, float, Tensor)): - return ops.bitwise_and(self, other) - raise TypeError("Unsupported operand type(s) for &: 'Tensor' and '{}'".format(type(other))) + def __len__(self): + if self.shape == (): + return 1 + return self.shape[0] - Tensor.__and__ = __and__ - StubTensor.__and__ = __and__ + def __repr__(self) -> str: + self.data_sync(True) + return Tensor_.__repr__(self)[:-1] + f', device={self.device})' - def detach(self): - return ops.stop_gradient(self) + def __format__(self, format_spec): + return np.ndarray.__format__(self.asnumpy(), format_spec) - Tensor.detach = detach - StubTensor.detach = detach + def __iter__(self): + if self.ndim == 0: + yield self + else: + for i in range(len(self)): + yield self[i] - origin_getitem = Tensor.__getitem__ def __getitem__(self, slices): slices = self._convert_numpy_slices(slices) # if 0 in self.shape: @@ -263,45 +198,17 @@ def __getitem__(self, slices): s = tensor(s) new_slices += (s,) slices = new_slices - return origin_getitem(self, slices) - - Tensor.__getitem__ = __getitem__ - StubTensor.__getitem__ = _stub_method(__getitem__) - - def _convert_numpy_slices(self, key): - """递归转换 key 中的 NumPy 整数为内置 int""" - # 处理元组:遍历所有元素并递归转换 - if isinstance(key, tuple): - return tuple(self._convert_numpy_slices(k) for k in key) - - # 处理 slice 对象:转换 start/stop/step - elif isinstance(key, slice): - start = key.start - stop = key.stop - step = key.step - - # 转换 NumPy 整数为 Python int - if isinstance(start, (np.integer, Tensor)): - start = int(start) - if isinstance(stop, (np.integer, Tensor)): - stop = int(stop) - if isinstance(step, (np.integer, Tensor)): - step = int(step) - - return slice(start, stop, step) - - # 转换单个 NumPy 索引值 - elif isinstance(key, np.integer): - return int(key) - - # 其他类型(如 int、None)直接返回 + if self.device.type == 'npu': + out = ops.tensor_getitem(self, slices) else: - return key + if CPU_USE_NUMPY_OP: + out = ops.getitem_np(self, slices) + else: + out = ops.getitem(self, slices) - Tensor._convert_numpy_slices = _convert_numpy_slices - StubTensor._convert_numpy_slices = _convert_numpy_slices + out._device = self.device + return out - origin_setitem = Tensor.__setitem__ def __setitem__(self, slices, value): slices = self._convert_numpy_slices(slices) if isinstance(value, float): @@ -317,274 +224,164 @@ def __setitem__(self, slices, value): new_slices += (s,) slices = new_slices if not isinstance(value, Tensor): - value = tensor(value, dtype=self.dtype) + value = tensor(value, dtype=self.dtype, device=self.device) else: value = value.to(self.dtype) if 1 in value.shape and self[slices].ndim != value.ndim: value = value.squeeze() - return origin_setitem(self, slices, value) - - Tensor.__setitem__ = __setitem__ - StubTensor.__setitem__ = __setitem__ - - def numel(self): - return math.prod(self.shape) - - Tensor.numel = numel - StubTensor.numel = numel - Tensor.nelement = numel - StubTensor.nelement = numel - - @property - def nbytes(self): - return self.numel() * self.element_size() - - Tensor.nbytes = nbytes - StubTensor.nbytes = nbytes - - Tensor.normal_ = ops.inplace_normal - StubTensor.normal_ = ops.inplace_normal - - - Tensor.softmax = ops.softmax - StubTensor.softmax = ops.softmax - - Tensor.squeeze = ops.squeeze - StubTensor.squeeze = ops.squeeze - - Tensor.unsqueeze = ops.unsqueeze - StubTensor.unsqueeze = ops.unsqueeze - - Tensor.log_softmax = ops.log_softmax - StubTensor.log_softmax = ops.log_softmax - - def untyped_storage(self): - return UntypedStorage(self) - - Tensor.untyped_storage = untyped_storage - StubTensor.untyped_storage = untyped_storage - - def element_size(self,): - return DTYPE_ELEMENT_SIZE_MAP[self.dtype] - - Tensor.element_size = element_size - StubTensor.element_size = element_size - - @property - def layout(self): - return None - - Tensor.layout = layout - StubTensor.layout = layout + if self.device.type == 'meta': + return self + elif self.device.type == 'npu': + if value.device != self.device: + value._device = self.device + out = ops.tensor_setitem(self, slices, value) + else: + if CPU_USE_NUMPY_OP: + out = ops.setitem_np(self, slices, value) + else: + out = ops.setitem(self, slices, value) + return self def __add__(self, other): # if 0 in self.shape: # return self return ops.add(self, other) - - Tensor.__add__ = __add__ - StubTensor.__add__ = __add__ def __iadd__(self, other): - self.data = ops.add(self, other) - return self + return self.copy_(ops.add(self, other)) - Tensor.__iadd__ = __iadd__ - StubTensor.__iadd__ = __iadd__ + def __radd__(self, other): + return Tensor.__add__(other, self) - def __sub__(self, other): + def __div__(self, other): # if 0 in self.shape: # return self if isinstance(other, (np.ndarray, np.integer)): other = tensor(other) - return ops.sub(self, other) - - Tensor.__sub__ = __sub__ - StubTensor.__sub__ = __sub__ + return ops.div(self, other) + def __rshift__(self, other): + return ops.bitwise_right_shift(self, other) - def __mul__(self, other): - # if 0 in self.shape: - # return self - if isinstance(other, (np.ndarray, np.integer)): - other = tensor(other) - return ops.mul(self, other) - - Tensor.__mul__ = __mul__ - StubTensor.__mul__ = __mul__ + def __rtruediv__ (self, other): + return ops.div(other, self) + def __ne__(self, other): + return ops.ne(self, other) - def __div__(self, other): + def __neg__(self): + return ops.neg(self) + + def __mul__(self, other): # if 0 in self.shape: # return self if isinstance(other, (np.ndarray, np.integer)): other = tensor(other) - return ops.div(self, other) - - Tensor.__truediv__ = __div__ - StubTensor.__truediv__ = __div__ - - Tensor.repeat_interleave = ops.repeat_interleave - StubTensor.repeat_interleave = ops.repeat_interleave - - def dim(self): - return self.ndim - - Tensor.dim = dim - StubTensor.dim = dim - - def unfold(self, dimension, size, step): - return ops.unfold(self, dimension, size, step) - - Tensor.unfold = unfold - StubTensor.unfold = unfold - - def new(self, *shape): - if not isinstance(shape[0], int): - return tensor(shape[0], dtype=self.dtype) - return ops.empty(*shape, dtype=self.dtype, device=self.device) - - Tensor.new = new - StubTensor.new = new - - def view(self, *args): - return self.reshape(*args) - - Tensor.view = view - StubTensor.view = view - - def cpu(self, *args, **kwargs): - return self - - Tensor.cpu = cpu - StubTensor.cpu = cpu - - Tensor.take = ops.take - StubTensor.take = ops.take - - Tensor.sort = ops.sort - StubTensor.sort = ops.sort - - def requires_grad_(self, requires_grad=True): - self.requires_grad = requires_grad - return self - - Tensor.requires_grad_ = requires_grad_ - StubTensor.requires_grad_ = requires_grad_ - - @property - def data(self): - return Tensor(self) - - @data.setter - def data(self, new_value): - if isinstance(self, StubTensor) and isinstance(new_value, StubTensor): - self.stub = new_value.stub - else: - self.assign_value(new_value) - - Tensor.data = data - StubTensor.data = data - - Tensor.narrow = ops.narrow - StubTensor.narrow = ops.narrow - - def bitwise_or_(self, other): - out = ops.bitwise_or(self, other) - self.copy_(out) - return self - - Tensor.bitwise_or_ = bitwise_or_ - StubTensor.bitwise_or_ = bitwise_or_ - - # fix TypeError: unhashable type: 'StubTensor' - StubTensor.__hash__ = Tensor.__hash__ - - Tensor.masked_fill = ops.masked_fill - StubTensor.masked_fill = ops.masked_fill - - Tensor.reshape = ops.reshape - StubTensor.reshape = ops.reshape + return ops.mul(self, other) def __rmul__(self, other): if isinstance(other, (str, list)): return self.item() * other return self.__mul__(other) - Tensor.__rmul__ = __rmul__ - StubTensor.__rmul__ = __rmul__ - Tensor.norm = ops.norm - StubTensor.norm = ops.norm + def __imul__(self, other): + return self.copy_(ops.mul(self, other)) - def clamp_min(self, value): - return ops.clamp(self, value) + def __itruediv__(self, other): + return self.copy_(ops.div(self, other)) - Tensor.clamp_min = clamp_min - StubTensor.clamp_min = clamp_min + def __pow__(self, other): + return ops.pow(self, other) - Tensor.index_copy_ = ops.inplace_index_copy - StubTensor.index_copy_ = ops.inplace_index_copy + def __rpow__(self, other): + return ops.pow(other, self) - Tensor.max = ops.max - StubTensor.max = ops.max + def __sub__(self, other): + # if 0 in self.shape: + # return self + if isinstance(other, (np.ndarray, np.integer)): + other = tensor(other) + return ops.sub(self, other) - Tensor.min = ops.min - StubTensor.min = ops.min + def __isub__(self, other): + return self.copy_(ops.sub(self, other)) - Tensor.squeeze_ = ops.inplace_squeeze - StubTensor.squeeze_ = ops.inplace_squeeze + def __rsub__(self, other): + return ops.sub(other, self) - Tensor.unsqueeze_ = ops.inplace_unsqueeze - StubTensor.unsqueeze_ = ops.inplace_unsqueeze + def __eq__(self, other): + return ops.eq(self, other) - def pin_memory(self, *args, **kwargs): - return self - - Tensor.pin_memory = pin_memory - StubTensor.pin_memory = pin_memory + def __gt__(self, other): + return ops.gt(self, other) - def __deepcopy__(self, memodict): - new_obj = Tensor(self) - return new_obj + def __ge__(self, other): + return ops.ge(self, other) - Tensor.__deepcopy__ = __deepcopy__ - StubTensor.__deepcopy__ = __deepcopy__ + def __lt__(self, other): + return ops.lt(self, other) - def asnumpy(self): - return Tensor_.asnumpy(self) + def __le__(self, other): + return ops.le(self, other) - Tensor.asnumpy = asnumpy - StubTensor.asnumpy = _stub_method(asnumpy) + def __int__(self): + return int(self.item()) - def backward(self, *args, **kwargs): - pass + def __bool__(self): + return bool(self.item()) + + def __index__(self): + if self.ndim > 0: + return self.tolist() + return int(self.item()) - Tensor.backward = backward - StubTensor.backward = backward + def __and__(self, other): + return ops.bitwise_and(self, other) + + def __xor__(self, other): + return ops.bitwise_xor(self, other) - def __repr__(self): - Tensor_.data_sync(self, True) - return Tensor_.__repr__(self) + def __or__(self, other): + return ops.bitwise_or(self, other) - Tensor.__repr__ = __repr__ - StubTensor.__repr__ = _stub_method(__repr__) + def __invert__(self): + return ops.logical_not(self) - def detach_(self): - return ops.stop_gradient(self) + def __round__(self): + return ops.round(self) - Tensor.detach_ = detach_ - StubTensor.detach_ = detach_ + # def __del__(self): + # # self._offload() + # # Tensor_.__del__(self) + # mindspore.runtime.synchronize() + + def new(self, *shape): + if not isinstance(shape[0], int): + return tensor(shape[0], dtype=self.dtype) + return ops.empty(*shape, dtype=self.dtype, device=self.device) + + # Tensor.new_tensor + def new_tensor(self, data, *, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False): + return tensor(data, dtype=dtype if dtype is not None else self.dtype) + # Tensor.new_full def new_full(self, size, fill_value, *, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False): return ops.full(size, fill_value, dtype=dtype if dtype is not None else self.dtype) - Tensor.new_full = new_full - StubTensor.new_full = new_full + # Tensor.new_empty + def new_empty(self, size, *, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False): + if dtype is None: + dtype = self.dtype + if device is None: + device = self.device + return ops.empty(*size, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory) - def new_zeros(self, *size, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False): + # Tensor.new_ones + def new_ones(self, *size, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False, **kwargs): + size = kwargs.get('size', size) if isinstance(size[0], (tuple, list)): size = size[0] @@ -593,13 +390,13 @@ def new_zeros(self, *size, dtype=None, device=None, requires_grad=False, layout= if isinstance(s, Tensor): s = s.item() new_size += (s,) - return ops.zeros(*new_size, dtype=dtype if dtype is not None else self.dtype) + if new_size == new_size: + new_size = (new_size,) + return ops.ones(*new_size, dtype=dtype if dtype is not None else self.dtype, device=self.device) - Tensor.new_zeros = new_zeros - StubTensor.new_zeros = new_zeros - def new_ones(self, *size, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False, **kwargs): - size = kwargs.get('size', size) + # Tensor.new_zeros + def new_zeros(self, *size, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False): if isinstance(size[0], (tuple, list)): size = size[0] @@ -608,234 +405,2119 @@ def new_ones(self, *size, dtype=None, device=None, requires_grad=False, layout=N if isinstance(s, Tensor): s = s.item() new_size += (s,) - if new_size == new_size: - new_size = (new_size,) - return ops.ones(*new_size, dtype=dtype if dtype is not None else self.dtype) + return ops.zeros(*new_size, dtype=dtype if dtype is not None else self.dtype) - Tensor.new_ones = new_ones - StubTensor.new_ones = new_ones + # Tensor.ndim + @property + def ndim(self): + return len(self.shape) - Tensor.sum = ops.sum - StubTensor.sum = ops.sum + def dim(self): + return self.ndim - def new_tensor(self, data, *, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False): - return tensor(data, dtype=dtype if dtype is not None else self.dtype) + def ndimension(self): + return self.ndim - Tensor.new_tensor = new_tensor - StubTensor.new_tensor = new_tensor + # Tensor.real + @property + def real(self): + return ops.real(self) - Tensor.fill_diagonal_ = ops.inplace_fill_diagonal - StubTensor.fill_diagonal_ = ops.inplace_fill_diagonal + # Tensor.imag + @property + def imag(self): + return ops.imag(self) - Tensor.fill_ = ops.inplace_fill - StubTensor.fill_ = ops.inplace_fill + # Tensor.nbytes + @property + def nbytes(self): + return self.numel() * self.element_size() - Tensor.zero_ = ops.inplace_zero - StubTensor.zero_ = ops.inplace_zero + # Tensor.itemsize + @property + def itemsize(self): + return self._data._itemsize - Tensor.uniform_ = ops.inplace_uniform - StubTensor.uniform_ = ops.inplace_uniform + # Tensor.abs + def abs(self): + return ops.abs(self) - Tensor.random_ = ops.inplace_random - StubTensor.random_ = ops.inplace_random + # Tensor.abs_ + def abs_(self): + return self.copy_(ops.abs(input)) + # Tensor.absolute + absolute = abs - Tensor.triu_ = ops.inplace_triu - StubTensor.triu_ = ops.inplace_triu + # Tensor.absolute_ + absolute_ = abs_ - Tensor.masked_fill_ = ops.inplace_masked_fill - StubTensor.masked_fill_ = ops.inplace_masked_fill + # Tensor.acos + def acos(self): + return ops.acos(self) + # Tensor.acos_ + def acos_(self): + return self.copy_(ops.acos(input)) - @property - def real(self): - return ops.real(self) - - Tensor.real = real - StubTensor.real = real + # Tensor.arccos + arccos = acos + # Tensor.arccos_ + arccos_ = acos_ + + # Tensor.add + def add(self, other, *, alpha=1): + return ops.add(self, other, alpha=alpha) + + # Tensor.add_ + def add_(self, other, *, alpha=1): + return ops.inplace_add(self, other, alpha=alpha) + + # Tensor.addbmm + def addbmm(self, batch1, batch2, *, beta=1, alpha=1): + return ops.addbmm(self, batch1, batch2, beta=beta, alpha=alpha) + + # Tensor.addbmm_ + def addbmm_(self, batch1, batch2, *, beta=1, alpha=1): + return self.copy_(ops.addbmm(self, batch1, batch2, beta=beta, alpha=alpha)) + + # Tensor.addcdiv + def addcdiv(self, tensor1, tensor2, *, value=1): + return ops.addcdiv(self, tensor1, tensor2, value=value) + + # Tensor.addcdiv_ + def addcdiv_(self, tensor1, tensor2, *, value=1): + return self.copy_(ops.addcdiv(self, tensor1, tensor2, value=value)) + + # Tensor.addcmul + def addcmul(self, tensor1, tensor2, *, value=1): + return ops.addcmul(self, tensor1, tensor2, value=value) + + # Tensor.addcmul_ + def addcmul_(self, tensor1, tensor2, *, value=1): + return self.copy_(ops.addcmul(self, tensor1, tensor2, value=value)) + + # Tensor.addmm + def addmm(self, mat1, mat2, *, beta=1, alpha=1): + return ops.addmm(self, mat1, mat2, beta=beta, alpha=alpha) + + # Tensor.addmm_ + def addmm_(self, mat1, mat2, *, beta=1, alpha=1): + return self.copy_(ops.addmm(self, mat1, mat2, beta=beta, alpha=alpha)) + + # Tensor.sspaddmm + + + # Tensor.addmv + def addmv(self, mat, vec, *, beta=1, alpha=1): + return ops.addmv(self, mat, vec, beta=beta, alpha=alpha) + + # Tensor.addmv_ + def addmv_(self, mat, vec, *, beta=1, alpha=1): + return self.copy_(ops.addmv(self, mat, vec, beta=beta, alpha=alpha)) + + # Tensor.addr + + # Tensor.addr_ + + + # Tensor.adjoint + + # Tensor.allclose + def allclose(self, other, rtol=1e-05, atol=1e-08, equal_nan=False): + return ops.allclose(self, other, rtol, atol, equal_nan) + + # Tensor.amax + def amax(self, dim=None, keepdim=False): + return ops.amax(self, dim, keepdim) + + # Tensor.amin + def amin(self, dim=None, keepdim=False): + return ops.amin(self, dim, keepdim) + + # Tensor.aminmax + def aminmax(self, dim=None, keepdim=False): + return ops.aminmax(self, dim=dim, keepdim=keepdim) + + # Tensor.angle + + + # Tensor.apply_ + def apply_(self, callable): + return self.copy_(callable(self)) + + # Tensor.argmax + def argmax(self, dim=None, keepdim=False): + out = ops.argmax(self, dim, keepdim) + return out + + # Tensor.argmin + def argmin(self, dim=None, keepdim=False): + out = ops.argmin(self, dim, keepdim) + return out + + # Tensor.argsort + def argsort(self, dim=-1, descending=False): + return ops.argsort(self, dim=-1, descending=False) + + # Tensor.argwhere + def argwhere(self): + return ops.argwhere(self) + + # Tensor.asin + def asin(self): + return ops.asin(self) + + # Tensor.asin_ + def asin_(self): + return self.copy_(ops.asin(self)) + + # Tensor.arcsin + arcsin = asin + + # Tensor.arcsin_ + arcsin_ = asin_ + + # Tensor.as_strided + def as_strided(self, size, stride, storage_offset=None): + return ops.as_strided(self, size, stride, storage_offset) + + # Tensor.atan + def atan(self): + return ops.atan(self) + + # Tensor.atan_ + def atan_(self): + return self.copy_(ops.atan(self)) + + # Tensor.arctan + arctan = atan + + # Tensor.arctan_ + arctan_ = atan_ + + # Tensor.atan2 + def atan2(self, other): + return ops.atan2(self, other) + + # Tensor.atan2_ + def atan2_(self, other): + return self.copy_(ops.atan2(self, other)) + + # Tensor.arctan2 + arctan2 = atan2 + + # Tensor.arctan2_ + arctan2_ = atan2_ + + # Tensor.all + def all(self, dim=None, keepdim=False): + return ops.all(self, dim, keepdim) + + # Tensor.any + def any(self, dim=None, keepdim=False): + return ops.any(self, dim, keepdim) + + # Tensor.baddbmm + def baddbmm(self, batch1, batch2, *, beta=1, alpha=1): + return ops.baddbmm(self, batch1, batch2, beta=beta, alpha=alpha) + + # Tensor.baddbmm_ + def baddbmm_(self, batch1, batch2, *, beta=1, alpha=1): + return self.copy_(ops.baddbmm(self, batch1, batch2, beta=beta, alpha=alpha)) + + # Tensor.bernoulli + def bernoulli(self, *, generator=None): + return ops.bernoulli(self, generator=generator) + + # Tensor.bernoulli_ + def bernoulli_(self, *, generator=None): + return self.copy_(ops.bernoulli(self, generator=generator)) + + # Tensor.bfloat16 + def bfloat16(self): + return self.to(ops.bfloat16) + + # Tensor.bincount + def bincount(self, weight=None, minlength=0): + return ops.bincount(self, weight, minlength) + + # Tensor.bitwise_not + def bitwise_not(self): + return ops.bitwise_not(self) + + # Tensor.bitwise_not_ + def bitwise_not_(self): + return self.copy_(ops.bitwise_not(self)) + + # Tensor.bitwise_and + def bitwise_and(self, other): + return ops.bitwise_and(self, other) + + # Tensor.bitwise_and_ + def bitwise_and_(self, other): + return self.copy_(ops.bitwise_and(self, other)) + + # Tensor.bitwise_or + def bitwise_or(self, other): + return ops.bitwise_or(self, other) + + # Tensor.bitwise_or_ + def bitwise_or_(self, other): + return self.copy_(ops.bitwise_or(self, other)) + + # Tensor.bitwise_xor + def bitwise_xor(self, other): + return ops.bitwise_xor(self, other) + + # Tensor.bitwise_xor_ + def bitwise_xor_(self, other): + return self.copy_(ops.bitwise_xor(self, other)) + + # Tensor.bitwise_left_shift + + + # Tensor.bitwise_left_shift_ + + + # Tensor.bitwise_right_shift + + + # Tensor.bitwise_right_shift_ + + + # Tensor.bmm + def bmm(self, batch2): + return ops.bmm(self, batch2) + + # Tensor.bool + def bool(self): + return self.to(mindspore.bool_) + + # Tensor.byte + def byte(self): + return self.to(mindspore.uint8) + + # Tensor.broadcast_to + def broadcast_to(self, shape): + return ops.broadcast_to(self, shape) + + # Tensor.cauchy_ + + + # Tensor.ceil + def ceil(self): + return ops.ceil(self) + + # Tensor.ceil_ + def ceil_(self): + return self.copy_(ops.ceil(self)) + + # Tensor.char + def char(self): + return self.to(mindspore.int8) + + # Tensor.cholesky + + + # Tensor.cholesky_inverse + + + # Tensor.cholesky_solve + + + # Tensor.chunk + def chunk(self, chunks, dim=0): + return ops.chunk(self, chunks, dim) + + # Tensor.clamp + def clamp(self, min=None, max=None): + return ops.clamp(self, min, max) + + # Tensor.clamp_ + def clamp_(self, min=None, max=None): + return self.copy_(ops.clamp(self, min, max)) + + # Tensor.clip + def clip(self, min=None, max=None): + return ops.clip(self, min, max) + + # Tensor.clip_ + def clip_(self, min=None, max=None): + return self.copy_(ops.clip(self, min, max)) + + # Tensor.clone + def clone(self, memory_format=None): + return ops.clone(self) + + # Tensor.contiguous + def contiguous(self): + return ops.contiguous(self) + + # Tensor.copy_ + def copy_(self, value): + if self.dtype != value.dtype: + value = value.to(self.dtype) + return ops.inplace_copy(self, value) + + # Tensor.conj + def conj(self): + return ops.conj(self) + + # Tensor.conj_physical + + + # Tensor.conj_physical_ + + + # Tensor.resolve_conj + + + # Tensor.resolve_neg + + + # Tensor.copysign + + + # Tensor.copysign_ + + + # Tensor.cos + def cos(self): + return ops.cos(self) + + # Tensor.cos_ + def cos_(self): + return self.copy_(ops.cos(self)) + + # Tensor.cosh + def cosh(self): + return ops.cosh(self) + + # Tensor.cosh_ + def cosh_(self): + return self.copy_(ops.cosh(self)) + + # Tensor.corrcoef + + + # Tensor.count_nonzero + def count_nonzero(self, dim=None): + return ops.count_nonzero(self, dim) + + # Tensor.cov + + + # Tensor.acosh + def acosh(self): + return ops.acosh(self) + + # Tensor.acosh_ + def acosh_(self): + return self.copy_(ops.acosh(self)) + + # Tensor.arccosh + arccosh = acosh + + # Tensor.arccosh_ + arccosh_ = acosh_ + + # Tensor.cross + + + # Tensor.logcumsumexp + + + # Tensor.cummax + + + # Tensor.cummin + + + # Tensor.cumprod + + + # Tensor.cumprod_ + + + # Tensor.cumsum + def cumsum(self, dim, dtype=None): + return ops.cumsum(self, dim, dtype) + + # Tensor.cumsum_ + def cumsum_(self, dim, dtype=None): + return self.copy_(ops.cumsum(self, dim, dtype)) + + # Tensor.chalf + + + # Tensor.cfloat + + + # Tensor.cdouble + + + @property + def data(self): + out = Tensor(self) + out._device = self.device + return out + + @data.setter + def data(self, new_value): + if isinstance(self, StubTensor) and isinstance(new_value, StubTensor): + self.stub = new_value.stub + else: + if self.device.type == 'cpu' and new_value.device.type == 'cpu': + src_ct = ctypes.c_void_p(new_value.data_ptr()) + dst_ct = ctypes.c_void_p(self.data_ptr()) + ctypes.memmove(dst_ct, src_ct, self.nbytes) + else: + self.assign_value(new_value) + self._device = new_value.device + + # Tensor.data_ptr + def data_ptr(self): + if self.device.type in ['cpu']: + self.dyn_shape() + # ptr = self._data_ptr() + return self._data_ptr() + + def dyn_shape(self): + return ops.dyn_shape(self) + + # Tensor.deg2rad + def deg2rad(self): + return ops.deg2rad(self) + + # Tensor.dequantize + + + # Tensor.det + + + # Tensor.dense_dim + + + # Tensor.diag + def diag(self, diagonal=0): + return ops.diag(self, diagonal) + + # Tensor.diag_embed + + + # Tensor.diagflat + + + # Tensor.diagonal + def diagnoal(self, offset=0, dim1=0, dim2=1): + return ops.diagonal(self, offset, dim1, dim2) + + + # Tensor.diagonal_scatter + + # Tensor.fill_diagonal_ + + + # Tensor.fmax + + + # Tensor.fmin + + + # Tensor.diff + + + # Tensor.digamma + + + # Tensor.digamma_ + + + # Tensor.dim_order + + + # Tensor.dist + + + # Tensor.div + def div(self, other): + return ops.div(self, other) + + # Tensor.div_ + def div_(self, other): + return self.copy_(ops.div(self, other)) + + # Tensor.divide + divide = div + + # Tensor.divide_ + divide_ = div_ + + # Tensor.dot + def dot(self, other): + return ops.dot(self, other) + + # Tensor.double + def double(self): + return self.to(mindspore.float64) + + # Tensor.dsplit + + + # Tensor.element_size + def element_size(self,): + return DTYPE_ELEMENT_SIZE_MAP[self.dtype] + + # Tensor.eq + def eq(self, other): + return ops.eq(self, other) + + # Tensor.eq_ + def eq_(self, other): + return self.copy_(ops.eq(self, other)) + + # Tensor.equal + def equal(self, other): + return ops.eq(self, other) + + # Tensor.erf + def erf(self): + return ops.erf(self) + + # Tensor.erf_ + def erf_(self): + return self.copy_(ops.erf(self)) + + # Tensor.erfc + def erfc(self): + return ops.erfc(self) + + # Tensor.erfc_ + def erfc_(self): + return self.copy_(ops.erfc(self)) + + # Tensor.erfinv + def erfinv(self): + return ops.erfinv(self) + + + # Tensor.erfinv_ + def erfinv_(self): + return self.copy_(ops.erfinv(self)) + + # Tensor.exp + def exp(self): + return ops.exp(self) + + # Tensor.exp_ + def exp_(self): + return self.copy_(ops.exp(self)) + + + # Tensor.expm1 + def expm1(self): + return ops.expm1(self) + + + # Tensor.expm1_ + def expm1_(self): + return self.copy_(ops.expm1(self)) + + + # Tensor.expand + def expand(self, *size): + if len(size) == 1: + size = size[0] + return self.broadcast_to(size) + + # Tensor.expand_as + def expand_as(self, other): + return self.expand(other.size()) + + # Tensor.exponential_ + + + # Tensor.fix + + + # Tensor.fix_ + + + # Tensor.fill_ + def fill_(self, value): + ops.inplace_fill(self, value) + return self + + # Tensor.flatten + def flatten(self, start_dim=0, end_dim=-1): + return ops.flatten(self, start_dim, end_dim) + + # Tensor.flip + def flip(self, dims): + return ops.flip(self, dims) + + # Tensor.fliplr + + + # Tensor.flipud + + + # Tensor.float + def float(self): + return self.to(mindspore.float32) + + # Tensor.float_power + def float_power(self, exponent): + return ops.float_power(self, exponent) + + # Tensor.float_power_ + def float_power_(self, exponent): + return self.copy_(ops.float_power(self, exponent)) + + # Tensor.floor + def floor(self): + return ops.floor(self) + + # Tensor.floor_ + def floor_(self): + return self.copy_(ops.floor(self)) + + # Tensor.floor_divide + def floor_divide(self, other): + return ops.floor_divide(self, other) + + # Tensor.floor_divide_ + def floor_divide_(self, other): + return self.copy_(ops.floor_divide(self, other)) + + + # Tensor.fmod + def fmod(self, other): + return ops.fmod(self, other) + + # Tensor.fmod_ + def fmod_(self, other): + return self.copy_(ops.fmod(self, other)) + + # Tensor.frac + def frac(self): + return ops.frac(self) + + # Tensor.frac_ + def frac_(self): + return self.copy_(ops.frac(self)) + + + # Tensor.frexp + + + # Tensor.gather + def gather(self, dim, index): + return ops.gather(self, dim, index) + + # Tensor.gcd + + + # Tensor.gcd_ + + + # Tensor.ge + def ge(self, other): + return ops.ge(self, other) + + # Tensor.ge_ + def ge_(self, other): + return self.copy_(ops.ge(self, other)) + + # Tensor.greater_equal + greater_equal = ge + + # Tensor.greater_equal_ + greater_equal_ = ge_ + + + # Tensor.geometric_ + + + # Tensor.geqrf + + + # Tensor.ger + + + # Tensor.get_device + def get_device(self): + return self.device.index + + # Tensor.gt + def gt(self, other): + return ops.gt(self, other) + + # Tensor.gt_ + def gt_(self, other): + return self.copy_(ops.gt(self, other)) + + # Tensor.greater + greater = gt + + # Tensor.greater_ + greater_ = gt_ + + + # Tensor.half + def half(self): + return self.to(mindspore.float16) + + # Tensor.hardshrink + def hardshrink(self, lambd=0.5): + return ops.nn.functional.hardshrink(self, lambd) + + # Tensor.heaviside + + + # Tensor.histc + + + # Tensor.histogram + + + # Tensor.hsplit + + + # Tensor.hypot + + + # Tensor.hypot_ + + + # Tensor.i0 + + + # Tensor.i0_ + + + # Tensor.igamma + + + # Tensor.igamma_ + + + # Tensor.igammac + + + # Tensor.igammac_ + + + # Tensor.index_add_ + def index_add_(self, dim, index, source, *, alpha=1): + return self.copy_(ops.index_add(self, dim, source, alpha=alpha)) + + # Tensor.index_add + def index_add(self, dim, index, source, *, alpha=1): + return ops.index_add(self, dim, source, alpha=alpha) + + # Tensor.index_copy_ + + + # Tensor.index_copy + + + # Tensor.index_fill_ + + + # Tensor.index_fill + + + # Tensor.index_put_ + + + # Tensor.index_put + + + # Tensor.index_reduce_ + + + # Tensor.index_reduce + + # Tensor.index_select + def index_select(self, dim, index): + return ops.index_select(self, dim, index) + + # Tensor.indices + + + # Tensor.inner + + + # Tensor.int + def int(self): + return self.to(mindspore.int64) + + # Tensor.int_repr + + + # Tensor.inverse + + + # Tensor.isclose + def isclose(self, other, rtol=1e-05, atol=1e-08, equal_nan=False): + return ops.isclose(self, other, rtol, atol, equal_nan) + + # Tensor.isfinite + def isfinite(self): + return ops.isfinite(self) + + # Tensor.isinf + def isinf(self): + return ops.isinf(self) + + # Tensor.isposinf + + + # Tensor.isneginf + + + # Tensor.isnan + def isnan(self): + return ops.isnan(self) + + # Tensor.is_contiguous + # def is_contiguous(self): + # return self.is_contiguous() + + # Tensor.is_complex + def is_complex(self): + return False + + # Tensor.is_conj + + + # Tensor.is_floating_point + def is_floating_point(self): + return isinstance(self.dtype, typing.Float) + + # Tensor.is_inference + + + # Tensor.is_leaf + @property + def is_leaf(self): + if not self.requires_grad: + return True + if self.requires_grad and self._user_created: + return True + return False + + # Tensor.is_pinned + def is_pinned(self): + return False + + # Tensor.is_set_to + + + # Tensor.is_shared + + + # Tensor.is_signed + + + # Tensor.is_sparse + @property + def is_sparse(self): + return False + + # Tensor.istft + + + # Tensor.isreal + + + # Tensor.item + def item(self): + return self._item() + + # Tensor.kthvalue + + @property + def layout(self): + return None + + # Tensor.lcm + + + # Tensor.lcm_ + + + # Tensor.ldexp + + + # Tensor.ldexp_ + + + # Tensor.le + def le(self, other): + return ops.le(self, other) + + # Tensor.le_ + def le_(self, other): + return self.copy_(ops.le(self, other)) + + # Tensor.less_equal + less_equal = le + + # Tensor.less_equal_ + less_equal_ = le_ + + + # Tensor.lerp + def lerp(self, end, weight): + return ops.lerp(self, end, weight) + + # Tensor.lerp_ + def lerp_(self, end, weight): + return self.copy_(ops.lerp(self, end, weight)) + + + # Tensor.lgamma + + + # Tensor.lgamma_ + + + # Tensor.log + def log(self): + return ops.log(self) + + # Tensor.log_ + def log_(self): + return self.copy_(ops.log(self)) + + # Tensor.logdet + + + # Tensor.log10 + def log10(self): + return ops.log10(self) + + + # Tensor.log10_ + def log10_(self): + return self.copy_(ops.log10(self)) + + # Tensor.log1p + def log1p(self): + return ops.log1p(self) + + + # Tensor.log1p_ + def log1p_(self): + return self.copy_(ops.log1p(self)) + + + # Tensor.log2 + def log2(self): + return ops.log2(self) + + + # Tensor.log2_ + def log2_(self): + return self.copy_(ops.log2(self)) + + + # Tensor.log_normal_ + + + # Tensor.logaddexp + + + # Tensor.logaddexp2 + + + # Tensor.logsumexp + def logsumexp(self, dim, keepdim=False): + return ops.logsumexp(self, dim, keepdim) + + # Tensor.logical_and + def logical_and(self, other): + return ops.logical_and(self, other) + + # Tensor.logical_and_ + def logical_and_(self, other): + return self.copy_(ops.logical_and(self, other)) + + + # Tensor.logical_not + def logical_not(self): + return ops.logical_not(self) + + + # Tensor.logical_not_ + def logical_not_(self): + return self.copy_(ops.logical_not(self)) + + + # Tensor.logical_or + def logical_or(self, other): + return ops.logical_or(self, other) + + + # Tensor.logical_or_ + def logical_or_(self, other): + return self.copy_(ops.logical_or(self, other)) + + + # Tensor.logical_xor + def logical_xor(self, other): + return ops.logical_xor(self, other) + + # Tensor.logical_xor_ + def logical_xor_(self, other): + return self.copy_(ops.logical_xor(self, other)) + + # Tensor.logit + + + # Tensor.logit_ + + + # Tensor.long + def long(self): + return self.to(mindspore.int64) + + # Tensor.lt + def lt(self, other): + return ops.lt(self, other) + + # Tensor.lt_ + def lt_(self, other): + return self.copy_(ops.lt(self, other)) + + # Tensor.less + less = lt + + # Tensor.less_ + less_ = lt_ + + # Tensor.lu + + + # Tensor.lu_solve + + + # Tensor.as_subclass + + + # Tensor.map_ + + + # Tensor.masked_scatter_ + + + # Tensor.masked_scatter + + + # Tensor.masked_fill_ + def masked_fill_(self, mask, value): + return self.copy_(ops.masked_fill(self, mask, value)) + + # Tensor.masked_fill + def masked_fill(self, mask, value): + return ops.masked_fill(self, mask, value) + + # Tensor.masked_select + def masked_select(self, mask): + return ops.masked_select(self, mask) + + # Tensor.matmul + def matmul(self, other): + return ops.matmul(self, other) + + # Tensor.matrix_power + + + # Tensor.matrix_exp + + + # Tensor.max + def max(self, dim=None, keepdim=False): + return ops.max(self, dim, keepdim) + + # Tensor.maximum + def maximum(self, other): + return ops.maximum(self, other) + + # Tensor.mean + def mean(self, dim=None, keepdim=False, *, dtype=None, **kwargs): + dim = kwargs.pop('axis', dim) + return ops.mean(self, dim, keepdim, dtype=dtype) + + # Tensor.module_load + + + # Tensor.nanmean + + + # Tensor.median + def median(self, dim=-1, keepdim=False): + return ops.median(self, dim, keepdim) + + # Tensor.nanmedian + + + # Tensor.min + def min(self, dim=None, keepdim=False): + return ops.min(self, dim, keepdim) + + # Tensor.minimum + def minimum(self, other): + return ops.minimum(self, other) + + # Tensor.mm + mm = matmul + + # Tensor.smm + + + # Tensor.mode + def mode(self, dim=None, keepdim=False): + return ops.mode(self, dim, keepdim) + + # Tensor.movedim + def movedim(self, source, destination): + return ops.movedim(source, destination) + + # Tensor.moveaxis + moveaxis = movedim + + # Tensor.msort + def msort(self): + return ops.msort(self) + + # Tensor.mul + def mul(self, other): + return ops.mul(self, other) + + # Tensor.mul_ + def mul_(self, other): + return self.copy_(ops.mul(self, other)) + + # Tensor.multiply + multiply = mul + + # Tensor.multiply_ + multiply_ = mul_ + + + # Tensor.multinomial + def multinomial(self, num_samples, replacement=False, *, generator=None): + return ops.multinomial(self, num_samples, replacement, generator=generator) + + # Tensor.mv + + + # Tensor.mvlgamma + + + # Tensor.mvlgamma_ + + + # Tensor.nansum + def nansum(self, dim=None, keepdim=False, *, dtype=None): + return ops.nansum(self, dim, keepdim, dtype=dtype) + + # Tensor.narrow + def narrow(self, dim, start, length): + return ops.narrow(self, dim, start, length) + + # Tensor.narrow_copy + def narrow_copy(self, dimension, start, length): + return ops.narrow(self, dimension, start, length).clone() + + # Tensor.nan_to_num + def nan_to_num(self, nan=0.0, posinf=None, neginf=None): + return ops.nan_to_num(self, nan, posinf, neginf) + + # Tensor.nan_to_num_ + def nan_to_num_(self, nan=0.0, posinf=None, neginf=None): + return self.copy_(ops.nan_to_num(self, nan, posinf, neginf)) + + # Tensor.ne + def ne(self, other): + return ops.ne(self, other) + + # Tensor.ne_ + def ne_(self, other): + return self.copy_(ops.ne(self, other)) + + # Tensor.not_equal + not_equal = ne + + # Tensor.not_equal_ + not_equal_ = ne_ + + + # Tensor.neg + def neg(self): + return ops.neg(self) + + # Tensor.neg_ + def neg_(self): + return self.copy_(ops.neg(self)) + + # Tensor.negative + negative = neg + + # Tensor.negative_ + negative_ = neg_ + + + # Tensor.numel + def numel(self): + return math.prod(self.shape) + + # Tensor.nelement + nelement = numel + + # Tensor.nextafter + + + # Tensor.nextafter_ + + + # Tensor.nonzero + def nonzero(self, as_tuple=False): + return ops.nonzero(self, as_tuple=as_tuple) + + # Tensor.norm + def norm(self, p='fro', dim=None, keepdim=False, dtype=None): + return ops.norm(self, p, dim, keepdim, dtype) + + # Tensor.normal_ + def normal_(self, mean=0, std=1, *, generator=None): + return ops.inplace_normal(self, mean, std, generator=generator) + + # Tensor.numpy + def numpy(self): + assert self.device.type == 'cpu' + return self.asnumpy() + + def mindspore(self): + return mindspore.Tensor(self._data) + + # Tensor.orgqr + + + # Tensor.ormqr + + + # Tensor.outer + def outer(self, vec2): + return ops.outer(self, vec2) + + # Tensor.permute + def permute(self, *dims): + return ops.permute(self, dims) + + # Tensor.pin_memory + + + # Tensor.pinverse + + + # Tensor.polygamma + + + # Tensor.polygamma_ + + + # Tensor.positive + def positive(self): + return self + + # Tensor.pow + def pow(self, exponent): + return ops.pow(self, exponent) + + # Tensor.pow_ + def pow_(self, exponent): + return self.copy_(ops.pow(self, exponent)) + + + # Tensor.prod + def prod(self, dim=None, keepdim=False, dtype=None): + return ops.prod(self, dim, keepdim, dtype=dtype) + + # Tensor.put_ + + + # Tensor.qr + + + # Tensor.qscheme + + + # Tensor.quantile + + + # Tensor.nanquantile + + + # Tensor.q_scale + + + # Tensor.q_zero_point + + + # Tensor.q_per_channel_scales + + + # Tensor.q_per_channel_zero_points + + + # Tensor.q_per_channel_axis + + + # Tensor.rad2deg + + + # Tensor.ravel + def ravel(self): + return ops.ravel(self) + + # Tensor.reciprocal + def reciprocal(self): + return ops.reciprocal(self) + + # Tensor.reciprocal_ + def reciprocal_(self): + return self.copy_(ops.reciprocal(self)) + + + # Tensor.record_stream + def record_stream(self, stream): + pass + + # Tensor.register_hook + def register_hook(self, hook): + return self._data.register_hook(hook) + + # Tensor.register_post_accumulate_grad_hook + + + # Tensor.remainder + def remainder(self, other): + return ops.remainder(self, other) + + # Tensor.remainder_ + def remainder_(self, other): + return self.copy_(ops.remainder(self, other)) + + # Tensor.renorm + + + # Tensor.renorm_ + + + # Tensor.repeat + def repeat(self, *repeats): + return ops.tile(self, repeats) + + # Tensor.repeat_interleave + def repeat_interleave(self, repeats, dim=None): + return ops.repeat_interleave(self, repeats, dim) + + # Tensor.reshape + def reshape(self, *shape): + return ops.reshape(self, *shape) + + # Tensor.reshape_as + def reshape_as(self, other): + return self.reshape(*other.shape) + + # Tensor.resize_ + def resize_(self, *shape): + self.data = ops.reshape(self, *shape) + return self + + # Tensor.resize_as_ + def resize_as_(self, other): + self.data = ops.reshape(self, *other.shape) + return self + + # Tensor.retains_grad @property - def imag(self): - return ops.imag(self) + def retains_grad(self): + return not self.is_leaf and self._retain_grad - Tensor.imag = imag - StubTensor.imag = imag + # Tensor.roll + def roll(self, shifts, dims=None): + return ops.roll(self, shifts, dims) - def bfloat16(self): - return self.to(_dtype.bfloat16) + # Tensor.rot90 + + + # Tensor.round + def round(self): + return ops.round(self) + + # Tensor.round_ + def round_(self): + return self.copy_(ops.round(self)) + + + # Tensor.rsqrt + def rsqrt(self): + return ops.rsqrt(self) + + # Tensor.rsqrt_ + def rsqrt_(self): + return self.copy_(ops.rsqrt(self)) + + + # Tensor.scatter + def scatter(self, dim, index, src): + return ops.scatter(self, dim, index, src) + + # Tensor.scatter_ + def scatter_(self, dim, index, src): + return self.copy_(ops.scatter(self, dim, index, src)) + + # Tensor.scatter_add_ + def scatter_add_(self, dim, index, src): + return self.copy_(ops.scatter_add(self, dim, index, src)) + + # Tensor.scatter_add + def scatter_add(self, dim, index, src): + return ops.scatter_add(self, dim, index, src) + + + # Tensor.scatter_reduce_ + def scatter_reduce_(self, dim, index, src): + return self.copy_(ops.scatter_reduce(self, dim, index, src)) + + + # Tensor.scatter_reduce + def scatter_reduce(self, dim, index, src): + return ops.scatter_reduce(self, dim, index, src) + + + # Tensor.select + def select(self, dim, index): + return ops.select(self, dim, index) + + # Tensor.select_scatter + + + # Tensor.set_ + + + # Tensor.share_memory_ + + + # Tensor.short + def short(self): + return self.to(mindspore.int16) + + # Tensor.sigmoid + def sigmoid(self): + return ops.sigmoid(self) + + # Tensor.sigmoid_ + def sigmoid_(self): + return self.copy_(ops.sigmoid(self)) + + # Tensor.sign + def sign(self): + return ops.sign(self) + + # Tensor.sign_ + def sign_(self): + return self.copy_(ops.sign(self)) + + + # Tensor.signbit + + + # Tensor.sgn + + + # Tensor.sgn_ + + + # Tensor.sin + def sin(self): + return ops.sin(self) + + # Tensor.sin_ + def sin_(self): + return self.copy_(ops.sin(self)) + + + # Tensor.sinc + def sinc(self): + return ops.sinc(self) + + + # Tensor.sinc_ + def sinc_(self): + return self.copy_(ops.sinc(self)) + + # Tensor.sinh + def sinh(self): + return ops.sinh(self) + + + # Tensor.sinh_ + def sinh_(self): + return self.copy_(ops.sinh(self)) + + + # Tensor.asinh + def asinh(self): + return ops.asinh(self) + + + # Tensor.asinh_ + def asinh_(self): + return self.copy_(ops.asinh(self)) + + + # Tensor.arcsinh + arcsinh_ = asinh + + # Tensor.arcsinh_ + arcsinh_ = asinh_ - Tensor.bfloat16 = bfloat16 - StubTensor.bfloat16 = bfloat16 + # Tensor.size + def size(self, dim=None): + if dim is None: + return self.shape + assert isinstance(dim, int), f'`dim` must be int but got {type(dim)}' + return self.shape[dim] + + # Tensor.slogdet + + + # Tensor.slice_scatter + + + # Tensor.softmax + def softmax(self, dim): + return ops.softmax(self, dim) + + # Tensor.sort def sort(self, dim=-1, descending=False): return ops.sort(self, dim=dim, descending=descending) - Tensor.sort = sort - StubTensor.sort = sort + # Tensor.split + def split(self, split_size, dim=0): + return ops.split(self, split_size, dim) - Tensor.cumsum = ops.cumsum - StubTensor.cumsum = ops.cumsum + # Tensor.sparse_mask - Tensor.scatter_ = ops.inplace_scatter - StubTensor.scatter_ = ops.inplace_scatter - def __contains__(self, item): - return ops.eq(self, item).any() + # Tensor.sparse_dim + + + # Tensor.sqrt + def sqrt(self): + return ops.sqrt(self) + + # Tensor.sqrt_ + def sqrt_(self): + return self.copy_(ops.sqrt(self)) - Tensor.__contains__ = __contains__ - StubTensor.__contains__ = __contains__ - Tensor.tile = ops.tile - StubTensor.tile = ops.tile + # Tensor.square + def square(self): + return ops.square(self) - Tensor.mean = ops.mean - StubTensor.mean = ops.mean - Tensor.amax = ops.amax - StubTensor.amax = ops.amax + # Tensor.square_ + def square_(self): + return self.copy_(ops.square(self)) - Tensor.as_strided = ops.as_strided - StubTensor.as_strided = ops.as_strided + # Tensor.squeeze + def squeeze(self, *args, **kwargs): + return ops.squeeze(self, *args, **kwargs) - Tensor.split = ops.split - StubTensor.split = ops.split + # Tensor.squeeze_ + def squeeze_(self, dim=None): + return self.copy_(ops.squeeze(self, dim)) - Tensor.flip = ops.flip - StubTensor.flip = ops.flip - Tensor.unflatten = ops.unflatten - StubTensor.unflatten = ops.unflatten + # Tensor.std + def std(self, dim=None, *, correction=1, keepdim=False): + return ops.std(self, dim, correction=correction, keepdim=keepdim) - Tensor.round_ = ops.inplace_round - StubTensor.round_ = ops.inplace_round + # Tensor.stft - Tensor.split_with_sizes = ops.split_with_sizes - StubTensor.split_with_sizes = ops.split_with_sizes - Tensor.scatter_reduce_ = ops.inplace_scatter_reduce - StubTensor.scatter_reduce_ = ops.inplace_scatter_reduce + # Tensor.storage - Tensor.exponential_ = ops.inplace_exponential - StubTensor.exponential_ = ops.inplace_exponential - Tensor.log_ = ops.inplace_log - StubTensor.log_ = ops.inplace_log + # Tensor.untyped_storage + def untyped_storage(self): + return UntypedStorage(self) - Tensor.mul_ = ops.inplace_mul - StubTensor.mul_ = ops.inplace_mul + # Tensor.storage_offset - Tensor.neg_ = ops.inplace_neg - StubTensor.neg_ = ops.inplace_neg - Tensor.exp_ = ops.inplace_exp - StubTensor.exp_ = ops.inplace_exp + # Tensor.storage_type - Tensor.sub_ = ops.inplace_sub - StubTensor.sub_ = ops.inplace_sub - Tensor.roll = ops.roll - StubTensor.roll = ops.roll + # Tensor.stride + def stride(self, dim=None): + if dim is None: + return self._data.stride() + return self._data.stride()[dim] - Tensor.bernoulli_ = ops.inplace_bernoulli - StubTensor.bernoulli_ = ops.inplace_bernoulli - Tensor.scatter_reduce = ops.scatter_reduce - StubTensor.scatter_reduce = ops.scatter_reduce + # Tensor.sub + def sub(self, other, *, alpha=1): + return ops.sub(self, other, alpha=alpha) - Tensor.tril_ = ops.inplace_tril - StubTensor.tril_ = ops.inplace_tril + # Tensor.sub_ + def sub_(self, other, *, alpha=1): + return self.copy_(ops.sub(self, other, alpha=alpha)) - Tensor.var = ops.var - StubTensor.var = ops.var - Tensor.logsumexp = ops.logsumexp - StubTensor.logsumexp = ops.logsumexp + # Tensor.subtract + subtract = sub - def __iter__(self): - if self.ndim == 0: - yield self + # Tensor.subtract_ + subtract_ = sub_ + + # Tensor.sum + def sum(self, dim=None, keepdim=False, dtype=None): + return ops.sum(self, dim, keepdim, dtype=dtype) + + # Tensor.sum_to_size + + + # Tensor.svd + + + # Tensor.swapaxes + def swapaxes(self, dim0, dim1): + return ops.swapaxes(self, dim0, dim1) + + # Tensor.swapdims + swapdims = swapaxes + + @property + def T(self): + return self.t() + + # Tensor.t + def t(self): + return ops.t(self) + + # Tensor.t_ + def t_(self): + self.data = ops.t(self) + return self + + # Tensor.tensor_split + + + # Tensor.tile + def tile(self, *dims): + return ops.tile(self, dims) + + # Tensor.to + def _move_to(self, device, non_blocking=False): + if device.type == 'meta': + out = Tensor(Tensor_(shape=self.shape, dtype=self.dtype)) + out._device = device + return out + if self.device == device: + return self else: - for i in range(len(self)): - yield self[i] + if DEVICE_TARGET == 'Ascend' and device.type == 'cuda': + device.type = 'npu' + device_str = device_map[device.type] + # if device_str == 'Ascend': + # out = ops.empty_like(self, device=device) + # ACL_MEMCPY_HOST_TO_DEVICE = 1 + # ret = acl.rt.memcpy(out.data_ptr(), self.nbytes, self.data_ptr(), self.nbytes, ACL_MEMCPY_HOST_TO_DEVICE) + # else: + # self.data_sync(True) + if self.device.type == 'cpu': + self.data_ptr() + data = self.move_to(device_str, blocking=not non_blocking) - Tensor.__iter__ = __iter__ - StubTensor.__iter__ = __iter__ + out = Tensor(data) + out._device = device + return out - def __float__(self): - out = self.item() - return round(float(out), 5) + def to(self, *args, **kwargs): + non_blocking = kwargs.get('non_blocking', False) + copy = kwargs.get('copy', False) + out = self + device = kwargs.pop('device', None) + dtype = kwargs.pop('dtype', None) + if device: + args += (device,) + if dtype: + args += (dtype,) - Tensor.__float__ = __float__ - StubTensor.__float__ = __float__ + for arg in args: + if isinstance(arg, device_): + out = Tensor._move_to(out, arg, non_blocking) + elif isinstance(arg, int): + device = device_(arg) + out = Tensor._move_to(out, device, non_blocking) + elif isinstance(arg, str): + device = device_(arg) + out = Tensor._move_to(out, device, non_blocking) + elif isinstance(arg, mindspore.common.dtype.Type): + if out.dtype == arg: + return out + else: + out = ops.cast(out, arg) + elif isinstance(arg, Tensor): + out = Tensor._move_to(out, arg.device, non_blocking) + if out.dtype == arg: + return out + else: + out = ops.cast(out, arg) + return out - Tensor.__matmul__ = ops.matmul - StubTensor.__matmul__ = ops.matmul + # Tensor.take + def take(self, index): + return ops.take(self, index) - Tensor.expm1 = ops.expm1 - StubTensor.expm1 = ops.expm1 + # Tensor.take_along_dim - Tensor.__eq__ = ops.eq - StubTensor.__eq__ = ops.eq - def tobytes(self): - return self.get_bytes() - - Tensor.tobytes = tobytes - StubTensor.tobytes = tobytes + # Tensor.tan + def tan(self): + return ops.tan(self) - Tensor.cuda = cpu - StubTensor.cuda = cpu + # Tensor.tan_ + def tan_(self): + return self.copy_(ops.tan(self)) - Tensor.nonzero = ops.nonzero - StubTensor.nonzero = ops.nonzero - Tensor.clamp_ = ops.inplace_clamp - StubTensor.clamp_ = ops.inplace_clamp + # Tensor.tanh + def tanh(self): + return ops.tanh(self) - Tensor.copy_ = ops.inplace_copy - StubTensor.copy_ = ops.inplace_copy - Tensor.index_add_ = ops.inplace_index_add - StubTensor.index_add_ = ops.inplace_index_add + # Tensor.tanh_ + def tanh_(self): + return self.copy_(ops.tanh(self)) - Tensor.erfinv_ = ops.inplace_erfinv - StubTensor.erfinv_ = ops.inplace_erfinv - def is_pinned(self): - return False - - Tensor.is_pinned = is_pinned - StubTensor.is_pinned = is_pinned + # Tensor.atanh + + def atanh(self): + return ops.atanh(self) + + + # Tensor.atanh_ + def atanh_(self): + return self.copy_(ops.atanh(self)) + + + # Tensor.arctanh + arctanh = atanh + + # Tensor.arctanh_ + arctanh_ = atanh_ + + # Tensor.tolist + # def tolist(self): + # return self.numpy().tolist() + + # Tensor.topk + def topk(self, k, dim=-1, largest=True, sorted=True): + return ops.topk(self, k, dim, largest, sorted) + + # Tensor.to_dense + + + # Tensor.to_sparse + + + # Tensor.to_sparse_csr - def record_stream(self, stream): - pass - Tensor.record_stream = record_stream - StubTensor.record_stream = record_stream + # Tensor.to_sparse_csc - Tensor.scatter = ops.scatter - StubTensor.scatter = ops.scatter - Tensor.mul = ops.mul - StubTensor.mul = ops.mul + # Tensor.to_sparse_bsr - Tensor.index_select = ops.index_select - StubTensor.index_select = ops.index_select - Tensor.gather = ops.gather - StubTensor.gather = ops.gather + # Tensor.to_sparse_bsc + + # Tensor.trace + + + # Tensor.transpose + def transpose(self, dim0, dim1): + return ops.transpose(self, dim0, dim1) + + # Tensor.transpose_ + def transpose_(self, dim0, dim1): + self.data = ops.transpose(self, dim0, dim1) + return self + + # Tensor.triangular_solve + + + # Tensor.tril + def tril(self, diagonal=0): + return ops.tril(self, diagonal) + + # Tensor.tril_ + def tril_(self, diagonal=0): + return self.copy_(ops.tril(self, diagonal)) + + + # Tensor.triu + def triu(self, diagonal=0): + return ops.triu(self, diagonal) + + + # Tensor.triu_ + def triu_(self, diagonal=0): + return self.copy_(ops.triu(self, diagonal)) + + + # Tensor.true_divide + def true_divide(self, other): + return ops.true_divide(self, other) + + # Tensor.true_divide_ + def true_divide_(self, other): + return self.copy_(ops.true_divide(self, other)) + + + # Tensor.trunc + def trunc(self): + return ops.trunc(self) + + # Tensor.trunc_ + def trunc_(self): + return self.copy_(ops.trunc(self)) + + + # Tensor.type + def type(self, dtype=None, non_blocking=False): + if dtype is None: + dtype_str = str(dtype_class_map[self.dtype])[8:-2] + dtype_str = dtype_str.replace('_tensor', self.device.type) \ + if self.device.type != 'cpu' else dtype_str.replace('._tensor', '') + return dtype_str + return self.to(dtype, non_blocking=non_blocking) + + # Tensor.type_as + def type_as(self, tensor): + return self.type(tensor.dtype) + + # Tensor.unbind + def unbind(self, dim=0): + return ops.unbind(self, dim) + + # Tensor.unflatten + def unflatten(self, dim, sizes): + return ops.unflatten(self, dim, sizes) + + # Tensor.unfold + def unfold(self, dimension, size, step): + return ops.unfold(self, dimension, size, step) + + # Tensor.uniform_ + def uniform_(self, *args, **kwargs): + return ops.inplace_uniform(self, *args, **kwargs) + + # Tensor.random_ + def random_(self, *args, **kwargs): + return ops.inplace_random(self, *args, **kwargs) + + # Tensor.unique + def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None): + return ops.unique(self, sorted, return_inverse, return_counts, dim) + + # Tensor.unique_consecutive + def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None): + return ops.unique_consecutive(self, return_inverse, return_counts, dim) + + # Tensor.unsqueeze + def unsqueeze(self, dim): + return ops.unsqueeze(self, dim) + + # Tensor.unsqueeze_ + def unsqueeze_(self, dim): + return self.copy_(ops.unsqueeze(self, dim)) + + + # Tensor.values + + + # Tensor.var + def var(self, dim=None, *, correction=1, keepdim=False): + return ops.var(self, dim, correction=correction, keepdim=keepdim) + + # Tensor.vdot + + + # Tensor.view + def view(self, *shape): + return self.reshape(*shape) + + # Tensor.view_as + def view_as(self, other): + return self.reshape(*other.shape) + + # Tensor.vsplit + + + # Tensor.where + def where(self, condition, y): + return ops.where(condition, self, y) + + # Tensor.xlogy + def xlogy(self, other): + return ops.xlogy(self, other) + + # Tensor.xlogy_ + def xlogy_(self, other): + return self.copy_(ops.xlogy(self, other)) + + # Tensor.zero_ + def zero_(self): + return ops.inplace_zero(self) + + # Tensor.detach + def detach(self): + out = self.data + out._requires_grad = False + return out + + # Tensor.detach_ + def detach_(self): + self.requires_grad_(self) + return self + + def stub_sync(self): + if self.stub: + self.tensor = self.stub.get_value() + self.stub = None + return self.tensor + + @property def is_cuda(self): device_type = 'cuda' if DEVICE_TARGET == 'Ascend': device_type = 'npu' return self.device.type == device_type - Tensor.is_cuda = is_cuda - StubTensor.is_cuda = is_cuda + def tobytes(self): + return self.get_bytes() + + def __contains__(self, item): + return ops.eq(self, item).any() + + def __float__(self): + out = self.item() + return round(float(out), 5) + + def pin_memory(self, *args, **kwargs): + return self + + + @property + def shape(self): + if isinstance(self, StubTensor): + if self.stub is not None: + stub_shape = self.stub.get_shape() + else: + stub_shape = self.tensor.shape + return Size(stub_shape) + return Size(self._shape) + + @property + def is_meta(self): + return False + + @property + def device(self): + if not hasattr(self, '_device'): + raise ValueError('Tensor must have device') + return self._device + + def _convert_numpy_slices(self, key): + """递归转换 key 中的 NumPy 整数为内置 int""" + # 处理元组:遍历所有元素并递归转换 + if isinstance(key, tuple): + return tuple(self._convert_numpy_slices(k) for k in key) + + # 处理 slice 对象:转换 start/stop/step + elif isinstance(key, slice): + start = key.start + stop = key.stop + step = key.step + + # 转换 NumPy 整数为 Python int + if isinstance(start, (np.integer, Tensor)): + start = int(start) + if isinstance(stop, (np.integer, Tensor)): + stop = int(stop) + if isinstance(step, (np.integer, Tensor)): + step = int(step) + + return slice(start, stop, step) + + # 转换单个 NumPy 索引值 + elif isinstance(key, np.integer): + return int(key) + + # 其他类型(如 int、None)直接返回 + else: + return key + + def __deepcopy__(self, memodict): + new_obj = Tensor(self) + new_obj._device = self.device + return new_obj + + def __matmul__(self, other): + return ops.matmul(self, other) + + def __truediv__(self, other): + return ops.true_divide(self, other) + + def __floordiv__(self, other): + return ops.floor_divide(self, other) + + def __mod__(self, other): + return ops.fmod(self, other) + + def backward(self): + pass + + def log_softmax(self, dim): + return ops.log_softmax(self, dim) + +def enable_mindspore_patch(): + fn_keys = list(TensorPlaceHolder.__dict__) + fn_keys.remove('__doc__') + fn_keys.remove('__dict__') + fn_keys.remove('__weakref__') + fn_keys.remove('__module__') + + for fn in fn_keys: + setattr(Tensor, fn, getattr(TensorPlaceHolder, fn)) + if StubTensor is not None: + setattr(StubTensor, fn, getattr(TensorPlaceHolder, fn)) - Tensor.__rshift__ = ops.bitwise_right_shift - StubTensor.__rshift__ = ops.bitwise_right_shift def _rebuild_from_type_v2(func, new_type, args, state): ret = func(*args) diff --git a/mindnlp/core/configs.py b/mindnlp/core/configs.py index 4711dfcc8..7342cc22b 100644 --- a/mindnlp/core/configs.py +++ b/mindnlp/core/configs.py @@ -11,6 +11,7 @@ DEFAULT_DTYPE = mindspore.float32 MS27 = '.'.join(mindspore.__version__.split('.')[:2]) >= '2.7' +CPU_USE_NUMPY_OP = DEVICE_TARGET != 'CPU' def set_pyboost(mode: bool): """set global pyboost""" diff --git a/mindnlp/core/dispatcher.py b/mindnlp/core/dispatcher.py new file mode 100644 index 000000000..1efe099a2 --- /dev/null +++ b/mindnlp/core/dispatcher.py @@ -0,0 +1,77 @@ +from mindnlp import core +from .types import device as device_ +from ._prims import ascend, cpu, numpy, meta +from .configs import DEVICE_TARGET, CPU_USE_NUMPY_OP + +device_map = {"cpu": "CPU", "npu": "Ascend", "cuda": "GPU"} + + +class SingletonMeta(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + return cls._instances[cls] + + +class Dispatcher(metaclass=SingletonMeta): + def __init__(self): + self._registry = {"cpu": {}, "npu": {}, "gpu": {}, 'numpy': {}, 'meta': {}} + + def register(self, func_name, device, func): + self._registry[device][func_name] = func + + def dispatch(self, func_name, *args, **kwargs): + device = kwargs.pop("device", None) + if isinstance(device, str): + device = device_(device) + + if device is None: + tensors = ( + [arg for arg in args[0] if core.is_tensor(arg)] + if isinstance(args[0], (tuple, list)) + else [arg for arg in args if core.is_tensor(arg)] + ) + + if len(tensors) == 1: + device = tensors[0].device + + else: + devices = {tensor.device for tensor in tensors} + + if len(devices) > 1: + raise ValueError("All tensor arguments must be on the same device.") + + device = next(iter(devices), device_("cpu")) + + if DEVICE_TARGET == 'Ascend' and device.type == 'cuda': + device.type = 'npu' + + device_type = device.type + + if CPU_USE_NUMPY_OP and device_type == 'cpu': + device_type = 'numpy' + + func = self._registry[device_type].get(func_name, None) + if func is None: + raise RuntimeError( + f"No implementation for function: {func_name} on {device_type}." + ) + return func(*args), device + + +dispatcher = Dispatcher() +for func_name in ascend.__all__: + dispatcher.register(func_name, "npu", getattr(ascend, func_name)) + +for func_name in cpu.__all__: + dispatcher.register(func_name, "cpu", getattr(cpu, func_name)) + +for func_name in numpy.__all__: + dispatcher.register(func_name, "numpy", getattr(numpy, func_name)) + +for func_name in meta.__all__: + dispatcher.register(func_name, "meta", getattr(meta, func_name)) + diff --git a/mindnlp/core/distributed/c10d/process_group.py b/mindnlp/core/distributed/c10d/process_group.py index b21c6c336..b1743298a 100644 --- a/mindnlp/core/distributed/c10d/process_group.py +++ b/mindnlp/core/distributed/c10d/process_group.py @@ -1,8 +1,10 @@ from mindnlp import core from mindnlp.core import Tensor +from mindnlp.core.executor import execute from typing import List, Optional, Dict, Any from enum import Enum + class BackendType(Enum): UNDEFINED = 0 GLOO = 1 @@ -108,7 +110,7 @@ def end_coalescing(self, device_type): def broadcast(self, tensors: List[Tensor], opts: Any) -> Any: tensor = tensors[0] - _, work = execute('dist_comm_broadcast', tensor, opts.rootRank, self._name, device=self.device) + _, work = execute('dist_comm_broadcast', tensor, opts.rootRank, self._rank, self._name, device=self.device) return work def allreduce(self, tensors: List[Tensor], opts: Any) -> Any: diff --git a/mindnlp/core/distributed/distributed_c10d.py b/mindnlp/core/distributed/distributed_c10d.py index a505e81ca..2bd9f2caf 100644 --- a/mindnlp/core/distributed/distributed_c10d.py +++ b/mindnlp/core/distributed/distributed_c10d.py @@ -23,6 +23,7 @@ import mindspore from mindspore.communication import init, GlobalComm, get_group_size, get_process_group_ranks as _get_group_ranks, \ create_group, get_rank as _get_rank +from mindspore.common.api import _pynative_executor from mindnlp import core # from core._C import _DistStoreError as DistStoreError @@ -4156,6 +4157,7 @@ def barrier( return work else: work.wait() + _pynative_executor.sync() def monitored_barrier( diff --git a/mindnlp/core/executor.py b/mindnlp/core/executor.py new file mode 100644 index 000000000..2762a7394 --- /dev/null +++ b/mindnlp/core/executor.py @@ -0,0 +1,13 @@ +from mindnlp import core +from .dispatcher import dispatcher + +def execute(func_name, *args, **kwargs): + out, device = dispatcher.dispatch(func_name, *args, **kwargs) + if not isinstance(out, (tuple, list)): + out._device = device + else: + for i in out: + if isinstance(i, core.Tensor): + i._device = device + return out + diff --git a/mindnlp/core/linalg/__init__.py b/mindnlp/core/linalg/__init__.py index 3bcc1b3b4..e73154a41 100644 --- a/mindnlp/core/linalg/__init__.py +++ b/mindnlp/core/linalg/__init__.py @@ -16,9 +16,9 @@ def cholesky_ex(A, *, upper=False, check_errors=False, out=None): try: out = cholesky(A, upper=upper, out=out) out + 1 - info = core.Tensor(0) + info = core.tensor(0, device=A.device) except: - info = core.Tensor(1) + info = core.tensor(1, device=A.device) out = A return linalg_cholesky_ex(out, info) diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py index 2cef6b6b0..2cb847e2b 100644 --- a/mindnlp/core/nn/functional.py +++ b/mindnlp/core/nn/functional.py @@ -3,80 +3,60 @@ import numbers import warnings from typing import Optional, Tuple, List -import numpy as np -import mindspore -from mindspore import ops, mint -from mindspore.ops._primitive_cache import _get_cache_prim -from mindspore.ops.auto_generate import (reflection_pad_1d_op, reflection_pad_2d_op, add_layernorm_v2_op, - reflection_pad_3d_op, # pylint: disable=W0611 - replication_pad_1d_op, replication_pad_2d_op, replication_pad_3d_op, - constant_pad_nd_op, dropout_ext_op, reverse_v2_impl, avg_pool2d_op, - upsample_nearest1d_op, upsample_nearest2d_op, upsample_nearest3d_op, - upsample_linear1d_op, upsample_bilinear2d_op, upsample_bicubic2d_op, - upsample_trilinear3d_impl, fill_scalar_op, floor_op, nllloss_2d_op, - masked_fill_op, masked_select, ones, flatten_ext, conv_transpose2d) - -from mindspore.ops.auto_generate.pyboost_inner_prim import nllloss_impl - - from mindnlp import core +from mindnlp.core.executor import execute + from ..configs import DEVICE_TARGET, ON_ORANGE_PI, use_pyboost, ON_A1 generator_step_ = 12 def gelu(input, *, approximate='none'): - if use_pyboost(): - return mint.nn.functional.gelu(input, approximate=approximate) - return ops.gelu(input, approximate) + if input.device.type == 'npu': + return execute('gelu_ext', input, approximate) + if approximate == 'tanh': + return execute('gelu', input) + return input * 0.5 * (1.0 + core.erf(input / core.sqrt(2.0))) + def relu(input, inplace=False): - if use_pyboost(): - return mint.nn.functional.relu(input) - return ops.relu(input) + if inplace: + execute('inplace_relu', input) + return input + return execute('relu', input) def tanh(input, inplace=False): - if use_pyboost(): - return mint.nn.functional.tanh(input) - return ops.tanh(input) - + if inplace: + execute('inplace_tanh', input) + return input + return execute('tanh', input) def sigmoid(input): - if use_pyboost() and not ON_ORANGE_PI: - return mint.nn.functional.sigmoid(input) - return ops.sigmoid(input) + return execute('sigmoid', input) def silu(input, inplace=False): - if DEVICE_TARGET == 'CPU' or ON_ORANGE_PI: - return input * sigmoid(input) - if use_pyboost(): - return mint.nn.functional.silu(input) - return ops.silu(input) + if inplace: + execute('inplace_silu', input) + return input + return execute('silu', input) def mish(input): - return ops.mish(input) + return execute('mish', input) def relu6(input): - return ops.relu6(input) + return execute('relu6', input) def elu(input, alpha=1.0): - if use_pyboost(): - return mint.nn.functional.elu(input, alpha) - return ops.elu(input, alpha) + return execute('relu6', input, alpha) def glu(input, dim=-1): - return ops.glu(input, dim) + return execute('glu', input, dim) def softplus(input, beta=1, threshold=20): - if use_pyboost(): - return mint.nn.functional.softplus(input, beta, threshold) - return ops.softplus(input, beta, threshold) + return execute('softplus', input, beta, threshold) def logsigmoid(input): - if use_pyboost(): - return mint.nn.functional.logsigmoid(input) - return ops.logsigmoid(input) - + return execute('logsigmoid', input) def leaky_relu(input, alpha=0.2): if use_pyboost(): @@ -143,7 +123,6 @@ def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, coun divisor_override = 0 return ops.avg_pool2d(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) -has_avg_pool3d = hasattr(mint.nn.functional, 'avg_pool3d') def avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None): if use_pyboost() and has_avg_pool3d: return mint.nn.functional.avg_pool3d(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) @@ -163,12 +142,14 @@ def adaptive_avg_pool2d(input, output_size): return mint.nn.functional.adaptive_avg_pool2d(input, output_size) return ops.adaptive_avg_pool2d(input, output_size) -def dropout(input, p=0.5, training=True): - if not training or p == 0 or 0 in input.shape: +def dropout(input, p=0.5, training=True, inplace=False): + if not training: return input - if use_pyboost() and not ON_ORANGE_PI: - return mint.nn.functional.dropout(input, p, training) - return ops.dropout(input, p, training) + out, _ = execute('dropout_ext', input, p) + if inplace: + input.copy_(out) + return input + return out def dropout2d(input, p=0.5, training=False): return ops.dropout2d(input, p, training) @@ -180,24 +161,21 @@ def drop_and_mask(keep_prob, seed=None): out, mask = dropout_op(input) return out, mask -dense_ = ops.Dense() def linear(input, weight, bias=None): if ON_ORANGE_PI: input = input.to(core.float16) weight = weight.to(core.float16) if bias is not None: bias = bias.to(core.float16) - return dense_(input, weight) + bias - return dense_(input, weight) - if use_pyboost(): - return mint.nn.functional.linear(input, weight, bias) - return dense_(input, weight, bias) + return execute('dense', input, weight) + bias + return execute('dense', input, weight) + return execute('dense', input, weight, bias) def binary_cross_entropy_with_logits(input, target, weight=None, reduction='mean', pos_weight=None): if input.shape != target.shape: target = target.unsqueeze(1).expand_as(input).to(input.dtype) - if use_pyboost(): - return mint.nn.functional.binary_cross_entropy_with_logits(input, target, weight, reduction, pos_weight) + + return execute('binary_cross_entropy_with_logits', input, target, weight, pos_weight, reduction) return ops.binary_cross_entropy_with_logits(input, target.astype(input.dtype), weight, pos_weight, reduction) def gumbel_softmax(logits: core.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> core.Tensor: @@ -220,19 +198,12 @@ def gumbel_softmax(logits: core.Tensor, tau: float = 1, hard: bool = False, eps: return ret def log_softmax(input, dim=None, dtype=None): - if use_pyboost(): - return mint.nn.functional.log_softmax(input, dim=dim, dtype=dtype) - if dim is None: - dim = -1 - out = ops.log_softmax(input, dim) - if dtype is not None: - out = out.to(dtype) - return out + if input.device.type == 'cpu': + return execute('log_softmax', input, dim) + return execute('log_softmax_ext', input, dim, dtype) -def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, *args, **kwargs): - if use_pyboost(): - return mint.nn.functional.embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq) - return ops.gather(weight, input, 0) +def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False): + return execute('embedding', input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq) def rms_norm(input, normalized_shape, weight, eps=None): if eps is None: @@ -286,6 +257,28 @@ def custom_circular_pad(x, pad): return x +def _reflection_pad(input, pad): + """reflection pad""" + out = input + if len(pad) == 2: + out = execute('reflection_pad_1d', input, pad) + elif len(pad) == 4: + out = execute('reflection_pad_2d', input, pad) + else: + out = execute('reflection_pad_3d', input, pad) + return out + +def _replication_pad(input, pad): + """replication pad""" + out = input + if len(pad) == 2: + out = execute('replication_pad_1d', input, pad) + elif len(pad) == 4: + out = execute('replication_pad_2d', input, pad) + else: + out = execute('replication_pad_3d', input, pad) + return out + def pad(input, pad, mode='constant', value=None): if input.device.type != 'npu': if mode == 'reflect' and input.ndim > 4: @@ -294,19 +287,38 @@ def pad(input, pad, mode='constant', value=None): paddings.append([pad[i], pad[i+1]]) old_shape = input.shape shape = (-1, *old_shape[-3:]) - out = ops.MirrorPad()(input.reshape(shape), mindspore.Tensor(paddings)) + out = execute('mirror_pad', input.reshape(shape), core.tensor(paddings, device=input.device)) return out.reshape(*old_shape[:-3], *out.shape[-3:]) - return ops.pad(input, pad, mode, value) + return execute('pad_v3', input, pad, mode, value) if sum(pad) == 0: return input if isinstance(pad, tuple): pad = tuple(p if isinstance(p, int) else p.item() for p in pad) - if use_pyboost() and not ON_A1: - return mint.nn.functional.pad(input, pad, mode, value) + if not ON_A1: + out = input + if (isinstance(pad, tuple) and not pad): + return out + if mode == "constant": + value = 0 if value is None else value + out = execute('constant_pad_nd', input, pad, value) + else: + if value is not None and value != 0: + raise ValueError(f"Padding mode {mode} doesn\'t take in value argument.") + if mode == "circular": + out = _circular_pad(input, pad) + elif mode == "reflect": + out = _reflection_pad(input, pad) + elif mode == "replicate": + out = _replication_pad(input, pad) + else: + raise ValueError(f"Pad filling mode must be 'constant' 'circular' 'reflect' or 'replicate'.") + return out + + if mode in ['reflect', 'replicate']: if mode == 'reflect' and input.ndim > 4: - return reflection_pad_3d_op(input, pad) - return ops.pad(input, pad, mode) + return execute('reflection_pad_3d', input, pad) + return execute('pad_v3', input, pad, mode) if mode == 'circular': return custom_circular_pad(input, pad) new_pad = () @@ -318,17 +330,17 @@ def pad(input, pad, mode='constant', value=None): new_pad += (pad_v,) if sum(new_pad) == 0: return input - if input.dtype == mindspore.bool_: - input = input.to(mindspore.int32) - return ops.pad(input, new_pad, mode, value).to(mindspore.bool_) + if input.dtype == core.bool_: + input = input.to(core.int32) + return execute('pad_v3', input, pad, mode, value).to(core.bool_) if input.ndim > 5 and mode == 'constant': paddings = () for i in range(0, len(new_pad), 2): paddings += (new_pad[i: i+2],) paddings = ((0, 0),) * (input.ndim - len(paddings)) + tuple(reversed(paddings)) - return _get_cache_prim(ops.Pad)(paddings)(input) - return ops.pad(input, new_pad, mode, value) + return execute('pad', paddings, input) + return execute('pad_v3', input, pad, mode, value) def nll_loss(input, target, weight=None, ignore_index=-100, reduction='mean'): if input.device.type == 'npu': @@ -407,16 +419,16 @@ def _nllloss_nd(input, target, weight=None, ingore_index=-100, reduction='mean') class_dim = 0 if input_dim == 1 else 1 n_classes = input.shape[class_dim] if weight is None: - weight = ones(n_classes, input.dtype) + weight = core.ones(n_classes, dtype=input.dtype, device=input.device) if input_dim < 1: raise ValueError(f"input dim should be less than 1, but got {input_dim}") if input_dim != 1 and input.shape[0] != target.shape[0]: raise ValueError(f"input bacth_size should be equal to target batch_size, but got {input.shape[0]} and " f"{target.shape[0]}") if input_dim == 1 or input_dim == 2: - return nllloss_impl(input.float(), target, weight.float(), reduction, ingore_index)[0] + return execute('nllloss', input.float(), target, weight.float(), reduction, ingore_index)[0] if input_dim == 4: - return nllloss_2d_op(input, target, weight, reduction, ingore_index)[0] + return execute('nllloss_2d', input, target, weight, reduction, ingore_index)[0] # input_dim==3 or input_dim>4 n = input.shape[0] c = input.shape[1] @@ -430,8 +442,8 @@ def _nllloss_nd(input, target, weight=None, ingore_index=-100, reduction='mean') else: target = target.view((n, 0, 0)) if reduction != 'none': - return nllloss_2d_op(input, target, weight, reduction, ingore_index)[0] - ret = nllloss_2d_op(input, target, weight, reduction, ingore_index)[0] + return execute('nllloss_2d', input, target, weight, reduction, ingore_index)[0] + ret = execute('nllloss_2d', input, target, weight, reduction, ingore_index)[0] return ret.view(out_size) def cross_entropy(input, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0): @@ -444,7 +456,7 @@ def cross_entropy(input, target, weight=None, ignore_index=-100, reduction='mean input = log_softmax(input, class_dim, dtype=input.dtype) # for probabilities target_dtype = target.dtype - if target_dtype in [mindspore.float32, mindspore.float16, mindspore.bfloat16]: + if target_dtype in [core.float32, core.float16, core.bfloat16]: return _cross_entropy_for_probabilities(input, target, weight, reduction, label_smoothing, class_dim, n_classes) # for class indices @@ -470,7 +482,7 @@ def _cross_entropy_for_probabilities(input, target, weight, reduction, label_smo loss = loss * weight_ loss = loss.view(ori_shape) if reduction == "mean": - return -mint.div(loss.sum(), (input.size / n_classes)) + return -core.div(loss.sum(), (input.size / n_classes)) if reduction == "sum": return -loss.sum() if reduction == "none": @@ -522,7 +534,7 @@ def _cross_entropy_for_class_indices(input, target, weight, ingore_index, reduct def mse_loss(input, target, reduction='mean'): - return ops.mse_loss(input, target, reduction) + return execute('mse_loss_ext', input, target, reduction) def l1_loss(input, target, reduction='mean'): return ops.l1_loss(input, target, reduction) @@ -542,30 +554,14 @@ def manual_softmax(x, dim=-1): return exp_x / ops.sum(exp_x, dim=dim, keepdim=True) def softmax(input, dim=-1, *, dtype=None): - if use_pyboost(): - return mint.nn.functional.softmax(input, dim, dtype=dtype) + out = execute('softmax', input, dim) if dtype is not None: - input = input.to(dtype) - if dim is None: - dim = -1 - if ON_ORANGE_PI: - return manual_softmax(input, dim) - softmax_ = _get_cache_prim(ops.Softmax)(dim) - return softmax_(input) + out = out.to(dtype) + return out def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): - if use_pyboost(): - return mint.nn.functional.layer_norm(input, normalized_shape, weight, bias, eps) - if weight is None: - weight = ops.ones(normalized_shape, dtype=input.dtype) - if bias is None: - bias = ops.zeros(normalized_shape, dtype=input.dtype) - if weight is not None: - begin_axis = input.ndim - weight.ndim - else: - begin_axis = -1 - _layer_norm = _get_cache_prim(ops.LayerNorm)(begin_axis, begin_axis, epsilon=eps) - return _layer_norm(input, weight, bias)[0] + return execute('layer_norm_ext', input, normalized_shape, weight, bias, eps)[0] + def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): if mode in ("nearest", "area", "nearest-exact"): @@ -785,7 +781,6 @@ def batch_norm(input, running_mean, running_var, weight=None, bias=None, trainin eps ) -has_conv1d = hasattr(mint.nn.functional, 'conv1d') def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): if use_pyboost() and has_conv1d and not ON_ORANGE_PI: return mint.nn.functional.conv1d(input, weight, bias, stride, padding, dilation, groups) @@ -1167,7 +1162,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. attn_bias = core.zeros(attn_bias_shape, dtype=query.dtype, device=query.device) if is_causal: assert attn_mask is None - temp_mask = core.ones(L, S, dtype=core.bool).tril(diagonal=0) + temp_mask = core.ones(L, S, dtype=core.bool, device=query.device).tril(diagonal=0) attn_bias = attn_bias.masked_fill_(temp_mask.logical_not(), core.finfo(attn_bias.dtype).min) attn_bias.to(query.dtype) diff --git a/mindnlp/core/nn/modules/module.py b/mindnlp/core/nn/modules/module.py index 32a8c88cf..1d8cfe9e4 100644 --- a/mindnlp/core/nn/modules/module.py +++ b/mindnlp/core/nn/modules/module.py @@ -2,6 +2,7 @@ import warnings import weakref import functools +import inspect from typing import Dict, Optional, Callable, Set, overload, TypeVar, Any, Iterator, Tuple, Union, \ Mapping, List import itertools @@ -555,10 +556,10 @@ def add_module(self, name: str, module: Optional["Module"]) -> None: module (Module): child module to be added to the module. """ if not isinstance(module, Module) and module is not None: - raise TypeError(f"{torch.typename(module)} is not a Module subclass") + raise TypeError(f"{core.typename(module)} is not a Module subclass") elif not isinstance(name, str): raise TypeError( - f"module name should be a string. Got {torch.typename(name)}" + f"module name should be a string. Got {core.typename(name)}" ) elif hasattr(self, name) and name not in self._modules: raise KeyError(f"attribute '{name}' already exists") @@ -589,7 +590,7 @@ def get_parameter(self, target: str) -> "Parameter": fully-qualified string.) Returns: - torch.nn.Parameter: The Parameter referenced by ``target`` + core.nn.Parameter: The Parameter referenced by ``target`` Raises: AttributeError: If the target string references an invalid @@ -625,7 +626,7 @@ def get_buffer(self, target: str) -> "Tensor": fully-qualified string.) Returns: - torch.Tensor: The buffer referenced by ``target`` + core.Tensor: The buffer referenced by ``target`` Raises: AttributeError: If the target string references an invalid @@ -701,11 +702,13 @@ def compute_should_use_set_data(tensor, tensor_applied): # `core.__future__.get_overwrite_module_params_on_conversion()` # global flag to let the user control whether they want the future # behavior of overwriting the existing tensor or not. - return True + return not core.__future__.get_overwrite_module_params_on_conversion() else: return False - should_use_swap_tensors = False + should_use_swap_tensors = ( + core.__future__.get_swap_module_params_on_conversion() + ) for key, param in self._parameters.items(): if param is None: @@ -718,10 +721,7 @@ def compute_should_use_set_data(tensor, tensor_applied): p_should_use_set_data = compute_should_use_set_data(param, param_applied) # subclasses may have multiple child tensors so we need to use swap_tensors - # p_should_use_swap_tensors = ( - # should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied) - # ) - p_should_use_swap_tensors = False + p_should_use_swap_tensors = should_use_swap_tensors param_grad = param.grad if p_should_use_swap_tensors: @@ -730,7 +730,7 @@ def compute_should_use_set_data(tensor, tensor_applied): # Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping. # Decrement use count of the gradient by setting to None param.grad = None - param_applied = core.nn.Parameter( + param_applied = Parameter( param_applied, requires_grad=param.requires_grad ) core.utils.swap_tensors(param, param_applied) @@ -1026,7 +1026,7 @@ def remove_from(*dicts_or_sets): if value is not None: raise TypeError( f"cannot assign '{core.typename(value)}' as parameter '{name}' " - "(torch.nn.Parameter or None expected)" + "(core.nn.Parameter or None expected)" ) self.register_parameter(name, value) else: @@ -1052,7 +1052,7 @@ def remove_from(*dicts_or_sets): if value is not None: raise TypeError( f"cannot assign '{core.typename(value)}' as child module '{name}' " - "(torch.nn.Module or None expected)" + "(core.nn.Module or None expected)" ) for hook in _global_module_registration_hooks.values(): output = hook(self, name, value) @@ -1065,7 +1065,7 @@ def remove_from(*dicts_or_sets): if value is not None and not isinstance(value, core.Tensor): raise TypeError( f"cannot assign '{core.typename(value)}' as buffer '{name}' " - "(torch.nn.Buffer, torch.Tensor or None expected)" + "(core.nn.Buffer, core.Tensor or None expected)" ) if isinstance(value, Buffer): persistent = value.persistent @@ -1115,7 +1115,7 @@ def __delattr__(self, name): super().__delattr__(name) def _register_state_dict_hook(self, hook): - r"""Register a post-hook for the :meth:`~torch.nn.Module.state_dict` method. + r"""Register a post-hook for the :meth:`~core.nn.Module.state_dict` method. It should have the following signature:: hook(module, state_dict, prefix, local_metadata) -> None or state_dict @@ -1204,7 +1204,7 @@ def npu(self: T, device: Optional[Union[int, device]] = None) -> T: return self._apply(lambda t: t.npu(device)) def cpu(self: T, device: Optional[Union[int, device]] = None) -> T: - return self._apply(lambda t: t.cpu(device)) + return self._apply(lambda t: t.cpu()) def _load_from_state_dict( @@ -1216,11 +1216,11 @@ def _load_from_state_dict( missing_keys, unexpected_keys, error_msgs, - ): + ) -> None: r"""Copy parameters and buffers from :attr:`state_dict` into only this module, but not its descendants. This is called on every submodule - in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + in :meth:`~core.nn.Module.load_state_dict`. Metadata saved for this module in input :attr:`state_dict` is provided as :attr:`local_metadata`. For state dicts without metadata, :attr:`local_metadata` is empty. Subclasses can achieve class-specific backward compatible loading using @@ -1231,7 +1231,7 @@ def _load_from_state_dict( .. note:: :attr:`state_dict` is not the same object as the input - :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + :attr:`state_dict` to :meth:`~core.nn.Module.load_state_dict`. So it can be modified. Args: @@ -1250,7 +1250,7 @@ def _load_from_state_dict( keys to this list error_msgs (list of str): error messages should be added to this list, and will be reported together in - :meth:`~torch.nn.Module.load_state_dict` + :meth:`~core.nn.Module.load_state_dict` """ for hook in self._load_state_dict_pre_hooks.values(): hook( @@ -1282,7 +1282,7 @@ def _load_from_state_dict( if not core.overrides.is_tensor_like(input_param): error_msgs.append( f'While copying the parameter named "{key}", ' - "expected torch.Tensor or Tensor-like object from checkpoint but " + "expected core.Tensor or Tensor-like object from checkpoint but " f"received {type(input_param)}" ) continue @@ -2008,7 +2008,7 @@ def to_empty( r"""Move the parameters and buffers to the specified device without copying storage. Args: - device (:class:`torch.device`): The desired device of the parameters + device (:class:`core.device`): The desired device of the parameters and buffers in this module. recurse (bool): Whether parameters and buffers of submodules should be recursively moved to the specified device. diff --git a/mindnlp/core/nn/modules/sparse.py b/mindnlp/core/nn/modules/sparse.py index 88642422d..b71e98230 100644 --- a/mindnlp/core/nn/modules/sparse.py +++ b/mindnlp/core/nn/modules/sparse.py @@ -1,5 +1,6 @@ """sparse""" from typing import Optional +from mindnlp import core from mindnlp.core import Tensor from ..parameter import Parameter from .module import Module @@ -33,7 +34,7 @@ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optiona max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, sparse: bool = False, _weight: Optional[Tensor] = None, _freeze: bool = False, dtype=None, device=None) -> None: - factory_kwargs = {'dtype': dtype} + factory_kwargs = {'dtype': dtype, 'device': device} super().__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim @@ -62,6 +63,11 @@ def reset_parameters(self) -> None: init.normal_(self.weight) self._fill_padding_idx_with_zero() + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with core.no_grad(): + self.weight[self.padding_idx].fill_(0) + def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None: self.weight[self.padding_idx] = 0 diff --git a/mindnlp/core/nn/parameter.py b/mindnlp/core/nn/parameter.py index 1fb06de52..1d2e69eb7 100644 --- a/mindnlp/core/nn/parameter.py +++ b/mindnlp/core/nn/parameter.py @@ -12,6 +12,7 @@ class Parameter(Tensor): def __init__(self, input_data=None, requires_grad=True, **kwargs): super().__init__(input_data) + self._device = input_data._device self.meta = False self.param_info = ParamInfo() self.param_info.name = str(uuid.uuid4()) @@ -21,6 +22,7 @@ def __init__(self, input_data=None, requires_grad=True, **kwargs): def __deepcopy__(self, memodict): new_obj = Parameter(self) + new_obj._device = self.device return new_obj def clone(self): @@ -44,16 +46,6 @@ def name(self): # only for O2 """ return self.param_info.name - @property - def data(self): - return Tensor(self) - - @data.setter - def data(self, new_value): - if isinstance(new_value, StubTensor) and new_value.stub is not None: - new_value = new_value.stub.get_value() - self.assign_value(new_value) - @property def requires_grad(self): return self._requires_grad diff --git a/mindnlp/core/npu/__init__.py b/mindnlp/core/npu/__init__.py index 167042e83..940f6d07b 100644 --- a/mindnlp/core/npu/__init__.py +++ b/mindnlp/core/npu/__init__.py @@ -1,14 +1,19 @@ +import os from typing import Any import mindspore +from mindspore._c_expression import _ms_memory_recycle from mindspore import get_rng_state, set_rng_state, manual_seed from mindspore.runtime import memory_reserved as ms_memory_reserved, \ - memory_allocated as ms_memory_allocated, StreamCtx as StreamContext, Stream, empty_cache, \ + memory_allocated as ms_memory_allocated, StreamCtx as StreamContext, Stream, empty_cache as ms_empty_cache, \ reset_peak_memory_stats, reset_max_memory_allocated, max_memory_allocated, synchronize, \ current_stream -from mindspore.device_context.ascend import device_count + +from mindspore.device_context.ascend import device_count as ms_device_count +from mindspore.communication import GlobalComm, get_group_size from mindnlp import core +from mindnlp.core.executor import execute from ..configs import SUPPORT_BF16 FloatTensor = core.FloatTensor @@ -21,6 +26,11 @@ def set_compile_mode(*args, **kwargs): def manual_seed_all(seed: int): manual_seed(seed) +def device_count(): + if GlobalComm.INITED: + return get_group_size() + return ms_device_count() + def current_device(): return core.device('npu', 0) @@ -40,6 +50,10 @@ def memory_allocated(device=None): return ms_memory_allocated() def memory_reserved(device=None): + if os.environ.get("MS_ALLOC_CONF", None) is not None: + # increase_size = 2GB + out = ((ms_memory_allocated() // (1024 * 1024 * 2048)) + 1) * (1024 * 1024 * 2048) + return out return ms_memory_reserved() class device: @@ -59,11 +73,31 @@ def __enter__(self): def __exit__(self, type: Any, value: Any, traceback: Any): return False -def mem_get_info(index): - return (1024, 1024) +def _try_initial_ascend(): + x = core.tensor(1, device='npu') + _ = x + 0 + +def mem_get_info(device=None): + if not isinstance(device, int): + device = mindspore.context.get_context("device_id") + + res = mindspore.hal.get_device_properties(device) + if res.total_memory == 0: + _try_initial_ascend() + res = mindspore.hal.get_device_properties(device) + + return (res.free_memory, res.total_memory) def current_device(): return core.device('npu', 0) def get_device_capability(device=None): - return 10, 0 \ No newline at end of file + return 10, 0 + + +def npu_rotary_mul(x, cos, sin): + return execute('rotary_position_embedding', x, cos, sin, 0) + +def empty_cache(): + ms_empty_cache() + _ms_memory_recycle() \ No newline at end of file diff --git a/mindnlp/core/ops/_inner.py b/mindnlp/core/ops/_inner.py index 0b292ee8d..b8bd9af09 100644 --- a/mindnlp/core/ops/_inner.py +++ b/mindnlp/core/ops/_inner.py @@ -1,20 +1,9 @@ """inner ops""" -import mindspore -from mindspore import ops -from ..configs import use_pyboost +from mindnlp.core.executor import execute def cast(input, dtype): - return ops.cast(input, dtype) + return execute('cast', input, dtype) def assign(input, other): - return ops.assign(input, other) - -def call_ms_func(func_name, *args, **kwargs): - out = kwargs.pop('out', None) - if out is None: - return func_name(*args, **kwargs) - else: - tmp = func_name(*args, **kwargs) - return out.copy_(tmp) - + return execute('assign', input, other) __all__ = ['cast', 'assign'] diff --git a/mindnlp/core/ops/array.py b/mindnlp/core/ops/array.py index 9c5c978c3..e7fdbd113 100644 --- a/mindnlp/core/ops/array.py +++ b/mindnlp/core/ops/array.py @@ -1,56 +1,54 @@ """array op""" import numbers import numpy as np + import mindspore -from mindspore import ops -from mindspore.ops._primitive_cache import _get_cache_prim -from mindspore.ops.operations._grad_ops import StridedSliceGrad -from mindspore.ops.auto_generate.gen_ops_prim import inplace_scatter_src_reduce_op - -from ..configs import use_pyboost, ON_ORANGE_PI -from .other import broadcast_tensors, finfo -from ._inner import call_ms_func from mindnlp import core +from mindnlp.core.executor import execute +from .other import broadcast_tensors, broadcast_to + + +def t(input): + assert input.ndim <= 2 + if input.ndim == 2: + return transpose(input, 0, 1) + return input + # adjoint + # argwhere def argwhere(input): - if use_pyboost(): - return mindspore.mint.nonzero(input) - return ops.argwhere(input) + return execute("nonzero", input) + # cat -has_cat = hasattr(mindspore.mint, 'cat') -def cat(tensors, dim=0, *, out=None, **kwargs): - axis = kwargs.get('axis', None) - if axis is not None: - dim = axis - max_dtype = max([x.dtype for x in tensors]) - tensors = [x.to(max_dtype) for x in tensors] - if use_pyboost() and has_cat: - return call_ms_func(mindspore.mint.cat, tensors, dim, out=out) - return call_ms_func(ops.cat, tensors, dim, out=out) +def cat(tensors, dim=0, **kwargs): + dim = kwargs.pop('axis', dim) + return execute("concat", tensors, dim) + # concat -has_concat = hasattr(mindspore.mint, 'concat') -def concat(tensors, dim=0, *, out=None, **kwargs): - return cat(tensors, dim, out=out, **kwargs) +def concat(tensors, dim=0, **kwargs): + dim = kwargs.pop('axis', dim) + return cat(tensors, dim) + # concatenate -def concatenate(tensors, dim=0, out=None, **kwargs): - return cat(tensors, dim, out=out, **kwargs) +def concatenate(tensors, dim=0): + return cat(tensors, dim) + # conj def conj(input): - return ops.conj(input) + return execute("conj", input) + # chunk -has_chunk = hasattr(mindspore.mint, 'chunk') def chunk(input, chunks, dim=0): - if use_pyboost() and has_chunk: - return mindspore.mint.chunk(input, chunks, dim) - return ops.chunk(input, chunks, dim) + return execute("chunk", input, chunks, dim) + # dsplit @@ -62,74 +60,26 @@ def chunk(input, chunks, dim=0): # gather -has_gather = hasattr(mindspore.mint, 'gather') def gather(input, dim, index): - is_complex = input.dtype == mindspore.complex64 - if is_complex: - real_part = mindspore.mint.gather(input.real, dim, index) - imag_part = mindspore.mint.gather(input.imag, dim, index) - _complex = _get_cache_prim(ops.Complex)() - return _complex(real_part, imag_part) - - if use_pyboost() and has_gather and not ON_ORANGE_PI: - return mindspore.mint.gather(input, dim, index) + return execute("gather_d", input, dim, index) - index = core.where(index < input.shape[dim], index, index - input.shape[dim]) - if not ON_ORANGE_PI: - return ops.gather_elements(input, dim, index) - - return torch_gather(input, index, dim) def gather_nd(input, indices): - return ops.gather_nd(input, indices) - -def tf_gather(input, indices, axis, batch_dims=0): - return ops.gather(input, indices, axis, batch_dims) - -def torch_gather(x, indices, axis=1): - # 这个实现模拟了 torch.gather 的行为 - if axis < 0: - axis = len(x.shape) + axis - - # 创建索引数组,其他维度保持原样 - all_indices = [] - for dim in range(len(x.shape)): - if dim == axis: - # 使用提供的索引 - all_indices.append(indices.to(mindspore.int32)) - else: - # 创建该维度的原始索引 - shape = [1] * len(x.shape) - shape[dim] = x.shape[dim] - dim_indices = core.arange(x.shape[dim], dtype=mindspore.int32) - dim_indices = core.reshape(dim_indices, shape) - # 广播到 indices 的形状 - dim_indices = core.broadcast_to(dim_indices, indices.shape) - all_indices.append(dim_indices) - - # 组合所有维度的索引 - multi_indices = core.stack(all_indices, axis=-1) - - # 使用 tf.gather_nd 收集元素 - return gather_nd(x, multi_indices) + return execute("gather_nd", input, indices) + # hsplit # hstack -def hstack(tensors): - return ops.hstack(tensors) - # index_fill -def index_fill(input, dim, index, value): - return ops.index_fill(input, dim, index, value) + # index_add def index_add(input, dim, index, source, *, alpha=1): - if use_pyboost(): - return mindspore.mint.index_add(input, dim, index, source, alpha=alpha) - return ops.index_add(input, index, source, dim) + return execute("index_add_ext", input, index, source, dim, alpha) + # index_copy @@ -138,106 +88,65 @@ def index_add(input, dim, index, source, *, alpha=1): # index_select -has_index_select = hasattr(mindspore.mint, 'index_select') -def index_select(input, dim, index, *, out=None): - if use_pyboost() and has_index_select: - return call_ms_func(mindspore.mint.index_select, input, dim, index, out=out) - return call_ms_func(ops.index_select, input, dim, index, out=out) +def index_select(input, dim, index): + return execute("index_select", input, dim, index) # masked_select -has_masked_select = hasattr(mindspore.mint, 'masked_select') -def masked_select(input, mask, *, out=None): - if use_pyboost() and has_masked_select: - return call_ms_func(mindspore.mint.masked_select, input, mask, out=out) - return call_ms_func(ops.masked_select, input, mask, out=out) +def masked_select(input, mask): + return execute("masked_select", input, mask) + # movedim -def movedim(input, source, destination): - return ops.movedim(input, source, destination) + # moveaxis # narrow -has_narrow = hasattr(mindspore.mint, 'narrow') def narrow(input, dim, start, length): - length = length.item() if isinstance(length, mindspore.Tensor) else length - if use_pyboost() and has_narrow: - return mindspore.mint.narrow(input, dim, start, length) - return ops.narrow(input, dim, start, length) + return execute("narrow", input, dim, start, length) + # narrow_copy # nonzero -has_nonzero = hasattr(mindspore.mint, 'nonzero') def nonzero(input, *, as_tuple=False): - if use_pyboost() and has_nonzero: - return mindspore.mint.nonzero(input, as_tuple=as_tuple) - _nonzero = _get_cache_prim(ops.NonZero)() - out = _nonzero(input) if as_tuple: - if 0 in out.shape: - return (out, out) - return unbind(out, 1) - return out + return execute("non_zero_ext", input) + return execute("non_zero", input) + # permute -has_permute = hasattr(mindspore.mint, 'permute') def permute(input, dims): - if use_pyboost() and has_permute: - return mindspore.mint.permute(input, dims) - return ops.permute(input, dims) + assert isinstance(dims, tuple) + return execute("transpose_view", input, dims) + # reshape -has_reshape = hasattr(mindspore.mint, 'reshape') -def reshape(input, *shape, **kwargs): - shape = kwargs.pop('shape', shape) +def reshape(input, *shape): if isinstance(shape[0], (tuple, list)): shape = shape[0] - new_shape = () - for s in shape: - if not isinstance(s, int): - s = s.item() - new_shape += (s,) - if use_pyboost() and has_reshape: - return mindspore.mint.reshape(input, new_shape) - return ops.reshape(input, new_shape) + return execute("reshape", input, shape) + def view(input, *shape): - # if use_pyboost(): - # return mindspore.ops.auto_generate.gen_ops_prim.view_op(input, shape) return reshape(input, shape) + # row_stack + # select -has_select = hasattr(mindspore.mint, 'select') def select(input, dim, index): - if use_pyboost() and has_select: - return mindspore.mint.select(input, dim, index) - slices = () - for _ in range(dim): - slices += (slice(None, None, None),) - slices += (index,) - return input[slices] + return execute("select_ext", input, dim, index) + # scatter -has_scatter = hasattr(mindspore.mint, 'scatter') def scatter(input, dim, index, src): - if use_pyboost() and has_scatter and not ON_ORANGE_PI: - return mindspore.mint.scatter(input, dim, index, src) - if not isinstance(src, mindspore.Tensor): - src = ops.full(index.shape, src, dtype=input.dtype) - if input.dtype == mindspore.bool_: - return ops.tensor_scatter_elements(input.int(), index, src.int(), dim).bool() - return ops.tensor_scatter_elements(input, index, src, dim) - -def tf_scatter_nd_update(input, indices, updates): - return ops.scatter_nd_update(input, indices, updates) + return execute( + "scatter", input, dim, index, src) -def tf_scatter_nd(indices, updates, shape): - return ops.scatter_nd(indices, updates, shape) # diagonal_scatter @@ -249,94 +158,61 @@ def tf_scatter_nd(indices, updates, shape): # scatter_add -has_scatter_add = hasattr(mindspore.mint, 'scatter_add') def scatter_add(input, dim, index, src): - if use_pyboost() and has_scatter_add: - return mindspore.mint.scatter_add(input, dim, index, src) - return ops.tensor_scatter_elements(input, index, src, dim, 'add') - -def scatter_reduce(input, dim, index, src, reduce, *, include_self=True): - if reduce == 'sum': - return scatter_add(input, dim, index, src) - else: - raise ValueError(f'do not support reduce: {reduce}') + return execute("scatter_add_ext", input, dim, index, src) -# scatter_nd_update -def scatter_nd_update(input, indices, update): - return ops.scatter_nd_update(input, indices, update) +# scatter_reduce -def scatter_update(input, indices, updates): - return ops.scatter_update(input, indices, updates) # split -has_split = hasattr(mindspore.mint, 'split') def split(tensor, split_size_or_sections, dim=0): - if isinstance(split_size_or_sections, (tuple, list)): - new_split_size_or_sections = () - for s in split_size_or_sections: - if not isinstance(s, int): - s = s.item() - new_split_size_or_sections += (s,) - split_size_or_sections = new_split_size_or_sections - if use_pyboost() and has_split: - return mindspore.mint.split(tensor, split_size_or_sections, dim) - return ops.split(tensor, split_size_or_sections, dim) - -def split_with_sizes(input, split_sizes, dim=0): - assert input.dim() != 0, "split expects at least a 1-dimensional tensor" - dim_size = input.size(dim) - num_splits = len(split_sizes) - start_idx = 0 - - splits = [] - for i in range(num_splits): - length = split_sizes[i] - assert length >= 0, f"split_with_sizes expects split_sizes have only non-negative entries, but got split_sizes={split_sizes}" - splits.append( - narrow(input, dim, start_idx, length) + if isinstance(split_size_or_sections, int): + res = execute("split_tensor", tensor, split_size_or_sections, dim) + elif isinstance(split_size_or_sections, (list, tuple)): + res = execute("split_with_size", tensor, split_size_or_sections, dim) + else: + raise TypeError( + f"Type of Argument `split_size_or_sections` should be integer, tuple(int) or list(int), " + f"but got {type(split_size_or_sections)}" ) - start_idx += length - - return splits + return res # squeeze -has_squeeze = hasattr(mindspore.mint, 'squeeze') def squeeze(input, *dim, **kwargs): dim = kwargs.get('dim', dim) - if use_pyboost() and has_squeeze: - return mindspore.mint.squeeze(input, dim) - return ops.squeeze(input, dim) + return execute("squeeze", input, dim) + # stack -has_stack = hasattr(mindspore.mint, 'stack') -def stack(tensors, dim=0, *, out=None, **kwargs): - dim = kwargs.pop('axis', dim) - if use_pyboost() and has_stack: - return call_ms_func(mindspore.mint.stack, tensors, dim, out=out) - return call_ms_func(ops.stack, tensors, dim, out=out) + + +def stack(tensors, dim=0): + if tensors[0].device.type == "npu": + return execute("stack_ext", tensors, dim) + return execute("stack", tensors, dim) + # swapaxes -has_swapaxes = hasattr(mindspore.mint, 'swapaxes') def swapaxes(input, dim0, dim1): return transpose(input, dim0, dim1) + # swapdims def swapdims(input, dim0, dim1): return transpose(input, dim0, dim1) + # take def take(input, index): input = input.view(-1) index_shape = index.shape index = index.view(-1) - if ON_ORANGE_PI: - return tf_gather(input, index, 0).view(index_shape) - if index_shape == (): - return gather(input, 0, index)[0] return gather(input, 0, index).view(index_shape) + +# take_along_dim def infer_size_impl(a, b): lenA = len(a) lenB = len(b) @@ -363,7 +239,6 @@ def infer_size_impl(a, b): return expanded_sizes - def _take_along_dim_helper(self, indices, dim): assert self.dim() == indices.dim(), f"torch.take_along_dim(): input and indices should have the same number of dimensions, " \ f"but got {self.dim()} dimensions for input, and {indices.dim()} dimensions for indices" @@ -388,93 +263,254 @@ def take_along_dim(input, indices, dim=None, *, out=None): return input.view(-1).gather(0, indices.view(-1)) # tensor_split -def tensor_split(input, indices_or_sections, dim=0): - return ops.tensor_split(input, indices_or_sections, dim) + # tile -has_tile = hasattr(mindspore.mint, 'tile') -def tile(input, *dims): - if isinstance(dims[0], (tuple, list)): - dims = tuple(dims[0]) - if use_pyboost() and has_tile: - return mindspore.mint.tile(input, dims) - return ops.tile(input, dims) +def tile(input, dims): + return execute("tile", input, dims) + # transpose -has_transpose = hasattr(mindspore.mint, 'transpose') def transpose(input, dim0, dim1): - if use_pyboost() and has_transpose: - return mindspore.mint.transpose(input, dim0, dim1) - ranks = list(range(input.ndim)) - rank0 = ranks[dim0] - rank1 = ranks[dim1] - ranks[dim0] = rank1 - ranks[dim1] = rank0 - return permute(input, tuple(ranks)) + return execute("transpose_ext_view", input, dim0, dim1) -def t(input): - assert input.ndim <= 2, 'Expects input to be <= 2-D tensor and transposes dimensions 0 and 1.' - if input.ndim == 1: - return input - return transpose(input, 0, 1) # unbind -has_unbind = hasattr(mindspore.mint, 'unbind') def unbind(input, dim=0): - if use_pyboost() and has_unbind: - return mindspore.mint.unbind(input, dim) - return ops.unbind(input, dim) + return execute("unstack_ext", input, dim) + # unravel_index + # unsqueeze -has_unsqueeze = hasattr(mindspore.mint, 'unsqueeze') -def unsqueeze(input, dim=None): - if use_pyboost() and has_unsqueeze: - return mindspore.mint.unsqueeze(input, dim) - return ops.expand_dims(input, dim) +def unsqueeze(input, dim): + return execute("expand_dims_view", input, dim) + # vsplit # vstack -def vstack(input): - return ops.vstack(input) + + +# where +def where(condition, input, other): + return execute("select", condition, input, other) + + +tensor_1d = mindspore.Tensor([0], dtype=core.int64) +empty_tensor_1d = mindspore.Tensor(shape=(0,), dtype=core.int64) +empty_tensor_9d = mindspore.Tensor(shape=(0,)*9, dtype=core.int64) + +def _do_select(self, dim: int, index: int, dim_index: int, self_shape: list): + """call select view operator""" + if not self_shape: + raise TypeError("Invalid index of a 0-dim tensor.") + dim_size = self_shape[dim] + if index >= dim_size or index < -dim_size: + raise IndexError(f"Index {index} is out of bounds for dimension {dim_index} with size {dim_size}") + index = index + dim_size if index < 0 else index + return execute('select_ext_view', self, dim, index) + + +def _do_slice(self, dim: int, index: slice, self_shape: list): + """call slice view operator""" + def _get_index(index, default): + if index is None: + return default + if core.is_tensor(index): + index = int(index) + return index + + if not self_shape: + raise TypeError("Invalid index of a 0-dim tensor.") + step = _get_index(index.step, 1) + if step <= 0: + raise ValueError("slice step must be positive") + start = _get_index(index.start, 0) + end = _get_index(index.stop, self_shape[dim]) + if start == 0 and end == self_shape[dim] and step == 1: + return self + return execute('slice_ext', self, dim, start, end, step) + +def _wrap_index_to_tuple(index): + """Wrap index to tuple""" + if isinstance(index, tuple): + return index + if isinstance(index, list): + if len(index) < 32 and any(isinstance(i, (core.Tensor, list, tuple, slice, type(None), type(...))) for i in index): + return tuple(index) + return (index,) + + +def _count_indexed_dims(indexes): + """Count indexed dims""" + count = 0 + for index in indexes: + if isinstance(index, core.Tensor): + if index.dtype == core.bool: + count += index.ndim + else: + count += 1 + elif not isinstance(index, (type(None), type(...), bool)): + count += 1 + return count + +def _record_tensor_index(index, remain_indexes, dim): + """Record indexes remained to be used by aclnnIndex/aclnnIndexPut""" + if len(remain_indexes) > dim: + remain_indexes[dim] = index + return remain_indexes + + while dim > len(remain_indexes): + # use empty_tensor with dim_num 9 to indicate unused dim + remain_indexes.append(empty_tensor_9d) + + remain_indexes.append(index) + return remain_indexes + +def _process_dim_in_multi_dim_index(prev_result, orig_tensor, index, dim, indexed_dims, dim_index, remain_indexes, + prev_shape): + """Process dim in multi dim index""" + if isinstance(index, bool): + result = unsqueeze(prev_result, dim) + index_for_bool = tensor_1d if index else empty_tensor_1d + _record_tensor_index(index_for_bool, remain_indexes, dim) + prev_shape.insert(dim, 1) + dim += 1 + return result, dim, remain_indexes, prev_shape + if isinstance(index, int): + result = _do_select(prev_result, dim, index, dim_index, prev_shape) + del prev_shape[dim] + return result, dim, remain_indexes, prev_shape + if isinstance(index, slice): + result = _do_slice(prev_result, dim, index, prev_shape) + # current dim in prev_shape will not be used later, ignore it + dim += 1 + return result, dim, remain_indexes, prev_shape + if isinstance(index, type(...)): + dim += (orig_tensor.ndim - indexed_dims) + return prev_result, dim, remain_indexes, prev_shape + if index is None: + result = unsqueeze(prev_result, dim) + prev_shape.insert(dim, 1) + dim += 1 + return result, dim, remain_indexes, prev_shape + if isinstance(index, core.Tensor): + result = prev_result + if index.ndim == 0 and index.dtype in (core.int, core.long, core.short, core.bool): + if index.dtype in (core.int, core.long, core.short): + result = _do_select(prev_result, dim, index.item(), dim_index, prev_shape) + del prev_shape[dim] + return result, dim, remain_indexes, prev_shape + # process index with Tensor bool type + result = unsqueeze(prev_result, dim) + index_for_bool = tensor_1d if index else empty_tensor_1d + _record_tensor_index(index_for_bool, remain_indexes, dim) + prev_shape.insert(dim, 1) + dim += 1 + return result, dim, remain_indexes, prev_shape + _record_tensor_index(index, remain_indexes, dim) + dim += 1 + return result, dim, remain_indexes, prev_shape + raise IndexError(f"Invalid tensor index type {index}") + + +def _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims): + """Process indexes in tuple""" + self_viewed = self + self_viewed_shape = list(self.shape) + dim = 0 + for i, index in enumerate(indexes): + if isinstance(index, (list, tuple, np.ndarray)): + index_np = np.array(index) if isinstance(index, (list, tuple)) else index + if index_np.dtype in (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, + np.float16, np.float32, np.float64): + index = core.tensor(index_np, device=self.device, dtype=core.int64) + elif index_np.dtype == np.bool_: + index = core.tensor(index_np, device=self.device, dtype=core.int64) + else: + raise TypeError(f"Index {index} contain unsupported elements") + self_viewed, dim, remain_indexes, self_viewed_shape = _process_dim_in_multi_dim_index( + self_viewed, self, index, dim, indexed_dims, i, remain_indexes, self_viewed_shape) + return self_viewed, remain_indexes + + +def tensor_getitem(self, index): + """Handle tensor getitem""" + if isinstance(index, bool): + self_viewed = unsqueeze(self, 0) + index_for_bool = tensor_1d if index else empty_tensor_1d + return execute('index', self_viewed, [index_for_bool]) + if isinstance(index, int): + return _do_select(self, 0, index, 0, list(self.shape)) + if isinstance(index, slice): + result = _do_slice(self, 0, index, list(self.shape)) + return result + if index is None: + return unsqueeze(self, 0) + if isinstance(index, type(...)): + return self + indexes = _wrap_index_to_tuple(index) + indexed_dims = _count_indexed_dims(indexes) + if self.ndim < indexed_dims: + raise IndexError(f"too many indices for tensor with dimension size {self.ndim}") + remain_indexes = [] + self_viewed, remain_indexes = _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims) + if not remain_indexes: + return self_viewed + return execute('index', self_viewed, remain_indexes) + + +def tensor_setitem(self, index, value): + """Handle tensor setitem""" + if not isinstance(value, core.Tensor): + if isinstance(value, (bool, int, float)): + value = core.tensor(value, dtype=self.dtype, device=self.device) + else: + raise TypeError(f"Can't assign a {type(value)} to a {self.dtype}.") + + if isinstance(index, bool) and index is False: + return self + if isinstance(index, type(...)): + execute('inplace_copy', self, value) + return self + if index is None or (isinstance(index, bool) and index is True): + self_viewed = unsqueeze(self, 0) + execute('inplace_copy', self_viewed, value) + return self + if isinstance(index, int): + self_viewed = _do_select(self, 0, index, 0, list(self.shape)) + execute('inplace_copy', self_viewed, value) + return self + if isinstance(index, slice): + self_viewed = _do_slice(self, 0, index, list(self.shape)) + execute('inplace_copy', self_viewed, value) + return self + indexes = _wrap_index_to_tuple(index) + indexed_dims = _count_indexed_dims(indexes) + if self.ndim < indexed_dims: + raise IndexError(f"too many indices for tensor with dimension size {self.ndim}") + remain_indexes = [] + self_viewed, remain_indexes = _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims) + if not remain_indexes: + execute('inplace_copy', self_viewed, value) + return self + execute('inplace_index_put', self_viewed, remain_indexes, value, False) # accumulate=False + return self _SLICE_ERROR = ( 'only integers, slices (`:`), ellipsis (`...`), ' 'newaxis (`None`) and integer or boolean arrays are valid indices' ) -# where -def where(condition, *args, out=None): - if len(args) == 0: - return nonzero(condition, as_tuple=True) - assert len(args) == 2 - input, other = args - - if isinstance(input, float) and input == -float("inf"): - input = finfo(other.dtype).min - if isinstance(other, float) and other == -float("inf"): - if isinstance(input, numbers.Number): - input = mindspore.tensor(input, dtype=mindspore.float32) - other = finfo(input.dtype).min - - if use_pyboost() and not ON_ORANGE_PI: - output = mindspore.mint.where(condition, input, other) - else: - output = condition * input + (~condition) * other - - if out is not None: - out.assign_value(output) - return output - def _as_index(idx, need_scalar=True): """Helper function to parse idx as an index. """ if isinstance(idx, numbers.Integral): return idx, True - idx = mindspore.Tensor(idx) + idx = core.tensor(idx) if need_scalar and idx.ndim not in (None, 0): raise IndexError(_SLICE_ERROR + ', got {!r}'.format(idx)) @@ -482,7 +518,6 @@ def _as_index(idx, need_scalar=True): return idx.item(), True return idx, False - def cumprod(x, axis=0, exclusive=False, reverse=False): x = np.array(x) if reverse: @@ -531,12 +566,12 @@ def _correct_axis(axis, rank): assert dest <= len(perm) perm.insert(dest, src) else: - r = ops.range(0, a_rank, 1) + r = core.range(0, a_rank, 1) def _remove_indices(a, b): """Remove indices (`b`) from `a`.""" - items = ops.unstack( - ops.sort(ops.stack(b)) + items = core.unbind( + core.sort(core.stack(b)) ) i = 0 @@ -548,18 +583,18 @@ def _remove_indices(a, b): result.append(a[i:]) - return ops.concat(result, 0) + return core.concat(result, 0) minus_sources = _remove_indices(r, source) minus_dest = _remove_indices(r, destination) - perm = ops.scatter_nd( - ops.expand_dims(minus_dest, 1), minus_sources, [a_rank] + perm = execute('scatter_nd', + core.unsqueeze(minus_dest, 1), minus_sources, [a_rank] ) - perm = ops.tensor_scatter_update( - perm, ops.expand_dims(destination, 1), source + perm = execute('tensor_scatter_update', + perm, core.unsqueeze(destination, 1), source ) - a = ops.transpose(a, tuple(perm)) + a = core.permute(a, tuple(perm)) return a @@ -630,21 +665,17 @@ def _slice_helper(tensor, slice_spec, do_update=False, updates=None): else: if updates is not None: original_tensor = tensor - tensor = ops.strided_slice( + tensor = execute( + 'strided_slice', tensor, begin, end, strides, - begin_mask=begin_mask, - end_mask=end_mask, - shrink_axis_mask=shrink_axis_mask, - new_axis_mask=new_axis_mask, - ellipsis_mask=ellipsis_mask, + begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask ) if not advanced_indices: return tensor - advanced_indices_map = {} for index, data, had_ellipsis in advanced_indices: if had_ellipsis: @@ -666,20 +697,20 @@ def _slice_helper(tensor, slice_spec, do_update=False, updates=None): break indices = [advanced_indices_map[x] for x in dims] indices = broadcast_tensors(*indices) - stacked_indices = ops.stack(indices, axis=-1) + stacked_indices = stack(indices, dim=-1) # Skip the contiguous-dims optimization for update because there is no # tf.*scatter* op that supports the `axis` argument. if not dims_contiguous or updates is not None: if range(len(dims)) != dims: tensor = moveaxis(tensor, dims, range(len(dims))) - tensor_shape_prefix = mindspore.Tensor(tensor.shape[: len(dims)]) + tensor_shape_prefix = core.tensor(tensor.shape[: len(dims)]) stacked_indices = where( stacked_indices < 0, stacked_indices + tensor_shape_prefix, stacked_indices, ) if updates is None: - return ops.gather_nd(tensor, stacked_indices) + return execute('gather_nd', tensor, stacked_indices) else: # We only need to move-axis `updates` in the contiguous case becausce # only in this case the result dimensions of advanced indexing are in @@ -697,7 +728,7 @@ def range_(start, length): updates = moveaxis( updates, range_(batch_start, batch_size), range(batch_size) ) - tensor = ops.tensor_scatter_update(tensor, stacked_indices, updates) + tensor = execute('tensor_scatter_update', tensor, stacked_indices, updates) if range(len(dims)) != dims: tensor = moveaxis(tensor, range(len(dims)), dims) return strided_slice_update( @@ -722,9 +753,9 @@ def range_(start, length): dim_sizes = np.take_along_axis(np.array(shape_tensor), np.array(dims), axis=0) if len(dims) == 1: stacked_indices = indices[0] - stacked_indices = ops.cast(stacked_indices, mindspore.int32) + stacked_indices = stacked_indices.to(core.int32) stacked_indices = where( - stacked_indices < 0, stacked_indices + mindspore.Tensor(dim_sizes), stacked_indices + stacked_indices < 0, stacked_indices + core.tensor(dim_sizes, device=stacked_indices.device), stacked_indices ) axis = dims[0] if len(dims) > 1: @@ -733,14 +764,14 @@ def range_(start, length): def _tensordot(a, b): # TODO(b/168657656): This function should be replaced by # tensordot(axis=1) once MatMul has int32 XLA kernel. - b = ops.broadcast_to(b, a.shape) - return ops.sum(a * b, dim=-1) + b = broadcast_to(b, a.shape) + return core.sum(a * b, dim=-1) - stacked_indices = _tensordot(stacked_indices, mindspore.Tensor(index_scaling)) + stacked_indices = _tensordot(stacked_indices, core.tensor(index_scaling)) flat_shape = shape_tensor[:axis] + (-1,) + shape_tensor[axis + len(dims) :] - tensor = ops.reshape(tensor, flat_shape) + tensor = tensor.reshape(flat_shape) - return ops.gather(tensor, stacked_indices, axis=axis) + return execute('gather', tensor, stacked_indices, axis) def _as_spec_tuple(slice_spec): """Convert slice_spec to tuple.""" @@ -758,11 +789,11 @@ def getitem(self, slice_spec): if ( isinstance(slice_spec, bool) or ( - isinstance(slice_spec, mindspore.Tensor) - and slice_spec.dtype == mindspore.bool_ + isinstance(slice_spec, core.Tensor) + and slice_spec.dtype == core.bool ) ): - return ops.boolean_mask(tensor=self, mask=slice_spec) + return masked_select(self, slice_spec) if not isinstance(slice_spec, tuple): slice_spec = _as_spec_tuple(slice_spec) @@ -775,8 +806,8 @@ def setitem(a, slice_spec, updates): if ( isinstance(slice_spec, bool) or ( - isinstance(slice_spec, mindspore.Tensor) - and slice_spec.dtype == mindspore.bool_ + isinstance(slice_spec, core.Tensor) + and slice_spec.dtype == core.bool ) ): slice_spec = nonzero(slice_spec) @@ -786,86 +817,80 @@ def setitem(a, slice_spec, updates): a_dtype = a.dtype result_t = _slice_helper(a, slice_spec, True, updates) - return result_t.astype(a_dtype) - -def tensor_scatter_add(input, indeices, updates): - return ops.tensor_scatter_add(input, indeices, updates) - -def tensor_scatter_max(input, indeices, updates): - return ops.tensor_scatter_max(input, indeices, updates) - -def tensor_scatter_min(input, indeices, updates): - return ops.tensor_scatter_min(input, indeices, updates) + return result_t.to(a_dtype) def strided_slice_update(input, begin, end, strides, update, begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0): - strided_slice_grad = _get_cache_prim(StridedSliceGrad)(begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask) - updated_tensor = strided_slice_grad(update, input.shape, begin, end, strides) - return ops.assign(input, where(updated_tensor != 0, updated_tensor, input)) + if isinstance(update, (int, float, bool)): + update = core.tensor(update, device=input.device, dtype=input.dtype) + sliced_tensor = execute('strided_slice', input, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask) + if update.shape != sliced_tensor.shape: + update = update.broadcast_to(sliced_tensor.shape) + update = update - sliced_tensor + updated_tensor = execute('strided_slice_grad', input, begin, end, strides, update, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask) + input.data = input + updated_tensor + return input + +def getitem_np(input, slice): + return execute('getitem', input, slice) + +def setitem_np(input, slice, value): + return execute('setitem', input, slice, value) __all__ = [ # adjoint, - 'argwhere', - 'cat', - 'concat', - 'concatenate', - 'conj', - 'chunk', + "argwhere", + "cat", + "concat", + "concatenate", + "conj", + "chunk", # dsplit, # column_stack # dstack - 'gather', - 'gather_nd', - 'tf_gather', + "gather", + "gather_nd", # hsplit - 'hstack', - 'index_fill', - 'index_add', + "index_add", # index_copy # index_reduce - 'index_select', - 'masked_select', - 'movedim', + "index_select", + "masked_select", + # movedim # moveaxis - 'narrow', + "narrow", # narrow_copy - 'nonzero', - 'permute', - 'reshape', - 'view', + "nonzero", + "permute", + "reshape", + "view", # row_stack - 'select', - 'scatter', - 'tf_scatter_nd_update', - 'tf_scatter_nd', + "select", + "scatter", # diagonal_scatter # select_scatter # slice_scatter - 'scatter_add', - 'scatter_reduce', - 'scatter_nd_update', - 'scatter_update', - 'split', - 'split_with_sizes', - 'squeeze', - 'stack', - 'swapaxes', - 'swapdims', - 'take', - 'take_along_dim', - 'tensor_split', - 'tile', - 'transpose', - 't', - 'unbind', + "scatter_add", + # scatter_reduce + "split", + "squeeze", + "stack", + "swapaxes", + "swapdims", + "take", + "take_along_dim", + # tensor_split + "tile", + "transpose", + "unbind", # unravel_index - 'unsqueeze', + "unsqueeze", # vsplit - 'vstack', - 'where', + "where", + 'tensor_getitem', + 'tensor_setitem', + 't', 'getitem', 'setitem', - 'tensor_scatter_add', - 'tensor_scatter_max', - 'tensor_scatter_min', - 'strided_slice_update' + 'getitem_np', + 'setitem_np' ] diff --git a/mindnlp/core/ops/blas.py b/mindnlp/core/ops/blas.py index 6506a2277..9a0ebe17e 100644 --- a/mindnlp/core/ops/blas.py +++ b/mindnlp/core/ops/blas.py @@ -1,62 +1,43 @@ """blas op""" -import mindspore - -from mindspore import ops -from ..configs import use_pyboost, ON_ORANGE_PI -from ._inner import call_ms_func +from mindnlp.core.executor import execute # addbmm -has_addbmm = hasattr(mindspore.mint, 'addbmm') -def addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None): - if use_pyboost() and has_addbmm: - return call_ms_func(mindspore.mint.addbmm, input, batch1, batch2, beta=beta, alpha=alpha, out=out) - return call_ms_func(ops.addbmm, input, batch1, batch2, beta=beta, alpha=alpha, out=out) +def addbmm(input, batch1, batch2, *, beta=1, alpha=1): + return execute('addbmm', input, batch1, batch2, beta, alpha) # addmm -has_addmm = hasattr(mindspore.mint, 'addmm') def addmm(input, mat1, mat2, *, beta=1, alpha=1): - if use_pyboost() and has_addmm: - return mindspore.mint.addmm(input, mat1, mat2, beta=beta, alpha=alpha) - return ops.addmm(input, mat1, mat2, beta=beta, alpha=alpha) + return execute('addmm', input, mat1, mat2, beta, alpha) # addmv - +def addmv(input, mat, vec, *, beta=1, alpha=1, out=None): + return execute('addmv', input, mat, vec, beta, alpha) # addr # baddbmm -has_baddbmm = hasattr(mindspore.mint, 'baddbmm') -def baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None): - if use_pyboost() and has_baddbmm: - return call_ms_func(mindspore.mint.baddbmm, input, batch1, batch2, beta=beta, alpha=alpha, out=out) - return call_ms_func(ops.baddbmm, input, batch1, batch2, beta=beta, alpha=alpha, out=out) +def baddbmm(input, batch1, batch2, *, beta=1, alpha=1): + return execute('baddbmm', input, batch1, batch2, beta, alpha) # bmm -has_bmm = hasattr(mindspore.mint, 'bmm') -def bmm(input, other, *, out=None): - if ON_ORANGE_PI: - input = input.to(mindspore.float16) - other = input.to(mindspore.float16) - if use_pyboost() and has_bmm: - return call_ms_func(mindspore.mint.bmm, input, other, out=out) - return call_ms_func(ops.bmm, input, other, out=out) +def bmm(input, other): + return execute('bmm_ext', input, other) # chain_matmul # cholesky +def cholesky(input, upper=False, *, out=None): + return execute('cholesky', input, upper) # cholesky_inverse # cholesky_solve # dot -has_dot = hasattr(mindspore.mint, 'dot') def dot(input, other): - if use_pyboost() and has_dot: - return mindspore.mint.dot(input, other) - return (input * other).sum() + return execute('dot', input, other) # geqrf @@ -76,25 +57,17 @@ def dot(input, other): # lu_solve - # lu_unpack # matmul -has_matmul = hasattr(mindspore.mint, 'matmul') -def matmul(input, other, *, out=None): - if ON_ORANGE_PI: - input = input.to(mindspore.float16) - other = other.to(mindspore.float16) - if use_pyboost() and has_matmul: - return call_ms_func(mindspore.mint.matmul, input, other, out=out) - return call_ms_func(ops.matmul, input, other, out=out) +def matmul(input, other): + return execute('matmul_ext', input, other) # matrix_power # matrix_exp # mm -has_mm = hasattr(mindspore.mint, 'mm') def mm(input, other): return matmul(input, other) @@ -106,15 +79,11 @@ def mm(input, other): # ormqr # outer -has_outer = hasattr(mindspore.mint, 'outer') -def outer(input, vec2, *, out=None): - if use_pyboost() and has_outer: - return call_ms_func(mindspore.mint.outer, input, vec2, out=out) - return call_ms_func(ops.outer, input, vec2, out=out) +def outer(input, vec2): + return execute('outer', input, vec2) # pinverse - # qr # svd diff --git a/mindnlp/core/ops/comparison.py b/mindnlp/core/ops/comparison.py index 6025cd78e..a0db4c15b 100644 --- a/mindnlp/core/ops/comparison.py +++ b/mindnlp/core/ops/comparison.py @@ -1,194 +1,144 @@ """comparison op""" from collections import namedtuple -import numpy as np -import mindspore -from mindspore import ops -from ..configs import use_pyboost, ON_ORANGE_PI - -from ._inner import call_ms_func +from mindnlp import core +from mindnlp.core.executor import execute sort_out = namedtuple('sort_out', ['values', 'indices']) topk_out = namedtuple('topk_out', ['values', 'indices']) + # allclose -has_allclose = hasattr(mindspore.mint, 'allclose') def allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): - rtol = rtol.item() if isinstance(rtol, mindspore.Tensor) else rtol - atol = atol.item() if isinstance(atol, mindspore.Tensor) else atol - if use_pyboost() and has_allclose and not ON_ORANGE_PI: - return mindspore.mint.allclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan) - return np.allclose(input.numpy(), other.numpy(), rtol, atol, equal_nan) + return isclose(input, other, rtol, atol, equal_nan).all().item() # argsort -has_argsort = hasattr(mindspore.mint, 'argsort') def argsort(input, dim=-1, descending=False, stable=False): - if use_pyboost() and has_argsort: - return mindspore.mint.argsort(input, dim=dim, descending=descending) return sort(input, dim=dim, descending=descending, stable=stable)[1] # eq -has_eq = hasattr(mindspore.mint, 'eq') -def eq(input, other, *, out=None): - if use_pyboost() and has_eq: - return call_ms_func(mindspore.mint.eq, input, other, out=out) - if isinstance(other, str): - return False - return call_ms_func(ops.eq, input, other, out=out) +def eq(input, other): + return execute('equal', input, other) # equal -has_equal = hasattr(mindspore.mint, 'equal') def equal(input, other): - if use_pyboost() and has_equal and not ON_ORANGE_PI: - return mindspore.mint.equal(input, other) - if input.shape != other.shape: - return False + if input.device.type == 'npu': + return execute('equal_ext', input, other) + # if input.shape != other.shape: + # return False out = eq(input, other) return out.all() # ge def ge(input, other): - return ops.ge(input, other) + return execute('greater_equal', input, other) # gt -has_gt = hasattr(mindspore.mint, 'gt') -def gt(input, other, *, out=None): - if use_pyboost() and has_gt: - return call_ms_func(mindspore.mint.gt, input, other, out=out) - return call_ms_func(ops.gt, input, other, out=out) - +def gt(input, other): + return execute('greater', input, other) # greater -has_greater = hasattr(mindspore.mint, 'greater') -def greater(input, other, *, out=None): - return gt(input, other, out=out) +def greater(input, other): + return gt(input, other) # isclose -has_isclose = hasattr(mindspore.mint, 'isclose') def isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): - if use_pyboost() and has_isclose and not ON_ORANGE_PI: - return mindspore.mint.isclose(input, other, rtol, atol, equal_nan) - return mindspore.tensor(np.isclose(input.numpy(), other.numpy(), rtol, atol, equal_nan)) + return execute('isclose', input, other, rtol, atol, equal_nan) # isfinite -has_isfinite = hasattr(mindspore.mint, 'isfinite') def isfinite(input): - if use_pyboost() and has_isfinite: - return mindspore.mint.isfinite(input) - return ops.isfinite(input) + return execute('isfinite', input) # isin def isin(elements, test_elements): - elements = elements.ravel().expand_dims(-1) - if isinstance(test_elements, mindspore.Tensor): - test_elements = test_elements.ravel() - included = ops.equal(elements, test_elements) - # F.reduce_sum only supports float - res = ops.sum(included.int(), -1).astype(mindspore.bool_) + if elements.device.type != 'cpu': + test_elements = core.tensor(test_elements) + if test_elements.ndim == 0: + test_elements = test_elements.unsqueeze(0) + return elements.tile(test_elements.shape[0], 1).eq(test_elements.unsqueeze(1)).sum(dim=0).bool().squeeze() - return res + return execute('isin', elements, test_elements) # isinf -has_isinf = hasattr(mindspore.mint, 'isinf') def isinf(input): - if use_pyboost() and has_isinf: - return mindspore.mint.isinf(input) - if input.dtype in (mindspore.int32, mindspore.int64): - input = input.to(mindspore.float32) - return ops.isinf(input) + return execute('isinf', input) # isposinf # isneginf # isnan -has_isnan = hasattr(mindspore.mint, 'isnan') def isnan(input): - if use_pyboost() and has_isnan: - return mindspore.mint.isnan(input) - if input.dtype in (mindspore.int32, mindspore.int64): - input = input.to(mindspore.float32) - return ops.isnan(input) + return execute('not_equal', input, input) # isreal # kthvalue # le -has_le = hasattr(mindspore.mint, 'le') -def le(input, other, *, out=None): - if use_pyboost() and has_le: - return call_ms_func(mindspore.mint.le, input, other, out=out) - return call_ms_func(ops.le, input, other, out=out) +def le(input, other): + return execute('less_equal', input, other) # less_equal -has_less_equal = hasattr(mindspore.mint, 'less_equal') -def less_equal(input, other, *, out=None): - return le(input, other, out=out) +def less_equal(input, other): + return le(input, other) # lt -has_lt = hasattr(mindspore.mint, 'lt') -def lt(input, other, *, out=None): - if use_pyboost() and has_lt: - return call_ms_func(mindspore.mint.lt, input, other, out=out) - return call_ms_func(ops.lt, input, other, out=out) +def lt(input, other): + return execute('less', input, other) # less -has_less = hasattr(mindspore.mint, 'less') -def less(input, other, *, out=None): - return lt(input, other, out=out) +def less(input, other): + return lt(input, other) # maximum -has_maximum = hasattr(mindspore.mint, 'maximum') -def maximum(input, other, *, out=None): - if use_pyboost() and has_maximum: - return call_ms_func(mindspore.mint.maximum, input, other, out=out) - return call_ms_func(ops.maximum, input, other, out=out) +def maximum(input, other): + return execute('maximum', input, other) # minimum -has_minimum = hasattr(mindspore.mint, 'minimum') -def minimum(input, other, *, out=None): - if use_pyboost() and has_minimum: - return call_ms_func(mindspore.mint.minimum, input, other, out=out) - return call_ms_func(ops.minimum, input, other, out=out) - +def minimum(input, other): + return execute('minimum', input, other) # fmax -def fmax(input, other): - return ops.fmax(input, other) # fmin -def fmin(input, other): - return ops.fmin(input, other) # ne -has_ne = hasattr(mindspore.mint, 'ne') -def ne(input, other, *, out=None): - if use_pyboost() and has_ne: - return call_ms_func(mindspore.mint.ne, input, other, out=out) - return call_ms_func(ops.ne, input, other, out=out) +def ne(input, other): + return execute('not_equal', input, other) # not_equal -has_not_equal = hasattr(mindspore.mint, 'not_equal') def not_equal(input, other): return ne(input, other) # sort -has_sort = hasattr(mindspore.mint, 'sort') def sort(input, *, dim=-1, descending=False, stable=False): - if use_pyboost() and has_sort and not ON_ORANGE_PI: - out = mindspore.mint.sort(input, dim=dim, descending=descending, stable=stable) - else: - out = ops.sort(input, dim, descending) + out = execute('sort_ext', input, dim, descending, stable) return sort_out(values=out[0], indices=out[1]) # topk -has_topk = hasattr(mindspore.mint, 'topk') def topk(input, k, dim=-1, largest=True, sorted=True): - if use_pyboost() and has_topk and not ON_ORANGE_PI: - out = mindspore.mint.topk(input, int(k), dim, largest, sorted) + if input.device.type == 'npu': + out = execute('topk_ext', input, k, dim, largest, sorted) else: - out = ops.topk(input, k, dim, largest, sorted) + if not largest: + input = -input + if dim is None or dim == input.ndim - 1: + if not largest: + res = execute('topk', input, k, sorted) + values, indices = -res[0], res[1] + return values, indices + return execute('topk', input, k, sorted) + input = input.swapaxes(dim, input.ndim - 1) + output = execute('topk', input, k, sorted) + values = output[0].swapaxes(dim, input.ndim - 1) + indices = output[1].swapaxes(dim, input.ndim - 1) + if not largest: + res = (-values, indices) + else: + res = (values, indices) + out = res return topk_out(values=out[0], indices=out[1]) + # msort def msort(input): return sort(input, dim=0) @@ -203,8 +153,8 @@ def msort(input): 'greater', 'isclose', 'isfinite', - 'isin', 'isinf', + 'isin', # isposinf, # isneginf, 'isnan', @@ -216,8 +166,6 @@ def msort(input): 'less', 'maximum', 'minimum', - 'fmax', - 'fmin', 'ne', 'not_equal', 'sort', diff --git a/mindnlp/core/ops/complex.py b/mindnlp/core/ops/complex.py index 11e063edc..a15543d65 100644 --- a/mindnlp/core/ops/complex.py +++ b/mindnlp/core/ops/complex.py @@ -1,3 +1,5 @@ +from mindnlp.core.executor import execute + def real(input): return execute('real', input) diff --git a/mindnlp/core/ops/creation.py b/mindnlp/core/ops/creation.py index feeac2919..6db903bc4 100644 --- a/mindnlp/core/ops/creation.py +++ b/mindnlp/core/ops/creation.py @@ -1,246 +1,209 @@ """creation ops""" +import numbers import numpy as np -from ml_dtypes import bfloat16 as np_bfloat16 -import mindspore + try: - from mindspore._c_expression import Tensor as CTensor # pylint: disable=no-name-in-module, import-error + from mindspore._c_expression import TensorPy as Tensor_ except: - from mindspore._c_expression import TensorPy as CTensor # pylint: disable=no-name-in-module, import-error + from mindspore._c_expression import Tensor as Tensor_ -from mindspore._c_expression.typing import Type -from mindspore import ops -from mindspore.ops._primitive_cache import _get_cache_prim -from ..configs import use_pyboost, ON_ORANGE_PI -from .._bind import get_default_dtype, get_default_device -from .._dtype import dtype2np -from .utils import py2dtype -from .other import finfo +from mindnlp import core +from mindnlp.core.executor import execute +from .._bind import get_default_dtype, get_device_in_context def as_strided(self, size, stride, storage_offset=None): - if len(size) != len(stride): - raise RuntimeError("mismatch in length of strides and shape.") - index = np.arange(0, size[0]*stride[0], stride[0]) - for i in np.arange(1, len(size)): - tmp = np.arange(0, size[i]*stride[i], stride[i]) - index = np.expand_dims(index, -1) - index = index + tmp - if storage_offset is not None: - index = index + storage_offset - - if index.size == 0: - input_indices = mindspore.numpy.empty(index.shape, dtype=mindspore.int32) - else: - input_indices = mindspore.tensor(index.astype(np.int32)) - out = ops.gather(self.reshape(-1), input_indices, 0) - return out + return execute('as_strided', self, size, stride, storage_offset) # from_numpy def from_numpy(ndarray): - return mindspore.Tensor(ndarray) + out = core.Tensor.from_numpy(ndarray) + out._device = core.device('cpu') + out._from_numpy = True + return out # frombuffer +def frombuffer(buffer, *, dtype, count=-1, offset=0, requires_grad=False): + arr = np.frombuffer(buffer=buffer, dtype=core.dtype_to_nptype(dtype), count=count, offset=offset) + tensor = core.Tensor(arr) + tensor.requires_grad_(requires_grad) + return tensor + # zeros -_zeros = ops.Zeros() -has_zeros = hasattr(mindspore.mint, 'zeros') -def zeros(*size, dtype=None, device=None, requires_grad=False, **kwargs): +def zeros(*size, out=None, dtype=None, layout=None, device=None, requires_grad=False): if dtype is None: dtype = get_default_dtype() - if not isinstance(dtype, Type): - dtype = py2dtype[dtype] - if len(size) == 0: - size = kwargs.get('size', None) - if size == () or size == []: - size = ((),) + if device is None: + device = get_device_in_context() + + if isinstance(device, str): + device = core.device(device) if isinstance(size[0], (tuple, list)): size = size[0] - - new_size = () - for s in size: - if not isinstance(s, int): - s = s.item() - new_size += (s,) - if use_pyboost() and has_zeros: - # if device == 'cpu': - # return mindspore.Tensor(np.zeros(size), dtype=dtype) - return mindspore.mint.zeros(new_size, dtype=dtype) - size = tuple(size) - return _zeros(new_size, dtype) + + output = execute('zeros', size, dtype, device=device, requires_grad=requires_grad, user_created=True) + if out is None: + return output + out.data = output + return out # zeros_like -has_zeros_like = hasattr(mindspore.mint, 'zeros_like') -def zeros_like(input, *, dtype=None, memory_format=None, **kwargs): +def zeros_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=None): if dtype is None: dtype = input.dtype - if use_pyboost() and has_zeros_like: - return mindspore.mint.zeros_like(input, dtype=dtype) - return ops.zeros_like(input, dtype=dtype) + if device is None: + device = input.device + if device.type == 'cpu': + return execute('zeros_like', input, device=device, requires_grad=requires_grad, user_created=True) + return execute('zeros_like_ext', input, dtype, + device=device, requires_grad=requires_grad, user_created=True) # ones -_ones = ops.Ones() -has_ones = hasattr(mindspore.mint, 'ones') -def ones(*size, dtype=None, device=None, **kwargs): - if len(size) == 0: - size = kwargs.get('size', None) - if size == () or size == []: - size = ((),) - - if isinstance(size[0], (tuple, list)): - size = size[0] +def ones(*size, out=None, dtype=None, layout=None, device=None, requires_grad=False): if dtype is None: dtype = get_default_dtype() - if not isinstance(dtype, Type): - dtype = py2dtype[dtype] - - new_size = () - for s in size: - if not isinstance(s, int): - s = s.item() - new_size += (s,) - if use_pyboost() and has_ones: - return mindspore.mint.ones(new_size, dtype=dtype) - return _ones(new_size, dtype) + if device is None: + device = get_device_in_context() + if isinstance(size[0], (tuple, list)): + size = size[0] + output = execute('ones', size, dtype, + device=device, requires_grad=requires_grad, user_created=True) + if out is None: + return output + out.data = output + return out # ones_like -has_ones_like = hasattr(mindspore.mint, 'ones_like') -def ones_like(input, *, dtype=None, device=None, **kwargs): +def ones_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=None): if dtype is None: dtype = input.dtype - if use_pyboost() and has_ones_like: - return mindspore.mint.ones_like(input, dtype=dtype) - return ops.ones_like(input, dtype=dtype) + if device is None: + device = input.device + if device.type == 'cpu': + return execute('ones_like', input, device=device, requires_grad=requires_grad, user_created=True) + return execute('ones_like_ext', input, dtype, + device=device, requires_grad=requires_grad, user_created=True) # arange -range_op = ops.Range() -has_arange = hasattr(mindspore.mint, 'arange') -def arange(start=0, end=None, step=1, *, dtype=None, device=None): - if ON_ORANGE_PI and dtype in (None, mindspore.int64): - dtype = mindspore.int32 - if use_pyboost() and has_arange: - start = start.item() if isinstance(start, (mindspore.Tensor, np.integer)) else start - end = end.item() if isinstance(end, (mindspore.Tensor, np.integer)) else end - step = step.item() if isinstance(step, (mindspore.Tensor, np.integer)) else step - return mindspore.mint.arange(start, end, step, dtype=dtype) - +def arange(start=0, end=None, step=1, *, out=None, dtype=None, layout=None, device=None, requires_grad=False): if end is None: - end = start - start = 0 - start = mindspore.Tensor(start) if not isinstance(start, mindspore.Tensor) else start - end = mindspore.Tensor(end) if not isinstance(start, mindspore.Tensor) else end - step = mindspore.Tensor(step) if not isinstance(start, mindspore.Tensor) else step - out = range_op(start, end, step) - if dtype: - out = out.to(dtype) + start, end = 0, start + if dtype is None: + dtype = core.int64 + if device is None: + device = get_device_in_context() + if isinstance(device, str): + device = core.device(device) + output = execute('arange', start, end, step, dtype, + device=device, requires_grad=requires_grad, user_created=True) + if out is None: + return output + out.data = output return out # range -def range(start=0, end=None, step=1, dtype=None): +def range(start=0, end=None, step=1, *, out=None, dtype=None, layout=None, device=None, requires_grad=False): if end is None: - start, end = 0, start - out = ops.range(start, end+1, step) - if dtype is not None: - out = out.to(dtype) + raise TypeError('range() missing 1 required positional arguments: "end"') + if dtype is None: + dtype = core.int64 + if device is None: + device = get_device_in_context() + output = execute('range', start, end + 1, step, 1000000, + device=device, requires_grad=requires_grad, user_created=True) + if out is None: + return output + out.data = output return out # linspace -has_linspace = hasattr(mindspore.mint, 'linspace') -def linspace(start, end, steps, *, dtype=None, **kwargs): +def linspace(start, end, steps, *, out=None, dtype=None, layout=None, device=None, requires_grad=False): if dtype is None: - dtype = mindspore.float32 - start = start.item() if isinstance(start, mindspore.Tensor) else start - end = end.item() if isinstance(end, mindspore.Tensor) else end - steps = steps.item() if isinstance(steps, mindspore.Tensor) else steps - if use_pyboost() and has_linspace and not ON_ORANGE_PI: - return mindspore.mint.linspace(start, end, steps, dtype=dtype) - return ops.linspace(start, end, steps).to(dtype) + dtype = get_default_dtype() + if device is None: + device = get_device_in_context() + if device.type == 'cpu': + start = core.tensor(start, device=device, dtype=dtype) + end = core.tensor(end, device=device, dtype=dtype) + output = execute('linspace', start, end, steps, + device=device, requires_grad=requires_grad, user_created=True) + else: + output = execute('lin_space_ext', start, end, steps, dtype, + device=device, requires_grad=requires_grad, user_created=True) + if out is None: + return output + out.data = output + return out # logspace -has_logspace = hasattr(mindspore.mint, 'logspace') -def logspace(start, end, steps, base=10.0, *, dtype=None, **kwargs): - if dtype is None: - dtype = get_default_dtype() - if use_pyboost() and has_logspace: - return mindspore.mint.logspace(start, end, steps, base, dtype=dtype) - return ops.logspace(float(start), float(end), steps, int(base), dtype=dtype) # eye -has_eye = hasattr(mindspore.mint, 'eye') -def eye(n, m=None, *, dtype=None, **kwargs): - if use_pyboost() and has_eye: - return mindspore.mint.eye(n, m, dtype) - return ops.eye(n, m, dtype) - -# empty -has_empty = hasattr(mindspore.mint, 'empty') -def empty(*size, dtype=None, device=None, requires_grad=False, pin_memory=False, **kwargs): - size = size or kwargs.get('size', None) +def eye(n, m=None, *, out=None, dtype=None, layout=None, device=None, requires_grad=False): if device is None: - device= get_default_device() - - if len(size) > 0 and isinstance(size[0], (tuple, list)): - size = size[0] + device = get_device_in_context() + if dtype is None: + dtype = get_default_dtype() + output = execute('eye', n, m, dtype, + device=device, requires_grad=requires_grad, user_created=True) + if out is None: + return output + out.data = output + return out +# empty +def empty(*size, out=None, dtype=None, layout=None, device=None, + requires_grad=False, pin_memory=False, memory_format=None, **kwargs): + size = kwargs.pop('size', size) if dtype is None: dtype = get_default_dtype() + if device is None: + device = get_device_in_context() + if isinstance(device, str): + device = core.device(device) + if isinstance(size[0], (tuple, list)): + size = size[0] - # if device: - # if not isinstance(device, str) and hasattr(device, "type"): - # device = device.type - # if device.lower() == 'cpu': - # device = 'CPU' - # elif device.lower() == 'npu': - # device = 'Ascend' - # elif device.lower() == 'cuda': - # device = 'GPU' - # else: - # device = 'meta' - - # # To avoid the problem in irecv and recv of using empty. - # if device not in ['meta', 'GPU']: - # out = mindspore.mint.empty(size, dtype=dtype, device=device) - # else: - out = CTensor(dtype=dtype, shape=size) - out = mindspore.Tensor(out) - # else: - # out = np.empty(size, dtype=dtype2np[dtype]) - # out = mindspore.Tensor(out) - - if requires_grad: - out.requires_grad = True + if device.type == 'meta': + output = core.tensor(Tensor_(shape=size, dtype=dtype), device=device) + else: + output = execute('empty', size, dtype, device=device) + if out is None: + return output + out.data = output return out # empty_like -has_empty_like = hasattr(mindspore.mint, 'empty_like') def empty_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=None): - if use_pyboost(): - return mindspore.mint.empty_like(input, dtype=dtype, device=device) - return mindspore.Tensor(np.empty(input.shape, dtype=dtype2np[input.dtype])) + if device is None: + device = input.device + return empty(input.shape, dtype=input.dtype, layout=layout, device=device, requires_grad=requires_grad) # empty_strided # full -has_full = hasattr(mindspore.mint, 'full') -def full(size, fill_value, *, dtype=None, device=None, **kwargs): - new_size = () - for s in size: - if isinstance(s, mindspore.Tensor): - s = s.item() - new_size += (s,) - if isinstance(fill_value, np.generic): - fill_value = fill_value.item() - if use_pyboost() and has_full: - return mindspore.mint.full(new_size, fill_value, dtype=dtype) - return ops.full(new_size, fill_value, dtype=dtype) +def full(size, fill_value, *, out=None, dtype=None, layout=None, device=None, requires_grad=False): + if dtype is None: + dtype = get_default_dtype() + if device is None: + device = get_device_in_context() + if device.type == 'cpu': + output = execute('full', size, fill_value, device=device, requires_grad=requires_grad, user_created=True) + else: + if isinstance(fill_value, numbers.Number): + output = execute('fill_scalar', size, fill_value, dtype, + device=device, requires_grad=requires_grad, user_created=True) + else: + output = execute('fill_tensor', size, fill_value, dtype, + device=device, requires_grad=requires_grad, user_created=True) + if out is None: + return output + out.data = output + return out # full_like -has_full_like = hasattr(mindspore.mint, 'full_like') -def full_like(input, fill_value, *, dtype=None, device=None): - if use_pyboost() and has_full_like: - return mindspore.mint.full_like(input, fill_value, dtype=dtype) - if dtype is None: - dtype = input.dtype - return full(input.shape, fill_value, dtype=dtype) +def full_like(input, fill_value, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=None): + return full(input.shape, fill_value, dtype=dtype, layout=layout, device=input.device, requires_grad=requires_grad) # quantize_per_tensor @@ -252,54 +215,21 @@ def full_like(input, fill_value, *, dtype=None, device=None): # complex -def complex(real, imag): - _complex = _get_cache_prim(ops.Complex)() - return _complex(real, imag) + # polar -has_polar = hasattr(mindspore.mint, 'polar') -def polar(abs, angle): - if use_pyboost() and has_polar: - return mindspore.mint.polar(abs, angle) - return ops.polar(abs, angle) +def polar(abs, angle, *, out=None): + output = execute('polar', abs, angle) + if out is None: + return output + out.data = output + return out + # heaviside -def heaviside(input, values): - return ops.heaviside(input, values) - -_TypeDict = { - mindspore.float16: np.float16, - mindspore.float32: np.float32, - mindspore.float64: np.float64, - mindspore.bfloat16: np_bfloat16, - mindspore.int8: np.int8, - mindspore.int16: np.int16, - mindspore.int32: np.int32, - mindspore.int64: np.int64, - mindspore.uint8: np.uint8, - mindspore.bool_: np.bool_, - mindspore.complex64: np.complex64, - mindspore.complex128: np.complex128, -} - - -def frombuffer(buffer, *, dtype=None, count=-1, offset=0, requires_grad=False): - np_dtype = _TypeDict[dtype] - output = np.frombuffer(buffer=buffer, dtype=np_dtype, count=count, offset=offset) - if dtype == mindspore.bfloat16: - return mindspore.Tensor(output.astype(np.float32), dtype=dtype) - return mindspore.Tensor(output, dtype=dtype) - - -def scalar_tensor(value, dtype, device=None): - if value == float("-inf"): - value = finfo(dtype).min - if value == float("inf"): - value = finfo(dtype).max - return mindspore.Tensor(value, dtype=dtype) - -__all__ = ['arange', 'as_strided', 'complex', 'empty', 'empty_like', - 'eye', 'from_numpy', 'full', 'full_like', 'frombuffer', - 'heaviside', 'linspace', 'logspace', 'ones', 'ones_like', - 'polar', 'range', 'zeros', 'zeros_like', 'scalar_tensor' -] \ No newline at end of file + +__all__ = ['arange', 'as_strided', 'empty', 'empty_like', + 'eye', 'from_numpy', 'frombuffer', 'full', 'full_like', + 'linspace', 'ones', 'ones_like', + 'polar', 'range', 'zeros', 'zeros_like' +] diff --git a/mindnlp/core/ops/inplace.py b/mindnlp/core/ops/inplace.py index d8c73626a..ab6f40a5b 100644 --- a/mindnlp/core/ops/inplace.py +++ b/mindnlp/core/ops/inplace.py @@ -1,57 +1,43 @@ -import numbers -import numpy as np -import mindspore -from mindspore import ops -from mindspore._c_expression import typing -from mindspore.ops._primitive_cache import _get_cache_prim -from mindspore.common.generator import default_generator -from mindspore.ops.auto_generate.gen_ops_prim import inplace_normal_op, inplace_scatter_value_op, inplace_scatter_src_reduce_op, \ - inplace_scatter_src_op, inplace_fill_tensor_op, inplace_fill_scalar_op, inplace_zero_op, inplace_uniform_op, \ - inplace_masked_fill_scalar_op, inplace_masked_fill_tensor_op, inplace_random_op, inplace_clamp_scalar_op, \ - inplace_clamp_tensor_op, inplace_copy_op, inplace_index_add_op, inplace_erfinv_op - from mindnlp import core -from ..configs import use_pyboost, ON_ORANGE_PI -from ._inner import assign +from mindnlp.core._C import default_generator +from mindnlp.core.executor import execute generator_step_ = 12 def inplace_copy(self, other): - if self.device.type == 'npu': - inplace_copy_op(self, other) - else: - self.data = other + if self.device != other.device: + other = other.to(self.device) + execute('inplace_copy', self, other) return self def inplace_zero(input): - if input.device == 'npu': - inplace_zero_op(input) + if input.device.type == 'npu': + execute('inplace_zero', input) + elif input.device.type == 'meta': + pass else: - input.data = ops.zeros(input.shape, dtype=input.dtype) + input.data = core.zeros_like(input) return input def inplace_fill(input, value): - if input.device.type == 'npu': - if isinstance(value, (int, float, bool)): - inplace_fill_scalar_op(input, value) - else: - inplace_fill_tensor_op(input, value) + if isinstance(value, (int, float, bool)): + execute('inplace_fill_scalar', input, value) else: - input.data = ops.full(input.shape, value, dtype=input.dtype) + execute('inplace_fill_tensor', input, value) return input def inplace_normal(input, mean=0, std=1, *, generator=None): if generator is None: generator = default_generator seed, offset = generator._step(generator_step_) + if isinstance(mean, core.Tensor): mean = mean.item() if isinstance(std, core.Tensor): std = std.item() - if input.device.type == 'npu': - inplace_normal_op(input, mean, std, seed, offset) - else: - input.data = core.tensor(np.random.normal(mean, std, input.shape), dtype=input.dtype) + + execute('inplace_normal', input, mean, std, seed, offset, device=input.device) + return input # uniform_ @@ -74,167 +60,21 @@ def inplace_uniform(input, *args, **kwargs): generator_ = kwargs.get("generator", None) if generator_ is None: generator_ = default_generator - seed, offset = generator_._step(generator_step_) - if input.device.type == 'npu': - inplace_uniform_op(input, from_, to_, seed, offset) - else: - input.data = core.tensor(np.random.uniform(from_, to_, input.shape), dtype=input.dtype) - # core.rand(input.shape, generator=generator_, dtype=input.dtype) * (to_ - from_) + from_ + execute("inplace_uniform", input, from_, to_, generator_) return input def inplace_add(input, other, alpha): execute('inplace_add_ext', input, other, alpha) return input -def inplace_scatter(input, dim, index, src): - if not isinstance(src, numbers.Number): - return inplace_scatter_src_op(input, dim, index, src) - return inplace_scatter_value_op(input, dim, index, src) - -def inplace_index_copy(input, dim, index, tensor): - selected = input.index_select(dim, index) - input.index_add_(dim, index, -selected) - input.index_add_(dim, index, tensor) - return input - -def inplace_index_add(input, dim, index, source): - if input.device == 'npu': - inplace_index_add_op(input, dim, index, source) - else: - _inplace = _get_cache_prim(ops.IndexAdd)(dim) - input.data = _inplace(input, index.int(), source) - return input - -has_squeeze = hasattr(mindspore.mint, 'squeeze') -def inplace_squeeze(input, *dim, **kwargs): - dim = kwargs.get('dim', dim) - if use_pyboost() and has_squeeze: - out = mindspore.mint.squeeze(input, dim) - else: - out = ops.squeeze(input, dim) - input.assign_value(out) - return input - - -has_unsqueeze = hasattr(mindspore.mint, 'unsqueeze') -def inplace_unsqueeze(input, dim=None): - if use_pyboost() and has_unsqueeze: - out = mindspore.mint.unsqueeze(input, dim) - out = ops.expand_dims(input, dim) - input.assign_value(out) - return input - -def inplace_fill_diagonal(input, fill_value, wrap=False): - fill_diagnoal_ = _get_cache_prim(ops.FillDiagonal)(float(fill_value), wrap) - out = fill_diagnoal_(input) - input.assign_value(out) - return input - -def inplace_triu(input, diagonal=0): - out = ops.triu(input, diagonal) - input.assign_value(out) - return input - -def inplace_round(input, decimals=0): - out = ops.round(input, decimals=decimals) - input.assign_value(out) - return input - -def inplace_scatter_reduce(input, dim, index, src, reduce, *, include_self=True): - if reduce == 'sum': - reduce = "add" - return inplace_scatter_src_reduce_op(input, dim, index, src, reduce) - -def inplace_exponential(tensor, lambd=1.0): - """ - 原地操作的指数分布采样 (类似Tensor.exponential_) - :param tensor: 要填充的目标张量 - :param lambd: 率参数 (λ > 0) - :return: 修改后的张量 (原张量被覆盖) - """ - assert lambd > 0, "lambd 必须大于0" - - # 生成与目标张量形状相同的均匀分布随机数 - u = core.rand_like(tensor) - - # 数值保护 - u = u.clamp(min=core.finfo(u.dtype).eps, max=1.0) - - # 逆变换法赋值 - tensor.data = -core.log(1 - u) / lambd - - return tensor - -def inplace_log(self): - self.data = core.log(self) - return self - -def inplace_mul(self, other): - self.data = core.mul(self, other) - return self - -def inplace_neg(self): - self.data = core.neg(self) - return self - -def inplace_exp(self): - self.data = core.exp(self) - return self - -def inplace_sub(self, other): - self.data = core.sub(self, other) - return self - -def inplace_bernoulli(self, p=0.5, *, generator=None): - self.data = core.bernoulli(self, generator=generator, p=p) - return self - -def inplace_tril(self, diagonal=0): - self.data = core.tril(self, diagonal) - return self - -def inplace_masked_fill(self, mask, value): - if self.device.type == 'npu': - if isinstance(value, (int, float, bool)): - inplace_masked_fill_scalar_op(self, mask, value) - else: - inplace_masked_fill_tensor_op(self, mask, value) - else: - self.data = ops.masked_fill(self, mask, value) - return self - def inplace_random(self, from_=0, to=None, *, generator=None): - if self.device.type == 'npu': - if not generator: - generator = default_generator - seed, offset = generator._step( # pylint: disable=protected-access - generator_step_) - return inplace_random_op(self, from_, to, seed, offset) - else: - if isinstance(self.dtype, typing.Float): - self.uniform_(from_, to, generator=generator) - elif isinstance(self.dtype, typing.Int): - if to is None: - to = core.iinfo(mindspore.int32).max - self.data = core.randint(from_, to, size=self.shape, dtype=self.dtype) - return self - -def inplace_clamp(self, min=None, max=None): - if self.device.type == 'npu': - if isinstance(min, (int, float, bool)) or isinstance(max, (int, float, bool)): - inplace_clamp_scalar_op(self, min, max) - else: - inplace_clamp_tensor_op(self, min, max) - else: - self.data = ops.clamp(self, min, max) + if not generator: + generator = default_generator + seed, offset = generator._step( # pylint: disable=protected-access + generator_step_) + execute('inplace_random', self, from_, to, seed, offset, device=self.device) return self -def inplace_erfinv(self): - if self.device.type == 'npu' and not ON_ORANGE_PI: - inplace_erfinv_op(self) - else: - self.data = core.erfinv(self) - return self __all__ = [ 'inplace_copy', @@ -243,25 +83,5 @@ def inplace_erfinv(self): 'inplace_fill', 'inplace_uniform', 'inplace_add', - 'inplace_scatter', - 'inplace_index_copy', - 'inplace_index_add', - 'inplace_squeeze', - 'inplace_unsqueeze', - 'inplace_fill_diagonal', - 'inplace_triu', - 'inplace_round', - 'inplace_scatter_reduce', - 'inplace_exponential', - 'inplace_log', - 'inplace_mul', - 'inplace_neg', - 'inplace_exp', - 'inplace_sub', - 'inplace_bernoulli', - 'inplace_tril', - 'inplace_masked_fill', - 'inplace_random', - 'inplace_clamp', - 'inplace_erfinv' + 'inplace_random' ] diff --git a/mindnlp/core/ops/optim.py b/mindnlp/core/ops/optim.py index 7265adcbf..ba7190796 100644 --- a/mindnlp/core/ops/optim.py +++ b/mindnlp/core/ops/optim.py @@ -1,44 +1,18 @@ """optim op""" -import mindspore -from mindspore import ops -from mindspore.ops._primitive_cache import _get_cache_prim +from mindnlp.core.executor import execute -DEVICE_TARGET = mindspore.get_context('device_target') - -_adadelta = ops.ApplyAdadelta() def raw_adadelta(param, square_avg, acc_delta, lr, rho, eps, grad): - return _adadelta(param, square_avg, acc_delta, lr, rho, eps, grad) + return execute('raw_adadelta', param, square_avg, acc_delta, lr, rho, eps, grad) -_adam = ops.Adam() def raw_adam(param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): # var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad - if DEVICE_TARGET == 'GPU' and param.dtype != mindspore.float32: - beta1_power, beta2_power, lr, beta1, beta2, epsilon = mindspore.tensor(beta1_power, dtype=param.dtype), \ - mindspore.tensor(beta2_power, dtype=param.dtype), \ - mindspore.tensor(lr, dtype=param.dtype), \ - mindspore.tensor(beta1, dtype=param.dtype), \ - mindspore.tensor(beta2, dtype=param.dtype), \ - mindspore.tensor(epsilon, dtype=param.dtype) - return _adam(param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) + return execute('raw_adam', param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) -_adam_amsgrad = ops.ApplyAdamWithAmsgradV2() def raw_adam_amsgrad(param, exp_avg, exp_avg_sq, max_exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): # var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad - - if DEVICE_TARGET == 'GPU' and param.dtype != mindspore.float32: - beta1_power, beta2_power, lr, beta1, beta2, epsilon = mindspore.tensor(beta1_power, dtype=param.dtype), \ - mindspore.tensor(beta2_power, dtype=param.dtype), \ - mindspore.tensor(lr, dtype=param.dtype), \ - mindspore.tensor(beta1, dtype=param.dtype), \ - mindspore.tensor(beta2, dtype=param.dtype), \ - mindspore.tensor(epsilon, dtype=param.dtype) - - return _adam_amsgrad(param, exp_avg, exp_avg_sq, max_exp_avg_sq, - beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) - + return execute('raw_adam_amsgrad', param, exp_avg, exp_avg_sq, max_exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) def raw_sgd(param, grad, lr, dampening, weight_decay, nesterov, accum, momentum, stat): - _sgd = _get_cache_prim(ops.SGD)(dampening, weight_decay, nesterov) - return _sgd(param, grad, lr, accum, momentum, stat) + return execute('raw_sgd', param, grad, lr, dampening, weight_decay, nesterov, accum, momentum, stat) __all__ = ['raw_adadelta', 'raw_adam', 'raw_adam_amsgrad', 'raw_sgd'] diff --git a/mindnlp/core/ops/other.py b/mindnlp/core/ops/other.py index e7b280401..f9f0f6fbb 100644 --- a/mindnlp/core/ops/other.py +++ b/mindnlp/core/ops/other.py @@ -1,17 +1,10 @@ """other op""" - -import copy import numpy as np import mindspore -from mindspore import ops -from mindspore.common.initializer import initializer -from mindspore.ops._primitive_cache import _get_cache_prim - +from mindspore.ops import gather from mindnlp import core -from ..configs import use_pyboost, ON_ORANGE_PI, ON_A1 -from .reduction import any -from .comparison import eq -from ._inner import call_ms_func +from mindnlp.core.executor import execute +from ..configs import ON_A1 # atleast_2d @@ -20,23 +13,8 @@ # bincount -has_bincount = hasattr(mindspore.mint, "bincount") - - def bincount(input, weights=None, minlength=0): - if use_pyboost() and has_bincount: - return mindspore.mint.bincount(input, weights, minlength) - if input.max() > minlength - 1: - length = (input.max() + 1) - else: - length = core.tensor(minlength) - idx = core.arange(length).unsqueeze(-1) - idx_mapping = core.eq(input, idx) - if weights is not None: - if input.shape != weights.shape: - raise ValueError('for bincount `input` and `weights` must have the same length') - idx_mapping = weights * idx_mapping - return core.sum(idx_mapping, 1).ravel() + return execute('bincount_ext', input, weights, minlength) # block_diag @@ -44,43 +22,13 @@ def bincount(input, weights=None, minlength=0): # broadcast_tensors def broadcast_tensors(*tensors): target_shape = broadcast_shapes(*[t.shape for t in tensors]) - broadcasted_tensors = [t.broadcast_to(target_shape) for t in tensors] - return broadcasted_tensors -def manual_expand(tensor, shape): - assert ( - len(shape) >= tensor.dim() - ), "Target shape must have equal or more dimensions than the tensor." - - for _ in range(len(shape) - tensor.dim()): - tensor = tensor.unsqueeze(0) - - repeats = [] - for i, (tensor_dim, target_dim) in enumerate(zip(tensor.shape, shape)): - if target_dim == -1: - repeats.append(1) - else: - repeats.append(target_dim // tensor_dim if tensor_dim == 1 else 1) - - return tensor.tile(tuple(repeats)) - - # broadcast_to -has_broadcast_to = hasattr(mindspore.mint, "broadcast_to") - - -def broadcast_to(input, *shape): - if isinstance(shape[0], (list, tuple)): - shape = shape[0] - if ON_ORANGE_PI and not use_pyboost(): - # return input.expand(mindspore.tensor(shape)) - return manual_expand(input, shape) - if use_pyboost() and has_broadcast_to: - return mindspore.mint.broadcast_to(input, shape) - return ops.broadcast_to(input, shape) +def broadcast_to(input, shape): + return execute('broadcast_to', input, shape) # broadcast_shapes @@ -105,42 +53,20 @@ def broadcast_shapes(*shapes): return tuple(reversed(result_shape)) - # bucketize -def bucketize(input, boundaries, *, out_int32=False, right=False, out=None): - if isinstance(boundaries, mindspore.Tensor): - boundaries = boundaries.tolist() - - if not boundaries: - return input - out = ops.bucketize(input, boundaries, right=right) - if not out_int32: - out = out.to(mindspore.int64) - return out # cartesian_prod # cdist -has_cdist = hasattr(mindspore.mint, "cdist") - - def cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"): - if isinstance(p, int): - p = float(p) - if use_pyboost() and has_cdist: - return mindspore.mint.cdist(x1, x2, p, compute_mode) - return ops.cdist(x1, x2, float(p)) - + return execute('cdist', x1, x2, p) # clone -has_clone = hasattr(mindspore.mint, "clone") - - -def clone(input): - if use_pyboost() and has_clone: - return mindspore.mint.clone(input) - return copy.deepcopy(input) +def clone(input, *, memory_format=core.preserve_format): + if input.device.type == 'npu': + return execute('clone', input) + return execute('identity', input) # combinations @@ -159,39 +85,15 @@ def clone(input): # cummin # cumprod -def cumprod(input, dim, *, dtype=None, out=None): - return ops.cumprod(input, dim, dtype=dtype) # cumsum -has_cumsum = hasattr(mindspore.mint, "cumsum") - -def cumsum(input, dim=None, dtype=None, out=None, **kwargs): - dim = kwargs.pop('axis', dim) - input_dtype = input.dtype - if input_dtype == mindspore.int64: - input = input.to(mindspore.int32) - if ( - use_pyboost() and has_cumsum and not ON_ORANGE_PI - ): # since cann8.0 community remove aclnn cumsum - output = mindspore.mint.cumsum(input, dim, dtype) - else: - if input.dtype == mindspore.bool_: - input = input.to(mindspore.int32) - output = ops.cumsum(input, dim, dtype) - if out is not None: - out.assign_value(output) - return out - output = output.to(input_dtype) - return output - +def cumsum(input, dim, dtype=None): + return execute('cumsum_ext', input, dim, + dtype if dtype is None else dtype_to_type_id('CumsumExt', 'dtype', dtype)) # diag -has_diag = hasattr(mindspore.mint, "diag") -def diag(input, diagonal=0): - if use_pyboost() and has_diag: - return mindspore.mint.diag(input, diagonal) - return mindspore.numpy.diag(input, diagonal) - +def diag(input, diagonal=0, *, out=None): + return execute('diag', input, diagonal) # diag_embed @@ -200,489 +102,542 @@ def diag(input, diagonal=0): # diagonal +def diagonal(input, offset=0, dim1=0, dim2=1): + return execute('diagonal', input, offset, dim1, dim2) # diff -def diff(input, n=1, dim=-1, prepend=None, append=None): - if use_pyboost(): - return mindspore.mint.diff(input, n, dim, prepend, append) - return ops.diff(input, n, dim, prepend, append) - -# einsum +def _diff_is_scalar_or_scalar_tensor(value): + """judge the value""" + if isinstance(value, int): + return True + if isinstance(value, core.Tensor) and value.shape == (): + return True -def einsum_label_to_index(label): - """ - Args: - label (str): The label representing a dimension in an Einstein sum. - It should be a single character from the alphabet (upper or lower case) or '.'. - - Returns: - NoneType: This function returns None. + return False - Raises: - None. - """ - if label == ".": - return 52 - NUM_OF_LETTERS = ord("z") - ord("a") + 1 - return ( - (ord(label) - ord("A")) - if (label.isupper()) - else (NUM_OF_LETTERS + (ord(label) - ord("a"))) - ) +def _diff_helper(input, n, dim): + """calculate the forward difference""" + out_len = input.shape[dim] - 1 + is_bool = (input.dtype == core.bool) + result = input + for _ in range(n): # pylint: disable=unused-variable + if is_bool: + result = core.logical_xor(core.narrow(result, dim, 1, out_len), core.narrow(result, dim, 0, out_len)) + else: + result = core.sub(core.narrow(result, dim, 1, out_len), core.narrow(result, dim, 0, out_len)) -def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True): - r""" - This function takes three parameters: dim, dim_post_expr, and wrap_scalar. + if out_len == 0: + break + out_len -= 1 - Args: - - dim (int): Represents the dimension to be wrapped. - - dim_post_expr (int): Represents the value used to wrap the dimension. - - wrap_scalar (bool, optional): Specifies whether a scalar value should be wrapped. Default is True. + return result - Returns: - None: This function does not return a value directly. - Raises: - AssertionError: Raised if the value of dim_post_expr is less than or equal to 0 and wrap_scalar is False. - AssertionError: Raised if the value of dim is less than the minimum or greater than the maximum allowed range. - AssertionError: Raised if the value of dim is negative and cannot be wrapped due to invalid dim_post_expr. +def _diff_prepend_append_on_dim(input, prepend, append, dim): + """append tensor on dim""" + if prepend is not None and append is None: + return core.cat((prepend, input), dim) - """ - if dim_post_expr <= 0: - assert wrap_scalar - dim_post_expr = 1 - min = -dim_post_expr - max = dim_post_expr - 1 - assert not (dim < min or dim > max) - if dim < 0: - dim += dim_post_expr - return dim - - -def dim_list_to_bitset(opt_dims, ndims): - r""" - Converts a list of optional dimensions to a bitset representation. + if prepend is None and append is not None: + return core.cat((input, append), dim) - Args: - opt_dims (List[int]): The list of optional dimensions to be converted to a bitset representation. - ndims (int): The total number of dimensions. + return core.cat((prepend, input, append), dim) - Returns: - List[bool]: A list representing the bitset, where True indicates the presence of the dimension and False indicates its absence. - Raises: - None - """ - if opt_dims: - seen = [False] * (max(opt_dims) + 1) - for dim in opt_dims: - dim = maybe_wrap_dim(dim, ndims) - seen[dim] = True +def diff(input, n=1, dim=-1, prepend=None, append=None): + if (prepend is None and append is None) or n == 0: + return _diff_helper(input, n, dim) + + input = _diff_prepend_append_on_dim(input, prepend, append, dim) + return _diff_helper(input, n, dim) + +def _einsum_convert_sublist_to_label(num, ell_num=False): + """Convert sublist to label.""" + if num == Ellipsis or ell_num and num == 52: + return '...' + if 0 <= num < 26: + return chr(num + ord('A')) + if 26 <= num < 52: + return chr(num + ord('a') - 26) + raise ValueError( + f'For einsum, the number in sublist must be in range [0, 52), but got {num}') + + +def _einsum_convert_label_to_index(label): + """Convert label to index.""" + label_num = ord(label) + if ord('A') <= label_num <= ord('Z'): + return label_num - ord('A') + if ord('a') <= label_num <= ord('z'): + return label_num - ord('a') + 26 + if label_num == ord('.'): + return 52 + raise ValueError( + f'For einsum, the label in equation must be in [a-zA-Z] or ., but got {label}') + + +def _einsum_convert_sublist(equation, *operands): + """Convert the sublist to an equation operand if the received input is a sublist format.""" + if isinstance(equation, core.Tensor): + equation_tmp = '' + for i, lst in enumerate(operands): + if i % 2 == 0: + for _, num in enumerate(lst): + equation_tmp += _einsum_convert_sublist_to_label(num) + if i in (len(operands) - 1, len(operands) - 2): + continue + equation_tmp += ',' + if len(operands) % 2 == 0: + equation_tmp += '->' + for _, num in enumerate(operands[-1]): + equation_tmp += _einsum_convert_sublist_to_label(num) + operands_tmp = list([equation]) + list(operands[1:-1:2]) + else: + operands_tmp = list([equation]) + list(operands[1::2]) + equation = equation_tmp + operands = tuple(operands_tmp) + if len(operands) == 0: # pylint: disable=len-as-condition + raise ValueError( + "For einsum, the 'operands' must have at least one operand.") + return equation, operands + + +def _einsum_check_inputargs(equation, operands): + """Check equation and operands.""" + if not isinstance(equation, str): + raise TypeError( + f"For einsum, 'equation' must be a str, but got {type(equation)}.") + for operand in operands: + if not isinstance(operand, core.Tensor): + raise TypeError( + f"For einsum, members of 'operands' must be Tensor, but got {type(operand)}.") + + +def _einsum_parse_equation(equation): + """Parse equation.""" + l_equation = '' + r_equation = '' + equation = equation.replace(' ', '') + + if '->' in equation: + l_equation, r_equation = equation.split('->', 1) + if l_equation == '': + raise ValueError( + 'For einsum, equation must contain characters to the left fo the arrow.') else: - seen = [True for _ in range(ndims)] - return seen + l_equation = equation + if ',' in l_equation: + l_equationlst = l_equation.split(",") + else: + l_equationlst = [l_equation] + + l_equationlst = [] + + for subequation in l_equation.split(','): + if '.' in subequation and ('...' not in subequation or subequation.count('.') != 3): + raise ValueError(f"For einsum, an ellipsis in the equation must include three continuous \'.\', " + f"and can only be found once.") + subequation_lst = [_einsum_convert_label_to_index(label) for label in subequation.replace('...', '.')] + l_equationlst.append(subequation_lst) + + if "." in r_equation and ('...' not in r_equation or r_equation.count('.') != 3): + raise ValueError(f"For einsum, an ellipsis in the equation must include three continuous \'.\', " + f"and can only be found once.") + r_equationlst = [_einsum_convert_label_to_index(label) for label in r_equation.replace('...', '.')] + + return l_equationlst, r_equationlst, ('->' in equation) + + +def _einsum_parse_labels(l_equationlst, operands): + """Parse left script of equation.""" + align_rank = 0 + max_labels = 53 + ellipsis_dimnum = 0 + labels_count = [0] * max_labels + + if len(operands) != len(l_equationlst): + raise ValueError(f"For einsum, 'operands' is not equal to specified in the 'equation', " + f"but got {len(operands)} and {len(l_equationlst)}.") + + for idx, sub_equ in enumerate(l_equationlst): + start_dim = 0 + label_num = 0 + operand_shape = list(operands[idx].shape) + for label in sub_equ: + dim_num = 1 + label_num += 1 + end_dim = start_dim + 1 + + # Label is ellipsis + if label == 52: + end_dim = len(operand_shape) - len(sub_equ) + label_num + dim_num = end_dim - start_dim + if ellipsis_dimnum != 0 and ellipsis_dimnum != dim_num: + raise ValueError(f"For einsum, an ellipsis in 'equation' can only represent the same numbers of " + f"dimensions in 'operands'.") + ellipsis_dimnum = dim_num + if labels_count[label] == 0: + align_rank += dim_num + labels_count[label] += 1 + start_dim += dim_num + if label_num != len(sub_equ) or start_dim != len(operand_shape): + raise ValueError(f"For einsum, the numbers of labels specified in the 'equation' does not match " + f"'operands[{idx}]'.") + return ellipsis_dimnum, labels_count, align_rank + + +def _einsum_infer_output(r_equationlst, arrow_exist, ellipsis_dimnum, labels_count): + """Parse right script of equation and infer output shape.""" + idx = 0 + idle_idx = -1 + output_rank = 0 + labels_perm_idx = [idle_idx] * 53 + + if arrow_exist: + for label in r_equationlst: + if labels_count[label] != 0: + if labels_perm_idx[label] != idle_idx: + raise ValueError(f"For einsum, '{_einsum_convert_sublist_to_label(label, True)}' or {label} in " + f"sublist format has appears more than once in output subscript.") + dimnum = 1 + if label == 52: + dimnum = ellipsis_dimnum + labels_perm_idx[label] = idx + output_rank += dimnum + idx += dimnum + else: + raise ValueError(f"For einsum, the label to the right of arrow in the 'equation' must appear on " + f"left, but '{_einsum_convert_sublist_to_label(label, True)}' does not.") + else: + if labels_count[52] != 0: + output_rank += ellipsis_dimnum + labels_perm_idx[52] = idx + idx += ellipsis_dimnum + for label, count in enumerate(labels_count): + if count == 1: + output_rank += 1 + labels_perm_idx[label] = idx + idx += 1 + + for label, count in enumerate(labels_count): + if count != 0 and labels_perm_idx[label] == idle_idx: + labels_perm_idx[label] = idx + idx += 1 + + return output_rank, labels_perm_idx + + +def _einsum_adjust_operands(operands, l_equationlst, ellipsis_dimnum, labels_perm_idx, align_rank): + """Align operands to output as possible.""" + # Unsqueeze miss dimensions to make all operands has same rank, compute diagonal if operand has same label. + # Then use _labels_perm_idx to transpose all operands to align dimensions with output. + adjust_operands = [] + for idx, operand in enumerate(operands): + idle_dim = -1 + align_axis = [idle_dim] * align_rank + label_dims = [idle_dim] * 53 + dim = 0 -def sumproduct_pair(left_, right_, sum_dims_, keep_dim_): - """ - Calculate the sum-product pair of two arrays along specified dimensions. - - Args: - left_ (array): The left input array. - right_ (array): The right input array. - sum_dims_ (list): A list of dimensions along which to calculate the sum-product pair. - keep_dim_ (bool): A flag indicating whether to keep the dimensions in the result. + for label in l_equationlst[idx]: + if label_dims[label] != idle_dim: + operand = core.diagonal(operand, 0, label_dims[label], dim) + diag_perm = [] + diag_dim = 0 + for i in range(len(operand.shape)): + if i == label_dims[label]: + diag_perm.append(len(operand.shape) - 1) + else: + diag_perm.append(diag_dim) + diag_dim += 1 + operand = core.permute(operand, tuple(diag_perm)) + else: + label_dims[label] = dim + if label == 52: + for ell_idx in range(ellipsis_dimnum): + align_axis[labels_perm_idx[label] + ell_idx] = dim + dim += 1 + else: + align_axis[labels_perm_idx[label]] = dim + dim += 1 + if len(operand.shape) < align_rank: + for i, axis in enumerate(align_axis): + if axis == idle_dim: + align_axis[i] = dim + dim += 1 + missing_dims = [1] * (align_rank - len(operand.shape)) + operand_shape = list(operand.shape) + missing_dims + operand = core.reshape(operand, operand_shape) + operand = core.permute(operand, tuple(align_axis)) + adjust_operands.append(operand) + return adjust_operands + + +def _einsum_find_dimlastop(align_rank, operands, adjust_operands): + """Find dim last operand.""" + dim_last_op = [0] * align_rank + has_zero_dim = False + for dim in range(align_rank): + broadcast_dim = adjust_operands[0].shape[dim] + for idx in range(1, len(adjust_operands)): + other_dim = adjust_operands[idx].shape[dim] + if broadcast_dim != other_dim and broadcast_dim != 1 and other_dim != 1: + err_msg = "For einsum, operands do not broadcast after align to output [shapes :origin -> adjust]:" + for i in range(len(operands)): + err_msg += f" {operands[i].shape} -> {adjust_operands[i].shape}" + raise ValueError(err_msg) + if other_dim != 1: + dim_last_op[dim] = idx + broadcast_dim = other_dim + has_zero_dim = has_zero_dim or broadcast_dim == 0 + return dim_last_op, has_zero_dim + + +def _einsum_multiplication(sum_dims, l_tensor, r_tensor): + """Compute bmm for einsum.""" + batch_dims = [] + lonly_dims = [] + ronly_dims = [] + batch_size = 1 + lonly_size = 1 + ronly_size = 1 + sum_size = 1 + + l_shape = l_tensor.shape + r_shape = r_tensor.shape + + # Compute sum if dim is in sum_dims and get shapes for bmm + for i in range(len(l_shape)): + sum_l = l_shape[i] > 1 + sum_r = r_shape[i] > 1 + if i in sum_dims: + if sum_l and sum_r: + sum_size *= l_shape[i] + elif sum_l: + l_tensor = core.sum(l_tensor, i, True) + elif sum_r: + r_tensor = core.sum(r_tensor, i, True) + elif sum_l and sum_r: + batch_dims.append(i) + batch_size *= l_shape[i] + elif sum_l: + lonly_dims.append(i) + lonly_size *= l_shape[i] + else: + ronly_dims.append(i) + ronly_size *= r_shape[i] + + # Compute the einsum bmm operators pipeline. + # The whole operators pipeline is transpose(in) -> reshape(in) -> bmm(in) -> reshape(out) -> transpose(out). + l_reshape_shape = (batch_size, lonly_size, sum_size) + r_reshape_shape = (batch_size, sum_size, ronly_size) + + out_reshape_shape = [l_shape[dim] for dim in batch_dims] + out_reshape_shape += [l_shape[dim] for dim in lonly_dims] + out_reshape_shape += [1 for _ in sum_dims] + out_reshape_shape += [r_shape[dim] for dim in ronly_dims] + + l_perm_axis = batch_dims + lonly_dims + sum_dims + ronly_dims + r_perm_axis = batch_dims + sum_dims + ronly_dims + lonly_dims + out_perm_axis = [-1] * len(out_reshape_shape) + + out_dim = 0 + for idx in range(len(l_perm_axis)): + out_perm_axis[l_perm_axis[idx]] = out_dim + out_dim += 1 + + l_tensor = core.permute(l_tensor, tuple(l_perm_axis)) + l_tensor = core.reshape(l_tensor, l_reshape_shape) + + r_tensor = core.permute(r_tensor, tuple(r_perm_axis)) + r_tensor = core.reshape(r_tensor, r_reshape_shape) + + output = core.bmm(l_tensor, r_tensor) + output = core.reshape(output, out_reshape_shape) + output = core.permute(output, tuple(out_perm_axis)) + + output_origin_shape = output.shape + output_squeeze_shape = [] + for dim in range(len(output_origin_shape)): + if dim not in sum_dims: + output_squeeze_shape.append(output_origin_shape[dim]) + + return core.reshape(output, output_squeeze_shape) + + +def _einsum(equation, operands): + '''Einsum main process''' + _l_equationlst, _r_equationlst, _arrow_exist = _einsum_parse_equation( + equation) + _ellipsis_dimnum, _labels_count, _align_rank = _einsum_parse_labels( + _l_equationlst, operands) + _output_rank, _labels_perm_idx = _einsum_infer_output( + _r_equationlst, _arrow_exist, _ellipsis_dimnum, _labels_count) + _adjust_operands = _einsum_adjust_operands(operands, _l_equationlst, _ellipsis_dimnum, _labels_perm_idx, + _align_rank) + _dim_last_op, _has_zero_dim = _einsum_find_dimlastop( + _align_rank, operands, _adjust_operands) + _result = _adjust_operands[0] + + # Fast path if operands has zero dim. + if _has_zero_dim: + output_shape = [] + for dim in range(_output_rank): + output_shape.append(_adjust_operands[_dim_last_op[dim]].shape[dim]) + return core.zeros(output_shape, dtype=_result.dtype) + + # Sum or squeeze dimensions that is 1 for all rest operands. + _reduce_dim = _output_rank + for dim in range(_output_rank, _align_rank): + if _dim_last_op[dim] == 0: + if _result.shape[_reduce_dim] == 1: + _result = core.squeeze(_result, _reduce_dim) + else: + _result = core.sum(_result, _reduce_dim) + else: + _reduce_dim += 1 - Returns: - None. The function performs the sum-product pair calculation and returns None. + # Compute multiplication if operands are more than two. + for i in range(1, len(_adjust_operands)): + operand = _adjust_operands[i] + dim = _output_rank + sum_dims = [] + for j in range(_output_rank, _align_rank): + if _dim_last_op[j] < i: + operand = core.squeeze(operand, dim) + elif _dim_last_op[j] == i: + if _result.shape[dim] == 1: + operand = core.sum(operand, dim) + _result = core.squeeze(_result, dim) + else: + sum_dims.append(dim) + dim += 1 + else: + dim += 1 - Raises: - AssertionError: If the number of dimensions of the input arrays do not match, - or if non-broadcast dimensions do not match. - """ - assert left_.ndim == right_.ndim, "number of dimensions must match" - if len(sum_dims_) == 0: - return ops.mul(left_, right_) - - dim = left_.ndim - sum_dims = dim_list_to_bitset(sum_dims_, dim) - - lro, lo, ro = [], [], [] - lro_size, lo_size, ro_size, sum_size = 1, 1, 1, 1 - left = left_ - right = right_ - - for i in range(dim): - sl = left.shape[i] > 1 - sr = right.shape[i] > 1 - if sum_dims[i]: - if sl and sr: - assert ( - left.shape[i] == right.shape[i] - ), "non-broadcast dimensions must match" - sum_size *= left.shape[i] - elif sl: - left = ops.sum(left, i, keepdim=True) - elif sr: - right = ops.sum(right, i, keepdim=True) - elif sl and sr: - assert ( - left.shape[i] == right.shape[i] - ), "non-broadcast dimensions must match" - lro.append(i) - lro_size *= left.shape[i] - elif sl: - lo.append(i) - lo_size *= left.shape[i] + if sum_dims == []: + _result = core.mul(_result, operand) + elif len(sum_dims) == len(_result.shape): + _result = core.dot(core.flatten(_result), core.flatten(operand)) else: - ro.append(i) - ro_size *= right.shape[i] - - out_size = [] - for d in lro: - out_size.append(left.shape[d]) - for d in lo: - out_size.append(left.shape[d]) - for d in sum_dims_: - out_size.append(1) - for d in ro: - out_size.append(right.shape[d]) - - lpermutation = lro.copy() - lpermutation += lo - lpermutation += sum_dims_ - lpermutation += ro - - rpermutation = lro.copy() - rpermutation += sum_dims_ - rpermutation += ro - rpermutation += lo - - opermutation = [-1] * (len(lro) + len(lo) + len(sum_dims_) + len(ro)) - i = 0 - for it in lro: - opermutation[it] = i - i += 1 - for it in lo: - opermutation[it] = i - i += 1 - for it in sum_dims_: - opermutation[it] = i - i += 1 - for it in ro: - opermutation[it] = i - i += 1 - - left = ops.transpose(left, tuple(lpermutation)).reshape(lro_size, lo_size, sum_size) - right = ops.transpose(right, tuple(rpermutation)).view(lro_size, sum_size, ro_size) - - result = ops.bmm(left, right) - result = result.view(*out_size).transpose(*opermutation) - - if not keep_dim_: - sizes = list(result.shape) - for i in range(dim - 1, 0, -1): - if sum_dims[i]: - sizes.pop(i) - result = result.view(*sizes) + _result = _einsum_multiplication(sum_dims, _result, operand) - return result + return _result -ELLIPSIS = 52 +def einsum(equation, *operands): + r""" + According to the Einstein summation Convention (Einsum), + the product of the input tensor elements is summed along the specified dimension. + You can use this operator to perform diagonal, reducesum, transpose, matmul, mul, inner product operations, etc. -has_einsum = hasattr(mindspore.mint, "einsum") + Note: + The sublist format is also supported. For example, einsum_ext(op1, sublist1, op2, sublist2, ..., sublist_out). + In this format, equation can be derived by the sublists which are made up of Python's Ellipsis and list of + integers in [0, 52). Each operand is followed by a sublist and an output sublist is at the end. + Dynamic shape, dynamic rank input is not supported in `graph mode (mode=mindspore.GRAPH_MODE) + `_. + .. warning:: + This is an experimental API that is subject to change or deletion. -def einsum(equation, *operands): - """ Args: - equation (str): A string representing the Einstein summation equation to be computed. - The equation should follow the Einstein summation convention with subscripts in [a-zA-Z], - commas separating operands, and '->' indicating the output structure. - It must include at least one operand. An ellipsis '...' can be used to represent multiple dimensions. + equation (str): Notation based on the Einstein summation convention, represent the operation you want to do. + the value can contain only letters, commas, ellipsis and arrow. The letters(must be in [a-zA-Z]) represent + input tensor dimension, commas(,) represent separate tensors, ellipsis indicates the tensor dimension that + you do not care about, the left of the arrow indicates the input tensors, and the right of it indicates the + desired output dimension. If there are no arrows in the equation, the letters that appear exactly once in + the equation will be part of the output, sorted in increasing alphabetical order. The output is computed by + multiplying the input operands element-wise, with their dimensions aligned based on the letters, and then + summing out the dimensions whose letters are not part of the output. If there is one arrow in the equation, + the output letters must appear at least once for some input operand and at most once for the output. + operands (Tensor): Input tensor used for calculation. The dtype of the tensor must be the same. Returns: - None: This function does not return a value. + Tensor, the shape of it can be obtained from the `equation` , and the dtype is the same as input tensors. Raises: - AssertionError: If the function is called without providing at least one operand. - AssertionError: If an invalid subscript is given in the equation string. - AssertionError: If the number of subscripts in the equation does not match the number of dimensions for an operand. - AssertionError: If fewer operands are provided than specified in the equation. - AssertionError: If more operands are provided than specified in the equation. - RuntimeError: If operands do not broadcast with remapped shapes [original->remapped]. + TypeError: If `equation` is invalid, or the `equation` does not match the input tensor. + ValueError: If the number in sublist is not in [0, 52) in sublist format. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import mindspore + >>> import numpy as np + >>> from mindspore import Tensor, ops + >>> x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32) + >>> equation = "i->" + >>> output = ops.einsum_ext(equation, x) + >>> print(output) + 7.0 + >>> x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32) + >>> y = Tensor(np.array([2.0, 4.0, 3.0]), mindspore.float32) + >>> equation = "i,i->i" + >>> output = ops.einsum_ext(equation, x, y) + >>> print(output) + [ 2. 8. 12.] + >>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32) + >>> y = Tensor(np.array([[2.0, 3.0], [1.0, 2.0], [4.0, 5.0]]), mindspore.float32) + >>> equation = "ij,jk->ik" + >>> output = ops.einsum_ext(equation, x, y) + >>> print(output) + [[16. 22.] + [37. 52.]] + >>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32) + >>> equation = "ij->ji" + >>> output = ops.einsum_ext(equation, x) + >>> print(output) + [[1. 4.] + [2. 5.] + [3. 6.]] + >>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32) + >>> equation = "ij->j" + >>> output = ops.einsum_ext(equation, x) + >>> print(output) + [5. 7. 9.] + >>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32) + >>> equation = "...->" + >>> output = ops.einsum_ext(equation, x) + >>> print(output) + 21.0 + >>> x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32) + >>> y = Tensor(np.array([2.0, 4.0, 1.0]), mindspore.float32) + >>> equation = "j,i->ji" + >>> output = ops.einsum_ext(equation, x, y) + >>> print(output) + [[ 2. 4. 1.] + [ 4. 8. 2.] + [ 6. 12. 3.]] + >>> x = mindspore.Tensor([1, 2, 3, 4], mindspore.float32) + >>> y = mindspore.Tensor([1, 2], mindspore.float32) + >>> output = ops.einsum_ext(x, [..., 1], y, [..., 2], [..., 1, 2]) + >>> print(output) + [[1. 2.] + [2. 4.] + [3. 6.] + [4. 8.]] """ - if isinstance(operands[0], (tuple, list)): - operands = operands[0] - if use_pyboost() and has_einsum: - return mindspore.mint.einsum(equation, *operands) - assert operands, "einsum(): must provide at least one operand" - if isinstance(operands[0], tuple): - operands = operands[0] - - arrow_pos = equation.find("->") - num_ops = len(operands) - op_labels = [[] for _ in range(num_ops)] - lhs = equation[0:arrow_pos] - - curr_op = 0 - found_ell = False - ell_skip = 0 - for i, label in enumerate(lhs): - if label == " ": - continue - if label == ".": - if ell_skip != 0: - ell_skip -= 1 - continue - assert ( - not found_ell - ), f"einsum(): found {curr_op} for operand for which an ellipsis was already found" - assert ( - i + 2 < len(lhs) and lhs[i + 1] == "." - ), f"einsum(): found {curr_op} for operand that is not part of any ellipsis" - ell_skip = 2 - op_labels[curr_op].append(ELLIPSIS) - found_ell = True - elif label == ",": - curr_op += 1 - assert ( - curr_op < num_ops - ), "einsum(): fewer operands were provided than specified in the equation" - found_ell = False - else: - assert str.isalpha( - label - ), f"einsum(): invalid subscript given at index {i} in the equation string, subscripts must be in [a-zA-Z]" - op_labels[curr_op].append(einsum_label_to_index(label)) - - assert ( - curr_op == num_ops - 1 - ), "einsum(): more operands were provided than specified in the equation" - # Labels must be within [a-zA-Z]. - TOTAL_LABELS = 52 - label_count = [0] * TOTAL_LABELS - # The maximum number of dimensions covered by any ellipsis, needed when - # unsqueezing missing dimensions from operands to permute and broadcast - ell_num_dim = 0 - - # Compute label frequency and number of dimensions covered by ellipsis - # We do this after parsing labels to make it more readable and simpler - # to compute the number of dimensions covered by ellipsis. - for i, operand in enumerate(operands): - labels = op_labels[i] - ndims = operand.ndim - nlabels = len(labels) - has_ellipsis = False - - for label in labels: - if label == ELLIPSIS: - nlabels -= 1 - has_ellipsis = True - ell_num_dim = max(ell_num_dim, ndims - nlabels) - else: - label_count[label] += 1 - if has_ellipsis: - assert nlabels <= ndims, ( - f"einsum(): the number of subscripts in the equation ({nlabels}" - f") is more than the number of dimensions ({ndims}) for operand {i}" - ) - else: - assert nlabels == ndims, ( - f"einsum(): the number of subscripts in the equation ({nlabels}" - f") does not match the number of dimensions (" - f"{ndims}) for operand {i} and no ellipsis was given" - ) - - # We want to align the dimensions of every input tensor to have - # shape out_dims + sum_dims. For this, we create a mapping of label - # to index into the permuted shape. - label_perm_index = [-1] * TOTAL_LABELS - # Current index in the permuted shape - perm_index = 0 - # Start index of ellipsis dimensions in the permuted shape - ell_index = 0 - found_ell = False - - if arrow_pos == -1: - # Implicit output is ellipsis (...) + labels seen only once - perm_index = ell_num_dim - found_ell = True - for label, _label_count in enumerate(label_count): - if _label_count == 1: - label_perm_index[label] = perm_index - perm_index += 1 - else: - rhs = equation[arrow_pos + 2 :] - ell_skip = 0 - for i, label in enumerate(rhs): - if label == " ": - continue - if label == ".": - if ell_skip != 0: - ell_skip -= 1 - continue - assert ( - not found_ell - ), "einsum(): found '.' for output but an ellipsis (...) was already found" - assert ( - i + 2 < len(rhs) and rhs[i + 1] == "." - ), "einsum(): found '.' for output that is not part of any ellipsis (...)" - ell_skip = 2 - ell_index = perm_index - perm_index += ell_num_dim - found_ell = True - else: - assert str.isalpha(label), ( - f"einsum(): invalid subscript given at index {len(lhs) + 2 + i} " - f"in the equation string, subscripts must be in [a-zA-Z]" - ) - - index = einsum_label_to_index(label) - label_perm_index[index] = perm_index - perm_index += 1 - - out_size = perm_index - if not found_ell: - ell_index = perm_index - perm_index += ell_num_dim - - for label in range(TOTAL_LABELS): - if label_count[label] > 0 and label_perm_index[label] == -1: - label_perm_index[label] = perm_index - perm_index += 1 - - # Here we unsqueeze missing dimensions to make all operands have the same - # number of dimensions. We take diagonals for repeated labels within the - # same operand. Finally we permute the operands to align dimensions as - # per the perm_out_index we computed above. - permuted_operands = [] - for i, operand in enumerate(operands): - perm_shape = [-1] * perm_index - label_dim = [-1] * TOTAL_LABELS - operand = operands[i] - labels = op_labels[i] - original_sizes = operand.shape - - j = 0 - for label in labels: - if label == ELLIPSIS: - # Add missing dimensions covered by the ellipsis - num_missing_dim = ell_num_dim - (len(original_sizes) - len(labels) + 1) - for k in range(num_missing_dim): - operand = ops.unsqueeze(operand, j) - for k in range(ell_num_dim): - perm_shape[ell_index + k] = j - j += 1 - elif label_dim[label] != -1: - dim = label_dim[label] - operand = ops.diagonal(operand, offset=0, dim1=dim, dim2=j) - operand = ops.moveaxis(operand, -1, dim) - else: - label_dim[label] = j - perm_shape[label_perm_index[label]] = j - j += 1 - - # Add dimensions for missing labels - for idx, index in enumerate(perm_shape): - if index == -1: - operand = ops.unsqueeze(operand, -1) - perm_shape[idx] = j - j += 1 - - operand = ops.transpose(operand, tuple(perm_shape)) - permuted_operands.append(operand) - - # Check if operands broadcast and keep track of last operand with - # dimension size != 1 for optimizing reductions - dim_last_op = [0] * perm_index - has_zero_size_dim = False - for dim in range(perm_index): - broadcast_size = permuted_operands[0].shape[dim] - for i in range(1, len(operands)): - dim_size = permuted_operands[i].shape[dim] - if broadcast_size != dim_size and broadcast_size != 1 and dim_size != 1: - raise RuntimeError( - "einsum(): operands do not broadcast with remapped shapes [original->remapped]" - ) - if dim_size != 1: - broadcast_size = dim_size - dim_last_op[dim] = i - has_zero_size_dim = has_zero_size_dim or (broadcast_size == 0) - - # Compute result - result = permuted_operands[0] - if has_zero_size_dim: - out_shape = [-1] * out_size - for i in range(out_size): - out_shape[i] = permuted_operands[dim_last_op[i]].shape[i] - return ops.zeros(out_shape) - - # Sum out or squeeze dimensions that are size 1 for all later operands - dim = out_size - for i in range(dim, perm_index): - if dim_last_op[i] == 0: - if result.shape[dim] == 1: - result = ops.squeeze(result, dim) - dim -= 1 - else: - result = ops.sum(result, dim) - dim -= 1 - dim += 1 - - for i in range(1, num_ops): - operand = permuted_operands[i] - sum_dims = [] - - # Sum out or squeeze dimensions that are size 1 for all later operands - dim = out_size - for j in range(dim, perm_index): - if dim_last_op[j] < i: - operand = ops.squeeze(operand, dim) - dim -= 1 - elif dim_last_op[j] == i: - if result.shape[dim] == 1: - operand = ops.sum(operand, dim) - result = ops.squeeze(result, dim) - dim -= 1 - else: - sum_dims.append(dim) - dim += 1 - if len(sum_dims) == 0: - result = result.mul(operand) - elif len(sum_dims) == len(result.shape): - result = result.flatten().dot(operand.flatten()) - else: - result = sumproduct_pair(result, operand, sum_dims, False) - return result - + _equation, _operands = _einsum_convert_sublist(equation, *operands) + _einsum_check_inputargs(_equation, _operands) + return _einsum(_equation, _operands) # flatten -has_flatten = hasattr(mindspore.mint, "flatten") - - def flatten(input, start_dim=0, end_dim=-1): - if use_pyboost() and has_flatten: - return mindspore.mint.flatten(input, start_dim, end_dim) - if end_dim < 0: - end_dim = input.ndim + end_dim - new_shape = input.shape[:start_dim] + (-1,) + input.shape[end_dim + 1 :] - return ops.reshape(input, new_shape) + if input.device.type == 'cpu': + if end_dim < 0: + end_dim = input.ndim + end_dim + new_shape = input.shape[:start_dim] + (-1,) + input.shape[end_dim + 1:] + return input.reshape(new_shape) + return execute('flatten_ext', input, start_dim, end_dim) # flip -has_flip = hasattr(mindspore.mint, "flip") - - def flip(input, dims): - if not isinstance(dims, (list, tuple)): - dims = (dims,) - if use_pyboost() and has_flip: - return mindspore.mint.flip(input, dims) - return ops.flip(input, dims) + return execute('reverse_v2', input, dims) # fliplr @@ -701,15 +656,6 @@ def flip(input, dims): # histc -has_histc = hasattr(mindspore.mint, "histc") - - -def histc(input, bins, min, max, *, out=None): - if use_pyboost() and has_histc: - return call_ms_func( - mindspore.mint.histc, input, bins=bins, min=min, max=max, out=out - ) - return call_ms_func(ops.histc, input, bins=bins, min=min, max=max, out=out) # histogram @@ -719,21 +665,10 @@ def histc(input, bins, min, max, *, out=None): # meshgrid -has_meshgrid = hasattr(mindspore.mint, "meshgrid") - - def meshgrid(*tensors, indexing=None): - if isinstance(tensors[0], (tuple, list)): - tensors = tensors[0] - if use_pyboost() and has_meshgrid: - return mindspore.mint.meshgrid(*tensors, indexing=indexing) - if isinstance(tensors[0], (list, tuple)): - tensors = tensors[0] - if len(tensors) == 1: - return tensors if indexing is None: - indexing = "ij" - return ops.meshgrid(*tensors, indexing=indexing) + indexing = 'ij' + return execute('meshgrid', tensors, indexing) # lcm @@ -742,18 +677,20 @@ def meshgrid(*tensors, indexing=None): # logcumsumexp # ravel - +def ravel(input): + return input.reshape(-1) # renorm # repeat_interleave -has_repeat_interleave = hasattr(mindspore.mint, 'repeat_interleave') -def repeat_interleave(input, repeats, dim=None, *, output_size=None): - if use_pyboost() and has_repeat_interleave and not ON_A1: - return mindspore.mint.repeat_interleave(input, repeats, dim=dim) - - if isinstance(repeats, mindspore.Tensor): +def repeat_interleave(input, repeats, dim=None): + if input.device.type == 'npu' and not ON_A1: + if isinstance(repeats, int): + return execute('repeat_interleave_int', input, repeats, dim, None) + return execute('repeat_interleave_tensor', input, repeats, dim, None) + + if isinstance(repeats, core.Tensor): repeats = repeats.tolist() if not isinstance(repeats, (tuple, list)): repeats = (repeats,) @@ -773,119 +710,69 @@ def repeat_interleave(input, repeats, dim=None, *, output_size=None): return Tensor_(input.dtype, (0,)) if input.dtype == mindspore.bool_: input = input.to(mindspore.int32) - out = ops.repeat_elements(input, repeats, dim) + out = execute('repeat_elements', input, repeats, dim) return out.to(mindspore.bool_) - return ops.repeat_elements(input, repeats, dim) + return execute('repeat_elements', input, repeats, dim) size = input.shape[dim] if len(repeats) != size: raise ValueError(f"For 'Tensor.repeat', the length of 'repeats' must be the same as the shape of the " f"original tensor in the 'axis' dimension, but got the length of 'repeats' " f"{len(repeats)}, the shape of the original tensor in the 'axis' dimension {size}.") - subs = ops.tensor_split(input, size, dim) + subs = core.tensor_split(input, size, dim) repeated_subs = [] for sub, rep in zip(subs, repeats): if rep != 0: - repeated_subs.append(ops.repeat_elements(sub, rep, dim)) - return ops.concat(repeated_subs, dim) - -# roll -DEVICE_TARGET = mindspore.get_context("device_target") -has_roll = hasattr(mindspore.mint, "roll") + repeated_subs.append(execute('repeat_elements', sub, rep, dim)) + return core.concat(repeated_subs, dim) +# roll def roll(input, shifts, dims=None): - if use_pyboost() and has_roll: - return mindspore.mint.roll(input, shifts, dims) - if DEVICE_TARGET == "CPU": - return mindspore.numpy.roll(input, shifts, dims) - return ops.roll(input, shifts, dims) + return execute('roll', input, shifts, dims) # searchsorted -has_searchsorted = hasattr(mindspore.mint, "searchsorted") def searchsorted( - sorted_sequence, - values, - *, - out_int32=False, - right=False, - side=None, - out=None, - sorter=None, + sorted_sequence, values, *, out_int32=False, right=False, side=None, sorter=None ): - if use_pyboost() and has_searchsorted: - if not isinstance(values, core.Tensor): - values = core.tensor(values) - return call_ms_func( - mindspore.mint.searchsorted, - sorted_sequence, - values, - out_int32=out_int32, - right=right, - side=side, - out=out, - sorter=sorter, - ) - return call_ms_func( - ops.searchsorted, - sorted_sequence, - values, - out_int32=out_int32, - right=right, - out=out, - ) - + dtype = core.int32 if bool(out_int32) else core.int64 + if (side == "left" and right is True): + raise ValueError(f"For 'searchsorted', side and right can't be set to opposites," + f"got side of left while right was True.") + if side == "right": + right = True + return execute('search_sorted', sorted_sequence, values, sorter, + dtype_to_type_id('SearchSorted', 'dtype', dtype), right) # tensordot # trace # tril -has_tril = hasattr(mindspore.mint, "tril") - - -def tril(input, diagonal=0, *, out=None): - if use_pyboost() and has_tril: - return call_ms_func(mindspore.mint.tril, input, diagonal, out=out) - return call_ms_func(ops.tril, input, diagonal, out=out) +def tril(input, diagonal=0): + return execute('tril_ext', input, diagonal) # tril_indices # triu -has_triu = hasattr(mindspore.mint, "triu") -def triu(input, diagonal=0, *, out=None): - if use_pyboost() and has_triu: - return call_ms_func(mindspore.mint.triu, input, diagonal, out=out) - return call_ms_func(ops.triu, input, diagonal, out=out) - +def triu(input, diagonal=0): + return execute('triu', input, diagonal) # triu_indices # unflatten def unflatten(x, dim, sizes): - if dim < 0: - dim = x.ndim + dim - front_part = x.shape[:dim] if dim != 0 else () - new_shape = front_part + sizes + x.shape[dim+1:] - return ops.reshape(x, new_shape) - + new_shape = x.shape[:dim] + sizes + return x.reshape(new_shape) # vander # view_as_real -def view_as_real(input): - real_part = input.real.expand_dims(-1) - imag_part = input.imag.expand_dims(-1) - return core.concat((real_part, imag_part), -1) # view_as_complex -def view_as_complex(input): - _complex = _get_cache_prim(ops.Complex)() - real_part, imag_part = input.tensor_split(2, -1) - return _complex(real_part.squeeze(-1), imag_part.squeeze(-1)) # resolve_conj @@ -893,7 +780,7 @@ def view_as_complex(input): # resolve_neg -has_masked_fill = hasattr(mindspore.mint, "masked_fill") + def masked_fill(input, mask, value): if isinstance(value, float): if value == -float('inf'): @@ -901,11 +788,7 @@ def masked_fill(input, mask, value): if value == float('inf'): value = finfo(input.dtype).max - if has_masked_fill: - return mindspore.mint.masked_fill(input, mask, value) - masked_fill_ = _get_cache_prim(ops.MaskedFill)() - return masked_fill_(input, mask, core.tensor(value, dtype=input.dtype)) - + return execute('masked_fill', input, mask, value) class finfo: def __init__(self, bits, min, max, eps, tiny, smallest_normal, resolution, dtype): @@ -982,6 +865,9 @@ def finfo(dtype): def iinfo(dtype): return iinfo_dtype[dtype] +def iinfo(dtype): + return np.iinfo(mindspore.dtype_to_nptype(dtype)) + def contains(self, key): r""" @@ -1000,37 +886,8 @@ def contains(self, key): return bool(res) -def initialize(self, init_method): - r""" - Initializes the object with the given initialization method. - - Args: - self (object): The instance of the class. - init_method (str): The method used for initialization. - This parameter determines how the data is initialized. - Valid values for `init_method` are: - - "random": Initializes the data with random values. - - "zeros": Initializes the data with zeros. - - "ones": Initializes the data with ones. - Default value is "random". - - Returns: - None. This function does not return any value. - - Raises: - None. - - Note: - This function sets the data of the object using the specified `init_method` and the object's shape and data type. - """ - self.assign_value(initializer(init_method, self.shape, self.dtype)) - - -_stop_gradient = ops.StopGradient() - - def stop_gradient(input): - return _stop_gradient(input) + return execute('stop_gradient', input) def _get_unfold_indices(input_shape, dimension, size, step): @@ -1045,70 +902,46 @@ def _get_unfold_indices(input_shape, dimension, size, step): def unfold(input, dimension, size, step): _indices, _dimension = _get_unfold_indices(input.shape, dimension, size, step) - indices = mindspore.Tensor(_indices).astype(mindspore.int32) - output = ops.gather(input, indices, axis=_dimension) - output = ops.moveaxis(output, _dimension + 1, -1) - + indices = core.tensor(_indices, device=input.device) + output = execute('gather', input, indices, _dimension) + output = core.swapaxes(output, _dimension + 1, -1) return output -def cartesian_prod(*tensors): - """ - 手动实现 torch.cartesian_prod - :param tensors: 一个或多个一维张量 - :return: 笛卡尔积结果的二维张量 (每行一个组合) - """ - # 生成网格坐标 - grids = core.meshgrid(*tensors, indexing='ij') - - # 展平每个网格张量并堆叠 - return core.stack([g.reshape(-1) for g in grids], dim=1) - - -def detach(input): - return ops.stop_gradient(input) +def contiguous(input): + return execute('contiguous', input) -def cosine_similarity(*args, **kwargs): - return core.nn.functional.cosine_similarity(*args, **kwargs) +def dyn_shape(input): + return execute('dyn_shape', input) __all__ = [ "bincount", "broadcast_shapes", "broadcast_tensors", "broadcast_to", - "bucketize", - "cartesian_prod", "cdist", "clone", "contains", "cumsum", - "cumprod", "diag", - "diff", - "dim_list_to_bitset", + "diagonal", "einsum", - "einsum_label_to_index", "finfo", "flatten", "flip", "iinfo", - "initialize", - "manual_expand", "masked_fill", - "maybe_wrap_dim", "meshgrid", "repeat_interleave", "roll", "searchsorted", "stop_gradient", - "sumproduct_pair", "tril", "triu", "unflatten", "unfold", - "histc", - "view_as_complex", - "view_as_real", - "detach", - "cosine_similarity" + "contiguous", + "ravel", + "dyn_shape", + "diff" ] diff --git a/mindnlp/core/ops/pointwise.py b/mindnlp/core/ops/pointwise.py index f053ce18f..9dfb61b53 100644 --- a/mindnlp/core/ops/pointwise.py +++ b/mindnlp/core/ops/pointwise.py @@ -1,248 +1,162 @@ """pointwise op""" -import numpy as np -import mindspore -from mindspore import ops -from ..configs import use_pyboost, ON_A1, ON_ORANGE_PI -from ._inner import call_ms_func +import math +import numbers from mindnlp import core - -# abs -has_abs = hasattr(mindspore.mint, "abs") +from mindnlp.core.executor import execute -def abs(input, *, out=None): - if use_pyboost() and has_abs: - return call_ms_func(mindspore.mint.abs, input, out=out) - return call_ms_func(ops.abs, input, out=out) +# abs +def abs(input): + return execute("abs", input) # absolute -def absolute(input, *, out=None): - return abs(input, out=out) +def absolute(input): + return abs(input) # acos -has_acos = hasattr(mindspore.mint, "acos") - - -def acos(input, *, out=None): - if use_pyboost() and has_acos: - return call_ms_func(mindspore.mint.acos, input, out=out) - return call_ms_func(ops.acos, input, out=out) +def acos(input): + return execute("acos", input) # arccos -def arrcos(input, out=None): - return acos(input, out=out) +def arrcos(input): + return acos(input) # acosh -has_acosh = hasattr(mindspore.mint, "acosh") - - -def acosh(input, *, out=None): - if use_pyboost and has_acosh: - return call_ms_func(mindspore.mint.acosh, input, out=out) - return call_ms_func(ops.acosh, input, out=out) +def acosh(input): + return execute("acosh_ext", input) # arccosh -has_arccosh = hasattr(mindspore.mint, "arccosh") - - def arccosh(input): return acosh(input) # add -has_add = hasattr(mindspore.mint, "add") - - -def add(input, other, *, alpha=1, out=None): - if use_pyboost() and has_add and not ON_ORANGE_PI: - return call_ms_func(mindspore.mint.add, input, other, alpha=alpha, out=out) +def add(input, other, *, alpha=1): if alpha != 1: - other = mul(alpha, other) - if input.dtype == mindspore.bool_: - return ops.add(input.int(), other.int()).bool() - return call_ms_func(ops.add, input, other, out=out) + return execute("add_ext", input, other, alpha) + return execute('add', input, other) # addcdiv def addcdiv(input, tensor1, tensor2, *, value=1): - return ops.addcdiv(input, tensor1, tensor2, value) + return execute("addcdiv", input, tensor1, tensor2, value) # addcmul def addcmul(input, tensor1, tensor2, *, value=1): - return ops.addcmul(input, tensor1, tensor2, value) + return execute("addcmul", input, tensor1, tensor2, value) # angle def angle(input): - return ops.angle(input) + return execute("angle", input) # asin -has_asin = hasattr(mindspore.mint, "asin") - - -def asin(input, *, out=None): - if use_pyboost and has_asin: - return call_ms_func(mindspore.mint.asin, input, out=out) - return call_ms_func(ops.asin, input, out=out) +def asin(input): + return execute("asin_ext", input) # arcsin -has_arcsin = hasattr(mindspore.mint, "arcsin") - - -def arcsin(input, *, out=None): - return asin(input, out=out) +def arcsin(input): + return asin(input) # asinh -has_asinh = hasattr(mindspore.mint, "asinh") - - -def asinh(input, *, out=None): - if use_pyboost and has_asinh: - return call_ms_func(mindspore.mint.asinh, input, out=out) - return call_ms_func(ops.asinh, input, out=out) +def asinh(input): + return execute("asinh_ext", input) # arcsinh -has_arcsinh = hasattr(mindspore.mint, "arcsinh") - - -def arcsinh(input, *, out=None): - return asinh(input, out=out) +def arcsinh(input): + return asinh(input) # atan -has_atan = hasattr(mindspore.mint, "atan") - - -def atan(input, *, out=None): - if use_pyboost and has_atan: - return call_ms_func(mindspore.mint.atan, input, out=out) - return call_ms_func(ops.atan, input, out=out) +def atan(input): + return execute("atan_ext", input) # arctan -has_arctan = hasattr(mindspore.mint, "arctan") - - -def arctan(input, *, out=None): - return atan(input, out=out) +def arctan(input): + return atan(input) # atanh -has_atanh = hasattr(mindspore.mint, "atanh") - - -def atanh(input, *, out=None): - if use_pyboost and has_atanh: - return call_ms_func(mindspore.mint.atanh, input, out=out) - return call_ms_func(ops.atanh, input, out=out) +def atanh(input): + return execute("atanh", input) # arctanh -has_arctanh = hasattr(mindspore.mint, "arctanh") - - -def arctanh(input, *, out=None): - return atanh(input, out=out) +def arctanh(input): + return atanh(input) # atan2 -has_atan2 = hasattr(mindspore.mint, "atan2") - - -def atan2(input, other, *, out=None): - if use_pyboost() and has_atan2: - return call_ms_func(mindspore.mint.atan2, input, other, out=out) - return call_ms_func(ops.atan2, input, other, out=out) +def atan2(input, other): + return execute("atan2_ext", input, other) # arctan2 -has_arctan2 = hasattr(mindspore.mint, "arctan2") - - -def arctan2(input, other, out=None): - return atan2(input, other, out=out) +def arctan2(input, other): + return atan2(input, other) # bitwise_not +def bitwise_not(input, *, out=None): + output = execute("bitwise_not", input) + if out is None: + return output + out.data = output + return out # bitwise_and -has_bitwise_and = hasattr(mindspore.mint, "bitwise_and") - - -def bitwise_and(input, other, *, out=None): - if use_pyboost() and has_bitwise_and: - return call_ms_func(mindspore.mint.bitwise_and, input, other, out=out) - return call_ms_func(ops.bitwise_and, input, other, out=out) +def bitwise_and(input, other): + if not isinstance(other, numbers.Number): + return execute("bitwise_and_tensor", input, other) + return execute("bitwise_and_scalar", input, other) # bitwise_or -has_bitwise_or = hasattr(mindspore.mint, "bitwise_or") - - -def bitwise_or(input, other, *, out=None): - if use_pyboost() and has_bitwise_or: - return call_ms_func(mindspore.mint.bitwise_or, input, other, out=out) - return call_ms_func(ops.bitwise_or, input, other, out=out) +def bitwise_or(input, other): + if not isinstance(other, numbers.Number): + return execute("bitwise_or_tensor", input, other) + return execute("bitwise_or_scalar", input, other) # bitwise_xor -has_bitwise_xor = hasattr(mindspore.mint, "bitwise_xor") - - -def bitwise_xor(input, other, *, out=None): - if use_pyboost() and has_bitwise_xor: - return call_ms_func(mindspore.mint.bitwise_xor, input, other, out=out) - return call_ms_func(ops.bitwise_xor, input, other, out=out) +def bitwise_xor(input, other): + if not isinstance(other, numbers.Number): + return execute("bitwise_xor_tensor", input, other) + return execute("bitwise_xor_scalar", input, other) # bitwise_left_shift -def bitwise_left_shift(input, other): - return ops.bitwise_left_shift(input, other) # bitwise_right_shift def bitwise_right_shift(input, other): - return ops.bitwise_right_shift(input, other) - + return execute('right_shift', input, other) # ceil -has_ceil = hasattr(mindspore.mint, "ceil") - - -def ceil(input, *, out=None): - if use_pyboost() and has_ceil: - return call_ms_func(mindspore.mint.ceil, input, out=out) - return call_ms_func(ops.ceil, input, out=out) +def ceil(input): + return execute("ceil", input) # clamp -has_clamp = hasattr(mindspore.mint, "clamp") - - -def clamp(input, min=None, max=None, *, out=None): - if use_pyboost() and has_clamp: - return call_ms_func(mindspore.mint.clamp, input, min, max, out=out) - return call_ms_func(ops.clamp, input, min, max, out=out) - +def clamp(input, min=None, max=None): + if isinstance(min, numbers.Number) or isinstance(max, numbers.Number): + return execute("clamp_scalar", input, min, max) + return execute("clamp_tensor", input, min, max) -def clamp_min(input, min): - return clamp(input, min) # clip -has_clip = hasattr(mindspore.mint, "clip") - - def clip(input, min=None, max=None): return clamp(input, min, max) @@ -254,181 +168,78 @@ def clip(input, min=None, max=None): # cos -has_cos = hasattr(mindspore.mint, "cos") - - -def cos(input, *, out=None): - if use_pyboost() and has_cos: - return call_ms_func(mindspore.mint.cos, input, out=out) - return call_ms_func(ops.cos, input, out=out) +def cos(input): + return execute("cos", input) # cosh -has_cosh = hasattr(mindspore.mint, "cosh") - - -def cosh(input, *, out=None): - if use_pyboost() and has_cosh: - return call_ms_func(mindspore.mint.cosh, input, out=out) - return call_ms_func(ops.cosh, input, out=out) +def cosh(input): + return execute("cosh", input) # deg2rad def deg2rad(input): - return ops.deg2rad(input) + return input * math.pi / 180.0 # div -has_div = hasattr(mindspore.mint, "div") - - -def div(input, other, *, rounding_mode=None, out=None): - if isinstance(other, mindspore.Tensor): - other = other.to(input.dtype) - - if isinstance(other, np.number): - other = other.item() - - if use_pyboost() and has_div: - return call_ms_func( - mindspore.mint.div, input, other, rounding_mode=rounding_mode, out=out +def div(input, other, *, rounding_mode=None): + if rounding_mode is not None and rounding_mode not in ["floor", "trunc"]: + raise ValueError( + "For ops.div, rounding_mode value should be None, 'floor' or 'trunc'." + ) + if rounding_mode: + output = execute( + "divmod", + input, + other, + rounding_mode ) - return call_ms_func(ops.div, input, other, rounding_mode=rounding_mode, out=out) + else: + output = execute("div", input, other) + return output # divide -has_divide = hasattr(mindspore.mint, "divide") - - -def divide(input, other, rounding_mode=None): - return div(input, other, rounding_mode=rounding_mode) +def divide(input, other): + return div(input, other) # digamma -def digamma(input): - return ops.digamma(input) # erf -has_erf = hasattr(mindspore.mint, "erf") - - -def erf(input, *, out=None): - if use_pyboost() and has_erf: - return call_ms_func(mindspore.mint.erf, input, out=out) - return call_ms_func(ops.erf, input, out=out) +def erf(input): + return execute("erf", input) # erfc -has_erfc = hasattr(mindspore.mint, "erfc") - - -def erfc(input, *, out=None): - if use_pyboost() and has_erfc: - return call_ms_func(mindspore.mint.erfc, input, out=out) - return call_ms_func(ops.erfc, input, out=out) +def erfc(input): + return execute("erfc", input) # erfinv -has_erfinv = hasattr(mindspore.mint, "erfinv") - - -def erfinv(input, *, out=None): - if ON_ORANGE_PI: - return erfinv_torch(input) - if use_pyboost() and has_erfinv: - return call_ms_func(mindspore.mint.erfinv, input, out=out) - return call_ms_func(ops.erfinv, input, out=out) - -def erfinv_torch(x): - """ - 使用有理函数近似实现erfinv,适用于PyTorch张量 - """ - # # 检查输入范围 - # if core.any((x < -1) | (x > 1)): - # raise ValueError("erfinv(x) is only defined for x in [-1, 1]") - - # 处理边界情况 - sign = core.where(x > 0, 1.0, -1.0) - x = core.abs(x) - - # Cody的有理函数近似 - mask = x <= 0.7 - x_sq = x * x - - # 对于x <= 0.7的情况 - p1 = 0.426170613044 + x_sq * (-0.304570194263 + x_sq * 0.152645863430) - q1 = 1.0 + x_sq * (-0.733058978416 + x_sq * 0.546875000000) - result1 = x * (p1 / q1) - - # 对于x > 0.7的情况 - t = core.sqrt(-core.log((1.0 - x)/2.0)) - p2 = -0.322232431088 + t * (-1.00002368368 + t * (-0.342242088547 + - t * (-0.0204231210245 + t * (-0.0000453642210148)))) - q2 = 0.460398842078 + t * (0.588581570495 + t * (0.531103462366 + - t * (0.103537752850 + t * 0.0038560700634))) - result2 = p2 / q2 - - # 合并结果 - result = core.where(mask, result1, result2) - - return sign * result - -# exp -has_exp = hasattr(mindspore.mint, "exp") -has_inplace_exp = hasattr(mindspore.Tensor, "exp_") +def erfinv(input): + return execute("erfinv", input) +# exp def exp(input, out=None): - if has_inplace_exp: - return inplace_exp(input, out) - - if use_pyboost() and has_exp: - output = mindspore.mint.exp(input) - else: - output = ops.exp(input) + output = execute("exp", input) if out is not None: - # out.data = output - out.assign_value(output) + out.data = output else: return output -def inplace_exp(input, out=None): - if out is None: - if use_pyboost() and has_exp: - output = mindspore.mint.exp(input) - else: - output = ops.exp(input) - return output - - if out is input: - return out.exp_() - else: - out.copy_(input) - return out.exp_() - - # exp2 -has_exp2 = hasattr(mindspore.mint, "exp2") - - def exp2(input): - if use_pyboost() and has_exp2: - return mindspore.mint.exp2(input) - return pow(2, input) + return execute("exp2", input) # expm1 -has_expm1 = hasattr(mindspore.mint, "expm1") - - -def expm1(input, *, out=None): - if input.dtype == mindspore.float64: - return expm1(input.float(), out=out).double() - if use_pyboost() and has_expm1: - return call_ms_func(mindspore.mint.expm1, input, out=out) - return call_ms_func(ops.expm1, input, out=out) +def expm1(input): + return execute("expm1", input) # fake_quantize_per_channel_affine @@ -441,206 +252,129 @@ def expm1(input, *, out=None): # float_power -has_float_power = hasattr(mindspore.mint, "float_power") - - def float_power(input, exponent): - if use_pyboost() and has_float_power: - return mindspore.mint.float_power(input, exponent) - return ops.float_power(input, exponent) + if isinstance(input, core.Tensor) and isinstance(exponent, numbers.Number): + return execute("pow_tensor_scalar", input, exponent) + if isinstance(input, numbers.Number) and isinstance(exponent, core.Tensor): + return execute("pow_scalar_tensor", input, exponent) - -# floor -has_floor = hasattr(mindspore.mint, "floor") + return pow(input, exponent) -def floor(input, *, out=None): - if use_pyboost() and has_floor: - return call_ms_func(mindspore.mint.floor, input, out=out) - return call_ms_func(ops.floor, input, out=out) +# floor +def floor(input): + return execute("floor", input) # floor_divide def floor_divide(input, other): - return ops.floor_divide(input, other) + return execute("floor_div", input, other) # fmod -has_fmod = hasattr(mindspore.mint, "fmod") - - def fmod(input, other): - if use_pyboost() and has_fmod: - return mindspore.mint.fmod(input, other) - return ops.fmod(input, other) + if isinstance(input, core.Tensor) and isinstance(other, numbers.Number): + return execute("fmod_scalar", input, other) + if isinstance(input, numbers.Number) and isinstance(other, core.Tensor): + return execute("fmod_scalar", input, other) + return execute("fmod_tensor", input, other) # frac -has_frac = hasattr(mindspore.mint, "frac") - - def frac(input): - if use_pyboost() and has_frac: - return mindspore.mint.frac(input) - return fmod(input, 1) + return execute("frac", input) # frexp # imag -def imag(input): - return ops.imag(input) # ldexp def ldexp(input, other, out=None): - output = ops.ldexp(input, other) - if out is not None: - out.data = output - return out - return output + # output = input * core.pow(2.0, other) + output = execute('ldexp', input, other) + if out is None: + return output + out.copy_(output) + return out # lerp -has_lerp = hasattr(mindspore.mint, "lerp") - - def lerp(input, end, weight): - if use_pyboost() and has_lerp: - return mindspore.mint.lerp(input, end, weight) - return ops.lerp(input, end, weight) + return execute("lerp", input, end, weight) # lgamma -def lgamma(input): - return ops.lgamma(input) # log -has_log = hasattr(mindspore.mint, "log") - - -def log(input, *, out=None): - if use_pyboost() and has_log: - return call_ms_func(mindspore.mint.log, input, out=out) - return call_ms_func(ops.log, input, out=out) +def log(input): + return execute("log", input) # log10 -# log1p -has_log1p = hasattr(mindspore.mint, "log1p") - -def log1p(input, *, out=None): - if use_pyboost() and has_log1p: - return call_ms_func(mindspore.mint.log1p, input, out=out) - return call_ms_func(ops.log1p, input, out=out) +# log1p +def log1p(input): + return execute("log1p", input) # log2 -has_log2 = hasattr(mindspore.mint, "log2") - - def log2(input): - if use_pyboost() and has_log2: - return mindspore.mint.log2(input) - return ops.log2(input) + return execute("log2", input) # logaddexp +def logaddexp(input, other): + return execute("logaddexp", input, other) # logaddexp2 # logical_and -has_logical_and = hasattr(mindspore.mint, "logical_and") - - -def logical_and(input, other, *, out=None): - if use_pyboost() and has_logical_and: - return call_ms_func(mindspore.mint.logical_and, input, other, out=out) - return call_ms_func(ops.logical_and, input, other, out=out) +def logical_and(input, other): + return execute("logical_and", input, other) # logical_not -has_logical_not = hasattr(mindspore.mint, "logical_not") - - -def logical_not(input, *, out=None): - if use_pyboost() and has_logical_not: - return call_ms_func(mindspore.mint.logical_not, input, out=out) - return call_ms_func(ops.logical_not, input, out=out) +def logical_not(input): + return execute("logical_not", input) # logical_or -has_logical_or = hasattr(mindspore.mint, "logical_or") - - -def logical_or(input, other, *, out=None): - if use_pyboost() and has_logical_or: - return call_ms_func(mindspore.mint.logical_or, input, other, out=out) - return call_ms_func(ops.logical_or, input, other, out=out) +def logical_or(input, other): + return execute("logical_or", input, other) # logical_xor -has_logical_xor = hasattr(mindspore.mint, "logical_xor") - - -def logical_xor(input, other, *, out=None): - if use_pyboost() and has_logical_xor: - return call_ms_func(mindspore.mint.logical_xor, input, other, out=out) - return call_ms_func(ops.logical_xor, input, other, out=out) +def logical_xor(input, other): + return execute("logical_xor", input, other) # logit -def logit(input, eps=None): - return ops.logit(input, eps) # hypot -def hypot(input, other): - return ops.hypot(input, other) # i0 # igamma -def igamma(input, other): - return ops.igamma(input, other) # igammac -def igammac(input, other): - return ops.igammac(input, other) # mul -has_mul = hasattr(mindspore.mint, "mul") - +def mul(input, other): + # if isinstance(other, (float, int, bool)) and isinstance(input, torch.Tensor): + # return execute("muls", input, other) + return execute("mul", input, other) -def mul(input, other, *, out=None): - if use_pyboost() and has_mul and not ON_ORANGE_PI: - out = mindspore.mint.mul(input, other) - else: - if input.dtype == mindspore.bool_: - if isinstance(other, bool): - if ON_ORANGE_PI: - out = ops.bitwise_and(input.int(), other).bool() - else: - out = ops.bitwise_and(input, other) - else: - out = ops.mul(input.int(), other) - else: - out = ops.mul(input, other) - return out - - if isinstance(other, mindspore.Tensor): - out_dtype = min(input.dtype, other.dtype) - return out.to(out_dtype) - return out # multiply def multiply(input, other): @@ -648,87 +382,29 @@ def multiply(input, other): # mvlgamma -def mvlgamma(input, p): - return ops.mvlgamma(input, p) # nan_to_num -has_nan_to_num = hasattr(mindspore.mint, "nan_to_num") - - -def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): - if use_pyboost() and has_nan_to_num and not ON_A1: - return call_ms_func( - mindspore.mint.nan_to_num, input, nan, posinf, neginf, out=out - ) - - # 创建输入张量的副本 - output = input.clone() - # 获取数据类型信息 - if output.is_floating_point(): - dtype = output.dtype - # 获取默认替换值 - f_info = core.finfo(dtype) - default_posinf = f_info.max if posinf is None else posinf - default_neginf = f_info.min if neginf is None else neginf - else: - # 对于整数类型,使用给定值或默认值 - default_posinf = core.iinfo(dtype).max if posinf is None else posinf - default_neginf = core.iinfo(dtype).min if neginf is None else neginf - - # 替换 NaN - if core.isnan(output).any(): - output = core.where( - core.isnan(output), - core.tensor(nan, dtype=output.dtype, device=output.device), - output, - ) - - # 替换正无穷大 - if core.isinf(output).any() and (posinf is not None or output.is_floating_point()): - output = core.where( - (output == float("inf")) & core.isinf(output), - core.tensor(default_posinf, dtype=output.dtype, device=output.device), - output, - ) - - # 替换负无穷大 - if core.isinf(output).any() and (neginf is not None or output.is_floating_point()): - output = core.where( - (output == float("-inf")) & core.isinf(output), - core.tensor(default_neginf, dtype=output.dtype, device=output.device), - output, - ) - - return output +def nan_to_num(input, nan=0.0, posinf=None, neginf=None): + return execute("nan_to_num", input, nan, posinf, neginf) # neg -has_neg = hasattr(mindspore.mint, "neg") - - -def neg(input, *, out=None): - if use_pyboost() and has_neg: - return call_ms_func(mindspore.mint.neg, input, out=out) - return call_ms_func(ops.neg, input, out=out) +def neg(input): + return execute("neg", input) # negative -has_negative = hasattr(mindspore.mint, "negative") - - def negative(input): return neg(input) # nextafter def nextafter(input, other): - return ops.nextafter(input, other) + return execute("next_after", input, other) # polygamma -def polygamma(n, input): - return ops.polygamma(n, input) # positive @@ -737,13 +413,12 @@ def positive(input): # pow -has_pow = hasattr(mindspore.mint, "pow") - - -def pow(input, exponent, *, out=None): - if use_pyboost() and has_pow: - return call_ms_func(mindspore.mint.pow, input, exponent, out=out) - return call_ms_func(ops.pow, input, exponent, out=out) +def pow(input, exponent): + if isinstance(input, core.Tensor) and isinstance(exponent, numbers.Number): + return execute("pow_tensor_scalar", input, exponent) + if isinstance(input, numbers.Number) and isinstance(exponent, core.Tensor): + return execute("pow_scalar_tensor", input, exponent) + return execute("pow", input, exponent) # quantized_batch_norm @@ -756,175 +431,112 @@ def pow(input, exponent, *, out=None): # rad2deg -def rad2deg(input): - return ops.rad2deg(input) # real -def real(input): - return ops.real(input) # reciprocal -has_reciprocal = hasattr(mindspore.mint, "reciprocal") - - -def reciprocal(input, *, out=None): - if use_pyboost() and has_reciprocal: - return call_ms_func(mindspore.mint.reciprocal, input, out=out) - return call_ms_func(ops.reciprocal, input, out=out) +def reciprocal(input): + return execute("reciprocal", input) # remainder -has_remainder = hasattr(mindspore.mint, "remainder") - - -def remainder(input, other, *, out=None): - if use_pyboost() and has_remainder: - return call_ms_func(mindspore.mint.remainder, input, other, out=out) - return call_ms_func(ops.remainder, input, other, out=out) +def remainder(input, other): + if isinstance(input, core.Tensor) and isinstance(other, numbers.Number): + return execute("remainder_tensor_scalar", input, other) + if isinstance(input, numbers.Number) and isinstance(other, core.Tensor): + return execute("remainder_scalar_tensor", input, other) + return execute("remainder_tensor_tensor", input, other) # round -has_round = hasattr(mindspore.mint, "round") - - -def round(input, *, decimals=0): - if use_pyboost() and has_round: - return mindspore.mint.round(input, decimals=decimals) - return ops.round(input, decimals=decimals) +def round(input): + return execute("round", input) # rsqrt -has_rsqrt = hasattr(mindspore.mint, "rsqrt") - - -def rsqrt(input, *, out=None): - if use_pyboost() and has_rsqrt: - return call_ms_func(mindspore.mint.rsqrt, input, out=out) - return call_ms_func(ops.rsqrt, input, out=out) +def rsqrt(input): + return execute("rsqrt", input) # sigmoid -has_sigmoid = hasattr(mindspore.mint, "sigmoid") - - -def sigmoid(input, *, out=None): - if use_pyboost() and has_sigmoid: - return call_ms_func(mindspore.mint.sigmoid, input, out=out) - return call_ms_func(ops.sigmoid, input, out=out) +def sigmoid(input): + return execute("sigmoid", input) # sign -has_sign = hasattr(mindspore.mint, "sign") - - -def sign(input, *, out=None): - if use_pyboost() and has_sign: - return call_ms_func(mindspore.mint.sign, input, out=out) - return call_ms_func(ops.sign, input, out=out) +def sign(input): + return execute("sign", input) # sgn # signbit -# sin -has_sin = hasattr(mindspore.mint, "sin") - -def sin(input, *, out=None): - if use_pyboost() and has_sin: - return call_ms_func(mindspore.mint.sin, input, out=out) - return call_ms_func(ops.sin, input, out=out) +# sin +def sin(input): + return execute("sin", input) # sinc -has_sinc = hasattr(mindspore.mint, "sinc") - - -def sinc(input, *, out=None): - if use_pyboost() and has_sinc: - return call_ms_func(mindspore.mint.sinc, input, out=out) - return call_ms_func(ops.sinc, input, out=out) +def sinc(input): + return execute("sinc", input) # sinh -has_sinh = hasattr(mindspore.mint, "sinh") - - -def sinh(input, *, out=None): - if use_pyboost() and has_sinh: - return call_ms_func(mindspore.mint.sinh, input, out=out) - return call_ms_func(ops.sinh, input, out=out) +def sinh(input): + return execute("sinh", input) # softmax -def softmax(input, dim, *, dtype=None): - if use_pyboost(): - return mindspore.mint.nn.functional.softmax(input, dim, dtype=dtype) - return ops.softmax(input, dim, dtype=dtype) - - -def log_softmax(input, dim=None, dtype=None): - return core.nn.functional.log_softmax(input, dim, dtype) +def softmax(input, dim=-1, *, dtype=None): + return execute("softmax", input, dim) # sqrt -has_sqrt = hasattr(mindspore.mint, "sqrt") - - -def sqrt(input, *, out=None): - if use_pyboost() and has_sqrt: - return call_ms_func(mindspore.mint.sqrt, input, out=out) - return call_ms_func(ops.sqrt, input, out=out) +def sqrt(input): + return execute("sqrt", input) # square -has_square = hasattr(mindspore.mint, "square") - - -def square(input, *, out=None): - if use_pyboost() and has_square: - return call_ms_func(mindspore.mint.square, input, out=out) - return call_ms_func(ops.square, input, out=out) +def square(input): + return execute("square", input) # sub -has_sub = hasattr(mindspore.mint, "sub") - - def sub(input, other, *, alpha=1, out=None): - if isinstance(other, mindspore.Tensor): - other = other.to(input.dtype) - if use_pyboost() and has_sub: - return call_ms_func(mindspore.mint.sub, input, other, alpha=alpha, out=out) - return call_ms_func(ops.sub, input, other, out=out) - + if not isinstance(input, numbers.Number) and not isinstance(other, numbers.Number): + device = max(input.device, other.device) + input = input.to(device) + other = other.to(device) + elif isinstance(input, numbers.Number): + device = other.device + else: + device = input.device + if device == 'cpu': + output = execute("sub", input, alpha * other) + else: + output = execute("sub_ext", input, other, alpha) + if out is None: + return output + out.copy_(output) + return out # subtract -def subtract(input, other): - return sub(input, other) +def subtract(input, other, *, alpha=1, out=None): + return sub(input, other, alpha=alpha, out=out) # tan -has_tan = hasattr(mindspore.mint, "tan") - - -def tan(input, *, out=None): - if use_pyboost() and has_tan: - return call_ms_func(mindspore.mint.tan, input, out=out) - return call_ms_func(ops.tan, input, out=out) +def tan(input): + return execute("tan", input) # tanh -has_tanh = hasattr(mindspore.mint, "tanh") - - -def tanh(input, *, out=None): - if use_pyboost() and has_tanh: - return call_ms_func(mindspore.mint.tanh, input, out=out) - return call_ms_func(ops.tanh, input, out=out) +def tanh(input): + return execute("tanh", input) # true_divide @@ -933,30 +545,29 @@ def true_divide(input, other): # trunc -has_trunc = hasattr(mindspore.mint, "trunc") - - -def trunc(input, *, out=None): - if use_pyboost() and has_trunc: - return call_ms_func(mindspore.mint.trunc, input, out=out) - return call_ms_func(ops.trunc, input, out=out) +def trunc(input): + return execute("trunc", input) # xlogy -has_xlogy = hasattr(mindspore.mint, "xlogy") - - -def xlogy(input, other, *, out=None): - if use_pyboost() and has_xlogy: - return call_ms_func(mindspore.mint.xlogy, input, other, out=out) - return call_ms_func(ops.xlogy, input, other, out=out) - +def xlogy(input, other): + if isinstance(input, core.Tensor) and isinstance(other, core.Tensor): + return execute("xlogy", input, other) + if isinstance(input, core.Tensor) and isinstance(other, (float, int, bool)): + return execute("xlogy_scalar_other", input, other) + if isinstance(input, (float, int, bool)) and isinstance(other, core.Tensor): + return execute("xlogy_scalar_self", input, other) + raise TypeError(f"For 'xlogy', at least one of input and other should be Tensor.") # relu def relu(input): - if use_pyboost(): - return mindspore.mint.nn.functional.relu(input) - return ops.relu(input) + return execute('relu', input) + + +def log_softmax(input, dim=None, dtype=None): + if input.device.type == 'cpu': + return execute('log_softmax', input, dim) + return execute('log_softmax_ext', input, dim, dtype) __all__ = [ @@ -980,19 +591,17 @@ def relu(input): "atan", "atan2", "atanh", + "bitwise_not", "bitwise_and", - "bitwise_left_shift", "bitwise_or", - "bitwise_right_shift", "bitwise_xor", + "bitwise_right_shift", "ceil", "clamp", - "clamp_min", "clip", "cos", "cosh", "deg2rad", - "digamma", "div", "divide", "erf", @@ -1006,13 +615,8 @@ def relu(input): "floor_divide", "fmod", "frac", - "hypot", - "igamma", - "igammac", - "imag", "ldexp", "lerp", - "lgamma", "log", "log1p", "log2", @@ -1020,20 +624,14 @@ def relu(input): "logical_not", "logical_or", "logical_xor", - "logit", - "log_softmax", "mul", "multiply", - "mvlgamma", "nan_to_num", "neg", "negative", "nextafter", - "polygamma", "positive", "pow", - "rad2deg", - "real", "reciprocal", "remainder", "round", @@ -1054,4 +652,5 @@ def relu(input): "trunc", "xlogy", "relu", + "log_softmax" ] diff --git a/mindnlp/core/ops/random.py b/mindnlp/core/ops/random.py index 2e8d993bb..2bd3ed578 100644 --- a/mindnlp/core/ops/random.py +++ b/mindnlp/core/ops/random.py @@ -1,166 +1,367 @@ """random op""" -import numpy as np -import mindspore -from mindspore import ops -from mindspore.ops._primitive_cache import _get_cache_prim -from ..configs import use_pyboost, DEVICE_TARGET, ON_A1 -from .other import cumsum, searchsorted -from .comparison import topk -from .pointwise import div, log -from .._bind import get_default_dtype -from ._inner import call_ms_func -from .._C import default_generator +from mindnlp import core +from mindnlp.core._C import default_generator +from mindnlp.core.executor import execute +from .._bind import get_default_dtype, get_device_in_context +from ..configs import ON_A1 + +generator_step_ = 12 + # bernoulli -has_bernoulli = hasattr(mindspore.mint, 'bernoulli') -def bernoulli(input, *, generator=None, out=None, **kwargs): - p = kwargs.pop('p', 0.5) - if use_pyboost() and has_bernoulli: - return call_ms_func(mindspore.mint.bernoulli, input, generator=generator, out=out) - random_numbers = rand(*input.shape, dtype=mindspore.float32) - samples = random_numbers < p - samples = samples.int() +def bernoulli(input, *, generator=None, out=None): + if generator is None: + generator = default_generator + seed, offset = generator._step(generator_step_) # pylint: disable=protected-access + output = execute("bernoulli_ext", input, seed, offset) if out is None: - return samples - else: - return out.copy_(samples) + return output + out.data = output + return out + # multinomial -has_multinomial = hasattr(mindspore.mint, 'multinomial') -def multinomial(input, num_samples, replacement=False, *, generator=None): +def multinomial(input, num_samples, replacement=False, *, generator=None, out=None): """custom multinomial""" - if use_pyboost() and has_multinomial and not ON_A1: - return mindspore.mint.multinomial(input, num_samples, replacement=replacement, generator=generator) - if replacement: - # with replacement - cumulative_probs = cumsum(input, dim=-1) - uniform_samples = rand(*input.shape[:-1] + (num_samples,)) - if cumulative_probs.dtype == mindspore.float16: - cumulative_probs = cumulative_probs.astype(mindspore.float32) - samples = searchsorted(cumulative_probs, uniform_samples, right=True) + if generator is None: + generator = default_generator + if not ON_A1: + output = execute("multinomial_ext", input, num_samples, replacement, generator) + else: - # without replacement - n_dist = 1 - if input.ndim > 1: - n_dist = input.shape[-2] - random_uniform = rand(*(n_dist * input.shape[-1],)) - if n_dist != 1: - random_uniform = random_uniform.reshape(n_dist, input.shape[-1]) - - vals = div(log(random_uniform), input + 1e-10) - _, samples = topk(vals, num_samples) - - return samples.astype(mindspore.int64) + if replacement: + # with replacement + cumulative_probs = core.cumsum(input, dim=-1) + uniform_samples = rand(*input.shape[:-1] + (num_samples,), device=input.device) + if cumulative_probs.dtype == core.float16: + cumulative_probs = cumulative_probs.astype(core.float32) + samples = core.searchsorted(cumulative_probs, uniform_samples, right=True) + else: + # without replacement + n_dist = 1 + if input.ndim > 1: + n_dist = input.shape[-2] + random_uniform = rand(*(n_dist * input.shape[-1],), device=input.device) + if n_dist != 1: + random_uniform = random_uniform.reshape(n_dist, input.shape[-1]) + + vals = core.div(core.log(random_uniform), input + 1e-10) + _, samples = core.topk(vals, num_samples) + + output = samples.astype(core.int64) + + if out is None: + return output + out.data = output + return out + # normal -has_normal = hasattr(mindspore.mint, 'normal') -def normal(mean=0.0, std=1.0, size=None, *, generator=None, out=None): - if use_pyboost() and has_normal: - mean = float(mean) if isinstance(mean, int) else mean - mean = float(std) if isinstance(std, int) else std - return call_ms_func(mindspore.mint.normal, mean, std, size, generator, out=out) - if size is None: - if isinstance(mean, mindspore.Tensor): - size = mean.shape +def normal(mean=0.0, std=1.0, *, size=None, generator=None, out=None, + dtype=None, layout=None, device=None, pin_memory=None, requires_grad=False): + if generator is None: + generator = default_generator + seed, offset = generator._step(generator_step_) # pylint: disable=protected-access + if device is None: + if out is None: + device = get_device_in_context() else: - size = () - return call_ms_func(ops.normal, size, mean, std, out=out) + device = out.device + + is_mean_tensor = isinstance(mean, core.Tensor) + is_std_tensor = isinstance(std, core.Tensor) + + if device.type == 'cpu': + if is_mean_tensor and is_std_tensor: + size = (mean * std).shape + if is_mean_tensor and not is_std_tensor: + size = mean.shape + if not is_mean_tensor and is_std_tensor: + size = std.shape + if out is not None: + size = out.shape + output = execute('normal', size) + output = output * std - mean + + else: + if is_mean_tensor and is_std_tensor: + output = execute("normal_tensor_tensor", mean, std, seed, offset, device=device) + if is_mean_tensor and not is_std_tensor: + output = execute("normal_tensor_float", mean, std, seed, offset, device=device) + if not is_mean_tensor and is_std_tensor: + output = execute("normal_float_tensor", mean, std, seed, offset, device=device) + if out is not None: + size = out.shape + output = execute("normal_float_float", float(mean), float(std), size, seed, offset, device=device) + + if out is None: + return output + out.data = output + return out # poisson # rand -has_rand = hasattr(mindspore.mint, 'rand') -def rand(*size, generator=None, out=None, dtype=None, device=None, pin_memory=False, **kwargs): - size = kwargs.pop('size', size) - if size[0] == []: - size = () - elif isinstance(size[0], (tuple, list)): - size = size[0] +def rand( + *size, + generator=None, + out=None, + dtype=None, + layout=None, + device=None, + requires_grad=False, + pin_memory=False +): + if device is None: + device = get_device_in_context() + if isinstance(device, str): + device = core.device(device) if dtype is None: dtype = get_default_dtype() - if use_pyboost() and has_rand: - return call_ms_func(mindspore.mint.rand, *size, generator=generator, dtype=dtype, out=out) - return call_ms_func(ops.rand, *size, dtype=dtype, out=out) + if not generator: + generator = default_generator + seed, offset = generator._step(generator_step_) # pylint: disable=protected-access + if size and isinstance(size[0], (tuple, list)): + size = size[0] + if device.type == 'cpu': + output = execute('uniform_real', size, + device=device, requires_grad=requires_grad, user_created=True).to(dtype) + else: + output = execute( + "rand_ext", + size, + seed, + offset, + dtype, + device=device, + requires_grad=requires_grad, + user_created=True, + ) + if out is None: + return output + out.data = output + return out + # rand_like -has_rand_like = hasattr(mindspore.mint, 'rand_like') -def rand_like(input, *, dtype=None): - if use_pyboost() and has_rand_like: - return mindspore.mint.rand_like(input, dtype=dtype) - return ops.rand_like(input, dtype=dtype) +def rand_like( + input, + *, + dtype=None, + layout=None, + device=None, + requires_grad=False, + memory_format=None +): + if device is None: + device = input.device + if isinstance(device, str): + device = core.device(device) + + if dtype is None: + dtype = input.dtype + seed, offset = default_generator._step( # pylint: disable=protected-access + generator_step_ + ) + return execute( + "rand_like_ext", + input, + seed, + offset, + dtype, + device=device, + requires_grad=requires_grad, + ) + # randint -has_randint = hasattr(mindspore.mint, 'randint') -def randint(*args, **kwargs): - device = kwargs.pop('device', None) - low = kwargs.pop('low', None) - high = kwargs.pop('high', None) - size = kwargs.pop('size', None) - if low is not None: - args += (low,) - if high is not None: - args += (high,) - - if size is not None: - args += (size,) +def randint( + *args, + generator=None, + out=None, + dtype=None, + layout=None, + device=None, + requires_grad=False, + **kwargs +): + if dtype is None: + dtype = core.int64 + if device is None: + device = get_device_in_context() + if isinstance(device, str): + device = core.device(device) + + if not generator: + generator = default_generator + args = list(args) + if len(args) == 2: + args = [0] + args + output = execute( + "randint", + *args, + dtype, + generator, + device=device, + ) + if out is None: + return output + out.data = output + return out - if use_pyboost() and has_randint: - return mindspore.mint.randint(*args, **kwargs) - return ops.randint(*args, **kwargs) # randint_like -def randint_like(*args, **kwargs): - if use_pyboost() and has_randint: - return mindspore.mint.randint_like(*args, **kwargs) - return ops.randint_like(*args, **kwargs) +def randint_like( + input, + low, + high=0, + *, + dtype=None, + layout=None, + device=None, + requires_grad=False, + memory_format=None +): + if high == 0: + low, high = 0, low + if device is None: + device = input.device + if isinstance(device, str): + device = core.device(device) + + if dtype is None: + dtype = input.dtype + seed, offset = default_generator._step( # pylint: disable=protected-access + generator_step_ + ) + return execute( + "randint_like_ext", + input, + low, + high, + seed, + offset, + dtype_to_type_id("RandIntLike", "dtype", dtype), + device=device, + requires_grad=requires_grad, + ) + # randn -has_randn = hasattr(mindspore.mint, 'randn') -def randn(*size, generator=None, dtype=None, **kwargs): - if isinstance(size[0], tuple): - size = size[0] - size = kwargs.pop('size', size) - new_size = () - for s in size: - if isinstance(s, np.integer): - s = s.item() - new_size += (s,) +def randn( + *size, + generator=None, + out=None, + dtype=None, + layout=None, + device=None, + requires_grad=False, + pin_memory=False +): + if device is None: + device = get_device_in_context() + if isinstance(device, str): + device = core.device(device) + if dtype is None: dtype = get_default_dtype() - if use_pyboost() and has_randn: - return mindspore.mint.randn(*new_size, generator=generator, dtype=dtype) - # return ops.randn(*new_size, dtype=dtype) if not generator: generator = default_generator - seed, _ = generator._step(12) - rng = np.random.default_rng(seed.item()) - return mindspore.Tensor(rng.standard_normal(new_size), dtype=dtype) + seed, offset = generator._step(generator_step_) # pylint: disable=protected-access + if size and isinstance(size[0], (tuple, list)): + size = size[0] + output = execute( + "randn", + size, + seed, + offset, + dtype, + device=device, + requires_grad=requires_grad, + user_created=True, + ) + if out is None: + return output + out.data = output + return out + # randn_like -has_randn_like = hasattr(mindspore.mint, 'randn_like') -def randn_like(input, *, dtype=None): - if use_pyboost() and has_randn_like: - return mindspore.mint.randn_like(input, dtype=dtype) - return ops.randn_like(input, dtype=dtype) +def randn_like( + input, + *, + dtype=None, + layout=None, + device=None, + requires_grad=False, + memory_format=None +): + if device is None: + device = input.device + if isinstance(device, str): + device = core.device(device) + + if dtype is None: + dtype = input.dtype + seed, offset = default_generator._step( # pylint: disable=protected-access + generator_step_ + ) + return execute( + "rand_like_ext", + input, + seed, + offset, + dtype_to_type_id("RandnLike", "dtype", dtype), + device=device, + requires_grad=requires_grad, + ) + # randperm -has_randperm = hasattr(mindspore.mint, 'randperm') -def randperm(n, *, generator=None, dtype=mindspore.int64): - """randperm""" - if use_pyboost() and has_randperm: - return mindspore.mint.randperm(n, generator=generator, dtype=dtype) - if DEVICE_TARGET == 'CPU': - seed, offset = 0, 0 - randperm_v2_op = _get_cache_prim(ops.RandpermV2)(seed, offset, dtype) - return randperm_v2_op(n) - - randperm_op = _get_cache_prim(ops.Randperm)(max_length=n, dtype=dtype) - return randperm_op(mindspore.tensor([n])) - -def gamma(shape, alpha, beta): - if DEVICE_TARGET != 'Ascend': - return mindspore.tensor(np.random.gamma(alpha, 1/beta, shape)) - return ops.gamma(shape, alpha, beta) - -__all__ = ['bernoulli', 'gamma', 'multinomial', 'normal', 'rand', 'rand_like', 'randint', 'randn', 'randn_like', 'randperm', 'randint_like'] +def randperm( + n, + *, + generator=None, + out=None, + dtype=core.int64, + layout=None, + device=None, + requires_grad=False, + pin_memory=False +): + if device is None: + device = get_device_in_context() + if isinstance(device, str): + device = core.device(device) + + if not generator: + generator = default_generator + seed, offset = generator._step(generator_step_) # pylint: disable=protected-access + output = execute( + "randperm_ext", + n, + seed, + offset, + dtype_to_type_id("RandpermExt", "dtype", dtype), + device=device, + requires_grad=requires_grad, + ) + if out is None: + return output + out.data = output + return out + + +__all__ = [ + "bernoulli", + "multinomial", + "normal", + "rand", + "rand_like", + "randint", + "randn", + "randn_like", + "randperm", + "randint_like", +] diff --git a/mindnlp/core/ops/reduction.py b/mindnlp/core/ops/reduction.py index 3a142ffc7..b5be90bb4 100644 --- a/mindnlp/core/ops/reduction.py +++ b/mindnlp/core/ops/reduction.py @@ -1,44 +1,27 @@ """reduction op""" -import numbers from collections import namedtuple -import mindspore -from mindspore import ops -from mindspore.ops._primitive_cache import _get_cache_prim -from ..configs import use_pyboost, DEVICE_TARGET, ON_ORANGE_PI -from ._inner import call_ms_func from mindnlp import core +from mindnlp.core.executor import execute max_out = namedtuple('max_out', ['values', 'indices']) min_out = namedtuple('min_out', ['values', 'indices']) # argmax -has_argmax = hasattr(mindspore.mint, 'argmax') def argmax(input, dim=None, keepdim=False): - if use_pyboost() and has_argmax: - return mindspore.mint.argmax(input, dim, keepdim) - return ops.argmax(input, dim, keepdim) + return execute('argmax_ext', input, dim, keepdim) # argmin -has_argmin = hasattr(mindspore.mint, 'argmin') def argmin(input, dim=None, keepdim=False): - if use_pyboost() and has_argmin: - return mindspore.mint.argmin(input, dim, keepdim) - return ops.argmin(input, dim, keepdim) + return execute('argmin_ext', input, dim, keepdim) # amax -has_amax = hasattr(mindspore.mint, 'amax') -def amax(input, dim=(), keepdim=False): - if use_pyboost() and has_amax: - return mindspore.mint.amax(input, dim, keepdim) - return ops.amax(input, dim, keepdim) +def amax(input, dim, keepdim=False): + return execute('reduce_max', input, dim, keepdim) # amin -has_amin = hasattr(mindspore.mint, 'amin') def amin(input, dim, keepdim=False): - if use_pyboost() and has_amin: - return mindspore.mint.amin(input, dim, keepdim) - return ops.amin(input, dim, keepdim) + return execute('reduce_min', input, dim, keepdim) # aminmax def aminmax(input, *, dim=None, keepdim=False): @@ -47,377 +30,189 @@ def aminmax(input, *, dim=None, keepdim=False): return amin(input, dim, keepdim), amax(input, dim, keepdim) # all -has_all = hasattr(mindspore.mint, 'all') def all(input, dim=None, keepdim=False, *, dtype=None, **kwargs): - axis = kwargs.get('axis', None) - keepdims = kwargs.get('keepdims', None) - if axis is not None: - dim = axis - if keepdims: - keepdim = keepdims - - if use_pyboost() and has_all: - return mindspore.mint.all(input, dim, keepdim).to(input.dtype) - return ops.all(input, dim, keepdim).to(input.dtype) + dim = kwargs.pop('axis', dim) + keepdim = kwargs.pop('keepdims', keepdim) + return execute('reduce_all', input, dim, keepdim) # any -has_any = hasattr(mindspore.mint, 'any') -def any(input, dim=None, keepdim=False, *, out=None): - if use_pyboost() and has_any: - if dim is None: - return call_ms_func(mindspore.mint.any, input, out=out) - else: - return call_ms_func(mindspore.mint.any, input, dim, keepdim, out=out) - return ops.any(input, dim, keepdim) +def any(input, dim=None, keepdim=False): + return execute('reduce_any', input, dim, keepdim) # max -has_max = hasattr(mindspore.mint, 'max') -def max(*args, **kwargs): - out = kwargs.pop('out', None) - if 'dim' in kwargs and 'keepdim' not in kwargs: - kwargs['keepdim'] = False - if 'axis' in kwargs: - kwargs['dim'] = kwargs.pop('axis') - out = mindspore.mint.max(*args, **kwargs) - if isinstance(out, tuple): - return max_out(values=out[0], indices=out[1]) +def max(input, dim=None, keepdim=False, *, out=None): + if dim is None and not keepdim: + return execute('max', input) + if core.is_tensor(dim): + return core.maximum(input, dim) + output = execute('argmax_with_value', input, dim, keepdim) + if out is None: + return max_out(values=output[1], indices=output[0]) + + out[0].data = output[0] + out[1].data = output[1] return out # min -has_min = hasattr(mindspore.mint, 'min') -def min(*args, **kwargs): - out = kwargs.pop('out', None) - if 'dim' in kwargs and 'keepdim' not in kwargs: - kwargs['keepdim'] = False - out = mindspore.mint.min(*args, **kwargs) - if isinstance(out, tuple): - return min_out(values=out[0], indices=out[1]) +def min(input, dim=None, keepdim=False, *, out=None): + if dim is None and not keepdim: + return execute('min', input) + if core.is_tensor(dim): + return core.minimum(input, dim) + output = execute('argmin_ext', input, dim, keepdim) + if out is None: + return min_out(values=output[1], indices=output[0]) + + out[0].data = output[0] + out[1].data = output[1] return out # dist - # logsumexp -has_logsumexp = hasattr(mindspore.mint, 'logsumexp') def logsumexp(input, dim, keepdim=False): - if use_pyboost() and has_logsumexp: - return mindspore.mint.logsumexp(input, dim, keepdim) - return ops.logsumexp(input, dim, keepdim) + return execute('logsumexp', input, dim, keepdim) # mean -has_mean = hasattr(mindspore.mint, 'mean') def mean(input, dim=None, keepdim=False, *, dtype=None, **kwargs): - axis = kwargs.get('axis', None) - if axis is not None: - dim = axis - if use_pyboost() and has_mean: - return mindspore.mint.mean(input, dim, keepdim, dtype=dtype) - out = ops.mean(input, dim, keepdim) - if dtype is not None: - out = out.astype(dtype) - return out + dim = kwargs.pop('axis', dim) + return execute('mean_ext', input, dim, keepdim, dtype) # nanmean # median -has_median = hasattr(mindspore.mint, 'median') def median(input, dim=-1, keepdim=False): - if use_pyboost() and has_median: - return mindspore.mint.median(input, dim, keepdim) - return ops.median(input, dim, keepdim) + if dim is None: + return execute('median_ext', input) + return execute('median_dim', input, dim, keepdim) # nanmedian -def nanmedian(input, dim=-1, keepdim=False): - return ops.nanmedian(input, dim, keepdim) + # mode # norm -has_norm = hasattr(mindspore.mint, 'norm') -def norm(input, p='fro', dim=None, keepdim=False, out=None, dtype=None): - if use_pyboost() and has_norm and not ON_ORANGE_PI: - return call_ms_func(mindspore.mint.norm, input, p, dim, keepdim, out=out, dtype=dtype) +def vector_norm_ext(input, p=2, dim=None, keepdim=False, *, dtype=None): + if float(p) in [0.0, 1.0, 2.0, 3.0]: + return execute('linalg_vector_norm', input, float(p), dim, keepdim, dtype) + if input.dtype in [core.bfloat16, core.float16, core.float32]: + if dtype is None: + return execute('lp_norm_v2', input, p, dim, keepdim, 0.0) + return execute('lp_norm_v2', input, p, dim, keepdim, 0.0).to(dtype) + + cast_dtype = input.dtype if dtype is None else dtype + input = input.to(core.float32) + return execute('lp_norm_v2', input, p, dim, keepdim, 0.0).to(cast_dtype) + +def matrix_norm_ext(A, ord='fro', dim=(-2, -1), keepdim=False, *, dtype=None): + ndim = A.ndim + row_axis, col_axis = _check_matrix_norm_axis(dim, ndim) + _check_matrix_norm_ord(ord) + if ord == 'fro': + return vector_norm_ext(A, 2, dim, keepdim, dtype=dtype) + if ord == 'nuc': + res = _multi_svd_norm(A, row_axis, col_axis, 'sum') + return _reshape_matrix_norm(A, res, dim, keepdim) + if ord == 2: + res = _multi_svd_norm(A, row_axis, col_axis, 'amax') + return _reshape_matrix_norm(A, res, dim, keepdim) + if ord == -2: + res = _multi_svd_norm(A, row_axis, col_axis, 'amin') + return _reshape_matrix_norm(A, res, dim, keepdim) + if ord in [float('inf'), -float('inf')]: + row_axis, col_axis = col_axis, row_axis + if not keepdim and col_axis > row_axis: + col_axis -= 1 + if ord < 0: + return amin(vector_norm_ext(A, 1, row_axis, keepdim, dtype=dtype), col_axis, keepdim) + return amax(vector_norm_ext(A, 1, row_axis, keepdim, dtype=dtype), col_axis, keepdim) + +def norm(input, p='fro', dim=None, keepdim=False, dtype=None): + if not isinstance(input, core.Tensor): + raise TypeError(f"For `norm_ext`, the `input` must be Tensor!, but get {type(input)}.") + if isinstance(p, (bool, int, float)): + return vector_norm_ext(input, p, dim, keepdim, dtype=dtype) if p == 'fro': - p = None - return ops.norm(input, p, dim, keepdim, dtype=dtype) + if isinstance(dim, (list, tuple)) and len(dim) > 2: + raise ValueError(f"For `norm_ext`, the size of `dim` cannot be greater than 2 " + f"when the norm mode is `fro`.") + return execute('linalg_vector_norm', input, 2.0, dim, keepdim, + dtype if dtype is None else dtype) + if p == 'nuc': + dim = tuple(range(input.ndim)) if dim is None else dim + return matrix_norm_ext(input, p, dim, keepdim, dtype=dtype) + raise ValueError(f"For `norm_ext`, the value of `p` must be one of [int, float, inf, -inf, 'fro', 'nuc',] " + f"but got `{p}`.") # nansum -has_nansum = hasattr(mindspore.mint, 'nansum') def nansum(input, dim=None, keepdim=False, *, dtype=None): - if use_pyboost() and has_nansum: - return mindspore.mint.nansum(input, dim, keepdim, dtype=dtype) - return ops.nansum(input, dim, keepdim, dtype=dtype) + return execute('nansum', input, dim, keepdim, dtype) # prod -has_prod = hasattr(mindspore.mint, 'prod') def prod(input, dim=None, keepdim=False, *, dtype=None): - if use_pyboost() and has_prod: - return mindspore.mint.prod(input, dim, keepdim, dtype=dtype) - return ops.prod(input, dim, keepdim).to(dtype) + return execute('prod_ext', input, dim, keepdim,dtype) # quantile -def quantile_output_shape( - original_dim, - input_tensor, - q, - keepdim, - wrapped_dim -): - """ - 计算分位数函数的输出形状 - - 参数: - original_dim: 原始维度(None表示展平) - input_tensor: 输入张量 - q: 分位数张量 - keepdim: 是否保留维度 - wrapped_dim: 处理后的维度索引 - """ - # 计算输出形状: q大小 + 缩减维度后的大小 - out_shape = [] - - if original_dim is not None and input_tensor.dim() > 0: - # 保留原始维度结构 - out_shape = list(input_tensor.shape) - if keepdim: - out_shape[wrapped_dim] = 1 - else: - del out_shape[wrapped_dim] - elif keepdim: - # 当展平但需保留维度时创建全1形状 - out_shape = [1] * input_tensor.dim() - - if q.dim() > 0: - # 添加分位数维度到最前面 - out_shape.insert(0, q.numel()) - - return out_shape - - -def quantile( - input_tensor, - q, - dim = None, - keepdim: bool = False, - interpolation: str = 'linear', - ignore_nan: bool = False -): - """ - PyTorch分位数函数的完整实现 - - 参数: - input_tensor: 输入数据 - q: 分位数(0-1之间) - dim: 计算维度 - keepdim: 是否保留维度 - interpolation: 插值模式 ('linear', 'lower', 'higher', 'nearest', 'midpoint') - ignore_nan: 是否忽略NaN值 - - 返回: - 计算得到的分位数 - """ - if isinstance(q, numbers.Number): - q = core.tensor(q, dtype=input_tensor.dtype) - # ===== 1. 输入验证 ===== - device = input_tensor.device - dtype = input_tensor.dtype - - # 验证分位数范围 - if device.type == 'cpu': - if not core.all((q >= 0) & (q <= 1)): - raise ValueError("quantile() q values must be in the range [0, 1]") - - # ===== 2. 维度处理 ===== - wrapped_dim = dim if dim is not None else 0 - original_dim = dim - - if dim is not None: - # 验证维度有效性 - if dim < 0: - dim = input_tensor.dim() + dim - if dim < 0 or dim >= input_tensor.dim(): - raise ValueError(f"Dimension out of range (expected to be in range [{-input_tensor.dim()}, {input_tensor.dim()-1}])") - wrapped_dim = dim - - # 计算输出形状 - out_shape = quantile_output_shape(original_dim, input_tensor, q, keepdim, wrapped_dim) - - # ===== 3. 数据预处理 ===== - # 处理标量分位数 - q_scalar = q.dim() == 0 - q = q.reshape(-1) # 确保q是1D - - # 展平或重排维度 - if dim is None: - # 展平整个张量 - sorted_x, _ = input_tensor.flatten().sort() - elif wrapped_dim == input_tensor.dim() - 1: - # 当目标维度已是最后一维时直接排序 - sorted_x, _ = input_tensor.sort(dim=wrapped_dim) - else: - # 将目标维度移到末尾再排序 - transposed = input_tensor.transpose(wrapped_dim, -1).unsqueeze(-1) - sorted_x, _ = transposed.sort(dim=-2) - sorted_x = sorted_x.squeeze(-1) - - # ===== 4. 分位数计算核心 ===== - n = sorted_x.shape[-1] - - # 处理空输入 - if n == 0: - result = core.full(out_shape, float('nan'), device=device, dtype=dtype) - return result - - # 计算排名位置 (考虑NaN处理) - if ignore_nan: - # 计算非NaN数量 - non_nan_count = (~sorted_x.isnan()).sum(dim=-1, keepdim=True) - ranks = q * (non_nan_count - 1) - ranks = core.clamp(ranks, min=0) # 防止负索引 - else: - last_index = n - 1 - # 广播处理NaN标记 - nan_mask = sorted_x.isnan().any(dim=-1, keepdim=True) - # 扩展q和nan_mask到相同形状 - expanded_q = q.view(1, -1).expand(*sorted_x.shape[:-1], q.numel()) - nan_mask = nan_mask.expand_as(expanded_q) - # 计算基础排名 - ranks = expanded_q * last_index - # 对包含NaN的行使用最后索引 - ranks = core.where(nan_mask, core.tensor(last_index, device=device), ranks) - - # 根据插值模式调整排名 - if interpolation == 'lower': - ranks = core.floor(ranks) - elif interpolation == 'higher': - ranks = core.ceil(ranks) - elif interpolation == 'nearest': - ranks = core.round(ranks) - - # 确保排名在有效范围内 - ranks = core.clamp(ranks, 0, n - 1) - - # 获取下界索引和值 - ranks_below = ranks.to(core.int64) - values_below = sorted_x.gather(-1, ranks_below) - - # ===== 5. 插值处理 ===== - if interpolation in ['linear', 'midpoint']: - # 计算插值权重 - weights = core.full_like(ranks, 0.5) if interpolation == 'midpoint' else ranks - ranks_below - - # 获取上界值 - ranks_above = core.ceil(ranks).to(core.int64) - values_above = sorted_x.gather(-1, ranks_above) - - # 线性插值: result = (1 - weight)*below + weight*above - values_below = values_below.lerp(values_above, weights) - - # ===== 6. 形状调整 ===== - if q_scalar: - # 标量分位数:移除分位数维度 - values_below = values_below.squeeze(-1) - else: - # 多分位数:移动分位数维度到最前面 - values_below = values_below.movedim(-1, 0) - - # 恢复原始输出形状 - if values_below.shape != tuple(out_shape): - values_below = values_below.reshape(out_shape) - - return values_below # nanquantile -def nanquantile(input, q, dim=None, keepdim=False, *, interpolation='linear'): - return ops.nanquantile(input, q, dim, keepdim) # std -has_std = hasattr(mindspore.mint, 'std') def std(input, dim=None, *, correction=1, keepdim=False, **kwargs): - axis = kwargs.get('axis', None) - if axis is not None: - dim = axis - if use_pyboost() and has_std and not ON_ORANGE_PI: - return mindspore.mint.std(input, dim=dim, correction=correction, keepdim=keepdim) - if DEVICE_TARGET == 'GPU': - unbiased = bool(correction) - if dim is None: - dim = () - if isinstance(dim, int): - dim = (dim,) - _std = _get_cache_prim(ops.ReduceStd)(dim, unbiased, keepdim) - _std.set_device('CPU') - return _std(input)[0] - return ops.std(input, dim, correction, keepdim) + dim = kwargs.pop('axis', dim) + return execute('std', input, dim, correction, keepdim) # std_mean -has_std_mean = hasattr(mindspore.mint, 'std_mean') def std_mean(input, dim=None, *, correction=1, keepdim=False): - if use_pyboost and has_std_mean: - return mindspore.mint.std_mean(input, dim=dim, correction=correction, keepdim=keepdim) - return std(input, dim, correction=correction, keepdim=keepdim), \ - mean(input, dim, keepdim) + return execute('std_mean', input, dim, correction, keepdim) # sum -has_sum = hasattr(mindspore.mint, 'sum') -def sum(input, dim=None, keepdim=False, *, dtype=None, **kwargs): - keepdims = kwargs.pop('keepdims', None) - if keepdims is not None: - keepdim = keepdims +def sum(input, dim=None, keepdim=False, *, dtype=None): if 0 in input.shape: - return mindspore.tensor(0, dtype=dtype) - if use_pyboost() and has_sum: - return mindspore.mint.sum(input, dim, keepdim, dtype=dtype) - return ops.sum(input, dim, keepdim, dtype=dtype) + return core.tensor(0, dtype=dtype, device=input.device) + return execute('sum_ext', input, dim, keepdim, dtype) # unique -has_unique = hasattr(mindspore.mint, 'unique') def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None): - if use_pyboost() and has_unique: - return mindspore.mint.unique(input, sorted, return_inverse, return_counts, dim) - - out, inverse = ops.unique(input) - outs = (out,) + if dim is None: + y, inverse, counts = execute('unique2', + input, sorted, return_inverse, return_counts) + else: + y, inverse, counts = execute('unique_dim', input, sorted, return_inverse, dim) + if return_inverse and return_counts: + return y, inverse, counts if return_inverse: - outs += (inverse,) + return y, inverse if return_counts: - counts = (out == input).sum(0, keepdims=True) - outs += (counts,) - return outs if len(outs) > 1 else outs[0] + return y, counts + return y # unique_consecutive -has_unique_consecutive = hasattr(mindspore.mint, 'unique_consecutive') def unique_consecutive(input, return_inverse=False, return_counts=False, dim=None): - if use_pyboost() and has_unique_consecutive: - return mindspore.mint.unique_consecutive(input, return_inverse, return_counts, dim) - return ops.unique_consecutive(input, return_inverse, return_counts, dim) + output, idx, counts = execute('unique_consecutive', input, return_inverse, return_counts, dim) + if return_inverse and return_counts: + return output, idx, counts + if return_inverse: + return output, idx + if return_counts: + return output, counts + return output # var -has_var = hasattr(mindspore.mint, 'var') -def var(input, dim=None, *, correction=1, keepdim=False, **kwargs): - correction = int(kwargs.pop('unbiased', correction)) - if use_pyboost and has_var: - return mindspore.mint.var(input, dim=dim, correction=correction, keepdim=keepdim) - return pow(std(input, dim, correction=correction, keepdim=keepdim), 2) +def var(input, dim=None, *, correction=1, keepdim=False): + return execute('var', input, dim, correction, keepdim) # var_mean -has_var_mean = hasattr(mindspore.mint, 'var_mean') def var_mean(input, dim=None, *, correction=1, keepdim=False): - if use_pyboost and has_var_mean: - return mindspore.mint.var_mean(input, dim=dim, correction=correction, keepdim=keepdim) - return pow(std(input, dim, correction=correction, keepdim=keepdim), 2), \ - mean(input, dim, keepdim) + return execute('var_mean', input, dim, correction, keepdim) # count_nonzero -has_count_nonzero = hasattr(mindspore.mint, 'count_nonzero') def count_nonzero(input, dim=None): - if use_pyboost() and has_count_nonzero: - return mindspore.mint.count_nonzero(input, dim) - if dim is None: - dim = () - return ops.count_nonzero(input, dim) + return execute('count_nonzero', input, dim) -__all__ = ['all', 'amax', 'amin', 'aminmax', 'any', 'argmax', 'argmin', 'count_nonzero', 'logsumexp', 'max', 'mean', 'median', 'min', 'nanmedian', 'nanquantile', 'nansum', 'norm', 'prod', 'quantile', 'std', 'std_mean', 'sum', 'unique', 'unique_consecutive', 'var', 'var_mean'] +__all__ = ['all', 'amax', 'amin', 'aminmax', 'any', 'argmax', 'argmin', 'count_nonzero', + 'logsumexp', 'max', 'mean', 'median', 'min', 'nansum', + 'norm', 'prod', 'std', 'std_mean', 'sum', 'unique', 'unique_consecutive', + 'var', 'var_mean'] \ No newline at end of file diff --git a/mindnlp/core/types.py b/mindnlp/core/types.py index 166e3fa78..8b0a47672 100644 --- a/mindnlp/core/types.py +++ b/mindnlp/core/types.py @@ -10,7 +10,9 @@ from typing import Any, IO, TYPE_CHECKING, Union, Dict from typing_extensions import Self, TypeAlias +from mindnlp import core from ._dtype import dtype +from .configs import DEVICE_TARGET DEVICE_MAP = { 'GPU': 'cuda', @@ -51,6 +53,8 @@ def __init__(self, type=None, index=None): self.type = _target self.index = _id + if DEVICE_TARGET == 'Ascned' and self.type == 'cuda': + self.type = 'npu' def __repr__(self): if self.index is None: @@ -65,12 +69,18 @@ def __eq__(self, __value): def __hash__(self): return hash(self.type) ^ hash(self.index) + def __gt__(self, other): + if self.type == 'cpu': + return False + return True + def __enter__(self): # self.prev_idx = torch.cuda._exchange_device(self.idx) - pass + core._bind.set_device_in_context(self) def __exit__(self, type: Any, value: Any, traceback: Any): # self.idx = torch.cuda._maybe_exchange_device(self.prev_idx) + core._bind.set_device_in_context(None) return False diff --git a/mindnlp/transformers/__init__.py b/mindnlp/transformers/__init__.py index 28030869c..9ae4f2b9d 100644 --- a/mindnlp/transformers/__init__.py +++ b/mindnlp/transformers/__init__.py @@ -1,8 +1,7 @@ import sys -import transformers -from transformers.utils import OptionalDependencyNotAvailable, _LazyModule -from transformers.utils.import_utils import * +from mindnlp.utils.import_utils import * +from mindnlp.utils.import_utils import _LazyModule # Base objects, independent of any specific backend _import_structure = { @@ -4103,8 +4102,11 @@ from . import ms_utils -from .masking_utils import create_causal_mask +from .masking_utils import create_causal_mask, create_sliding_window_causal_mask +from .modeling_utils import construct_pipeline_parallel_model, _load_pretrained_model_wrapper +# redirect mindnlp.transformers to transformers +import transformers sys.modules[__name__] = _LazyModule( 'transformers', transformers.__file__, @@ -4113,19 +4115,31 @@ extra_objects={"__version__": transformers.__version__}, ) + +# patch transformers def not_supported(): return False +def empty_fn(*args, **kwargs): + pass + transformers.utils.import_utils._torch_fx_available = False transformers.utils.import_utils.is_torch_sdpa_available = not_supported -from ..utils.decorators import dtype_wrapper, patch_dtype_wrapper +from ..utils.decorators import dtype_wrapper, patch_dtype_wrapper, patch_wrappers patch_dtype_wrapper(transformers.AutoModel, 'from_pretrained') patch_dtype_wrapper(transformers.modeling_utils.PreTrainedModel, 'from_pretrained', [transformers.modeling_utils.restore_default_torch_dtype] ) +patch_wrappers(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', + [_load_pretrained_model_wrapper]) + transformers.pipelines.pipeline = dtype_wrapper(transformers.pipelines.pipeline) +transformers.modeling_utils.caching_allocator_warmup = empty_fn +transformers.masking_utils.create_causal_mask = create_causal_mask +transformers.masking_utils.create_sliding_window_causal_mask = create_sliding_window_causal_mask -transformers.masking_utils.create_causal_mask = create_causal_mask \ No newline at end of file +# add mindnlp.transformers modules/attrs to lazymodule +# setattr(sys.modules[__name__], 'test_ms_model', test_ms_model) diff --git a/mindnlp/transformers/modeling_utils.py b/mindnlp/transformers/modeling_utils.py new file mode 100644 index 000000000..f5df8bc3c --- /dev/null +++ b/mindnlp/transformers/modeling_utils.py @@ -0,0 +1,213 @@ + +import types + +from mindspore.communication import GlobalComm +from ..core import nn, ops, distributed as dist +from ..utils import logging + +logger = logging.get_logger(__name__) + + +def replace_submodule(model, submodule_path, new_module): + parent_path, _, child_name = submodule_path.rpartition('.') + + parent_module = model.get_submodule(parent_path) if parent_path else model + + setattr(parent_module, child_name, new_module) + +def send_forward(self, *args, **kwargs): + output = self._forward(*args, **kwargs) + dist.isend(output[0], self.dist) + return output + +def receive_forward(self, *args, **kwargs): + hidden_states = args[0] + dist.irecv(hidden_states, src=self.src) + output = self._forward(*((hidden_states,) + args[1:]), **kwargs) + return output + +def broadcast_forward(self, *args, **kwargs): + output = self._forward(*args, **kwargs) + dist.broadcast(output, src=self.src) + return output + +class DecoderLayerIdentity(nn.Module): + def __init__(self, layer_idx, config): + super().__init__() + self.layer_idx = layer_idx + self.num_key_value_heads = config.num_key_value_heads + self.attention_type = config.layer_types[layer_idx] + + def forward(self, *args, **kwargs): + past_key_value = kwargs.get('past_key_value', None) + hidden_states = args[0] + bs, seq_len, _ = hidden_states.shape + + if past_key_value is not None: + past_key_value.update( + ops.empty(bs, self.num_key_value_heads, seq_len, 0, dtype=hidden_states.dtype, device='meta'), + ops.empty(bs, self.num_key_value_heads, seq_len, 0, dtype=hidden_states.dtype, device='meta'), + self.layer_idx) + + return hidden_states + + +class EmbeddingIndentity(nn.Module): + def __init__(self, num_embeddings: int, embedding_dim: int, dtype=None): + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.dtype = dtype + + def forward(self, input): + return ops.empty(input.shape + (self.embedding_dim,), dtype=self.dtype, device='meta') + +class LinearIndetity(nn.Module): + def __init__(self, in_features, out_features, dtype=None): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.dtype = dtype + + def forward(self, input): + return ops.empty(input.shape[:-1] + (self.out_features,), dtype=self.dtype, device='meta') + +def construct_pipeline_parallel_model(model, device_map): + current_device = dist.get_rank() + last_device = dist.get_world_size() - 1 + no_split_modules = model._get_no_split_modules(device_map) + reversed_device_map = {} + for scope_name, device in device_map.items(): + if device not in reversed_device_map: + reversed_device_map[device] = [scope_name] + else: + reversed_device_map[device].append(scope_name) + + if device != current_device: + submodule = model.get_submodule(scope_name) + if isinstance(submodule, nn.Embedding): + new_embedding = EmbeddingIndentity(submodule.num_embeddings, submodule.embedding_dim, model.dtype) + replace_submodule(model, scope_name, new_embedding) + elif isinstance(submodule, nn.Linear): + new_linear = LinearIndetity(submodule.in_features, submodule.out_features, model.dtype) + replace_submodule(model, scope_name, new_linear) + elif submodule.__class__.__name__ in no_split_modules: + new_layer = DecoderLayerIdentity(submodule.self_attn.layer_idx, submodule.self_attn.config) + replace_submodule(model, scope_name, new_layer) + else: + # new_layer = nn.Identity() + # replace_submodule(model, scope_name, new_layer) + pass + + if current_device < last_device: + current_last_layer = model.get_submodule(reversed_device_map[current_device][-1]) + current_last_layer._forward = current_last_layer.forward + current_last_layer.forward = types.MethodType(send_forward, current_last_layer) + current_last_layer.dist = current_device + 1 + + if current_device > 0: + current_first_layer = model.get_submodule(reversed_device_map[current_device][0]) + current_first_layer._forward = current_first_layer.forward + current_first_layer.forward = types.MethodType(receive_forward, current_first_layer) + current_first_layer.src = current_device - 1 + + model_last_layer = model.get_submodule(next(reversed(device_map))) + model_last_layer._forward = model_last_layer.forward + model_last_layer.forward = types.MethodType(broadcast_forward, model_last_layer) + model_last_layer.src = last_device + + return model + +def find_usefull_files(shared_files, shared_meta, model_params): + files_path = '/'.join(shared_files[0].split('/')[:-1]) + usefull_files = set() + + for param_name, file_name in shared_meta['weight_map'].items(): + if param_name in model_params: + usefull_files.add(file_name) + # else: + # shared_meta['all_checkpoint_keys'].remove(param_name) + + usefull_files = [files_path + '/' + file for file in usefull_files] + + return usefull_files, shared_meta + + +def _load_pretrained_model_wrapper(fn): + def wrapper( + cls, + model, + state_dict, + checkpoint_files, + pretrained_model_name_or_path, + ignore_mismatched_sizes, + sharded_metadata, + device_map, + disk_offload_folder, + offload_state_dict = None, + dtype = None, + hf_quantizer = None, + keep_in_fp32_regex = None, + device_mesh = None, + key_mapping = None, + weights_only = True, + ): + # if device_map is not None and not initialize distribute module, raise Error. + if device_map is not None: + if all([isinstance(d, int) for d in device_map.values()]): + if len(set(device_map.values())) > 1 and not GlobalComm.INITED: + raise RuntimeError(f'to use transformers with multi-gpu/npu, please use `msrun/mpirun` ' \ + f'with {len(set(device_map.values()))} devices to launch multiprocess.') + + model = construct_pipeline_parallel_model(model, device_map) + checkpoint_files, sharded_metadata = find_usefull_files(checkpoint_files, sharded_metadata, model.state_dict().keys()) + + rank = dist.get_rank() + world_size = dist.get_world_size() + + dist.barrier() + + for target_rank in range(world_size): + if rank == target_rank: + print(f'rebuild model ont rank {rank}') + model = fn( + cls, + model, + state_dict, + checkpoint_files, + pretrained_model_name_or_path, + ignore_mismatched_sizes, + sharded_metadata, + device_map, + disk_offload_folder, + offload_state_dict, + dtype, + hf_quantizer, + keep_in_fp32_regex, + device_mesh, + key_mapping, + weights_only, + ) + + dist.barrier() + return model + + return fn( + cls, + model, + state_dict, + checkpoint_files, + pretrained_model_name_or_path, + ignore_mismatched_sizes, + sharded_metadata, + device_map, + disk_offload_folder, + offload_state_dict, + dtype, + hf_quantizer, + keep_in_fp32_regex, + device_mesh, + key_mapping, + weights_only, + ) + return wrapper \ No newline at end of file diff --git a/mindnlp/transformers/models/__init__.py b/mindnlp/transformers/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/utils/__init__.py b/mindnlp/utils/__init__.py index 7fca6fc2b..fac1eefd4 100644 --- a/mindnlp/utils/__init__.py +++ b/mindnlp/utils/__init__.py @@ -21,13 +21,7 @@ # from .download import * # from .compatibility import * # from .chat_template_utils import * -from .import_utils import requires_backends, is_mindspore_available, OptionalDependencyNotAvailable, is_sentencepiece_available, \ -is_tokenizers_available, direct_transformers_import, is_protobuf_available, is_safetensors_available, \ -is_cython_available, is_pretty_midi_available, is_essentia_available, is_librosa_available, is_scipy_available, is_pyctcdecode_available, \ -is_jieba_available, is_vision_available, is_sudachi_projection_available, is_g2p_en_available, is_levenshtein_available, is_nltk_available, \ -is_bs4_available, is_pytesseract_available, is_tiktoken_available, is_einops_available, is_faiss_available, is_datasets_available, \ -is_sacremoses_available, is_phonemizer_available,is_speech_available, is_kenlm_available, is_triton_available - +from .import_utils import * # from .testing_utils import require_mindspore # from .save import convert_file_size_to_int # from .peft_utils import find_adapter_config_file diff --git a/mindnlp/utils/decorators.py b/mindnlp/utils/decorators.py index 59afd838f..16e1b8f19 100644 --- a/mindnlp/utils/decorators.py +++ b/mindnlp/utils/decorators.py @@ -16,13 +16,15 @@ def wrapper(*args, **kwargs): return wrapper def patch_dtype_wrapper(cls, method_name, other_decorators=None): + patch_wrappers(cls, method_name, [dtype_wrapper]) + +def patch_wrappers(cls, method_name, other_decorators=None): original_method = getattr(cls, method_name) - - wrapped_func = dtype_wrapper(original_method.__func__) + wrapped_func = original_method.__func__ if other_decorators is not None: for dec in other_decorators: wrapped_func = dec(wrapped_func) # 重新创建类方法并赋值回类 - setattr(cls, method_name, classmethod(wrapped_func)) \ No newline at end of file + setattr(cls, method_name, classmethod(wrapped_func)) diff --git a/mindnlp/utils/generic.py b/mindnlp/utils/generic.py index 2de3a7a36..13f8bb953 100644 --- a/mindnlp/utils/generic.py +++ b/mindnlp/utils/generic.py @@ -25,7 +25,6 @@ from functools import wraps import numpy as np import mindspore -from .import_utils import is_mindspore_available def is_tensor(x): @@ -58,6 +57,7 @@ def is_mindspore_tensor(x): Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed. """ return False if not is_mindspore_available() else _is_mindspore(x) + def set_attribute_for_modules(module, key: str, value: Any): """ Set a value to a module and all submodules. diff --git a/mindnlp/utils/import_utils.py b/mindnlp/utils/import_utils.py index 1a8cecae0..a858c94c6 100644 --- a/mindnlp/utils/import_utils.py +++ b/mindnlp/utils/import_utils.py @@ -1,5 +1,4 @@ # Copyright 2022 The HuggingFace Team. All rights reserved. -# Copyright 2023 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,501 +11,1666 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================ """ Import utilities: Utilities related to imports and our lazy inits. """ +import importlib.machinery +import importlib.metadata +import importlib.util +import json +import operator import os +import re +import shutil +import subprocess import sys import warnings -from types import ModuleType from collections import OrderedDict -from functools import wraps, lru_cache -from typing import Tuple, Union -import importlib.util +from enum import Enum +from functools import lru_cache +from itertools import chain +from types import ModuleType +from typing import Any, Optional, Union + from packaging import version from . import logging -if sys.version_info >= (3, 8): - # For Python 3.8 and later - from importlib import metadata as importlib_metadata -else: - # For Python versions earlier than 3.8 - import importlib_metadata +logger = logging.get_logger(__name__) # pylint: disable=invalid-name -logger = logging.get_logger(__name__) -def _is_package_available( - pkg_name: str, return_version: bool = False -) -> Union[Tuple[bool, str], bool]: - """ - Checks if a specified package is available and optionally returns its version. - - Args: - pkg_name (str): The name of the package to check for availability. - return_version (bool, optional): Indicates whether to return the package version along with availability status. Defaults to False. - - Returns: - Union[Tuple[bool, str], bool]: If return_version is True, returns a tuple containing a boolean indicating package availability and a string representing the package version. - If return_version is False, returns a boolean indicating package availability. - - Raises: - No specific exceptions are raised within this function. - """ - # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version +# TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better. +def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[tuple[bool, str], bool]: + # Check if the package spec exists and grab its version to avoid importing a local directory package_exists = importlib.util.find_spec(pkg_name) is not None package_version = "N/A" if package_exists: try: - package_version = importlib_metadata.version(pkg_name) - package_exists = True - except importlib_metadata.PackageNotFoundError: - package_exists = False - logger.debug(f"Detected {pkg_name} version {package_version}") + # TODO: Once python 3.9 support is dropped, `importlib.metadata.packages_distributions()` + # should be used here to map from package name to distribution names + # e.g. PIL -> Pillow, Pillow-SIMD; quark -> amd-quark; onnxruntime -> onnxruntime-gpu. + # `importlib.metadata.packages_distributions()` is not available in Python 3.9. + + # Primary method to get the package version + package_version = importlib.metadata.version(pkg_name) + except importlib.metadata.PackageNotFoundError: + # Fallback method: Only for "torch" and versions containing "dev" + if pkg_name == "torch": + try: + package = importlib.import_module(pkg_name) + temp_version = getattr(package, "__version__", "N/A") + # Check if the version contains "dev" + if "dev" in temp_version: + package_version = temp_version + package_exists = True + else: + package_exists = False + except ImportError: + # If the package can't be imported, it's not available + package_exists = False + elif pkg_name == "quark": + # TODO: remove once `importlib.metadata.packages_distributions()` is supported. + try: + package_version = importlib.metadata.version("amd-quark") + except Exception: + package_exists = False + elif pkg_name == "triton": + try: + package_version = importlib.metadata.version("pytorch-triton") + except Exception: + package_exists = False + else: + # For packages other than "torch", don't attempt the fallback and set as not available + package_exists = False + logger.debug(f"Detected {pkg_name} version: {package_version}") if return_version: return package_exists, package_version - return package_exists - - -_ftfy_available = _is_package_available("ftfy") -_einops_available = _is_package_available('einops') -_tiktoken_available = _is_package_available('tiktoken') + else: + return package_exists + + +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} +ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) + +USE_TF = os.environ.get("USE_TF", "AUTO").upper() +USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() +USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() + +# Try to run a native pytorch job in an environment with TorchXLA installed by setting this value to 0. +USE_TORCH_XLA = os.environ.get("USE_TORCH_XLA", "1").upper() + +FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper() + +# `transformers` requires `torch>=1.11` but this variable is exposed publicly, and we can't simply remove it. +# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs. +TORCH_FX_REQUIRED_VERSION = version.parse("1.10") + +ACCELERATE_MIN_VERSION = "0.26.0" +SCHEDULEFREE_MIN_VERSION = "1.2.6" +FSDP_MIN_VERSION = "1.12.0" +GGUF_MIN_VERSION = "0.10.0" +XLA_FSDPV2_MIN_VERSION = "2.2.0" +HQQ_MIN_VERSION = "0.2.1" +VPTQ_MIN_VERSION = "0.0.4" +TORCHAO_MIN_VERSION = "0.4.0" +AUTOROUND_MIN_VERSION = "0.5.0" +TRITON_MIN_VERSION = "1.0.0" + +_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) +_apex_available = _is_package_available("apex") +_apollo_torch_available = _is_package_available("apollo_torch") +_aqlm_available = _is_package_available("aqlm") +_vptq_available, _vptq_version = _is_package_available("vptq", return_version=True) +_av_available = importlib.util.find_spec("av") is not None +_decord_available = importlib.util.find_spec("decord") is not None +_torchcodec_available = importlib.util.find_spec("torchcodec") is not None +_libcst_available = _is_package_available("libcst") +_bitsandbytes_available = _is_package_available("bitsandbytes") +_eetq_available = _is_package_available("eetq") +_fbgemm_gpu_available = _is_package_available("fbgemm_gpu") +_galore_torch_available = _is_package_available("galore_torch") +_lomo_available = _is_package_available("lomo_optim") +_grokadamw_available = _is_package_available("grokadamw") +_schedulefree_available, _schedulefree_version = _is_package_available("schedulefree", return_version=True) +_torch_optimi_available = importlib.util.find_spec("optimi") is not None +# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. _bs4_available = importlib.util.find_spec("bs4") is not None -_pytest_available = _is_package_available("pytest") +_coloredlogs_available = _is_package_available("coloredlogs") +# `importlib.metadata.util` doesn't work with `opencv-python-headless`. +_cv2_available = importlib.util.find_spec("cv2") is not None +_yt_dlp_available = importlib.util.find_spec("yt_dlp") is not None _datasets_available = _is_package_available("datasets") -_sentencepiece_available = _is_package_available("sentencepiece") -_soundfile_available = _is_package_available("soundfile") -_tokenizers_available = _is_package_available("tokenizers") -_pyctcdecode_available = _is_package_available("pyctcdecode") -_safetensors_available = _is_package_available("safetensors") -_modelscope_available = _is_package_available("modelscope") -_jieba_available = _is_package_available("jieba") -_pytesseract_available = _is_package_available("pytesseract") +_detectron2_available = _is_package_available("detectron2") +# We need to check `faiss`, `faiss-cpu` and `faiss-gpu`. +_faiss_available = importlib.util.find_spec("faiss") is not None +try: + _faiss_version = importlib.metadata.version("faiss") + logger.debug(f"Successfully imported faiss version {_faiss_version}") +except importlib.metadata.PackageNotFoundError: + try: + _faiss_version = importlib.metadata.version("faiss-cpu") + logger.debug(f"Successfully imported faiss version {_faiss_version}") + except importlib.metadata.PackageNotFoundError: + try: + _faiss_version = importlib.metadata.version("faiss-gpu") + logger.debug(f"Successfully imported faiss version {_faiss_version}") + except importlib.metadata.PackageNotFoundError: + _faiss_available = False +_ftfy_available = _is_package_available("ftfy") _g2p_en_available = _is_package_available("g2p_en") +_hadamard_available = _is_package_available("fast_hadamard_transform") +_ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True) +_jieba_available = _is_package_available("jieba") +_jinja_available = _is_package_available("jinja2") +_kenlm_available = _is_package_available("kenlm") +_keras_nlp_available = _is_package_available("keras_nlp") +_levenshtein_available = _is_package_available("Levenshtein") +_librosa_available = _is_package_available("librosa") +_natten_available = _is_package_available("natten") +_nltk_available = _is_package_available("nltk") +_onnx_available = _is_package_available("onnx") +_openai_available = _is_package_available("openai") +_optimum_available = _is_package_available("optimum") +_auto_gptq_available = _is_package_available("auto_gptq") +_gptqmodel_available = _is_package_available("gptqmodel") +_auto_round_available, _auto_round_version = _is_package_available("auto_round", return_version=True) +# `importlib.metadata.version` doesn't work with `awq` +_auto_awq_available = importlib.util.find_spec("awq") is not None +_quark_available = _is_package_available("quark") +_fp_quant_available, _fp_quant_version = _is_package_available("fp_quant", return_version=True) +_qutlass_available = _is_package_available("qutlass") +_is_optimum_quanto_available = False +try: + importlib.metadata.version("optimum_quanto") + _is_optimum_quanto_available = True +except importlib.metadata.PackageNotFoundError: + _is_optimum_quanto_available = False +# For compressed_tensors, only check spec to allow compressed_tensors-nightly package +_compressed_tensors_available = importlib.util.find_spec("compressed_tensors") is not None +_pandas_available = _is_package_available("pandas") +_peft_available = _is_package_available("peft") _phonemizer_available = _is_package_available("phonemizer") -_mindspore_version, _mindspore_available = _is_package_available( - "mindspore", return_version=True -) +_uroman_available = _is_package_available("uroman") +_psutil_available = _is_package_available("psutil") +_py3nvml_available = _is_package_available("py3nvml") +_pyctcdecode_available = _is_package_available("pyctcdecode") +_pygments_available = _is_package_available("pygments") +_pytesseract_available = _is_package_available("pytesseract") +_pytest_available = _is_package_available("pytest") +_pytorch_quantization_available = _is_package_available("pytorch_quantization") +_rjieba_available = _is_package_available("rjieba") +_sacremoses_available = _is_package_available("sacremoses") +_safetensors_available = _is_package_available("safetensors") +_scipy_available = _is_package_available("scipy") +_sentencepiece_available = _is_package_available("sentencepiece") +_is_seqio_available = _is_package_available("seqio") +_is_gguf_available, _gguf_version = _is_package_available("gguf", return_version=True) +_sklearn_available = importlib.util.find_spec("sklearn") is not None +if _sklearn_available: + try: + importlib.metadata.version("scikit-learn") + except importlib.metadata.PackageNotFoundError: + _sklearn_available = False +_smdistributed_available = importlib.util.find_spec("smdistributed") is not None +_soundfile_available = _is_package_available("soundfile") +_spacy_available = _is_package_available("spacy") _sudachipy_available, _sudachipy_version = _is_package_available("sudachipy", return_version=True) +_tensorflow_probability_available = _is_package_available("tensorflow_probability") +_tensorflow_text_available = _is_package_available("tensorflow_text") +_tf2onnx_available = _is_package_available("tf2onnx") +_timm_available = _is_package_available("timm") +_tokenizers_available = _is_package_available("tokenizers") +_torchaudio_available = _is_package_available("torchaudio") +_torchao_available, _torchao_version = _is_package_available("torchao", return_version=True) +_torchdistx_available = _is_package_available("torchdistx") +_torchvision_available, _torchvision_version = _is_package_available("torchvision", return_version=True) +_mlx_available = _is_package_available("mlx") +_num2words_available = _is_package_available("num2words") +_hqq_available, _hqq_version = _is_package_available("hqq", return_version=True) +_tiktoken_available = _is_package_available("tiktoken") +_blobfile_available = _is_package_available("blobfile") +_liger_kernel_available = _is_package_available("liger_kernel") +_spqr_available = _is_package_available("spqr_quant") +_rich_available = _is_package_available("rich") +_kernels_available = _is_package_available("kernels") +_matplotlib_available = _is_package_available("matplotlib") +_mistral_common_available = _is_package_available("mistral_common") +_triton_available, _triton_version = _is_package_available("triton", return_version=True) +_triton_kernels_available = _is_package_available("triton_kernels") + +_torch_version = "N/A" +_torch_available = False +if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: + _torch_available, _torch_version = _is_package_available("torch", return_version=True) + if _torch_available: + _torch_available = version.parse(_torch_version) >= version.parse("2.1.0") + if not _torch_available: + logger.warning(f"Disabling PyTorch because PyTorch >= 2.1 is required but found {_torch_version}") +else: + logger.info("Disabling PyTorch because USE_TF is set") + _torch_available = False + + +_tf_version = "N/A" +_tf_available = False +if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES: + _tf_available = True +else: + if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: + # Note: _is_package_available("tensorflow") fails for tensorflow-cpu. Please test any changes to the line below + # with tensorflow-cpu to make sure it still works! + _tf_available = importlib.util.find_spec("tensorflow") is not None + if _tf_available: + candidates = ( + "tensorflow", + "tensorflow-cpu", + "tensorflow-gpu", + "tf-nightly", + "tf-nightly-cpu", + "tf-nightly-gpu", + "tf-nightly-rocm", + "intel-tensorflow", + "intel-tensorflow-avx512", + "tensorflow-rocm", + "tensorflow-macos", + "tensorflow-aarch64", + ) + _tf_version = None + # For the metadata, we have to look for both tensorflow and tensorflow-cpu + for pkg in candidates: + try: + _tf_version = importlib.metadata.version(pkg) + break + except importlib.metadata.PackageNotFoundError: + pass + _tf_available = _tf_version is not None + if _tf_available: + if version.parse(_tf_version) < version.parse("2"): + logger.info( + f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum." + ) + _tf_available = False + else: + logger.info("Disabling Tensorflow because USE_TORCH is set") -_librosa_available = _is_package_available("librosa") -_scipy_available = _is_package_available("scipy") -_triton_available = _is_package_available("triton") -_sacremoses_available = _is_package_available("sacremoses") -_torchaudio_available = _is_package_available("pykaldi") -_kenlm_available = _is_package_available("kenlm") -_datamodel_code_generator_availabel = _is_package_available('datamodel_code_generator') -_pretty_midi_available = importlib.util.find_spec("pretty_midi") is not None -try: - _pretty_midi_version = importlib_metadata.version("pretty_midi") - logger.debug(f"Successfully imported pretty_midi version {_pretty_midi_version}") -except importlib_metadata.PackageNotFoundError: - _pretty_midi_available = False _essentia_available = importlib.util.find_spec("essentia") is not None try: - _essentia_version = importlib_metadata.version("essentia") + _essentia_version = importlib.metadata.version("essentia") logger.debug(f"Successfully imported essentia version {_essentia_version}") -except importlib_metadata.PackageNotFoundError: +except importlib.metadata.PackageNotFoundError: _essentia_version = False -_levenshtein_available = _is_package_available("Levenshtein") -_nltk_available = _is_package_available("nltk") + +_pydantic_available = importlib.util.find_spec("pydantic") is not None +try: + _pydantic_version = importlib.metadata.version("pydantic") + logger.debug(f"Successfully imported pydantic version {_pydantic_version}") +except importlib.metadata.PackageNotFoundError: + _pydantic_available = False -_faiss_available = importlib.util.find_spec("faiss") is not None +_fastapi_available = importlib.util.find_spec("fastapi") is not None try: - _faiss_version = importlib.metadata.version("faiss") - logger.debug(f"Successfully imported faiss version {_faiss_version}") + _fastapi_version = importlib.metadata.version("fastapi") + logger.debug(f"Successfully imported pydantic version {_fastapi_version}") except importlib.metadata.PackageNotFoundError: - try: - _faiss_version = importlib.metadata.version("faiss-cpu") - logger.debug(f"Successfully imported faiss version {_faiss_version}") - except importlib.metadata.PackageNotFoundError: - _faiss_available = False + _fastapi_available = False -def is_triton_available(): - return _triton_available -def is_datamodel_code_generator_availabel(): - return _datamodel_code_generator_availabel +_uvicorn_available = importlib.util.find_spec("uvicorn") is not None +try: + _uvicorn_version = importlib.metadata.version("uvicorn") + logger.debug(f"Successfully imported pydantic version {_uvicorn_version}") +except importlib.metadata.PackageNotFoundError: + _uvicorn_available = False -def is_faiss_available(): - return _faiss_available -def is_levenshtein_available(): - return _levenshtein_available +_pretty_midi_available = importlib.util.find_spec("pretty_midi") is not None +try: + _pretty_midi_version = importlib.metadata.version("pretty_midi") + logger.debug(f"Successfully imported pretty_midi version {_pretty_midi_version}") +except importlib.metadata.PackageNotFoundError: + _pretty_midi_available = False -def is_nltk_available(): - return _nltk_available +ccl_version = "N/A" +_is_ccl_available = ( + importlib.util.find_spec("torch_ccl") is not None + or importlib.util.find_spec("oneccl_bindings_for_pytorch") is not None +) +try: + ccl_version = importlib.metadata.version("oneccl_bind_pt") + logger.debug(f"Detected oneccl_bind_pt version {ccl_version}") +except importlib.metadata.PackageNotFoundError: + _is_ccl_available = False -def is_einops_available(): - return _einops_available +_flax_available = False +if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: + _flax_available, _flax_version = _is_package_available("flax", return_version=True) + if _flax_available: + _jax_available, _jax_version = _is_package_available("jax", return_version=True) + if _jax_available: + logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") + else: + _flax_available = _jax_available = False + _jax_version = _flax_version = "N/A" -def is_sudachi_available(): - """ - Checks if SudachiPy is available for use. - - Returns: - None: Indicates whether SudachiPy is available or not. - - """ - return _sudachipy_available +_torch_xla_available = False +if USE_TORCH_XLA in ENV_VARS_TRUE_VALUES: + _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla", return_version=True) + if _torch_xla_available: + logger.info(f"Torch XLA version {_torch_xla_version} available.") -def get_sudachi_version(): - ''' - Returns the version of SudachiPy. - - Returns: - None: This function does not take any parameters. - - Raises: - None - ''' - return _sudachipy_version +def is_kenlm_available(): + return _kenlm_available -def is_bs4_available(): - return _bs4_available +def is_kernels_available(): + return _kernels_available -def is_sudachi_projection_available(): - """ - Checks if Sudachi projection is available. - - This function checks if Sudachi is available and if the Sudachi version is equal to or greater than 0.6.8. - - Returns: - None - - Raises: - None - """ - if not is_sudachi_available(): - return False - # NOTE: We require sudachipy>=0.6.8 to use projection option in sudachi_kwargs for the forwardor of BertJapaneseTokenizer. - # - `projection` option is not supported in sudachipy<0.6.8, see https://github.com/WorksApplications/sudachi.rs/issues/230 - return version.parse(_sudachipy_version) >= version.parse("0.6.8") +def is_cv2_available(): + return _cv2_available -def is_sacremoses_available(): - """ - Checks if the sacremoses library is available in the current environment. - - Returns: - None: Indicates whether the sacremoses library is available or not. - - Raises: - None. - """ - return _sacremoses_available +def is_yt_dlp_available(): + return _yt_dlp_available -def is_mindspore_available(): - ''' - Checks if MindSpore is available. - - Args: - None - - Returns: - None: Indicates that the function does not return any value. - - Raises: - None: No exceptions are raised by this function. - ''' - return _mindspore_available +def is_torch_available(): + return _torch_available -def get_mindspore_version(): - """ - Returns the current version of MindSpore. - - Args: - - Returns: - None: This function does not take any parameters. - - Raises: - None: This function does not raise any exceptions. - """ - return _mindspore_version +def is_libcst_available(): + return _libcst_available -def is_ftfy_available(): - return _ftfy_available +def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION): + return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version) -def is_datasets_available(): - """ - Checks if datasets are available. - - Returns: - None: This function does not return any value. - - Raises: - None: This function does not raise any exceptions. - """ - return _datasets_available +def is_torch_accelerator_available(): + if is_torch_available(): + import torch + return hasattr(torch, "accelerator") -def is_sentencepiece_available(): + return False + + +def is_torch_deterministic(): """ - Checks if SentencePiece library is available. - - Returns: - None: Indicates whether the SentencePiece library is available or not. - - Raises: - None. + Check whether pytorch uses deterministic algorithms by looking if torch.set_deterministic_debug_mode() is set to 1 or 2" """ - return _sentencepiece_available + if is_torch_available(): + import torch + if torch.get_deterministic_debug_mode() == 0: + return False + else: + return True -def is_tokenizers_available(): - """Check if tokenizers are available. - - This function checks if tokenizers are available for use. It does not take any parameters. - - Returns: - None: This function does not return any value. - - Raises: - None: This function does not raise any exceptions. - """ - return _tokenizers_available + return False -def is_safetensors_available(): - """ - Checks if SafeTensors is available in the current environment. - - Returns: - None: Indicates whether SafeTensors is available or not. - - """ - return _safetensors_available +def is_triton_available(min_version: str = TRITON_MIN_VERSION): + return _triton_available and version.parse(_triton_version) >= version.parse(min_version) -def is_modelscope_available(): - ''' - Checks if the model scope is available. - - Returns: - None: Indicates whether the model scope is available or not. - ''' - return _modelscope_available +def is_triton_kernels_availalble(): + return _triton_kernels_available -def is_cython_available(): - """ - Checks if Cython is available in the current environment. - - Returns: - None: Indicates whether Cython is available or not. - - Raises: - None - """ - return importlib.util.find_spec("pyximport") is not None +def is_hadamard_available(): + return _hadamard_available -def is_protobuf_available(): - """ - Checks if the Google Protocol Buffers (protobuf) library is available. - - Returns: - bool: True if the protobuf library is available, False otherwise. - - Raises: - No specific exceptions are raised by this function. - """ - if importlib.util.find_spec("google") is None: +def is_hqq_available(min_version: str = HQQ_MIN_VERSION): + return _hqq_available and version.parse(_hqq_version) >= version.parse(min_version) + + +def is_pygments_available(): + return _pygments_available + + +def get_torch_version(): + return _torch_version + + +def get_torch_major_and_minor_version() -> str: + if _torch_version == "N/A": + return "N/A" + parsed_version = version.parse(_torch_version) + return str(parsed_version.major) + "." + str(parsed_version.minor) + + +def is_torch_sdpa_available(): + if not is_torch_available(): + return False + elif _torch_version == "N/A": return False - return importlib.util.find_spec("google.protobuf") is not None + # NOTE: MLU is OK with non-contiguous inputs. + if is_torch_mlu_available(): + return True + # NOTE: NPU can use SDPA in Transformers with torch>=2.1.0. + if is_torch_npu_available(): + return True + # NOTE: We require torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous inputs: https://github.com/pytorch/pytorch/issues/112577 + return version.parse(_torch_version) >= version.parse("2.1.1") -def is_pytest_available(): - """ - Check if the pytest library is available. - - Returns: - None: This function does not return any value. - - """ - return _pytest_available +def is_torch_flex_attn_available(): + if not is_torch_available(): + return False + elif _torch_version == "N/A": + return False -def is_pretty_midi_available(): - """ - Checks if the 'pretty_midi' library is available. - - Returns: - None - - Raises: - None - """ - return _pretty_midi_available + # TODO check if some bugs cause push backs on the exact version + # NOTE: We require torch>=2.5.0 as it is the first release + return version.parse(_torch_version) >= version.parse("2.5.0") + + +def is_torchvision_available(): + return _torchvision_available + + +def is_torchvision_v2_available(): + if not is_torchvision_available(): + return False + + # NOTE: We require torchvision>=0.15 as v2 transforms are available from this version: https://pytorch.org/vision/stable/transforms.html#v1-or-v2-which-one-should-i-use + return version.parse(_torchvision_version) >= version.parse("0.15") + + +def is_galore_torch_available(): + return _galore_torch_available + + +def is_apollo_torch_available(): + return _apollo_torch_available + + +def is_torch_optimi_available(): + return _torch_optimi_available + + +def is_lomo_available(): + return _lomo_available + + +def is_grokadamw_available(): + return _grokadamw_available + + +def is_schedulefree_available(min_version: str = SCHEDULEFREE_MIN_VERSION): + return _schedulefree_available and version.parse(_schedulefree_version) >= version.parse(min_version) + + +def is_pyctcdecode_available(): + return _pyctcdecode_available def is_librosa_available(): - """ - Checks if the 'librosa' library is available. - - Returns: - None - - Raises: - None - """ return _librosa_available def is_essentia_available(): - """ - Checks if the 'essentia' library is available. - - Returns: - None. - - Raises: - None. - """ return _essentia_available -def is_pyctcdecode_available(): - """ - Check if the PyCTCDecode library is available. - - Returns: - None: This function does not return any value. - - Raises: - None - """ - return _pyctcdecode_available +def is_pydantic_available(): + return _pydantic_available -def is_scipy_available(): - """ - Checks if the SciPy library is available. - - Returns: - None: This function does not return any value. - - Raises: - None: This function does not raise any exceptions. - """ - return _scipy_available +def is_fastapi_available(): + return _fastapi_available -def is_jieba_available(): - ''' - Checks if the Jieba library is available. - - Returns: - None: The function does not return any value. - - ''' - return _jieba_available +def is_uvicorn_available(): + return _uvicorn_available -def is_pytesseract_available(): - """ - Check if pytesseract library is available. - - Returns: - None: This function does not return any value. - - Raises: - None: This function does not raise any exceptions. - """ - return _pytesseract_available +def is_openai_available(): + return _openai_available -def is_g2p_en_available(): - return _g2p_en_available +def is_pretty_midi_available(): + return _pretty_midi_available -def is_tiktoken_available(): - return _tiktoken_available +def is_torch_cuda_available(): + if is_torch_available(): + import torch + return torch.cuda.is_available() + else: + return False -def is_phonemizer_available(): - return _phonemizer_available +def is_cuda_platform(): + if is_torch_available(): + import torch -@lru_cache() -def is_vision_available(): - """ - Checks if the Pillow library is available for image processing. - - Returns: - bool: True if Pillow library is available, False otherwise. - - Raises: - PackageNotFoundError: If Pillow or Pillow-SIMD package is not found. - """ - _pil_available = importlib.util.find_spec("PIL") is not None - if _pil_available: - try: - package_version = importlib_metadata.version("Pillow") - except importlib_metadata.PackageNotFoundError: - try: - package_version = importlib_metadata.version("Pillow-SIMD") - except importlib_metadata.PackageNotFoundError: - return False - logger.debug(f"Detected PIL version {package_version}") - return _pil_available + return torch.version.cuda is not None + else: + return False -def is_in_notebook(): - """ - This function checks if the code is running in a Jupyter notebook environment by examining the current execution environment and relevant environment variables. - - Returns: - bool: Returns True if the code is running in a Jupyter notebook environment, otherwise False. - - Raises: - AttributeError: If an attribute error occurs during the execution of the function. - ImportError: If the code is running in the console, VS Code, or Databricks environment, respective ImportError with the environment name is raised. - KeyError: If a key error occurs during the execution of the function. - """ - try: - # Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py - get_ipython = sys.modules["IPython"].get_ipython - if "IPKernelApp" not in get_ipython().config: - raise ImportError("console") - if "VSCODE_PID" in os.environ: - raise ImportError("vscode") - if ( - "DATABRICKS_RUNTIME_VERSION" in os.environ - and os.environ["DATABRICKS_RUNTIME_VERSION"] < "11.0" - ): - # Databricks Runtime 11.0 and above uses IPython kernel by default so it should be compatible with Jupyter notebook - # https://docs.microsoft.com/en-us/azure/databricks/notebooks/ipython-kernel - raise ImportError("databricks") +def is_rocm_platform(): + if is_torch_available(): + import torch - return importlib.util.find_spec("IPython") is not None - except (AttributeError, ImportError, KeyError): + return torch.version.hip is not None + else: return False +def is_mamba_ssm_available(): + if is_torch_available(): + import torch + + if not torch.cuda.is_available(): + return False + else: + return _is_package_available("mamba_ssm") + return False + + +def is_mamba_2_ssm_available(): + if is_torch_available(): + import torch + + if not torch.cuda.is_available(): + return False + else: + if _is_package_available("mamba_ssm"): + import mamba_ssm + + if version.parse(mamba_ssm.__version__) >= version.parse("2.0.4"): + return True + return False + + +def is_causal_conv1d_available(): + if is_torch_available(): + import torch + + if not torch.cuda.is_available(): + return False + return _is_package_available("causal_conv1d") + return False + + +def is_xlstm_available(): + if is_torch_available(): + return _is_package_available("xlstm") + return False + + +def is_mambapy_available(): + if is_torch_available(): + return _is_package_available("mambapy") + return False + + +def is_torch_mps_available(min_version: Optional[str] = None): + if is_torch_available(): + import torch + + if hasattr(torch.backends, "mps"): + backend_available = torch.backends.mps.is_available() and torch.backends.mps.is_built() + if min_version is not None: + flag = version.parse(_torch_version) >= version.parse(min_version) + backend_available = backend_available and flag + return backend_available + return False + + +def is_torch_bf16_gpu_available() -> bool: + if not is_torch_available(): + return False + + import torch + + if torch.cuda.is_available(): + return torch.cuda.is_bf16_supported() + if is_torch_xpu_available(): + return torch.xpu.is_bf16_supported() + if is_torch_hpu_available(): + return True + if is_torch_npu_available(): + return torch.npu.is_bf16_supported() + return False + + +def is_torch_bf16_cpu_available() -> bool: + return is_torch_available() + + +def is_torch_bf16_available(): + # the original bf16 check was for gpu only, but later a cpu/bf16 combo has emerged so this util + # has become ambiguous and therefore deprecated + warnings.warn( + "The util is_torch_bf16_available is deprecated, please use is_torch_bf16_gpu_available " + "or is_torch_bf16_cpu_available instead according to whether it's used with cpu or gpu", + FutureWarning, + ) + return is_torch_bf16_gpu_available() + + +@lru_cache +def is_torch_fp16_available_on_device(device): + if not is_torch_available(): + return False + + if is_torch_hpu_available(): + if is_habana_gaudi1(): + return False + else: + return True + + import torch + + try: + x = torch.zeros(2, 2, dtype=torch.float16, device=device) + _ = x @ x + + # At this moment, let's be strict of the check: check if `LayerNorm` is also supported on device, because many + # models use this layer. + batch, sentence_length, embedding_dim = 3, 4, 5 + embedding = torch.randn(batch, sentence_length, embedding_dim, dtype=torch.float16, device=device) + layer_norm = torch.nn.LayerNorm(embedding_dim, dtype=torch.float16, device=device) + _ = layer_norm(embedding) + + except: # noqa: E722 + # TODO: more precise exception matching, if possible. + # most backends should return `RuntimeError` however this is not guaranteed. + return False + + return True + + +@lru_cache +def is_torch_bf16_available_on_device(device): + if not is_torch_available(): + return False + + import torch + + if device == "cuda": + return is_torch_bf16_gpu_available() + + if device == "hpu": + return True + + try: + x = torch.zeros(2, 2, dtype=torch.bfloat16, device=device) + _ = x @ x + except: # noqa: E722 + # TODO: more precise exception matching, if possible. + # most backends should return `RuntimeError` however this is not guaranteed. + return False + + return True + + +def is_torch_tf32_available(): + if not is_torch_available(): + return False + + import torch + + if not torch.cuda.is_available() or torch.version.cuda is None: + return False + if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8: + return False + return True + + +def is_torch_fx_available(): + return is_torch_available() + + +def is_peft_available(): + return _peft_available + + +def is_bs4_available(): + return _bs4_available + + +def is_tf_available(): + return _tf_available + + +def is_coloredlogs_available(): + return _coloredlogs_available + + +def is_tf2onnx_available(): + return _tf2onnx_available + + +def is_onnx_available(): + return _onnx_available + + +def is_flax_available(): + return _flax_available + + +def is_flute_available(): + try: + return importlib.util.find_spec("flute") is not None and importlib.metadata.version("flute-kernel") >= "0.4.1" + except importlib.metadata.PackageNotFoundError: + return False + + +def is_ftfy_available(): + return _ftfy_available + + +def is_g2p_en_available(): + return _g2p_en_available + + +@lru_cache +def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False): + """ + Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set + the USE_TORCH_XLA to false. + """ + assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true." + + if not _torch_xla_available: + return False + + import torch_xla + + if check_is_gpu: + return torch_xla.runtime.device_type() in ["GPU", "CUDA"] + elif check_is_tpu: + return torch_xla.runtime.device_type() == "TPU" + + return True + + +@lru_cache +def is_torch_neuroncore_available(check_device=True): + if importlib.util.find_spec("torch_neuronx") is not None: + return is_torch_xla_available() + return False + + +@lru_cache +def is_torch_npu_available(check_device=False): + "Checks if `torch_npu` is installed and potentially if a NPU is in the environment" + if not _torch_available or importlib.util.find_spec("torch_npu") is None: + return False + + import torch + import torch_npu # noqa: F401 + + if check_device: + try: + # Will raise a RuntimeError if no NPU is found + _ = torch.npu.device_count() + return torch.npu.is_available() + except RuntimeError: + return False + return hasattr(torch, "npu") and torch.npu.is_available() + + +@lru_cache +def is_torch_mlu_available(check_device=False): + """ + Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu + uninitialized. + """ + if not _torch_available or importlib.util.find_spec("torch_mlu") is None: + return False + + import torch + import torch_mlu # noqa: F401 + + pytorch_cndev_based_mlu_check_previous_value = os.environ.get("PYTORCH_CNDEV_BASED_MLU_CHECK") + try: + os.environ["PYTORCH_CNDEV_BASED_MLU_CHECK"] = str(1) + available = torch.mlu.is_available() + finally: + if pytorch_cndev_based_mlu_check_previous_value: + os.environ["PYTORCH_CNDEV_BASED_MLU_CHECK"] = pytorch_cndev_based_mlu_check_previous_value + else: + os.environ.pop("PYTORCH_CNDEV_BASED_MLU_CHECK", None) + + return available + + +@lru_cache +def is_torch_musa_available(check_device=False): + "Checks if `torch_musa` is installed and potentially if a MUSA is in the environment" + if not _torch_available or importlib.util.find_spec("torch_musa") is None: + return False + + import torch + import torch_musa # noqa: F401 + + torch_musa_min_version = "0.33.0" + if _accelerate_available and version.parse(_accelerate_version) < version.parse(torch_musa_min_version): + return False + + if check_device: + try: + # Will raise a RuntimeError if no MUSA is found + _ = torch.musa.device_count() + return torch.musa.is_available() + except RuntimeError: + return False + return hasattr(torch, "musa") and torch.musa.is_available() + + +@lru_cache +def is_torch_hpu_available(): + "Checks if `torch.hpu` is available and potentially if a HPU is in the environment" + if ( + not _torch_available + or importlib.util.find_spec("habana_frameworks") is None + or importlib.util.find_spec("habana_frameworks.torch") is None + ): + return False + + torch_hpu_min_accelerate_version = "1.5.0" + if _accelerate_available and version.parse(_accelerate_version) < version.parse(torch_hpu_min_accelerate_version): + return False + + import torch + + if os.environ.get("PT_HPU_LAZY_MODE", "1") == "1": + # import habana_frameworks.torch in case of lazy mode to patch torch with torch.hpu + import habana_frameworks.torch # noqa: F401 + + if not hasattr(torch, "hpu") or not torch.hpu.is_available(): + return False + + # We patch torch.gather for int64 tensors to avoid a bug on Gaudi + # Graph compile failed with synStatus 26 [Generic failure] + # This can be removed once bug is fixed but for now we need it. + original_gather = torch.gather + + def patched_gather(input: torch.Tensor, dim: int, index: torch.LongTensor) -> torch.Tensor: + if input.dtype == torch.int64 and input.device.type == "hpu": + return original_gather(input.to(torch.int32), dim, index).to(torch.int64) + else: + return original_gather(input, dim, index) + + torch.gather = patched_gather + torch.Tensor.gather = patched_gather + + original_take_along_dim = torch.take_along_dim + + def patched_take_along_dim( + input: torch.Tensor, indices: torch.LongTensor, dim: Optional[int] = None + ) -> torch.Tensor: + if input.dtype == torch.int64 and input.device.type == "hpu": + return original_take_along_dim(input.to(torch.int32), indices, dim).to(torch.int64) + else: + return original_take_along_dim(input, indices, dim) + + torch.take_along_dim = patched_take_along_dim + + original_cholesky = torch.linalg.cholesky + + def safe_cholesky(A, *args, **kwargs): + output = original_cholesky(A, *args, **kwargs) + + if torch.isnan(output).any(): + jitter_value = 1e-9 + diag_jitter = torch.eye(A.size(-1), dtype=A.dtype, device=A.device) * jitter_value + output = original_cholesky(A + diag_jitter, *args, **kwargs) + + return output + + torch.linalg.cholesky = safe_cholesky + + original_scatter = torch.scatter + + def patched_scatter( + input: torch.Tensor, dim: int, index: torch.Tensor, src: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + if input.device.type == "hpu" and input is src: + return original_scatter(input, dim, index, src.clone(), *args, **kwargs) + else: + return original_scatter(input, dim, index, src, *args, **kwargs) + + torch.scatter = patched_scatter + torch.Tensor.scatter = patched_scatter + + # IlyasMoutawwakil: we patch torch.compile to use the HPU backend by default + # https://github.com/huggingface/transformers/pull/38790#discussion_r2157043944 + # This is necessary for cases where torch.compile is used as a decorator (defaulting to inductor) + # https://github.com/huggingface/transformers/blob/af6120b3eb2470b994c21421bb6eaa76576128b0/src/transformers/models/modernbert/modeling_modernbert.py#L204 + original_compile = torch.compile + + def hpu_backend_compile(*args, **kwargs): + if kwargs.get("backend") not in ["hpu_backend", "eager"]: + logger.warning( + f"Calling torch.compile with backend={kwargs.get('backend')} on a Gaudi device is not supported. " + "We will override the backend with 'hpu_backend' to avoid errors." + ) + kwargs["backend"] = "hpu_backend" + + return original_compile(*args, **kwargs) + + torch.compile = hpu_backend_compile + + return True + + +@lru_cache +def is_habana_gaudi1(): + if not is_torch_hpu_available(): + return False + + import habana_frameworks.torch.utils.experimental as htexp # noqa: F401 + + # Check if the device is Gaudi1 (vs Gaudi2, Gaudi3) + return htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi + + +def is_torchdynamo_available(): + return is_torch_available() + + +def is_torch_compile_available(): + return is_torch_available() + + +def is_torchdynamo_compiling(): + if not is_torch_available(): + return False + + # Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622) + # hence rather relying on `torch.compiler.is_compiling()` when possible (torch>=2.3) + try: + import torch + + return torch.compiler.is_compiling() + except Exception: + try: + import torch._dynamo as dynamo # noqa: F401 + + return dynamo.is_compiling() + except Exception: + return False + + +def is_torchdynamo_exporting(): + if not is_torch_available(): + return False + + try: + import torch + + return torch.compiler.is_exporting() + except Exception: + try: + import torch._dynamo as dynamo # noqa: F401 + + return dynamo.is_exporting() + except Exception: + return False + + +def is_torch_tensorrt_fx_available(): + if importlib.util.find_spec("torch_tensorrt") is None: + return False + return importlib.util.find_spec("torch_tensorrt.fx") is not None + + +def is_datasets_available(): + return _datasets_available + + +def is_detectron2_available(): + return _detectron2_available + + +def is_rjieba_available(): + return _rjieba_available + + +def is_psutil_available(): + return _psutil_available + + +def is_py3nvml_available(): + return _py3nvml_available + + +def is_sacremoses_available(): + return _sacremoses_available + + +def is_apex_available(): + return _apex_available + + +def is_aqlm_available(): + return _aqlm_available + + +def is_vptq_available(min_version: str = VPTQ_MIN_VERSION): + return _vptq_available and version.parse(_vptq_version) >= version.parse(min_version) + + +def is_av_available(): + return _av_available + + +def is_decord_available(): + return _decord_available + + +def is_torchcodec_available(): + return _torchcodec_available + + +def is_ninja_available(): + r""" + Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the + [ninja](https://ninja-build.org/) build system is available on the system, `False` otherwise. + """ + try: + subprocess.check_output("ninja --version".split()) + except Exception: + return False + else: + return True + + +def is_ipex_available(min_version: str = ""): + def get_major_and_minor_from_version(full_version): + return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor) + + if not is_torch_available() or not _ipex_available: + return False + + torch_major_and_minor = get_major_and_minor_from_version(_torch_version) + ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version) + if torch_major_and_minor != ipex_major_and_minor: + logger.warning( + f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*," + f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again." + ) + return False + if min_version: + return version.parse(_ipex_version) >= version.parse(min_version) + return True + + +@lru_cache +def is_torch_xpu_available(check_device=False): + """ + Checks if XPU acceleration is available either via native PyTorch (>=2.6), + `intel_extension_for_pytorch` or via stock PyTorch (>=2.4) and potentially + if a XPU is in the environment. + """ + if not is_torch_available(): + return False + + torch_version = version.parse(_torch_version) + if torch_version.major == 2 and torch_version.minor < 6: + if is_ipex_available(): + import intel_extension_for_pytorch # noqa: F401 + elif torch_version.major == 2 and torch_version.minor < 4: + return False + + import torch + + if check_device: + try: + # Will raise a RuntimeError if no XPU is found + _ = torch.xpu.device_count() + return torch.xpu.is_available() + except RuntimeError: + return False + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +@lru_cache +def is_bitsandbytes_available(check_library_only=False) -> bool: + if not _bitsandbytes_available: + return False + + if check_library_only: + return True + + if not is_torch_available(): + return False + + import torch + + # `bitsandbytes` versions older than 0.43.1 eagerly require CUDA at import time, + # so those versions of the library are practically only available when CUDA is too. + if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.1"): + return torch.cuda.is_available() + + # Newer versions of `bitsandbytes` can be imported on systems without CUDA. + return True + + +def is_bitsandbytes_multi_backend_available() -> bool: + if not is_bitsandbytes_available(): + return False + + import bitsandbytes as bnb + + return "multi_backend" in getattr(bnb, "features", set()) + + +def is_flash_attn_2_available(): + if not is_torch_available(): + return False + + if not _is_package_available("flash_attn"): + return False + + # Let's add an extra check to see if cuda is available + import torch + + if not (torch.cuda.is_available() or is_torch_mlu_available()): + return False + + if torch.version.cuda: + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0") + elif torch.version.hip: + # TODO: Bump the requirement to 2.1.0 once released in https://github.com/ROCmSoftwarePlatform/flash-attention + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.0.4") + elif is_torch_mlu_available(): + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.3.3") + else: + return False + + +@lru_cache +def is_flash_attn_3_available(): + if not is_torch_available(): + return False + + if not _is_package_available("flash_attn_3"): + return False + + import torch + + if not torch.cuda.is_available(): + return False + + # TODO: Check for a minimum version when FA3 is stable + # return version.parse(importlib.metadata.version("flash_attn_3")) >= version.parse("3.0.0") + + return True + + +@lru_cache +def is_flash_attn_greater_or_equal_2_10(): + if not _is_package_available("flash_attn"): + return False + + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0") + + +@lru_cache +def is_flash_attn_greater_or_equal(library_version: str): + if not _is_package_available("flash_attn"): + return False + + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version) + + +@lru_cache +def is_torch_greater_or_equal(library_version: str, accept_dev: bool = False): + """ + Accepts a library version and returns True if the current version of the library is greater than or equal to the + given version. If `accept_dev` is True, it will also accept development versions (e.g. 2.7.0.dev20250320 matches + 2.7.0). + """ + if not _is_package_available("torch"): + return False + + if accept_dev: + return version.parse(version.parse(importlib.metadata.version("torch")).base_version) >= version.parse( + library_version + ) + else: + return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version) + + +@lru_cache +def is_torch_less_or_equal(library_version: str, accept_dev: bool = False): + """ + Accepts a library version and returns True if the current version of the library is less than or equal to the + given version. If `accept_dev` is True, it will also accept development versions (e.g. 2.7.0.dev20250320 matches + 2.7.0). + """ + if not _is_package_available("torch"): + return False + + if accept_dev: + return version.parse(version.parse(importlib.metadata.version("torch")).base_version) <= version.parse( + library_version + ) + else: + return version.parse(importlib.metadata.version("torch")) <= version.parse(library_version) + + +@lru_cache +def is_huggingface_hub_greater_or_equal(library_version: str, accept_dev: bool = False): + if not _is_package_available("huggingface_hub"): + return False + + if accept_dev: + return version.parse( + version.parse(importlib.metadata.version("huggingface_hub")).base_version + ) >= version.parse(library_version) + else: + return version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(library_version) + + +def is_torchdistx_available(): + return _torchdistx_available + + +def is_faiss_available(): + return _faiss_available + + +def is_scipy_available(): + return _scipy_available + + +def is_sklearn_available(): + return _sklearn_available + + +def is_sentencepiece_available(): + return _sentencepiece_available + + +def is_seqio_available(): + return _is_seqio_available + + +def is_gguf_available(min_version: str = GGUF_MIN_VERSION): + return _is_gguf_available and version.parse(_gguf_version) >= version.parse(min_version) + + +def is_protobuf_available(): + if importlib.util.find_spec("google") is None: + return False + return importlib.util.find_spec("google.protobuf") is not None + + +def is_fsdp_available(min_version: str = FSDP_MIN_VERSION): + return is_torch_available() and version.parse(_torch_version) >= version.parse(min_version) + + +def is_optimum_available(): + return _optimum_available + + +def is_auto_awq_available(): + return _auto_awq_available + + +def is_auto_round_available(min_version: str = AUTOROUND_MIN_VERSION): + return _auto_round_available and version.parse(_auto_round_version) >= version.parse(min_version) + + +def is_optimum_quanto_available(): + # `importlib.metadata.version` doesn't work with `optimum.quanto`, need to put `optimum_quanto` + return _is_optimum_quanto_available + + +def is_quark_available(): + return _quark_available + + +def is_fp_quant_available(): + return _fp_quant_available and version.parse(_fp_quant_version) >= version.parse("0.1.6") + + +def is_qutlass_available(): + return _qutlass_available + + +def is_compressed_tensors_available(): + return _compressed_tensors_available + + +def is_auto_gptq_available(): + return _auto_gptq_available + + +def is_gptqmodel_available(): + return _gptqmodel_available + + +def is_eetq_available(): + return _eetq_available + + +def is_fbgemm_gpu_available(): + return _fbgemm_gpu_available + + +def is_levenshtein_available(): + return _levenshtein_available + + +def is_optimum_neuron_available(): + return _optimum_available and _is_package_available("optimum.neuron") + + +def is_safetensors_available(): + return _safetensors_available + + +def is_tokenizers_available(): + return _tokenizers_available + + +@lru_cache +def is_vision_available(): + _pil_available = importlib.util.find_spec("PIL") is not None + if _pil_available: + try: + package_version = importlib.metadata.version("Pillow") + except importlib.metadata.PackageNotFoundError: + try: + package_version = importlib.metadata.version("Pillow-SIMD") + except importlib.metadata.PackageNotFoundError: + return False + logger.debug(f"Detected PIL version {package_version}") + return _pil_available + + +def is_pytesseract_available(): + return _pytesseract_available + + +def is_pytest_available(): + return _pytest_available + + +def is_spacy_available(): + return _spacy_available + + +def is_tensorflow_text_available(): + return is_tf_available() and _tensorflow_text_available + + +def is_keras_nlp_available(): + return is_tensorflow_text_available() and _keras_nlp_available + + +def is_in_notebook(): + try: + # Check if we are running inside Marimo + if "marimo" in sys.modules: + return True + # Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py + get_ipython = sys.modules["IPython"].get_ipython + if "IPKernelApp" not in get_ipython().config: + raise ImportError("console") + # Removed the lines to include VSCode + if "DATABRICKS_RUNTIME_VERSION" in os.environ and os.environ["DATABRICKS_RUNTIME_VERSION"] < "11.0": + # Databricks Runtime 11.0 and above uses IPython kernel by default so it should be compatible with Jupyter notebook + # https://docs.microsoft.com/en-us/azure/databricks/notebooks/ipython-kernel + raise ImportError("databricks") + + return importlib.util.find_spec("IPython") is not None + except (AttributeError, ImportError, KeyError): + return False + + +def is_pytorch_quantization_available(): + return _pytorch_quantization_available + + +def is_tensorflow_probability_available(): + return _tensorflow_probability_available + + +def is_pandas_available(): + return _pandas_available + + +def is_sagemaker_dp_enabled(): + # Get the sagemaker specific env variable. + sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}") + try: + # Parse it and check the field "sagemaker_distributed_dataparallel_enabled". + sagemaker_params = json.loads(sagemaker_params) + if not sagemaker_params.get("sagemaker_distributed_dataparallel_enabled", False): + return False + except json.JSONDecodeError: + return False + # Lastly, check if the `smdistributed` module is present. + return _smdistributed_available + + +def is_sagemaker_mp_enabled(): + # Get the sagemaker specific mp parameters from smp_options variable. + smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}") + try: + # Parse it and check the field "partitions" is included, it is required for model parallel. + smp_options = json.loads(smp_options) + if "partitions" not in smp_options: + return False + except json.JSONDecodeError: + return False + + # Get the sagemaker specific framework parameters from mpi_options variable. + mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}") + try: + # Parse it and check the field "sagemaker_distributed_dataparallel_enabled". + mpi_options = json.loads(mpi_options) + if not mpi_options.get("sagemaker_mpi_enabled", False): + return False + except json.JSONDecodeError: + return False + # Lastly, check if the `smdistributed` module is present. + return _smdistributed_available + + +def is_training_run_on_sagemaker(): + return "SAGEMAKER_JOB_NAME" in os.environ + + +def is_soundfile_available(): + return _soundfile_available + + +def is_timm_available(): + return _timm_available + + +def is_natten_available(): + return _natten_available + + +def is_nltk_available(): + return _nltk_available + + +def is_torchaudio_available(): + return _torchaudio_available + + +def is_torchao_available(min_version: str = TORCHAO_MIN_VERSION): + return _torchao_available and version.parse(_torchao_version) >= version.parse(min_version) + + +def is_speech_available(): + # For now this depends on torchaudio but the exact dependency might evolve in the future. + return _torchaudio_available + + +def is_spqr_available(): + return _spqr_available + + +def is_phonemizer_available(): + return _phonemizer_available + + +def is_uroman_available(): + return _uroman_available + + +def torch_only_method(fn): + def wrapper(*args, **kwargs): + if not _torch_available: + raise ImportError( + "You need to install pytorch to use this method or class, " + "or activate it with environment variables USE_TORCH=1 and USE_TF=0." + ) + else: + return fn(*args, **kwargs) + + return wrapper + + +def is_ccl_available(): + return _is_ccl_available + + +def is_sudachi_available(): + return _sudachipy_available + + +def get_sudachi_version(): + return _sudachipy_version + + +def is_sudachi_projection_available(): + if not is_sudachi_available(): + return False + + # NOTE: We require sudachipy>=0.6.8 to use projection option in sudachi_kwargs for the constructor of BertJapaneseTokenizer. + # - `projection` option is not supported in sudachipy<0.6.8, see https://github.com/WorksApplications/sudachi.rs/issues/230 + return version.parse(_sudachipy_version) >= version.parse("0.6.8") + + +def is_jumanpp_available(): + return (importlib.util.find_spec("rhoknp") is not None) and (shutil.which("jumanpp") is not None) + + +def is_cython_available(): + return importlib.util.find_spec("pyximport") is not None + + +def is_jieba_available(): + return _jieba_available + + +def is_jinja_available(): + return _jinja_available + + +def is_mlx_available(): + return _mlx_available + + +def is_num2words_available(): + return _num2words_available + + +def is_tiktoken_available(): + return _tiktoken_available and _blobfile_available + + +def is_liger_kernel_available(): + if not _liger_kernel_available: + return False + + return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.3.0") + + +def is_rich_available(): + return _rich_available + + +def is_matplotlib_available(): + return _matplotlib_available + + +def is_mistral_common_available(): + return _mistral_common_available + + +def check_torch_load_is_safe(): + if not is_torch_greater_or_equal("2.6"): + raise ValueError( + "Due to a serious vulnerability issue in `torch.load`, even with `weights_only=True`, we now require users " + "to upgrade torch to at least v2.6 in order to use the function. This version restriction does not apply " + "when loading files with safetensors." + "\nSee the vulnerability report here https://nvd.nist.gov/vuln/detail/CVE-2025-32434" + ) + + # docstyle-ignore -CYTHON_IMPORT_ERROR = """ -{0} requires the Cython library but it was not found in your environment. You can install it with pip: `pip install -Cython`. Please note that you may need to restart your runtime after installation. +AV_IMPORT_ERROR = """ +{0} requires the PyAv library but it was not found in your environment. You can install it with: +``` +pip install av +``` +Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +YT_DLP_IMPORT_ERROR = """ +{0} requires the YT-DLP library but it was not found in your environment. You can install it with: +``` +pip install yt-dlp +``` +Please note that you may need to restart your runtime after installation. +""" + +DECORD_IMPORT_ERROR = """ +{0} requires the PyAv library but it was not found in your environment. You can install it with: +``` +pip install decord +``` +Please note that you may need to restart your runtime after installation. +""" + +TORCHCODEC_IMPORT_ERROR = """ +{0} requires the TorchCodec (https://github.com/pytorch/torchcodec) library, but it was not found in your environment. You can install it with: +``` +pip install torchcodec +``` +Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +CV2_IMPORT_ERROR = """ +{0} requires the OpenCV library but it was not found in your environment. You can install it with: +``` +pip install opencv-python +``` +Please note that you may need to restart your runtime after installation. """ + # docstyle-ignore DATASETS_IMPORT_ERROR = """ {0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with: @@ -524,6 +1688,7 @@ def is_in_notebook(): that python file if that's the case. Please note that you may need to restart your runtime after installation. """ + # docstyle-ignore TOKENIZERS_IMPORT_ERROR = """ {0} requires the 🤗 Tokenizers library but it was not found in your environment. You can install it with: @@ -537,113 +1702,418 @@ def is_in_notebook(): Please note that you may need to restart your runtime after installation. """ + # docstyle-ignore SENTENCEPIECE_IMPORT_ERROR = """ -{0} requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the +{0} requires the SentencePiece library but it was not found in your environment. Check out the instructions on the installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones that match your environment. Please note that you may need to restart your runtime after installation. """ + # docstyle-ignore PROTOBUF_IMPORT_ERROR = """ -{0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the +{0} requires the protobuf library but it was not found in your environment. Check out the instructions on the installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and follow the ones that match your environment. Please note that you may need to restart your runtime after installation. """ + +# docstyle-ignore +FAISS_IMPORT_ERROR = """ +{0} requires the faiss library but it was not found in your environment. Check out the instructions on the +installation page of its repo: https://github.com/facebookresearch/faiss/blob/master/INSTALL.md and follow the ones +that match your environment. Please note that you may need to restart your runtime after installation. +""" + + # docstyle-ignore -MINDSPORE_IMPORT_ERROR = """ -{0} requires the MindSpore library but it was not found in your environment. Checkout the instructions on the -installation page: https://www.mindspore.cn/install/ and follow the ones that match your environment. +PYTORCH_IMPORT_ERROR = """ +{0} requires the PyTorch library but it was not found in your environment. Check out the instructions on the +installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. Please note that you may need to restart your runtime after installation. """ -LIBROSA_IMPORT_ERROR = """ -{0} requires thes librosa library. But that was not found in your environment. You can install them with pip: -`pip install librosa` + +# docstyle-ignore +TORCHVISION_IMPORT_ERROR = """ +{0} requires the Torchvision library but it was not found in your environment. Check out the instructions on the +installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. Please note that you may need to restart your runtime after installation. """ -ESSENTIA_IMPORT_ERROR = """ -{0} requires essentia library. But that was not found in your environment. You can install them with pip: -`pip install essentia==2.1b6.dev1034` +# docstyle-ignore +PYTORCH_IMPORT_ERROR_WITH_TF = """ +{0} requires the PyTorch library but it was not found in your environment. +However, we were able to find a TensorFlow installation. TensorFlow classes begin +with "TF", but are otherwise identically named to our PyTorch classes. This +means that the TF equivalent of the class you tried to import would be "TF{0}". +If you want to use TensorFlow, please use TF classes instead! + +If you really do want to use PyTorch please go to +https://pytorch.org/get-started/locally/ and follow the instructions that +match your environment. +""" + +# docstyle-ignore +TF_IMPORT_ERROR_WITH_PYTORCH = """ +{0} requires the TensorFlow library but it was not found in your environment. +However, we were able to find a PyTorch installation. PyTorch classes do not begin +with "TF", but are otherwise identically named to our TF classes. +If you want to use PyTorch, please use those classes instead! + +If you really do want to use TensorFlow, please follow the instructions on the +installation page https://www.tensorflow.org/install that match your environment. +""" + +# docstyle-ignore +BS4_IMPORT_ERROR = """ +{0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip: +`pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +SKLEARN_IMPORT_ERROR = """ +{0} requires the scikit-learn library but it was not found in your environment. You can install it with: +``` +pip install -U scikit-learn +``` +In a notebook or a colab, you can install it by executing a cell with +``` +!pip install -U scikit-learn +``` +Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +TENSORFLOW_IMPORT_ERROR = """ +{0} requires the TensorFlow library but it was not found in your environment. Check out the instructions on the +installation page: https://www.tensorflow.org/install and follow the ones that match your environment. +Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +DETECTRON2_IMPORT_ERROR = """ +{0} requires the detectron2 library but it was not found in your environment. Check out the instructions on the +installation page: https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md and follow the ones +that match your environment. Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +FLAX_IMPORT_ERROR = """ +{0} requires the FLAX library but it was not found in your environment. Check out the instructions on the +installation page: https://github.com/google/flax and follow the ones that match your environment. +Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +FTFY_IMPORT_ERROR = """ +{0} requires the ftfy library but it was not found in your environment. Check out the instructions on the +installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones +that match your environment. Please note that you may need to restart your runtime after installation. +""" + +LEVENSHTEIN_IMPORT_ERROR = """ +{0} requires the python-Levenshtein library but it was not found in your environment. You can install it with pip: `pip +install python-Levenshtein`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +G2P_EN_IMPORT_ERROR = """ +{0} requires the g2p-en library but it was not found in your environment. You can install it with pip: +`pip install g2p-en`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +PYTORCH_QUANTIZATION_IMPORT_ERROR = """ +{0} requires the pytorch-quantization library but it was not found in your environment. You can install it with pip: +`pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com` +Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +TENSORFLOW_PROBABILITY_IMPORT_ERROR = """ +{0} requires the tensorflow_probability library but it was not found in your environment. You can install it with pip as +explained here: https://github.com/tensorflow/probability. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +TENSORFLOW_TEXT_IMPORT_ERROR = """ +{0} requires the tensorflow_text library but it was not found in your environment. You can install it with pip as +explained here: https://www.tensorflow.org/text/guide/tf_text_intro. +Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +TORCHAUDIO_IMPORT_ERROR = """ +{0} requires the torchaudio library but it was not found in your environment. Please install it and restart your +runtime. +""" + +# docstyle-ignore +PANDAS_IMPORT_ERROR = """ +{0} requires the pandas library but it was not found in your environment. You can install it with pip as +explained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/install.html. Please note that you may need to restart your runtime after installation. """ + +# docstyle-ignore +PHONEMIZER_IMPORT_ERROR = """ +{0} requires the phonemizer library but it was not found in your environment. You can install it with pip: +`pip install phonemizer`. Please note that you may need to restart your runtime after installation. +""" +# docstyle-ignore +UROMAN_IMPORT_ERROR = """ +{0} requires the uroman library but it was not found in your environment. You can install it with pip: +`pip install uroman`. Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +SACREMOSES_IMPORT_ERROR = """ +{0} requires the sacremoses library but it was not found in your environment. You can install it with pip: +`pip install sacremoses`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore SCIPY_IMPORT_ERROR = """ {0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install scipy`. Please note that you may need to restart your runtime after installation. """ -PRETTY_MIDI_IMPORT_ERROR = """ -{0} requires thes pretty_midi library. But that was not found in your environment. You can install them with pip: -`pip install pretty_midi` +# docstyle-ignore +KERAS_NLP_IMPORT_ERROR = """ +{0} requires the keras_nlp library but it was not found in your environment. You can install it with pip. Please note that you may need to restart your runtime after installation. """ +# docstyle-ignore +SPEECH_IMPORT_ERROR = """ +{0} requires the torchaudio library but it was not found in your environment. You can install it with pip: +`pip install torchaudio`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +TIMM_IMPORT_ERROR = """ +{0} requires the timm library but it was not found in your environment. You can install it with pip: +`pip install timm`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +NATTEN_IMPORT_ERROR = """ +{0} requires the natten library but it was not found in your environment. You can install it by referring to: +shi-labs.com/natten . You can also install it with pip (may take longer to build): +`pip install natten`. Please note that you may need to restart your runtime after installation. +""" + +NUMEXPR_IMPORT_ERROR = """ +{0} requires the numexpr library but it was not found in your environment. You can install it by referring to: +https://numexpr.readthedocs.io/en/latest/index.html. +""" + + +# docstyle-ignore +NLTK_IMPORT_ERROR = """ +{0} requires the NLTK library but it was not found in your environment. You can install it by referring to: +https://www.nltk.org/install.html. Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +VISION_IMPORT_ERROR = """ +{0} requires the PIL library but it was not found in your environment. You can install it with pip: +`pip install pillow`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +PYDANTIC_IMPORT_ERROR = """ +{0} requires the pydantic library but it was not found in your environment. You can install it with pip: +`pip install pydantic`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +FASTAPI_IMPORT_ERROR = """ +{0} requires the fastapi library but it was not found in your environment. You can install it with pip: +`pip install fastapi`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +UVICORN_IMPORT_ERROR = """ +{0} requires the uvicorn library but it was not found in your environment. You can install it with pip: +`pip install uvicorn`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +OPENAI_IMPORT_ERROR = """ +{0} requires the openai library but it was not found in your environment. You can install it with pip: +`pip install openai`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +PYTESSERACT_IMPORT_ERROR = """ +{0} requires the PyTesseract library but it was not found in your environment. You can install it with pip: +`pip install pytesseract`. Please note that you may need to restart your runtime after installation. +""" + # docstyle-ignore PYCTCDECODE_IMPORT_ERROR = """ {0} requires the pyctcdecode library but it was not found in your environment. You can install it with pip: `pip install pyctcdecode`. Please note that you may need to restart your runtime after installation. """ +# docstyle-ignore +ACCELERATE_IMPORT_ERROR = """ +{0} requires the accelerate library >= {ACCELERATE_MIN_VERSION} it was not found in your environment. +You can install or update it with pip: `pip install --upgrade accelerate`. Please note that you may need to restart your +runtime after installation. +""" + +# docstyle-ignore +CCL_IMPORT_ERROR = """ +{0} requires the torch ccl library but it was not found in your environment. You can install it with pip: +`pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable` +Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +ESSENTIA_IMPORT_ERROR = """ +{0} requires essentia library. But that was not found in your environment. You can install them with pip: +`pip install essentia==2.1b6.dev1034` +Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +LIBROSA_IMPORT_ERROR = """ +{0} requires the librosa library. But that was not found in your environment. You can install them with pip: +`pip install librosa` +Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +PRETTY_MIDI_IMPORT_ERROR = """ +{0} requires the pretty_midi library. But that was not found in your environment. You can install them with pip: +`pip install pretty_midi` +Please note that you may need to restart your runtime after installation. +""" + + +CYTHON_IMPORT_ERROR = """ +{0} requires the Cython library but it was not found in your environment. You can install it with pip: `pip install +Cython`. Please note that you may need to restart your runtime after installation. +""" + JIEBA_IMPORT_ERROR = """ {0} requires the jieba library but it was not found in your environment. You can install it with pip: `pip install jieba`. Please note that you may need to restart your runtime after installation. """ -VISION_IMPORT_ERROR = """ -{0} requires the PIL library but it was not found in your environment. You can install it with pip: -`pip install pillow`. Please note that you may need to restart your runtime after installation. +PEFT_IMPORT_ERROR = """ +{0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install +peft`. Please note that you may need to restart your runtime after installation. """ -# docstyle-ignore -G2P_EN_IMPORT_ERROR = """ -{0} requires the g2p-en library but it was not found in your environment. You can install it with pip: -`pip install g2p-en`. Please note that you may need to restart your runtime after installation. +JINJA_IMPORT_ERROR = """ +{0} requires the jinja library but it was not found in your environment. You can install it with pip: `pip install +jinja2`. Please note that you may need to restart your runtime after installation. +""" + +RICH_IMPORT_ERROR = """ +{0} requires the rich library but it was not found in your environment. You can install it with pip: `pip install +rich`. Please note that you may need to restart your runtime after installation. +""" + +MISTRAL_COMMON_IMPORT_ERROR = """ +{0} requires the mistral-common library but it was not found in your environment. You can install it with pip: `pip install mistral-common`. Please note that you may need to restart your runtime after installation. """ + BACKENDS_MAPPING = OrderedDict( [ - ("mindspore", (is_mindspore_available, MINDSPORE_IMPORT_ERROR)), - ("cython", (is_cython_available, CYTHON_IMPORT_ERROR)), + ("av", (is_av_available, AV_IMPORT_ERROR)), + ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), + ("cv2", (is_cv2_available, CV2_IMPORT_ERROR)), ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), + ("decord", (is_decord_available, DECORD_IMPORT_ERROR)), + ("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)), + ("essentia", (is_essentia_available, ESSENTIA_IMPORT_ERROR)), + ("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)), + ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), + ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), + ("g2p_en", (is_g2p_en_available, G2P_EN_IMPORT_ERROR)), + ("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)), + ("phonemizer", (is_phonemizer_available, PHONEMIZER_IMPORT_ERROR)), + ("uroman", (is_uroman_available, UROMAN_IMPORT_ERROR)), + ("pretty_midi", (is_pretty_midi_available, PRETTY_MIDI_IMPORT_ERROR)), + ("levenshtein", (is_levenshtein_available, LEVENSHTEIN_IMPORT_ERROR)), + ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), ("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)), + ("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)), + ("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)), + ("sacremoses", (is_sacremoses_available, SACREMOSES_IMPORT_ERROR)), + ("pytorch_quantization", (is_pytorch_quantization_available, PYTORCH_QUANTIZATION_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), + ("sklearn", (is_sklearn_available, SKLEARN_IMPORT_ERROR)), + ("speech", (is_speech_available, SPEECH_IMPORT_ERROR)), + ("tensorflow_probability", (is_tensorflow_probability_available, TENSORFLOW_PROBABILITY_IMPORT_ERROR)), + ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), + ("tensorflow_text", (is_tensorflow_text_available, TENSORFLOW_TEXT_IMPORT_ERROR)), + ("timm", (is_timm_available, TIMM_IMPORT_ERROR)), + ("torchaudio", (is_torchaudio_available, TORCHAUDIO_IMPORT_ERROR)), + ("natten", (is_natten_available, NATTEN_IMPORT_ERROR)), + ("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)), ("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)), - ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), - ("essentia", (is_essentia_available, ESSENTIA_IMPORT_ERROR)), + ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), + ("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)), + ("torchcodec", (is_torchcodec_available, TORCHCODEC_IMPORT_ERROR)), + ("vision", (is_vision_available, VISION_IMPORT_ERROR)), ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), - ("pretty_midi", (is_pretty_midi_available, PRETTY_MIDI_IMPORT_ERROR)), - ("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)), + ("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)), + ("oneccl_bind_pt", (is_ccl_available, CCL_IMPORT_ERROR)), + ("cython", (is_cython_available, CYTHON_IMPORT_ERROR)), ("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)), - ("vision", (is_vision_available, VISION_IMPORT_ERROR)), - ("g2p_en", (is_g2p_en_available, G2P_EN_IMPORT_ERROR)), + ("peft", (is_peft_available, PEFT_IMPORT_ERROR)), + ("jinja", (is_jinja_available, JINJA_IMPORT_ERROR)), + ("yt_dlp", (is_yt_dlp_available, YT_DLP_IMPORT_ERROR)), + ("rich", (is_rich_available, RICH_IMPORT_ERROR)), + ("keras_nlp", (is_keras_nlp_available, KERAS_NLP_IMPORT_ERROR)), + ("pydantic", (is_pydantic_available, PYDANTIC_IMPORT_ERROR)), + ("fastapi", (is_fastapi_available, FASTAPI_IMPORT_ERROR)), + ("uvicorn", (is_uvicorn_available, UVICORN_IMPORT_ERROR)), + ("openai", (is_openai_available, OPENAI_IMPORT_ERROR)), + ("mistral-common", (is_mistral_common_available, MISTRAL_COMMON_IMPORT_ERROR)), ] ) def requires_backends(obj, backends): - """ - Function to check if the specified backends are available for the given object. - - Args: - obj (object): The object for which backends availability needs to be checked. - backends (list or tuple or str): The backend(s) to be checked for availability. Can be a single backend as a string or a list/tuple of backends. - - Returns: - None. This function does not return any value. - - Raises: - ImportError: If any of the specified backends are not available for the object. - """ if not isinstance(backends, (list, tuple)): backends = [backends] name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ - checks = (BACKENDS_MAPPING[backend] for backend in backends) - failed = [msg.format(name) for available, msg in checks if not available()] + # Raise an error for users who might not realize that classes without "TF" are torch-only + if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available(): + raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name)) + + # Raise the inverse error for PyTorch users trying to load TF classes + if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available(): + raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name)) + + failed = [] + for backend in backends: + if isinstance(backend, Backend): + available, msg = backend.is_satisfied, backend.error_message + else: + available, msg = BACKENDS_MAPPING[backend] + + if not available(): + failed.append(msg.format(name)) + if failed: raise ImportError("".join(failed)) @@ -653,84 +2123,751 @@ class DummyObject(type): Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by `requires_backend` each time a user tries to access any method of that class. """ + + is_dummy = True + def __getattribute__(cls, key): - """ - This method is called automatically when an attribute is accessed on the 'DummyObject' class or any of its subclasses. - - Args: - cls (type): The class object that the method was called on. - key (str): The name of the attribute being accessed. - - Returns: - None: This method does not return any value. - - Raises: - None: This method does not raise any exceptions. - """ - if key.startswith("_") and key != "_from_config": + if (key.startswith("_") and key != "_from_config") or key == "is_dummy" or key == "mro" or key == "call": return super().__getattribute__(key) requires_backends(cls, cls._backends) -def mindspore_required(func): +def is_torch_fx_proxy(x): + if is_torch_fx_available(): + import torch.fx + + return isinstance(x, torch.fx.Proxy) + return False + + +BACKENDS_T = frozenset[str] +IMPORT_STRUCTURE_T = dict[BACKENDS_T, dict[str, set[str]]] + + +class _LazyModule(ModuleType): """ - This function decorates another function to require the presence of MindSpore framework. - - Args: - func (function): The function to be decorated. - - Returns: - None. The function returns None. - - Raises: - FutureWarning: If the method `torch_required` is deprecated. - ImportError: If the decorated function requires MindSpore but MindSpore is not available. + Module class that surfaces all objects but only performs associated imports when the objects are requested. """ - warnings.warn( - "The method `torch_required` is deprecated. Use `requires_backends` instead.", - FutureWarning, - ) - # Chose a different decorator name than in tests so it's clear they are not the same. - @wraps(func) - def wrapper(*args, **kwargs): - if is_mindspore_available(): - return func(*args, **kwargs) - raise ImportError(f"Method `{func.__name__}` requires MindSpore.") + # Very heavily inspired by optuna.integration._IntegrationModule + # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py + def __init__( + self, + name: str, + module_file: str, + import_structure: IMPORT_STRUCTURE_T, + module_spec: Optional[importlib.machinery.ModuleSpec] = None, + extra_objects: Optional[dict[str, object]] = None, + explicit_import_shortcut: Optional[dict[str, list[str]]] = None, + ): + super().__init__(name) + + self._object_missing_backend = {} + self._explicit_import_shortcut = explicit_import_shortcut if explicit_import_shortcut else {} + + if any(isinstance(key, frozenset) for key in import_structure): + self._modules = set() + self._class_to_module = {} + self.__all__ = [] + + _import_structure = {} + + for backends, module in import_structure.items(): + missing_backends = [] + + # This ensures that if a module is importable, then all other keys of the module are importable. + # As an example, in module.keys() we might have the following: + # + # dict_keys(['models.nllb_moe.configuration_nllb_moe', 'models.sew_d.configuration_sew_d']) + # + # with this, we don't only want to be able to import these explicitly, we want to be able to import + # every intermediate module as well. Therefore, this is what is returned: + # + # { + # 'models.nllb_moe.configuration_nllb_moe', + # 'models.sew_d.configuration_sew_d', + # 'models', + # 'models.sew_d', 'models.nllb_moe' + # } + + module_keys = set( + chain(*[[k.rsplit(".", i)[0] for i in range(k.count(".") + 1)] for k in list(module.keys())]) + ) + + for backend in backends: + if backend in BACKENDS_MAPPING: + callable, _ = BACKENDS_MAPPING[backend] + else: + if any(key in backend for key in ["=", "<", ">"]): + backend = Backend(backend) + callable = backend.is_satisfied + else: + raise ValueError( + f"Backend should be defined in the BACKENDS_MAPPING. Offending backend: {backend}" + ) + + try: + if not callable(): + missing_backends.append(backend) + except (importlib.metadata.PackageNotFoundError, ModuleNotFoundError, RuntimeError): + missing_backends.append(backend) + + self._modules = self._modules.union(module_keys) + + for key, values in module.items(): + if missing_backends: + self._object_missing_backend[key] = missing_backends + + for value in values: + self._class_to_module[value] = key + if missing_backends: + self._object_missing_backend[value] = missing_backends + _import_structure.setdefault(key, []).extend(values) + + # Needed for autocompletion in an IDE + self.__all__.extend(module_keys | set(chain(*module.values()))) + + self.__file__ = module_file + self.__spec__ = module_spec + self.__path__ = [os.path.dirname(module_file)] + self._objects = {} if extra_objects is None else extra_objects + self._name = name + self._import_structure = _import_structure + + # This can be removed once every exportable object has a `require()` require. + else: + self._modules = set(import_structure.keys()) + self._class_to_module = {} + for key, values in import_structure.items(): + for value in values: + self._class_to_module[value] = key + # Needed for autocompletion in an IDE + self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values())) + self.__file__ = module_file + self.__spec__ = module_spec + self.__path__ = [os.path.dirname(module_file)] + self._objects = {} if extra_objects is None else extra_objects + self._name = name + self._import_structure = import_structure + + # Needed for autocompletion in an IDE + def __dir__(self): + result = super().__dir__() + # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether + # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir. + for attr in self.__all__: + if attr not in result: + result.append(attr) + return result + + def __getattr__(self, name: str) -> Any: + if name in self._objects: + return self._objects[name] + if name in self._object_missing_backend: + missing_backends = self._object_missing_backend[name] + + class Placeholder(metaclass=DummyObject): + _backends = missing_backends + + def __init__(self, *args, **kwargs): + requires_backends(self, missing_backends) + + def call(self, *args, **kwargs): + pass + + Placeholder.__name__ = name + + if name not in self._class_to_module: + module_name = f"transformers.{name}" + else: + module_name = self._class_to_module[name] + if not module_name.startswith("transformers."): + module_name = f"transformers.{module_name}" + + Placeholder.__module__ = module_name + + value = Placeholder + elif name in self._class_to_module: + try: + module = self._get_module(self._class_to_module[name]) + value = getattr(module, name) + except (ModuleNotFoundError, RuntimeError) as e: + raise ModuleNotFoundError( + f"Could not import module '{name}'. Are this object's requirements defined correctly?" + ) from e + + elif name in self._modules: + try: + value = self._get_module(name) + except (ModuleNotFoundError, RuntimeError) as e: + raise ModuleNotFoundError( + f"Could not import module '{name}'. Are this object's requirements defined correctly?" + ) from e + else: + value = None + for key, values in self._explicit_import_shortcut.items(): + if name in values: + value = self._get_module(key) + + if value is None: + raise AttributeError(f"module {self.__name__} has no attribute {name}") + + setattr(self, name, value) + return value + + def _get_module(self, module_name: str): + try: + return importlib.import_module("." + module_name, self.__name__) + except Exception as e: + raise e - return wrapper + def __reduce__(self): + return (self.__class__, (self._name, self.__file__, self._import_structure)) class OptionalDependencyNotAvailable(BaseException): """Internally used error class for signalling an optional dependency was not found.""" + + def direct_transformers_import(path: str, file="__init__.py") -> ModuleType: """Imports transformers directly Args: path (`str`): The path to the source file - file (`str`, optional): The file to join with the path. Defaults to "__init__.py". + file (`str`, *optional*): The file to join with the path. Defaults to "__init__.py". Returns: `ModuleType`: The resulting imported module """ - name = "mindnlp.transformers" + name = "transformers" location = os.path.join(path, file) - spec = importlib.util.spec_from_file_location( - name, location, submodule_search_locations=[path] - ) + spec = importlib.util.spec_from_file_location(name, location, submodule_search_locations=[path]) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) module = sys.modules[name] return module -def is_soundfile_availble(): - return _soundfile_available +class VersionComparison(Enum): + EQUAL = operator.eq + NOT_EQUAL = operator.ne + GREATER_THAN = operator.gt + LESS_THAN = operator.lt + GREATER_THAN_OR_EQUAL = operator.ge + LESS_THAN_OR_EQUAL = operator.le + + @staticmethod + def from_string(version_string: str) -> "VersionComparison": + string_to_operator = { + "=": VersionComparison.EQUAL.value, + "==": VersionComparison.EQUAL.value, + "!=": VersionComparison.NOT_EQUAL.value, + ">": VersionComparison.GREATER_THAN.value, + "<": VersionComparison.LESS_THAN.value, + ">=": VersionComparison.GREATER_THAN_OR_EQUAL.value, + "<=": VersionComparison.LESS_THAN_OR_EQUAL.value, + } + + return string_to_operator[version_string] + + +@lru_cache +def split_package_version(package_version_str) -> tuple[str, str, str]: + pattern = r"([a-zA-Z0-9_-]+)([!<>=~]+)([0-9.]+)" + match = re.match(pattern, package_version_str) + if match: + return (match.group(1), match.group(2), match.group(3)) + else: + raise ValueError(f"Invalid package version string: {package_version_str}") + + +class Backend: + def __init__(self, backend_requirement: str): + self.package_name, self.version_comparison, self.version = split_package_version(backend_requirement) + + if self.package_name not in BACKENDS_MAPPING: + raise ValueError( + f"Backends should be defined in the BACKENDS_MAPPING. Offending backend: {self.package_name}" + ) + + def is_satisfied(self) -> bool: + return VersionComparison.from_string(self.version_comparison)( + version.parse(importlib.metadata.version(self.package_name)), version.parse(self.version) + ) + + def __repr__(self) -> str: + return f'Backend("{self.package_name}", {VersionComparison[self.version_comparison]}, "{self.version}")' + + @property + def error_message(self): + return ( + f"{{0}} requires the {self.package_name} library version {self.version_comparison}{self.version}. That" + f" library was not found with this version in your environment." + ) + + +def requires(*, backends=()): + """ + This decorator enables two things: + - Attaching a `__backends` tuple to an object to see what are the necessary backends for it + to execute correctly without instantiating it + - The '@requires' string is used to dynamically import objects + """ + if not isinstance(backends, tuple): + raise TypeError("Backends should be a tuple.") -def is_speech_available(): - return _torchaudio_available + applied_backends = [] + for backend in backends: + if backend in BACKENDS_MAPPING: + applied_backends.append(backend) + else: + if any(key in backend for key in ["=", "<", ">"]): + applied_backends.append(Backend(backend)) + else: + raise ValueError(f"Backend should be defined in the BACKENDS_MAPPING. Offending backend: {backend}") + def inner_fn(fun): + fun.__backends = applied_backends + return fun -def is_kenlm_available(): - return _kenlm_available \ No newline at end of file + return inner_fn + + +BASE_FILE_REQUIREMENTS = { + lambda e: "modeling_tf_" in e: ("tf",), + lambda e: "modeling_flax_" in e: ("flax",), + lambda e: "modeling_" in e: ("torch",), + lambda e: e.startswith("tokenization_") and e.endswith("_fast"): ("tokenizers",), + lambda e: e.startswith("image_processing_") and e.endswith("_fast"): ("vision", "torch", "torchvision"), + lambda e: e.startswith("image_processing_"): ("vision",), +} + + +def fetch__all__(file_content): + """ + Returns the content of the __all__ variable in the file content. + Returns None if not defined, otherwise returns a list of strings. + """ + + if "__all__" not in file_content: + return [] + + start_index = None + lines = file_content.splitlines() + for index, line in enumerate(lines): + if line.startswith("__all__"): + start_index = index + + # There is no line starting with `__all__` + if start_index is None: + return [] + + lines = lines[start_index:] + + if not lines[0].startswith("__all__"): + raise ValueError( + "fetch__all__ accepts a list of lines, with the first line being the __all__ variable declaration" + ) + + # __all__ is defined on a single line + if lines[0].endswith("]"): + return [obj.strip("\"' ") for obj in lines[0].split("=")[1].strip(" []").split(",")] + + # __all__ is defined on multiple lines + else: + _all = [] + for __all__line_index in range(1, len(lines)): + if lines[__all__line_index].strip() == "]": + return _all + else: + _all.append(lines[__all__line_index].strip("\"', ")) + + return _all + + +@lru_cache +def create_import_structure_from_path(module_path): + """ + This method takes the path to a file/a folder and returns the import structure. + If a file is given, it will return the import structure of the parent folder. + + Import structures are designed to be digestible by `_LazyModule` objects. They are + created from the __all__ definitions in each files as well as the `@require` decorators + above methods and objects. + + The import structure allows explicit display of the required backends for a given object. + These backends are specified in two ways: + + 1. Through their `@require`, if they are exported with that decorator. This `@require` decorator + accepts a `backend` tuple kwarg mentioning which backends are required to run this object. + + 2. If an object is defined in a file with "default" backends, it will have, at a minimum, this + backend specified. The default backends are defined according to the filename: + + - If a file is named like `modeling_*.py`, it will have a `torch` backend + - If a file is named like `modeling_tf_*.py`, it will have a `tf` backend + - If a file is named like `modeling_flax_*.py`, it will have a `flax` backend + - If a file is named like `tokenization_*_fast.py`, it will have a `tokenizers` backend + - If a file is named like `image_processing*_fast.py`, it will have a `torchvision` + `torch` backend + + Backends serve the purpose of displaying a clear error message to the user in case the backends are not installed. + Should an object be imported without its required backends being in the environment, any attempt to use the + object will raise an error mentioning which backend(s) should be added to the environment in order to use + that object. + + Here's an example of an input import structure at the src.transformers.models level: + + { + 'albert': { + frozenset(): { + 'configuration_albert': {'AlbertConfig', 'AlbertOnnxConfig'} + }, + frozenset({'tokenizers'}): { + 'tokenization_albert_fast': {'AlbertTokenizerFast'} + }, + }, + 'align': { + frozenset(): { + 'configuration_align': {'AlignConfig', 'AlignTextConfig', 'AlignVisionConfig'}, + 'processing_align': {'AlignProcessor'} + }, + }, + 'altclip': { + frozenset(): { + 'configuration_altclip': {'AltCLIPConfig', 'AltCLIPTextConfig', 'AltCLIPVisionConfig'}, + 'processing_altclip': {'AltCLIPProcessor'}, + } + } + } + """ + import_structure = {} + + if os.path.isfile(module_path): + module_path = os.path.dirname(module_path) + + directory = module_path + adjacent_modules = [] + + for f in os.listdir(module_path): + if f != "__pycache__" and os.path.isdir(os.path.join(module_path, f)): + import_structure[f] = create_import_structure_from_path(os.path.join(module_path, f)) + + elif not os.path.isdir(os.path.join(directory, f)): + adjacent_modules.append(f) + + # We're only taking a look at files different from __init__.py + # We could theoretically require things directly from the __init__.py + # files, but this is not supported at this time. + if "__init__.py" in adjacent_modules: + adjacent_modules.remove("__init__.py") + + # Modular files should not be imported + def find_substring(substring, list_): + return any(substring in x for x in list_) + + if find_substring("modular_", adjacent_modules) and find_substring("modeling_", adjacent_modules): + adjacent_modules = [module for module in adjacent_modules if "modular_" not in module] + + module_requirements = {} + for module_name in adjacent_modules: + # Only modules ending in `.py` are accepted here. + if not module_name.endswith(".py"): + continue + + with open(os.path.join(directory, module_name), encoding="utf-8") as f: + file_content = f.read() + + # Remove the .py suffix + module_name = module_name[:-3] + + previous_line = "" + previous_index = 0 + + # Some files have some requirements by default. + # For example, any file named `modeling_tf_xxx.py` + # should have TensorFlow as a required backend. + base_requirements = () + for string_check, requirements in BASE_FILE_REQUIREMENTS.items(): + if string_check(module_name): + base_requirements = requirements + break + + # Objects that have a `@require` assigned to them will get exported + # with the backends specified in the decorator as well as the file backends. + exported_objects = set() + if "@requires" in file_content: + lines = file_content.split("\n") + for index, line in enumerate(lines): + # This allows exporting items with other decorators. We'll take a look + # at the line that follows at the same indentation level. + if line.startswith((" ", "\t", "@", ")")) and not line.startswith("@requires"): + continue + + # Skipping line enables putting whatever we want between the + # export() call and the actual class/method definition. + # This is what enables having # Copied from statements, docs, etc. + skip_line = False + + if "@requires" in previous_line: + skip_line = False + + # Backends are defined on the same line as export + if "backends" in previous_line: + backends_string = previous_line.split("backends=")[1].split("(")[1].split(")")[0] + backends = tuple(sorted([b.strip("'\",") for b in backends_string.split(", ") if b])) + + # Backends are defined in the lines following export, for example such as: + # @export( + # backends=( + # "sentencepiece", + # "torch", + # "tf", + # ) + # ) + # + # or + # + # @export( + # backends=( + # "sentencepiece", "tf" + # ) + # ) + elif "backends" in lines[previous_index + 1]: + backends = [] + for backend_line in lines[previous_index:index]: + if "backends" in backend_line: + backend_line = backend_line.split("=")[1] + if '"' in backend_line or "'" in backend_line: + if ", " in backend_line: + backends.extend(backend.strip("()\"', ") for backend in backend_line.split(", ")) + else: + backends.append(backend_line.strip("()\"', ")) + + # If the line is only a ')', then we reached the end of the backends and we break. + if backend_line.strip() == ")": + break + backends = tuple(backends) + + # No backends are registered for export + else: + backends = () + + backends = frozenset(backends + base_requirements) + if backends not in module_requirements: + module_requirements[backends] = {} + if module_name not in module_requirements[backends]: + module_requirements[backends][module_name] = set() + + if not line.startswith("class") and not line.startswith("def"): + skip_line = True + else: + start_index = 6 if line.startswith("class") else 4 + object_name = line[start_index:].split("(")[0].strip(":") + module_requirements[backends][module_name].add(object_name) + exported_objects.add(object_name) + + if not skip_line: + previous_line = line + previous_index = index + + # All objects that are in __all__ should be exported by default. + # These objects are exported with the file backends. + if "__all__" in file_content: + for _all_object in fetch__all__(file_content): + if _all_object not in exported_objects: + backends = frozenset(base_requirements) + if backends not in module_requirements: + module_requirements[backends] = {} + if module_name not in module_requirements[backends]: + module_requirements[backends][module_name] = set() + + module_requirements[backends][module_name].add(_all_object) + + import_structure = {**module_requirements, **import_structure} + return import_structure + + +def spread_import_structure(nested_import_structure): + """ + This method takes as input an unordered import structure and brings the required backends at the top-level, + aggregating modules and objects under their required backends. + + Here's an example of an input import structure at the src.transformers.models level: + + { + 'albert': { + frozenset(): { + 'configuration_albert': {'AlbertConfig', 'AlbertOnnxConfig'} + }, + frozenset({'tokenizers'}): { + 'tokenization_albert_fast': {'AlbertTokenizerFast'} + }, + }, + 'align': { + frozenset(): { + 'configuration_align': {'AlignConfig', 'AlignTextConfig', 'AlignVisionConfig'}, + 'processing_align': {'AlignProcessor'} + }, + }, + 'altclip': { + frozenset(): { + 'configuration_altclip': {'AltCLIPConfig', 'AltCLIPTextConfig', 'AltCLIPVisionConfig'}, + 'processing_altclip': {'AltCLIPProcessor'}, + } + } + } + + Here's an example of an output import structure at the src.transformers.models level: + + { + frozenset({'tokenizers'}): { + 'albert.tokenization_albert_fast': {'AlbertTokenizerFast'} + }, + frozenset(): { + 'albert.configuration_albert': {'AlbertConfig', 'AlbertOnnxConfig'}, + 'align.processing_align': {'AlignProcessor'}, + 'align.configuration_align': {'AlignConfig', 'AlignTextConfig', 'AlignVisionConfig'}, + 'altclip.configuration_altclip': {'AltCLIPConfig', 'AltCLIPTextConfig', 'AltCLIPVisionConfig'}, + 'altclip.processing_altclip': {'AltCLIPProcessor'} + } + } + + """ + + def propagate_frozenset(unordered_import_structure): + frozenset_first_import_structure = {} + for _key, _value in unordered_import_structure.items(): + # If the value is not a dict but a string, no need for custom manipulation + if not isinstance(_value, dict): + frozenset_first_import_structure[_key] = _value + + elif any(isinstance(v, frozenset) for v in _value): + for k, v in _value.items(): + if isinstance(k, frozenset): + # Here we want to switch around _key and k to propagate k upstream if it is a frozenset + if k not in frozenset_first_import_structure: + frozenset_first_import_structure[k] = {} + if _key not in frozenset_first_import_structure[k]: + frozenset_first_import_structure[k][_key] = {} + + frozenset_first_import_structure[k][_key].update(v) + + else: + # If k is not a frozenset, it means that the dictionary is not "level": some keys (top-level) + # are frozensets, whereas some are not -> frozenset keys are at an unkown depth-level of the + # dictionary. + # + # We recursively propagate the frozenset for this specific dictionary so that the frozensets + # are at the top-level when we handle them. + propagated_frozenset = propagate_frozenset({k: v}) + for r_k, r_v in propagated_frozenset.items(): + if isinstance(_key, frozenset): + if r_k not in frozenset_first_import_structure: + frozenset_first_import_structure[r_k] = {} + if _key not in frozenset_first_import_structure[r_k]: + frozenset_first_import_structure[r_k][_key] = {} + + # _key is a frozenset -> we switch around the r_k and _key + frozenset_first_import_structure[r_k][_key].update(r_v) + else: + if _key not in frozenset_first_import_structure: + frozenset_first_import_structure[_key] = {} + if r_k not in frozenset_first_import_structure[_key]: + frozenset_first_import_structure[_key][r_k] = {} + + # _key is not a frozenset -> we keep the order of r_k and _key + frozenset_first_import_structure[_key][r_k].update(r_v) + + else: + frozenset_first_import_structure[_key] = propagate_frozenset(_value) + + return frozenset_first_import_structure + + def flatten_dict(_dict, previous_key=None): + items = [] + for _key, _value in _dict.items(): + _key = f"{previous_key}.{_key}" if previous_key is not None else _key + if isinstance(_value, dict): + items.extend(flatten_dict(_value, _key).items()) + else: + items.append((_key, _value)) + return dict(items) + + # The tuples contain the necessary backends. We want these first, so we propagate them up the + # import structure. + ordered_import_structure = nested_import_structure + + # 6 is a number that gives us sufficient depth to go through all files and foreseeable folder depths + # while not taking too long to parse. + for i in range(6): + ordered_import_structure = propagate_frozenset(ordered_import_structure) + + # We then flatten the dict so that it references a module path. + flattened_import_structure = {} + for key, value in ordered_import_structure.copy().items(): + if isinstance(key, str): + del ordered_import_structure[key] + else: + flattened_import_structure[key] = flatten_dict(value) + + return flattened_import_structure + + +@lru_cache +def define_import_structure(module_path: str, prefix: Optional[str] = None) -> IMPORT_STRUCTURE_T: + """ + This method takes a module_path as input and creates an import structure digestible by a _LazyModule. + + Here's an example of an output import structure at the src.transformers.models level: + + { + frozenset({'tokenizers'}): { + 'albert.tokenization_albert_fast': {'AlbertTokenizerFast'} + }, + frozenset(): { + 'albert.configuration_albert': {'AlbertConfig', 'AlbertOnnxConfig'}, + 'align.processing_align': {'AlignProcessor'}, + 'align.configuration_align': {'AlignConfig', 'AlignTextConfig', 'AlignVisionConfig'}, + 'altclip.configuration_altclip': {'AltCLIPConfig', 'AltCLIPTextConfig', 'AltCLIPVisionConfig'}, + 'altclip.processing_altclip': {'AltCLIPProcessor'} + } + } + + The import structure is a dict defined with frozensets as keys, and dicts of strings to sets of objects. + + If `prefix` is not None, it will add that prefix to all keys in the returned dict. + """ + import_structure = create_import_structure_from_path(module_path) + spread_dict = spread_import_structure(import_structure) + + if prefix is None: + return spread_dict + else: + spread_dict = {k: {f"{prefix}.{kk}": vv for kk, vv in v.items()} for k, v in spread_dict.items()} + return spread_dict + + +def clear_import_cache(): + """ + Clear cached Transformers modules to allow reloading modified code. + + This is useful when actively developing/modifying Transformers code. + """ + # Get all transformers modules + transformers_modules = [mod_name for mod_name in sys.modules if mod_name.startswith("transformers.")] + + # Remove them from sys.modules + for mod_name in transformers_modules: + module = sys.modules[mod_name] + # Clear _LazyModule caches if applicable + if isinstance(module, _LazyModule): + module._objects = {} # Clear cached objects + del sys.modules[mod_name] + + # Force reload main transformers module + if "transformers" in sys.modules: + main_module = sys.modules["transformers"] + if isinstance(main_module, _LazyModule): + main_module._objects = {} # Clear cached objects + importlib.reload(main_module) diff --git a/mindnlp/utils/safetensors_patch.py b/mindnlp/utils/safetensors_patch.py index 22310e26b..068c15db7 100644 --- a/mindnlp/utils/safetensors_patch.py +++ b/mindnlp/utils/safetensors_patch.py @@ -3,8 +3,8 @@ from typing import OrderedDict import numpy as np import mindspore -from mindspore import Tensor +from mindnlp import core from mindnlp.core.configs import SUPPORT_BF16 import safetensors from safetensors import SafetensorError @@ -92,8 +92,8 @@ def get(self, *args, **kwargs): array = np.frombuffer(buffer, dtype=self.dtype).reshape(self.shape) array = array.reshape(self.shape) if not SUPPORT_BF16 and self.info["dtype"] == 'BF16': - array = array.astype(np.float16) - tensor = Tensor.from_numpy(array) + array = array.view(np.float16) + tensor = core.from_numpy(array) tensor._ptr = array.ctypes.data return tensor diff --git a/mindnlp/utils/torch_proxy.py b/mindnlp/utils/torch_proxy.py index 1882dc0b2..0cf1ae32d 100644 --- a/mindnlp/utils/torch_proxy.py +++ b/mindnlp/utils/torch_proxy.py @@ -101,6 +101,8 @@ def __setattr__(_, name, value): REDIRECT_MAP = { "torch": "mindnlp.core", } +if DEVICE_TARGET == 'Ascend': + REDIRECT_MAP["torch_npu"] = 'mindnlp.core.npu' def initialize_torch_proxy(): sys.meta_path.insert(0, RedirectFinder(REDIRECT_MAP)) @@ -108,6 +110,7 @@ def initialize_torch_proxy(): torch.__version__ = TORCH_VERSION + def setup_metadata_patch(): """解决 importlib.metadata 找不到 torch 的问题""" # 保存原始函数