From 7dc54b146d3c006018853cc54ed22f455d41d50b Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Mon, 30 May 2022 10:37:55 +0300 Subject: [PATCH] use ParamSpec to annotate memoization decorators --- pytools/__init__.py | 70 ++++++++++++++++++++++++++++++--------------- 1 file changed, 47 insertions(+), 23 deletions(-) diff --git a/pytools/__init__.py b/pytools/__init__.py index 6a3ed65a..c2ced5f0 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -33,7 +33,7 @@ import sys import logging from typing import ( - cast, Any, Callable, Dict, Hashable, Iterable, + Any, Callable, Dict, Generic, Hashable, Iterable, List, Optional, Set, Tuple, TypeVar, Union) import builtins @@ -41,9 +41,9 @@ from sys import intern try: - from typing import SupportsIndex + from typing import SupportsIndex, ParamSpec, Concatenate except ImportError: - from typing_extensions import SupportsIndex + from typing_extensions import SupportsIndex, ParamSpec, Concatenate # These are deprecated and will go away in 2022. all = builtins.all @@ -183,18 +183,25 @@ ------------------- .. class:: T +.. class:: R - Any type. + Generic unbound invariant :class:`typing.TypeVar`. .. class:: F - Any callable. + Generic invariant :class:`typing.TypeVar` bound to a :class:`typing.Callable`. + +.. class:: P + + Generic unbound invariant :class:`typing.ParamSpec`. """ # {{{ type variables T = TypeVar("T") +R = TypeVar("R") F = TypeVar("F", bound=Callable[..., Any]) +P = ParamSpec("P") # }}} @@ -727,7 +734,9 @@ class _HasKwargs: pass -def memoize_on_first_arg(function: F, cache_dict_name: Optional[str] = None) -> F: +def memoize_on_first_arg( + function: Callable[Concatenate[T, P], R], *, + cache_dict_name: Optional[str] = None) -> Callable[Concatenate[T, P], R]: """Like :func:`memoize_method`, but for functions that take the object in which do memoization information is stored as first argument. @@ -739,12 +748,13 @@ def memoize_on_first_arg(function: F, cache_dict_name: Optional[str] = None) -> f"_memoize_dic_{function.__module__}{function.__name__}" ) - def wrapper(obj, *args, **kwargs): + def wrapper(obj: T, *args: P.args, **kwargs: P.kwargs) -> R: if kwargs: key = (_HasKwargs, frozenset(kwargs.items())) + args else: key = args + assert cache_dict_name is not None try: return getattr(obj, cache_dict_name)[key] except AttributeError: @@ -770,10 +780,12 @@ def clear_cache(obj): # into the function's dict is moderately sketchy. new_wrapper.clear_cache = clear_cache # type: ignore[attr-defined] - return cast(F, new_wrapper) + return new_wrapper -def memoize_method(method: F) -> F: +def memoize_method( + method: Callable[Concatenate[T, P], R] + ) -> Callable[Concatenate[T, P], R]: """Supports cache deletion via ``method_name.clear_cache(self)``. .. versionchanged:: 2021.2 @@ -783,10 +795,10 @@ def memoize_method(method: F) -> F: """ return memoize_on_first_arg(method, - intern(f"_memoize_dic_{method.__name__}")) + cache_dict_name=intern(f"_memoize_dic_{method.__name__}")) -class keyed_memoize_on_first_arg: # noqa: N801 +class keyed_memoize_on_first_arg(Generic[T, P, R]): # noqa: N801 """Like :func:`memoize_method`, but for functions that take the object in which memoization information is stored as first argument. @@ -800,23 +812,29 @@ class keyed_memoize_on_first_arg: # noqa: N801 .. versionadded :: 2020.3 """ - def __init__(self, key, cache_dict_name=None): + def __init__(self, + key: Callable[P, Hashable], *, + cache_dict_name: Optional[str] = None) -> None: self.key = key self.cache_dict_name = cache_dict_name - def _default_cache_dict_name(self, function): + def _default_cache_dict_name(self, + function: Callable[Concatenate[T, P], R]) -> str: return intern(f"_memoize_dic_{function.__module__}{function.__name__}") - def __call__(self, function): + def __call__( + self, function: Callable[Concatenate[T, P], R] + ) -> Callable[Concatenate[T, P], R]: cache_dict_name = self.cache_dict_name key = self.key if cache_dict_name is None: cache_dict_name = self._default_cache_dict_name(function) - def wrapper(obj, *args, **kwargs): + def wrapper(obj: T, *args: P.args, **kwargs: P.kwargs) -> R: cache_key = key(*args, **kwargs) + assert cache_dict_name is not None try: return getattr(obj, cache_dict_name)[cache_key] except AttributeError: @@ -833,7 +851,7 @@ def clear_cache(obj): from functools import update_wrapper new_wrapper = update_wrapper(wrapper, function) - new_wrapper.clear_cache = clear_cache + new_wrapper.clear_cache = clear_cache # type: ignore[attr-defined] return new_wrapper @@ -858,7 +876,7 @@ def _default_cache_dict_name(self, function): return intern(f"_memoize_dic_{function.__name__}") -class memoize_in: # noqa +class memoize_in(Generic[P, R]): # noqa """Adds a cache to the function it decorates. The cache is attached to *container* and must be uniquely specified by *identifier* (i.e. all functions using the same *container* and *identifier* will be using @@ -892,9 +910,11 @@ def __init__(self, container: Any, identifier: Hashable) -> None: self.cache_dict = memoize_in_dict.setdefault(identifier, {}) - def __call__(self, inner: F) -> F: + def __call__(self, inner: Callable[P, R]) -> Callable[P, R]: @wraps(inner) - def new_inner(*args): + def new_inner(*args: P.args, **kwargs: P.kwargs) -> R: + assert not kwargs + try: return self.cache_dict[args] except KeyError: @@ -908,7 +928,7 @@ def new_inner(*args): return new_inner # type: ignore[return-value] -class keyed_memoize_in: # noqa +class keyed_memoize_in(Generic[P, R]): # noqa """Like :class:`memoize_in`, but additionally uses a function *key* to compute the key under which the function result is memoized. @@ -918,7 +938,9 @@ class keyed_memoize_in: # noqa .. versionadded :: 2021.2.1 """ - def __init__(self, container, identifier, key): + def __init__(self, + container: Any, identifier: Hashable, + key: Callable[P, Hashable]) -> None: try: memoize_in_dict = container._pytools_keyed_memoize_in_dict except AttributeError: @@ -929,10 +951,12 @@ def __init__(self, container, identifier, key): self.cache_dict = memoize_in_dict.setdefault(identifier, {}) self.key = key - def __call__(self, inner): + def __call__(self, inner: Callable[P, R]) -> Callable[P, R]: @wraps(inner) - def new_inner(*args): + def new_inner(*args: P.args, **kwargs: P.kwargs) -> R: + assert not kwargs key = self.key(*args) + try: return self.cache_dict[key] except KeyError: