From 84813d73106ad59ab4def34e7b337f50df8bddb0 Mon Sep 17 00:00:00 2001 From: cevans87 Date: Thu, 12 Dec 2019 12:11:19 -0800 Subject: [PATCH] Adds memoize.reset_call --- atools/_memoize_decorator.py | 156 ++++++++++++++++++++------------- setup.py | 2 +- test/test_memoize_decorator.py | 83 +++++++++++++++++- 3 files changed, 175 insertions(+), 66 deletions(-) diff --git a/atools/_memoize_decorator.py b/atools/_memoize_decorator.py index 0af482c..9aeaa97 100644 --- a/atools/_memoize_decorator.py +++ b/atools/_memoize_decorator.py @@ -14,7 +14,7 @@ Decoratee = Union[Callable, Type] -GetKey = Callable[..., Tuple[Any]] +Keygen = Callable[..., Tuple[Any]] _default_db_path = Path.home() / '.memoize' @@ -33,7 +33,7 @@ class _MemoReturnState: @dataclass(frozen=True) class _MemoBase: fn: Callable - expire_time: Optional[float] + t0: Optional[float] memo_return_state: _MemoReturnState = field(init=False, default_factory=_MemoReturnState) @@ -54,9 +54,9 @@ class _SyncMemo(_MemoBase): class _MemoizeBase: db: Optional[Connection] default_kwargs: Mapping[str, Any] - fn: Callable - get_key: Optional[GetKey] duration: Optional[timedelta] + fn: Callable + keygen: Optional[Keygen] size: Optional[int] expire_order: OrderedDict = field(init=False, default_factory=OrderedDict, hash=False) @@ -67,13 +67,16 @@ def __post_init__(self) -> None: self.db.execute(dedent(f''' CREATE TABLE IF NOT EXISTS `{self.table_name}` ( k TEXT PRIMARY KEY, + t0 FLOAT, t FLOAT, - e FLOAT, v TEXT NOT NULL ) ''')) if self.duration: - self.db.execute(f"DELETE FROM `{self.table_name}` WHERE e < {time()}") + self.db.execute(dedent(f''' + DELETE FROM `{self.table_name}` + WHERE t0 < {time() - self.duration.total_seconds()} + ''')) if self.size: res = self.db.execute( @@ -82,23 +85,16 @@ def __post_init__(self) -> None: if res: (min_t,) = res[-1] self.db.execute(f"DELETE FROM `{self.table_name}` WHERE t < {min_t}") - for k, t, v in self.db.execute( - f"SELECT k, t, v FROM `{self.table_name}` ORDER BY t" + for k, t0, t, v in self.db.execute( + f"SELECT k, t0, t, v FROM `{self.table_name}` ORDER BY t" ).fetchall(): - memo = self.make_memo( - fn=self.fn, - expire_time=( - t + self.duration.total_seconds() - if self.duration is not None - else None - ) - ) + memo = self.make_memo(fn=self.fn, t0=t0) memo.memo_return_state.called = True (memo.memo_return_state.value,) = eval(v, __builtins__) self.memos[k] = memo if self.duration: - for k, e in self.db.execute( - f"SELECT k, e FROM `{self.table_name}` ORDER BY e" + for k, t0 in self.db.execute( + f"SELECT k, t0 FROM `{self.table_name}` ORDER BY t0" ).fetchall(): self.expire_order[k] = ... self.db.commit() @@ -115,7 +111,7 @@ def table_name(self) -> str: f':{self.fn.__code__.co_firstlineno}' ) - def get_default_key(self, *args, **kwargs) -> Tuple[Hashable, ...]: + def default_keygen(self, *args, **kwargs) -> Tuple[Hashable, ...]: """Returns all params (args, kwargs, and missing default kwargs) for function as kwargs.""" args_as_kwargs = {} for k, v in zip(self.default_kwargs, args): @@ -123,21 +119,21 @@ def get_default_key(self, *args, **kwargs) -> Tuple[Hashable, ...]: return tuple(ChainMap(args_as_kwargs, kwargs, self.default_kwargs).values()) - def get_memo(self, key: int) -> _Memo: + def get_memo(self, key: Union[int, str]) -> _Memo: try: memo = self.memos[key] = self.memos.pop(key) - if self.duration is not None and memo.expire_time < time(): + if self.duration is not None and memo.t0 < time() - self.duration.total_seconds(): self.expire_order.pop(key) raise ValueError('value expired') except (KeyError, ValueError): if self.duration is None: - expire_time = None + t0 = None else: - expire_time = time() + self.duration.total_seconds() + t0 = time() # The value has no significance. We're using the dict entirely for ordering keys. self.expire_order[key] = ... - memo = self.memos[key] = self.make_memo(self.fn, expire_time=expire_time) + memo = self.memos[key] = self.make_memo(self.fn, t0=t0) return memo @@ -146,7 +142,10 @@ def expire_one_memo(self) -> None: if ( (self.expire_order is not None) and (len(self.expire_order) > 0) and - (self.memos[next(iter(self.expire_order))].expire_time < time()) + ( + self.memos[next(iter(self.expire_order))].t0 < + time() - self.duration.total_seconds() + ) ): (k, _) = self.expire_order.popitem(last=False) self.memos.pop(k) @@ -156,32 +155,40 @@ def expire_one_memo(self) -> None: self.db.execute(f"DELETE FROM `{self.table_name}` WHERE k = '{k}'") self.db.commit() - def finalize_memo(self, memo: _Memo, key: int) -> Any: + def finalize_memo(self, memo: _Memo, key: Union[int, str]) -> Any: if memo.memo_return_state.raised: raise memo.memo_return_state.value else: - if self.db is not None: + if (self.db is not None) and (self.memos[key] is memo): value = str((memo.memo_return_state.value,)) assert (memo.memo_return_state.value,) == eval(value, __builtins__) self.db.execute( dedent(f''' INSERT OR REPLACE INTO `{self.table_name}` - (k, t, e, v) + (k, t0, t, v) VALUES (?, ?, ?, ?) '''), ( key, + memo.t0, time(), - memo.expire_time, value ) ) self.db.commit() return memo.memo_return_state.value + def get_hashed_key(self, key: Tuple[Hashable]) -> Union[int, str]: + if self.db is None: + key = hash(key) + else: + key = sha256(str(key).encode()).hexdigest() + + return key + @staticmethod - def make_memo(fn, expire_time: Optional[float]) -> _Memo: # pragma: no cover + def make_memo(fn, t0: Optional[float]) -> _Memo: # pragma: no cover raise NotImplemented def reset(self) -> None: @@ -191,23 +198,36 @@ def reset(self) -> None: self.db.execute(f"DELETE FROM `{self.table_name}`") self.db.commit() + def reset_key(self, key: Union[int, str]) -> None: + if key in self.memos: + self.memos.pop(key) + if self.duration is not None: + self.expire_order.pop(key) + if self.db is not None: + self.db.execute(f"DELETE FROM `{self.table_name}` WHERE k == '{key}'") + @dataclass(frozen=True) class _AsyncMemoize(_MemoizeBase): + async def get_key(self, *args, **kwargs) -> Union[int, str]: + if self.keygen is None: + key = self.default_keygen(*args, **kwargs) + else: + key = self.keygen(*args, **kwargs) + key = list(key) + for i, v in enumerate(key): + if inspect.isawaitable(v): + key[i] = await v + key = tuple(key) + + key = self.get_hashed_key(key) + + return key + def get_decorator(self) -> Callable: async def decorator(*args, **kwargs) -> Any: - if self.get_key is None: - key = self.get_default_key(*args, **kwargs) - else: - key = self.get_key(*args, **kwargs) - key = list(key) - for i, v in enumerate(key): - if inspect.isawaitable(v): - key[i] = await v - key = tuple(key) - - key = sha256(str(key).encode()).hexdigest() + key = await self.get_key(*args, **kwargs) memo: _AsyncMemo = self.get_memo(key) @@ -229,8 +249,12 @@ async def decorator(*args, **kwargs) -> Any: return decorator @staticmethod - def make_memo(fn, expire_time: Optional[float]) -> _AsyncMemo: - return _AsyncMemo(fn=fn, expire_time=expire_time) + def make_memo(fn, t0: Optional[float]) -> _AsyncMemo: + return _AsyncMemo(fn=fn, t0=t0) + + async def reset_call(self, *args, **kwargs) -> None: + key = await self.get_key(*args, **kwargs) + self.reset_key(key) @dataclass(frozen=True) @@ -238,14 +262,19 @@ class _SyncMemoize(_MemoizeBase): _sync_lock: SyncLock = field(init=False, default_factory=lambda: SyncLock()) + def get_key(self, *args, **kwargs) -> Union[int, str]: + if self.keygen is None: + key = self.default_keygen(*args, **kwargs) + else: + key = self.keygen(*args, **kwargs) + + key = self.get_hashed_key(key) + + return key + def get_decorator(self) -> Callable: def decorator(*args, **kwargs): - if self.get_key is None: - key = self.get_default_key(*args, **kwargs) - else: - key = self.get_key(*args, **kwargs) - - key = sha256(str(key).encode()).hexdigest() + key = self.get_key(*args, **kwargs) with self._sync_lock: memo: _SyncMemo = self.get_memo(key) @@ -268,13 +297,18 @@ def decorator(*args, **kwargs): return decorator @staticmethod - def make_memo(fn, expire_time: Optional[float]) -> _SyncMemo: - return _SyncMemo(fn=fn, expire_time=expire_time) + def make_memo(fn, t0: Optional[float]) -> _SyncMemo: + return _SyncMemo(fn=fn, t0=t0) def reset(self) -> None: with self._sync_lock: super().reset() + def reset_call(self, *args, **kwargs) -> None: + key = self.get_key(*args, **kwargs) + with self._sync_lock: + self.reset_key(key) + _Memoize = Union[_AsyncMemoize, _SyncMemoize] @@ -286,7 +320,7 @@ def memoize( *, db: Union[bool, Path, str] = False, duration: Optional[Union[int, float, timedelta]] = None, - get_key: Optional[GetKey] = None, + keygen: Optional[Keygen] = None, size: Optional[int] = None, ): """Decorates a function call and caches return value for given inputs. @@ -295,7 +329,7 @@ def memoize( If 'duration' is provided, memoize will only retain return values for up to given 'duration'. - If 'get_key' is provided, memoize will use the function to calculate the memoize hash key. + If 'keygen' is provided, memoize will use the function to calculate the memoize hash key. If 'size' is provided, memoize will only retain up to 'size' return values. @@ -366,13 +400,13 @@ def foo(bar) -> Any: ... len(foo.memoize) # returns 2 - Memoization hash keys can be generated from a non-default function: - @memoize(get_key=lambda a, b, c: (a, b, c)) + @memoize(keygen=lambda a, b, c: (a, b, c)) def foo(a, b, c) -> Any: ... - - If part of the returned key from get_key is awaitable, it will be awaited. + - If part of the returned key from keygen is awaitable, it will be awaited. async def await_something() -> Hashable: ... - @memoize(get_key=lambda bar: (bar, await_something())) + @memoize(keygen=lambda bar: (bar, await_something())) async def foo(bar) -> Any: ... - Properties can be memoized @@ -401,7 +435,7 @@ def bar(self, baz): -> Any: ... - The default memoize key generator can be overridden. The inputs must match the function's. Class Foo: - @memoize(get_key=lambda self, a, b, c: (a, b, c)) + @memoize(keygen=lambda self, a, b, c: (a, b, c)) def bar(self, a, b, c) -> Any: ... a, b = Foo(), Foo() @@ -416,7 +450,7 @@ def bar(self, a, b, c) -> Any: ... - If the memoized function is async and any part of the key is awaitable, it is awaited. async def morph_a(a: int) -> int: ... - @memoize(get_key=lambda a, b, c: (morph_a(a), b, c)) + @memoize(keygen=lambda a, b, c: (morph_a(a), b, c)) def foo(a, b, c) -> Any: ... - Values can persist to disk and be reloaded when memoize is initialized again. @@ -448,7 +482,7 @@ def bar(cls, a) -> Any: ... # You can create a consistent hash key to avoid this. class Foo: @classmethod - @memoize(db=True, get_key=lambda cls: (f'{cls.__package__}:{cls.__name__}', a)) + @memoize(db=True, keygen=lambda cls: (f'{cls.__package__}:{cls.__name__}', a)) def bar(cls, a) -> Any: ... - Alternative location of 'db' can also be given as pathlib.Path or str. @@ -459,7 +493,7 @@ def foo() -> Any: ... def bar() -> Any: ... """ if _decoratee is None: - return partial(memoize, db=db, duration=duration, get_key=get_key, size=size) + return partial(memoize, db=db, duration=duration, keygen=keygen, size=size) if inspect.isclass(_decoratee): assert not db, 'Class memoization not allowed with db.' @@ -502,7 +536,7 @@ class Wrapped(_decoratee, metaclass=WrappedMeta): default_kwargs=default_kwargs, duration=duration, fn=fn, - get_key=get_key, + keygen=keygen, size=size, ).get_decorator() diff --git a/setup.py b/setup.py index e34ee02..b65347c 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='atools', - version='0.9.2', + version='0.10.0', packages=find_packages(), python_requires='>=3.6', url='https://github.com/cevans87/atools', diff --git a/test/test_memoize_decorator.py b/test/test_memoize_decorator.py index 914567b..5f13cc8 100644 --- a/test/test_memoize_decorator.py +++ b/test/test_memoize_decorator.py @@ -573,10 +573,10 @@ class Foo: assert Foo.__doc__ == "Foo doc" -def test_get_key_overrides_default() -> None: +def test_keygen_overrides_default() -> None: body = MagicMock() - @memoize(get_key=lambda bar, baz: (bar,)) + @memoize(keygen=lambda bar, baz: (bar,)) def foo(bar: int, baz: int) -> int: body(bar, baz) @@ -589,7 +589,7 @@ def foo(bar: int, baz: int) -> int: @pytest.mark.asyncio -async def test_get_key_awaits_awaitable_parts() -> None: +async def test_keygen_awaits_awaitable_parts() -> None: key_part_body = MagicMock() @@ -600,7 +600,7 @@ async def key_part(bar: int, baz: int) -> Tuple[Hashable, ...]: body = MagicMock() - @memoize(get_key=lambda bar, baz: (key_part(bar, baz),)) + @memoize(keygen=lambda bar, baz: (key_part(bar, baz),)) async def foo(bar: int, baz: int) -> int: body(bar, baz) @@ -773,3 +773,78 @@ def foo_inner() -> FrozenSet[int]: assert foo() == frozenset({1, 2, 3}) assert body.call_count == 1 + + +def test_sync_reset_call_resets_one() -> None: + body = MagicMock() + + @memoize + def foo(bar: int) -> None: + body(bar) + + for i in range(10): + foo(i) + assert body.call_count == 10 + + foo.memoize.reset_call(5) + for i in range(10): + foo(i) + assert body.call_count == 11 + + +def test_reset_call_before_expire_resets_one(time: MagicMock) -> None: + body = MagicMock() + + @memoize(duration=timedelta(days=1)) + def foo(bar: int) -> None: + body(bar) + + time.return_value = 0.0 + foo(0) + foo(0) + assert body.call_count == 1 + + foo.memoize.reset_call(0) + foo(0) + foo(0) + assert body.call_count == 2 + + +@pytest.mark.asyncio +async def test_async_reset_call_resets_call() -> None: + body = MagicMock() + + @memoize + async def foo(bar: int) -> None: + body(bar) + + for i in range(10): + await foo(i) + assert body.call_count == 10 + + await foo.memoize.reset_call(5) + for i in range(10): + await foo(i) + assert body.call_count == 11 + + +def test_reset_call_with_db_resets_call(db: Union[bool, Connection, Path, str]) -> None: + body = MagicMock() + + def get_foo() -> Callable[[int], None]: + @memoize(db=db) + def foo(_i: int) -> None: + body(_i) + + return foo + + foo = get_foo() + for i in range(10): + foo(i) + assert body.call_count == 10 + + foo = get_foo() + foo.memoize.reset_call(5) + for i in range(10): + foo(i) + assert body.call_count == 11