Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add the client, command, and extension to CommandContext instances #1093

Merged
merged 2 commits into from Sep 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions interactions/client/bot.py
Expand Up @@ -1038,6 +1038,7 @@ def decorator(coro: Callable[..., Coroutine]) -> Command:
description_localizations=description_localizations,
default_scope=default_scope,
)
cmd.client = self
self._commands.append(cmd)
return cmd

Expand Down Expand Up @@ -1590,6 +1591,7 @@ def __new__(cls, client: Client, *args, **kwargs) -> "Extension":
continue

cmd.extension = self
cmd.client = self.client
self.client._commands.append(cmd)

commands = self._commands.get(cmd.name, [])
Expand Down
18 changes: 12 additions & 6 deletions interactions/client/context.py
@@ -1,8 +1,7 @@
from logging import Logger
from typing import List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union

from ..api.error import LibraryException
from ..api.http.client import HTTPClient
from ..api.models.channel import Channel
from ..api.models.flags import Permissions
from ..api.models.guild import Guild
Expand All @@ -18,6 +17,10 @@
from .models.component import ActionRow, Button, Modal, SelectMenu, _build_components
from .models.misc import InteractionData

if TYPE_CHECKING:
from .bot import Client, Extension
from .models.command import Command

log: Logger = get_logger("context")

__all__ = (
Expand All @@ -43,7 +46,6 @@ class _Context(ClientSerializerMixin):
:ivar Optional[Guild] guild: The guild data model.
"""

client: HTTPClient = field(default=None)
message: Optional[Message] = field(converter=Message, default=None, add_client=True)
author: Member = field(converter=Member, default=None, add_client=True)
member: Member = field(converter=Member, add_client=True)
Expand All @@ -66,9 +68,6 @@ class _Context(ClientSerializerMixin):
app_permissions: Permissions = field(converter=convert_int(Permissions), default=None)

def __attrs_post_init__(self) -> None:
# backwards compatibility
self.client = self._client

if self.member:
if self.guild_id:
self.member._extras["guild_id"] = self.guild_id
Expand Down Expand Up @@ -367,10 +366,17 @@ class CommandContext(_Context):
:ivar str locale?: The selected language of the user invoking the interaction.
:ivar str guild_locale?: The guild's preferred language, if invoked in a guild.
:ivar str app_permissions?: Bitwise set of permissions the bot has within the channel the interaction was sent from.
:ivar Client client: The client instance that the command belongs to.
:ivar Command command: The command object that is being invoked.
:ivar Extension extension: The extension the command belongs to.
"""

target: Optional[Union[Message, Member, User]] = field(default=None)

client: "Client" = field(default=None, init=False)
command: "Command" = field(default=None, init=False)
extension: "Extension" = field(default=None, init=False)

def __attrs_post_init__(self) -> None:
super().__attrs_post_init__()

Expand Down
8 changes: 7 additions & 1 deletion interactions/client/models/command.py
Expand Up @@ -17,7 +17,7 @@

if TYPE_CHECKING:
from ...api.dispatch import Listener
from ..bot import Extension
from ..bot import Client, Extension
from ..context import CommandContext

__all__ = (
Expand Down Expand Up @@ -413,6 +413,7 @@ class Command(DictSerializerMixin):
:ivar Optional[str] recent_group: The name of the group most recently utilized.
:ivar bool resolved: Whether the command is synced. Defaults to ``False``.
:ivar Optional[Extension] extension: The extension that the command belongs to, if any.
:ivar Client client: The client that the command belongs to.
:ivar Optional[Listener] listener: The listener, used for dispatching command errors.
"""

Expand All @@ -437,6 +438,7 @@ class Command(DictSerializerMixin):
error_callback: Optional[Callable[..., Awaitable]] = field(default=None, init=False)
resolved: bool = field(default=False, init=False)
extension: Optional["Extension"] = field(default=None, init=False)
client: "Client" = field(default=None, init=False)
listener: Optional["Listener"] = field(default=None, init=False)

def __attrs_post_init__(self) -> None:
Expand Down Expand Up @@ -927,6 +929,10 @@ def __wrap_coro(

@wraps(coro)
async def wrapper(ctx: "CommandContext", *args, **kwargs):
ctx.client = self.client
ctx.command = self
ctx.extension = self.extension

try:
if self.extension:
return await coro(self.extension, ctx, *args, **kwargs)
Expand Down