diff --git a/discord_slash/__init__.py b/discord_slash/__init__.py index 7fa481236..75b207e76 100644 --- a/discord_slash/__init__.py +++ b/discord_slash/__init__.py @@ -13,6 +13,8 @@ from .context import ComponentContext # noqa: F401 from .context import SlashContext # noqa: F401 from .dpy_overrides import ComponentMessage # noqa: F401 +from .model import ButtonStyle # noqa: F401 +from .model import ComponentType # noqa: F401 from .model import SlashCommandOptionType # noqa: F401 from .utils import manage_commands # noqa: F401 from .utils import manage_components # noqa: F401 diff --git a/discord_slash/context.py b/discord_slash/context.py index 5972ccad5..29d2501fc 100644 --- a/discord_slash/context.py +++ b/discord_slash/context.py @@ -13,7 +13,7 @@ class InteractionContext: """ Base context for interactions.\n - Kinda similar with discord.ext.commands.Context. + In some ways similar with discord.ext.commands.Context. .. warning:: Do not manually init this model. @@ -139,7 +139,7 @@ async def send( components: typing.List[dict] = None, ) -> model.SlashMessage: """ - Sends response of the slash command. + Sends response of the interaction. .. warning:: - Since Release 1.0.9, this is completely changed. If you are migrating from older version, please make sure to fix the usage. @@ -297,6 +297,8 @@ def __init__( self.origin_message = None self.origin_message_id = int(_json["message"]["id"]) if "message" in _json.keys() else None + self._deferred_edit_origin = False + if self.origin_message_id and (_json["message"]["flags"] & 64) != 64: self.origin_message = ComponentMessage( state=self.bot._connection, channel=self.channel, data=_json["message"] @@ -307,17 +309,58 @@ async def defer(self, hidden: bool = False, edit_origin: bool = False): 'Defers' the response, showing a loading state to the user :param hidden: Whether the deferred response should be ephemeral . Default ``False``. - :param edit_origin: Whether the response is editing the origin message. If ``False``, the deferred response will be for a follow up message. Defaults ``False``. + :param edit_origin: Whether the type is editing the origin message. If ``False``, the deferred response will be for a follow up message. Defaults ``False``. """ if self.deferred or self.responded: raise error.AlreadyResponded("You have already responded to this command!") + base = {"type": 6 if edit_origin else 5} - if hidden and not edit_origin: + + if hidden: + if edit_origin: + raise error.IncorrectFormat( + "'hidden' and 'edit_origin' flags are mutually exclusive" + ) base["data"] = {"flags": 64} self._deferred_hidden = True + + self._deferred_edit_origin = edit_origin + await self._http.post_initial_response(base, self.interaction_id, self._token) self.deferred = True + async def send( + self, + content: str = "", + *, + embed: discord.Embed = None, + embeds: typing.List[discord.Embed] = None, + tts: bool = False, + file: discord.File = None, + files: typing.List[discord.File] = None, + allowed_mentions: discord.AllowedMentions = None, + hidden: bool = False, + delete_after: float = None, + components: typing.List[dict] = None, + ) -> model.SlashMessage: + if self.deferred and self._deferred_edit_origin: + self._logger.warning( + "Deferred response might not be what you set it to! (edit origin / send response message) " + "This is because it was deferred with different response type." + ) + return await super().send( + content, + embed=embed, + embeds=embeds, + tts=tts, + file=file, + files=files, + allowed_mentions=allowed_mentions, + hidden=hidden, + delete_after=delete_after, + components=components, + ) + async def edit_origin(self, **fields): """ Edits the origin message of the component. @@ -366,13 +409,16 @@ async def edit_origin(self, **fields): if files and not self.deferred: await self.defer(edit_origin=True) if self.deferred: - _json = await self._http.edit(_resp, self._token, files=files) + if not self._deferred_edit_origin: + self._logger.warning( + "Deferred response might not be what you set it to! (edit origin / send response message) " + "This is because it was deferred with different response type." + ) + await self._http.edit(_resp, self._token, files=files) self.deferred = False - else: # noqa: F841 + else: json_data = {"type": 7, "data": _resp} - _json = await self._http.post_initial_response( # noqa: F841 - json_data, self.interaction_id, self._token - ) + await self._http.post_initial_response(json_data, self.interaction_id, self._token) self.responded = True else: raise error.IncorrectFormat("Already responded") diff --git a/discord_slash/model.py b/discord_slash/model.py index c8b4f2a34..f8959aff8 100644 --- a/discord_slash/model.py +++ b/discord_slash/model.py @@ -556,3 +556,24 @@ def from_type(cls, t: type): return cls.ROLE if issubclass(t, discord.abc.User): return cls.USER + + +class ComponentType(IntEnum): + actionrow = 1 + button = 2 + select = 3 + + +class ButtonStyle(IntEnum): + blue = 1 + blurple = 1 + gray = 2 + grey = 2 + green = 3 + red = 4 + URL = 5 + + primary = 1 + secondary = 2 + success = 3 + danger = 4 diff --git a/discord_slash/utils/manage_components.py b/discord_slash/utils/manage_components.py index 6d64cc9a9..60f0bfbd1 100644 --- a/discord_slash/utils/manage_components.py +++ b/discord_slash/utils/manage_components.py @@ -1,17 +1,11 @@ -import enum import typing import uuid import discord from ..context import ComponentContext -from ..error import IncorrectFormat - - -class ComponentsType(enum.IntEnum): - actionrow = 1 - button = 2 - select = 3 +from ..error import IncorrectFormat, IncorrectType +from ..model import ButtonStyle, ComponentType def create_actionrow(*components: dict) -> dict: @@ -24,27 +18,12 @@ def create_actionrow(*components: dict) -> dict: if not components or len(components) > 5: raise IncorrectFormat("Number of components in one row should be between 1 and 5.") if ( - ComponentsType.select in [component["type"] for component in components] + ComponentType.select in [component["type"] for component in components] and len(components) > 1 ): raise IncorrectFormat("Action row must have only one select component and nothing else") - return {"type": ComponentsType.actionrow, "components": components} - - -class ButtonStyle(enum.IntEnum): - blue = 1 - blurple = 1 - gray = 2 - grey = 2 - green = 3 - red = 4 - URL = 5 - - primary = 1 - secondary = 2 - success = 3 - danger = 4 + return {"type": ComponentType.actionrow, "components": components} def emoji_to_dict(emoji: typing.Union[discord.Emoji, discord.PartialEmoji, str]) -> dict: @@ -103,7 +82,7 @@ def create_button( emoji = emoji_to_dict(emoji) data = { - "type": ComponentsType.button, + "type": ComponentType.button, "style": style, } @@ -146,7 +125,11 @@ def create_select_option( def create_select( - options: typing.List[dict], custom_id=None, placeholder=None, min_values=None, max_values=None + options: typing.List[dict], + custom_id=None, + placeholder=None, + min_values=None, + max_values=None, ): """ Creates a select (dropdown) component for use with the ``components`` field. Must be inside an ActionRow to be used (see :meth:`create_actionrow`). @@ -158,7 +141,7 @@ def create_select( raise IncorrectFormat("Options length should be between 1 and 25.") return { - "type": ComponentsType.select, + "type": ComponentType.select, "options": options, "custom_id": custom_id or str(uuid.uuid4()), "placeholder": placeholder or "", @@ -167,51 +150,77 @@ def create_select( } -async def wait_for_component( - client: discord.Client, component: typing.Union[dict, str], check=None, timeout=None -) -> ComponentContext: +def get_components_ids(component: typing.Union[str, dict, list]) -> typing.Iterator[str]: """ - Waits for a component interaction. Only accepts interactions based on the custom ID of the component, and optionally a check function. + Returns generator with 'custom_id' of component or components. - :param client: The client/bot object. - :type client: :class:`discord.Client` - :param component: The component dict or custom ID. - :type component: Union[dict, str] - :param check: Optional check function. Must take a `ComponentContext` as the first parameter. - :param timeout: The number of seconds to wait before timing out and raising :exc:`asyncio.TimeoutError`. - :raises: :exc:`asyncio.TimeoutError` + :param component: Custom ID or component dict (actionrow or button) or list of previous two. """ - def _check(ctx): - if check and not check(ctx): - return False - return ( - component["custom_id"] if isinstance(component, dict) else component - ) == ctx.custom_id - - return await client.wait_for("component", check=_check, timeout=timeout) + if isinstance(component, str): + yield component + elif isinstance(component, dict): + if component["type"] == ComponentType.actionrow: + yield from (comp["custom_id"] for comp in component["components"]) + else: + yield component["custom_id"] + elif isinstance(component, list): + # Either list of components (actionrows or buttons) or list of ids + yield from (comp_id for comp in component for comp_id in get_components_ids(comp)) + else: + raise IncorrectType( + f"Unknown component type of {component} ({type(component)}). " + f"Expected str, dict or list" + ) + + +def _get_messages_ids(message: typing.Union[discord.Message, int, list]) -> typing.Iterator[int]: + if isinstance(message, int): + yield message + elif isinstance(message, discord.Message): + yield message.id + elif isinstance(message, list): + yield from (msg_id for msg in message for msg_id in _get_messages_ids(msg)) + else: + raise IncorrectType( + f"Unknown component type of {message} ({type(message)}). " + f"Expected discord.Message, int or list" + ) -async def wait_for_any_component( - client: discord.Client, message: typing.Union[discord.Message, int], check=None, timeout=None +async def wait_for_component( + client: discord.Client, + component: typing.Union[str, dict, list] = None, + message: typing.Union[discord.Message, int, list] = None, + check=None, + timeout=None, ) -> ComponentContext: """ - Waits for any component interaction. Only accepts interactions based on the message ID given and optionally a check function. + Helper function - wrapper around 'client.wait_for("component", ...)' + Waits for a component interaction. Only accepts interactions based on the custom ID of the component or/and message ID, and optionally a check function. :param client: The client/bot object. :type client: :class:`discord.Client` - :param message: The message object to check for, or the message ID. - :type message: Union[discord.Message, int] + :param component: Custom ID or component dict (actionrow or button) or list of previous two. + :param message: The message object to check for, or the message ID or list of previous two. + :type component: Union[dict, str] :param check: Optional check function. Must take a `ComponentContext` as the first parameter. :param timeout: The number of seconds to wait before timing out and raising :exc:`asyncio.TimeoutError`. :raises: :exc:`asyncio.TimeoutError` """ - def _check(ctx): + if not (component or message): + raise IncorrectFormat("You must specify component or message (or both)") + + components_ids = list(get_components_ids(component)) if component else None + message_ids = list(_get_messages_ids(message)) if message else None + + def _check(ctx: ComponentContext): if check and not check(ctx): return False - return ( - message.id if isinstance(message, discord.Message) else message - ) == ctx.origin_message_id + # if components_ids is empty or there is a match + wanted_component = not components_ids or ctx.custom_id in components_ids + wanted_message = not message_ids or ctx.origin_message_id in message_ids + return wanted_component and wanted_message return await client.wait_for("component", check=_check, timeout=timeout)