diff --git a/nncf/experimental/tensor/README.md b/nncf/experimental/tensor/README.md index ea8ae2168c9..3eac8d575d9 100644 --- a/nncf/experimental/tensor/README.md +++ b/nncf/experimental/tensor/README.md @@ -32,8 +32,6 @@ tenor_b = Tensor(np.array([1,2])) tensor_a + tenor_b # Tensor(array([2, 4])) ``` -**NOTE** Division operations for the numpy backend are performed with warnings disabled for the same for all backends. - ### Comparison operators All math operations are overrided to operated with wrapped object and return `Tensor` @@ -108,7 +106,7 @@ tensor_a[0:2] # Tensor(array([[1],[2]])) 2. Add function to [function.py](function.py) ```python - @functools.singledispatch + @tensor_dispatch() def foo(a: TTensor, arg1: Type) -> TTensor: """ __description__ @@ -117,21 +115,15 @@ tensor_a[0:2] # Tensor(array([[1],[2]])) :param arg1: __description__ :return: __description__ """ - if isinstance(a, tensor.Tensor): - return tensor.Tensor(foo(a.data, axis)) - return NotImplemented(f"Function `foo` is not implemented for {type(a)}") ``` - **NOTE** For the case when the first argument has type `List[Tensor]`, use the `_dispatch_list` function. This function dispatches function by first element in the first argument. + **NOTE** To control work with Tensors, different types of wrapper functions can be selected + `@tensor_dispatch(wrapper_type=WrapperType.TensorToTensor)`: - ```python - @functools.singledispatch - def foo(x: List[Tensor], axis: int = 0) -> Tensor: - if isinstance(x, List): - unwrapped_x = [i.data for i in x] - return Tensor(_dispatch_list(foo, unwrapped_x, axis=axis)) - raise NotImplementedError(f"Function `foo` is not implemented for {type(x)}") - ``` + - `WrapperType.TensorToTensor` (default) expects Tensor as first argument, result will be wrapped to Tensor. + - `WrapperType.TensorToAny` expects Tensor as first argument, result will not be wrapped to Tensor. + - `WrapperType.TensorToList` expects Tensor as first argument, each element in result list will be wrapped to Tensor. + - `WrapperType.ListToTensor` expects List of Tensors as first argument, result will be wrapped to Tensor. 3. Add backend specific implementation of method to: diff --git a/nncf/experimental/tensor/dispatcher.py b/nncf/experimental/tensor/dispatcher.py new file mode 100644 index 00000000000..234223e38d9 --- /dev/null +++ b/nncf/experimental/tensor/dispatcher.py @@ -0,0 +1,155 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import types +import weakref +from abc import get_cache_token +from enum import Enum +from enum import auto +from functools import _find_impl +from functools import update_wrapper +from typing import Callable, List, Optional, Type, Union + +from nncf.experimental.tensor import Tensor + + +class WrapperType(Enum): + TensorToTensor = auto() + TensorToAny = auto() + TensorToList = auto() + ListToTensor = auto() + OnlyDispatch = auto() + + +def tensor_dispatch(wrapper_type: WrapperType = WrapperType.TensorToTensor) -> Callable: + """Custom implementation of functools.singledispatch function decorator. + + Transforms a function into a generic function, which can have different + behaviours depending upon the type of its first argument. The decorated + function acts as the default implementation, and additional + implementations can be registered using the register() attribute of the + generic function. + + To control work with Tensors, different types of wrapper functions can be selected: + TensorToTensor - expects Tensor as first argument, result will be wrapped to Tensor. + TensorToAny - expects Tensor as first argument, result will not be wrapped to Tensor. + TensorToList - expects Tensor as first argument, each element in result list will be wrapped to Tensor. + ListToTensor - expects List of Tensors as first argument, result will be wrapped to Tensor. + + For not registered types will be raised NotImplementedError. + + In case of the first argument is not wrapped to Tensor will call backend specific function directory. + + :param wrapper_type: Type of wrapper function, defaults TensorToTensor. + """ + + def decorator(func: Callable) -> Callable: + registry = {} + dispatch_cache = weakref.WeakKeyDictionary() + cache_token = None + + def dispatch(cls: Type) -> Callable: + """generic_func.dispatch(cls) -> + + Runs the dispatch algorithm to return the best available implementation + for the given *cls* registered on *generic_func*. + """ + nonlocal cache_token + if cache_token is not None: + current_token = get_cache_token() + if cache_token != current_token: + dispatch_cache.clear() + cache_token = current_token + try: + impl = dispatch_cache[cls] + except KeyError: + try: + impl = registry[cls] + except KeyError: + impl = _find_impl(cls, registry) + dispatch_cache[cls] = impl + return impl + + def register(cls: Type, func: Optional[Callable] = None): + """generic_func.register(cls, func) -> func + + Registers a new implementation for the given *cls* on a *generic_func*. + + """ + nonlocal cache_token + if func is None: + if isinstance(cls, type): + return lambda f: register(cls, f) + ann = getattr(cls, "__annotations__", {}) + if not ann: + raise TypeError( + f"Invalid first argument to `register()`: {cls!r}. " + f"Use either `@register(some_class)` or plain `@register` " + f"on an annotated function." + ) + func = cls + + # only import typing if annotation parsing is necessary + from typing import get_type_hints + + argname, cls = next(iter(get_type_hints(func).items())) + if not isinstance(cls, type): + raise TypeError(f"Invalid annotation for {argname!r}. " f"{cls!r} is not a class.") + registry[cls] = func + if cache_token is None and hasattr(cls, "__abstractmethods__"): + cache_token = get_cache_token() + dispatch_cache.clear() + return func + + def wrapper_tensor_to_tensor(tensor: Tensor, *args, **kw): + args = tuple(x.data if isinstance(x, Tensor) else x for x in args) + return Tensor(dispatch(tensor.data.__class__)(tensor.data, *args, **kw)) + + def wrapper_tensor_to_any(tensor: Tensor, *args, **kw): + args = tuple(x.data if isinstance(x, Tensor) else x for x in args) + return dispatch(tensor.data.__class__)(tensor.data, *args, **kw) + + def wrapper_tensor_to_list(tensor: Tensor, *args, **kw): + args = tuple(x.data if isinstance(x, Tensor) else x for x in args) + return [Tensor(x) for x in dispatch(tensor.data.__class__)(tensor.data, *args, **kw)] + + def wrapper_list_to_tensor(list_of_tensors: List[Tensor], *args, **kw): + list_of_tensors = [x.data for x in list_of_tensors] + return Tensor(dispatch(list_of_tensors[0].__class__)(list_of_tensors, *args, **kw)) + + wrappers_map = { + WrapperType.TensorToTensor: wrapper_tensor_to_tensor, + WrapperType.TensorToAny: wrapper_tensor_to_any, + WrapperType.TensorToList: wrapper_tensor_to_list, + WrapperType.ListToTensor: wrapper_list_to_tensor, + } + + def raise_not_implemented(data: Union[Tensor, List[Tensor]], *args, **kw): + """ + Raising NotImplementedError for not registered type. + """ + if wrapper_type == WrapperType.ListToTensor: + arg_type = type(data[0].data) if isinstance(data[0], Tensor) else type(data[0]) + else: + arg_type = type(data.data) if isinstance(data, Tensor) else type(data) + + raise NotImplementedError(f"Function `{func.__name__}` is not implemented for {arg_type}") + + registry[object] = raise_not_implemented + wrapper = wrappers_map[wrapper_type] + wrapper.register = register + wrapper.dispatch = dispatch + wrapper.registry = types.MappingProxyType(registry) + wrapper._clear_cache = dispatch_cache.clear + update_wrapper(wrapper, func) + return wrapper + + return decorator diff --git a/nncf/experimental/tensor/functions.py b/nncf/experimental/tensor/functions.py index c434de3b1bf..197b7a8f6d2 100644 --- a/nncf/experimental/tensor/functions.py +++ b/nncf/experimental/tensor/functions.py @@ -9,31 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union +from nncf.experimental.tensor.dispatcher import WrapperType +from nncf.experimental.tensor.dispatcher import tensor_dispatch from nncf.experimental.tensor.enums import TensorDataType from nncf.experimental.tensor.enums import TensorDeviceType from nncf.experimental.tensor.tensor import Tensor -from nncf.experimental.tensor.tensor import unwrap_tensor_data -def _tensor_guard(func: callable): - """ - A decorator that ensures that the first argument to the decorated function is a Tensor. - """ - - @functools.wraps(func) - def wrapper(*args, **kwargs): - if isinstance(args[0], Tensor): - return func(*args, **kwargs) - raise NotImplementedError(f"Function `{func.__name__}` is not implemented for {type(args[0])}") - - return wrapper - - -@functools.singledispatch -@_tensor_guard +@tensor_dispatch(wrapper_type=WrapperType.TensorToAny) def device(a: Tensor) -> TensorDeviceType: """ Return the device of the tensor. @@ -41,11 +26,9 @@ def device(a: Tensor) -> TensorDeviceType: :param a: The input tensor. :return: The device of the tensor. """ - return device(a.data) -@functools.singledispatch -@_tensor_guard +@tensor_dispatch() def squeeze(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: """ Remove axes of length one from a. @@ -56,11 +39,9 @@ def squeeze(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Te This is always a itself or a view into a. Note that if all axes are squeezed, the result is a 0d array and not a scalar. """ - return Tensor(squeeze(a.data, axis=axis)) -@functools.singledispatch -@_tensor_guard +@tensor_dispatch() def flatten(a: Tensor) -> Tensor: """ Return a copy of the tensor collapsed into one dimension. @@ -68,41 +49,35 @@ def flatten(a: Tensor) -> Tensor: :param a: The input tensor. :return: A copy of the input tensor, flattened to one dimension. """ - return Tensor(flatten(a.data)) -@functools.singledispatch -@_tensor_guard +@tensor_dispatch() def max(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: Optional[bool] = False) -> Tensor: """ Return the maximum of an array or maximum along an axis. :param a: The input tensor. :param axis: Axis or axes along which to operate. By default, flattened input is used. - :param keepdim: If this is set to True, the axes which are reduced are left in the result as dimensions with size + :param keepdims: If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array. False, by default. :return: Maximum of a. """ - return Tensor(max(a.data, axis, keepdims)) -@functools.singledispatch -@_tensor_guard +@tensor_dispatch() def min(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: Optional[bool] = False) -> Tensor: """ Return the minimum of an array or minimum along an axis. :param a: The input tensor. :param axis: Axis or axes along which to operate. By default, flattened input is used. - :param keepdim: If this is set to True, the axes which are reduced are left in the result as dimensions with size + :param keepdims: If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array. False, by default. :return: Minimum of a. """ - return Tensor(min(a.data, axis, keepdims)) -@functools.singledispatch -@_tensor_guard +@tensor_dispatch() def abs(a: Tensor) -> Tensor: """ Calculate the absolute value element-wise. @@ -110,12 +85,10 @@ def abs(a: Tensor) -> Tensor: :param a: The input tensor. :return: A tensor containing the absolute value of each element in x. """ - return Tensor(abs(a.data)) -@functools.singledispatch -@_tensor_guard -def astype(a: Tensor, data_type: TensorDataType) -> Tensor: +@tensor_dispatch() +def astype(a: Tensor, dtype: TensorDataType) -> Tensor: """ Copy of the tensor, cast to a specified type. @@ -124,11 +97,9 @@ def astype(a: Tensor, data_type: TensorDataType) -> Tensor: :return: Copy of the tensor in specified type. """ - return Tensor(astype(a.data, data_type)) -@functools.singledispatch -@_tensor_guard +@tensor_dispatch() def dtype(a: Tensor) -> TensorDataType: """ Return data type of the tensor. @@ -136,11 +107,9 @@ def dtype(a: Tensor) -> TensorDataType: :param a: The input tensor. :return: The data type of the tensor. """ - return dtype(a.data) -@functools.singledispatch -@_tensor_guard +@tensor_dispatch() def reshape(a: Tensor, shape: Tuple[int, ...]) -> Tensor: """ Gives a new shape to a tensor without changing its data. @@ -149,11 +118,9 @@ def reshape(a: Tensor, shape: Tuple[int, ...]) -> Tensor: :param shape: The new shape should be compatible with the original shape. :return: Reshaped tensor. """ - return Tensor(reshape(a.data, shape)) -@functools.singledispatch -@_tensor_guard +@tensor_dispatch() def all(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: """ Test whether all tensor elements along a given axis evaluate to True. @@ -162,11 +129,9 @@ def all(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor :param axis: Axis or axes along which a logical AND reduction is performed. :return: A new boolean or tensor. """ - return Tensor(all(a.data, axis=axis)) -@functools.singledispatch -@_tensor_guard +@tensor_dispatch() def allclose( a: Tensor, b: Union[Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False ) -> Tensor: @@ -182,19 +147,9 @@ def allclose( Defaults to False. :return: True if the two arrays are equal within the given tolerance, otherwise False. """ - return Tensor( - allclose( - a.data, - unwrap_tensor_data(b), - rtol=rtol, - atol=atol, - equal_nan=equal_nan, - ) - ) -@functools.singledispatch -@_tensor_guard +@tensor_dispatch() def any(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: """ Test whether any tensor elements along a given axis evaluate to True. @@ -203,11 +158,9 @@ def any(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor :param axis: Axis or axes along which a logical OR reduction is performed. :return: A new boolean or tensor. """ - return Tensor(any(a.data, axis)) -@functools.singledispatch -@_tensor_guard +@tensor_dispatch() def count_nonzero(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: """ Counts the number of non-zero values in the tensor input. @@ -217,11 +170,9 @@ def count_nonzero(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) :return: Number of non-zero values in the tensor along a given axis. Otherwise, the total number of non-zero values in the tensor is returned. """ - return Tensor(count_nonzero(a.data, axis)) -@functools.singledispatch -@_tensor_guard +@tensor_dispatch(wrapper_type=WrapperType.TensorToAny) def isempty(a: Tensor) -> bool: """ Return True if input tensor is empty. @@ -229,11 +180,9 @@ def isempty(a: Tensor) -> bool: :param a: The input tensor. :return: True if tensor is empty, otherwise False. """ - return isempty(a.data) -@functools.singledispatch -@_tensor_guard +@tensor_dispatch() def isclose( a: Tensor, b: Union[Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False ) -> Tensor: @@ -249,19 +198,9 @@ def isclose( Defaults to False. :return: Returns a boolean tensor of where a and b are equal within the given tolerance. """ - return Tensor( - isclose( - a.data, - unwrap_tensor_data(b), - rtol=rtol, - atol=atol, - equal_nan=equal_nan, - ) - ) -@functools.singledispatch -@_tensor_guard +@tensor_dispatch() def maximum(x1: Tensor, x2: Union[Tensor, float]) -> Tensor: """ Element-wise maximum of tensor elements. @@ -270,11 +209,9 @@ def maximum(x1: Tensor, x2: Union[Tensor, float]) -> Tensor: :param x2: The second input tensor. :return: Output tensor. """ - return Tensor(maximum(x1.data, unwrap_tensor_data(x2))) -@functools.singledispatch -@_tensor_guard +@tensor_dispatch() def minimum(x1: Tensor, x2: Union[Tensor, float]) -> Tensor: """ Element-wise minimum of tensor elements. @@ -283,11 +220,9 @@ def minimum(x1: Tensor, x2: Union[Tensor, float]) -> Tensor: :param x2: The second input tensor. :return: Output tensor. """ - return Tensor(minimum(x1.data, unwrap_tensor_data(x2))) -@functools.singledispatch -@_tensor_guard +@tensor_dispatch() def ones_like(a: Tensor) -> Tensor: """ Return a tensor of ones with the same shape and type as a given tensor. @@ -295,11 +230,9 @@ def ones_like(a: Tensor) -> Tensor: :param a: The shape and data-type of a define these same attributes of the returned tensor. :return: Tensor of ones with the same shape and type as a. """ - return Tensor(ones_like(a.data)) -@functools.singledispatch -@_tensor_guard +@tensor_dispatch() def where(condition: Tensor, x: Union[Tensor, float], y: Union[Tensor, float]) -> Tensor: """ Return elements chosen from x or y depending on condition. @@ -309,17 +242,9 @@ def where(condition: Tensor, x: Union[Tensor, float], y: Union[Tensor, float]) - :param y: Value at indices where condition is False. :return: A tensor with elements from x where condition is True, and elements from y elsewhere. """ - return Tensor( - where( - condition.data, - unwrap_tensor_data(x), - unwrap_tensor_data(y), - ) - ) -@functools.singledispatch -@_tensor_guard +@tensor_dispatch() def zeros_like(a: Tensor) -> Tensor: """ Return an tensor of zeros with the same shape and type as a given tensor. @@ -327,10 +252,9 @@ def zeros_like(a: Tensor) -> Tensor: :param input: The shape and data-type of a define these same attributes of the returned tensor. :return: tensor of zeros with the same shape and type as a. """ - return Tensor(zeros_like(a.data)) -@functools.singledispatch +@tensor_dispatch(wrapper_type=WrapperType.ListToTensor) def stack(x: List[Tensor], axis: int = 0) -> Tensor: """ Stacks a list of Tensors rank-R tensors into one Tensor rank-(R+1) tensor. @@ -339,27 +263,20 @@ def stack(x: List[Tensor], axis: int = 0) -> Tensor: :param axis: The axis to stack along. :return: Stacked Tensor. """ - if isinstance(x, List): - return Tensor(_dispatch_list(stack, x, axis=axis)) - raise NotImplementedError(f"Function `stack` is not implemented for {type(x)}") -@functools.singledispatch -@_tensor_guard -def unstack(a: Tensor, axis: int = 0) -> List[Tensor]: +@tensor_dispatch(wrapper_type=WrapperType.TensorToList) +def unstack(x: Tensor, axis: int = 0) -> List[Tensor]: """ Unstack a Tensor into list. - :param a: Tensor to unstack. + :param x: Tensor to unstack. :param axis: The axis to unstack along. :return: List of Tensor. """ - res = unstack(a.data, axis=axis) - return [Tensor(i) for i in res] -@functools.singledispatch -@_tensor_guard +@tensor_dispatch() def moveaxis(a: Tensor, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]]) -> Tensor: """ Move axes of an array to new positions. @@ -369,11 +286,9 @@ def moveaxis(a: Tensor, source: Union[int, Tuple[int, ...]], destination: Union[ :param destination: Destination positions for each of the original axes. These must also be unique. :return: Array with moved axes. """ - return Tensor(moveaxis(a.data, source, destination)) -@functools.singledispatch -@_tensor_guard +@tensor_dispatch() def mean(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Tensor: """ Compute the arithmetic mean along the specified axis. @@ -383,12 +298,10 @@ def mean(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims :param keepdims: Destination positions for each of the original axes. These must also be unique. :return: Array with moved axes. """ - return Tensor(mean(a.data, axis, keepdims)) -@functools.singledispatch -@_tensor_guard -def round(a: Tensor, decimals=0) -> Tensor: +@tensor_dispatch() +def round(a: Tensor, decimals: int = 0) -> Tensor: """ Evenly round to the given number of decimals. @@ -397,47 +310,6 @@ def round(a: Tensor, decimals=0) -> Tensor: it specifies the number of positions to the left of the decimal point. :return: An array of the same type as a, containing the rounded values. """ - return Tensor(round(a.data, decimals)) - - -@functools.singledispatch -@_tensor_guard -def _binary_op_nowarn(a: Tensor, b: Union[Tensor, float], operator_fn: Callable) -> Tensor: - """ - Applies a binary operation with disable warnings. - - :param a: The first tensor. - :param b: The second tensor. - :param operator_fn: The binary operation function. - :return: The result of the binary operation. - """ - return Tensor(_binary_op_nowarn(a.data, unwrap_tensor_data(b), operator_fn)) - - -@functools.singledispatch -@_tensor_guard -def _binary_reverse_op_nowarn(a: Tensor, b: Union[Tensor, float], operator_fn: Callable) -> Tensor: - """ - Applies a binary reverse operation with disable warnings. - - :param a: The first tensor. - :param b: The second tensor. - :param operator_fn: The binary operation function. - :return: The result of the binary operation. - """ - return Tensor(_binary_reverse_op_nowarn(a.data, unwrap_tensor_data(b), operator_fn)) - - -def _dispatch_list(fn: "functools._SingleDispatchCallable", tensor_list: List[Tensor], *args, **kwargs): - """ - Dispatches the function to the type of the wrapped data of the first element in tensor_list. - - :param fn: A function wrapped by `functools.singledispatch`. - :param tensor_list: List of Tensors. - :return: The result value of the function call. - """ - unwrapped_list = [i.data for i in tensor_list] - return fn.dispatch(type(unwrapped_list[0]))(unwrapped_list, *args, **kwargs) def _initialize_backends(): diff --git a/nncf/experimental/tensor/numpy_functions.py b/nncf/experimental/tensor/numpy_functions.py index 7899c2807e4..6be4f82f3ae 100644 --- a/nncf/experimental/tensor/numpy_functions.py +++ b/nncf/experimental/tensor/numpy_functions.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np @@ -188,21 +188,3 @@ def _(a: Union[np.ndarray, np.generic], axis: Union[int, Tuple[int, ...]] = None @_register_numpy_types(fns.round) def _(a: Union[np.ndarray, np.generic], decimals: int = 0) -> np.ndarray: return np.round(a, decimals=decimals) - - -@_register_numpy_types(fns._binary_op_nowarn) -def _( - a: Union[np.ndarray, np.generic], b: Union[np.ndarray, np.generic, float], operator_fn: Callable -) -> Union[np.ndarray, np.generic]: - # Run operator with disabled warning - with np.errstate(invalid="ignore", divide="ignore"): - return operator_fn(a, b) - - -@_register_numpy_types(fns._binary_reverse_op_nowarn) -def _( - a: Union[np.ndarray, np.generic], b: Union[np.ndarray, np.generic, float], operator_fn: Callable -) -> Union[np.ndarray, np.generic]: - # Run operator with disabled warning - with np.errstate(invalid="ignore", divide="ignore"): - return operator_fn(b, a) diff --git a/nncf/experimental/tensor/tensor.py b/nncf/experimental/tensor/tensor.py index 7ec2646ecbc..67db261f8a7 100644 --- a/nncf/experimental/tensor/tensor.py +++ b/nncf/experimental/tensor/tensor.py @@ -10,7 +10,6 @@ # limitations under the License. from __future__ import annotations -import operator from typing import Any, Optional, Tuple, TypeVar, Union from nncf.experimental.tensor.enums import TensorDataType @@ -24,7 +23,7 @@ class Tensor: An interface to framework specific tensors for common NNCF algorithms. """ - def __init__(self, data: Optional[TTensor]): + def __init__(self, data: TTensor): self._data = data.data if isinstance(data, Tensor) else data @property @@ -86,16 +85,16 @@ def __pow__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data ** unwrap_tensor_data(other)) def __truediv__(self, other: Union[Tensor, float]) -> Tensor: - return _call_function("_binary_op_nowarn", self, other, operator.truediv) + return Tensor(self.data / unwrap_tensor_data(other)) def __rtruediv__(self, other: Union[Tensor, float]) -> Tensor: - return _call_function("_binary_reverse_op_nowarn", self, other, operator.truediv) + return Tensor(other / self.data) def __floordiv__(self, other: Union[Tensor, float]) -> Tensor: - return _call_function("_binary_op_nowarn", self, other, operator.floordiv) + return Tensor(self.data // unwrap_tensor_data(other)) def __rfloordiv__(self, other: Union[Tensor, float]) -> Tensor: - return _call_function("_binary_reverse_op_nowarn", self, other, operator.floordiv) + return Tensor(other // self.data) def __neg__(self) -> Tensor: return Tensor(-self.data) diff --git a/nncf/experimental/tensor/torch_functions.py b/nncf/experimental/tensor/torch_functions.py index fb13f3caa2d..b5e5a90aec8 100644 --- a/nncf/experimental/tensor/torch_functions.py +++ b/nncf/experimental/tensor/torch_functions.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch @@ -54,22 +54,16 @@ def _(a: torch.Tensor) -> torch.Tensor: @fns.max.register(torch.Tensor) def _( - a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: Optional[bool] = False + a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: Optional[bool] = False ) -> torch.Tensor: - # Analog of numpy.max is torch.amax - if axis is None: - return torch.amax(a) - return torch.amax(a, dim=axis, keepdim=keepdim) + return torch.amax(a, dim=axis, keepdim=keepdims) @fns.min.register(torch.Tensor) def _( - a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: Optional[bool] = False + a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: Optional[bool] = False ) -> torch.Tensor: - # Analog of numpy.min is torch.amin - if axis is None: - return torch.amin(a) - return torch.amin(a, dim=axis, keepdim=keepdim) + return torch.amin(a, dim=axis, keepdim=keepdims) @fns.abs.register(torch.Tensor) @@ -190,13 +184,3 @@ def _(a: torch.Tensor, axis: Union[int, Tuple[int, ...]] = None, keepdims: bool @fns.round.register(torch.Tensor) def _(a: torch.Tensor, decimals=0) -> torch.Tensor: return torch.round(a, decimals=decimals) - - -@fns._binary_op_nowarn.register(torch.Tensor) -def _(a: torch.Tensor, b: Union[torch.Tensor, float], operator_fn: Callable) -> torch.Tensor: - return operator_fn(a, b) - - -@fns._binary_reverse_op_nowarn.register(torch.Tensor) -def _(a: torch.Tensor, b: Union[torch.Tensor, float], operator_fn: Callable) -> torch.Tensor: - return operator_fn(b, a) diff --git a/nncf/quantization/fake_quantize.py b/nncf/quantization/fake_quantize.py index 38b56c97019..636e1f64919 100644 --- a/nncf/quantization/fake_quantize.py +++ b/nncf/quantization/fake_quantize.py @@ -106,8 +106,10 @@ def tune_range( fval = -left_border * s qval = fns.round(fval) - ra = fns.where(qval < level_high, qval / (qval - level_high) * right_border, left_border) - rb = fns.where(qval > 0.0, (qval - level_high) / qval * left_border, right_border) + # TODO(AlexanderDokuchaev) rework function to avoid divide by zero and inf + with np.errstate(invalid="ignore", divide="ignore"): + ra = fns.where(qval < level_high, qval / (qval - level_high) * right_border, left_border) + rb = fns.where(qval > 0.0, (qval - level_high) / qval * left_border, right_border) range_a = right_border - ra range_b = rb - left_border diff --git a/tests/shared/test_templates/template_test_nncf_tensor.py b/tests/shared/test_templates/template_test_nncf_tensor.py index ca6104c7f62..4f5501a5313 100644 --- a/tests/shared/test_templates/template_test_nncf_tensor.py +++ b/tests/shared/test_templates/template_test_nncf_tensor.py @@ -371,7 +371,7 @@ def test_fn_count_nonzero(self, axis, ref): res = fns.count_nonzero(nncf_tensor, axis=axis) assert isinstance(res, Tensor) - assert fns.allclose(res.data, ref_tensor) + assert fns.allclose(res, ref_tensor) assert res.device == nncf_tensor.device def test_fn_zeros_like(self): @@ -445,7 +445,7 @@ def test_fn_all(self, val, axis, ref): tensor = Tensor(self.to_tensor(val)) res = fns.all(tensor, axis=axis) assert isinstance(res, Tensor) - assert fns.allclose(res.data, self.to_tensor(ref)) + assert fns.allclose(res, self.to_tensor(ref)) assert res.device == tensor.device @pytest.mark.parametrize( @@ -462,7 +462,7 @@ def test_fn_any(self, val, axis, ref): res = fns.any(tensor, axis=axis) assert isinstance(res, Tensor) - assert fns.allclose(res.data, self.to_tensor(ref)) + assert fns.allclose(res, self.to_tensor(ref)) assert res.device == tensor.device def test_fn_where(self): @@ -578,7 +578,7 @@ def test_fn_reshape(self): def test_not_implemented(self): with pytest.raises(NotImplementedError, match="is not implemented for"): - fns.device({}, [1, 2]) + fns.device(Tensor(None)) @pytest.mark.parametrize( "x, axis, ref", @@ -628,7 +628,7 @@ def test_fn_stack(self, x, axis, ref): res = fns.stack(list_tensor, axis=axis) assert isinstance(res, Tensor) - assert fns.all(res.data == ref) + assert fns.all(res == ref) assert res.device == list_tensor[0].device def test_fn_moveaxis(self): @@ -675,7 +675,7 @@ def test_fn_mean(self, x, axis, keepdims, ref): res = fns.mean(tensor, axis, keepdims) assert isinstance(res, Tensor) - assert fns.allclose(res.data, ref_tensor) + assert fns.allclose(res, ref_tensor) assert res.device == tensor.device @pytest.mark.parametrize( @@ -693,7 +693,7 @@ def test_fn_round(self, val, decimals, ref): res = fns.round(tensor, decimals) assert isinstance(res, Tensor) - assert fns.allclose(res.data, ref_tensor) + assert fns.allclose(res, ref_tensor) assert res.device == tensor.device @pytest.mark.parametrize(