diff --git a/atools/_memoize_decorator.py b/atools/_memoize_decorator.py index ac7e47b..ffe8028 100644 --- a/atools/_memoize_decorator.py +++ b/atools/_memoize_decorator.py @@ -123,11 +123,14 @@ def table_name(self) -> str: def default_keygen(self, *args, **kwargs) -> Tuple[Hashable, ...]: """Returns all params (args, kwargs, and missing default kwargs) for function as kwargs.""" + + return tuple(self.get_args_as_kwargs(*args, **kwargs).values()) + + def get_args_as_kwargs(self, *args, **kwargs) -> Mapping[str, Any]: args_as_kwargs = {} for k, v in zip(self.default_kwargs, args): args_as_kwargs[k] = v - - return tuple(ChainMap(args_as_kwargs, kwargs, self.default_kwargs).values()) + return ChainMap(args_as_kwargs, kwargs, self.default_kwargs) def get_memo(self, key: Union[int, str]) -> _Memo: try: @@ -221,7 +224,7 @@ async def get_key(self, *args, **kwargs) -> Union[int, str]: if self.keygen is None: key = self.default_keygen(*args, **kwargs) else: - key = self.keygen(*args, **kwargs) + key = self.keygen(**self.get_args_as_kwargs(*args, **kwargs)) if isinstance(key, tuple): key = list(key) else: @@ -277,7 +280,7 @@ def get_key(self, *args, **kwargs) -> Union[int, str]: if self.keygen is None: key = self.default_keygen(*args, **kwargs) else: - key = self.keygen(*args, **kwargs) + key = self.keygen(**self.get_args_as_kwargs(*args, **kwargs)) key = self.get_hashed_key(key) diff --git a/setup.py b/setup.py index ee2987c..4e6d25d 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='atools', - version='0.10.4', + version='0.10.5', packages=find_packages(), python_requires='>=3.6', url='https://github.com/cevans87/atools', diff --git a/test/test_memoize_decorator.py b/test/test_memoize_decorator.py index d072da2..d4cba68 100644 --- a/test/test_memoize_decorator.py +++ b/test/test_memoize_decorator.py @@ -919,3 +919,11 @@ def _foo() -> None: del foo # FIXME there's a race condition here. Garbage collector may not have cleaned up foo yet assert r() is None + + +def test_keygen_works_with_default_kwargs() -> None: + @memoize(keygen=lambda bar: bar) + def foo(bar=1) -> None: + ... + + foo()