Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor dispatcher #2660

Draft
wants to merge 7 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 11 additions & 16 deletions nncf/experimental/tensor/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,45 +108,40 @@ tensor_a[0:2] # Tensor(array([[1],[2]]))
2. Add function to functions module

```python
@functools.singledispatch
def foo(a: TTensor, arg1: Type) -> TTensor:
@tensor_dispatch
def foo(a: Tensor, arg1: Type) -> Tensor:
"""
__description__

:param a: The input tensor.
:param a: __description__
:param arg1: __description__
:return: __description__
"""
if isinstance(a, tensor.Tensor):
return tensor.Tensor(foo(a.data, axis))
return NotImplemented(f"Function `foo` is not implemented for {type(a)}")
```

**NOTE** For the case when the first argument has type `List[Tensor]`, use the `_dispatch_list` function. This function dispatches function by first element in the first argument.
**NOTE** Type of wrapper function selected by type hint of function, supported signatures of functions:

```python
@functools.singledispatch
def foo(x: List[Tensor], axis: int = 0) -> Tensor:
if isinstance(x, List):
unwrapped_x = [i.data for i in x]
return Tensor(_dispatch_list(foo, unwrapped_x, axis=axis))
raise NotImplementedError(f"Function `foo` is not implemented for {type(x)}")
def foo(a: Tensor, ...) -> Tensor:
def foo(a: Tensor, ...) -> Any:
def foo(a: Tensor, ...) -> List[Tensor]:
def foo(a: List[Tensor], ...) -> Tensor:
```

3. Add backend specific implementation of method to correcponding module:
3. Add backend specific implementation of method to corresponding module:

- `functions/numpy_*.py`

```python
@_register_numpy_types(fns.foo)
@fns.foo.register
def _(a: TType, arg1: Type) -> np.ndarray:
return np.foo(a, arg1)
```

- `functions/torch_*.py`

```python
@fns.foo.register(torch.Tensor)
@fns.foo.register
def _(a: torch.Tensor, arg1: Type) -> torch.Tensor:
return torch.foo(a, arg1)
```
Expand Down
162 changes: 131 additions & 31 deletions nncf/experimental/tensor/functions/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,50 +8,150 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import List

import numpy as np
from functools import _find_impl
from inspect import getfullargspec
from inspect import isclass
from inspect import isfunction
from types import MappingProxyType
from typing import List, get_type_hints

from nncf.experimental.tensor import Tensor


def tensor_guard(func: callable):
def _get_target_types(type_alias):
if isclass(type_alias):
return [type_alias]
ret = []
for t in type_alias.__args__:
ret.extend(_get_target_types(t))
return ret


def tensor_dispatch(func):
"""
A decorator that ensures that the first argument to the decorated function is a Tensor.
This decorator creates a registry of functions for different types and provides a wrapper
that calls the appropriate function based on the type of the first argument.
It's particularly designed to handle Tensor inputs and outputs effectively.

:param func: The function to be decorated.
:return: The decorated function with type-based dispatching functionality.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
if isinstance(args[0], Tensor):
return func(*args, **kwargs)
raise NotImplementedError(f"Function `{func.__name__}` is not implemented for {type(args[0])}")
registry = {}

return wrapper
def dispatch(cls):
"""
Retrieves the registered function for a given type.

:param cls: The type to retrieve the function for.
:return: The registered function for the given type.
"""
try:
return registry[cls]
except KeyError:
return _find_impl(cls, registry)

def dispatch_list(fn: "functools._SingleDispatchCallable", tensor_list: List[Tensor], *args, **kwargs):
"""
Dispatches the function to the type of the wrapped data of the first element in tensor_list.
def register(rfunc):
"""Registers a function for a specific type or types.

:param fn: A function wrapped by `functools.singledispatch`.
:param tensor_list: List of Tensors.
:return: The result value of the function call.
"""
unwrapped_list = [i.data for i in tensor_list]
return fn.dispatch(type(unwrapped_list[0]))(unwrapped_list, *args, **kwargs)
:param rfunc: The function to register.
:return: The registered function.
"""
assert isfunction(rfunc), "Register object should be a function."
assert getfullargspec(func)[0] == getfullargspec(rfunc)[0], "Differ names of arguments of function"

target_type_hint = get_type_hints(rfunc).get(getfullargspec(rfunc)[0][0])
assert target_type_hint is not None, "No type hint for first argument of function"

def register_numpy_types(singledispatch_fn):
"""
Decorator to register function to singledispatch for numpy classes.
types_to_registry = set(_get_target_types(target_type_hint))
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved

:param singledispatch_fn: singledispatch function.
"""
for t in types_to_registry:
assert t not in registry, f"{t} already registered for function"
registry[t] = rfunc
return rfunc

def inner(func):
singledispatch_fn.register(np.ndarray)(func)
singledispatch_fn.register(np.generic)(func)
return func
def wrapper_tensor_to_tensor(*args, **kw):
"""
Wrapper for functions that take and return a Tensor.
This wrapper unwraps Tensor arguments and wraps the returned value in a Tensor if necessary.
"""
is_wrapped = any(isinstance(x, Tensor) for x in args)
args = tuple(x.data if isinstance(x, Tensor) else x for x in args)
kw = {k: v.data if isinstance(v, Tensor) else v for k, v in kw.items()}
ret = dispatch(args[0].__class__)(*args, **kw)
return Tensor(ret) if is_wrapped else ret

return inner
def wrapper_tensor_to_any(*args, **kw):
"""
Wrapper for functions that take a Tensor and return any type.
This wrapper unwraps Tensor arguments but doesn't specifically wrap the returned value.
"""
args = tuple(x.data if isinstance(x, Tensor) else x for x in args)
kw = {k: v.data if isinstance(v, Tensor) else v for k, v in kw.items()}
return dispatch(args[0].__class__)(*args, **kw)

def wrapper_tensor_to_list(*args, **kw):
"""
Wrapper for functions that take a Tensor and return a list.
This wrapper unwraps Tensor arguments and wraps the list elements as Tensors if necessary.
"""
is_wrapped = any(isinstance(x, Tensor) for x in args)
args = tuple(x.data if isinstance(x, Tensor) else x for x in args)
kw = {k: v.data if isinstance(v, Tensor) else v for k, v in kw.items()}
ret = dispatch(args[0].__class__)(*args, **kw)
if is_wrapped:
return [Tensor(x) for x in ret]
return ret

def wrapper_list_to_tensor(list_of_tensors: List[Tensor], *args, **kw):
"""
Wrapper for functions that take a list of Tensors and return a Tensor.
This wrapper handles lists containing Tensors appropriately.
"""
if any(isinstance(x, Tensor) for x in list_of_tensors):
args = tuple(x.data if isinstance(x, Tensor) else x for x in args)
kw = {k: v.data if isinstance(v, Tensor) else v for k, v in kw.items()}
list_of_tensors = [x.data if isinstance(x, Tensor) else x for x in list_of_tensors]
return Tensor(dispatch(list_of_tensors[0].__class__)(list_of_tensors, *args, **kw))
return dispatch(list_of_tensors[0].__class__)(list_of_tensors, *args, **kw)

def raise_not_implemented(*args, **kw):
"""
Raises a NotImplementedError for types that are not registered.
"""
if isinstance(args[0], list):
arg_type = type(args[0][0].data) if isinstance(args[0][0], Tensor) else type(args[0][0])
else:
arg_type = type(args[0].data) if isinstance(args[0], Tensor) else type(args[0])

raise NotImplementedError(f"Function `{func.__name__}` is not implemented for {arg_type}")

# Select wrapper by signature of function
type_hints = get_type_hints(func)
first_type_hint = type_hints.get(getfullargspec(func)[0][0])
return_type_hint = type_hints.get("return")
wrapper = None
if first_type_hint is Tensor:
if return_type_hint is Tensor:
wrapper = wrapper_tensor_to_tensor
elif not isclass(return_type_hint) and return_type_hint._name == "List":
wrapper = wrapper_tensor_to_list
else:
wrapper = wrapper_tensor_to_any
elif not isclass(first_type_hint) and first_type_hint._name == "List" and return_type_hint is Tensor:
wrapper = wrapper_list_to_tensor

assert wrapper is not None, (
"Not supported signature of dispatch function, supported:\n"
" def foo(a: Tensor, ...) -> Tensor\n"
" def foo(a: Tensor, ...) -> Any\n"
" def foo(a: Tensor, ...) -> List[Tensor]\n"
" def foo(a: List[Tensor], ...) -> Tensor\n"
)

registry[object] = raise_not_implemented
wrapper.register = register
wrapper.dispatch = dispatch
wrapper.registry = MappingProxyType(registry)

return wrapper
19 changes: 5 additions & 14 deletions nncf/experimental/tensor/functions/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from typing import Optional, Tuple, Union

from nncf.experimental.tensor import Tensor
from nncf.experimental.tensor.functions.dispatcher import tensor_guard
from nncf.experimental.tensor.functions.dispatcher import tensor_dispatch


@functools.singledispatch
@tensor_guard
@tensor_dispatch
def norm(
a: Tensor,
ord: Optional[Union[str, float, int]] = None,
Expand Down Expand Up @@ -61,11 +59,9 @@ def norm(
as dimensions with size one. Default: False.
:return: Norm of the matrix or vector.
"""
return Tensor(norm(a.data, ord, axis, keepdims))


@functools.singledispatch
@tensor_guard
@tensor_dispatch
def cholesky(a: Tensor, upper: bool = False) -> Tensor:
"""
Computes the Cholesky decomposition of a complex Hermitian or real symmetric
Expand All @@ -81,11 +77,9 @@ def cholesky(a: Tensor, upper: bool = False) -> Tensor:
Default is lower-triangular.
:return: Upper- or lower-triangular Cholesky factor of `a`.
"""
return Tensor(cholesky(a.data, upper))


@functools.singledispatch
@tensor_guard
@tensor_dispatch
def cholesky_inverse(a: Tensor, upper: bool = False) -> Tensor:
"""
Computes the inverse of a complex Hermitian or real symmetric positive-definite matrix given
Expand All @@ -97,11 +91,9 @@ def cholesky_inverse(a: Tensor, upper: bool = False) -> Tensor:
upper triangular. Default: False.
:return: The inverse of matrix given its Cholesky decomposition.
"""
return Tensor(cholesky_inverse(a.data, upper))


@functools.singledispatch
@tensor_guard
@tensor_dispatch
def inv(a: Tensor) -> Tensor:
"""
Computes the inverse of a matrix.
Expand All @@ -110,4 +102,3 @@ def inv(a: Tensor) -> Tensor:
consisting of invertible matrices.
:return: The inverse of the input tensor.
"""
return Tensor(inv(a.data))
Loading
Loading