diff --git a/alexBot/tools.py b/alexBot/tools.py index 82435a8..195c86a 100644 --- a/alexBot/tools.py +++ b/alexBot/tools.py @@ -6,6 +6,8 @@ from typing import TYPE_CHECKING, Callable, Generator, Iterable, Sequence, Tuple, TypeVar, Union from urllib.parse import urlparse +import functools + import discord from jishaku.paginators import PaginatorInterface from pytz import timezone @@ -24,6 +26,16 @@ _T = TypeVar("_T") +def convert_to_bool(argument: str) -> bool: + lowered = argument.lower() + if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'): + return True + elif lowered in ('no', 'n', 'false', 'f', '0', 'disable', 'off'): + return False + else: + raise commands.BadBoolArgument(lowered) + + class InteractionPaginator(PaginatorInterface): # send_interaction takes an interaction and uses that to send the paginator async def send_interaction(self, interaction: discord.Interaction): @@ -116,9 +128,29 @@ def grouper(iterable: Sequence[_T], n: int) -> Generator[Sequence[_T], None, Non yield iterable[i * n : i * n + n] -def transform_neosdb(url: str) -> str: - url = urlparse(url) - return f"https://cloudxstorage.blob.core.windows.net/assets{posixpath.splitext(url.path)[0]}" +def time_cache(max_age: int, maxsize=128, typed=False): + """Least-recently-used cache decorator with time-based cache invalidation. + + Args: + max_age: Time to live for cached results (in seconds). + maxsize: Maximum cache size (see `functools.lru_cache`). + typed: Cache on distinct input types (see `functools.lru_cache`). + + copied from stackoverflow: https://stackoverflow.com/a/63674816 + """ + + def _decorator(fn): + @functools.lru_cache(maxsize=maxsize, typed=typed) + def _new(*args, __time_salt, **kwargs): + return fn(*args, **kwargs) + + @functools.wraps(fn) + def _wrapped(*args, **kwargs): + return _new(*args, **kwargs, __time_salt=int(time.time() / max_age)) + + return _wrapped + + return _decorator timeUnits = { @@ -127,10 +159,11 @@ def transform_neosdb(url: str) -> str: 'h': lambda v: v * 60 * 60, 'd': lambda v: v * 60 * 60 * 24, 'w': lambda v: v * 60 * 60 * 24 * 7, + 'M': lambda v: v * 60 * 60 * 24 * 30, } -def resolve_duration(data) -> datetime.datetime: +def resolve_duration(data: str, tz: datetime.tzinfo = timezone('UTC')) -> datetime.timedelta: """ Takes a raw input string formatted 1w1d1h1m1s (any order) and converts to timedelta @@ -151,4 +184,4 @@ def resolve_duration(data) -> datetime.datetime: value += timeUnits[char](int(digits)) digits = '' - return datetime.datetime.now(tz=timezone('UTC')) + datetime.timedelta(seconds=value + 1) + return datetime.timedelta(seconds=value)