From 497c7a5de399ca5f088fc26da6a141ca96d26cc4 Mon Sep 17 00:00:00 2001 From: Toricane <73972068+Toricane@users.noreply.github.com> Date: Tue, 6 Sep 2022 00:42:31 -0700 Subject: [PATCH] fix: autodefer with args and kwargs in commands (#1074) * fix: args and kwargs in commands * fix: autodefer errors --- interactions/client/models/command.py | 103 ++++++++++++++++++-------- interactions/utils/utils.py | 7 +- 2 files changed, 77 insertions(+), 33 deletions(-) diff --git a/interactions/client/models/command.py b/interactions/client/models/command.py index c326dc641..0b36aaff5 100644 --- a/interactions/client/models/command.py +++ b/interactions/client/models/command.py @@ -816,59 +816,83 @@ async def command_error(ctx, error): message=f"Your command needs at least {'three parameters to return self, context, and the' if self.extension else 'two parameter to return context and'} error.", ) - self.error_callback = self.__wrap_coro(coro) + self.error_callback = self.__wrap_coro(coro, error_callback=True) return coro async def __call( self, coro: Callable[..., Awaitable], ctx: "CommandContext", - *args, + *args, # empty for now since all parameters are dispatched as kwargs _name: Optional[str] = None, _res: Optional[Union[BaseResult, GroupResult]] = None, **kwargs, ) -> Optional[Any]: """Handles calling the coroutine based on parameter count.""" - param_len = len(signature(coro).parameters) - opt_len = self.num_options.get(_name, len(args) + len(kwargs)) + params = signature(coro).parameters + param_len = len(params) + opt_len = self.num_options.get(_name, len(args) + len(kwargs)) # options of slash command + last = params[list(params)[-1]] # last parameter + has_args = any(param.kind == param.VAR_POSITIONAL for param in params.values()) # any *args + index_of_var_pos = next( + (i for i, param in enumerate(params.values()) if param.kind == param.VAR_POSITIONAL), + param_len, + ) # index of *args + par_opts = list(params.keys())[ + (num := 2 if self.extension else 1) : ( + -1 if last.kind in (last.VAR_POSITIONAL, last.VAR_KEYWORD) else index_of_var_pos + ) + ] # parameters that are before *args and **kwargs + keyword_only_args = list(params.keys())[index_of_var_pos:] # parameters after *args try: _coro = coro if hasattr(coro, "_wrapped") else self.__wrap_coro(coro) - if param_len < (2 if self.extension else 1): + if last.kind == last.VAR_KEYWORD: # foo(ctx, ..., **kwargs) + return await _coro(ctx, *args, **kwargs) + if last.kind == last.VAR_POSITIONAL: # foo(ctx, ..., *args) + return await _coro( + ctx, + *(kwargs[opt] for opt in par_opts if opt in kwargs), + *args, + ) + if has_args: # foo(ctx, ..., *args, ..., **kwargs) OR foo(ctx, *args, ...) + return await _coro( + ctx, + *(kwargs[opt] for opt in par_opts if opt in kwargs), # pos before *args + *args, + *( + kwargs[opt] + for opt in kwargs + if opt not in par_opts and opt not in keyword_only_args + ), # additional args + **{ + opt: kwargs[opt] + for opt in kwargs + if opt not in par_opts and opt in keyword_only_args + }, # kwargs after *args + ) + + if param_len < num: + inner_msg: str = f"{num} parameter{'s' if num > 1 else ''} to return" + ( + " self and" if self.extension else "" + ) raise LibraryException( - code=11, - message=f"Your command needs at least {'two parameters to return self and' if self.extension else 'one parameter to return'} context.", + code=11, message=f"Your command needs at least {inner_msg} context." ) - if param_len == (2 if self.extension else 1): + if param_len == num: return await _coro(ctx) if _res: - if param_len - opt_len == (2 if self.extension else 1): + if param_len - opt_len == num: return await _coro(ctx, *args, **kwargs) - elif param_len - opt_len == (3 if self.extension else 2): + elif param_len - opt_len == num + 1: return await _coro(ctx, _res, *args, **kwargs) return await _coro(ctx, *args, **kwargs) except CancelledError: pass - except Exception as e: - if self.error_callback: - num_params = len(signature(self.error_callback).parameters) - - if num_params == (3 if self.extension else 2): - await self.error_callback(ctx, e) - elif num_params == (4 if self.extension else 3): - await self.error_callback(ctx, e, _res) - else: - await self.error_callback(ctx, e, _res, *args, **kwargs) - elif self.listener and "on_command_error" in self.listener.events: - self.listener.dispatch("on_command_error", ctx, e) - else: - raise e - - return StopCommand def __check_command(self, command_type: str) -> None: """Checks if subcommands, groups, or autocompletions are created on context menus.""" @@ -895,7 +919,9 @@ async def __no_group(self, *args, **kwargs) -> None: """This is the coroutine used when no group coroutine is provided.""" pass - def __wrap_coro(self, coro: Callable[..., Awaitable]) -> Callable[..., Awaitable]: + def __wrap_coro( + self, coro: Callable[..., Awaitable], /, *, error_callback: bool = False + ) -> Callable[..., Awaitable]: """Wraps a coroutine to make sure the :class:`interactions.client.bot.Extension` is passed to the coroutine, if any.""" @wraps(coro) @@ -907,11 +933,28 @@ async def wrapper(ctx: "CommandContext", *args, **kwargs): except CancelledError: pass except Exception as e: + if error_callback: + raise e if self.error_callback: - num_params = len(signature(self.error_callback).parameters) - - if num_params == (3 if self.extension else 2): + params = signature(self.error_callback).parameters + num_params = len(params) + last = params[list(params)[-1]] + num = 2 if self.extension else 1 + + if num_params == num: + await self.error_callback(ctx) + elif num_params == num + 1: await self.error_callback(ctx, e) + elif last.kind == last.VAR_KEYWORD: + if num_params == num + 2: + await self.error_callback(ctx, e, **kwargs) + elif num_params >= num + 3: + await self.error_callback(ctx, e, *args, **kwargs) + elif last.kind == last.VAR_POSITIONAL: + if num_params == num + 2: + await self.error_callback(ctx, e, *args) + elif num_params >= num + 3: + await self.error_callback(ctx, e, *args, **kwargs) else: await self.error_callback(ctx, e, *args, **kwargs) elif self.listener and "on_command_error" in self.listener.events: diff --git a/interactions/utils/utils.py b/interactions/utils/utils.py index 070ad1c02..756b67d08 100644 --- a/interactions/utils/utils.py +++ b/interactions/utils/utils.py @@ -25,7 +25,7 @@ from ..api.models.message import Message from ..api.models.misc import Snowflake from ..client.bot import Client, Extension - from ..client.context import CommandContext + from ..client.context import CommandContext # noqa F401 __all__ = ( "autodefer", @@ -67,7 +67,7 @@ async def command(ctx): """ def decorator(coro: Callable[..., Union[Awaitable, Coroutine]]) -> Callable[..., Awaitable]: - from ..client.context import ComponentContext + from ..client.context import CommandContext, ComponentContext # noqa F811 @wraps(coro) async def deferring_func( @@ -80,7 +80,8 @@ async def deferring_func( if isinstance(args[0], (ComponentContext, CommandContext)): self = ctx - ctx = list(args).pop(0) + args = list(args) + ctx = args.pop(0) task: Task = loop.create_task(coro(self, ctx, *args, **kwargs))