Skip to content

Commit

Permalink
use ParamSpec to annotate memoization decorators
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl authored and inducer committed Jun 15, 2022
1 parent 0a06f30 commit 7dc54b1
Showing 1 changed file with 47 additions and 23 deletions.
70 changes: 47 additions & 23 deletions pytools/__init__.py
Expand Up @@ -33,17 +33,17 @@
import sys
import logging
from typing import (
cast, Any, Callable, Dict, Hashable, Iterable,
Any, Callable, Dict, Generic, Hashable, Iterable,
List, Optional, Set, Tuple, TypeVar, Union)
import builtins

import math
from sys import intern

try:
from typing import SupportsIndex
from typing import SupportsIndex, ParamSpec, Concatenate
except ImportError:
from typing_extensions import SupportsIndex
from typing_extensions import SupportsIndex, ParamSpec, Concatenate

# These are deprecated and will go away in 2022.
all = builtins.all
Expand Down Expand Up @@ -183,18 +183,25 @@
-------------------
.. class:: T
.. class:: R
Any type.
Generic unbound invariant :class:`typing.TypeVar`.
.. class:: F
Any callable.
Generic invariant :class:`typing.TypeVar` bound to a :class:`typing.Callable`.
.. class:: P
Generic unbound invariant :class:`typing.ParamSpec`.
"""

# {{{ type variables

T = TypeVar("T")
R = TypeVar("R")
F = TypeVar("F", bound=Callable[..., Any])
P = ParamSpec("P")

# }}}

Expand Down Expand Up @@ -727,7 +734,9 @@ class _HasKwargs:
pass


def memoize_on_first_arg(function: F, cache_dict_name: Optional[str] = None) -> F:
def memoize_on_first_arg(
function: Callable[Concatenate[T, P], R], *,
cache_dict_name: Optional[str] = None) -> Callable[Concatenate[T, P], R]:
"""Like :func:`memoize_method`, but for functions that take the object
in which do memoization information is stored as first argument.
Expand All @@ -739,12 +748,13 @@ def memoize_on_first_arg(function: F, cache_dict_name: Optional[str] = None) ->
f"_memoize_dic_{function.__module__}{function.__name__}"
)

def wrapper(obj, *args, **kwargs):
def wrapper(obj: T, *args: P.args, **kwargs: P.kwargs) -> R:
if kwargs:
key = (_HasKwargs, frozenset(kwargs.items())) + args
else:
key = args

assert cache_dict_name is not None
try:
return getattr(obj, cache_dict_name)[key]
except AttributeError:
Expand All @@ -770,10 +780,12 @@ def clear_cache(obj):
# into the function's dict is moderately sketchy.
new_wrapper.clear_cache = clear_cache # type: ignore[attr-defined]

return cast(F, new_wrapper)
return new_wrapper


def memoize_method(method: F) -> F:
def memoize_method(
method: Callable[Concatenate[T, P], R]
) -> Callable[Concatenate[T, P], R]:
"""Supports cache deletion via ``method_name.clear_cache(self)``.
.. versionchanged:: 2021.2
Expand All @@ -783,10 +795,10 @@ def memoize_method(method: F) -> F:
"""

return memoize_on_first_arg(method,
intern(f"_memoize_dic_{method.__name__}"))
cache_dict_name=intern(f"_memoize_dic_{method.__name__}"))


class keyed_memoize_on_first_arg: # noqa: N801
class keyed_memoize_on_first_arg(Generic[T, P, R]): # noqa: N801
"""Like :func:`memoize_method`, but for functions that take the object
in which memoization information is stored as first argument.
Expand All @@ -800,23 +812,29 @@ class keyed_memoize_on_first_arg: # noqa: N801
.. versionadded :: 2020.3
"""

def __init__(self, key, cache_dict_name=None):
def __init__(self,
key: Callable[P, Hashable], *,
cache_dict_name: Optional[str] = None) -> None:
self.key = key
self.cache_dict_name = cache_dict_name

def _default_cache_dict_name(self, function):
def _default_cache_dict_name(self,
function: Callable[Concatenate[T, P], R]) -> str:
return intern(f"_memoize_dic_{function.__module__}{function.__name__}")

def __call__(self, function):
def __call__(
self, function: Callable[Concatenate[T, P], R]
) -> Callable[Concatenate[T, P], R]:
cache_dict_name = self.cache_dict_name
key = self.key

if cache_dict_name is None:
cache_dict_name = self._default_cache_dict_name(function)

def wrapper(obj, *args, **kwargs):
def wrapper(obj: T, *args: P.args, **kwargs: P.kwargs) -> R:
cache_key = key(*args, **kwargs)

assert cache_dict_name is not None
try:
return getattr(obj, cache_dict_name)[cache_key]
except AttributeError:
Expand All @@ -833,7 +851,7 @@ def clear_cache(obj):

from functools import update_wrapper
new_wrapper = update_wrapper(wrapper, function)
new_wrapper.clear_cache = clear_cache
new_wrapper.clear_cache = clear_cache # type: ignore[attr-defined]

return new_wrapper

Expand All @@ -858,7 +876,7 @@ def _default_cache_dict_name(self, function):
return intern(f"_memoize_dic_{function.__name__}")


class memoize_in: # noqa
class memoize_in(Generic[P, R]): # noqa
"""Adds a cache to the function it decorates. The cache is attached
to *container* and must be uniquely specified by *identifier* (i.e.
all functions using the same *container* and *identifier* will be using
Expand Down Expand Up @@ -892,9 +910,11 @@ def __init__(self, container: Any, identifier: Hashable) -> None:

self.cache_dict = memoize_in_dict.setdefault(identifier, {})

def __call__(self, inner: F) -> F:
def __call__(self, inner: Callable[P, R]) -> Callable[P, R]:
@wraps(inner)
def new_inner(*args):
def new_inner(*args: P.args, **kwargs: P.kwargs) -> R:
assert not kwargs

try:
return self.cache_dict[args]
except KeyError:
Expand All @@ -908,7 +928,7 @@ def new_inner(*args):
return new_inner # type: ignore[return-value]


class keyed_memoize_in: # noqa
class keyed_memoize_in(Generic[P, R]): # noqa
"""Like :class:`memoize_in`, but additionally uses a function *key* to
compute the key under which the function result is memoized.
Expand All @@ -918,7 +938,9 @@ class keyed_memoize_in: # noqa
.. versionadded :: 2021.2.1
"""

def __init__(self, container, identifier, key):
def __init__(self,
container: Any, identifier: Hashable,
key: Callable[P, Hashable]) -> None:
try:
memoize_in_dict = container._pytools_keyed_memoize_in_dict
except AttributeError:
Expand All @@ -929,10 +951,12 @@ def __init__(self, container, identifier, key):
self.cache_dict = memoize_in_dict.setdefault(identifier, {})
self.key = key

def __call__(self, inner):
def __call__(self, inner: Callable[P, R]) -> Callable[P, R]:
@wraps(inner)
def new_inner(*args):
def new_inner(*args: P.args, **kwargs: P.kwargs) -> R:
assert not kwargs
key = self.key(*args)

try:
return self.cache_dict[key]
except KeyError:
Expand Down

0 comments on commit 7dc54b1

Please sign in to comment.