diff --git a/mindnlp/core/__init__.py b/mindnlp/core/__init__.py index a4b91704a..80ad6ae56 100644 --- a/mindnlp/core/__init__.py +++ b/mindnlp/core/__init__.py @@ -46,7 +46,7 @@ from ._bind import get_default_dtype, set_default_dtype from . import profiler, cuda, optim, amp, compiler, jit, version, __future__, overrides, \ - return_types, linalg, fx + return_types, linalg, fx, backends, testing from ._lowrank import svd_lowrank from .random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state diff --git a/mindnlp/core/_dtype.py b/mindnlp/core/_dtype.py index b145865c5..b1c57716c 100644 --- a/mindnlp/core/_dtype.py +++ b/mindnlp/core/_dtype.py @@ -11,6 +11,12 @@ def is_floating_point(self): Type.is_floating_point = is_floating_point Type.__str__ = Type.__repr__ +@property +def itemsize(self): + return ITEM_SIZE[self] + +Type.itemsize = itemsize + half = float16 float = float32 double = float64 @@ -22,6 +28,22 @@ def is_floating_point(self): float8_e4m3fn = None # TODO: not support fp8 for now float8_e5m2 = None +ITEM_SIZE = { + bool : 1, + int8 : 1, + int16 : 2, + int32 : 4, + int64 : 8, + uint8 : 1, + uint16 : 2, + uint32 : 4, + uint64 : 8, + float16 : 2, + bfloat16 : 2, + float32 : 4, + float64 : 8, +} + np2dtype = { np.bool_: bool, np.int8: int8, diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py index 2722ea67c..f9c116e63 100644 --- a/mindnlp/core/_tensor.py +++ b/mindnlp/core/_tensor.py @@ -199,7 +199,7 @@ def __getitem__(self, slices): new_slices = () for s in slices: if isinstance(s, range): - s = slice(s.start, s.stop, s.step) + s = list(s) new_slices += (s,) slices = new_slices return origin_getitem(self, slices) @@ -288,10 +288,8 @@ def unfold(self, dimension, size, step): Tensor.unfold = unfold StubTensor.unfold = unfold - def new(self, data=None): - if data is None: - return Tensor([], dtype=self.dtype) - return Tensor(data, dtype=self.dtype) + def new(self, *shape): + return ops.empty(*shape, dtype=self.dtype) Tensor.new = new StubTensor.new = new @@ -310,6 +308,33 @@ def cpu(self): Tensor.cpu = cpu StubTensor.cpu = cpu + Tensor.take = ops.take + StubTensor.take = ops.take + + Tensor.sort = ops.sort + StubTensor.sort = ops.sort + + def requires_grad_(self, requires_grad=True): + self.requires_grad = requires_grad + return self + + Tensor.requires_grad_ = requires_grad_ + StubTensor.requires_grad_ = requires_grad_ + + @property + def data(self): + return Tensor(self) + + @data.setter + def data(self, new_value): + if isinstance(self, StubTensor) and isinstance(new_value, StubTensor): + self.stub = new_value.stub + else: + self.assign_value(new_value) + + Tensor.data = data + StubTensor.data = data + def _rebuild_from_type_v2(func, new_type, args, state): ret = func(*args) return ret \ No newline at end of file diff --git a/mindnlp/core/backends/__init__.py b/mindnlp/core/backends/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/backends/cuda/__init__.py b/mindnlp/core/backends/cuda/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/core/nn/modules/module.py b/mindnlp/core/nn/modules/module.py index 9a18c7e02..7b90980d3 100644 --- a/mindnlp/core/nn/modules/module.py +++ b/mindnlp/core/nn/modules/module.py @@ -1031,7 +1031,7 @@ def remove_from(*dicts_or_sets): self.register_parameter(name, value) else: modules = self.__dict__.get("_modules") - if isinstance(value, core.nn.Module): + if isinstance(value, Module): if modules is None: raise AttributeError( "cannot assign module before Module.__init__() call" @@ -1565,7 +1565,7 @@ def get_submodule(self, target: str) -> "Module": mod = getattr(mod, item) - if not isinstance(mod, core.nn.Module): + if not isinstance(mod, Module): raise AttributeError("`" + item + "` is not " "an nn.Module") diff --git a/mindnlp/core/ops/array.py b/mindnlp/core/ops/array.py index 571293abb..f54341f93 100644 --- a/mindnlp/core/ops/array.py +++ b/mindnlp/core/ops/array.py @@ -227,7 +227,8 @@ def split(tensor, split_size_or_sections, dim=0): # squeeze has_squeeze = hasattr(mindspore.mint, 'squeeze') -def squeeze(input, *dim): +def squeeze(input, *dim, **kwargs): + dim = kwargs.get('dim', dim) if use_pyboost() and has_squeeze: return mindspore.mint.squeeze(input, dim) return ops.squeeze(input, dim) @@ -255,6 +256,8 @@ def take(input, index): index = index.view(-1) if ON_ORANGE_PI: return tf_gather(input, index, 0).view(index_shape) + if index_shape == (): + return gather(input, 0, index)[0] return gather(input, 0, index).view(index_shape) def infer_size_impl(a, b): diff --git a/mindnlp/core/ops/comparison.py b/mindnlp/core/ops/comparison.py index 1be22e8c5..1baac8820 100644 --- a/mindnlp/core/ops/comparison.py +++ b/mindnlp/core/ops/comparison.py @@ -1,4 +1,5 @@ """comparison op""" +from collections import namedtuple import numpy as np import mindspore from mindspore import ops @@ -6,6 +7,7 @@ from ._inner import call_ms_func +sort_out = namedtuple('stor_out', ['sorted', 'indices']) # allclose has_allclose = hasattr(mindspore.mint, 'allclose') def allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): @@ -167,8 +169,9 @@ def not_equal(input, other): has_sort = hasattr(mindspore.mint, 'sort') def sort(input, *, dim=-1, descending=False, stable=False): if use_pyboost() and has_sort: - return mindspore.mint.sort(input, dim=dim, descending=descending, stable=stable) - return ops.sort(input, dim, descending) + out = mindspore.mint.sort(input, dim=dim, descending=descending, stable=stable) + out = ops.sort(input, dim, descending) + return sort_out(sorted=out[0], indices=out[1]) # topk has_topk = hasattr(mindspore.mint, 'topk') diff --git a/mindnlp/core/ops/random.py b/mindnlp/core/ops/random.py index 4aed3df98..b5d594e31 100644 --- a/mindnlp/core/ops/random.py +++ b/mindnlp/core/ops/random.py @@ -88,6 +88,7 @@ def rand_like(input, *, dtype=None): # randint has_randint = hasattr(mindspore.mint, 'randint') def randint(*args, **kwargs): + device = kwargs.pop('device', None) if use_pyboost() and has_randint: return mindspore.mint.randint(*args, **kwargs) return ops.randint(*args, **kwargs) diff --git a/mindnlp/core/ops/reduction.py b/mindnlp/core/ops/reduction.py index 1da14ab25..4f8c3c629 100644 --- a/mindnlp/core/ops/reduction.py +++ b/mindnlp/core/ops/reduction.py @@ -1,4 +1,5 @@ """reduction op""" +from collections import namedtuple import mindspore from mindspore import ops from mindspore.ops._primitive_cache import _get_cache_prim @@ -6,6 +7,7 @@ from ._inner import call_ms_func +max_out = namedtuple('max_out', ['values', 'indices']) # argmax has_argmax = hasattr(mindspore.mint, 'argmax') def argmax(input, dim=None, keepdim=False): @@ -67,7 +69,10 @@ def any(input, dim=None, keepdim=False, *, out=None): # max has_max = hasattr(mindspore.mint, 'max') def max(*args, **kwargs): - return mindspore.mint.max(*args, **kwargs) + out = mindspore.mint.max(*args, **kwargs) + if isinstance(out, tuple): + return max_out(values=out[0], indices=out[1]) + return out # min has_min = hasattr(mindspore.mint, 'min') diff --git a/mindnlp/core/overrides.py b/mindnlp/core/overrides.py index 047e190a5..7808818a1 100644 --- a/mindnlp/core/overrides.py +++ b/mindnlp/core/overrides.py @@ -35,7 +35,7 @@ def is_tensor_like(inp): >>> is_tensor_like(TensorLike()) True """ - return type(inp) is core.Tensor or hasattr(inp, "__torch_function__") + return isinstance(inp, core.Tensor) or hasattr(inp, "__torch_function__") def handle_torch_function( public_api: Callable, diff --git a/mindnlp/core/testing/__init__.py b/mindnlp/core/testing/__init__.py index e69de29bb..4e4ec332f 100644 --- a/mindnlp/core/testing/__init__.py +++ b/mindnlp/core/testing/__init__.py @@ -0,0 +1 @@ +from ._comparison import assert_allclose, assert_close as assert_close diff --git a/mindnlp/core/testing/_comparison.py b/mindnlp/core/testing/_comparison.py new file mode 100644 index 000000000..139ef0e2d --- /dev/null +++ b/mindnlp/core/testing/_comparison.py @@ -0,0 +1,1604 @@ +# mypy: allow-untyped-defs +import abc +import cmath +import collections.abc +import contextlib +from collections.abc import Collection, Sequence +from typing import Any, Callable, NoReturn, Optional, Union +from typing_extensions import deprecated + +from mindnlp import core + +try: + import numpy as np + + HAS_NUMPY = True +except ModuleNotFoundError: + HAS_NUMPY = False + np = None # type: ignore[assignment] + + +class ErrorMeta(Exception): + """Internal testing exception that makes that carries error metadata.""" + + def __init__( + self, type: type[Exception], msg: str, *, id: tuple[Any, ...] = () + ) -> None: + super().__init__( + "If you are a user and see this message during normal operation " + "please file an issue at https://github.com/pytorch/pytorch/issues. " + "If you are a developer and working on the comparison functions, please `raise ErrorMeta.to_error()` " + "for user facing errors." + ) + self.type = type + self.msg = msg + self.id = id + + def to_error( + self, msg: Optional[Union[str, Callable[[str], str]]] = None + ) -> Exception: + if not isinstance(msg, str): + generated_msg = self.msg + if self.id: + generated_msg += f"\n\nThe failure occurred for item {''.join(str([item]) for item in self.id)}" + + msg = msg(generated_msg) if callable(msg) else generated_msg + + return self.type(msg) + + +# Some analysis of tolerance by logging tests from test_core.py can be found in +# https://github.com/pytorch/pytorch/pull/32538. +# {dtype: (rtol, atol)} +_DTYPE_PRECISIONS = { + core.float16: (0.001, 1e-5), + core.bfloat16: (0.016, 1e-5), + core.float32: (1.3e-6, 1e-5), + core.float64: (1e-7, 1e-7), + # core.complex32: (0.001, 1e-5), + # core.complex64: (1.3e-6, 1e-5), + # core.complex128: (1e-7, 1e-7), +} +# # The default tolerances of core.float32 are used for quantized dtypes, because quantized tensors are compared in +# # their dequantized and floating point representation. For more details see `TensorLikePair._compare_quantized_values` +# _DTYPE_PRECISIONS.update( +# dict.fromkeys( +# (core.quint8, core.quint2x4, core.quint4x2, core.qint8, core.qint32), +# _DTYPE_PRECISIONS[core.float32], +# ) +# ) + + +def default_tolerances( + *inputs: Union[core.Tensor, core.dtype], + dtype_precisions: Optional[dict[core.dtype, tuple[float, float]]] = None, +) -> tuple[float, float]: + """Returns the default absolute and relative testing tolerances for a set of inputs based on the dtype. + + See :func:`assert_close` for a table of the default tolerance for each dtype. + + Returns: + (Tuple[float, float]): Loosest tolerances of all input dtypes. + """ + dtypes = [] + for input in inputs: + if isinstance(input, core.Tensor): + dtypes.append(input.dtype) + elif isinstance(input, core.dtype): + dtypes.append(input) + else: + raise TypeError( + f"Expected a core.Tensor or a core.dtype, but got {type(input)} instead." + ) + dtype_precisions = dtype_precisions or _DTYPE_PRECISIONS + rtols, atols = zip(*[dtype_precisions.get(dtype, (0.0, 0.0)) for dtype in dtypes]) + return max(rtols), max(atols) + + +def get_tolerances( + *inputs: Union[core.Tensor, core.dtype], + rtol: Optional[float], + atol: Optional[float], + id: tuple[Any, ...] = (), +) -> tuple[float, float]: + """Gets absolute and relative to be used for numeric comparisons. + + If both ``rtol`` and ``atol`` are specified, this is a no-op. If both are not specified, the return value of + :func:`default_tolerances` is used. + + Raises: + ErrorMeta: With :class:`ValueError`, if only ``rtol`` or ``atol`` is specified. + + Returns: + (Tuple[float, float]): Valid absolute and relative tolerances. + """ + if (rtol is None) ^ (atol is None): + # We require both tolerance to be omitted or specified, because specifying only one might lead to surprising + # results. Imagine setting atol=0.0 and the tensors still match because rtol>0.0. + raise ErrorMeta( + ValueError, + f"Both 'rtol' and 'atol' must be either specified or omitted, " + f"but got no {'rtol' if rtol is None else 'atol'}.", + id=id, + ) + elif rtol is not None and atol is not None: + return rtol, atol + else: + return default_tolerances(*inputs) + + +def _make_bitwise_mismatch_msg( + *, + default_identifier: str, + identifier: Optional[Union[str, Callable[[str], str]]] = None, + extra: Optional[str] = None, + first_mismatch_idx: Optional[tuple[int]] = None, +): + """Makes a mismatch error message for bitwise values. + + Args: + default_identifier (str): Default description of the compared values, e.g. "Tensor-likes". + identifier (Optional[Union[str, Callable[[str], str]]]): Optional identifier that overrides + ``default_identifier``. Can be passed as callable in which case it will be called with + ``default_identifier`` to create the description at runtime. + extra (Optional[str]): Extra information to be placed after the message header and the mismatch statistics. + first_mismatch_idx (Optional[tuple[int]]): the index of the first mismatch, for each dimension. + """ + if identifier is None: + identifier = default_identifier + elif callable(identifier): + identifier = identifier(default_identifier) + + msg = f"{identifier} are not 'equal'!\n\n" + + if extra: + msg += f"{extra.strip()}\n" + if first_mismatch_idx is not None: + msg += f"The first mismatched element is at index {first_mismatch_idx}.\n" + return msg.strip() + + +def _make_mismatch_msg( + *, + default_identifier: str, + identifier: Optional[Union[str, Callable[[str], str]]] = None, + extra: Optional[str] = None, + abs_diff: float, + abs_diff_idx: Optional[Union[int, tuple[int, ...]]] = None, + atol: float, + rel_diff: float, + rel_diff_idx: Optional[Union[int, tuple[int, ...]]] = None, + rtol: float, +) -> str: + """Makes a mismatch error message for numeric values. + + Args: + default_identifier (str): Default description of the compared values, e.g. "Tensor-likes". + identifier (Optional[Union[str, Callable[[str], str]]]): Optional identifier that overrides + ``default_identifier``. Can be passed as callable in which case it will be called with + ``default_identifier`` to create the description at runtime. + extra (Optional[str]): Extra information to be placed after the message header and the mismatch statistics. + abs_diff (float): Absolute difference. + abs_diff_idx (Optional[Union[int, Tuple[int, ...]]]): Optional index of the absolute difference. + atol (float): Allowed absolute tolerance. Will only be added to mismatch statistics if it or ``rtol`` are + ``> 0``. + rel_diff (float): Relative difference. + rel_diff_idx (Optional[Union[int, Tuple[int, ...]]]): Optional index of the relative difference. + rtol (float): Allowed relative tolerance. Will only be added to mismatch statistics if it or ``atol`` are + ``> 0``. + """ + equality = rtol == 0 and atol == 0 + + def make_diff_msg( + *, + type: str, + diff: float, + idx: Optional[Union[int, tuple[int, ...]]], + tol: float, + ) -> str: + if idx is None: + msg = f"{type.title()} difference: {diff}" + else: + msg = f"Greatest {type} difference: {diff} at index {idx}" + if not equality: + msg += f" (up to {tol} allowed)" + return msg + "\n" + + if identifier is None: + identifier = default_identifier + elif callable(identifier): + identifier = identifier(default_identifier) + + msg = f"{identifier} are not {'equal' if equality else 'close'}!\n\n" + + if extra: + msg += f"{extra.strip()}\n" + + msg += make_diff_msg(type="absolute", diff=abs_diff, idx=abs_diff_idx, tol=atol) + msg += make_diff_msg(type="relative", diff=rel_diff, idx=rel_diff_idx, tol=rtol) + + return msg.strip() + + +def make_scalar_mismatch_msg( + actual: Union[bool, int, float, complex], + expected: Union[bool, int, float, complex], + *, + rtol: float, + atol: float, + identifier: Optional[Union[str, Callable[[str], str]]] = None, +) -> str: + """Makes a mismatch error message for scalars. + + Args: + actual (Union[bool, int, float, complex]): Actual scalar. + expected (Union[bool, int, float, complex]): Expected scalar. + rtol (float): Relative tolerance. + atol (float): Absolute tolerance. + identifier (Optional[Union[str, Callable[[str], str]]]): Optional description for the scalars. Can be passed + as callable in which case it will be called by the default value to create the description at runtime. + Defaults to "Scalars". + """ + abs_diff = abs(actual - expected) + rel_diff = float("inf") if expected == 0 else abs_diff / abs(expected) + return _make_mismatch_msg( + default_identifier="Scalars", + identifier=identifier, + extra=f"Expected {expected} but got {actual}.", + abs_diff=abs_diff, + atol=atol, + rel_diff=rel_diff, + rtol=rtol, + ) + + +def make_tensor_mismatch_msg( + actual: core.Tensor, + expected: core.Tensor, + matches: core.Tensor, + *, + rtol: float, + atol: float, + identifier: Optional[Union[str, Callable[[str], str]]] = None, +): + """Makes a mismatch error message for tensors. + + Args: + actual (core.Tensor): Actual tensor. + expected (core.Tensor): Expected tensor. + matches (core.Tensor): Boolean mask of the same shape as ``actual`` and ``expected`` that indicates the + location of matches. + rtol (float): Relative tolerance. + atol (float): Absolute tolerance. + identifier (Optional[Union[str, Callable[[str], str]]]): Optional description for the tensors. Can be passed + as callable in which case it will be called by the default value to create the description at runtime. + Defaults to "Tensor-likes". + """ + + def unravel_flat_index(flat_index: int) -> tuple[int, ...]: + if not matches.shape: + return () + + inverse_index = [] + for size in matches.shape[::-1]: + div, mod = divmod(flat_index, size) + flat_index = div + inverse_index.append(mod) + + return tuple(inverse_index[::-1]) + + number_of_elements = matches.numel() + total_mismatches = number_of_elements - int(core.sum(matches)) + extra = ( + f"Mismatched elements: {total_mismatches} / {number_of_elements} " + f"({total_mismatches / number_of_elements:.1%})" + ) + if actual.dtype.is_floating_point and actual.dtype.itemsize == 1: + # skip checking for max_abs_diff and max_rel_diff for float8-like values + first_mismatch_idx = tuple(core.nonzero(~matches, as_tuple=False)[0].tolist()) + return _make_bitwise_mismatch_msg( + default_identifier="Tensor-likes", + identifier=identifier, + extra=extra, + first_mismatch_idx=first_mismatch_idx, + ) + + actual_flat = actual.flatten() + expected_flat = expected.flatten() + matches_flat = matches.flatten() + + if not actual.dtype.is_floating_point and not actual.dtype.is_complex: + # TODO: Instead of always upcasting to int64, it would be sufficient to cast to the next higher dtype to avoid + # overflow + actual_flat = actual_flat.to(core.int64) + expected_flat = expected_flat.to(core.int64) + + abs_diff = core.abs(actual_flat - expected_flat) + # Ensure that only mismatches are used for the max_abs_diff computation + abs_diff[matches_flat] = 0 + max_abs_diff, max_abs_diff_flat_idx = core.max(abs_diff, 0) + + rel_diff = abs_diff / core.abs(expected_flat) + # Ensure that only mismatches are used for the max_rel_diff computation + rel_diff[matches_flat] = 0 + max_rel_diff, max_rel_diff_flat_idx = core.max(rel_diff, 0) + return _make_mismatch_msg( + default_identifier="Tensor-likes", + identifier=identifier, + extra=extra, + abs_diff=max_abs_diff.item(), + abs_diff_idx=unravel_flat_index(int(max_abs_diff_flat_idx)), + atol=atol, + rel_diff=max_rel_diff.item(), + rel_diff_idx=unravel_flat_index(int(max_rel_diff_flat_idx)), + rtol=rtol, + ) + + +class UnsupportedInputs(Exception): # noqa: B903 + """Exception to be raised during the construction of a :class:`Pair` in case it doesn't support the inputs.""" + + +class Pair(abc.ABC): + """ABC for all comparison pairs to be used in conjunction with :func:`assert_equal`. + + Each subclass needs to overwrite :meth:`Pair.compare` that performs the actual comparison. + + Each pair receives **all** options, so select the ones applicable for the subclass and forward the rest to the + super class. Raising an :class:`UnsupportedInputs` during constructions indicates that the pair is not able to + handle the inputs and the next pair type will be tried. + + All other errors should be raised as :class:`ErrorMeta`. After the instantiation, :meth:`Pair._make_error_meta` can + be used to automatically handle overwriting the message with a user supplied one and id handling. + """ + + def __init__( + self, + actual: Any, + expected: Any, + *, + id: tuple[Any, ...] = (), + **unknown_parameters: Any, + ) -> None: + self.actual = actual + self.expected = expected + self.id = id + self._unknown_parameters = unknown_parameters + + @staticmethod + def _inputs_not_supported() -> NoReturn: + raise UnsupportedInputs + + @staticmethod + def _check_inputs_isinstance(*inputs: Any, cls: Union[type, tuple[type, ...]]): + """Checks if all inputs are instances of a given class and raise :class:`UnsupportedInputs` otherwise.""" + if not all(isinstance(input, cls) for input in inputs): + Pair._inputs_not_supported() + + def _fail( + self, type: type[Exception], msg: str, *, id: tuple[Any, ...] = () + ) -> NoReturn: + """Raises an :class:`ErrorMeta` from a given exception type and message and the stored id. + + .. warning:: + + If you use this before the ``super().__init__(...)`` call in the constructor, you have to pass the ``id`` + explicitly. + """ + raise ErrorMeta(type, msg, id=self.id if not id and hasattr(self, "id") else id) + + @abc.abstractmethod + def compare(self) -> None: + """Compares the inputs and raises an :class`ErrorMeta` in case they mismatch.""" + + def extra_repr(self) -> Sequence[Union[str, tuple[str, Any]]]: + """Returns extra information that will be included in the representation. + + Should be overwritten by all subclasses that use additional options. The representation of the object will only + be surfaced in case we encounter an unexpected error and thus should help debug the issue. Can be a sequence of + key-value-pairs or attribute names. + """ + return [] + + def __repr__(self) -> str: + head = f"{type(self).__name__}(" + tail = ")" + body = [ + f" {name}={value!s}," + for name, value in [ + ("id", self.id), + ("actual", self.actual), + ("expected", self.expected), + *[ + (extra, getattr(self, extra)) if isinstance(extra, str) else extra + for extra in self.extra_repr() + ], + ] + ] + return "\n".join((head, *body, *tail)) + + +class ObjectPair(Pair): + """Pair for any type of inputs that will be compared with the `==` operator. + + .. note:: + + Since this will instantiate for any kind of inputs, it should only be used as fallback after all other pairs + couldn't handle the inputs. + + """ + + def compare(self) -> None: + try: + equal = self.actual == self.expected + except Exception as error: + # We are not using `self._raise_error_meta` here since we need the exception chaining + raise ErrorMeta( + ValueError, + f"{self.actual} == {self.expected} failed with:\n{error}.", + id=self.id, + ) from error + + if not equal: + self._fail(AssertionError, f"{self.actual} != {self.expected}") + + +class NonePair(Pair): + """Pair for ``None`` inputs.""" + + def __init__(self, actual: Any, expected: Any, **other_parameters: Any) -> None: + if not (actual is None or expected is None): + self._inputs_not_supported() + + super().__init__(actual, expected, **other_parameters) + + def compare(self) -> None: + if not (self.actual is None and self.expected is None): + self._fail( + AssertionError, f"None mismatch: {self.actual} is not {self.expected}" + ) + + +class BooleanPair(Pair): + """Pair for :class:`bool` inputs. + + .. note:: + + If ``numpy`` is available, also handles :class:`numpy.bool_` inputs. + + """ + + def __init__( + self, + actual: Any, + expected: Any, + *, + id: tuple[Any, ...], + **other_parameters: Any, + ) -> None: + actual, expected = self._process_inputs(actual, expected, id=id) + super().__init__(actual, expected, **other_parameters) + + @property + def _supported_types(self) -> tuple[type, ...]: + cls: list[type] = [bool] + if HAS_NUMPY: + cls.append(np.bool_) + return tuple(cls) + + def _process_inputs( + self, actual: Any, expected: Any, *, id: tuple[Any, ...] + ) -> tuple[bool, bool]: + self._check_inputs_isinstance(actual, expected, cls=self._supported_types) + actual, expected = ( + self._to_bool(bool_like, id=id) for bool_like in (actual, expected) + ) + return actual, expected + + def _to_bool(self, bool_like: Any, *, id: tuple[Any, ...]) -> bool: + if isinstance(bool_like, bool): + return bool_like + elif isinstance(bool_like, np.bool_): + return bool_like.item() + else: + raise ErrorMeta( + TypeError, f"Unknown boolean type {type(bool_like)}.", id=id + ) + + def compare(self) -> None: + if self.actual is not self.expected: + self._fail( + AssertionError, + f"Booleans mismatch: {self.actual} is not {self.expected}", + ) + + +class NumberPair(Pair): + """Pair for Python number (:class:`int`, :class:`float`, and :class:`complex`) inputs. + + .. note:: + + If ``numpy`` is available, also handles :class:`numpy.number` inputs. + + Kwargs: + rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default + values based on the type are selected with the below table. + atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default + values based on the type are selected with the below table. + equal_nan (bool): If ``True``, two ``NaN`` values are considered equal. Defaults to ``False``. + check_dtype (bool): If ``True``, the type of the inputs will be checked for equality. Defaults to ``False``. + + The following table displays correspondence between Python number type and the ``core.dtype``'s. See + :func:`assert_close` for the corresponding tolerances. + + +------------------+-------------------------------+ + | ``type`` | corresponding ``core.dtype`` | + +==================+===============================+ + | :class:`int` | :attr:`~core.int64` | + +------------------+-------------------------------+ + | :class:`float` | :attr:`~core.float64` | + +------------------+-------------------------------+ + | :class:`complex` | :attr:`~core.complex64` | + +------------------+-------------------------------+ + """ + + _TYPE_TO_DTYPE = { + int: core.int64, + float: core.float64, + complex: core.complex128, + } + _NUMBER_TYPES = tuple(_TYPE_TO_DTYPE.keys()) + + def __init__( + self, + actual: Any, + expected: Any, + *, + id: tuple[Any, ...] = (), + rtol: Optional[float] = None, + atol: Optional[float] = None, + equal_nan: bool = False, + check_dtype: bool = False, + **other_parameters: Any, + ) -> None: + actual, expected = self._process_inputs(actual, expected, id=id) + super().__init__(actual, expected, id=id, **other_parameters) + + self.rtol, self.atol = get_tolerances( + *[self._TYPE_TO_DTYPE[type(input)] for input in (actual, expected)], + rtol=rtol, + atol=atol, + id=id, + ) + self.equal_nan = equal_nan + self.check_dtype = check_dtype + + @property + def _supported_types(self) -> tuple[type, ...]: + cls = list(self._NUMBER_TYPES) + if HAS_NUMPY: + cls.append(np.number) + return tuple(cls) + + def _process_inputs( + self, actual: Any, expected: Any, *, id: tuple[Any, ...] + ) -> tuple[Union[int, float, complex], Union[int, float, complex]]: + self._check_inputs_isinstance(actual, expected, cls=self._supported_types) + actual, expected = ( + self._to_number(number_like, id=id) for number_like in (actual, expected) + ) + return actual, expected + + def _to_number( + self, number_like: Any, *, id: tuple[Any, ...] + ) -> Union[int, float, complex]: + if HAS_NUMPY and isinstance(number_like, np.number): + return number_like.item() + elif isinstance(number_like, self._NUMBER_TYPES): + return number_like # type: ignore[return-value] + else: + raise ErrorMeta( + TypeError, f"Unknown number type {type(number_like)}.", id=id + ) + + def compare(self) -> None: + if self.check_dtype and type(self.actual) is not type(self.expected): + self._fail( + AssertionError, + f"The (d)types do not match: {type(self.actual)} != {type(self.expected)}.", + ) + + if self.actual == self.expected: + return + + if self.equal_nan and cmath.isnan(self.actual) and cmath.isnan(self.expected): + return + + abs_diff = abs(self.actual - self.expected) + tolerance = self.atol + self.rtol * abs(self.expected) + + if cmath.isfinite(abs_diff) and abs_diff <= tolerance: + return + + self._fail( + AssertionError, + make_scalar_mismatch_msg( + self.actual, self.expected, rtol=self.rtol, atol=self.atol + ), + ) + + def extra_repr(self) -> Sequence[str]: + return ( + "rtol", + "atol", + "equal_nan", + "check_dtype", + ) + + +class TensorLikePair(Pair): + """Pair for :class:`core.Tensor`-like inputs. + + Kwargs: + allow_subclasses (bool): + rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default + values based on the type are selected. See :func:assert_close: for details. + atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default + values based on the type are selected. See :func:assert_close: for details. + equal_nan (bool): If ``True``, two ``NaN`` values are considered equal. Defaults to ``False``. + check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same + :attr:`~core.Tensor.device`. If this check is disabled, tensors on different + :attr:`~core.Tensor.device`'s are moved to the CPU before being compared. + check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this + check is disabled, tensors with different ``dtype``'s are promoted to a common ``dtype`` (according to + :func:`core.promote_types`) before being compared. + check_layout (bool): If ``True`` (default), asserts that corresponding tensors have the same ``layout``. If this + check is disabled, tensors with different ``layout``'s are converted to strided tensors before being + compared. + check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride. + """ + + def __init__( + self, + actual: Any, + expected: Any, + *, + id: tuple[Any, ...] = (), + allow_subclasses: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, + equal_nan: bool = False, + check_device: bool = True, + check_dtype: bool = True, + check_layout: bool = True, + check_stride: bool = False, + **other_parameters: Any, + ): + actual, expected = self._process_inputs( + actual, expected, id=id, allow_subclasses=allow_subclasses + ) + super().__init__(actual, expected, id=id, **other_parameters) + + self.rtol, self.atol = get_tolerances( + actual, expected, rtol=rtol, atol=atol, id=self.id + ) + self.equal_nan = equal_nan + self.check_device = check_device + self.check_dtype = check_dtype + self.check_layout = check_layout + self.check_stride = check_stride + + def _process_inputs( + self, actual: Any, expected: Any, *, id: tuple[Any, ...], allow_subclasses: bool + ) -> tuple[core.Tensor, core.Tensor]: + directly_related = isinstance(actual, type(expected)) or isinstance( + expected, type(actual) + ) + if not directly_related: + self._inputs_not_supported() + + if not allow_subclasses and type(actual) is not type(expected): + self._inputs_not_supported() + + actual, expected = (self._to_tensor(input) for input in (actual, expected)) + return actual, expected + + def _to_tensor(self, tensor_like: Any) -> core.Tensor: + if isinstance(tensor_like, core.Tensor): + return tensor_like + + try: + return core.as_tensor(tensor_like) + except Exception: + self._inputs_not_supported() + + def compare(self) -> None: + actual, expected = self.actual, self.expected + + self._compare_attributes(actual, expected) + if any(input.device.type == "meta" for input in (actual, expected)): + return + + actual, expected = self._equalize_attributes(actual, expected) + self._compare_values(actual, expected) + + def _compare_attributes( + self, + actual: core.Tensor, + expected: core.Tensor, + ) -> None: + """Checks if the attributes of two tensors match. + + Always checks + + - the :attr:`~core.Tensor.shape`, + - whether both inputs are quantized or not, + - and if they use the same quantization scheme. + + Checks for + + - :attr:`~core.Tensor.layout`, + - :meth:`~core.Tensor.stride`, + - :attr:`~core.Tensor.device`, and + - :attr:`~core.Tensor.dtype` + + are optional and can be disabled through the corresponding ``check_*`` flag during construction of the pair. + """ + + def raise_mismatch_error( + attribute_name: str, actual_value: Any, expected_value: Any + ) -> NoReturn: + self._fail( + AssertionError, + f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}.", + ) + + if actual.shape != expected.shape: + raise_mismatch_error("shape", actual.shape, expected.shape) + + # if actual.is_quantized != expected.is_quantized: + # raise_mismatch_error( + # "is_quantized", actual.is_quantized, expected.is_quantized + # ) + # elif actual.is_quantized and actual.qscheme() != expected.qscheme(): + # raise_mismatch_error("qscheme()", actual.qscheme(), expected.qscheme()) + + # if actual.layout != expected.layout: + # if self.check_layout: + # raise_mismatch_error("layout", actual.layout, expected.layout) + # elif ( + # actual.layout == core.strided + # and self.check_stride + # and actual.stride() != expected.stride() + # ): + # raise_mismatch_error("stride()", actual.stride(), expected.stride()) + + if self.check_device and actual.device != expected.device: + raise_mismatch_error("device", actual.device, expected.device) + + if self.check_dtype and actual.dtype != expected.dtype: + raise_mismatch_error("dtype", actual.dtype, expected.dtype) + + def _equalize_attributes( + self, actual: core.Tensor, expected: core.Tensor + ) -> tuple[core.Tensor, core.Tensor]: + """Equalizes some attributes of two tensors for value comparison. + + If ``actual`` and ``expected`` are ... + + - ... not on the same :attr:`~core.Tensor.device`, they are moved CPU memory. + - ... not of the same ``dtype``, they are promoted to a common ``dtype`` (according to + :func:`core.promote_types`). + - ... not of the same ``layout``, they are converted to strided tensors. + + Args: + actual (Tensor): Actual tensor. + expected (Tensor): Expected tensor. + + Returns: + (Tuple[Tensor, Tensor]): Equalized tensors. + """ + # The comparison logic uses operators currently not supported by the MPS backends. + # See https://github.com/pytorch/pytorch/issues/77144 for details. + # TODO: Remove this conversion as soon as all operations are supported natively by the MPS backend + + if actual.device != expected.device: + actual = actual.cpu() + expected = expected.cpu() + + if actual.dtype != expected.dtype: + actual_dtype = actual.dtype + expected_dtype = expected.dtype + # For uint64, this is not sound in general, which is why promote_types doesn't + # allow it, but for easy testing, we're unlikely to get confused + # by large uint64 overflowing into negative int64 + if actual_dtype in [core.uint64, core.uint32, core.uint16]: + actual_dtype = core.int64 + if expected_dtype in [core.uint64, core.uint32, core.uint16]: + expected_dtype = core.int64 + dtype = core.promote_types(actual_dtype, expected_dtype) + actual = actual.to(dtype) + expected = expected.to(dtype) + + if actual.layout != expected.layout: + # These checks are needed, since Tensor.to_dense() fails on tensors that are already strided + actual = actual.to_dense() if actual.layout != core.strided else actual + expected = ( + expected.to_dense() if expected.layout != core.strided else expected + ) + + return actual, expected + + def _compare_values(self, actual: core.Tensor, expected: core.Tensor) -> None: + if actual.dtype.is_floating_point and actual.dtype.itemsize == 1: + + def bitwise_comp( + actual: core.Tensor, + expected: core.Tensor, + *, + rtol: float, + atol: float, + equal_nan: bool, + identifier: Optional[Union[str, Callable[[str], str]]] = None, + ) -> None: + if rtol != 0.0 or atol != 0.0: + raise ErrorMeta( + AssertionError, + f"Rtol={rtol} and atol={atol} are not supported for bitwise comparison of low" + " dimensional floats. Please use rtol=0.0 and atol=0.0.", + ) + + return self._compare_regular_values_close( + actual, + expected, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + identifier=identifier, + ) + + compare_fn = bitwise_comp + else: + compare_fn = self._compare_regular_values_close + + compare_fn( + actual, expected, rtol=self.rtol, atol=self.atol, equal_nan=self.equal_nan + ) + + def _compare_quantized_values( + self, + actual: core.Tensor, + expected: core.Tensor, + *, + rtol: float, + atol: float, + equal_nan: bool, + ) -> None: + """Compares quantized tensors by comparing the :meth:`~core.Tensor.dequantize`'d variants for closeness. + + .. note:: + + A detailed discussion about why only the dequantized variant is checked for closeness rather than checking + the individual quantization parameters for closeness and the integer representation for equality can be + found in https://github.com/pytorch/pytorch/issues/68548. + """ + return self._compare_regular_values_close( + actual.dequantize(), + expected.dequantize(), + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + identifier=lambda default_identifier: f"Quantized {default_identifier.lower()}", + ) + + def _compare_sparse_coo_values( + self, + actual: core.Tensor, + expected: core.Tensor, + *, + rtol: float, + atol: float, + equal_nan: bool, + ) -> None: + """Compares sparse COO tensors by comparing + + - the number of sparse dimensions, + - the number of non-zero elements (nnz) for equality, + - the indices for equality, and + - the values for closeness. + """ + if actual.sparse_dim() != expected.sparse_dim(): + self._fail( + AssertionError, + ( + f"The number of sparse dimensions in sparse COO tensors does not match: " + f"{actual.sparse_dim()} != {expected.sparse_dim()}" + ), + ) + + if actual._nnz() != expected._nnz(): + self._fail( + AssertionError, + ( + f"The number of specified values in sparse COO tensors does not match: " + f"{actual._nnz()} != {expected._nnz()}" + ), + ) + + self._compare_regular_values_equal( + actual._indices(), + expected._indices(), + identifier="Sparse COO indices", + ) + self._compare_regular_values_close( + actual._values(), + expected._values(), + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + identifier="Sparse COO values", + ) + + def _compare_sparse_compressed_values( + self, + actual: core.Tensor, + expected: core.Tensor, + *, + rtol: float, + atol: float, + equal_nan: bool, + ) -> None: + """Compares sparse compressed tensors by comparing + + - the number of non-zero elements (nnz) for equality, + - the plain indices for equality, + - the compressed indices for equality, and + - the values for closeness. + """ + format_name, compressed_indices_method, plain_indices_method = { + core.sparse_csr: ( + "CSR", + core.Tensor.crow_indices, + core.Tensor.col_indices, + ), + core.sparse_csc: ( + "CSC", + core.Tensor.ccol_indices, + core.Tensor.row_indices, + ), + core.sparse_bsr: ( + "BSR", + core.Tensor.crow_indices, + core.Tensor.col_indices, + ), + core.sparse_bsc: ( + "BSC", + core.Tensor.ccol_indices, + core.Tensor.row_indices, + ), + }[actual.layout] + + if actual._nnz() != expected._nnz(): + self._fail( + AssertionError, + ( + f"The number of specified values in sparse {format_name} tensors does not match: " + f"{actual._nnz()} != {expected._nnz()}" + ), + ) + + # Compressed and plain indices in the CSR / CSC / BSR / BSC sparse formats can be `core.int32` _or_ + # `core.int64`. While the same dtype is enforced for the compressed and plain indices of a single tensor, it + # can be different between two tensors. Thus, we need to convert them to the same dtype, or the comparison will + # fail. + actual_compressed_indices = compressed_indices_method(actual) + expected_compressed_indices = compressed_indices_method(expected) + indices_dtype = core.promote_types( + actual_compressed_indices.dtype, expected_compressed_indices.dtype + ) + + self._compare_regular_values_equal( + actual_compressed_indices.to(indices_dtype), + expected_compressed_indices.to(indices_dtype), + identifier=f"Sparse {format_name} {compressed_indices_method.__name__}", + ) + self._compare_regular_values_equal( + plain_indices_method(actual).to(indices_dtype), + plain_indices_method(expected).to(indices_dtype), + identifier=f"Sparse {format_name} {plain_indices_method.__name__}", + ) + self._compare_regular_values_close( + actual.values(), + expected.values(), + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + identifier=f"Sparse {format_name} values", + ) + + def _compare_regular_values_equal( + self, + actual: core.Tensor, + expected: core.Tensor, + *, + equal_nan: bool = False, + identifier: Optional[Union[str, Callable[[str], str]]] = None, + ) -> None: + """Checks if the values of two tensors are equal.""" + self._compare_regular_values_close( + actual, expected, rtol=0, atol=0, equal_nan=equal_nan, identifier=identifier + ) + + def _compare_regular_values_close( + self, + actual: core.Tensor, + expected: core.Tensor, + *, + rtol: float, + atol: float, + equal_nan: bool, + identifier: Optional[Union[str, Callable[[str], str]]] = None, + ) -> None: + """Checks if the values of two tensors are close up to a desired tolerance.""" + matches = core.isclose( + actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan + ) + if core.all(matches): + return + + if actual.shape == core.Size([]): + msg = make_scalar_mismatch_msg( + actual.item(), + expected.item(), + rtol=rtol, + atol=atol, + identifier=identifier, + ) + else: + msg = make_tensor_mismatch_msg( + actual, expected, matches, rtol=rtol, atol=atol, identifier=identifier + ) + self._fail(AssertionError, msg) + + def extra_repr(self) -> Sequence[str]: + return ( + "rtol", + "atol", + "equal_nan", + "check_device", + "check_dtype", + "check_layout", + "check_stride", + ) + + +def originate_pairs( + actual: Any, + expected: Any, + *, + pair_types: Sequence[type[Pair]], + sequence_types: tuple[type, ...] = (collections.abc.Sequence,), + mapping_types: tuple[type, ...] = (collections.abc.Mapping,), + id: tuple[Any, ...] = (), + **options: Any, +) -> list[Pair]: + """Originates pairs from the individual inputs. + + ``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or + :class:`~collections.abc.Mapping`'s. In this case the pairs are originated by recursing through them. + + Args: + actual (Any): Actual input. + expected (Any): Expected input. + pair_types (Sequence[Type[Pair]]): Sequence of pair types that will be tried to construct with the inputs. + First successful pair will be used. + sequence_types (Tuple[Type, ...]): Optional types treated as sequences that will be checked elementwise. + mapping_types (Tuple[Type, ...]): Optional types treated as mappings that will be checked elementwise. + id (Tuple[Any, ...]): Optional id of a pair that will be included in an error message. + **options (Any): Options passed to each pair during construction. + + Raises: + ErrorMeta: With :class`AssertionError`, if the inputs are :class:`~collections.abc.Sequence`'s, but their + length does not match. + ErrorMeta: With :class`AssertionError`, if the inputs are :class:`~collections.abc.Mapping`'s, but their set of + keys do not match. + ErrorMeta: With :class`TypeError`, if no pair is able to handle the inputs. + ErrorMeta: With any expected exception that happens during the construction of a pair. + + Returns: + (List[Pair]): Originated pairs. + """ + # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: + # "a" == "a"[0][0]... + if ( + isinstance(actual, sequence_types) + and not isinstance(actual, str) + and isinstance(expected, sequence_types) + and not isinstance(expected, str) + ): + actual_len = len(actual) # type: ignore[arg-type] + expected_len = len(expected) # type: ignore[arg-type] + if actual_len != expected_len: + raise ErrorMeta( + AssertionError, + f"The length of the sequences mismatch: {actual_len} != {expected_len}", + id=id, + ) + + pairs = [] + for idx in range(actual_len): + pairs.extend( + originate_pairs( + actual[idx], # type: ignore[index] + expected[idx], # type: ignore[index] + pair_types=pair_types, + sequence_types=sequence_types, + mapping_types=mapping_types, + id=(*id, idx), + **options, + ) + ) + return pairs + + elif isinstance(actual, mapping_types) and isinstance(expected, mapping_types): + actual_keys = set(actual.keys()) # type: ignore[attr-defined] + expected_keys = set(expected.keys()) # type: ignore[attr-defined] + if actual_keys != expected_keys: + missing_keys = expected_keys - actual_keys + additional_keys = actual_keys - expected_keys + raise ErrorMeta( + AssertionError, + ( + f"The keys of the mappings do not match:\n" + f"Missing keys in the actual mapping: {sorted(missing_keys)}\n" + f"Additional keys in the actual mapping: {sorted(additional_keys)}" + ), + id=id, + ) + + keys: Collection = actual_keys + # Since the origination aborts after the first failure, we try to be deterministic + with contextlib.suppress(Exception): + keys = sorted(keys) + + pairs = [] + for key in keys: + pairs.extend( + originate_pairs( + actual[key], # type: ignore[index] + expected[key], # type: ignore[index] + pair_types=pair_types, + sequence_types=sequence_types, + mapping_types=mapping_types, + id=(*id, key), + **options, + ) + ) + return pairs + + else: + for pair_type in pair_types: + try: + return [pair_type(actual, expected, id=id, **options)] + # Raising an `UnsupportedInputs` during origination indicates that the pair type is not able to handle the + # inputs. Thus, we try the next pair type. + except UnsupportedInputs: + continue + # Raising an `ErrorMeta` during origination is the orderly way to abort and so we simply re-raise it. This + # is only in a separate branch, because the one below would also except it. + except ErrorMeta: + raise + # Raising any other exception during origination is unexpected and will give some extra information about + # what happened. If applicable, the exception should be expected in the future. + except Exception as error: + print(error) + raise RuntimeError( + f"Originating a {pair_type.__name__}() at item {''.join(str([item]) for item in id)} with\n\n" + f"{type(actual).__name__}(): {actual}\n\n" + f"and\n\n" + f"{type(expected).__name__}(): {expected}\n\n" + f"resulted in the unexpected exception above. " + f"If you are a user and see this message during normal operation " + "please file an issue at https://github.com/pytorch/pytorch/issues. " + "If you are a developer and working on the comparison functions, " + "please except the previous error and raise an expressive `ErrorMeta` instead." + ) from error + else: + raise ErrorMeta( + TypeError, + f"No comparison pair was able to handle inputs of type {type(actual)} and {type(expected)}.", + id=id, + ) + + +def not_close_error_metas( + actual: Any, + expected: Any, + *, + pair_types: Sequence[type[Pair]] = (ObjectPair,), + sequence_types: tuple[type, ...] = (collections.abc.Sequence,), + mapping_types: tuple[type, ...] = (collections.abc.Mapping,), + **options: Any, +) -> list[ErrorMeta]: + """Asserts that inputs are equal. + + ``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or + :class:`~collections.abc.Mapping`'s. In this case the comparison happens elementwise by recursing through them. + + Args: + actual (Any): Actual input. + expected (Any): Expected input. + pair_types (Sequence[Type[Pair]]): Sequence of :class:`Pair` types that will be tried to construct with the + inputs. First successful pair will be used. Defaults to only using :class:`ObjectPair`. + sequence_types (Tuple[Type, ...]): Optional types treated as sequences that will be checked elementwise. + mapping_types (Tuple[Type, ...]): Optional types treated as mappings that will be checked elementwise. + **options (Any): Options passed to each pair during construction. + """ + # Hide this function from `pytest`'s traceback + __tracebackhide__ = True + + try: + pairs = originate_pairs( + actual, + expected, + pair_types=pair_types, + sequence_types=sequence_types, + mapping_types=mapping_types, + **options, + ) + except ErrorMeta as error_meta: + # Explicitly raising from None to hide the internal traceback + raise error_meta.to_error() from None # noqa: RSE102 + + error_metas: list[ErrorMeta] = [] + for pair in pairs: + try: + pair.compare() + except ErrorMeta as error_meta: + error_metas.append(error_meta) + # Raising any exception besides `ErrorMeta` while comparing is unexpected and will give some extra information + # about what happened. If applicable, the exception should be expected in the future. + except Exception as error: + raise RuntimeError( + f"Comparing\n\n" + f"{pair}\n\n" + f"resulted in the unexpected exception above. " + f"If you are a user and see this message during normal operation " + "please file an issue at https://github.com/pytorch/pytorch/issues. " + "If you are a developer and working on the comparison functions, " + "please except the previous error and raise an expressive `ErrorMeta` instead." + ) from error + + # [ErrorMeta Cycles] + # ErrorMeta objects in this list capture + # tracebacks that refer to the frame of this function. + # The local variable `error_metas` refers to the error meta + # objects, creating a reference cycle. Frames in the traceback + # would not get freed until cycle collection, leaking cuda memory in tests. + # We break the cycle by removing the reference to the error_meta objects + # from this frame as it returns. + error_metas = [error_metas] + return error_metas.pop() + + +def assert_close( + actual: Any, + expected: Any, + *, + allow_subclasses: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, + equal_nan: bool = False, + check_device: bool = True, + check_dtype: bool = True, + check_layout: bool = True, + check_stride: bool = False, + msg: Optional[Union[str, Callable[[str], str]]] = None, +): + r"""Asserts that ``actual`` and ``expected`` are close. + + If ``actual`` and ``expected`` are strided, non-quantized, real-valued, and finite, they are considered close if + + .. math:: + + \lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert + + Non-finite values (``-inf`` and ``inf``) are only considered close if and only if they are equal. ``NaN``'s are + only considered equal to each other if ``equal_nan`` is ``True``. + + In addition, they are only considered close if they have the same + + - :attr:`~core.Tensor.device` (if ``check_device`` is ``True``), + - ``dtype`` (if ``check_dtype`` is ``True``), + - ``layout`` (if ``check_layout`` is ``True``), and + - stride (if ``check_stride`` is ``True``). + + If either ``actual`` or ``expected`` is a meta tensor, only the attribute checks will be performed. + + If ``actual`` and ``expected`` are sparse (either having COO, CSR, CSC, BSR, or BSC layout), their strided members are + checked individually. Indices, namely ``indices`` for COO, ``crow_indices`` and ``col_indices`` for CSR and BSR, + or ``ccol_indices`` and ``row_indices`` for CSC and BSC layouts, respectively, + are always checked for equality whereas the values are checked for closeness according to the definition above. + + If ``actual`` and ``expected`` are quantized, they are considered close if they have the same + :meth:`~core.Tensor.qscheme` and the result of :meth:`~core.Tensor.dequantize` is close according to the + definition above. + + ``actual`` and ``expected`` can be :class:`~core.Tensor`'s or any tensor-or-scalar-likes from which + :class:`core.Tensor`'s can be constructed with :func:`core.as_tensor`. Except for Python scalars the input types + have to be directly related. In addition, ``actual`` and ``expected`` can be :class:`~collections.abc.Sequence`'s + or :class:`~collections.abc.Mapping`'s in which case they are considered close if their structure matches and all + their elements are considered close according to the above definition. + + .. note:: + + Python scalars are an exception to the type relation requirement, because their :func:`type`, i.e. + :class:`int`, :class:`float`, and :class:`complex`, is equivalent to the ``dtype`` of a tensor-like. Thus, + Python scalars of different types can be checked, but require ``check_dtype=False``. + + Args: + actual (Any): Actual input. + expected (Any): Expected input. + allow_subclasses (bool): If ``True`` (default) and except for Python scalars, inputs of directly related types + are allowed. Otherwise type equality is required. + rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default + values based on the :attr:`~core.Tensor.dtype` are selected with the below table. + atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default + values based on the :attr:`~core.Tensor.dtype` are selected with the below table. + equal_nan (Union[bool, str]): If ``True``, two ``NaN`` values will be considered equal. + check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same + :attr:`~core.Tensor.device`. If this check is disabled, tensors on different + :attr:`~core.Tensor.device`'s are moved to the CPU before being compared. + check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this + check is disabled, tensors with different ``dtype``'s are promoted to a common ``dtype`` (according to + :func:`core.promote_types`) before being compared. + check_layout (bool): If ``True`` (default), asserts that corresponding tensors have the same ``layout``. If this + check is disabled, tensors with different ``layout``'s are converted to strided tensors before being + compared. + check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride. + msg (Optional[Union[str, Callable[[str], str]]]): Optional error message to use in case a failure occurs during + the comparison. Can also passed as callable in which case it will be called with the generated message and + should return the new message. + + Raises: + ValueError: If no :class:`core.Tensor` can be constructed from an input. + ValueError: If only ``rtol`` or ``atol`` is specified. + AssertionError: If corresponding inputs are not Python scalars and are not directly related. + AssertionError: If ``allow_subclasses`` is ``False``, but corresponding inputs are not Python scalars and have + different types. + AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match. + AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys do not match. + AssertionError: If corresponding tensors do not have the same :attr:`~core.Tensor.shape`. + AssertionError: If ``check_layout`` is ``True``, but corresponding tensors do not have the same + :attr:`~core.Tensor.layout`. + AssertionError: If only one of corresponding tensors is quantized. + AssertionError: If corresponding tensors are quantized, but have different :meth:`~core.Tensor.qscheme`'s. + AssertionError: If ``check_device`` is ``True``, but corresponding tensors are not on the same + :attr:`~core.Tensor.device`. + AssertionError: If ``check_dtype`` is ``True``, but corresponding tensors do not have the same ``dtype``. + AssertionError: If ``check_stride`` is ``True``, but corresponding strided tensors do not have the same stride. + AssertionError: If the values of corresponding tensors are not close according to the definition above. + + The following table displays the default ``rtol`` and ``atol`` for different ``dtype``'s. In case of mismatching + ``dtype``'s, the maximum of both tolerances is used. + + +---------------------------+------------+----------+ + | ``dtype`` | ``rtol`` | ``atol`` | + +===========================+============+==========+ + | :attr:`~core.float16` | ``1e-3`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~core.bfloat16` | ``1.6e-2`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~core.float32` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~core.float64` | ``1e-7`` | ``1e-7`` | + +---------------------------+------------+----------+ + | :attr:`~core.complex32` | ``1e-3`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~core.complex64` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~core.complex128` | ``1e-7`` | ``1e-7`` | + +---------------------------+------------+----------+ + | :attr:`~core.quint8` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~core.quint2x4` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~core.quint4x2` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~core.qint8` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~core.qint32` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | other | ``0.0`` | ``0.0`` | + +---------------------------+------------+----------+ + + .. note:: + + :func:`~core.testing.assert_close` is highly configurable with strict default settings. Users are encouraged + to :func:`~functools.partial` it to fit their use case. For example, if an equality check is needed, one might + define an ``assert_equal`` that uses zero tolerances for every ``dtype`` by default: + + >>> import functools + >>> assert_equal = functools.partial(core.testing.assert_close, rtol=0, atol=0) + >>> assert_equal(1e-9, 1e-10) + Traceback (most recent call last): + ... + AssertionError: Scalars are not equal! + + Expected 1e-10 but got 1e-09. + Absolute difference: 9.000000000000001e-10 + Relative difference: 9.0 + + Examples: + >>> # tensor to tensor comparison + >>> expected = core.tensor([1e0, 1e-1, 1e-2]) + >>> actual = core.acos(core.cos(expected)) + >>> core.testing.assert_close(actual, expected) + + >>> # scalar to scalar comparison + >>> import math + >>> expected = math.sqrt(2.0) + >>> actual = 2.0 / math.sqrt(2.0) + >>> core.testing.assert_close(actual, expected) + + >>> # numpy array to numpy array comparison + >>> import numpy as np + >>> expected = np.array([1e0, 1e-1, 1e-2]) + >>> actual = np.arccos(np.cos(expected)) + >>> core.testing.assert_close(actual, expected) + + >>> # sequence to sequence comparison + >>> import numpy as np + >>> # The types of the sequences do not have to match. They only have to have the same + >>> # length and their elements have to match. + >>> expected = [core.tensor([1.0]), 2.0, np.array(3.0)] + >>> actual = tuple(expected) + >>> core.testing.assert_close(actual, expected) + + >>> # mapping to mapping comparison + >>> from collections import OrderedDict + >>> import numpy as np + >>> foo = core.tensor(1.0) + >>> bar = 2.0 + >>> baz = np.array(3.0) + >>> # The types and a possible ordering of mappings do not have to match. They only + >>> # have to have the same set of keys and their elements have to match. + >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)]) + >>> actual = {"baz": baz, "bar": bar, "foo": foo} + >>> core.testing.assert_close(actual, expected) + + >>> expected = core.tensor([1.0, 2.0, 3.0]) + >>> actual = expected.clone() + >>> # By default, directly related instances can be compared + >>> core.testing.assert_close(core.nn.Parameter(actual), expected) + >>> # This check can be made more strict with allow_subclasses=False + >>> core.testing.assert_close( + ... core.nn.Parameter(actual), expected, allow_subclasses=False + ... ) + Traceback (most recent call last): + ... + TypeError: No comparison pair was able to handle inputs of type + and . + >>> # If the inputs are not directly related, they are never considered close + >>> core.testing.assert_close(actual.numpy(), expected) + Traceback (most recent call last): + ... + TypeError: No comparison pair was able to handle inputs of type + and . + >>> # Exceptions to these rules are Python scalars. They can be checked regardless of + >>> # their type if check_dtype=False. + >>> core.testing.assert_close(1.0, 1, check_dtype=False) + + >>> # NaN != NaN by default. + >>> expected = core.tensor(float("Nan")) + >>> actual = expected.clone() + >>> core.testing.assert_close(actual, expected) + Traceback (most recent call last): + ... + AssertionError: Scalars are not close! + + Expected nan but got nan. + Absolute difference: nan (up to 1e-05 allowed) + Relative difference: nan (up to 1.3e-06 allowed) + >>> core.testing.assert_close(actual, expected, equal_nan=True) + + >>> expected = core.tensor([1.0, 2.0, 3.0]) + >>> actual = core.tensor([1.0, 4.0, 5.0]) + >>> # The default error message can be overwritten. + >>> core.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") + Traceback (most recent call last): + ... + AssertionError: Argh, the tensors are not close! + >>> # If msg is a callable, it can be used to augment the generated message with + >>> # extra information + >>> core.testing.assert_close( + ... actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter" + ... ) + Traceback (most recent call last): + ... + AssertionError: Header + + Tensor-likes are not close! + + Mismatched elements: 2 / 3 (66.7%) + Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed) + Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed) + + Footer + """ + # Hide this function from `pytest`'s traceback + __tracebackhide__ = True + + error_metas = not_close_error_metas( + actual, + expected, + pair_types=( + NonePair, + BooleanPair, + NumberPair, + TensorLikePair, + ), + allow_subclasses=allow_subclasses, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + check_device=check_device, + check_dtype=check_dtype, + check_layout=check_layout, + check_stride=check_stride, + msg=msg, + ) + + if error_metas: + # TODO: compose all metas into one AssertionError + raise error_metas[0].to_error(msg) + + +@deprecated( + "`core.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. " + "Please use `core.testing.assert_close()` instead. " + "You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.", + category=FutureWarning, +) +def assert_allclose( + actual: Any, + expected: Any, + rtol: Optional[float] = None, + atol: Optional[float] = None, + equal_nan: bool = True, + msg: str = "", +) -> None: + """ + .. warning:: + + :func:`core.testing.assert_allclose` is deprecated since ``1.12`` and will be removed in a future release. + Please use :func:`core.testing.assert_close` instead. You can find detailed upgrade instructions + `here `_. + """ + if not isinstance(actual, core.Tensor): + actual = core.tensor(actual) + if not isinstance(expected, core.Tensor): + expected = core.tensor(expected, dtype=actual.dtype) + + if rtol is None and atol is None: + rtol, atol = default_tolerances( + actual, + expected, + dtype_precisions={ + core.float16: (1e-3, 1e-3), + core.float32: (1e-4, 1e-5), + core.float64: (1e-5, 1e-8), + }, + ) + + core.testing.assert_close( + actual, + expected, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + check_device=True, + check_dtype=False, + check_stride=False, + msg=msg or None, + ) \ No newline at end of file diff --git a/mindnlp/core/types.py b/mindnlp/core/types.py index 2ab5afc11..1fb332d12 100644 --- a/mindnlp/core/types.py +++ b/mindnlp/core/types.py @@ -50,6 +50,15 @@ def __eq__(self, __value): def __hash__(self): return hash(self.type) ^ hash(self.index) + def __enter__(self): + # self.prev_idx = torch.cuda._exchange_device(self.idx) + pass + + def __exit__(self, type: Any, value: Any, traceback: Any): + # self.idx = torch.cuda._maybe_exchange_device(self.prev_idx) + return False + + # Meta-type for "numeric" things; matches our docs Number: TypeAlias = Union[int, float, bool] # tuple for isinstance(x, Number) checks. diff --git a/mindnlp/transformers/__init__.py b/mindnlp/transformers/__init__.py index c4c6fc82d..f43b09198 100644 --- a/mindnlp/transformers/__init__.py +++ b/mindnlp/transformers/__init__.py @@ -4111,4 +4111,8 @@ extra_objects={"__version__": transformers.__version__}, ) -transformers.utils.import_utils._torch_fx_available = False \ No newline at end of file +def not_supported(): + return False + +transformers.utils.import_utils._torch_fx_available = False +transformers.utils.import_utils.is_torch_sdpa_available = not_supported \ No newline at end of file diff --git a/mindnlp/utils/torch_proxy.py b/mindnlp/utils/torch_proxy.py index cca325d9e..52de6bf54 100644 --- a/mindnlp/utils/torch_proxy.py +++ b/mindnlp/utils/torch_proxy.py @@ -2,44 +2,85 @@ import types import importlib import importlib.metadata +from collections import defaultdict class TorchProxyModule(types.ModuleType): - def __init__(self): - super().__init__("torch") - # 保存真实模块的引用 - self._real_module = None + """递归代理模块,支持任意深度的模块路径""" - def _load_real_module(self): - """按需加载真实模块""" - if self._real_module is None: - # 尝试直接导入mindnlp.core作为torch - self._real_module = importlib.import_module("mindnlp.core") - # 添加必要的元数据属性 - self._real_module.__name__ = "torch" - self._real_module.__package__ = "torch" - self._real_module.__file__ = "" + # 缓存已创建的代理模块 + _proxy_cache = defaultdict(dict) + + def __new__(cls, real_module, proxy_name): + """使用缓存避免重复创建代理""" + # 生成缓存键:真实模块ID + 代理名称 + cache_key = (id(real_module), proxy_name) + + # 如果已存在缓存,直接返回 + if cache_key in cls._proxy_cache[real_module]: + return cls._proxy_cache[real_module][cache_key] - return self._real_module + # 创建新实例并缓存 + instance = super().__new__(cls, proxy_name) + cls._proxy_cache[real_module][cache_key] = instance + return instance + def __init__(self, real_module, proxy_name): + """初始化代理模块""" + super().__init__(proxy_name) + self._real_module = real_module + self._proxy_name = proxy_name + self._submodule_proxies = {} + + # 设置关键元数据 + self.__name__ = proxy_name + self.__package__ = proxy_name + self.__file__ = "" + def __getattr__(self, name): - """任何属性访问都重定向到真实模块""" - # 处理特殊元数据属性 - if name in {"__name__", "__package__", "__file__"}: - return getattr(self._load_real_module(), name) + """动态获取属性并创建子模块代理""" + # 1. 尝试从真实模块获取属性 + try: + real_attr = getattr(self._real_module, name) + except AttributeError: + raise AttributeError( + f"module '{self._proxy_name}' has no attribute '{name}'" + ) + + # 2. 如果是模块类型,创建递归代理 + if isinstance(real_attr, types.ModuleType): + # 构建子模块的代理名称 + sub_proxy_name = f"{self._proxy_name}.{name}" - return getattr(self._load_real_module(), name) + # 创建或获取子模块代理 + if name not in self._submodule_proxies: + self._submodule_proxies[name] = TorchProxyModule( + real_attr, + sub_proxy_name + ) + + return self._submodule_proxies[name] + + # 4. 其他类型直接返回 + return real_attr def __setattr__(self, name, value): - """属性设置也重定向到真实模块""" - # 跳过自身内部属性 - if name in {"_real_module", "__name__", "__package__", "__file__"}: + """处理属性设置""" + # 内部属性直接设置 + if name in {"_real_module", "_proxy_name", "_submodule_proxies"}: super().__setattr__(name, value) - else: - setattr(self._load_real_module(), name, value) - + return + + # 其他属性设置到真实模块 + if name not in self._submodule_proxies: + setattr(self._real_module, name, value) + def __dir__(self): """返回真实模块的属性列表""" - return dir(self._load_real_module()) + return dir(self._real_module) + + def __repr__(self): + """友好的代理模块表示""" + return f"" def __getattribute__(self, name): """特殊处理元数据相关属性""" @@ -52,8 +93,8 @@ def __getattribute__(self, name): return super().__getattribute__(name) def initialize_torch_proxy(): - - torch_proxy = TorchProxyModule() + import mindnlp + torch_proxy = TorchProxyModule(mindnlp.core, 'torch') sys.modules["torch"] = torch_proxy # 设置必要的元数据