diff --git a/mindnlp/__init__.py b/mindnlp/__init__.py index 695f0b34c..60bf86e17 100644 --- a/mindnlp/__init__.py +++ b/mindnlp/__init__.py @@ -39,8 +39,8 @@ # for different ascend devices if platform.system().lower() == 'linux': SOC = MSContext.get_instance().get_ascend_soc_version() - if ('910b' not in SOC and '310' not in SOC) or version.parse(mindspore.__version__) < version.parse('2.4.0'): - os.environ["MS_ALLOC_CONF"] = 'enable_vmm:True,vmm_align_size:2MB' + # enable vmm since only vmm can release device memory when del tensor. + os.environ["MS_ALLOC_CONF"] = 'enable_vmm:True,vmm_align_size:2MB' if SOC in ('ascend910', 'ascend310b'): # context.set_context(ascend_config={"precision_mode": "allow_mix_precision"}) diff --git a/mindnlp/core/_dtype.py b/mindnlp/core/_dtype.py index d60f3cdef..7d9c27ba3 100644 --- a/mindnlp/core/_dtype.py +++ b/mindnlp/core/_dtype.py @@ -12,6 +12,8 @@ from ml_dtypes import bfloat16 as np_bfloat16 bool_alias = bool +float_alias = float +int_alias = int if ON_A1: warnings.warn('MindSpore on GPU/910A do not support bfloat16, use float16 instead.') @@ -116,5 +118,7 @@ def __gt__(self, other): dtype2np[bfloat16] = np_bfloat16 py2dtype = { - bool_alias: bool + bool_alias: bool, + float_alias: float, + int_alias: int64 } diff --git a/mindnlp/core/_prims/ascend.py b/mindnlp/core/_prims/ascend.py index dbf0d34e1..0acef1149 100644 --- a/mindnlp/core/_prims/ascend.py +++ b/mindnlp/core/_prims/ascend.py @@ -1,12 +1,12 @@ import numbers import mindspore from mindspore import ops +from mindspore.ops._primitive_cache import _get_cache_prim from mindspore.ops.auto_generate import gen_ops_prim from mindspore.ops.auto_generate import pyboost_inner_prim from mindspore._c_expression import _empty_instance from mindspore.ops.operations.math_ops import NPUGetFloatStatusV2, NPUClearFloatStatusV2 -from mindspore.ops.operations.nn_ops import AllFinite - +from mindspore.ops.auto_generate.gen_ops_prim import MaxPoolWithIndices, MaxPoolWithMask from mindnlp import core from mindnlp.core._C import default_generator @@ -105,7 +105,12 @@ def tile(*args): __all__.append('tile') def pad_v3(input_x, padding, mode='constant', value=None): - pad_op = ops.PadV3(mode=mode, paddings_contiguous=True).set_device('CPU') + pad_op = ops.PadV3(mode=mode, paddings_contiguous=True).set_device('Ascend') + if input_x.dtype == core.bool: + input_x = input_x.to(core.int32) + out = pad_op(input_x, padding, value) + return cast(out, core.bool) + if isinstance(value, (float, int)): value = core.tensor(value, dtype=input_x.dtype) return pad_op(input_x, padding, value) @@ -248,3 +253,39 @@ def triu(input, diagonal): return pyboost_inner_prim.triu_impl(input, diagonal) __all__.append('triu') + +masked_scatter_op = ops.MaskedScatter().set_device('Ascend') +def masked_scatter(input, mask, source): + return masked_scatter_op(input, mask, source) + +__all__.append('masked_scatter') + +def roll(*args): + return pyboost_inner_prim.roll_impl(*args) + +__all__.append('roll') + +lgamma_op = ops.Lgamma().set_device('Ascend') +def lgamma(input): + return lgamma_op(input) + +__all__.append('lgamma') + +def max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode, return_indices): + strides = stride if (stride is not None) else kernel_size + if return_indices: + max_pool_func_ = _get_cache_prim(MaxPoolWithIndices)(kernel_size, strides, padding, dilation, ceil_mode) + out, indices = max_pool_func_(input) + else: + max_pool_func_ = _get_cache_prim(MaxPoolWithMask)(kernel_size, strides, padding, dilation, ceil_mode) + out, indices = max_pool_func_(input) + if return_indices: + return out, indices + return out + +__all__.append('max_pool2d') + +def unique_consecutive(*args): + return pyboost_inner_prim.unique_consecutive_impl(*args) + +__all__.append('unique_consecutive') diff --git a/mindnlp/core/_prims/meta.py b/mindnlp/core/_prims/meta.py index 08bf032a9..33b9b8446 100644 --- a/mindnlp/core/_prims/meta.py +++ b/mindnlp/core/_prims/meta.py @@ -290,4 +290,43 @@ def linalg_vector_norm(input, p, dim, keepdim, dtype): out = Tensor_(shape=tuple(new_shape), dtype=dtype) return core.Tensor(out) -__all__.append('linalg_vector_norm') \ No newline at end of file +__all__.append('linalg_vector_norm') + +def erfinv(input): + return input +__all__.append('erfinv') + + +def stop_gradient(input): + out = Tensor_(shape=input.shape, dtype=input.dtype) + return core.Tensor(out) + +__all__.append('stop_gradient') + +def log(input): + return input +__all__.append('log') + +def mul(input, other): + out = Tensor_(shape=input.shape, dtype=input.dtype) + return core.Tensor(out) +__all__.append('mul') + +def randn(size, seed, offset, dtype): + out = Tensor_(shape=size, dtype=dtype) + return core.Tensor(out) + +__all__.append('randn') + +def zeros_like_ext(input, *args, **kwargs): + out = Tensor_(shape=input.shape, dtype=input.dtype) + return core.Tensor(out) +__all__.append('zeros_like_ext') + +def inplace_add_ext(input, other, alpha): + return input +__all__.append('inplace_add_ext') + +def clamp_scalar(input, *args): + return input +__all__.append('clamp_scalar') diff --git a/mindnlp/core/_prims/numpy.py b/mindnlp/core/_prims/numpy.py index 0599c42d6..a27096677 100644 --- a/mindnlp/core/_prims/numpy.py +++ b/mindnlp/core/_prims/numpy.py @@ -1,6 +1,8 @@ import numbers import numpy as np +import scipy from mindspore import ops +from mindspore.ops._primitive_cache import _get_cache_prim from mindnlp import core __all__ = [] @@ -28,6 +30,8 @@ def arange(start, end, step, dtype): def div(input, other): if not isinstance(input, numbers.Number): input = input.numpy() + if input.dtype == np.int64: + input = input.astype(np.int32) elif not isinstance(other, numbers.Number): other = other.numpy() out = np.divide(input, other) @@ -98,7 +102,15 @@ def cast(input, dtype): __all__.append('cast') def getitem(input, slice): - out = input.asnumpy()[slice] + if isinstance(slice, tuple): + new_slice = () + for s in slice: + if isinstance(s, core.Tensor): + s = s.numpy() + new_slice += (s,) + else: + new_slice = slice + out = input.asnumpy()[new_slice] if not isinstance(out, np.ndarray): out = np.array(out) return core.Tensor.from_numpy(out) @@ -233,6 +245,8 @@ def concat(tensors, dim): def abs(input): out = np.abs(input.numpy()) + if not isinstance(out, np.ndarray): + out = np.array(out) return core.Tensor.from_numpy(out) __all__.append('abs') @@ -277,6 +291,8 @@ def identity(input): # def non_zero() def isclose(input, other, rtol, atol, equal_nan): out = np.isclose(input.numpy(), other.numpy(), rtol, atol, equal_nan) + if not isinstance(out, np.ndarray): + out = np.array(out) return core.Tensor.from_numpy(out) __all__.append('isclose') @@ -308,8 +324,11 @@ def index_select(input, dim, index): __all__.append('index_select') def rand_ext(size, seed, offset, dtype): - out = np.random.randn(*size).astype(core.dtype2np[dtype]) - return core.Tensor.from_numpy(out[0]) + out = np.random.randn(*size) + if not isinstance(out, np.ndarray): + out = np.array(out) + out = out.astype(core.dtype2np[dtype]) + return core.Tensor.from_numpy(out) __all__.append('rand_ext') @@ -438,6 +457,9 @@ def less(input, other): other = other.numpy() out = input < other + if not isinstance(out, np.ndarray): + out = np.array(out) + return core.Tensor.from_numpy(out) __all__.append('less') @@ -529,3 +551,87 @@ def randn(size, seed, offset, dtype): return core.Tensor.from_numpy(out) __all__.append('randn') + +def erfinv(input): + out = scipy.special.erfinv(input) + return core.Tensor.from_numpy(out) + +__all__.append('erfinv') + +def inplace_add_ext(input, other, alpha): + if not isinstance(other, numbers.Number): + other = other.numpy() + out = input.numpy() + other * alpha + input.assign_value(core.Tensor.from_numpy(out)) + return input + +__all__.append('inplace_add_ext') + +def pow_tensor_scalar(input, other): + out = np.power(input.numpy(), other) + return core.Tensor.from_numpy(out) + +__all__.append('pow_tensor_scalar') + +stop_gradient_op = ops.StopGradient().set_device('CPU') +def stop_gradient(*args): + return stop_gradient_op(*args) + +__all__.append('stop_gradient') + +def fmod_scalar(input, other): + out = np.fmod(input.numpy(), other) + return core.Tensor.from_numpy(out) + +__all__.append('fmod_scalar') + +def argmax_with_value(input, dim, keepdim): + indices = np.argmax(input.numpy(), dim, keepdims=keepdim) + values = np.max(input.numpy(), dim, keepdims=keepdim) + + if not isinstance(indices, np.ndarray): + indices = np.array(indices) + if not isinstance(values, np.ndarray): + values = np.array(values) + return core.Tensor.from_numpy(indices), core.Tensor.from_numpy(values) + +__all__.append('argmax_with_value') + +def argmax_ext(input, dim, keepdim): + indices = np.argmax(input.numpy(), dim, keepdims=keepdim) + if not isinstance(indices, np.ndarray): + indices = np.array(indices) + return core.Tensor.from_numpy(indices) +__all__.append('argmax_ext') + + +def log(input): + out = np.log(input.numpy()) + return core.Tensor.from_numpy(out) + +__all__.append('log') + +def eye(n, m, dtype): + out = np.eye(n, m, dtype=core.dtype2np[dtype]) + return core.Tensor.from_numpy(out) + +__all__.append('eye') + +def lin_space_ext(start, end, steps, dtype): + out = np.linspace(start, end, steps, dtype=core.dtype2np[dtype]) + return core.Tensor.from_numpy(out) + +__all__.append('lin_space_ext') + +def upsample_bilinear2d(input, output_size, scale_factors, align_corners): + resize = _get_cache_prim(ops.ResizeBilinearV2)(align_corners, not align_corners).set_device('CPU') + return resize(input, output_size) + +__all__.append('upsample_bilinear2d') + +def split_with_size(tensor, split_size_or_sections, dim): + out = np.array_split(tensor.numpy(), np.cumsum(split_size_or_sections[:-1]), dim) + out = [core.Tensor.from_numpy(o) for o in out] + return out + +__all__.append('split_with_size') diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py index 49f8544a1..ac7e63e0c 100644 --- a/mindnlp/core/_tensor.py +++ b/mindnlp/core/_tensor.py @@ -1,3 +1,4 @@ +import gc import math import ctypes import numpy as np @@ -50,31 +51,37 @@ def __isinstancecheck__(self, instance): class IntTensor(Tensor, metaclass=TypedTensorMeta): dtype = _dtype.int def __init__(self, *args, **kwargs): + self._device = kwargs.pop('device', device_('cpu')) super().__init__(*args, dtype=_dtype.int, **kwargs) class LongTensor(Tensor, metaclass=TypedTensorMeta): dtype = _dtype.long def __init__(self, *args, **kwargs): + self._device = kwargs.pop('device', device_('cpu')) super().__init__(*args, dtype=_dtype.long, **kwargs) class FloatTensor(Tensor, metaclass=TypedTensorMeta): dtype = _dtype.float32 def __init__(self, *args, **kwargs): + self._device = kwargs.pop('device', device_('cpu')) super().__init__(*args, dtype=_dtype.float32, **kwargs) class HalfTensor(Tensor, metaclass=TypedTensorMeta): dtype = _dtype.float16 def __init__(self, *args, **kwargs): + self._device = kwargs.pop('device', device_('cpu')) super().__init__(*args, dtype=_dtype.float16, **kwargs) class BFloat16Tensor(Tensor, metaclass=TypedTensorMeta): dtype = _dtype.float16 def __init__(self, *args, **kwargs): + self._device = kwargs.pop('device', device_('cpu')) super().__init__(*args, dtype=_dtype.bfloat16, **kwargs) class BoolTensor(Tensor, metaclass=TypedTensorMeta): dtype = _dtype.bool def __init__(self, *args, **kwargs): + self._device = kwargs.pop('device', device_('cpu')) super().__init__(*args, dtype=_dtype.bool, **kwargs) @@ -86,17 +93,17 @@ def tensor_meta_str(self): old_init = Tensor.__init__ def __init__(self, *args, **kwargs): requires_grad = kwargs.pop('requires_grad', False) + device = kwargs.pop('device', core.get_default_device()) if len(args) > 1 and all([isinstance(arg, int) for arg in args]): tensor = Tensor_(shape=args, dtype=get_default_dtype()) old_init(self, tensor, internal=True) else: old_init(self, *args, **kwargs) self.requires_grad_(requires_grad) + self._device = device Tensor.__init__ = __init__ origin_setitem = Tensor.__setitem__ - -Tensor._device = device_('cpu') Tensor._requires_grad = False def tensor(data, *, dtype=None, device=None, requires_grad=False): @@ -363,11 +370,6 @@ def __invert__(self): def __round__(self): return ops.round(self) - # def __del__(self): - # # self._offload() - # # Tensor_.__del__(self) - # mindspore.runtime.synchronize() - def new(self, *shape): if not isinstance(shape[0], int): return tensor(shape[0], dtype=self.dtype) @@ -415,7 +417,11 @@ def new_zeros(self, *size, dtype=None, device=None, requires_grad=False, layout= if isinstance(s, Tensor): s = s.item() new_size += (s,) - return ops.zeros(*new_size, dtype=dtype if dtype is not None else self.dtype) + return ops.zeros( + *new_size, + dtype=dtype if dtype is not None else self.dtype, + device=device if device is not None else self.device + ) # Tensor.ndim @property @@ -745,6 +751,12 @@ def chunk(self, chunks, dim=0): def clamp(self, min=None, max=None): return ops.clamp(self, min, max) + def clamp_min(self, min): + return ops.clamp(self, min, None) + + def clamp_max(self, min): + return ops.clamp(self, None, max) + # Tensor.clamp_ def clamp_(self, min=None, max=None): return self.copy_(ops.clamp(self, min, max)) @@ -1523,10 +1535,12 @@ def lt_(self, other): # Tensor.masked_scatter_ - + def masked_scatter_(self, mask, tensor): + return self.copy_(ops.masked_scatter(self, mask, tensor)) # Tensor.masked_scatter - + def masked_scatter(self, mask, tensor): + return ops.masked_scatter(self, mask, tensor) # Tensor.masked_fill_ def masked_fill_(self, mask, value): @@ -1728,6 +1742,8 @@ def outer(self, vec2): # Tensor.permute def permute(self, *dims): + if isinstance(dims[0], (list, tuple)): + dims = tuple(dims[0]) return ops.permute(self, dims) # Tensor.pin_memory @@ -1835,8 +1851,8 @@ def repeat(self, *repeats): return ops.tile(self, repeats) # Tensor.repeat_interleave - def repeat_interleave(self, repeats, dim=None): - return ops.repeat_interleave(self, repeats, dim) + def repeat_interleave(self, repeats, dim=None, output_size=None): + return ops.repeat_interleave(self, repeats, dim, output_size=output_size) # Tensor.reshape def reshape(self, *shape): @@ -2241,8 +2257,8 @@ def atanh_(self): arctanh_ = atanh_ # Tensor.tolist - # def tolist(self): - # return self.numpy().tolist() + def tolist(self): + return self.numpy().tolist() # Tensor.topk def topk(self, k, dim=-1, largest=True, sorted=True): @@ -2329,7 +2345,10 @@ def type(self, dtype=None, non_blocking=False): # Tensor.type_as def type_as(self, tensor): - return self.type(tensor.dtype) + out = self.type(tensor.dtype) + if self.device != tensor.device: + out = out.to(tensor.device) + return out # Tensor.unbind def unbind(self, dim=0): @@ -2365,7 +2384,8 @@ def unsqueeze(self, dim): # Tensor.unsqueeze_ def unsqueeze_(self, dim): - return self.copy_(ops.unsqueeze(self, dim)) + self.data = ops.unsqueeze(self, dim) + return self # Tensor.values diff --git a/mindnlp/core/distributions/utils.py b/mindnlp/core/distributions/utils.py index cdb6e3b8d..ea71fa3db 100644 --- a/mindnlp/core/distributions/utils.py +++ b/mindnlp/core/distributions/utils.py @@ -1,17 +1,15 @@ -"""distribution utils""" -# mypy: allow-untyped-defs +from collections.abc import Sequence from functools import update_wrapper -from numbers import Number -from typing import Any, Dict +from typing import Any, Callable, Final, Generic, Optional, overload, TypeVar, Union -import mindspore -from .. import ops -from ..autograd import enable_grad -from .._bind import get_default_dtype -from ..nn import functional as F +from mindnlp import core +import mindnlp.core.nn.functional as F +from mindnlp.core import Tensor +from mindnlp.core.overrides import is_tensor_like +from mindnlp.core.types import _dtype, _Number, Device, Number -euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant +euler_constant: Final[float] = 0.57721566490153286060 # Euler Mascheroni Constant __all__ = [ "broadcast_all", @@ -24,45 +22,57 @@ ] -def broadcast_all(*values): +# FIXME: Use (*values: *Ts) -> tuple[Tensor for T in Ts] if Mapping-Type is ever added. +# See https://github.com/python/typing/issues/1216#issuecomment-2126153831 +def broadcast_all(*values: Union[Tensor, Number]) -> tuple[Tensor, ...]: r""" Given a list of values (possibly containing numbers), returns a list where each value is broadcasted based on the following rules: - `core.*Tensor` instances are broadcasted as per :ref:`_broadcasting-semantics`. - - numbers.Number instances (scalars) are upcast to tensors having + - Number instances (scalars) are upcast to tensors having the same size and type as the first tensor passed to `values`. If all the values are scalars, then they are upcasted to scalar Tensors. Args: - values (list of `numbers.Number`, `core.*Tensor` or objects implementing __torch_function__) + values (list of `Number`, `core.*Tensor` or objects implementing __torch_function__) Raises: - ValueError: if any of the values is not a `numbers.Number` instance, + ValueError: if any of the values is not a `Number` instance, a `core.*Tensor` instance, or an instance implementing __torch_function__ """ - if not all(isinstance(v, (mindspore.Tensor, Number)) for v in values): + if not all(is_tensor_like(v) or isinstance(v, _Number) for v in values): raise ValueError( - "Input arguments must all be instances of numbers.Number, " - "mindspore.Tensor or objects implementing __torch_function__." + "Input arguments must all be instances of Number, " + "core.Tensor or objects implementing __torch_function__." ) - if not all(isinstance(v, mindspore.Tensor) for v in values): - options: Dict[str, Any] = {"dtype": get_default_dtype()} + if not all(is_tensor_like(v) for v in values): + options: dict[str, Any] = dict(dtype=core.get_default_dtype()) for value in values: - if isinstance(value, mindspore.Tensor): - options = {"dtype": value.dtype} + if isinstance(value, core.Tensor): + options = dict(dtype=value.dtype, device=value.device) break new_values = [ - v if isinstance(v, mindspore.Tensor) else mindspore.tensor(v, **options) for v in values + v if is_tensor_like(v) else core.tensor(v, **options) for v in values ] - return ops.broadcast_tensors(*new_values) - return ops.broadcast_tensors(*values) - - -def _standard_normal(shape, dtype, device=None): - return ops.normal(size = shape).to(dtype) + return core.broadcast_tensors(*new_values) + return core.broadcast_tensors(*values) + + +def _standard_normal( + shape: Sequence[int], + dtype: Optional[_dtype], + device: Optional[Device], +) -> Tensor: + if core._C._get_tracing_state(): + # [JIT WORKAROUND] lack of support for .normal_() + return core.normal( + core.zeros(shape, dtype=dtype, device=device), + core.ones(shape, dtype=dtype, device=device), + ) + return core.empty(shape, dtype=dtype, device=device).normal_() -def _sum_rightmost(value, dim): +def _sum_rightmost(value: Tensor, dim: int) -> Tensor: r""" Sum out ``dim`` many rightmost dimensions of a given tensor. @@ -76,7 +86,7 @@ def _sum_rightmost(value, dim): return value.reshape(required_shape).sum(-1) -def logits_to_probs(logits, is_binary=False): +def logits_to_probs(logits: Tensor, is_binary: bool = False) -> Tensor: r""" Converts a tensor of logits into probabilities. Note that for the binary case, each value denotes log odds, whereas for the @@ -84,11 +94,11 @@ def logits_to_probs(logits, is_binary=False): the log probabilities (possibly unnormalized) of the events. """ if is_binary: - return ops.sigmoid(logits) + return core.sigmoid(logits) return F.softmax(logits, dim=-1) -def clamp_probs(probs): +def clamp_probs(probs: Tensor) -> Tensor: """Clamps the probabilities to be in the open interval `(0, 1)`. The probabilities would be clamped between `eps` and `1 - eps`, @@ -101,20 +111,20 @@ def clamp_probs(probs): Tensor: The clamped probabilities. Examples: - >>> probs = mindspore.tensor([0.0, 0.5, 1.0]) + >>> probs = core.tensor([0.0, 0.5, 1.0]) >>> clamp_probs(probs) tensor([1.1921e-07, 5.0000e-01, 1.0000e+00]) - >>> probs = mindspore.tensor([0.0, 0.5, 1.0], dtype=mindspore.float64) + >>> probs = core.tensor([0.0, 0.5, 1.0], dtype=core.float64) >>> clamp_probs(probs) - tensor([2.2204e-16, 5.0000e-01, 1.0000e+00], dtype=mindspore.float64) + tensor([2.2204e-16, 5.0000e-01, 1.0000e+00], dtype=core.float64) """ - eps = ops.finfo(probs.dtype).eps + eps = core.finfo(probs.dtype).eps return probs.clamp(min=eps, max=1 - eps) -def probs_to_logits(probs, is_binary=False): +def probs_to_logits(probs: Tensor, is_binary: bool = False) -> Tensor: r""" Converts a tensor of probabilities into logits. For the binary case, this denotes the probability of occurrence of the event indexed by `1`. @@ -123,11 +133,15 @@ def probs_to_logits(probs, is_binary=False): """ ps_clamped = clamp_probs(probs) if is_binary: - return ops.log(ps_clamped) - ops.log1p(-ps_clamped) - return ops.log(ps_clamped) + return core.log(ps_clamped) - core.log1p(-ps_clamped) + return core.log(ps_clamped) + + +T = TypeVar("T", contravariant=True) +R = TypeVar("R", covariant=True) -class lazy_property: +class lazy_property(Generic[T, R]): r""" Used as a decorator for lazy loading of class attributes. This uses a non-data descriptor that calls the wrapped method to compute the property on @@ -135,42 +149,55 @@ class lazy_property: attribute. """ - def __init__(self, wrapped): - self.wrapped = wrapped + def __init__(self, wrapped: Callable[[T], R]) -> None: + self.wrapped: Callable[[T], R] = wrapped update_wrapper(self, wrapped) # type:ignore[arg-type] - def __get__(self, instance, obj_type=None): + @overload + def __get__( + self, instance: None, obj_type: Any = None + ) -> "_lazy_property_and_property[T, R]": ... + + @overload + def __get__(self, instance: T, obj_type: Any = None) -> R: ... + + def __get__( + self, instance: Union[T, None], obj_type: Any = None + ) -> "R | _lazy_property_and_property[T, R]": if instance is None: return _lazy_property_and_property(self.wrapped) - with enable_grad(): + with core.enable_grad(): value = self.wrapped(instance) setattr(instance, self.wrapped.__name__, value) return value -class _lazy_property_and_property(lazy_property, property): +class _lazy_property_and_property(lazy_property[T, R], property): """We want lazy properties to look like multiple things. * property when Sphinx autodoc looks * lazy_property when Distribution validate_args looks """ + def __init__(self, wrapped: Callable[[T], R]) -> None: + property.__init__(self, wrapped) + -def tril_matrix_to_vec(mat: mindspore.Tensor, diag: int = 0) -> mindspore.Tensor: +def tril_matrix_to_vec(mat: Tensor, diag: int = 0) -> Tensor: r""" Convert a `D x D` matrix or a batch of matrices into a (batched) vector which comprises of lower triangular elements from the matrix in row order. """ n = mat.shape[-1] - # if not core._C._get_tracing_state() and (diag < -n or diag >= n): - # raise ValueError(f"diag ({diag}) provided is outside [{-n}, {n-1}].") - arange = ops.arange(n) + if not core._C._get_tracing_state() and (diag < -n or diag >= n): + raise ValueError(f"diag ({diag}) provided is outside [{-n}, {n - 1}].") + arange = core.arange(n, device=mat.device) tril_mask = arange < arange.view(-1, 1) + (diag + 1) vec = mat[..., tril_mask] return vec -def vec_to_tril_matrix(vec: mindspore.Tensor, diag: int = 0) -> mindspore.Tensor: +def vec_to_tril_matrix(vec: Tensor, diag: int = 0) -> Tensor: r""" Convert a vector or a batch of vectors into a batched `D x D` lower triangular matrix containing elements from the vector in row order. @@ -180,15 +207,15 @@ def vec_to_tril_matrix(vec: mindspore.Tensor, diag: int = 0) -> mindspore.Tensor -(1 + 2 * diag) + ((1 + 2 * diag) ** 2 + 8 * vec.shape[-1] + 4 * abs(diag) * (diag + 1)) ** 0.5 ) / 2 - eps = ops.finfo(vec.dtype).eps - # if not core._C._get_tracing_state() and (round(n) - n > eps): - # raise ValueError( - # f"The size of last dimension is {vec.shape[-1]} which cannot be expressed as " - # + "the lower triangular part of a square D x D matrix." - # ) - n = round(n.item()) if isinstance(n, mindspore.Tensor) else round(n) - mat = vec.new_zeros(vec.shape[:-1] + (n, n)) - arange = ops.arange(n) + eps = core.finfo(vec.dtype).eps + if not core._C._get_tracing_state() and (round(n) - n > eps): + raise ValueError( + f"The size of last dimension is {vec.shape[-1]} which cannot be expressed as " + + "the lower triangular part of a square D x D matrix." + ) + n = round(n.item()) if isinstance(n, core.Tensor) else round(n) + mat = vec.new_zeros(vec.shape[:-1] + core.Size((n, n))) + arange = core.arange(n, device=vec.device) tril_mask = arange < arange.view(-1, 1) + (diag + 1) mat[..., tril_mask] = vec - return mat + return mat \ No newline at end of file diff --git a/mindnlp/core/fft/__init__.py b/mindnlp/core/fft/__init__.py index ccbf5ad90..88269425f 100644 --- a/mindnlp/core/fft/__init__.py +++ b/mindnlp/core/fft/__init__.py @@ -1,33 +1,32 @@ """fft""" -from mindspore import ops -from mindspore.ops._primitive_cache import _get_cache_prim -from ..configs import use_pyboost -from ..ops import narrow, roll -from ..nn import functional as F +from ..executor import execute + def rfft(input, n=None, dim=-1, norm="backward"): - if use_pyboost(): - return ops.rfft(input, n, dim, norm) - if input.shape[dim] < n: - pad_inf = (0, n - input.shape[dim]) - pad_dims = (0, 0) * (input.ndim - (dim + 1)) + pad_inf - input = F.pad(input, pad_dims) - else: - input = narrow(input, dim, 0, n) - _rfft = _get_cache_prim(ops.FFTWithSize)(input.ndim, False, True, norm) - return _rfft(input) + return execute('rfft', input, n, dim, norm) + # if use_pyboost(): + # return ops.rfft(input, n, dim, norm) + # if input.shape[dim] < n: + # pad_inf = (0, n - input.shape[dim]) + # pad_dims = (0, 0) * (input.ndim - (dim + 1)) + pad_inf + # input = F.pad(input, pad_dims) + # else: + # input = narrow(input, dim, 0, n) + # _rfft = _get_cache_prim(ops.FFTWithSize)(input.ndim, False, True, norm) + # return _rfft(input) def irfft(input, n=None, dim=-1, norm="backward"): - if use_pyboost(): - return ops.irfft(input, n, dim, norm) - if input.shape[dim] < n: - pad_inf = (0, n - input.shape[dim]) - pad_dims = (0, 0) * (input.ndim - (dim + 1)) + pad_inf - input = pad(input, pad_dims) - else: - input = narrow(input, dim, 0, n) - _irfft = _get_cache_prim(ops.FFTWithSize)(input.ndim, True, True, norm) - return _irfft(input) + return execute('irfft', input, n, dim, norm) + # if use_pyboost(): + # return ops.irfft(input, n, dim, norm) + # if input.shape[dim] < n: + # pad_inf = (0, n - input.shape[dim]) + # pad_dims = (0, 0) * (input.ndim - (dim + 1)) + pad_inf + # input = pad(input, pad_dims) + # else: + # input = narrow(input, dim, 0, n) + # _irfft = _get_cache_prim(ops.FFTWithSize)(input.ndim, True, True, norm) + # return _irfft(input) def fftn(input, s=None, dim=None, norm=None): return ops.fftn(input, s, dim, norm) diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py index 97db15b3b..ff6ddc3dd 100644 --- a/mindnlp/core/nn/functional.py +++ b/mindnlp/core/nn/functional.py @@ -53,7 +53,7 @@ def glu(input, dim=-1): return execute('glu', input, dim) def softplus(input, beta=1, threshold=20): - return execute('softplus', input, beta, threshold) + return execute('softplus_ext', input, beta, threshold) def logsigmoid(input): return execute('logsigmoid', input) @@ -96,10 +96,7 @@ def avg_pool1d(input, kernel_size, stride, padding=0, ceil_mode=False, count_inc Returns: - numpy array: The result of the average pooling operation. """ - if use_pyboost(): - return mint.nn.functional.avg_pool1d(input, kernel_size, stride, padding, ceil_mode, count_include_pad) - - return ops.avg_pool1d(input, kernel_size, stride, padding, ceil_mode, count_include_pad) + return execute('avg_pool1d', input, kernel_size, stride, padding, ceil_mode, count_include_pad) def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None): """ @@ -116,12 +113,7 @@ def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, coun Returns: - numpy array: The result of the average pooling operation. """ - if use_pyboost(): - return mint.nn.functional.avg_pool2d(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) - - if divisor_override is None: - divisor_override = 0 - return ops.avg_pool2d(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + return execute('avg_pool2d', input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) def avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None): if use_pyboost() and has_avg_pool3d: @@ -138,9 +130,7 @@ def adaptive_avg_pool1d(input, output_size): return ops.adaptive_avg_pool1d(input, output_size) def adaptive_avg_pool2d(input, output_size): - if use_pyboost(): - return mint.nn.functional.adaptive_avg_pool2d(input, output_size) - return ops.adaptive_avg_pool2d(input, output_size) + return execute('adaptive_avg_pool2d_ext', input, output_size) def dropout(input, p=0.5, training=True, inplace=False): if not training: @@ -252,7 +242,7 @@ def custom_circular_pad(x, pad): new_size = left_pad + size + right_pad # 生成循环索引: (index - left_pad) mod size - index = (core.arange(new_size) - left_pad) % size + index = (core.arange(new_size, device=x.device) - left_pad) % size x = core.index_select(x, dim, index) return x @@ -279,68 +269,53 @@ def _replication_pad(input, pad): out = execute('replication_pad_3d', input, pad) return out +def _circular_pad(input_x, padding): + """circular pad""" + if isinstance(padding, tuple): + padding = core.tensor(padding, dtype=core.int64, device=input_x.device) + elif isinstance(padding, list): + padding = core.tensor(padding, dtype=core.int64, device=input_x.device) + is_expand = False + if padding.shape[0] // 2 + 1 == input_x.ndim: + input_x = input_x.expand_dims(0) + is_expand = True + out = execute('pad_v3', input_x, padding, "circular", None) + if is_expand: + out = out.squeeze(0) + return out + def pad(input, pad, mode='constant', value=None): - if input.device.type != 'npu': - if mode == 'reflect' and input.ndim > 4: - paddings = [[0, 0]] - for i in range(0, len(pad), 2): - paddings.append([pad[i], pad[i+1]]) - old_shape = input.shape - shape = (-1, *old_shape[-3:]) - out = execute('mirror_pad', input.reshape(shape), core.tensor(paddings, device=input.device)) - return out.reshape(*old_shape[:-3], *out.shape[-3:]) - return execute('pad_v3', input, pad, mode, value) + if input.device.type in ['cpu', 'meta'] or ON_A1: + new_pad = () + for idx, pad_v in enumerate(pad): + if pad_v < 0: + dim = input.ndim - 1 - idx // 2 + input = input.narrow(dim, 0, input.shape[dim] + pad_v) + pad_v = 0 + new_pad += (pad_v,) + if sum(new_pad) == 0: + return input + return execute('pad_v3', input, new_pad, mode, value) + out = input + if (isinstance(pad, tuple) and not pad): + return out if sum(pad) == 0: - return input - if isinstance(pad, tuple): - pad = tuple(p if isinstance(p, int) else p.item() for p in pad) - if not ON_A1: - out = input - if (isinstance(pad, tuple) and not pad): - return out - if mode == "constant": - value = 0 if value is None else value - out = execute('constant_pad_nd', input, pad, value) - else: - if value is not None and value != 0: - raise ValueError(f"Padding mode {mode} doesn\'t take in value argument.") - if mode == "circular": - out = _circular_pad(input, pad) - elif mode == "reflect": - out = _reflection_pad(input, pad) - elif mode == "replicate": - out = _replication_pad(input, pad) - else: - raise ValueError(f"Pad filling mode must be 'constant' 'circular' 'reflect' or 'replicate'.") return out - - - if mode in ['reflect', 'replicate']: - if mode == 'reflect' and input.ndim > 4: - return execute('reflection_pad_3d', input, pad) - return execute('pad_v3', input, pad, mode) - if mode == 'circular': - return custom_circular_pad(input, pad) - new_pad = () - for idx, pad_v in enumerate(pad): - if pad_v < 0: - dim = input.ndim - 1 - idx // 2 - input = input.narrow(dim, 0, input.shape[dim] + pad_v) - pad_v = 0 - new_pad += (pad_v,) - if sum(new_pad) == 0: - return input - if input.dtype == core.bool_: - input = input.to(core.int32) - return execute('pad_v3', input, pad, mode, value).to(core.bool_) - if input.ndim > 5 and mode == 'constant': - paddings = () - for i in range(0, len(new_pad), 2): - paddings += (new_pad[i: i+2],) - - paddings = ((0, 0),) * (input.ndim - len(paddings)) + tuple(reversed(paddings)) - return execute('pad', paddings, input) - return execute('pad_v3', input, pad, mode, value) + if mode == "constant": + value = 0 if value is None else value + out = execute('constant_pad_nd', input, pad, value) + else: + if value is not None and value != 0: + raise ValueError(f"Padding mode {mode} doesn\'t take in value argument.") + if mode == "circular": + out = _circular_pad(input, pad) + elif mode == "reflect": + out = _reflection_pad(input, pad) + elif mode == "replicate": + out = _replication_pad(input, pad) + else: + raise ValueError(f"Pad filling mode must be 'constant' 'circular' 'reflect' or 'replicate'.") + return out def nll_loss(input, target, weight=None, ignore_index=-100, reduction='mean'): if input.device.type == 'npu': @@ -459,12 +434,8 @@ def cross_entropy(input, target, weight=None, ignore_index=-100, reduction='mean if target_dtype in [core.float32, core.float16, core.bfloat16]: return _cross_entropy_for_probabilities(input, target, weight, reduction, label_smoothing, class_dim, n_classes) - # for class indices - if input.device.type == 'npu': - return _cross_entropy_for_class_indices(input, target, weight, ignore_index, reduction, label_smoothing, - class_dim, n_classes) - return _inner_nll_loss(log_softmax(input, class_dim), target, weight, ignore_index, reduction, label_smoothing) - + return _cross_entropy_for_class_indices(input, target, weight, ignore_index, reduction, label_smoothing, + class_dim, n_classes) def _cross_entropy_for_probabilities(input, target, weight, reduction, label_smoothing, class_dim, n_classes): """cross_entropy inner function for class probabilities""" @@ -507,12 +478,12 @@ def _cross_entropy_for_class_indices(input, target, weight, ingore_index, reduct smooth_loss = -loss.sum(class_dim) else: smooth_loss = -input.sum(class_dim) - ignore_mask = ops.eq(target, ingore_index) - smooth_loss = masked_fill_op(smooth_loss, ignore_mask, 0) + ignore_mask = core.eq(target, ingore_index) + smooth_loss = core.masked_fill(smooth_loss, ignore_mask, 0) if reduction == "mean": true_mask = ~ignore_mask if weight is not None: - weight_sum = mint.gather(weight, 0, mint.masked_select(masked_select(target, true_mask))).sum() + weight_sum = core.gather(weight, 0, core.masked_select(core.masked_select(target, true_mask))).sum() if weight_sum == 0: ret = smooth_loss.sum() else: @@ -692,8 +663,8 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne # return torch._C._nn._upsample_bilinear2d_aa( # input, output_size, align_corners, scale_factors # ) - return upsample_bilinear2d_op( - input, output_size, scale_factors, align_corners + return execute( + 'upsample_bilinear2d', input, output_size, scale_factors, align_corners ) if input.dim() == 5 and mode == "trilinear": assert align_corners is not None @@ -750,31 +721,21 @@ def normalize(input, p=2.0, dim=1, eps=1e-6): The Lp norm is defined as the p-th root of the sum of the absolute values raised to the power of 'p'. The resulting tensor will have the same shape as the input tensor. """ - return input / ops.norm(input, ord=p, dim=dim, keepdim=True) + return input / core.norm(input, p=p, dim=dim, keepdim=True) def batch_norm(input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05): if running_mean is None: - running_mean = ops.ones(input.shape[1]) + running_mean = core.ones(input.shape[1], dtype=input.dtype, device=input.device) if running_var is None: - running_var = ops.zeros(input.shape[1]) + running_var = core.zeros(input.shape[1], dtype=input.dtype, device=input.device) if weight is None: - weight = ops.ones(input.shape[1]) + weight = core.ones(input.shape[1], dtype=input.dtype, device=input.device) if bias is None: - bias = ops.zeros(input.shape[1]) + bias = core.zeros(input.shape[1], dtype=input.dtype, device=input.device) - if use_pyboost() and not ON_ORANGE_PI: - return mint.nn.functional.batch_norm( - input, - running_mean, - running_var, - weight, - bias, - training, - momentum, - eps - ) - return ops.batch_norm( + return execute( + 'batch_norm_ext', input, running_mean, running_var, @@ -783,49 +744,17 @@ def batch_norm(input, running_mean, running_var, weight=None, bias=None, trainin training, momentum, eps - ) + )[0] def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): - if use_pyboost() and has_conv1d and not ON_ORANGE_PI: - return mint.nn.functional.conv1d(input, weight, bias, stride, padding, dilation, groups) - pad_mode = 'pad' - pad = padding - if isinstance(padding, tuple): - pad = (0, 0, padding[0], padding[0]) - elif isinstance(padding, int): - pad = (0, 0) + (padding,) * 2 - if not isinstance(padding, (int, tuple)): - pad_mode = padding - pad = (0,) * 4 - - _conv2d = _get_cache_prim(ops.Conv2D)(out_channel=weight.shape[0], - kernel_size=(1, weight.shape[-1]), - mode=1, - pad_mode=pad_mode, - pad=pad, - stride=(1, stride) if isinstance(stride, int) else (1, *stride), - dilation=(1, dilation) if isinstance(dilation, int) else (1, *dilation), - group=groups) - - input = input.expand_dims(2) - output = _conv2d(input, weight.expand_dims(2)) - - if bias is not None: - output = ops.bias_add(output, bias) - - output = output.squeeze(2) - return output - + if isinstance(padding, str): + return execute('conv1d_padding', input, weight, bias, stride, padding, dilation, groups) + return execute('conv1d_ext', input, weight, bias, stride, padding, dilation, groups) def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): - if use_pyboost() and not ON_ORANGE_PI: - return execute('conv2d_ext', input, weight, bias, stride, padding, dilation, groups) - - # pad_mode = 'pad' - # if not isinstance(padding, (int, tuple)): - # pad_mode = padding - - # return ops.conv2d(input, weight, bias=bias, stride=stride, pad_mode=pad_mode, padding=padding, dilation=dilation, groups=groups) + if isinstance(padding, str): + return execute('conv2d_padding', input, weight, bias, stride, padding, dilation, groups) + return execute('conv2d_ext', input, weight, bias, stride, padding, dilation, groups) def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): if use_pyboost() and not ON_ORANGE_PI: @@ -980,15 +909,13 @@ def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_paddi def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): - if use_pyboost(): - input_ndim = input.ndim - if input_ndim == 3: - input = input.unsqueeze(1) - out = mint.nn.functional.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode=ceil_mode, return_indices=return_indices) - if input_ndim == 3: - out = out.squeeze(1) - return out - return ops.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode=ceil_mode, return_indices=return_indices) + input_ndim = input.ndim + if input_ndim == 3: + input = input.unsqueeze(1) + out = execute('max_pool2d', input, kernel_size, stride, padding, dilation, ceil_mode=ceil_mode, return_indices=return_indices) + if input_ndim == 3: + out = out.squeeze(1) + return out def max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): if stride is None: @@ -1013,29 +940,32 @@ def max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5): - if use_pyboost() and not ON_ORANGE_PI: - return mint.nn.functional.group_norm(input, num_groups, weight, bias, eps) - - input_shape = input.shape - N = input_shape[0] - C = input_shape[1] - input_reshaped = input.view(1, N * num_groups, -1 if N!=0 else 1) - outputs = batch_norm(input_reshaped, None, None, None, None, True, 0., eps) - out = outputs.view(input_shape) - affine_param_shape = [1] * input.ndim - affine_param_shape[1] = C - affine_param_shape = tuple(affine_param_shape) - if weight is not None and bias is not None: - if not ON_ORANGE_PI: - out = bias.view(affine_param_shape).addcmul(out, weight.view(affine_param_shape), 1) - else: - out = core.addcmul(bias.view(affine_param_shape), out, weight.view(affine_param_shape), value=1) - - elif weight is not None: - out = out.mul(weight.view(affine_param_shape)) - elif bias is not None: - out = out.add(bias.view(affine_param_shape)) - return out + if weight is None: + weight = core.ones([input.shape[1]], dtype=input.dtype, device=input.device) + if bias is None: + bias = core.zeros([input.shape[1]], dtype=input.dtype, device=input.device) + return execute('group_norm', input, num_groups, weight, bias, eps)[0] + + # input_shape = input.shape + # N = input_shape[0] + # C = input_shape[1] + # input_reshaped = input.view(1, N * num_groups, -1 if N!=0 else 1) + # outputs = batch_norm(input_reshaped, None, None, None, None, True, 0., eps) + # out = outputs.view(input_shape) + # affine_param_shape = [1] * input.ndim + # affine_param_shape[1] = C + # affine_param_shape = tuple(affine_param_shape) + # if weight is not None and bias is not None: + # if not ON_ORANGE_PI: + # out = bias.view(affine_param_shape).addcmul(out, weight.view(affine_param_shape), 1) + # else: + # out = core.addcmul(bias.view(affine_param_shape), out, weight.view(affine_param_shape), value=1) + + # elif weight is not None: + # out = out.mul(weight.view(affine_param_shape)) + # elif bias is not None: + # out = out.add(bias.view(affine_param_shape)) + # return out def _in_projection( @@ -1135,7 +1065,7 @@ def _in_projection_packed( # # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() # proj = proj.unflatten(-1, (3, E)).unsqueeze(0).swapaxes(0, -2).squeeze(-2) # return proj[0], proj[1], proj[2] - return linear(q, w, b).chunk(3, axis=-1) + return linear(q, w, b).chunk(3, dim=-1) else: # encoder-decoder attention w_q, w_kv = w.split([E, E * 2]) @@ -1148,7 +1078,7 @@ def _in_projection_packed( # # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk() # kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).swapaxes(0, -2).squeeze(-2) # return (q_proj, kv_proj[0], kv_proj[1]) - return (linear(q, w_q, b_q),) + linear(k, w_kv, b_kv).chunk(2, axis=-1) + return (linear(q, w_q, b_q),) + linear(k, w_kv, b_kv).chunk(2, dim=-1) else: w_q, w_k, w_v = w.chunk(3) if b is None: diff --git a/mindnlp/core/nn/modules/batchnorm.py b/mindnlp/core/nn/modules/batchnorm.py index 92fa50e8b..c20be7bf4 100644 --- a/mindnlp/core/nn/modules/batchnorm.py +++ b/mindnlp/core/nn/modules/batchnorm.py @@ -58,6 +58,7 @@ def __init__( self.register_buffer("running_mean", None) self.register_buffer("running_var", None) self.register_buffer("num_batches_tracked", None) + self.reset_parameters() def reset_running_stats(self) -> None: if self.track_running_stats: diff --git a/mindnlp/core/ops/_inner.py b/mindnlp/core/ops/_inner.py index ed3c39726..2f7c35915 100644 --- a/mindnlp/core/ops/_inner.py +++ b/mindnlp/core/ops/_inner.py @@ -16,8 +16,11 @@ def npu_clear_float_status_v2(status): def all_finite(inputs): return execute('all_finite', inputs) +def masked_scatter(input, mask, source): + return execute('masked_scatter', input, mask, source) + __all__ = [ - 'cast', 'depend', + 'cast', 'depend', 'masked_scatter', 'npu_get_float_status_v2', 'npu_clear_float_status_v2', 'all_finite' ] diff --git a/mindnlp/core/ops/array.py b/mindnlp/core/ops/array.py index 5ff97f758..f7479a38a 100644 --- a/mindnlp/core/ops/array.py +++ b/mindnlp/core/ops/array.py @@ -22,10 +22,19 @@ def t(input): def argwhere(input): return execute("nonzero", input) +def infer_dtype(dtypes): + is_float_dtypes = [d.is_floating_point for d in dtypes] + float_dtypes = [d for d in dtypes if d.is_floating_point] + if any(is_float_dtypes): + return max(float_dtypes) + else: + return max(dtypes) # cat def cat(tensors, dim=0, **kwargs): dim = kwargs.pop('axis', dim) + dtype = infer_dtype([t.dtype for t in tensors]) + tensors = [t.to(dtype) for t in tensors] return execute("concat", tensors, dim) diff --git a/mindnlp/core/ops/creation.py b/mindnlp/core/ops/creation.py index f382e4857..1d8141834 100644 --- a/mindnlp/core/ops/creation.py +++ b/mindnlp/core/ops/creation.py @@ -91,13 +91,12 @@ def arange(start=0, end=None, step=1, *, out=None, dtype=None, layout=None, devi if end is None: start, end = 0, start if dtype is None: - dtype = core.int64 + dtype = core.py2dtype[type(start)] if device is None: device = get_device_in_context() if isinstance(device, str): device = core.device(device) - output = execute('arange', start, end, step, dtype, - device=device, requires_grad=requires_grad, user_created=True) + output = execute('arange', start, end, step, dtype, device=device, requires_grad=requires_grad, user_created=True) if out is None: return output out.data = output @@ -124,14 +123,11 @@ def linspace(start, end, steps, *, out=None, dtype=None, layout=None, device=Non dtype = get_default_dtype() if device is None: device = get_device_in_context() - if device.type == 'cpu': - start = core.tensor(start, device=device, dtype=dtype) - end = core.tensor(end, device=device, dtype=dtype) - output = execute('linspace', start, end, steps, - device=device, requires_grad=requires_grad, user_created=True) - else: - output = execute('lin_space_ext', start, end, steps, dtype, - device=device, requires_grad=requires_grad, user_created=True) + if isinstance(device, str): + device = core.device(device) + + output = execute('lin_space_ext', start, end, steps, dtype, + device=device, requires_grad=requires_grad, user_created=True) if out is None: return output out.data = output @@ -162,7 +158,7 @@ def empty(*size, out=None, dtype=None, layout=None, device=None, device = get_device_in_context() if isinstance(device, str): device = core.device(device) - if isinstance(size[0], (tuple, list)): + if len(size) > 0 and isinstance(size[0], (tuple, list)): size = size[0] if device.type == 'meta': @@ -205,6 +201,8 @@ def full(size, fill_value, *, out=None, dtype=None, layout=None, device=None, re # full_like def full_like(input, fill_value, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=None): + if dtype is None: + dtype = input.dtype return full(input.shape, fill_value, dtype=dtype, layout=layout, device=input.device, requires_grad=requires_grad) # quantize_per_tensor diff --git a/mindnlp/core/ops/other.py b/mindnlp/core/ops/other.py index ec1920926..99373d289 100644 --- a/mindnlp/core/ops/other.py +++ b/mindnlp/core/ops/other.py @@ -621,6 +621,8 @@ def einsum(equation, *operands): [3. 6.] [4. 8.]] """ + if isinstance(operands[0], (list, tuple)): + operands = operands[0] _equation, _operands = _einsum_convert_sublist(equation, *operands) _einsum_check_inputargs(_equation, _operands) return _einsum(_equation, _operands) @@ -684,7 +686,7 @@ def ravel(input): # repeat_interleave -def repeat_interleave(input, repeats, dim=None): +def repeat_interleave(input, repeats, dim=None, *, output_size=None): if input.device.type == 'npu' and not ON_A1: if isinstance(repeats, int): return execute('repeat_interleave_int', input, repeats, dim, None) diff --git a/mindnlp/core/ops/pointwise.py b/mindnlp/core/ops/pointwise.py index d85c57e36..508318eb5 100644 --- a/mindnlp/core/ops/pointwise.py +++ b/mindnlp/core/ops/pointwise.py @@ -306,6 +306,8 @@ def lerp(input, end, weight): # lgamma +def lgamma(input): + return execute('lgamma', input) # log @@ -617,6 +619,7 @@ def log_softmax(input, dim=None, dtype=None): "frac", "ldexp", "lerp", + "lgamma", "log", "log1p", "log2", diff --git a/mindnlp/core/ops/random.py b/mindnlp/core/ops/random.py index 50b66959d..51dc798cc 100644 --- a/mindnlp/core/ops/random.py +++ b/mindnlp/core/ops/random.py @@ -124,20 +124,16 @@ def rand( seed, offset = generator._step(generator_step_) # pylint: disable=protected-access if size and isinstance(size[0], (tuple, list)): size = size[0] - if device.type == 'cpu': - output = execute('uniform_real', size, - device=device, requires_grad=requires_grad, user_created=True).to(dtype) - else: - output = execute( - "rand_ext", - size, - seed, - offset, - dtype, - device=device, - requires_grad=requires_grad, - user_created=True, - ) + output = execute( + "rand_ext", + size, + seed, + offset, + dtype, + device=device, + requires_grad=requires_grad, + user_created=True, + ) if out is None: return output out.data = output diff --git a/mindnlp/core/ops/reduction.py b/mindnlp/core/ops/reduction.py index dec11cbf9..5b0501bb4 100644 --- a/mindnlp/core/ops/reduction.py +++ b/mindnlp/core/ops/reduction.py @@ -49,8 +49,8 @@ def max(input, dim=None, keepdim=False, *, out=None): if out is None: return max_out(values=output[1], indices=output[0]) - out[0].data = output[0] - out[1].data = output[1] + out[0].data = output[1] + out[1].data = output[0] return out # min diff --git a/mindnlp/core/types.py b/mindnlp/core/types.py index 3a659cd96..e8be49035 100644 --- a/mindnlp/core/types.py +++ b/mindnlp/core/types.py @@ -17,6 +17,8 @@ _TensorOrTensors: TypeAlias = Union[Tensor, Sequence[Tensor]] # noqa: PYI047 +Device: TypeAlias = Union[_device, str, int, None] + # Meta-type for "numeric" things; matches our docs Number: TypeAlias = Union[int, float, bool] # tuple for isinstance(x, Number) checks. diff --git a/setup.py b/setup.py index c4e2a8fe8..f78c05984 100644 --- a/setup.py +++ b/setup.py @@ -159,9 +159,9 @@ def run(self): 'mindspore>=2.5.0', 'tqdm', 'requests', - 'accelerate', # hf dependency + 'accelerate>=1.6.0', # hf dependency 'transformers>=4.55.0', # hf dependency - 'peft', # hf dependency + 'peft>=0.15.2', # hf dependency 'datasets', # hf dependency 'evaluate', # hf dependency 'tokenizers', # hf dependency @@ -174,6 +174,7 @@ def run(self): 'pyctcdecode', 'pytest', 'pillow>=10.0.0', + 'ftfy' ], classifiers=[ 'License :: OSI Approved :: Apache Software License' diff --git a/tests/run_test.py b/tests/run_test.py index 08f2a46f9..ac2ea6639 100644 --- a/tests/run_test.py +++ b/tests/run_test.py @@ -33,10 +33,12 @@ def run_tests(): "and not torch_fx " \ "and not test_wrong_device_map " \ "and not test_layerwise_casting " \ - "and not test_flex_attention" + "and not test_flex_attention " \ + "and not offload " \ + "and not global_device" pytest_args.extend(["--ignore-glob=test_modeling_flax_*.py"]) - # pytest_args.extend(['-k', skip_ut]) + pytest_args.extend(['-k', skip_ut]) if not pytest_args: print("未提供参数,默认运行当前目录下所有测试") print("使用示例: python run_test.py -v tests/")