diff --git a/mindnlp/accelerate/__init__.py b/mindnlp/accelerate/__init__.py
new file mode 100644
index 000000000..51c7f5deb
--- /dev/null
+++ b/mindnlp/accelerate/__init__.py
@@ -0,0 +1,18 @@
+import sys
+import accelerate
+from transformers.utils import _LazyModule
+
+_import_structure = {
+ "utils": [
+ 'DistributedType',
+
+ ]
+}
+
+sys.modules[__name__] = _LazyModule(
+ 'accelerate',
+ accelerate.__file__,
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={"__version__": accelerate.__version__},
+)
diff --git a/mindnlp/core/__init__.py b/mindnlp/core/__init__.py
index ec74c77c0..1a5ce8335 100644
--- a/mindnlp/core/__init__.py
+++ b/mindnlp/core/__init__.py
@@ -46,5 +46,28 @@
from ._bind import get_default_dtype, set_default_dtype
from . import profiler, cuda, optim, amp, compiler, jit, version, __future__, overrides, \
- return_types
+ return_types, linalg
+from ._lowrank import svd_lowrank
+
+def _has_compatible_shallow_copy_type(tensor, other):
+ """
+ Mimics the behavior of mindtorch._has_compatible_shallow_copy_type.
+
+ Args:
+ tensor (mindtorch.Tensor): The source tensor.
+ other (mindtorch.Tensor): The target tensor to check compatibility.
+
+ Returns:
+ bool: True if `tensor` and `other` have compatible types for shallow copy.
+ """
+ # Check if both tensors have the same type
+ if not is_tensor(tensor) or not is_tensor(other):
+ return False
+
+ # Check if both tensors are on the same device
+ if tensor.shape != other.shape:
+ return False
+
+ # Compatibility confirmed
+ return True
\ No newline at end of file
diff --git a/mindnlp/core/_dtype.py b/mindnlp/core/_dtype.py
index 5bb777f92..b145865c5 100644
--- a/mindnlp/core/_dtype.py
+++ b/mindnlp/core/_dtype.py
@@ -9,6 +9,7 @@ def is_floating_point(self):
return isinstance(self, (typing.Float, typing.BFloat16))
Type.is_floating_point = is_floating_point
+Type.__str__ = Type.__repr__
half = float16
float = float32
@@ -19,6 +20,7 @@ def is_floating_point(self):
bool = bool_
float8_e4m3fn = None # TODO: not support fp8 for now
+float8_e5m2 = None
np2dtype = {
np.bool_: bool,
diff --git a/mindnlp/core/_dynamo/utils.py b/mindnlp/core/_dynamo/utils.py
deleted file mode 100644
index 1f7544cae..000000000
--- a/mindnlp/core/_dynamo/utils.py
+++ /dev/null
@@ -1,3 +0,0 @@
-def is_compile_supported(device_type):
- compile_supported = False
- return compile_supported
\ No newline at end of file
diff --git a/mindnlp/core/_lowrank.py b/mindnlp/core/_lowrank.py
index c03d4f468..819e42f00 100644
--- a/mindnlp/core/_lowrank.py
+++ b/mindnlp/core/_lowrank.py
@@ -5,8 +5,8 @@
from typing import Optional
from mindnlp import core
-from mindnlp.core import _linalg_utils as _utils, Tensor
-from core.overrides import handle_torch_function, has_torch_function
+from . import _linalg_utils as _utils, Tensor
+from .overrides import handle_torch_function, has_torch_function
def get_approximate_basis(
diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py
index f4251f01b..ef287c4ba 100644
--- a/mindnlp/core/_tensor.py
+++ b/mindnlp/core/_tensor.py
@@ -1,17 +1,29 @@
+import math
import numpy as np
import mindspore
from mindspore import Tensor
from mindspore.common.tensor import _TensorMeta
+from mindspore._c_expression.typing import Type
try:
from mindspore.common._stub_tensor import StubTensor
except:
class StubTensor: pass
+from . import ops, _dtype
from ._dtype import dtype2np
-from ._bind import get_default_device
-
-from ._dtype import *
+from ._bind import get_default_device, device_
+from .configs import use_pyboost, ON_A1
+from .storage import UntypedStorage
+DTYPE_ELEMENT_SIZE_MAP = {
+ mindspore.float64: 8,
+ mindspore.int64: 8,
+ mindspore.int32: 4,
+ mindspore.float32: 4,
+ mindspore.int16: 2,
+ mindspore.bfloat16: 2,
+ mindspore.float16: 2,
+}
class TypedTensorMeta(_TensorMeta):
def __isinstancecheck__(self, instance):
@@ -20,31 +32,31 @@ def __isinstancecheck__(self, instance):
return instance.dtype == self.dtype
class LongTensor(Tensor, metaclass=TypedTensorMeta):
- dtype = long
+ dtype = _dtype.long
def __init__(self, data, device=None):
- super().__init__(data, dtype=long)
+ super().__init__(data, dtype=_dtype.long)
class FloatTensor(Tensor, metaclass=TypedTensorMeta):
- dtype = float32
+ dtype = _dtype.float32
def __init__(self, data, device=None):
- super().__init__(data, dtype=float32)
+ super().__init__(data, dtype=_dtype.float32)
class HalfTensor(Tensor, metaclass=TypedTensorMeta):
- dtype = float16
+ dtype = _dtype.float16
def __init__(self, data, device=None):
- super().__init__(data, dtype=float16)
+ super().__init__(data, dtype=_dtype.float16)
class BFloat16Tensor(Tensor, metaclass=TypedTensorMeta):
- dtype = float16
+ dtype = _dtype.float16
def __init__(self, data, device=None):
- super().__init__(data, dtype=bfloat16)
+ super().__init__(data, dtype=_dtype.bfloat16)
class BoolTensor(Tensor, metaclass=TypedTensorMeta):
- dtype = bool
+ dtype = _dtype.bool
def __init__(self, data, device=None):
- super().__init__(data, dtype=bool)
+ super().__init__(data, dtype=_dtype.bool)
def tensor(data, *, dtype=None, device=None, requires_grad=False):
if isinstance(data, Tensor):
@@ -105,8 +117,8 @@ def data_ptr(self):
Tensor.data_ptr = data_ptr
StubTensor.data_ptr = data_ptr
- Tensor.device = None
- StubTensor.device = None
+ Tensor.device = device_('not support yet.')
+ StubTensor.device = device_('not support yet.')
def _expand(self, *size):
if len(size) == 1:
@@ -115,3 +127,136 @@ def _expand(self, *size):
Tensor.expand = _expand
StubTensor.expand = _expand
+
+ def clone(self, *args, **kwargs):
+ return self.copy()
+
+ Tensor.clone = clone
+ StubTensor.clone = clone
+
+ def _repeat(self, *sizes):
+ if len(sizes) == 1 and isinstance(sizes[0], (list, tuple)):
+ sizes = sizes[0]
+ return ops.tile(self, tuple(sizes))
+
+ Tensor.repeat = _repeat
+ StubTensor.repeat = _repeat
+
+ def __or__(self, other):
+ if isinstance(other, (int, bool, float, Tensor)):
+ return ops.bitwise_or(self, other)
+ raise TypeError("Unsupported operand type(s) for |: 'Tensor' and '{}'".format(type(other)))
+
+ Tensor.__or__ = __or__
+ StubTensor.__or__ = __or__
+
+ def __and__(self, other):
+ if isinstance(other, (int, bool, float, Tensor)):
+ return ops.bitwise_and(self, other)
+ raise TypeError("Unsupported operand type(s) for &: 'Tensor' and '{}'".format(type(other)))
+
+ Tensor.__and__ = __and__
+ StubTensor.__and__ = __and__
+
+ def detach(self):
+ return ops.stop_gradient(self)
+
+ Tensor.detach = detach
+ StubTensor.detach = detach
+
+ origin_getitem = Tensor.__getitem__
+ def __getitem__(self, slices):
+ # if 0 in self.shape:
+ # return self
+ if isinstance(slices, tuple):
+ new_slices = ()
+ for s in slices:
+ if isinstance(s, range):
+ s = slice(s.start, s.stop, s.step)
+ new_slices += (s,)
+ slices = new_slices
+ return origin_getitem(self, slices)
+
+ Tensor.__getitem__ = __getitem__
+ StubTensor.__getitem__ = __getitem__
+
+ def numel(self):
+ return math.prod(self.shape)
+
+ Tensor.numel = numel
+ StubTensor.numel = numel
+ Tensor.nelement = numel
+ StubTensor.nelement = numel
+
+ @property
+ def nbytes(self):
+ return self.numel() * self.element_size()
+
+ Tensor.nbytes = nbytes
+ StubTensor.nbytes = nbytes
+
+ Tensor.normal_ = ops.inplace_normal
+ StubTensor.normal_ = ops.inplace_normal
+
+
+ Tensor.softmax = ops.softmax
+ StubTensor.softmax = ops.softmax
+
+ Tensor.squeeze = ops.squeeze
+ StubTensor.squeeze = ops.squeeze
+
+ Tensor.unsqueeze = ops.unsqueeze
+ StubTensor.unsqueeze = ops.unsqueeze
+
+ def log_softmax(self, dim=-1, dtype=None):
+ if use_pyboost():
+ return mindspore.mint.nn.functional.log_softmax(self, dim=dim, dtype=dtype)
+ out = mindspore.ops.log_softmax(self, dim)
+ if dtype is not None:
+ out = out.to(dtype)
+ return out
+
+ Tensor.log_softmax = log_softmax
+ StubTensor.log_softmax = log_softmax
+
+ def untyped_storage(self):
+ return UntypedStorage(self)
+
+ Tensor.untyped_storage = untyped_storage
+ StubTensor.untyped_storage = untyped_storage
+
+ def element_size(self,):
+ return DTYPE_ELEMENT_SIZE_MAP[self.dtype]
+
+ Tensor.element_size = element_size
+ StubTensor.element_size = element_size
+
+ @property
+ def layout(self):
+ return None
+
+ Tensor.layout = layout
+ StubTensor.layout = layout
+
+ def __add__(self, other):
+ # if 0 in self.shape:
+ # return self
+ return ops.add(self, other)
+
+ Tensor.__add__ = __add__
+ StubTensor.__add__ = __add__
+
+ Tensor.repeat_interleave = ops.repeat_interleave
+ StubTensor.repeat_interleave = ops.repeat_interleave
+
+ def dim(self):
+ return self.ndim
+
+ Tensor.dim = dim
+ StubTensor.dim = dim
+
+ def unfold(self, dimension, size, step):
+ return ops.unfold(self, dimension, size, step)
+
+ Tensor.unfold = unfold
+ StubTensor.unfold = unfold
diff --git a/mindnlp/core/autograd/__init__.py b/mindnlp/core/autograd/__init__.py
index f579cc99b..0f2fd5d47 100644
--- a/mindnlp/core/autograd/__init__.py
+++ b/mindnlp/core/autograd/__init__.py
@@ -1,4 +1,4 @@
"""autograd"""
from .node import Node
-from .function import Function
+from .function import Function, value_and_grad
from .grad_mode import no_grad, enable_grad, inference_mode
diff --git a/mindnlp/core/autograd/function.py b/mindnlp/core/autograd/function.py
index e7e35a925..02a1e9855 100644
--- a/mindnlp/core/autograd/function.py
+++ b/mindnlp/core/autograd/function.py
@@ -1,57 +1,110 @@
"""functional autograd"""
from collections.abc import Generator
-from dataclasses import dataclass
-from typing import Tuple, Any, Optional, Type, Sequence
-import functools
-
-@dataclass(unsafe_hash=True)
-class Context:
- """
- Context class is used by `Function` to store information during the forward pass.
- """
-
- no_grad: bool = False
- saved_values: Tuple[Any, ...] = ()
-
- def save_for_backward(self, *values: Any) -> None:
- "Store the given `values` if they need to be used during backpropagation."
- if self.no_grad:
- return
- self.saved_values = values
-
- @property
- def saved_tensors(self) -> Tuple[Any, ...]:
- return self.saved_values
-
-# Constructors
-class Function:
- @classmethod
- def _backward(cls, ctx: Context, *grad_out):
- return cls.backward(ctx, *grad_out) # type: ignore
-
- @classmethod
- def _forward(cls, ctx: Context, *inps, **kwargs):
- return cls.forward(ctx, *inps, **kwargs) # type: ignore
-
- @classmethod
- def apply(cls, *vals, **kwargs):
- # Create the context.
- ctx = Context(not requires_grad)
- # Call forward with the variables.
- results = cls._forward(ctx, *vals, **kwargs)
- requires_grad = any([x.requires_grad for x in vals])
-
- if requires_grad: # cut useless nodes
- generation = max([x.generation for x in vals])
- ctx.outputs = [weakref.ref(output) for output in outputs]
- back = History(cls, ctx, generation)
- for output in outputs:
- output.set_creator(back)
-
- return outputs if len(outputs) > 1 else outputs[0]
-
- def forward(self, xs):
- raise NotImplementedError()
-
- def backward(self, gys):
- raise NotImplementedError()
+
+import mindspore
+from mindspore.ops.composite import GradOperation
+from mindspore.ops import stop_gradient
+from mindspore.common.api import _pynative_executor
+from mindspore._c_expression import Cell_
+from .grad_mode import no_grad
+
+try:
+ from mindspore import _Function as Function
+except:
+ Function = None
+
+grad_ = GradOperation(False, True, False)
+grad_sens_ = GradOperation(False, True, True)
+grad_input_sens_ = GradOperation(True, True, True)
+
+def value_and_grad(fn, params_or_argnums, has_aux=False, attach_grads=True):
+ use_argnums = False
+ if isinstance(params_or_argnums, Generator):
+ params_or_argnums = tuple(params_or_argnums)
+
+ if isinstance(params_or_argnums[0], int):
+ use_argnums = True
+
+ def fn_aux(*args):
+ outputs = fn(*args)
+ no_grad_outputs = ()
+ for out in outputs[1:]:
+ no_grad_outputs += (stop_gradient(out),)
+ return outputs[0], no_grad_outputs
+
+ if has_aux:
+ fn_ = fn_aux
+ else:
+ fn_ = fn
+
+ def value_and_grad_f(*args, **kwargs):
+ _pynative_executor.set_grad_flag(True)
+ _pynative_executor.new_graph(fn, *args, **kwargs)
+ values = fn_(*args, **kwargs)
+ _pynative_executor.end_graph(fn, values, *args, **kwargs)
+
+ run_args = args
+ if kwargs:
+ run_args = args + tuple(kwargs.values())
+
+ grads = _pynative_executor.grad(fn_, grad_, params_or_argnums, None, *run_args)
+ grads = tuple(mindspore.Tensor(grad) for grad in grads)
+ if attach_grads:
+ for param, grad in zip(params_or_argnums, grads):
+ if param.grad is None:
+ param.grad = grad
+ else:
+ param.grad += grad
+ return values
+ return values, grads
+ return value_and_grad_f
+
+def grad(fn, params_or_argnums=None, has_aux=False):
+ value_and_grad_f = value_and_grad(fn, params_or_argnums, has_aux)
+ def grad_f(*args):
+ _, g = value_and_grad_f(*args)
+ return g
+ return grad_f
+
+
+if Function is None:
+ class Function(Cell_):
+ def __init__(self):
+ super().__init__(str(self.__class__)[8:-2])
+ self.saved_tensors = []
+ self.used_bprop_inputs = []
+
+ def save_for_backward(self, *args):
+ if isinstance(args, tuple):
+ self.saved_tensors.extend(list(args))
+ else:
+ self.saved_tensors.append(args)
+
+ @staticmethod
+ def forward(ctx, *args, **kwargs):
+ raise NotImplementedError
+
+ @staticmethod
+ def backward(ctx, *args, **kwargs):
+ raise NotImplementedError
+
+ def construct(self, *args, **kwargs):
+ self.needs_input_grad = [input_.requires_grad if hasattr(input_, 'requires_grad') else False for input_ in args]
+ args = (self,) + args
+ return self.forward(*args, **kwargs)
+
+ def bprop(self, *args, **kwargs):
+ args = (args[-1],)
+ args = (self,) + args
+ return self.backward(*args, **kwargs)
+
+ def __call__(self, *args, **kwargs):
+ with no_grad():
+ output = self.construct(*args, **kwargs)
+ if _pynative_executor.requires_grad():
+ _pynative_executor.call_custom_bprop(self, output, *args, **kwargs)
+ return output
+
+ @classmethod
+ def apply(cls, *args, **kwargs):
+ return cls()(*args, **kwargs)
diff --git a/mindnlp/core/autograd/functions/custom.py b/mindnlp/core/autograd/functions/custom.py
deleted file mode 100644
index 5f53b3bd6..000000000
--- a/mindnlp/core/autograd/functions/custom.py
+++ /dev/null
@@ -1,59 +0,0 @@
-from mindspore._c_expression import TensorPy as MSTensor
-
-from mindnlp import core
-from core._prims.ascend import cast_npu
-from core._prims.cpu import cast_cpu
-from ..node import Node
-
-
-class AccumulateGrad(Node):
- def __init__(self):
- super().__init__('AccumulateGrad')
- self._post_hook = None
-
- def construct(self, input):
- return input
-
- def bprop(self, input, output, grad):
- if input.grad is None:
- input.grad = grad
- else:
- input.grad += grad
-
- if self._post_hook is not None:
- self._post_hook(input)
- return grad
-
- def register_post_hook(self, hook):
- self._post_hook = hook
-
-
-class Cast(Node):
- def __init__(self):
- super().__init__('Cast')
- self.used_bprop_inputs = []
-
- def construct(self, input, dtype, device):
- self.device = input.device
- self.dtype = input.dtype
- if device.type == 'cpu':
- out = cast_cpu(input, dtype).get_value()
- else:
- out = cast_npu(input, dtype).get_value()
-
- output = core.Tensor.__new__(core.Tensor)
- MSTensor.__init__(output, out)
- output.device = device
- return output
-
- def bprop(self, *args):
- grad = args[-1]
- if self.device.type == 'cpu':
- out = cast_cpu(grad, self.dtype).get_value()
- else:
- out = cast_npu(grad, self.dtype).get_value()
-
- output = core.Tensor.__new__(core.Tensor)
- MSTensor.__init__(output, out)
- output.device = self.device
- return output, None, None
diff --git a/mindnlp/core/configs.py b/mindnlp/core/configs.py
index c046d3ec7..6a65f786a 100644
--- a/mindnlp/core/configs.py
+++ b/mindnlp/core/configs.py
@@ -5,6 +5,7 @@
SOC = MSContext.get_instance().get_ascend_soc_version()
DEVICE_TARGET = mindspore.get_context('device_target')
SUPPORT_BF16 = SOC in ["ascend910b", "ascend910_93"]
+ON_A1 = not SUPPORT_BF16
ON_ORANGE_PI = '310b' in SOC
USE_PYBOOST = DEVICE_TARGET == 'Ascend'
DEFAULT_DTYPE = mindspore.float32
diff --git a/mindnlp/core/distributed/fsdp/__init__.py b/mindnlp/core/distributed/fsdp/__init__.py
index fa3888cbd..3b6767333 100644
--- a/mindnlp/core/distributed/fsdp/__init__.py
+++ b/mindnlp/core/distributed/fsdp/__init__.py
@@ -1 +1,2 @@
-FullyShardedDataParallel = None
+class FullyShardedDataParallel:
+ pass
diff --git a/mindnlp/core/distributed/tensor/__init__.py b/mindnlp/core/distributed/tensor/__init__.py
index 82f67afae..9790c03e8 100644
--- a/mindnlp/core/distributed/tensor/__init__.py
+++ b/mindnlp/core/distributed/tensor/__init__.py
@@ -1,4 +1,4 @@
Replicate = None
-DTensor = None
+class DTensor(): pass
Placement = None
Shard = None
diff --git a/mindnlp/core/distributions/__init__.py b/mindnlp/core/distributions/__init__.py
index f012e1020..3ca74fcd0 100644
--- a/mindnlp/core/distributions/__init__.py
+++ b/mindnlp/core/distributions/__init__.py
@@ -10,3 +10,4 @@
from .transforms import *
from .relaxed_categorical import *
from .relaxed_bernoulli import *
+from .multivariate_normal import *
\ No newline at end of file
diff --git a/mindnlp/core/distributions/constraints.py b/mindnlp/core/distributions/constraints.py
index 460e83182..ef0311ac8 100644
--- a/mindnlp/core/distributions/constraints.py
+++ b/mindnlp/core/distributions/constraints.py
@@ -33,7 +33,7 @@
"""
import mindspore
-from .. import ops
+from mindnlp import core
__all__ = [
@@ -452,7 +452,7 @@ class _Simplex(Constraint):
event_dim = 1
def check(self, value):
- return ops.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)
+ return core.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)
class _Multinomial(Constraint):
@@ -528,7 +528,7 @@ class _Square(Constraint):
event_dim = 2
def check(self, value):
- return ops.full(
+ return core.full(
size=value.shape[:-2],
fill_value=(value.shape[-2] == value.shape[-1]),
dtype=mindspore.bool_,
@@ -544,7 +544,7 @@ def check(self, value):
square_check = super().check(value)
if not square_check.all():
return square_check
- return ops.isclose(value, value.mT, atol=1e-6).all(-2).all(-1)
+ return core.isclose(value, value.mT, atol=1e-6).all(-2).all(-1)
class _PositiveSemidefinite(_Symmetric):
@@ -568,7 +568,7 @@ def check(self, value):
sym_check = super().check(value)
if not sym_check.all():
return sym_check
- return ops.linalg.cholesky_ex(value).info.eq(0)
+ return core.linalg.cholesky_ex(value).info.eq(0)
class _Cat(Constraint):
diff --git a/mindnlp/core/distributions/multivariate_normal.py b/mindnlp/core/distributions/multivariate_normal.py
new file mode 100644
index 000000000..a4f25920e
--- /dev/null
+++ b/mindnlp/core/distributions/multivariate_normal.py
@@ -0,0 +1,269 @@
+# mypy: allow-untyped-defs
+import math
+from typing import Optional
+
+from mindnlp import core
+from mindnlp.core import Tensor
+from mindnlp.core.distributions import constraints
+from mindnlp.core.distributions.distribution import Distribution
+from mindnlp.core.distributions.utils import _standard_normal, lazy_property
+from mindnlp.core.types import _size
+
+
+__all__ = ["MultivariateNormal"]
+
+
+def _batch_mv(bmat, bvec):
+ r"""
+ Performs a batched matrix-vector product, with compatible but different batch shapes.
+
+ This function takes as input `bmat`, containing :math:`n \times n` matrices, and
+ `bvec`, containing length :math:`n` vectors.
+
+ Both `bmat` and `bvec` may have any number of leading dimensions, which correspond
+ to a batch shape. They are not necessarily assumed to have the same batch shape,
+ just ones which can be broadcasted.
+ """
+ return core.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1)
+
+
+def _batch_mahalanobis(bL, bx):
+ r"""
+ Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
+ for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.
+
+ Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch
+ shape, but `bL` one should be able to broadcasted to `bx` one.
+ """
+ n = bx.size(-1)
+ bx_batch_shape = bx.shape[:-1]
+
+ # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
+ # we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tri.solve
+ bx_batch_dims = len(bx_batch_shape)
+ bL_batch_dims = bL.dim() - 2
+ outer_batch_dims = bx_batch_dims - bL_batch_dims
+ old_batch_dims = outer_batch_dims + bL_batch_dims
+ new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
+ # Reshape bx with the shape (..., 1, i, j, 1, n)
+ bx_new_shape = bx.shape[:outer_batch_dims]
+ for sL, sx in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
+ bx_new_shape += (sx // sL, sL)
+ bx_new_shape += (n,)
+ bx = bx.reshape(bx_new_shape)
+ # Permute bx to make it have shape (..., 1, j, i, 1, n)
+ permute_dims = (
+ list(range(outer_batch_dims))
+ + list(range(outer_batch_dims, new_batch_dims, 2))
+ + list(range(outer_batch_dims + 1, new_batch_dims, 2))
+ + [new_batch_dims]
+ )
+ bx = bx.permute(permute_dims)
+
+ flat_L = bL.reshape(-1, n, n) # shape = b x n x n
+ flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n
+ flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c
+ M_swap = (
+ core.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2)
+ ) # shape = b x c
+ M = M_swap.t() # shape = c x b
+
+ # Now we revert the above reshape and permute operators.
+ permuted_M = M.reshape(bx.shape[:-1]) # shape = (..., 1, j, i, 1)
+ permute_inv_dims = list(range(outer_batch_dims))
+ for i in range(bL_batch_dims):
+ permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
+ reshaped_M = permuted_M.permute(permute_inv_dims) # shape = (..., 1, i, j, 1)
+ return reshaped_M.reshape(bx_batch_shape)
+
+
+def _precision_to_scale_tril(P):
+ # Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
+ Lf = core.linalg.cholesky(core.flip(P, (-2, -1)))
+ L_inv = core.transpose(core.flip(Lf, (-2, -1)), -2, -1)
+ Id = core.eye(P.shape[-1], dtype=P.dtype, device=P.device)
+ L = core.linalg.solve_triangular(L_inv, Id, upper=False)
+ return L
+
+
+class MultivariateNormal(Distribution):
+ r"""
+ Creates a multivariate normal (also called Gaussian) distribution
+ parameterized by a mean vector and a covariance matrix.
+
+ The multivariate normal distribution can be parameterized either
+ in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}`
+ or a positive definite precision matrix :math:`\mathbf{\Sigma}^{-1}`
+ or a lower-triangular matrix :math:`\mathbf{L}` with positive-valued
+ diagonal entries, such that
+ :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`. This triangular matrix
+ can be obtained via e.g. Cholesky decomposition of the covariance.
+
+ Example:
+
+ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
+ >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+ >>> m = MultivariateNormal(core.zeros(2), core.eye(2))
+ >>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I`
+ tensor([-0.2102, -0.5429])
+
+ Args:
+ loc (Tensor): mean of the distribution
+ covariance_matrix (Tensor): positive-definite covariance matrix
+ precision_matrix (Tensor): positive-definite precision matrix
+ scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
+
+ Note:
+ Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
+ :attr:`scale_tril` can be specified.
+
+ Using :attr:`scale_tril` will be more efficient: all computations internally
+ are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
+ :attr:`precision_matrix` is passed instead, it is only used to compute
+ the corresponding lower triangular matrices using a Cholesky decomposition.
+ """
+
+ arg_constraints = {
+ "loc": constraints.real_vector,
+ "covariance_matrix": constraints.positive_definite,
+ "precision_matrix": constraints.positive_definite,
+ "scale_tril": constraints.lower_cholesky,
+ }
+ support = constraints.real_vector
+ has_rsample = True
+
+ def __init__(
+ self,
+ loc: Tensor,
+ covariance_matrix: Optional[Tensor] = None,
+ precision_matrix: Optional[Tensor] = None,
+ scale_tril: Optional[Tensor] = None,
+ validate_args: Optional[bool] = None,
+ ) -> None:
+ if loc.dim() < 1:
+ raise ValueError("loc must be at least one-dimensional.")
+ if (covariance_matrix is not None) + (scale_tril is not None) + (
+ precision_matrix is not None
+ ) != 1:
+ raise ValueError(
+ "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
+ )
+
+ if scale_tril is not None:
+ if scale_tril.dim() < 2:
+ raise ValueError(
+ "scale_tril matrix must be at least two-dimensional, "
+ "with optional leading batch dimensions"
+ )
+ batch_shape = core.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
+ self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
+ elif covariance_matrix is not None:
+ if covariance_matrix.dim() < 2:
+ raise ValueError(
+ "covariance_matrix must be at least two-dimensional, "
+ "with optional leading batch dimensions"
+ )
+ batch_shape = core.broadcast_shapes(
+ covariance_matrix.shape[:-2], loc.shape[:-1]
+ )
+ self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
+ else:
+ assert precision_matrix is not None # helps mypy
+ if precision_matrix.dim() < 2:
+ raise ValueError(
+ "precision_matrix must be at least two-dimensional, "
+ "with optional leading batch dimensions"
+ )
+ batch_shape = core.broadcast_shapes(
+ precision_matrix.shape[:-2], loc.shape[:-1]
+ )
+ self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
+ self.loc = loc.expand(batch_shape + (-1,))
+
+ event_shape = self.loc.shape[-1:]
+ super().__init__(batch_shape, event_shape, validate_args=validate_args)
+
+ if scale_tril is not None:
+ self._unbroadcasted_scale_tril = scale_tril
+ elif covariance_matrix is not None:
+ self._unbroadcasted_scale_tril = core.linalg.cholesky(covariance_matrix)
+ else: # precision_matrix is not None
+ self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
+
+ def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(MultivariateNormal, _instance)
+ batch_shape = core.Size(batch_shape)
+ loc_shape = batch_shape + self.event_shape
+ cov_shape = batch_shape + self.event_shape + self.event_shape
+ new.loc = self.loc.expand(loc_shape)
+ new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril
+ if "covariance_matrix" in self.__dict__:
+ new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
+ if "scale_tril" in self.__dict__:
+ new.scale_tril = self.scale_tril.expand(cov_shape)
+ if "precision_matrix" in self.__dict__:
+ new.precision_matrix = self.precision_matrix.expand(cov_shape)
+ super(MultivariateNormal, new).__init__(
+ batch_shape, self.event_shape, validate_args=False
+ )
+ new._validate_args = self._validate_args
+ return new
+
+ @lazy_property
+ def scale_tril(self) -> Tensor:
+ return self._unbroadcasted_scale_tril.expand(
+ self._batch_shape + self._event_shape + self._event_shape
+ )
+
+ @lazy_property
+ def covariance_matrix(self) -> Tensor:
+ return core.matmul(
+ self._unbroadcasted_scale_tril, self._unbroadcasted_scale_tril.mT
+ ).expand(self._batch_shape + self._event_shape + self._event_shape)
+
+ @lazy_property
+ def precision_matrix(self) -> Tensor:
+ return core.cholesky_inverse(self._unbroadcasted_scale_tril).expand(
+ self._batch_shape + self._event_shape + self._event_shape
+ )
+
+ @property
+ def mean(self) -> Tensor:
+ return self.loc
+
+ @property
+ def mode(self) -> Tensor:
+ return self.loc
+
+ @property
+ def variance(self) -> Tensor:
+ return (
+ self._unbroadcasted_scale_tril.pow(2)
+ .sum(-1)
+ .expand(self._batch_shape + self._event_shape)
+ )
+
+ def rsample(self, sample_shape: _size = core.Size()) -> Tensor:
+ shape = self._extended_shape(sample_shape)
+ eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
+ return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps)
+
+ def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ diff = value - self.loc
+ M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
+ half_log_det = (
+ self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
+ )
+ return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det
+
+ def entropy(self):
+ half_log_det = (
+ self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
+ )
+ H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det
+ if len(self._batch_shape) == 0:
+ return H
+ else:
+ return H.expand(self._batch_shape)
\ No newline at end of file
diff --git a/mindnlp/core/distributions/utils.py b/mindnlp/core/distributions/utils.py
index 596f1d68b..cdb6e3b8d 100644
--- a/mindnlp/core/distributions/utils.py
+++ b/mindnlp/core/distributions/utils.py
@@ -58,7 +58,7 @@ def broadcast_all(*values):
return ops.broadcast_tensors(*values)
-def _standard_normal(shape, dtype):
+def _standard_normal(shape, dtype, device=None):
return ops.normal(size = shape).to(dtype)
diff --git a/mindnlp/core/linalg/__init__.py b/mindnlp/core/linalg/__init__.py
new file mode 100644
index 000000000..25891956c
--- /dev/null
+++ b/mindnlp/core/linalg/__init__.py
@@ -0,0 +1,23 @@
+from collections import namedtuple
+from mindspore import ops
+from mindspore.ops._primitive_cache import _get_cache_prim
+
+from mindnlp import core
+
+linalg_cholesky_ex = namedtuple('linalg_cholesky_ex', ['L', 'info'])
+
+def cholesky(A, *, upper=False, out=None):
+ cholesky_op = _get_cache_prim(ops.Cholesky)(upper=upper).set_device('CPU')
+ return cholesky_op(A)
+
+def cholesky_ex(A, *, upper=False, check_errors=False, out=None):
+ try:
+ out = cholesky(A, upper=upper, out=out)
+ out + 1
+ info = core.Tensor(0)
+ except:
+ info = core.Tensor(1)
+ out = A
+ return linalg_cholesky_ex(out, info)
+
+
diff --git a/mindnlp/core/nn/__init__.py b/mindnlp/core/nn/__init__.py
index c3eecf6d8..0de29ba94 100644
--- a/mindnlp/core/nn/__init__.py
+++ b/mindnlp/core/nn/__init__.py
@@ -16,3 +16,4 @@
from . import utils, functional, init
from .modules import *
from .parameter import Parameter
+from .parallel import DataParallel as DataParallel
\ No newline at end of file
diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py
index e7db4975d..241df33ae 100644
--- a/mindnlp/core/nn/functional.py
+++ b/mindnlp/core/nn/functional.py
@@ -10,7 +10,7 @@
from mindnlp import core
from mindnlp.core.executor import execute
-from ..configs import DEVICE_TARGET, ON_ORANGE_PI, use_pyboost
+from ..configs import DEVICE_TARGET, ON_ORANGE_PI, use_pyboost, ON_A1
generator_step_ = 12
@@ -152,15 +152,12 @@ def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, coun
return ops.avg_pool2d(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
-def dropout(input, p=0.5, training=True, inplace=False):
- if not training:
- return input
- seed, offset = default_generator._step(generator_step_) # pylint: disable=protected-access
- out, _ = execute('dropout_ext', input, p, seed, offset)
- if inplace:
- input.copy_(out)
+def dropout(input, p=0.5, training=True):
+ if not training or p == 0:
return input
- return out
+ if use_pyboost() and not ON_ORANGE_PI:
+ return mint.nn.functional.dropout(input, p, training)
+ return ops.dropout(input, p, training)
def dropout2d(input, p=0.5, training=False):
return ops.dropout2d(input, p, training)
@@ -211,11 +208,15 @@ def gumbel_softmax(logits: core.Tensor, tau: float = 1, hard: bool = False, eps:
ret = y_soft
return ret
-def log_softmax(input, dim=-1, dtype=None):
- if input.device.type == 'cpu':
- return execute('log_softmax', input, dim)
- return execute('log_softmax_ext', input, dim,
- dtype if dtype is None else dtype_to_type_id('LogSoftmaxExt', 'dtype', dtype))
+def log_softmax(input, dim=None, dtype=None):
+ if use_pyboost():
+ return mint.nn.functional.log_softmax(input, dim=dim, dtype=dtype)
+ if dim is None:
+ dim = -1
+ out = ops.log_softmax(input, dim)
+ if dtype is not None:
+ out = out.to(dtype)
+ return out
def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False):
if use_pyboost():
@@ -259,24 +260,15 @@ def _replication_pad(input, pad):
return out
def pad(input, pad, mode='constant', value=0.0):
- out = input
- if (isinstance(pad, tuple) and not pad):
- return out
- if mode == "constant":
- value = 0 if value is None else value
- out = execute('constant_pad_nd', input, pad, value)
- else:
- if value is not None and value != 0:
- raise ValueError(f"Padding mode {mode} doesn\'t take in value argument.")
- if mode == "circular":
- out = _circular_pad(input, pad)
- elif mode == "reflect":
- out = _reflection_pad(input, pad)
- elif mode == "replicate":
- out = _replication_pad(input, pad)
- else:
- raise ValueError(f"Pad filling mode must be 'constant' 'circular' 'reflect' or 'replicate'.")
- return out
+ if sum(pad) == 0:
+ return input
+ if isinstance(pad, tuple):
+ pad = tuple(p if isinstance(p, int) else p.item() for p in pad)
+ if use_pyboost() and not ON_A1:
+ return mint.nn.functional.pad(input, pad, mode, value)
+ if mode in ['reflect', 'circular']:
+ return ops.pad(input, pad, mode)
+ return ops.pad(input, pad, mode, value)
def nll_loss(input, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0):
return _inner_nll_loss(input, target, weight, ignore_index, reduction, label_smoothing)
@@ -338,7 +330,7 @@ def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, red
if target.ndim == inputs.ndim - 1:
target = target.unsqueeze(target_dim)
if ignore_index is not None:
- non_pad_mask = core.equal(target, ignore_index)
+ non_pad_mask = core.eq(target, ignore_index)
target = target.masked_fill(non_pad_mask, 0)
else:
non_pad_mask = target
diff --git a/mindnlp/core/nn/modules/linear.py b/mindnlp/core/nn/modules/linear.py
index 43ee37f9b..23d0e5b7a 100644
--- a/mindnlp/core/nn/modules/linear.py
+++ b/mindnlp/core/nn/modules/linear.py
@@ -29,13 +29,13 @@ class Linear(Module):
Examples::
>>> m = nn.Linear(20, 30)
- >>> input = autograd.Variable(core.randn(128, 20))
+ >>> input = autograd.Variable(torch.randn(128, 20))
>>> output = m(input)
>>> print(output.size())
"""
- def __init__(self, in_features, out_features, bias=True, dtype=None) -> None:
- factory_kwargs = {'dtype': dtype}
+ def __init__(self, in_features, out_features, bias=True, dtype=None, device=None) -> None:
+ factory_kwargs = {'dtype': dtype, 'device': device}
super().__init__()
self.in_features = in_features
self.out_features = out_features
@@ -80,10 +80,10 @@ class Identity(Module):
Examples::
>>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
- >>> input = core.randn(128, 20)
+ >>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
- core.Size([128, 20])
+ torch.Size([128, 20])
"""
diff --git a/mindnlp/core/nn/modules/sparse.py b/mindnlp/core/nn/modules/sparse.py
index 5712d68a4..88642422d 100644
--- a/mindnlp/core/nn/modules/sparse.py
+++ b/mindnlp/core/nn/modules/sparse.py
@@ -32,7 +32,7 @@ class Embedding(Module):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
sparse: bool = False, _weight: Optional[Tensor] = None, _freeze: bool = False,
- dtype=None) -> None:
+ dtype=None, device=None) -> None:
factory_kwargs = {'dtype': dtype}
super().__init__()
self.num_embeddings = num_embeddings
@@ -107,10 +107,10 @@ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None,
Examples::
>>> # FloatTensor containing pretrained weights
- >>> weight = core.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
+ >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
>>> embedding = nn.Embedding.from_pretrained(weight)
>>> # Get embeddings for index 1
- >>> input = core.LongTensor([1])
+ >>> input = torch.LongTensor([1])
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> embedding(input)
tensor([[ 4.0000, 5.1000, 6.3000]])
diff --git a/mindnlp/core/nn/parallel/__init__.py b/mindnlp/core/nn/parallel/__init__.py
index e69de29bb..04d8b0c5e 100644
--- a/mindnlp/core/nn/parallel/__init__.py
+++ b/mindnlp/core/nn/parallel/__init__.py
@@ -0,0 +1 @@
+from .distributed import DistributedDataParallel, DataParallel
\ No newline at end of file
diff --git a/mindnlp/core/nn/parallel/distributed.py b/mindnlp/core/nn/parallel/distributed.py
index aefe06666..5961ad131 100644
--- a/mindnlp/core/nn/parallel/distributed.py
+++ b/mindnlp/core/nn/parallel/distributed.py
@@ -1,4 +1,7 @@
from ..modules import Module
class DistributedDataParallel(Module):
- pass
\ No newline at end of file
+ pass
+
+class DataParallel(Module):
+ pass
diff --git a/mindnlp/core/ops/__init__.py b/mindnlp/core/ops/__init__.py
index 5f14d9d90..e4b029cbd 100644
--- a/mindnlp/core/ops/__init__.py
+++ b/mindnlp/core/ops/__init__.py
@@ -1,6 +1,6 @@
"""core ops like torch funcional api"""
from . import array, blas, comparison, pointwise, creation, random, reduction, other, \
- tensor, _inner, optim
+ tensor, _inner, optim, inplace
from .array import *
from .blas import *
from .comparison import *
@@ -14,6 +14,7 @@
# from .spectral import *
from ._inner import *
from .optim import *
+from .inplace import *
def load_library(lib_path):
raise ImportError('not support import any ops for now.')
@@ -34,3 +35,4 @@ def load_library(lib_path):
__all__.extend(tensor.__all__)
__all__.extend(other.__all__)
__all__.extend(optim.__all__)
+__all__.extend(inplace.__all__)
diff --git a/mindnlp/core/ops/array.py b/mindnlp/core/ops/array.py
index ba54bf936..571293abb 100644
--- a/mindnlp/core/ops/array.py
+++ b/mindnlp/core/ops/array.py
@@ -9,6 +9,7 @@
from ..configs import use_pyboost, ON_ORANGE_PI
from .other import broadcast_tensors
from ._inner import call_ms_func
+from .creation import arange
# adjoint
@@ -20,7 +21,10 @@ def argwhere(input):
# cat
has_cat = hasattr(mindspore.mint, 'cat')
-def cat(tensors, dim=0, *, out=None):
+def cat(tensors, dim=0, *, out=None, **kwargs):
+ axis = kwargs.get('axis', None)
+ if axis is not None:
+ dim = axis
if use_pyboost() and has_cat:
return call_ms_func(mindspore.mint.cat, tensors, dim, out=out)
return call_ms_func(ops.cat, tensors, dim, out=out)
@@ -223,7 +227,7 @@ def split(tensor, split_size_or_sections, dim=0):
# squeeze
has_squeeze = hasattr(mindspore.mint, 'squeeze')
-def squeeze(input, dim=None):
+def squeeze(input, *dim):
if use_pyboost() and has_squeeze:
return mindspore.mint.squeeze(input, dim)
return ops.squeeze(input, dim)
@@ -253,15 +257,58 @@ def take(input, index):
return tf_gather(input, index, 0).view(index_shape)
return gather(input, 0, index).view(index_shape)
-# take_along_dim
+def infer_size_impl(a, b):
+ lenA = len(a)
+ lenB = len(b)
+ ndim = max(lenA, lenB)
+ expanded_sizes = [0] * ndim
+
+ for i in range(ndim - 1, -1, -1):
+ offset = ndim - 1 - i
+ dimA = lenA - 1 - offset
+ dimB = lenB - 1 - offset
+
+ sizeA = a[dimA] if dimA >= 0 else 1
+ sizeB = b[dimB] if dimB >= 0 else 1
+
+ # 检查维度兼容性
+ if not (sizeA == sizeB or sizeA == 1 or sizeB == 1):
+ raise RuntimeError(
+ f"The size of tensor a ({sizeA}) must match the size of tensor b ({sizeB}) "
+ f"at non-singleton dimension {i}"
+ )
+
+ # 应用广播规则:优先选择非1的维度大小
+ expanded_sizes[i] = sizeB if sizeA == 1 else sizeA
+
+ return expanded_sizes
+
+def _take_along_dim_helper(self, indices, dim):
+ assert self.dim() == indices.dim(), f"torch.take_along_dim(): input and indices should have the same number of dimensions, " \
+ f"but got {self.dim()} dimensions for input, and {indices.dim()} dimensions for indices"
+ dim = self.dim() + dim if dim < 0 else dim
+ self_sizes = list(self.shape)
+ self_sizes[dim] = indices.size(dim)
+ broadcast_shape = infer_size_impl(self_sizes, indices.shape)
+ indices_broadcasted = indices.broadcast_to(broadcast_shape)
+
+ indices_sizes = list(indices.shape)
+ indices_sizes[dim] = self.size(dim)
+ broadcast_shape = infer_size_impl(indices_sizes, self.shape)
+ self_broadcasted = self.broadcast_to(broadcast_shape)
+
+ return self_broadcasted, indices_broadcasted, dim
+
+# take_along_dim
+def take_along_dim(input, indices, dim=None, *, out=None):
+ if dim:
+ self_broadcasted, indices_broadcasted, dim = _take_along_dim_helper(input, indices, dim)
+ return self_broadcasted.gather(dim, indices_broadcasted)
+ return input.view(-1).gather(0, indices.view(-1))
# tensor_split
def tensor_split(input, indices_or_sections, dim=0):
- if isinstance(indices_or_sections, mindspore.Tensor):
- indices_or_sections = indices_or_sections.tolist()
- else:
- indices_or_sections = tuple([get_item(t) for t in indices_or_sections])
return ops.tensor_split(input, indices_or_sections, dim)
# tile
@@ -710,7 +757,7 @@ def strided_slice_update(input, begin, end, strides, update, begin_mask=0, end_m
'swapaxes',
'swapdims',
'take',
- # take_along_dim
+ 'take_along_dim',
'tensor_split',
'tile',
'transpose',
diff --git a/mindnlp/core/ops/creation.py b/mindnlp/core/ops/creation.py
index 657c3e84f..5bb550a32 100644
--- a/mindnlp/core/ops/creation.py
+++ b/mindnlp/core/ops/creation.py
@@ -87,13 +87,24 @@ def ones_like(input, *, dtype=None, device=None):
return ops.ones_like(input, dtype=dtype)
# arange
+range_op = ops.Range()
has_arange = hasattr(mindspore.mint, 'arange')
def arange(start=0, end=None, step=1, *, dtype=None, device=None):
if ON_ORANGE_PI and dtype in (None, mindspore.int64):
dtype = mindspore.int32
if use_pyboost() and has_arange:
+ start = start.item() if isinstance(start, mindspore.Tensor) else start
+ end = end.item() if isinstance(end, mindspore.Tensor) else end
+ step = step.item() if isinstance(step, mindspore.Tensor) else step
return mindspore.mint.arange(start, end, step, dtype=dtype)
- return ops.arange(start, end, step, dtype=dtype)
+
+ start = mindspore.Tensor(start) if not isinstance(start, mindspore.Tensor) else start
+ end = mindspore.Tensor(end) if not isinstance(start, mindspore.Tensor) else end
+ step = mindspore.Tensor(step) if not isinstance(start, mindspore.Tensor) else step
+ out = range_op(start, end, step)
+ if dtype:
+ out = out.to(dtype)
+ return out
# range
def range(start=0, end=None, step=1, dtype=None):
diff --git a/mindnlp/core/ops/inplace.py b/mindnlp/core/ops/inplace.py
index 107a82dce..45655f62f 100644
--- a/mindnlp/core/ops/inplace.py
+++ b/mindnlp/core/ops/inplace.py
@@ -1,6 +1,7 @@
-
from mindspore.common.generator import default_generator
+from mindspore.ops.auto_generate.gen_ops_prim import inplace_normal_op
+from mindnlp import core
generator_step_ = 12
@@ -39,10 +40,11 @@ def inplace_normal(input, mean=0, std=1, *, generator=None):
if generator is None:
generator = default_generator
seed, offset = generator._step(generator_step_)
- if input.device.type == 'npu':
- execute('inplace_normal', input, mean, std, seed, offset)
- elif input.device.type == 'cpu':
- core.normal(mean, std, size=input.size, generator=generator, out=input)
+ if isinstance(mean, core.Tensor):
+ mean = mean.item()
+ if isinstance(std, core.Tensor):
+ std = std.item()
+ inplace_normal_op(input, mean, std, seed, offset)
return input
diff --git a/mindnlp/core/ops/other.py b/mindnlp/core/ops/other.py
index d786a96cd..e57faaba5 100644
--- a/mindnlp/core/ops/other.py
+++ b/mindnlp/core/ops/other.py
@@ -7,7 +7,7 @@
from mindspore.common.initializer import initializer
from mindspore.ops._primitive_cache import _get_cache_prim
-from ..configs import use_pyboost, ON_ORANGE_PI
+from ..configs import use_pyboost, ON_ORANGE_PI, ON_A1
from .reduction import any
from .comparison import eq
from ._inner import call_ms_func
@@ -710,19 +710,20 @@ def meshgrid(*tensors, indexing=None):
# repeat_interleave
-has_repeat_interleave = hasattr(mindspore.mint, "repeat_interleave")
-
-
-def repeat_interleave(*args, **kwargs):
- if use_pyboost() and has_repeat_interleave:
- return mindspore.mint.repeat_interleave(*args, **kwargs)
-
- input, repeats, dim = args.get("input"), args.get("repeats"), args.get("dim")
+has_repeat_interleave = hasattr(mindspore.mint, 'repeat_interleave')
+def repeat_interleave(input, repeats, dim=None):
+ if use_pyboost() and has_repeat_interleave and not ON_A1:
+ return mindspore.mint.repeat_interleave(input, repeats, dim=dim)
if input.dtype == mindspore.bool_:
input = input.int()
- return input.repeat(repeats, dim).bool()
- return input.repeat(repeats, dim)
-
+ new_shape = list(input.shape)
+ new_shape.insert(dim+1, repeats)
+ expanded = input.unsqueeze(dim+1).expand(new_shape)
+
+ final_shape = new_shape.copy()
+ final_shape[dim] *= final_shape[dim+1]
+ del final_shape[dim+1]
+ return expanded.reshape(final_shape)
# roll
DEVICE_TARGET = mindspore.get_context("device_target")
diff --git a/mindnlp/core/ops/random.py b/mindnlp/core/ops/random.py
index f66f58f30..4aed3df98 100644
--- a/mindnlp/core/ops/random.py
+++ b/mindnlp/core/ops/random.py
@@ -3,7 +3,7 @@
import mindspore
from mindspore import ops
from mindspore.ops._primitive_cache import _get_cache_prim
-from ..configs import use_pyboost, DEVICE_TARGET
+from ..configs import use_pyboost, DEVICE_TARGET, ON_A1
from .other import cumsum, searchsorted
from .comparison import topk
from .pointwise import div, log
@@ -27,7 +27,7 @@ def bernoulli(input, *, generator=None, out=None):
has_multinomial = hasattr(mindspore.mint, 'multinomial')
def multinomial(input, num_samples, replacement=False, *, generator=None):
"""custom multinomial"""
- if use_pyboost() and has_multinomial:
+ if use_pyboost() and has_multinomial and not ON_A1:
return mindspore.mint.multinomial(input, num_samples, replacement=replacement, generator=generator)
if replacement:
# with replacement
diff --git a/mindnlp/core/ops/reduction.py b/mindnlp/core/ops/reduction.py
index 1789c991a..1da14ab25 100644
--- a/mindnlp/core/ops/reduction.py
+++ b/mindnlp/core/ops/reduction.py
@@ -42,7 +42,14 @@ def aminmax(input, *, dim=None, keepdim=False):
# all
has_all = hasattr(mindspore.mint, 'all')
-def all(input, dim=None, keepdim=False, *, dtype=None):
+def all(input, dim=None, keepdim=False, *, dtype=None, **kwargs):
+ axis = kwargs.get('axis', None)
+ keepdims = kwargs.get('keepdims', None)
+ if axis is not None:
+ dim = axis
+ if keepdims:
+ keepdim = keepdims
+
if use_pyboost() and has_all:
return mindspore.mint.all(input, dim, keepdim).to(input.dtype)
return ops.all(input, dim, keepdim).to(input.dtype)
@@ -79,7 +86,10 @@ def logsumexp(input, dim, keepdim=False):
# mean
has_mean = hasattr(mindspore.mint, 'mean')
-def mean(input, dim=None, keepdim=False, *, dtype=None):
+def mean(input, dim=None, keepdim=False, *, dtype=None, **kwargs):
+ axis = kwargs.get('axis', None)
+ if axis is not None:
+ dim = axis
if use_pyboost() and has_mean:
return mindspore.mint.mean(input, dim, keepdim, dtype=dtype)
out = ops.mean(input, dim, keepdim)
@@ -135,7 +145,10 @@ def nanquantile(input, q, dim=None, keepdim=False, *, interpolation='linear'):
# std
has_std = hasattr(mindspore.mint, 'std')
-def std(input, dim=None, *, correction=1, keepdim=False):
+def std(input, dim=None, *, correction=1, keepdim=False, **kwargs):
+ axis = kwargs.get('axis', None)
+ if axis is not None:
+ dim = axis
if use_pyboost() and has_std:
return mindspore.mint.std(input, dim=dim, correction=correction, keepdim=keepdim)
if DEVICE_TARGET == 'GPU':
diff --git a/mindnlp/core/serialization.py b/mindnlp/core/serialization.py
index ac9dce42d..049a81e42 100644
--- a/mindnlp/core/serialization.py
+++ b/mindnlp/core/serialization.py
@@ -1235,11 +1235,11 @@ class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
# Lets us override the imports that pickle uses when unpickling an object.
# This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
def find_class(self, mod_name, name):
- if mod_name == 'core._utils':
+ if mod_name == 'torch._utils':
return eval(name)
if mod_name == 'torch':
return str(name)
- if mod_name == 'core._tensor':
+ if mod_name == 'torch._tensor':
return eval(name)
mod_name = load_module_mapping.get(mod_name, mod_name)
return super().find_class(mod_name, name)
diff --git a/mindnlp/core/storage.py b/mindnlp/core/storage.py
index 47f6a5770..45cb25527 100644
--- a/mindnlp/core/storage.py
+++ b/mindnlp/core/storage.py
@@ -12,12 +12,12 @@
from typing_extensions import Self
from mindnlp import core
-from core._utils import _to, _type
-from core.types import _bool, _int, Storage
+from ._utils import _to, _type
+from .types import _bool, _int, Storage
if TYPE_CHECKING:
- from core._prims_common import DeviceLikeType
+ from ._prims_common import DeviceLikeType
__all__ = ["TypedStorage", "UntypedStorage"]
@@ -525,7 +525,10 @@ def _share_filename_cpu_(self, *args, **kwargs):
return super()._share_filename_cpu_(*args, **kwargs)
def data_ptr(self):
- return self.data.ctypes.data
+ if isinstance(self.data, np.ndarray):
+ return self.data.ctypes.data
+ else:
+ return self.data.data_ptr()
def nbytes(self):
return self.data.nbytes
diff --git a/mindnlp/dataset/load.py b/mindnlp/dataset/load.py
index 6a8c9424d..6ae103cc1 100644
--- a/mindnlp/dataset/load.py
+++ b/mindnlp/dataset/load.py
@@ -16,13 +16,11 @@
"""
load
"""
-import os
from typing import Union, Optional, Dict, Sequence, Mapping
from datasets import load_dataset as hf_load
from datasets import Dataset, IterableDataset, Split, Features, \
DownloadConfig, DownloadMode, VerificationMode, Version
from mindspore.dataset import GeneratorDataset
-from mindspore.communication import get_rank, get_group_size
class TransferIterableDataset():
diff --git a/mindnlp/core/_dynamo/__init__.py b/mindnlp/diffusers/__init__.py
similarity index 100%
rename from mindnlp/core/_dynamo/__init__.py
rename to mindnlp/diffusers/__init__.py
diff --git a/mindnlp/engine/trainer/base.py b/mindnlp/engine/trainer/base.py
index bb76fadab..07a8d16a7 100644
--- a/mindnlp/engine/trainer/base.py
+++ b/mindnlp/engine/trainer/base.py
@@ -41,13 +41,12 @@
from ...core.autograd import value_and_grad
from ...core.serialization import safe_load_file, safe_save_file, save, save_checkpoint, load, load_checkpoint
from ...peft import PeftModel
-from ...configs import WEIGHTS_NAME, CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME, \
- WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
from ...dataset import BaseMapFunction
from ...utils import logging, find_labels, can_return_loss
from ...accelerate.utils import DistributedType
-from ...accelerate.utils import accelerate_distributed_type
from ...utils.import_utils import is_safetensors_available
+from ...transformers.utils import WEIGHTS_NAME, CONFIG_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, \
+ ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME
from ...transformers.modeling_utils import PreTrainedModel
from ...transformers.configuration_utils import PretrainedConfig
from ...transformers.tokenization_utils_base import PreTrainedTokenizerBase
@@ -285,7 +284,6 @@ def __init__(
# Internal variables to help with automatic batch size reduction
self._train_batch_size = args.train_batch_size
self._created_lr_scheduler = False
- self.actual_distributed_type = accelerate_distributed_type
def _get_learning_rate(self):
@@ -1377,20 +1375,6 @@ def _prepare_inputs(self, inputs: Dict[str, Union[mindspore.Tensor, Any]]) -> Di
return inputs
-
- def update_gradient_by_distributed_type(self, model: nn.Module) -> None:
- """update gradient by distributed_type"""
- if accelerate_distributed_type == DistributedType.NO:
- return
- if accelerate_distributed_type == DistributedType.MULTI_NPU:
- from mindspore.communication import get_group_size
- from mindspore.communication.comm_func import all_reduce
- rank_size = get_group_size()
- for parameter in model.parameters():
- new_grads_mean = all_reduce(parameter.grad) / rank_size
- parameter.grad = new_grads_mean
-
-
def training_step(self, model: nn.Module, inputs: Dict[str, Union[mindspore.Tensor, Any]]) -> Tuple[List[mindspore.Tensor], mindspore.Tensor]:
"""
Perform a training step on a batch of inputs.
@@ -1422,7 +1406,6 @@ def forward(inputs):
self.grad_fn = value_and_grad(forward, weights, attach_grads=True)
loss = self.grad_fn(inputs)
- self.update_gradient_by_distributed_type(model)
return loss / self.args.gradient_accumulation_steps
def compute_loss(self, model, inputs, return_outputs=False):
diff --git a/mindnlp/engine/utils.py b/mindnlp/engine/utils.py
index 0ed550cb8..45e36c078 100644
--- a/mindnlp/engine/utils.py
+++ b/mindnlp/engine/utils.py
@@ -29,8 +29,7 @@
import mindspore
from mindnlp.core import ops
-from core.nn import functional as F
-from mindnlp.configs import GENERATOR_SEED
+from mindnlp.core.nn import functional as F
from mindnlp.utils import is_mindspore_available, ExplicitEnum
@@ -215,8 +214,7 @@ def set_seed(seed: int):
np.random.seed(seed)
if is_mindspore_available():
mindspore.set_seed(seed)
- if GENERATOR_SEED:
- mindspore.manual_seed(seed)
+ mindspore.manual_seed(seed)
def enable_full_determinism(seed: int, warn_only: bool = False):
"""
diff --git a/mindnlp/sentence/__init__.py b/mindnlp/sentence/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/mindnlp/timm/__init__.py b/mindnlp/timm/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/mindnlp/transformers/__init__.py b/mindnlp/transformers/__init__.py
index e02e58ecb..a5b1c4060 100644
--- a/mindnlp/transformers/__init__.py
+++ b/mindnlp/transformers/__init__.py
@@ -3,7 +3,6 @@
from transformers.utils import OptionalDependencyNotAvailable, _LazyModule
from transformers.utils.import_utils import *
-
# Base objects, independent of any specific backend
_import_structure = {
# "agents": [
@@ -4102,342 +4101,7 @@
_import_structure["models.musicgen_melody"].append("MusicgenMelodyProcessor")
-# FLAX-backed objects
-try:
- if not is_flax_available():
- raise OptionalDependencyNotAvailable()
-except OptionalDependencyNotAvailable:
- from transformers.utils import dummy_flax_objects
-
- _import_structure["utils.dummy_flax_objects"] = [
- name for name in dir(dummy_flax_objects) if not name.startswith("_")
- ]
-else:
- _import_structure["generation"].extend(
- [
- "FlaxForcedBOSTokenLogitsProcessor",
- "FlaxForcedEOSTokenLogitsProcessor",
- "FlaxForceTokensLogitsProcessor",
- "FlaxGenerationMixin",
- "FlaxLogitsProcessor",
- "FlaxLogitsProcessorList",
- "FlaxLogitsWarper",
- "FlaxMinLengthLogitsProcessor",
- "FlaxTemperatureLogitsWarper",
- "FlaxSuppressTokensAtBeginLogitsProcessor",
- "FlaxSuppressTokensLogitsProcessor",
- "FlaxTopKLogitsWarper",
- "FlaxTopPLogitsWarper",
- "FlaxWhisperTimeStampLogitsProcessor",
- ]
- )
- _import_structure["modeling_flax_outputs"] = []
- _import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
- _import_structure["models.albert"].extend(
- [
- "FlaxAlbertForMaskedLM",
- "FlaxAlbertForMultipleChoice",
- "FlaxAlbertForPreTraining",
- "FlaxAlbertForQuestionAnswering",
- "FlaxAlbertForSequenceClassification",
- "FlaxAlbertForTokenClassification",
- "FlaxAlbertModel",
- "FlaxAlbertPreTrainedModel",
- ]
- )
- _import_structure["models.auto"].extend(
- [
- "FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
- "FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
- "FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
- "FLAX_MODEL_FOR_MASKED_LM_MAPPING",
- "FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
- "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
- "FLAX_MODEL_FOR_PRETRAINING_MAPPING",
- "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
- "FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
- "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
- "FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
- "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
- "FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING",
- "FLAX_MODEL_MAPPING",
- "FlaxAutoModel",
- "FlaxAutoModelForCausalLM",
- "FlaxAutoModelForImageClassification",
- "FlaxAutoModelForMaskedLM",
- "FlaxAutoModelForMultipleChoice",
- "FlaxAutoModelForNextSentencePrediction",
- "FlaxAutoModelForPreTraining",
- "FlaxAutoModelForQuestionAnswering",
- "FlaxAutoModelForSeq2SeqLM",
- "FlaxAutoModelForSequenceClassification",
- "FlaxAutoModelForSpeechSeq2Seq",
- "FlaxAutoModelForTokenClassification",
- "FlaxAutoModelForVision2Seq",
- ]
- )
-
- # Flax models structure
-
- _import_structure["models.bart"].extend(
- [
- "FlaxBartDecoderPreTrainedModel",
- "FlaxBartForCausalLM",
- "FlaxBartForConditionalGeneration",
- "FlaxBartForQuestionAnswering",
- "FlaxBartForSequenceClassification",
- "FlaxBartModel",
- "FlaxBartPreTrainedModel",
- ]
- )
- _import_structure["models.beit"].extend(
- [
- "FlaxBeitForImageClassification",
- "FlaxBeitForMaskedImageModeling",
- "FlaxBeitModel",
- "FlaxBeitPreTrainedModel",
- ]
- )
-
- _import_structure["models.bert"].extend(
- [
- "FlaxBertForCausalLM",
- "FlaxBertForMaskedLM",
- "FlaxBertForMultipleChoice",
- "FlaxBertForNextSentencePrediction",
- "FlaxBertForPreTraining",
- "FlaxBertForQuestionAnswering",
- "FlaxBertForSequenceClassification",
- "FlaxBertForTokenClassification",
- "FlaxBertModel",
- "FlaxBertPreTrainedModel",
- ]
- )
- _import_structure["models.big_bird"].extend(
- [
- "FlaxBigBirdForCausalLM",
- "FlaxBigBirdForMaskedLM",
- "FlaxBigBirdForMultipleChoice",
- "FlaxBigBirdForPreTraining",
- "FlaxBigBirdForQuestionAnswering",
- "FlaxBigBirdForSequenceClassification",
- "FlaxBigBirdForTokenClassification",
- "FlaxBigBirdModel",
- "FlaxBigBirdPreTrainedModel",
- ]
- )
- _import_structure["models.blenderbot"].extend(
- [
- "FlaxBlenderbotForConditionalGeneration",
- "FlaxBlenderbotModel",
- "FlaxBlenderbotPreTrainedModel",
- ]
- )
- _import_structure["models.blenderbot_small"].extend(
- [
- "FlaxBlenderbotSmallForConditionalGeneration",
- "FlaxBlenderbotSmallModel",
- "FlaxBlenderbotSmallPreTrainedModel",
- ]
- )
- _import_structure["models.bloom"].extend(
- [
- "FlaxBloomForCausalLM",
- "FlaxBloomModel",
- "FlaxBloomPreTrainedModel",
- ]
- )
- _import_structure["models.clip"].extend(
- [
- "FlaxCLIPModel",
- "FlaxCLIPPreTrainedModel",
- "FlaxCLIPTextModel",
- "FlaxCLIPTextPreTrainedModel",
- "FlaxCLIPTextModelWithProjection",
- "FlaxCLIPVisionModel",
- "FlaxCLIPVisionPreTrainedModel",
- ]
- )
- _import_structure["models.dinov2"].extend(
- [
- "FlaxDinov2Model",
- "FlaxDinov2ForImageClassification",
- "FlaxDinov2PreTrainedModel",
- ]
- )
- _import_structure["models.distilbert"].extend(
- [
- "FlaxDistilBertForMaskedLM",
- "FlaxDistilBertForMultipleChoice",
- "FlaxDistilBertForQuestionAnswering",
- "FlaxDistilBertForSequenceClassification",
- "FlaxDistilBertForTokenClassification",
- "FlaxDistilBertModel",
- "FlaxDistilBertPreTrainedModel",
- ]
- )
- _import_structure["models.electra"].extend(
- [
- "FlaxElectraForCausalLM",
- "FlaxElectraForMaskedLM",
- "FlaxElectraForMultipleChoice",
- "FlaxElectraForPreTraining",
- "FlaxElectraForQuestionAnswering",
- "FlaxElectraForSequenceClassification",
- "FlaxElectraForTokenClassification",
- "FlaxElectraModel",
- "FlaxElectraPreTrainedModel",
- ]
- )
- _import_structure["models.encoder_decoder"].append("FlaxEncoderDecoderModel")
- _import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"])
- _import_structure["models.gpt_neo"].extend(
- ["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"]
- )
- _import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"])
- _import_structure["models.llama"].extend(["FlaxLlamaForCausalLM", "FlaxLlamaModel", "FlaxLlamaPreTrainedModel"])
- _import_structure["models.gemma"].extend(["FlaxGemmaForCausalLM", "FlaxGemmaModel", "FlaxGemmaPreTrainedModel"])
- _import_structure["models.longt5"].extend(
- [
- "FlaxLongT5ForConditionalGeneration",
- "FlaxLongT5Model",
- "FlaxLongT5PreTrainedModel",
- ]
- )
- _import_structure["models.marian"].extend(
- [
- "FlaxMarianModel",
- "FlaxMarianMTModel",
- "FlaxMarianPreTrainedModel",
- ]
- )
- _import_structure["models.mbart"].extend(
- [
- "FlaxMBartForConditionalGeneration",
- "FlaxMBartForQuestionAnswering",
- "FlaxMBartForSequenceClassification",
- "FlaxMBartModel",
- "FlaxMBartPreTrainedModel",
- ]
- )
- _import_structure["models.mistral"].extend(
- [
- "FlaxMistralForCausalLM",
- "FlaxMistralModel",
- "FlaxMistralPreTrainedModel",
- ]
- )
- _import_structure["models.mt5"].extend(["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"])
- _import_structure["models.opt"].extend(
- [
- "FlaxOPTForCausalLM",
- "FlaxOPTModel",
- "FlaxOPTPreTrainedModel",
- ]
- )
- _import_structure["models.pegasus"].extend(
- [
- "FlaxPegasusForConditionalGeneration",
- "FlaxPegasusModel",
- "FlaxPegasusPreTrainedModel",
- ]
- )
- _import_structure["models.regnet"].extend(
- [
- "FlaxRegNetForImageClassification",
- "FlaxRegNetModel",
- "FlaxRegNetPreTrainedModel",
- ]
- )
- _import_structure["models.resnet"].extend(
- [
- "FlaxResNetForImageClassification",
- "FlaxResNetModel",
- "FlaxResNetPreTrainedModel",
- ]
- )
- _import_structure["models.roberta"].extend(
- [
- "FlaxRobertaForCausalLM",
- "FlaxRobertaForMaskedLM",
- "FlaxRobertaForMultipleChoice",
- "FlaxRobertaForQuestionAnswering",
- "FlaxRobertaForSequenceClassification",
- "FlaxRobertaForTokenClassification",
- "FlaxRobertaModel",
- "FlaxRobertaPreTrainedModel",
- ]
- )
- _import_structure["models.roberta_prelayernorm"].extend(
- [
- "FlaxRobertaPreLayerNormForCausalLM",
- "FlaxRobertaPreLayerNormForMaskedLM",
- "FlaxRobertaPreLayerNormForMultipleChoice",
- "FlaxRobertaPreLayerNormForQuestionAnswering",
- "FlaxRobertaPreLayerNormForSequenceClassification",
- "FlaxRobertaPreLayerNormForTokenClassification",
- "FlaxRobertaPreLayerNormModel",
- "FlaxRobertaPreLayerNormPreTrainedModel",
- ]
- )
- _import_structure["models.roformer"].extend(
- [
- "FlaxRoFormerForMaskedLM",
- "FlaxRoFormerForMultipleChoice",
- "FlaxRoFormerForQuestionAnswering",
- "FlaxRoFormerForSequenceClassification",
- "FlaxRoFormerForTokenClassification",
- "FlaxRoFormerModel",
- "FlaxRoFormerPreTrainedModel",
- ]
- )
- _import_structure["models.speech_encoder_decoder"].append("FlaxSpeechEncoderDecoderModel")
- _import_structure["models.t5"].extend(
- [
- "FlaxT5EncoderModel",
- "FlaxT5ForConditionalGeneration",
- "FlaxT5Model",
- "FlaxT5PreTrainedModel",
- ]
- )
- _import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel")
- _import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"])
- _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
- _import_structure["models.wav2vec2"].extend(
- [
- "FlaxWav2Vec2ForCTC",
- "FlaxWav2Vec2ForPreTraining",
- "FlaxWav2Vec2Model",
- "FlaxWav2Vec2PreTrainedModel",
- ]
- )
- _import_structure["models.whisper"].extend(
- [
- "FlaxWhisperForConditionalGeneration",
- "FlaxWhisperModel",
- "FlaxWhisperPreTrainedModel",
- "FlaxWhisperForAudioClassification",
- ]
- )
- _import_structure["models.xglm"].extend(
- [
- "FlaxXGLMForCausalLM",
- "FlaxXGLMModel",
- "FlaxXGLMPreTrainedModel",
- ]
- )
- _import_structure["models.xlm_roberta"].extend(
- [
- "FlaxXLMRobertaForMaskedLM",
- "FlaxXLMRobertaForMultipleChoice",
- "FlaxXLMRobertaForQuestionAnswering",
- "FlaxXLMRobertaForSequenceClassification",
- "FlaxXLMRobertaForTokenClassification",
- "FlaxXLMRobertaModel",
- "FlaxXLMRobertaForCausalLM",
- "FlaxXLMRobertaPreTrainedModel",
- ]
- )
+from . import ms_utils
sys.modules[__name__] = _LazyModule(
'transformers',
@@ -4446,3 +4110,4 @@
module_spec=__spec__,
extra_objects={"__version__": transformers.__version__},
)
+
diff --git a/mindnlp/transformers/ms_utils.py b/mindnlp/transformers/ms_utils.py
new file mode 100644
index 000000000..592c366c1
--- /dev/null
+++ b/mindnlp/transformers/ms_utils.py
@@ -0,0 +1,258 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""MindNLP MindSpore Utils"""
+
+import inspect
+from typing import Union, Optional, List, Tuple
+
+import mindspore
+from mindspore.common.initializer import initializer, Normal
+
+from mindnlp.core import nn, ops
+from mindnlp.core.nn import Parameter
+
+ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
+
+class Conv1D(nn.Module):
+ """
+ 1D-convolutional layer Basically works like a linear layer but the weights are transposed.
+
+ Args:
+ n_out (`int`): The number of output features.
+ n_in (`int`): The number of input features.
+ """
+ def __init__(self, n_out, n_in):
+ """
+ Initialize the Conv1D class with the specified number of output channels and input channels.
+
+ Args:
+ self (object): The instance of the Conv1D class.
+ n_out (int): The number of output channels for the convolution operation.
+ n_in (int): The number of input channels for the convolution operation.
+
+ Returns:
+ None.
+
+ Raises:
+ None.
+ """
+ super().__init__()
+ self.n_out = n_out
+ self.weight = nn.Parameter(ops.empty(nx, nf))
+ self.bias = nn.Parameter(ops.zeros(nf))
+
+ def forward(self, x):
+ """
+ Constructs the 1D convolutional operation on the input tensor x.
+
+ Args:
+ self (Conv1D): An instance of the Conv1D class.
+ x (mindspore.Tensor): The input tensor on which the convolution operation is applied.
+ Should have a shape of (batch_size, sequence_length, input_channels).
+
+ Returns:
+ None: The method modifies the input tensor x in place by performing the convolution operation.
+
+ Raises:
+ ValueError: If the shape of the input tensor x is not as expected for a 1D convolution operation.
+ RuntimeError: If there are any runtime issues during the convolution operation.
+ """
+ size_out = x.shape[:-1] + (self.n_out,)
+ x = ops.matmul(x.view(-1, x.shape[-1]), self.weight) + self.bias
+ x = x.view(size_out)
+ return x
+
+
+def prune_conv1d_layer(layer, index, dim=1):
+ """
+ Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights
+ are transposed.
+
+ Used to remove heads.
+
+ Args:
+ layer ([`~mindspore_utils.Conv1D`]): The layer to prune.
+ index (`mindspore.Tensor[int64]`): The indices to keep in the layer.
+ axis (`int`, *optional*, defaults to 1): The dimension on which to keep the indices.
+
+ Returns:
+ [`~mindspore_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
+ """
+ gama_l = layer.weight.index_select(dim, index)
+ if dim == 0:
+ beta_l = layer.bias
+ else:
+ beta_l = layer.bias[index]
+ new_size = list(layer.weight.shape)
+ new_size[dim] = len(index)
+ new_layer = Conv1D(new_size[1], new_size[0])
+ new_layer.weight.requires_grad = False
+ new_layer.weight.assign_value(gama_l)
+ new_layer.weight.requires_grad = True
+ new_layer.bias.requires_grad = False
+ new_layer.bias.assign_value(beta_l)
+ new_layer.bias.requires_grad = True
+ return new_layer
+
+
+def find_pruneable_heads_and_indices(heads, n_heads, head_size, already_pruned_heads):
+ """
+ Finds the heads and their indices taking `already_pruned_heads` into account.
+
+ Args:
+ heads (`List[int]`): List of the indices of heads to prune.
+ n_heads (`int`): The number of heads in the model.
+ head_size (`int`): The size of each head.
+ already_pruned_heads (`Set[int]`): A set of already pruned heads.
+
+ Returns:
+ `Tuple[Set[int], MindSpore.Tensor[int64]]`: A tuple with the remaining heads and their corresponding indices.
+ """
+ mask = ops.ones((n_heads, head_size))
+ heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
+ for head in heads:
+ # Compute how many pruned heads are before the head and move the index accordingly
+ head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
+ mask[head] = 0
+ mask = mask.view(-1).eq(1)
+ index = ops.arange(len(mask), dtype=mindspore.int64)[mask]
+ return heads, index
+
+def prune_linear_layer(layer, index, dim=0):
+ """
+ Prune a linear layer to keep only entries in index.
+ Used to remove heads.
+
+ Args:
+ layer (`mindspore.nn.Linear`): The layer to prune.
+ index (`mindspore.Tensor[int64]`): The indices to keep in the layer.
+ axis (`int`, *optional*, defaults to 0): The dimension on which to keep the indices.
+
+ Returns:
+ `mindspore.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`.
+ """
+ W = layer.weight.index_select(dim, index).copy()
+ if layer.bias is not None:
+ if dim == 1:
+ b = layer.bias.copy()
+ else:
+ b = layer.bias[index].copy()
+ new_size = list(layer.weight.shape)
+ new_size[dim] = len(index)
+ new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None)
+ new_layer.weight.requires_grad = False
+ new_layer.weight.assign_value(W)
+ new_layer.weight.requires_grad = True
+ if layer.bias is not None:
+ new_layer.bias.requires_grad = False
+ new_layer.bias.assign_value(b)
+ new_layer.bias.requires_grad = True
+ return new_layer
+
+
+def apply_chunking_to_forward(forward_fn, chunk_size, chunk_axis, *input_tensors):
+ """
+ This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
+ `chunk_axis`. It then applies a layer `forward_fn` to each chunk independently to save memory.
+ If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
+ applying `forward_fn` to `input_tensors`.
+
+ Args:
+ forward_fn (`Callable[..., mindspore.Tensor]`):
+ The forward function of the model.
+ chunk_size (`int`):
+ The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
+ chunk_axis (`int`):
+ The dimension over which the `input_tensors` should be chunked.
+ input_tensors (`Tuple[mindspore.Tensor]`):
+ The input tensors of `forward_fn` which will be chunked
+
+ Returns:
+ `mindspore.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`.
+ """
+ assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"
+
+ # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
+ num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
+ if num_args_in_forward_chunk_fn != len(input_tensors):
+ raise ValueError(
+ f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
+ "tensors are given"
+ )
+
+ if chunk_size > 0:
+ tensor_shape = input_tensors[0].shape[chunk_axis]
+ for input_tensor in input_tensors:
+ if input_tensor.shape[chunk_axis] != tensor_shape:
+ raise ValueError(
+ f"All input tenors have to be of the same shape: {tensor_shape}, "
+ f"found shape {input_tensor.shape[chunk_axis]}"
+ )
+
+ if input_tensors[0].shape[chunk_axis] % chunk_size != 0:
+ raise ValueError(
+ f"The dimension to be chunked {input_tensors[0].shape[chunk_axis]} has to be a multiple of the chunk "
+ f"size {chunk_size}"
+ )
+
+ num_chunks = input_tensors[0].shape[chunk_axis] // chunk_size
+
+ # chunk input tensor into tuples
+ input_tensors_chunks = tuple(ops.chunk(input_tensor, num_chunks, dim=chunk_axis) for input_tensor in input_tensors)
+ # apply forward fn to every tuple
+ output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
+ # concatenate output at same dimension
+ return ops.cat(output_chunks, dim=chunk_axis)
+
+ return forward_fn(*input_tensors)
+
+def zero_init(cls, *args, **kwargs):
+ """init zeros to speed up initialize stage."""
+ for k in kwargs.keys():
+ if 'init' in k:
+ kwargs.pop(k)
+ init_signature = inspect.signature(cls.__init__)
+ init_params = init_signature.parameters
+ for param_name in init_params.keys():
+ if 'init' in param_name:
+ kwargs[param_name] = 'zeros'
+ def _reset_parameters(self):
+ pass
+ cls.reset_parameters = _reset_parameters
+ return cls(*args, **kwargs)
+
+def meshgrid(
+ *tensors: Union[mindspore.Tensor, List[mindspore.Tensor]], indexing: Optional[str] = None
+) -> Tuple[mindspore.Tensor, ...]:
+ """
+ Wrapper around torch.meshgrid to avoid warning messages about the introduced `indexing` argument.
+
+ Reference: https://pytorch.org/docs/1.13/generated/torch.meshgrid.html
+ """
+ return ops.meshgrid(*tensors, indexing=indexing)
+
+def isin_friendly(elements: mindspore.Tensor, test_elements: mindspore.Tensor) -> mindspore.Tensor:
+ """
+ Same as `ops.isin` without flags, but MPS-friendly.
+
+ Args:
+ elements (`mindspore.Tensor`): Input elements
+ test_elements (`mindspore.Tensor`): The elements to check against.
+
+ Returns:
+ `mindspore.Tensor`: A boolean tensor of the same shape as `elements` that is True for `elements` in `test_elements`
+ and False otherwise
+ """
+ return elements.tile((test_elements.shape[0], 1)).eq(test_elements.unsqueeze(1)).sum(0).bool().squeeze()
\ No newline at end of file
diff --git a/mindnlp/trl/__init__.py b/mindnlp/trl/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/mindnlp/utils/__init__.py b/mindnlp/utils/__init__.py
index e69de29bb..7fca6fc2b 100644
--- a/mindnlp/utils/__init__.py
+++ b/mindnlp/utils/__init__.py
@@ -0,0 +1,37 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""
+Common utils
+"""
+from .generic import *
+# from .decompress import unzip, untar, ungz
+# from .download import *
+# from .compatibility import *
+# from .chat_template_utils import *
+from .import_utils import requires_backends, is_mindspore_available, OptionalDependencyNotAvailable, is_sentencepiece_available, \
+is_tokenizers_available, direct_transformers_import, is_protobuf_available, is_safetensors_available, \
+is_cython_available, is_pretty_midi_available, is_essentia_available, is_librosa_available, is_scipy_available, is_pyctcdecode_available, \
+is_jieba_available, is_vision_available, is_sudachi_projection_available, is_g2p_en_available, is_levenshtein_available, is_nltk_available, \
+is_bs4_available, is_pytesseract_available, is_tiktoken_available, is_einops_available, is_faiss_available, is_datasets_available, \
+is_sacremoses_available, is_phonemizer_available,is_speech_available, is_kenlm_available, is_triton_available
+
+# from .testing_utils import require_mindspore
+# from .save import convert_file_size_to_int
+# from .peft_utils import find_adapter_config_file
+
+DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
+DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
+SENTENCEPIECE_UNDERLINE = "▁"
\ No newline at end of file
diff --git a/mindnlp/utils/generic.py b/mindnlp/utils/generic.py
new file mode 100644
index 000000000..2de3a7a36
--- /dev/null
+++ b/mindnlp/utils/generic.py
@@ -0,0 +1,612 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+Generic utils.
+"""
+import inspect
+from enum import Enum
+from collections import OrderedDict, UserDict
+from dataclasses import fields
+from typing import Any, Tuple, ContextManager, List, TypedDict, Optional
+from contextlib import ExitStack
+from functools import wraps
+import numpy as np
+import mindspore
+from .import_utils import is_mindspore_available
+
+
+def is_tensor(x):
+ """
+ Tests if `x` is a `mindspore.Tensor` or `np.ndarray`.
+ """
+ if isinstance(x, mindspore.Tensor):
+ return True
+
+ return isinstance(x, np.ndarray)
+
+def _is_mindspore(x):
+ """
+ Checks if the input x is a MindSpore tensor.
+
+ Args:
+ x (object): The input object to be checked.
+
+ Returns:
+ None: This function does not return any value.
+
+ Raises:
+ None: This function does not raise any exceptions.
+ """
+ return isinstance(x, mindspore.Tensor)
+
+
+def is_mindspore_tensor(x):
+ """
+ Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed.
+ """
+ return False if not is_mindspore_available() else _is_mindspore(x)
+def set_attribute_for_modules(module, key: str, value: Any):
+ """
+ Set a value to a module and all submodules.
+ """
+ setattr(module, key, value)
+ for submodule in module.children():
+ set_attribute_for_modules(submodule, key, value)
+
+def can_return_tuple(func):
+ """
+ Decorator to wrap model method, to call output.to_tuple() if return_dict=False passed as a kwarg or
+ use_return_dict=False is set in the config.
+
+ Note:
+ output.to_tuple() convert output to tuple skipping all `None` values.
+ """
+
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ is_requested_to_return_tuple = kwargs.pop("return_dict", True) is False
+ is_configured_to_return_tuple = self.config.use_return_dict is False if hasattr(self, "config") else False
+
+ # The following allows to convert output to tuple ONLY on top level forward call,
+ # while internal modules of the model will return Output objects
+ # to be able to use name-based attribute access in modeling code.
+
+ # We will check if we are on top level module, if so, turn off to tuple conversion for all
+ # underling calls.
+ is_top_level_module = getattr(self, "_is_top_level_module", True)
+ if is_configured_to_return_tuple and is_top_level_module:
+ set_attribute_for_modules(self, "_is_top_level_module", False)
+
+ try:
+ output = func(self, *args, **kwargs)
+ if is_requested_to_return_tuple or (is_configured_to_return_tuple and is_top_level_module):
+ output = output.to_tuple()
+ finally:
+ # Remove the flag after the model forward call is finished.
+ if is_configured_to_return_tuple and is_top_level_module:
+ del_attribute_from_modules(self, "_is_top_level_module")
+
+ return output
+
+ return wrapper
+class ExplicitEnum(str, Enum):
+ """
+ Enum with more explicit error message for missing values.
+ """
+ @classmethod
+ def _missing_(cls, value):
+ """
+ This method `_missing_` in the class `ExplicitEnum` is a class method used to handle missing values in the ExplicitEnum class.
+
+ Args:
+ cls (class): The class itself, used for referring to the class instance inside the method.
+ value (any): The value that was not found in the ExplicitEnum class.
+
+ Returns:
+ None: This method does not return any value as it raises an exception when called.
+
+ Raises:
+ ValueError: If the value provided is not a valid member of the Enum class, a ValueError is raised with a message listing the valid options to choose from.
+ """
+ raise ValueError(
+ f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
+ )
+
+class TensorType(ExplicitEnum):
+ """
+ Possible values for the `return_tensors` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for
+ tab-completion in an IDE.
+ """
+ MINDSPORE = "ms"
+ NUMPY = "np"
+
+class PaddingStrategy(ExplicitEnum):
+ """
+ Possible values for the `padding` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in an
+ IDE.
+ """
+ LONGEST = "longest"
+ MAX_LENGTH = "max_length"
+ DO_NOT_PAD = "do_not_pad"
+
+class LossKwargs(TypedDict, total=False):
+ """
+ Keyword arguments to be passed to the loss function
+
+ Attributes:
+ num_items_in_batch (`int`, *optional*):
+ Number of items in the batch. It is recommended to pass it when
+ you are doing gradient accumulation.
+ """
+
+ num_items_in_batch: Optional[int]
+
+class ModelOutput(OrderedDict):
+ """
+ Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
+ tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
+ python dictionary.
+
+
+
+ You can't unpack a `ModelOutput` directly. Use the [`~utils.ModelOutput.to_tuple`] method to convert it to a tuple
+ before.
+
+
+ """
+
+ def __post_init__(self):
+ """Perform post-initialization actions for the ModelOutput class.
+
+ This method is automatically called after the initialization of a ModelOutput object.
+
+ Args:
+ self: An instance of the ModelOutput class.
+
+ Returns:
+ None
+
+ Raises:
+ ValueError: If the ModelOutput object has no fields or more than one required field.
+ ValueError: If a key/value pair in the first field is not a tuple or if it does not follow the format (key, value).
+ ValueError: If the key/value pair cannot be set for a given element in the first field.
+ """
+ class_fields = fields(self)
+
+ # Safety and consistency checks
+ if len(class_fields) == 0:
+ raise ValueError(f"{self.__class__.__name__} has no fields.")
+ if not all(field.default is None for field in class_fields[1:]):
+ raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")
+
+ first_field = getattr(self, class_fields[0].name)
+ other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
+
+ if other_fields_are_none and not is_tensor(first_field):
+ if isinstance(first_field, dict):
+ iterator = first_field.items()
+ first_field_iterator = True
+ else:
+ try:
+ iterator = iter(first_field)
+ first_field_iterator = True
+ except TypeError:
+ first_field_iterator = False
+
+ # if we provided an iterator as first field and the iterator is a (key, value) iterator
+ # set the associated fields
+ if first_field_iterator:
+ for idx, element in enumerate(iterator):
+ if (
+ not isinstance(element, (list, tuple))
+ or not len(element) == 2
+ or not isinstance(element[0], str)
+ ):
+ if idx == 0:
+ # If we do not have an iterator of key/values, set it as attribute
+ self[class_fields[0].name] = first_field
+ else:
+ # If we have a mixed iterator, raise an error
+ raise ValueError(
+ f"Cannot set key/value for {element}. It needs to be a tuple (key, value)."
+ )
+ break
+ setattr(self, element[0], element[1])
+ if element[1] is not None:
+ self[element[0]] = element[1]
+ elif first_field is not None:
+ self[class_fields[0].name] = first_field
+ else:
+ for field in class_fields:
+ v = getattr(self, field.name)
+ if v is not None:
+ self[field.name] = v
+
+ def __delitem__(self, *args, **kwargs):
+ """
+ __delitem__
+
+ Deletes an item from the ModelOutput instance.
+
+ Args:
+ self (ModelOutput): The ModelOutput instance from which the item will be deleted.
+
+ Returns:
+ None. This method does not return a value.
+
+ Raises:
+ RuntimeError: If the '__delitem__' method is attempted to be used on a ModelOutput instance, a RuntimeError is raised with a message indicating that this method cannot be used on the instance.
+ """
+ raise RuntimeError(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
+
+ def setdefault(self, *args, **kwargs):
+ """
+ Sets a default value in the ModelOutput instance.
+
+ Args:
+ self: The ModelOutput instance itself.
+
+ Returns:
+ None. This method does not return any value.
+
+ Raises:
+ RuntimeError: This exception is raised if the method 'setdefault' is called on a ModelOutput instance. The message in the exception states that the 'setdefault' method cannot be used on a
+ModelOutput instance.
+
+ Note:
+ The 'setdefault' method is not supported for ModelOutput instances as it can only be used on dictionary objects.
+ """
+ raise RuntimeError(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
+
+ def pop(self, *args, **kwargs):
+ """
+ Method that raises a RuntimeError to prevent the use of 'pop' on a ModelOutput instance.
+
+ Args:
+ self (object): The ModelOutput instance on which 'pop' is being called.
+ This parameter is required and represents the current instance of the class.
+
+ Returns:
+ None. This method does not return any value.
+
+ Raises:
+ RuntimeError: Raised when attempting to use 'pop' method on a ModelOutput instance. The exception message
+ specifies that 'pop' cannot be used on a ModelOutput instance to prevent unintended behavior.
+ """
+ raise RuntimeError(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
+
+ def update(self, *args, **kwargs):
+ """
+ Updates the current instance of the ModelOutput class.
+
+ Args:
+ self (ModelOutput): The instance of the ModelOutput class.
+
+ Returns:
+ None: This method does not return any value.
+
+ Raises:
+ RuntimeError: If the method is called on an instance of the ModelOutput class. This is to prevent using the 'update' method on a ModelOutput instance, as it is not allowed.
+
+ Note:
+ The 'update' method is not allowed to be used on a ModelOutput instance. If called, it will raise a RuntimeError.
+ """
+ raise RuntimeError(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
+
+ def __getitem__(self, k):
+ """
+ This method allows accessing the elements of the ModelOutput object using the square bracket notation.
+
+ Args:
+ self (ModelOutput): The instance of the ModelOutput class.
+ k (str or int): The key or index for accessing the element. If k is a string, it is used as a key to retrieve the corresponding value. If k is an integer, it is used as an index to retrieve the
+element.
+
+ Returns:
+ None: This method does not return any value directly. The retrieved value is returned based on the input key or index.
+
+ Raises:
+ TypeError: If the input parameter k is not a string or an integer.
+ KeyError: If the input key k is not found in the internal dictionary when k is a string.
+ IndexError: If the input index k is out of range when k is an integer.
+ """
+ if isinstance(k, str):
+ inner_dict = dict(self.items())
+ return inner_dict[k]
+ return self.to_tuple()[k]
+
+ def __setattr__(self, name, value):
+ """
+ Method __setattr__ in the class ModelOutput sets the value for the specified attribute name.
+
+ Args:
+ self (object): The instance of the ModelOutput class.
+ name (str): The name of the attribute to be set.
+ value (any): The value to be assigned to the attribute. It can be of any type.
+
+ Returns:
+ None. This method does not return any value.
+
+ Raises:
+ No specific exceptions are raised by this method. However, if the attribute name is not in the keys of the object, it will be added as a new attribute. If the value is None, the attribute will be
+set to None.
+ """
+ if name in self.keys() and value is not None:
+ # Don't call self.__setitem__ to avoid recursion errors
+ super().__setitem__(name, value)
+ super().__setattr__(name, value)
+
+ def __setitem__(self, key, value):
+ """
+ This method '__setitem__' in the class 'ModelOutput' allows setting key-value pairs in the model output object.
+
+ Args:
+ self (ModelOutput): The instance of the ModelOutput class.
+ key (Any): The key to be set in the model output object.
+ value (Any): The value corresponding to the key to be set in the model output object.
+
+ Returns:
+ None. This method does not return any value explicitly.
+
+ Raises:
+ This method may raise the following exceptions:
+ - TypeError: If the key is not of a valid type.
+ - ValueError: If the value is not acceptable for the given key.
+ - Other exceptions related to the internal implementation of the ModelOutput class.
+ """
+ # Will raise a KeyException if needed
+ super().__setitem__(key, value)
+ # Don't call self.__setattr__ to avoid recursion errors
+ super().__setattr__(key, value)
+
+ def to_tuple(self) -> Tuple[Any]:
+ """
+ Convert self to a tuple containing all the attributes/keys that are not `None`.
+ """
+ return tuple(v for _, v in self.items())
+
+# vendored from distutils.util
+def strtobool(val):
+ """Convert a string representation of truth to true (1) or false (0).
+
+ True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'.
+ Raises ValueError if 'val' is anything else.
+ """
+ val = val.lower()
+ if val in {"y", "yes", "t", "true", "on", "1"}:
+ return 1
+ if val in {"n", "no", "f", "false", "off", "0"}:
+ return 0
+ raise ValueError(f"invalid truth value {val!r}")
+
+class cached_property(property):
+ """
+ Descriptor that mimics @property but caches output in member variable.
+
+ From tensorflow_datasets
+
+ Built-in in functools from Python 3.8.
+ """
+ def __get__(self, obj, objtype=None):
+ """
+ Method '__get__' in the class 'cached_property'.
+
+ Args:
+ self (object): The current instance of the class.
+ obj (object): The object on which the method is being called.
+ objtype (object): The type of the object, if available. Defaults to None.
+
+ Returns:
+ None: The method returns a value of type None.
+
+ Raises:
+ AttributeError: If the attribute is unreadable, this exception is raised.
+ """
+ # See docs.python.org/3/howto/descriptor.html#properties
+ if obj is None:
+ return self
+ if self.fget is None:
+ raise AttributeError("unreadable attribute")
+ attr = "__cached_" + self.fget.__name__
+ cached = getattr(obj, attr, None)
+ if cached is None:
+ cached = self.fget(obj)
+ setattr(obj, attr, cached)
+ return cached
+
+def _is_numpy(x):
+ """
+ This function checks if the input is a NumPy array.
+
+ Args:
+ x (any): The input to be checked for being a NumPy array.
+
+ Returns:
+ None: This function does not return a value.
+
+ Raises:
+ None
+ """
+ return isinstance(x, np.ndarray)
+
+
+def is_numpy_array(x):
+ """
+ Tests if `x` is a numpy array or not.
+ """
+ return _is_numpy(x)
+
+def infer_framework_from_repr(x):
+ """
+ Tries to guess the framework of an object `x` from its repr (brittle but will help in `is_tensor` to try the
+ frameworks in a smart order, without the need to import the frameworks).
+ """
+ representation = str(type(x))
+ if representation.startswith("= (3, 8):
+ # For Python 3.8 and later
+ from importlib import metadata as importlib_metadata
+else:
+ # For Python versions earlier than 3.8
+ import importlib_metadata
+
+
+logger = logging.get_logger(__name__)
+
+def _is_package_available(
+ pkg_name: str, return_version: bool = False
+) -> Union[Tuple[bool, str], bool]:
+ """
+ Checks if a specified package is available and optionally returns its version.
+
+ Args:
+ pkg_name (str): The name of the package to check for availability.
+ return_version (bool, optional): Indicates whether to return the package version along with availability status. Defaults to False.
+
+ Returns:
+ Union[Tuple[bool, str], bool]: If return_version is True, returns a tuple containing a boolean indicating package availability and a string representing the package version.
+ If return_version is False, returns a boolean indicating package availability.
+
+ Raises:
+ No specific exceptions are raised within this function.
+ """
+ # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
+ package_exists = importlib.util.find_spec(pkg_name) is not None
+ package_version = "N/A"
+ if package_exists:
+ try:
+ package_version = importlib_metadata.version(pkg_name)
+ package_exists = True
+ except importlib_metadata.PackageNotFoundError:
+ package_exists = False
+ logger.debug(f"Detected {pkg_name} version {package_version}")
+ if return_version:
+ return package_exists, package_version
+ return package_exists
+
+
+_ftfy_available = _is_package_available("ftfy")
+_einops_available = _is_package_available('einops')
+_tiktoken_available = _is_package_available('tiktoken')
+_bs4_available = importlib.util.find_spec("bs4") is not None
+_pytest_available = _is_package_available("pytest")
+_datasets_available = _is_package_available("datasets")
+_sentencepiece_available = _is_package_available("sentencepiece")
+_soundfile_available = _is_package_available("soundfile")
+_tokenizers_available = _is_package_available("tokenizers")
+_pyctcdecode_available = _is_package_available("pyctcdecode")
+_safetensors_available = _is_package_available("safetensors")
+_modelscope_available = _is_package_available("modelscope")
+_jieba_available = _is_package_available("jieba")
+_pytesseract_available = _is_package_available("pytesseract")
+_g2p_en_available = _is_package_available("g2p_en")
+_phonemizer_available = _is_package_available("phonemizer")
+_mindspore_version, _mindspore_available = _is_package_available(
+ "mindspore", return_version=True
+)
+_sudachipy_available, _sudachipy_version = _is_package_available("sudachipy", return_version=True)
+
+_librosa_available = _is_package_available("librosa")
+_scipy_available = _is_package_available("scipy")
+_triton_available = _is_package_available("triton")
+_sacremoses_available = _is_package_available("sacremoses")
+_torchaudio_available = _is_package_available("pykaldi")
+_kenlm_available = _is_package_available("kenlm")
+_datamodel_code_generator_availabel = _is_package_available('datamodel_code_generator')
+_pretty_midi_available = importlib.util.find_spec("pretty_midi") is not None
+try:
+ _pretty_midi_version = importlib_metadata.version("pretty_midi")
+ logger.debug(f"Successfully imported pretty_midi version {_pretty_midi_version}")
+except importlib_metadata.PackageNotFoundError:
+ _pretty_midi_available = False
+
+_essentia_available = importlib.util.find_spec("essentia") is not None
+try:
+ _essentia_version = importlib_metadata.version("essentia")
+ logger.debug(f"Successfully imported essentia version {_essentia_version}")
+except importlib_metadata.PackageNotFoundError:
+ _essentia_version = False
+
+_levenshtein_available = _is_package_available("Levenshtein")
+_nltk_available = _is_package_available("nltk")
+
+
+_faiss_available = importlib.util.find_spec("faiss") is not None
+try:
+ _faiss_version = importlib.metadata.version("faiss")
+ logger.debug(f"Successfully imported faiss version {_faiss_version}")
+except importlib.metadata.PackageNotFoundError:
+ try:
+ _faiss_version = importlib.metadata.version("faiss-cpu")
+ logger.debug(f"Successfully imported faiss version {_faiss_version}")
+ except importlib.metadata.PackageNotFoundError:
+ _faiss_available = False
+
+def is_triton_available():
+ return _triton_available
+
+def is_datamodel_code_generator_availabel():
+ return _datamodel_code_generator_availabel
+
+def is_faiss_available():
+ return _faiss_available
+
+def is_levenshtein_available():
+ return _levenshtein_available
+
+
+def is_nltk_available():
+ return _nltk_available
+
+
+def is_einops_available():
+ return _einops_available
+
+
+def is_sudachi_available():
+ """
+ Checks if SudachiPy is available for use.
+
+ Returns:
+ None: Indicates whether SudachiPy is available or not.
+
+ """
+ return _sudachipy_available
+
+
+def get_sudachi_version():
+ '''
+ Returns the version of SudachiPy.
+
+ Returns:
+ None: This function does not take any parameters.
+
+ Raises:
+ None
+ '''
+ return _sudachipy_version
+
+
+def is_bs4_available():
+ return _bs4_available
+
+def is_sudachi_projection_available():
+ """
+ Checks if Sudachi projection is available.
+
+ This function checks if Sudachi is available and if the Sudachi version is equal to or greater than 0.6.8.
+
+ Returns:
+ None
+
+ Raises:
+ None
+ """
+ if not is_sudachi_available():
+ return False
+
+ # NOTE: We require sudachipy>=0.6.8 to use projection option in sudachi_kwargs for the forwardor of BertJapaneseTokenizer.
+ # - `projection` option is not supported in sudachipy<0.6.8, see https://github.com/WorksApplications/sudachi.rs/issues/230
+ return version.parse(_sudachipy_version) >= version.parse("0.6.8")
+
+def is_sacremoses_available():
+ """
+ Checks if the sacremoses library is available in the current environment.
+
+ Returns:
+ None: Indicates whether the sacremoses library is available or not.
+
+ Raises:
+ None.
+ """
+ return _sacremoses_available
+
+
+def is_mindspore_available():
+ '''
+ Checks if MindSpore is available.
+
+ Args:
+ None
+
+ Returns:
+ None: Indicates that the function does not return any value.
+
+ Raises:
+ None: No exceptions are raised by this function.
+ '''
+ return _mindspore_available
+
+
+def get_mindspore_version():
+ """
+ Returns the current version of MindSpore.
+
+ Args:
+
+ Returns:
+ None: This function does not take any parameters.
+
+ Raises:
+ None: This function does not raise any exceptions.
+ """
+ return _mindspore_version
+
+
+
+def is_ftfy_available():
+ return _ftfy_available
+
+
+def is_datasets_available():
+ """
+ Checks if datasets are available.
+
+ Returns:
+ None: This function does not return any value.
+
+ Raises:
+ None: This function does not raise any exceptions.
+ """
+ return _datasets_available
+
+
+def is_sentencepiece_available():
+ """
+ Checks if SentencePiece library is available.
+
+ Returns:
+ None: Indicates whether the SentencePiece library is available or not.
+
+ Raises:
+ None.
+ """
+ return _sentencepiece_available
+
+
+def is_tokenizers_available():
+ """Check if tokenizers are available.
+
+ This function checks if tokenizers are available for use. It does not take any parameters.
+
+ Returns:
+ None: This function does not return any value.
+
+ Raises:
+ None: This function does not raise any exceptions.
+ """
+ return _tokenizers_available
+
+
+def is_safetensors_available():
+ """
+ Checks if SafeTensors is available in the current environment.
+
+ Returns:
+ None: Indicates whether SafeTensors is available or not.
+
+ """
+ return _safetensors_available
+
+
+def is_modelscope_available():
+ '''
+ Checks if the model scope is available.
+
+ Returns:
+ None: Indicates whether the model scope is available or not.
+ '''
+ return _modelscope_available
+
+
+def is_cython_available():
+ """
+ Checks if Cython is available in the current environment.
+
+ Returns:
+ None: Indicates whether Cython is available or not.
+
+ Raises:
+ None
+ """
+ return importlib.util.find_spec("pyximport") is not None
+
+
+def is_protobuf_available():
+ """
+ Checks if the Google Protocol Buffers (protobuf) library is available.
+
+ Returns:
+ bool: True if the protobuf library is available, False otherwise.
+
+ Raises:
+ No specific exceptions are raised by this function.
+ """
+ if importlib.util.find_spec("google") is None:
+ return False
+ return importlib.util.find_spec("google.protobuf") is not None
+
+
+def is_pytest_available():
+ """
+ Check if the pytest library is available.
+
+ Returns:
+ None: This function does not return any value.
+
+ """
+ return _pytest_available
+
+
+def is_pretty_midi_available():
+ """
+ Checks if the 'pretty_midi' library is available.
+
+ Returns:
+ None
+
+ Raises:
+ None
+ """
+ return _pretty_midi_available
+
+
+def is_librosa_available():
+ """
+ Checks if the 'librosa' library is available.
+
+ Returns:
+ None
+
+ Raises:
+ None
+ """
+ return _librosa_available
+
+
+def is_essentia_available():
+ """
+ Checks if the 'essentia' library is available.
+
+ Returns:
+ None.
+
+ Raises:
+ None.
+ """
+ return _essentia_available
+
+
+def is_pyctcdecode_available():
+ """
+ Check if the PyCTCDecode library is available.
+
+ Returns:
+ None: This function does not return any value.
+
+ Raises:
+ None
+ """
+ return _pyctcdecode_available
+
+
+def is_scipy_available():
+ """
+ Checks if the SciPy library is available.
+
+ Returns:
+ None: This function does not return any value.
+
+ Raises:
+ None: This function does not raise any exceptions.
+ """
+ return _scipy_available
+
+
+def is_jieba_available():
+ '''
+ Checks if the Jieba library is available.
+
+ Returns:
+ None: The function does not return any value.
+
+ '''
+ return _jieba_available
+
+
+def is_pytesseract_available():
+ """
+ Check if pytesseract library is available.
+
+ Returns:
+ None: This function does not return any value.
+
+ Raises:
+ None: This function does not raise any exceptions.
+ """
+ return _pytesseract_available
+
+
+def is_g2p_en_available():
+ return _g2p_en_available
+
+
+def is_tiktoken_available():
+ return _tiktoken_available
+
+
+def is_phonemizer_available():
+ return _phonemizer_available
+
+
+@lru_cache()
+def is_vision_available():
+ """
+ Checks if the Pillow library is available for image processing.
+
+ Returns:
+ bool: True if Pillow library is available, False otherwise.
+
+ Raises:
+ PackageNotFoundError: If Pillow or Pillow-SIMD package is not found.
+ """
+ _pil_available = importlib.util.find_spec("PIL") is not None
+ if _pil_available:
+ try:
+ package_version = importlib_metadata.version("Pillow")
+ except importlib_metadata.PackageNotFoundError:
+ try:
+ package_version = importlib_metadata.version("Pillow-SIMD")
+ except importlib_metadata.PackageNotFoundError:
+ return False
+ logger.debug(f"Detected PIL version {package_version}")
+ return _pil_available
+
+
+def is_in_notebook():
+ """
+ This function checks if the code is running in a Jupyter notebook environment by examining the current execution environment and relevant environment variables.
+
+ Returns:
+ bool: Returns True if the code is running in a Jupyter notebook environment, otherwise False.
+
+ Raises:
+ AttributeError: If an attribute error occurs during the execution of the function.
+ ImportError: If the code is running in the console, VS Code, or Databricks environment, respective ImportError with the environment name is raised.
+ KeyError: If a key error occurs during the execution of the function.
+ """
+ try:
+ # Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
+ get_ipython = sys.modules["IPython"].get_ipython
+ if "IPKernelApp" not in get_ipython().config:
+ raise ImportError("console")
+ if "VSCODE_PID" in os.environ:
+ raise ImportError("vscode")
+ if (
+ "DATABRICKS_RUNTIME_VERSION" in os.environ
+ and os.environ["DATABRICKS_RUNTIME_VERSION"] < "11.0"
+ ):
+ # Databricks Runtime 11.0 and above uses IPython kernel by default so it should be compatible with Jupyter notebook
+ # https://docs.microsoft.com/en-us/azure/databricks/notebooks/ipython-kernel
+ raise ImportError("databricks")
+
+ return importlib.util.find_spec("IPython") is not None
+ except (AttributeError, ImportError, KeyError):
+ return False
+
+
+# docstyle-ignore
+CYTHON_IMPORT_ERROR = """
+{0} requires the Cython library but it was not found in your environment. You can install it with pip: `pip install
+Cython`. Please note that you may need to restart your runtime after installation.
+"""
+
+# docstyle-ignore
+DATASETS_IMPORT_ERROR = """
+{0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with:
+```
+pip install datasets
+```
+In a notebook or a colab, you can install it by executing a cell with
+```
+!pip install datasets
+```
+then restarting your kernel.
+
+Note that if you have a local folder named `datasets` or a local python file named `datasets.py` in your current
+working directory, python may try to import this instead of the 🤗 Datasets library. You should rename this folder or
+that python file if that's the case. Please note that you may need to restart your runtime after installation.
+"""
+
+# docstyle-ignore
+TOKENIZERS_IMPORT_ERROR = """
+{0} requires the 🤗 Tokenizers library but it was not found in your environment. You can install it with:
+```
+pip install tokenizers
+```
+In a notebook or a colab, you can install it by executing a cell with
+```
+!pip install tokenizers
+```
+Please note that you may need to restart your runtime after installation.
+"""
+
+# docstyle-ignore
+SENTENCEPIECE_IMPORT_ERROR = """
+{0} requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the
+installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones
+that match your environment. Please note that you may need to restart your runtime after installation.
+"""
+
+# docstyle-ignore
+PROTOBUF_IMPORT_ERROR = """
+{0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the
+installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and follow the ones
+that match your environment. Please note that you may need to restart your runtime after installation.
+"""
+
+# docstyle-ignore
+MINDSPORE_IMPORT_ERROR = """
+{0} requires the MindSpore library but it was not found in your environment. Checkout the instructions on the
+installation page: https://www.mindspore.cn/install/ and follow the ones that match your environment.
+Please note that you may need to restart your runtime after installation.
+"""
+
+LIBROSA_IMPORT_ERROR = """
+{0} requires thes librosa library. But that was not found in your environment. You can install them with pip:
+`pip install librosa`
+Please note that you may need to restart your runtime after installation.
+"""
+
+ESSENTIA_IMPORT_ERROR = """
+{0} requires essentia library. But that was not found in your environment. You can install them with pip:
+`pip install essentia==2.1b6.dev1034`
+Please note that you may need to restart your runtime after installation.
+"""
+
+SCIPY_IMPORT_ERROR = """
+{0} requires the scipy library but it was not found in your environment. You can install it with pip:
+`pip install scipy`. Please note that you may need to restart your runtime after installation.
+"""
+
+PRETTY_MIDI_IMPORT_ERROR = """
+{0} requires thes pretty_midi library. But that was not found in your environment. You can install them with pip:
+`pip install pretty_midi`
+Please note that you may need to restart your runtime after installation.
+"""
+
+# docstyle-ignore
+PYCTCDECODE_IMPORT_ERROR = """
+{0} requires the pyctcdecode library but it was not found in your environment. You can install it with pip:
+`pip install pyctcdecode`. Please note that you may need to restart your runtime after installation.
+"""
+
+JIEBA_IMPORT_ERROR = """
+{0} requires the jieba library but it was not found in your environment. You can install it with pip: `pip install
+jieba`. Please note that you may need to restart your runtime after installation.
+"""
+
+VISION_IMPORT_ERROR = """
+{0} requires the PIL library but it was not found in your environment. You can install it with pip:
+`pip install pillow`. Please note that you may need to restart your runtime after installation.
+"""
+
+# docstyle-ignore
+G2P_EN_IMPORT_ERROR = """
+{0} requires the g2p-en library but it was not found in your environment. You can install it with pip:
+`pip install g2p-en`. Please note that you may need to restart your runtime after installation.
+"""
+
+BACKENDS_MAPPING = OrderedDict(
+ [
+ ("mindspore", (is_mindspore_available, MINDSPORE_IMPORT_ERROR)),
+ ("cython", (is_cython_available, CYTHON_IMPORT_ERROR)),
+ ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
+ ("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
+ ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
+ ("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
+ ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
+ ("essentia", (is_essentia_available, ESSENTIA_IMPORT_ERROR)),
+ ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
+ ("pretty_midi", (is_pretty_midi_available, PRETTY_MIDI_IMPORT_ERROR)),
+ ("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)),
+ ("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)),
+ ("vision", (is_vision_available, VISION_IMPORT_ERROR)),
+ ("g2p_en", (is_g2p_en_available, G2P_EN_IMPORT_ERROR)),
+ ]
+)
+
+
+def requires_backends(obj, backends):
+ """
+ Function to check if the specified backends are available for the given object.
+
+ Args:
+ obj (object): The object for which backends availability needs to be checked.
+ backends (list or tuple or str): The backend(s) to be checked for availability. Can be a single backend as a string or a list/tuple of backends.
+
+ Returns:
+ None. This function does not return any value.
+
+ Raises:
+ ImportError: If any of the specified backends are not available for the object.
+ """
+ if not isinstance(backends, (list, tuple)):
+ backends = [backends]
+
+ name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
+
+ checks = (BACKENDS_MAPPING[backend] for backend in backends)
+ failed = [msg.format(name) for available, msg in checks if not available()]
+ if failed:
+ raise ImportError("".join(failed))
+
+
+class DummyObject(type):
+ """
+ Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
+ `requires_backend` each time a user tries to access any method of that class.
+ """
+ def __getattribute__(cls, key):
+ """
+ This method is called automatically when an attribute is accessed on the 'DummyObject' class or any of its subclasses.
+
+ Args:
+ cls (type): The class object that the method was called on.
+ key (str): The name of the attribute being accessed.
+
+ Returns:
+ None: This method does not return any value.
+
+ Raises:
+ None: This method does not raise any exceptions.
+ """
+ if key.startswith("_") and key != "_from_config":
+ return super().__getattribute__(key)
+ requires_backends(cls, cls._backends)
+
+
+def mindspore_required(func):
+ """
+ This function decorates another function to require the presence of MindSpore framework.
+
+ Args:
+ func (function): The function to be decorated.
+
+ Returns:
+ None. The function returns None.
+
+ Raises:
+ FutureWarning: If the method `torch_required` is deprecated.
+ ImportError: If the decorated function requires MindSpore but MindSpore is not available.
+ """
+ warnings.warn(
+ "The method `torch_required` is deprecated. Use `requires_backends` instead.",
+ FutureWarning,
+ )
+
+ # Chose a different decorator name than in tests so it's clear they are not the same.
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if is_mindspore_available():
+ return func(*args, **kwargs)
+ raise ImportError(f"Method `{func.__name__}` requires MindSpore.")
+
+ return wrapper
+
+
+class OptionalDependencyNotAvailable(BaseException):
+ """Internally used error class for signalling an optional dependency was not found."""
+def direct_transformers_import(path: str, file="__init__.py") -> ModuleType:
+ """Imports transformers directly
+
+ Args:
+ path (`str`): The path to the source file
+ file (`str`, optional): The file to join with the path. Defaults to "__init__.py".
+
+ Returns:
+ `ModuleType`: The resulting imported module
+ """
+ name = "mindnlp.transformers"
+ location = os.path.join(path, file)
+ spec = importlib.util.spec_from_file_location(
+ name, location, submodule_search_locations=[path]
+ )
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ module = sys.modules[name]
+ return module
+
+
+def is_soundfile_availble():
+ return _soundfile_available
+
+
+def is_speech_available():
+ return _torchaudio_available
+
+
+def is_kenlm_available():
+ return _kenlm_available
\ No newline at end of file
diff --git a/mindnlp/utils/logging.py b/mindnlp/utils/logging.py
new file mode 100644
index 000000000..5cd14e8b2
--- /dev/null
+++ b/mindnlp/utils/logging.py
@@ -0,0 +1,527 @@
+# coding=utf-8
+# Copyright 2020 Optuna, Hugging Face
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+# pylint: disable=unused-import
+""" Logging utilities."""
+
+import functools
+import logging
+import os
+import sys
+import threading
+from logging import (
+ CRITICAL, # NOQA
+ DEBUG, # NOQA
+ ERROR, # NOQA
+ FATAL, # NOQA
+ INFO, # NOQA
+ NOTSET, # NOQA
+ WARN, # NOQA
+ WARNING, # NOQA
+)
+from logging import captureWarnings as _captureWarnings
+from typing import Optional
+
+from tqdm import auto as tqdm_lib
+
+
+_lock = threading.Lock()
+_default_handler: Optional[logging.Handler] = None
+
+log_levels = {
+ "detail": logging.DEBUG, # will also print filename and line number
+ "debug": logging.DEBUG,
+ "info": logging.INFO,
+ "warning": logging.WARNING,
+ "error": logging.ERROR,
+ "critical": logging.CRITICAL,
+}
+
+_default_log_level = logging.WARNING
+
+_tqdm_active = True
+
+
+def _get_default_logging_level():
+ """
+ If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
+ not - fall back to `_default_log_level`
+ """
+ env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None)
+ if env_level_str:
+ if env_level_str in log_levels:
+ return log_levels[env_level_str]
+ logging.getLogger().warning(
+ f"Unknown option TRANSFORMERS_VERBOSITY={env_level_str}, "
+ f"has to be one of: { ', '.join(log_levels.keys()) }"
+ )
+ return _default_log_level
+
+
+def _get_library_name() -> str:
+ """
+ Returns the name of the library based on the module name.
+
+ Returns:
+ str: The name of the library extracted from the module name.
+
+ """
+ return __name__.split(".")[0] # pylint: disable=use-maxsplit-arg
+
+
+def _get_library_root_logger() -> logging.Logger:
+ """
+ Retrieves the root logger for the library.
+
+ Returns:
+ A logging.Logger object representing the root logger for the library.
+
+ Raises:
+ None.
+ """
+ return logging.getLogger(_get_library_name())
+
+
+def _configure_library_root_logger() -> None:
+ """
+ This function configures the root logger for the library.
+
+ Returns:
+ None: This function does not return any value.
+
+ Raises:
+ None
+ """
+ global _default_handler
+
+ with _lock:
+ if _default_handler:
+ # This library has already configured the library root logger.
+ return
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
+ # set defaults based on https://github.com/pyinstaller/pyinstaller/issues/7334#issuecomment-1357447176
+ if sys.stderr is None:
+ sys.stderr = open(os.devnull, "w")
+
+ _default_handler.flush = sys.stderr.flush
+
+ # Apply our default configuration to the library root logger.
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.addHandler(_default_handler)
+ library_root_logger.setLevel(_get_default_logging_level())
+ # if logging level is debug, we add pathname and lineno to formatter for easy debugging
+ if os.getenv("TRANSFORMERS_VERBOSITY", None) == "detail":
+ formatter = logging.Formatter("[%(levelname)s|%(pathname)s:%(lineno)s] %(asctime)s >> %(message)s")
+ _default_handler.setFormatter(formatter)
+
+ library_root_logger.propagate = False
+
+
+def _reset_library_root_logger() -> None:
+ """
+ Resets the root logger of the library to its default state.
+
+ Args:
+ None
+
+ Returns:
+ None. The function does not return any value.
+
+ Raises:
+ None
+ """
+ global _default_handler
+
+ with _lock:
+ if not _default_handler:
+ return
+
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.removeHandler(_default_handler)
+ library_root_logger.setLevel(logging.NOTSET)
+ _default_handler = None
+
+
+def get_log_levels_dict():
+ """
+ Returns a dictionary of log levels.
+
+ Returns:
+ dict: A dictionary containing log levels and their corresponding values.
+ """
+ return log_levels
+
+
+def captureWarnings(capture):
+ """
+ Calls the `captureWarnings` method from the logging library to enable management of the warnings emitted by the
+ `warnings` library.
+
+ Read more about this method here:
+ https://docs.python.org/3/library/logging.html#integration-with-the-warnings-module
+
+ All warnings will be logged through the `py.warnings` logger.
+
+ Careful: this method also adds a handler to this logger if it does not already have one, and updates the logging
+ level of that logger to the library's root logger.
+ """
+ logger = get_logger("py.warnings")
+
+ if not logger.handlers:
+ logger.addHandler(_default_handler)
+
+ logger.setLevel(_get_library_root_logger().level)
+
+ _captureWarnings(capture)
+
+
+def get_logger(name: Optional[str] = None) -> logging.Logger:
+ """
+ Return a logger with the specified name.
+
+ This function is not supposed to be directly accessed unless you are writing a custom transformers module.
+ """
+ if name is None:
+ name = _get_library_name()
+
+ _configure_library_root_logger()
+ return logging.getLogger(name)
+
+
+def get_verbosity() -> int:
+ """
+ Return the current level for the 🤗 Transformers's root logger as an int.
+
+ Returns:
+ `int`: The logging level.
+
+
+
+ 🤗 Transformers has following logging levels:
+
+ - 50: `transformers.logging.CRITICAL` or `transformers.logging.FATAL`
+ - 40: `transformers.logging.ERROR`
+ - 30: `transformers.logging.WARNING` or `transformers.logging.WARN`
+ - 20: `transformers.logging.INFO`
+ - 10: `transformers.logging.DEBUG`
+
+ """
+ _configure_library_root_logger()
+ return _get_library_root_logger().getEffectiveLevel()
+
+
+def set_verbosity(verbosity: int) -> None:
+ """
+ Set the verbosity level for the 🤗 Transformers's root logger.
+
+ Args:
+ verbosity (`int`):
+ Logging level, e.g., one of:
+
+ - `transformers.logging.CRITICAL` or `transformers.logging.FATAL`
+ - `transformers.logging.ERROR`
+ - `transformers.logging.WARNING` or `transformers.logging.WARN`
+ - `transformers.logging.INFO`
+ - `transformers.logging.DEBUG`
+ """
+ _configure_library_root_logger()
+ _get_library_root_logger().setLevel(verbosity)
+
+
+def set_verbosity_info():
+ """Set the verbosity to the `INFO` level."""
+ return set_verbosity(INFO)
+
+
+def set_verbosity_warning():
+ """Set the verbosity to the `WARNING` level."""
+ return set_verbosity(WARNING)
+
+
+def set_verbosity_debug():
+ """Set the verbosity to the `DEBUG` level."""
+ return set_verbosity(DEBUG)
+
+
+def set_verbosity_error():
+ """Set the verbosity to the `ERROR` level."""
+ return set_verbosity(ERROR)
+
+
+def disable_default_handler() -> None:
+ """Disable the default handler of the HuggingFace Transformers's root logger."""
+ _configure_library_root_logger()
+
+ assert _default_handler is not None
+ _get_library_root_logger().removeHandler(_default_handler)
+
+
+def enable_default_handler() -> None:
+ """Enable the default handler of the HuggingFace Transformers's root logger."""
+ _configure_library_root_logger()
+
+ assert _default_handler is not None
+ _get_library_root_logger().addHandler(_default_handler)
+
+
+def add_handler(handler: logging.Handler) -> None:
+ """adds a handler to the HuggingFace Transformers's root logger."""
+ _configure_library_root_logger()
+
+ assert handler is not None
+ _get_library_root_logger().addHandler(handler)
+
+
+def remove_handler(handler: logging.Handler) -> None:
+ """removes given handler from the HuggingFace Transformers's root logger."""
+ _configure_library_root_logger()
+
+ assert handler is not None and handler not in _get_library_root_logger().handlers
+ _get_library_root_logger().removeHandler(handler)
+
+
+def disable_propagation() -> None:
+ """
+ Disable propagation of the library log outputs. Note that log propagation is disabled by default.
+ """
+ _configure_library_root_logger()
+ _get_library_root_logger().propagate = False
+
+
+def enable_propagation() -> None:
+ """
+ Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to
+ prevent double logging if the root logger has been configured.
+ """
+ _configure_library_root_logger()
+ _get_library_root_logger().propagate = True
+
+
+def enable_explicit_format() -> None:
+ """
+ Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows:
+ ```
+ [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
+ ```
+ All handlers currently bound to the root logger are affected by this method.
+ """
+ handlers = _get_library_root_logger().handlers
+
+ for handler in handlers:
+ formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
+ handler.setFormatter(formatter)
+
+
+def reset_format() -> None:
+ """
+ Resets the formatting for HuggingFace Transformers's loggers.
+
+ All handlers currently bound to the root logger are affected by this method.
+ """
+ handlers = _get_library_root_logger().handlers
+
+ for handler in handlers:
+ handler.setFormatter(None)
+
+
+def warning_advice(self, *args, **kwargs):
+ """
+ This method is identical to `logger.warning()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this
+ warning will not be printed
+ """
+ no_advisory_warnings = os.getenv("NO_ADVISORY_WARNINGS", False) # pylint: disable=invalid-envvar-default
+ if no_advisory_warnings:
+ return
+ self.warning(*args, **kwargs)
+
+
+logging.Logger.warning_advice = warning_advice
+
+
+@functools.lru_cache(None)
+def warning_once(self, *args, **kwargs):
+ """
+ This method is identical to `logger.warning()`, but will emit the warning with the same message only once
+
+ Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
+ The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
+ another type of cache that includes the caller frame information in the hashing function.
+ """
+ self.warning(*args, **kwargs)
+
+
+logging.Logger.warning_once = warning_once
+
+
+class EmptyTqdm:
+ """Dummy tqdm which doesn't do anything."""
+ def __init__(self, *args, **kwargs):
+ """
+ Initializes an instance of the EmptyTqdm class.
+
+ Args:
+ self: The instance of the EmptyTqdm class.
+
+ Returns:
+ None. This method does not return any value.
+
+ Raises:
+ None.
+ """
+ self._iterator = args[0] if args else None
+
+ def __iter__(self):
+ """
+ This method implements the iterator protocol for the EmptyTqdm class.
+
+ Args:
+ self: EmptyTqdm object. The instance of the EmptyTqdm class for which the iterator is being created.
+
+ Returns:
+ None. This method returns an iterator object that iterates over the _iterator attribute of the EmptyTqdm instance.
+
+ Raises:
+ No specific exceptions are raised by this method.
+ """
+ return iter(self._iterator)
+
+ def __getattr__(self, _):
+ """Return empty function."""
+ def empty_fn(*args, **kwargs):
+ return
+ return empty_fn
+
+ def __enter__(self):
+ """
+ __enter__
+
+ Args:
+ self: EmptyTqdm
+ The self parameter refers to the current instance of the EmptyTqdm class.
+
+ Returns:
+ None
+ This method returns None.
+
+ Raises:
+ No exceptions are raised by this method.
+ """
+ return self
+
+ def __exit__(self, type_, value, traceback):
+ """
+ __exit__ method in the EmptyTqdm class.
+
+ Args:
+ self: EmptyTqdm object
+ The instance of the EmptyTqdm class.
+ type_: type
+ The type of the exception. It represents the type of the exception being handled.
+ value: exception
+ The exception that was raised. It represents the actual exception object.
+ traceback: traceback
+ The traceback object. It represents the traceback information associated with the exception.
+
+ Returns:
+ None
+ This method does not return any value.
+
+ Raises:
+ This method does not raise any exceptions explicitly.
+ """
+ return
+
+
+class _tqdm_cls:
+
+ """_tqdm_cls is a Python class that provides functionality for managing the progress of tasks. It includes methods for calling the class, setting a lock, and getting a lock. This class is designed to work
+in conjunction with the tqdm_lib module for displaying progress bars during iterative processes. When _tqdm_active is True, the class uses methods from the tqdm_lib.tqdm module to handle progress tracking.
+Otherwise, it falls back to using an EmptyTqdm instance for progress tracking. The set_lock method allows users to specify a lock for thread safety, and the get_lock method retrieves the current lock if one
+has been set."""
+ def __call__(self, *args, **kwargs):
+ """
+ This method __call__ in the class _tqdm_cls is used to conditionally return either a tqdm object or an EmptyTqdm object based on the _tqdm_active flag.
+
+ Args:
+ self (object): The instance of the _tqdm_cls class. It is used to access the attributes and methods of the class.
+
+ Returns:
+ None: This method does not explicitly return any value. It returns either a tqdm object or an EmptyTqdm object based on the _tqdm_active flag.
+
+ Raises:
+ No specific exceptions are raised by this method under normal circumstances. However, if there are issues related to the instantiation of tqdm objects or EmptyTqdm objects, standard Python
+exceptions may be raised.
+ """
+ if _tqdm_active:
+ return tqdm_lib.tqdm(*args, **kwargs)
+ return EmptyTqdm(*args, **kwargs)
+
+ def set_lock(self, *args, **kwargs):
+ """
+ Method to set the lock for the _tqdm_cls instance.
+
+ Args:
+ self (_tqdm_cls): The instance of the _tqdm_cls class.
+ This parameter is required to access the instance and set the lock.
+ It is of type _tqdm_cls and represents the instance on which the lock is being set.
+
+ Returns:
+ None: This method does not return any value. The lock is set within the instance itself.
+
+ Raises:
+ No specific exceptions are raised by this method.
+ However, if _tqdm_active is False, the method will not set the lock and will return without any further action.
+ """
+ self._lock = None
+ if _tqdm_active:
+ return tqdm_lib.tqdm.set_lock(*args, **kwargs)
+
+ def get_lock(self):
+ """
+ This method is used to retrieve the lock used by the _tqdm_cls class.
+
+ Args:
+ self (object): The instance of the _tqdm_cls class.
+
+ Returns:
+ None: This method does not return any value.
+
+ Raises:
+ N/A
+ """
+ if _tqdm_active:
+ return tqdm_lib.tqdm.get_lock()
+
+
+tqdm = _tqdm_cls()
+
+
+def is_progress_bar_enabled() -> bool:
+ """Return a boolean indicating whether tqdm progress bars are enabled."""
+ global _tqdm_active # pylint: disable=global-variable-not-assigned
+ return bool(_tqdm_active)
+
+
+def enable_progress_bar():
+ """Enable tqdm progress bar."""
+ global _tqdm_active
+ _tqdm_active = True
+
+
+def disable_progress_bar():
+ """Disable tqdm progress bar."""
+ global _tqdm_active
+ _tqdm_active = False
\ No newline at end of file
diff --git a/mindnlp/utils/safetensors_patch.py b/mindnlp/utils/safetensors_patch.py
index 740b7ba5c..2d32c56e5 100644
--- a/mindnlp/utils/safetensors_patch.py
+++ b/mindnlp/utils/safetensors_patch.py
@@ -6,6 +6,7 @@
from mindspore import Tensor
from mindnlp.core.configs import SUPPORT_BF16
+import safetensors
if SUPPORT_BF16:
from mindspore.common.np_dtype import bfloat16 # pylint: disable=import-error
@@ -98,6 +99,7 @@ def start_offset(self):
return self.base_ptr + self.info["data_offsets"][0]
def get_shape(self):
+ print('get_shape', self.shape)
return self.shape
def get_dtype(self):
@@ -190,7 +192,9 @@ def __exit__(self, *args):
self.file.close()
def metadata(self):
- return self.__metadata__
+ meta = self.__metadata__
+ meta['format'] = 'pt'
+ return meta
def keys(self):
return list(self.tensors.keys())
@@ -201,6 +205,27 @@ def get_tensor(self, name):
def get_slice(self, name):
return self.tensors[name]
+def safe_save_file(tensor_dict, filename, metadata=None):
+ """
+ Function to safely save a dictionary of tensors to a file.
+
+ Args:
+ tensor_dict (dict): A dictionary where keys are strings and values are numpy arrays representing tensors.
+ filename (str): The name of the file where the tensor data will be saved.
+ metadata (optional): Additional metadata to be saved along with the tensor data. Default is None.
+
+ Returns:
+ None. The function does not return any value explicitly.
+
+ Raises:
+ ValueError: If the input tensor_dict is not in the expected format.
+ IOError: If there are issues with writing the data to the specified file.
+ Exception: Any other unexpected error that may occur during the process.
+ """
+ tensor_dict = {k: v.asnumpy() for k, v in tensor_dict.items()}
+ return safetensors.numpy.save_file(tensor_dict, filename, metadata)
+
def setup_safetensors_patch():
- import safetensors
safetensors.safe_open = fast_safe_open
+ from safetensors import torch
+ torch.save_file = safe_save_file
diff --git a/mindnlp/utils/testing_utils.py b/mindnlp/utils/testing_utils.py
index d819c9a1a..4a9167541 100644
--- a/mindnlp/utils/testing_utils.py
+++ b/mindnlp/utils/testing_utils.py
@@ -47,7 +47,6 @@
from transformers.utils.import_utils import (
is_pytest_available,
- is_mindspore_available,
is_essentia_available,
is_librosa_available,
is_pretty_midi_available,
@@ -55,7 +54,6 @@
is_pyctcdecode_available,
is_safetensors_available,
is_sentencepiece_available,
- is_soundfile_availble,
is_tokenizers_available,
is_pytesseract_available,
is_vision_available,
@@ -65,6 +63,10 @@
is_ftfy_available
)
from transformers.utils.generic import strtobool
+from .import_utils import (
+ is_mindspore_available,
+ is_soundfile_availble,
+)
if is_pytest_available():
from _pytest.doctest import (
diff --git a/mindnlp/utils/torch_proxy.py b/mindnlp/utils/torch_proxy.py
index 22e0c410a..7309b4d62 100644
--- a/mindnlp/utils/torch_proxy.py
+++ b/mindnlp/utils/torch_proxy.py
@@ -57,7 +57,7 @@ def initialize_torch_proxy():
sys.modules["torch"] = torch_proxy
# 设置必要的元数据
- torch_proxy.__version__ = "1.13.1+mindnlp"
+ torch_proxy.__version__ = "2.1.1"
return torch_proxy
@@ -71,9 +71,9 @@ def setup_metadata_patch():
def patched_distribution(dist_name):
if dist_name == "torch":
return types.SimpleNamespace(
- version="1.13.1+mindnlp",
- metadata={"Name": "torch", "Version": "1.13.1+mindnlp"},
- read_text=lambda f: f"Name: torch\nVersion: 1.13.1+mindnlp" if f == "METADATA" else None
+ version="2.1.1",
+ metadata={"Name": "torch", "Version": "2.1.1"},
+ read_text=lambda f: f"Name: torch\nVersion: 2.1.1" if f == "METADATA" else None
)
return orig_distribution(dist_name)
@@ -82,10 +82,12 @@ def patched_distributions(**kwargs):
dists = list(orig_distributions(**kwargs))
dists.append(types.SimpleNamespace(
name="torch",
- version="1.13.1+mindnlp",
- metadata={"Name": "torch", "Version": "1.13.1+mindnlp"},
+ version="2.1.1",
+ metadata={"Name": "torch", "Version": "2.1.1"},
files=[],
- locate_file=lambda p: None
+ locate_file=lambda p: None,
+ _normalized_name='torch',
+ entry_points=[]
))
return dists
diff --git a/tests/transformers/models/bert/test_modeling_bert.py b/tests/transformers/models/bert/test_modeling_bert.py
index 1eb1876dc..b93b9c108 100644
--- a/tests/transformers/models/bert/test_modeling_bert.py
+++ b/tests/transformers/models/bert/test_modeling_bert.py
@@ -49,6 +49,7 @@
logging,
)
+mindspore.set_context(pynative_synchronize=True)
class BertModelTester:
def __init__(
self,
@@ -660,7 +661,7 @@ def test_sdpa_ignored_mask(self):
pkv.append([ops.rand(1, num_heads, 3, head_dim), ops.rand(1, num_heads, 3, head_dim)])
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertModel")
- inp = tokenizer("I am in Paris and", return_tensors="ms")
+ inp = tokenizer("I am in Paris and", return_tensors="pt")
del inp["attention_mask"]
diff --git a/tests/transformers/test_modeling_common.py b/tests/transformers/test_modeling_common.py
index 5b094368d..504e41d1e 100644
--- a/tests/transformers/test_modeling_common.py
+++ b/tests/transformers/test_modeling_common.py
@@ -81,9 +81,10 @@
require_mindspore,
slow,
)
-from mindnlp.configs import CONFIG_NAME, GENERATION_CONFIG_NAME, SAFE_WEIGHTS_NAME, ON_ORANGE_PI
+from mindnlp.transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, SAFE_WEIGHTS_NAME
+from mindnlp.core.configs import ON_ORANGE_PI
from mindnlp.utils.generic import ContextManagers, ModelOutput
-
+from transformers.pytorch_utils import id_tensor_storage
# if is_accelerate_available():
# from accelerate.utils import compute_module_sizes
@@ -1690,8 +1691,8 @@ def test_tied_weights_keys(self):
)
def test_model_weights_reload_no_missing_tied_weights(self):
- config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
@@ -1699,13 +1700,12 @@ def test_model_weights_reload_no_missing_tied_weights(self):
# We are nuking ALL weights on file, so every parameter should
# yell on load. We're going to detect if we yell too much, or too little.
placeholder_dict = {"tensor": mindspore.tensor([1, 2])}
- safe_save_file(placeholder_dict, os.path.join(tmp_dir, "model.safetensors"), metadata={"format": "np"})
+ safe_save_file(placeholder_dict, os.path.join(tmp_dir, "model.safetensors"), metadata={"format": "pt"})
model_reloaded, infos = model_class.from_pretrained(tmp_dir, output_loading_info=True)
- prefix = f"{model_reloaded.base_model_prefix}."
params = dict(model_reloaded.named_parameters())
params.update(dict(model_reloaded.named_buffers()))
- param_names = {k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys()}
+ param_names = set(params.keys())
missing_keys = set(infos["missing_keys"])
@@ -1714,12 +1714,11 @@ def test_model_weights_reload_no_missing_tied_weights(self):
# counterpart is present but here there are no weights at all so we do get the warning.
ptrs = collections.defaultdict(list)
for name, tensor in model_reloaded.state_dict().items():
- ptrs[id(tensor)].append(name)
+ ptrs[id_tensor_storage(tensor)].append(name)
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
for group in tied_params:
- group = {k[len(prefix) :] if k.startswith(prefix) else k for k in group}
# We remove the group from extra_missing if not all weights from group are in it
- if len(group - extra_missing) > 0:
+ if len(set(group) - extra_missing) > 0:
extra_missing = extra_missing - set(group)
self.assertEqual(
@@ -1733,15 +1732,14 @@ def test_model_weights_reload_no_missing_tied_weights(self):
# Remove nonpersistent buffers from missed_missing
buffers = [n for n, _ in model_reloaded.named_buffers()]
nonpersistent_buffers = {n for n in buffers if n not in model_reloaded.state_dict()}
- nonpersistent_buffers = {
- k[len(prefix) :] if k.startswith(prefix) else k for k in nonpersistent_buffers
- }
missed_missing = missed_missing - nonpersistent_buffers
if model_reloaded._keys_to_ignore_on_load_missing is None:
expected_missing = set()
else:
- expected_missing = set(model_reloaded._keys_to_ignore_on_load_missing)
+ expected_missing = set()
+ for pattern in model_reloaded._keys_to_ignore_on_load_missing:
+ expected_missing.update({k for k in param_names if re.search(pattern, k) is not None})
self.assertEqual(
missed_missing,
expected_missing,