Skip to content

Commit

Permalink
fix: autodefer with args and kwargs in commands (#1074)
Browse files Browse the repository at this point in the history
* fix: args and kwargs in commands

* fix: autodefer errors
  • Loading branch information
Toricane committed Sep 6, 2022
1 parent 7d5a369 commit 497c7a5
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 33 deletions.
103 changes: 73 additions & 30 deletions interactions/client/models/command.py
Expand Up @@ -816,59 +816,83 @@ async def command_error(ctx, error):
message=f"Your command needs at least {'three parameters to return self, context, and the' if self.extension else 'two parameter to return context and'} error.",
)

self.error_callback = self.__wrap_coro(coro)
self.error_callback = self.__wrap_coro(coro, error_callback=True)
return coro

async def __call(
self,
coro: Callable[..., Awaitable],
ctx: "CommandContext",
*args,
*args, # empty for now since all parameters are dispatched as kwargs
_name: Optional[str] = None,
_res: Optional[Union[BaseResult, GroupResult]] = None,
**kwargs,
) -> Optional[Any]:
"""Handles calling the coroutine based on parameter count."""
param_len = len(signature(coro).parameters)
opt_len = self.num_options.get(_name, len(args) + len(kwargs))
params = signature(coro).parameters
param_len = len(params)
opt_len = self.num_options.get(_name, len(args) + len(kwargs)) # options of slash command
last = params[list(params)[-1]] # last parameter
has_args = any(param.kind == param.VAR_POSITIONAL for param in params.values()) # any *args
index_of_var_pos = next(
(i for i, param in enumerate(params.values()) if param.kind == param.VAR_POSITIONAL),
param_len,
) # index of *args
par_opts = list(params.keys())[
(num := 2 if self.extension else 1) : (
-1 if last.kind in (last.VAR_POSITIONAL, last.VAR_KEYWORD) else index_of_var_pos
)
] # parameters that are before *args and **kwargs
keyword_only_args = list(params.keys())[index_of_var_pos:] # parameters after *args

try:
_coro = coro if hasattr(coro, "_wrapped") else self.__wrap_coro(coro)

if param_len < (2 if self.extension else 1):
if last.kind == last.VAR_KEYWORD: # foo(ctx, ..., **kwargs)
return await _coro(ctx, *args, **kwargs)
if last.kind == last.VAR_POSITIONAL: # foo(ctx, ..., *args)
return await _coro(
ctx,
*(kwargs[opt] for opt in par_opts if opt in kwargs),
*args,
)
if has_args: # foo(ctx, ..., *args, ..., **kwargs) OR foo(ctx, *args, ...)
return await _coro(
ctx,
*(kwargs[opt] for opt in par_opts if opt in kwargs), # pos before *args
*args,
*(
kwargs[opt]
for opt in kwargs
if opt not in par_opts and opt not in keyword_only_args
), # additional args
**{
opt: kwargs[opt]
for opt in kwargs
if opt not in par_opts and opt in keyword_only_args
}, # kwargs after *args
)

if param_len < num:
inner_msg: str = f"{num} parameter{'s' if num > 1 else ''} to return" + (
" self and" if self.extension else ""
)
raise LibraryException(
code=11,
message=f"Your command needs at least {'two parameters to return self and' if self.extension else 'one parameter to return'} context.",
code=11, message=f"Your command needs at least {inner_msg} context."
)

if param_len == (2 if self.extension else 1):
if param_len == num:
return await _coro(ctx)

if _res:
if param_len - opt_len == (2 if self.extension else 1):
if param_len - opt_len == num:
return await _coro(ctx, *args, **kwargs)
elif param_len - opt_len == (3 if self.extension else 2):
elif param_len - opt_len == num + 1:
return await _coro(ctx, _res, *args, **kwargs)

return await _coro(ctx, *args, **kwargs)
except CancelledError:
pass
except Exception as e:
if self.error_callback:
num_params = len(signature(self.error_callback).parameters)

if num_params == (3 if self.extension else 2):
await self.error_callback(ctx, e)
elif num_params == (4 if self.extension else 3):
await self.error_callback(ctx, e, _res)
else:
await self.error_callback(ctx, e, _res, *args, **kwargs)
elif self.listener and "on_command_error" in self.listener.events:
self.listener.dispatch("on_command_error", ctx, e)
else:
raise e

return StopCommand

def __check_command(self, command_type: str) -> None:
"""Checks if subcommands, groups, or autocompletions are created on context menus."""
Expand All @@ -895,7 +919,9 @@ async def __no_group(self, *args, **kwargs) -> None:
"""This is the coroutine used when no group coroutine is provided."""
pass

def __wrap_coro(self, coro: Callable[..., Awaitable]) -> Callable[..., Awaitable]:
def __wrap_coro(
self, coro: Callable[..., Awaitable], /, *, error_callback: bool = False
) -> Callable[..., Awaitable]:
"""Wraps a coroutine to make sure the :class:`interactions.client.bot.Extension` is passed to the coroutine, if any."""

@wraps(coro)
Expand All @@ -907,11 +933,28 @@ async def wrapper(ctx: "CommandContext", *args, **kwargs):
except CancelledError:
pass
except Exception as e:
if error_callback:
raise e
if self.error_callback:
num_params = len(signature(self.error_callback).parameters)

if num_params == (3 if self.extension else 2):
params = signature(self.error_callback).parameters
num_params = len(params)
last = params[list(params)[-1]]
num = 2 if self.extension else 1

if num_params == num:
await self.error_callback(ctx)
elif num_params == num + 1:
await self.error_callback(ctx, e)
elif last.kind == last.VAR_KEYWORD:
if num_params == num + 2:
await self.error_callback(ctx, e, **kwargs)
elif num_params >= num + 3:
await self.error_callback(ctx, e, *args, **kwargs)
elif last.kind == last.VAR_POSITIONAL:
if num_params == num + 2:
await self.error_callback(ctx, e, *args)
elif num_params >= num + 3:
await self.error_callback(ctx, e, *args, **kwargs)
else:
await self.error_callback(ctx, e, *args, **kwargs)
elif self.listener and "on_command_error" in self.listener.events:
Expand Down
7 changes: 4 additions & 3 deletions interactions/utils/utils.py
Expand Up @@ -25,7 +25,7 @@
from ..api.models.message import Message
from ..api.models.misc import Snowflake
from ..client.bot import Client, Extension
from ..client.context import CommandContext
from ..client.context import CommandContext # noqa F401

__all__ = (
"autodefer",
Expand Down Expand Up @@ -67,7 +67,7 @@ async def command(ctx):
"""

def decorator(coro: Callable[..., Union[Awaitable, Coroutine]]) -> Callable[..., Awaitable]:
from ..client.context import ComponentContext
from ..client.context import CommandContext, ComponentContext # noqa F811

@wraps(coro)
async def deferring_func(
Expand All @@ -80,7 +80,8 @@ async def deferring_func(

if isinstance(args[0], (ComponentContext, CommandContext)):
self = ctx
ctx = list(args).pop(0)
args = list(args)
ctx = args.pop(0)

task: Task = loop.create_task(coro(self, ctx, *args, **kwargs))

Expand Down

0 comments on commit 497c7a5

Please sign in to comment.