Skip to content

Commit

Permalink
Adds memoize.reset_call
Browse files Browse the repository at this point in the history
  • Loading branch information
cevans87 committed Dec 12, 2019
1 parent e2bd385 commit 84813d7
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 66 deletions.
156 changes: 95 additions & 61 deletions atools/_memoize_decorator.py
Expand Up @@ -14,7 +14,7 @@


Decoratee = Union[Callable, Type]
GetKey = Callable[..., Tuple[Any]]
Keygen = Callable[..., Tuple[Any]]

_default_db_path = Path.home() / '.memoize'

Expand All @@ -33,7 +33,7 @@ class _MemoReturnState:
@dataclass(frozen=True)
class _MemoBase:
fn: Callable
expire_time: Optional[float]
t0: Optional[float]
memo_return_state: _MemoReturnState = field(init=False, default_factory=_MemoReturnState)


Expand All @@ -54,9 +54,9 @@ class _SyncMemo(_MemoBase):
class _MemoizeBase:
db: Optional[Connection]
default_kwargs: Mapping[str, Any]
fn: Callable
get_key: Optional[GetKey]
duration: Optional[timedelta]
fn: Callable
keygen: Optional[Keygen]
size: Optional[int]

expire_order: OrderedDict = field(init=False, default_factory=OrderedDict, hash=False)
Expand All @@ -67,13 +67,16 @@ def __post_init__(self) -> None:
self.db.execute(dedent(f'''
CREATE TABLE IF NOT EXISTS `{self.table_name}` (
k TEXT PRIMARY KEY,
t0 FLOAT,
t FLOAT,
e FLOAT,
v TEXT NOT NULL
)
'''))
if self.duration:
self.db.execute(f"DELETE FROM `{self.table_name}` WHERE e < {time()}")
self.db.execute(dedent(f'''
DELETE FROM `{self.table_name}`
WHERE t0 < {time() - self.duration.total_seconds()}
'''))

if self.size:
res = self.db.execute(
Expand All @@ -82,23 +85,16 @@ def __post_init__(self) -> None:
if res:
(min_t,) = res[-1]
self.db.execute(f"DELETE FROM `{self.table_name}` WHERE t < {min_t}")
for k, t, v in self.db.execute(
f"SELECT k, t, v FROM `{self.table_name}` ORDER BY t"
for k, t0, t, v in self.db.execute(
f"SELECT k, t0, t, v FROM `{self.table_name}` ORDER BY t"
).fetchall():
memo = self.make_memo(
fn=self.fn,
expire_time=(
t + self.duration.total_seconds()
if self.duration is not None
else None
)
)
memo = self.make_memo(fn=self.fn, t0=t0)
memo.memo_return_state.called = True
(memo.memo_return_state.value,) = eval(v, __builtins__)
self.memos[k] = memo
if self.duration:
for k, e in self.db.execute(
f"SELECT k, e FROM `{self.table_name}` ORDER BY e"
for k, t0 in self.db.execute(
f"SELECT k, t0 FROM `{self.table_name}` ORDER BY t0"
).fetchall():
self.expire_order[k] = ...
self.db.commit()
Expand All @@ -115,29 +111,29 @@ def table_name(self) -> str:
f':{self.fn.__code__.co_firstlineno}'
)

def get_default_key(self, *args, **kwargs) -> Tuple[Hashable, ...]:
def default_keygen(self, *args, **kwargs) -> Tuple[Hashable, ...]:
"""Returns all params (args, kwargs, and missing default kwargs) for function as kwargs."""
args_as_kwargs = {}
for k, v in zip(self.default_kwargs, args):
args_as_kwargs[k] = v

return tuple(ChainMap(args_as_kwargs, kwargs, self.default_kwargs).values())

def get_memo(self, key: int) -> _Memo:
def get_memo(self, key: Union[int, str]) -> _Memo:
try:
memo = self.memos[key] = self.memos.pop(key)
if self.duration is not None and memo.expire_time < time():
if self.duration is not None and memo.t0 < time() - self.duration.total_seconds():
self.expire_order.pop(key)
raise ValueError('value expired')
except (KeyError, ValueError):
if self.duration is None:
expire_time = None
t0 = None
else:
expire_time = time() + self.duration.total_seconds()
t0 = time()
# The value has no significance. We're using the dict entirely for ordering keys.
self.expire_order[key] = ...

memo = self.memos[key] = self.make_memo(self.fn, expire_time=expire_time)
memo = self.memos[key] = self.make_memo(self.fn, t0=t0)

return memo

Expand All @@ -146,7 +142,10 @@ def expire_one_memo(self) -> None:
if (
(self.expire_order is not None) and
(len(self.expire_order) > 0) and
(self.memos[next(iter(self.expire_order))].expire_time < time())
(
self.memos[next(iter(self.expire_order))].t0 <
time() - self.duration.total_seconds()
)
):
(k, _) = self.expire_order.popitem(last=False)
self.memos.pop(k)
Expand All @@ -156,32 +155,40 @@ def expire_one_memo(self) -> None:
self.db.execute(f"DELETE FROM `{self.table_name}` WHERE k = '{k}'")
self.db.commit()

def finalize_memo(self, memo: _Memo, key: int) -> Any:
def finalize_memo(self, memo: _Memo, key: Union[int, str]) -> Any:
if memo.memo_return_state.raised:
raise memo.memo_return_state.value
else:
if self.db is not None:
if (self.db is not None) and (self.memos[key] is memo):
value = str((memo.memo_return_state.value,))
assert (memo.memo_return_state.value,) == eval(value, __builtins__)
self.db.execute(
dedent(f'''
INSERT OR REPLACE INTO `{self.table_name}`
(k, t, e, v)
(k, t0, t, v)
VALUES
(?, ?, ?, ?)
'''),
(
key,
memo.t0,
time(),
memo.expire_time,
value
)
)
self.db.commit()
return memo.memo_return_state.value

def get_hashed_key(self, key: Tuple[Hashable]) -> Union[int, str]:
if self.db is None:
key = hash(key)
else:
key = sha256(str(key).encode()).hexdigest()

return key

@staticmethod
def make_memo(fn, expire_time: Optional[float]) -> _Memo: # pragma: no cover
def make_memo(fn, t0: Optional[float]) -> _Memo: # pragma: no cover
raise NotImplemented

def reset(self) -> None:
Expand All @@ -191,23 +198,36 @@ def reset(self) -> None:
self.db.execute(f"DELETE FROM `{self.table_name}`")
self.db.commit()

def reset_key(self, key: Union[int, str]) -> None:
if key in self.memos:
self.memos.pop(key)
if self.duration is not None:
self.expire_order.pop(key)
if self.db is not None:
self.db.execute(f"DELETE FROM `{self.table_name}` WHERE k == '{key}'")


@dataclass(frozen=True)
class _AsyncMemoize(_MemoizeBase):

async def get_key(self, *args, **kwargs) -> Union[int, str]:
if self.keygen is None:
key = self.default_keygen(*args, **kwargs)
else:
key = self.keygen(*args, **kwargs)
key = list(key)
for i, v in enumerate(key):
if inspect.isawaitable(v):
key[i] = await v
key = tuple(key)

key = self.get_hashed_key(key)

return key

def get_decorator(self) -> Callable:
async def decorator(*args, **kwargs) -> Any:
if self.get_key is None:
key = self.get_default_key(*args, **kwargs)
else:
key = self.get_key(*args, **kwargs)
key = list(key)
for i, v in enumerate(key):
if inspect.isawaitable(v):
key[i] = await v
key = tuple(key)

key = sha256(str(key).encode()).hexdigest()
key = await self.get_key(*args, **kwargs)

memo: _AsyncMemo = self.get_memo(key)

Expand All @@ -229,23 +249,32 @@ async def decorator(*args, **kwargs) -> Any:
return decorator

@staticmethod
def make_memo(fn, expire_time: Optional[float]) -> _AsyncMemo:
return _AsyncMemo(fn=fn, expire_time=expire_time)
def make_memo(fn, t0: Optional[float]) -> _AsyncMemo:
return _AsyncMemo(fn=fn, t0=t0)

async def reset_call(self, *args, **kwargs) -> None:
key = await self.get_key(*args, **kwargs)
self.reset_key(key)


@dataclass(frozen=True)
class _SyncMemoize(_MemoizeBase):

_sync_lock: SyncLock = field(init=False, default_factory=lambda: SyncLock())

def get_key(self, *args, **kwargs) -> Union[int, str]:
if self.keygen is None:
key = self.default_keygen(*args, **kwargs)
else:
key = self.keygen(*args, **kwargs)

key = self.get_hashed_key(key)

return key

def get_decorator(self) -> Callable:
def decorator(*args, **kwargs):
if self.get_key is None:
key = self.get_default_key(*args, **kwargs)
else:
key = self.get_key(*args, **kwargs)

key = sha256(str(key).encode()).hexdigest()
key = self.get_key(*args, **kwargs)

with self._sync_lock:
memo: _SyncMemo = self.get_memo(key)
Expand All @@ -268,13 +297,18 @@ def decorator(*args, **kwargs):
return decorator

@staticmethod
def make_memo(fn, expire_time: Optional[float]) -> _SyncMemo:
return _SyncMemo(fn=fn, expire_time=expire_time)
def make_memo(fn, t0: Optional[float]) -> _SyncMemo:
return _SyncMemo(fn=fn, t0=t0)

def reset(self) -> None:
with self._sync_lock:
super().reset()

def reset_call(self, *args, **kwargs) -> None:
key = self.get_key(*args, **kwargs)
with self._sync_lock:
self.reset_key(key)


_Memoize = Union[_AsyncMemoize, _SyncMemoize]

Expand All @@ -286,7 +320,7 @@ def memoize(
*,
db: Union[bool, Path, str] = False,
duration: Optional[Union[int, float, timedelta]] = None,
get_key: Optional[GetKey] = None,
keygen: Optional[Keygen] = None,
size: Optional[int] = None,
):
"""Decorates a function call and caches return value for given inputs.
Expand All @@ -295,7 +329,7 @@ def memoize(
If 'duration' is provided, memoize will only retain return values for up to given 'duration'.
If 'get_key' is provided, memoize will use the function to calculate the memoize hash key.
If 'keygen' is provided, memoize will use the function to calculate the memoize hash key.
If 'size' is provided, memoize will only retain up to 'size' return values.
Expand Down Expand Up @@ -366,13 +400,13 @@ def foo(bar) -> Any: ...
len(foo.memoize) # returns 2
- Memoization hash keys can be generated from a non-default function:
@memoize(get_key=lambda a, b, c: (a, b, c))
@memoize(keygen=lambda a, b, c: (a, b, c))
def foo(a, b, c) -> Any: ...
- If part of the returned key from get_key is awaitable, it will be awaited.
- If part of the returned key from keygen is awaitable, it will be awaited.
async def await_something() -> Hashable: ...
@memoize(get_key=lambda bar: (bar, await_something()))
@memoize(keygen=lambda bar: (bar, await_something()))
async def foo(bar) -> Any: ...
- Properties can be memoized
Expand Down Expand Up @@ -401,7 +435,7 @@ def bar(self, baz): -> Any: ...
- The default memoize key generator can be overridden. The inputs must match the function's.
Class Foo:
@memoize(get_key=lambda self, a, b, c: (a, b, c))
@memoize(keygen=lambda self, a, b, c: (a, b, c))
def bar(self, a, b, c) -> Any: ...
a, b = Foo(), Foo()
Expand All @@ -416,7 +450,7 @@ def bar(self, a, b, c) -> Any: ...
- If the memoized function is async and any part of the key is awaitable, it is awaited.
async def morph_a(a: int) -> int: ...
@memoize(get_key=lambda a, b, c: (morph_a(a), b, c))
@memoize(keygen=lambda a, b, c: (morph_a(a), b, c))
def foo(a, b, c) -> Any: ...
- Values can persist to disk and be reloaded when memoize is initialized again.
Expand Down Expand Up @@ -448,7 +482,7 @@ def bar(cls, a) -> Any: ...
# You can create a consistent hash key to avoid this.
class Foo:
@classmethod
@memoize(db=True, get_key=lambda cls: (f'{cls.__package__}:{cls.__name__}', a))
@memoize(db=True, keygen=lambda cls: (f'{cls.__package__}:{cls.__name__}', a))
def bar(cls, a) -> Any: ...
- Alternative location of 'db' can also be given as pathlib.Path or str.
Expand All @@ -459,7 +493,7 @@ def foo() -> Any: ...
def bar() -> Any: ...
"""
if _decoratee is None:
return partial(memoize, db=db, duration=duration, get_key=get_key, size=size)
return partial(memoize, db=db, duration=duration, keygen=keygen, size=size)

if inspect.isclass(_decoratee):
assert not db, 'Class memoization not allowed with db.'
Expand Down Expand Up @@ -502,7 +536,7 @@ class Wrapped(_decoratee, metaclass=WrappedMeta):
default_kwargs=default_kwargs,
duration=duration,
fn=fn,
get_key=get_key,
keygen=keygen,
size=size,
).get_decorator()

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -2,7 +2,7 @@

setup(
name='atools',
version='0.9.2',
version='0.10.0',
packages=find_packages(),
python_requires='>=3.6',
url='https://github.com/cevans87/atools',
Expand Down

0 comments on commit 84813d7

Please sign in to comment.