diff --git a/README.md b/README.md index fc90f7167..2640437ed 100644 --- a/README.md +++ b/README.md @@ -83,4 +83,4 @@ This library is based on gateway event. If you are looking for webserver based, [dispike](https://github.com/ms7m/dispike) [discord-interactions-python](https://github.com/discord/discord-interactions-python) Or for other languages: -[discord-api-docs Community Resources: Interactions](https://discord.com/developers/docs/topics/community-resources#interactions) +[discord-api-docs Community Resources: Interactions](https://discord.com/developers/docs/topics/community-resources#interactions) \ No newline at end of file diff --git a/discord_slash/__init__.py b/discord_slash/__init__.py index 332c97666..662495e71 100644 --- a/discord_slash/__init__.py +++ b/discord_slash/__init__.py @@ -11,6 +11,9 @@ from .client import SlashCommand from .model import SlashCommandOptionType from .context import SlashContext +from .context import ComponentContext +from .dpy_overrides import ComponentMessage from .utils import manage_commands +from .utils import manage_components __version__ = "1.2.2" diff --git a/discord_slash/client.py b/discord_slash/client.py index bb054d2ef..82bc521f8 100644 --- a/discord_slash/client.py +++ b/discord_slash/client.py @@ -9,6 +9,7 @@ from . import model from . import error from . import context +from . import dpy_overrides from .utils import manage_commands @@ -869,8 +870,21 @@ async def invoke_command(self, func, ctx, args): :param args: Args. Can be list or dict. """ try: - await func.invoke(ctx, args) + if isinstance(args, dict): + await func.invoke(ctx, **args) + else: + await func.invoke(ctx, *args) except Exception as ex: + if hasattr(func, "on_error"): + if func.on_error is not None: + try: + if hasattr(func, "cog"): + await func.on_error(func.cog, ctx, ex) + else: + await func.on_error(ctx, ex) + return + except Exception as e: + self.logger.error(f"{ctx.command}:: Error using error decorator: {e}") await self.on_slash_command_error(ctx, ex) async def on_socket_response(self, msg): @@ -886,10 +900,19 @@ async def on_socket_response(self, msg): return to_use = msg["d"] + interaction_type = to_use["type"] + if interaction_type in (1, 2): + return await self._on_slash(to_use) + if interaction_type == 3: + return await self._on_component(to_use) + + raise NotImplementedError - if to_use["type"] not in (1, 2): - return # to only process ack and slash-commands and exclude other interactions like buttons + async def _on_component(self, to_use): + ctx = context.ComponentContext(self.req, to_use, self._discord, self.logger) + self._discord.dispatch("component", ctx) + async def _on_slash(self, to_use): if to_use["data"]["name"] in self.commands: ctx = context.SlashContext(self.req, to_use, self._discord, self.logger) diff --git a/discord_slash/context.py b/discord_slash/context.py index 54dade599..fa56a10cb 100644 --- a/discord_slash/context.py +++ b/discord_slash/context.py @@ -1,3 +1,4 @@ +import datetime import typing import asyncio from warnings import warn @@ -5,27 +6,24 @@ import discord from contextlib import suppress from discord.ext import commands +from discord.utils import snowflake_time + from . import http from . import error from . import model +from . dpy_overrides import ComponentMessage -class SlashContext: +class InteractionContext: """ - Context of the slash command.\n + Base context for interactions.\n Kinda similar with discord.ext.commands.Context. .. warning:: Do not manually init this model. :ivar message: Message that invoked the slash command. - :ivar name: Name of the command. - :ivar args: List of processed arguments invoked with the command. - :ivar kwargs: Dictionary of processed arguments invoked with the command. - :ivar subcommand_name: Subcommand of the command. - :ivar subcommand_group: Subcommand group of the command. :ivar interaction_id: Interaction ID of the command message. - :ivar command_id: ID of the command. :ivar bot: discord.py client. :ivar _http: :class:`.http.SlashCommandRequest` of the client. :ivar _logger: Logger instance. @@ -43,15 +41,9 @@ def __init__(self, _json: dict, _discord: typing.Union[discord.Client, commands.Bot], logger): - self.__token = _json["token"] + self._token = _json["token"] self.message = None # Should be set later. - self.name = self.command = self.invoked_with = _json["data"]["name"] - self.args = [] - self.kwargs = {} - self.subcommand_name = self.invoked_subcommand = self.subcommand_passed = None - self.subcommand_group = self.invoked_subcommand_group = self.subcommand_group_passed = None self.interaction_id = _json["id"] - self.command_id = _json["data"]["id"] self._http = _http self.bot = _discord self._logger = logger @@ -67,6 +59,7 @@ def __init__(self, self.author = discord.User(data=_json["member"]["user"], state=self.bot._connection) else: self.author = discord.User(data=_json["user"], state=self.bot._connection) + self.created_at: datetime.datetime = snowflake_time(int(self.interaction_id)) @property def _deffered_hidden(self): @@ -118,7 +111,7 @@ async def defer(self, hidden: bool = False): if hidden: base["data"] = {"flags": 64} self._deferred_hidden = True - await self._http.post_initial_response(base, self.interaction_id, self.__token) + await self._http.post_initial_response(base, self.interaction_id, self._token) self.deferred = True async def send(self, @@ -130,7 +123,9 @@ async def send(self, files: typing.List[discord.File] = None, allowed_mentions: discord.AllowedMentions = None, hidden: bool = False, - delete_after: float = None) -> model.SlashMessage: + delete_after: float = None, + components: typing.List[dict] = None, + ) -> model.SlashMessage: """ Sends response of the slash command. @@ -157,6 +152,8 @@ async def send(self, :type hidden: bool :param delete_after: If provided, the number of seconds to wait in the background before deleting the message we just sent. If the deletion fails, then it is silently ignored. :type delete_after: float + :param components: Message components in the response. The top level must be made of ActionRows. + :type components: List[dict] :return: Union[discord.Message, dict] """ if embed and embeds: @@ -174,13 +171,16 @@ async def send(self, files = [file] if delete_after and hidden: raise error.IncorrectFormat("You can't delete a hidden message!") + if components and not all(comp.get("type") == 1 for comp in components): + raise error.IncorrectFormat("The top level of the components list must be made of ActionRows!") base = { "content": content, "tts": tts, "embeds": [x.to_dict() for x in embeds] if embeds else [], "allowed_mentions": allowed_mentions.to_dict() if allowed_mentions - else self.bot.allowed_mentions.to_dict() if self.bot.allowed_mentions else {} + else self.bot.allowed_mentions.to_dict() if self.bot.allowed_mentions else {}, + "components": components or [], } if hidden: base["flags"] = 64 @@ -196,21 +196,21 @@ async def send(self, "Deferred response might not be what you set it to! (hidden / visible) " "This is because it was deferred in a different state." ) - resp = await self._http.edit(base, self.__token, files=files) + resp = await self._http.edit(base, self._token, files=files) self.deferred = False else: json_data = { "type": 4, "data": base } - await self._http.post_initial_response(json_data, self.interaction_id, self.__token) + await self._http.post_initial_response(json_data, self.interaction_id, self._token) if not hidden: - resp = await self._http.edit({}, self.__token) + resp = await self._http.edit({}, self._token) else: resp = {} self.responded = True else: - resp = await self._http.post_followup(base, self.__token, files=files) + resp = await self._http.post_followup(base, self._token, files=files) if files: for file in files: file.close() @@ -219,7 +219,7 @@ async def send(self, data=resp, channel=self.channel or discord.Object(id=self.channel_id), _http=self._http, - interaction_token=self.__token) + interaction_token=self._token) if delete_after: self.bot.loop.create_task(smsg.delete(delay=delete_after)) if initial_message: @@ -227,3 +227,133 @@ async def send(self, return smsg else: return resp + + +class SlashContext(InteractionContext): + """ + Context of a slash command. Has all attributes from :class:`InteractionContext`, plus the slash-command-specific ones below. + + :ivar name: Name of the command. + :ivar args: List of processed arguments invoked with the command. + :ivar kwargs: Dictionary of processed arguments invoked with the command. + :ivar subcommand_name: Subcommand of the command. + :ivar subcommand_group: Subcommand group of the command. + :ivar command_id: ID of the command. + """ + def __init__(self, + _http: http.SlashCommandRequest, + _json: dict, + _discord: typing.Union[discord.Client, commands.Bot], + logger): + self.name = self.command = self.invoked_with = _json["data"]["name"] + self.args = [] + self.kwargs = {} + self.subcommand_name = self.invoked_subcommand = self.subcommand_passed = None + self.subcommand_group = self.invoked_subcommand_group = self.subcommand_group_passed = None + self.command_id = _json["data"]["id"] + + super().__init__(_http=_http, _json=_json, _discord=_discord, logger=logger) + + +class ComponentContext(InteractionContext): + """ + Context of a component interaction. Has all attributes from :class:`InteractionContext`, plus the component-specific ones below. + + :ivar custom_id: The custom ID of the component. + :ivar component_type: The type of the component. + :ivar origin_message: The origin message of the component. Not available if the origin message was ephemeral. + :ivar origin_message_id: The ID of the origin message. + """ + def __init__(self, + _http: http.SlashCommandRequest, + _json: dict, + _discord: typing.Union[discord.Client, commands.Bot], + logger): + self.custom_id = self.component_id = _json["data"]["custom_id"] + self.component_type = _json["data"]["component_type"] + super().__init__(_http=_http, _json=_json, _discord=_discord, logger=logger) + self.origin_message = None + self.origin_message_id = int(_json["message"]["id"]) if "message" in _json.keys() else None + + if self.origin_message_id and (_json["message"]["flags"] & 64) != 64: + self.origin_message = ComponentMessage(state=self.bot._connection, channel=self.channel, + data=_json["message"]) + + async def defer(self, hidden: bool = False, edit_origin: bool = False): + """ + 'Defers' the response, showing a loading state to the user + + :param hidden: Whether the deferred response should be ephemeral . Default ``False``. + :param edit_origin: Whether the response is editing the origin message. If ``False``, the deferred response will be for a follow up message. Defaults ``False``. + """ + if self.deferred or self.responded: + raise error.AlreadyResponded("You have already responded to this command!") + base = {"type": 6 if edit_origin else 5} + if hidden and not edit_origin: + base["data"] = {"flags": 64} + self._deferred_hidden = True + await self._http.post_initial_response(base, self.interaction_id, self._token) + self.deferred = True + + async def edit_origin(self, **fields): + """ + Edits the origin message of the component. + Refer to :meth:`discord.Message.edit` and :meth:`InteractionContext.send` for fields. + """ + _resp = {} + + content = fields.get("content") + if content: + _resp["content"] = str(content) + + embed = fields.get("embed") + embeds = fields.get("embeds") + file = fields.get("file") + files = fields.get("files") + components = fields.get("components") + + if components: + _resp["components"] = components + + if embed and embeds: + raise error.IncorrectFormat("You can't use both `embed` and `embeds`!") + if file and files: + raise error.IncorrectFormat("You can't use both `file` and `files`!") + if file: + files = [file] + if embed: + embeds = [embed] + if embeds: + if not isinstance(embeds, list): + raise error.IncorrectFormat("Provide a list of embeds.") + elif len(embeds) > 10: + raise error.IncorrectFormat("Do not provide more than 10 embeds.") + _resp["embeds"] = [x.to_dict() for x in embeds] + + allowed_mentions = fields.get("allowed_mentions") + _resp["allowed_mentions"] = allowed_mentions.to_dict() if allowed_mentions else \ + self.bot.allowed_mentions.to_dict() if self.bot.allowed_mentions else {} + + if not self.responded: + if files and not self.deferred: + await self.defer(edit_origin=True) + if self.deferred: + _json = await self._http.edit(_resp, self._token, files=files) + self.deferred = False + else: + json_data = { + "type": 7, + "data": _resp + } + _json = await self._http.post_initial_response(json_data, self.interaction_id, self._token) + self.responded = True + else: + raise error.IncorrectFormat("Already responded") + + if files: + for file in files: + file.close() + + # Commented out for now as sometimes (or at least, when not deferred) _json is an empty string? + # self.origin_message = ComponentMessage(state=self.bot._connection, channel=self.channel, + # data=_json) diff --git a/discord_slash/dpy_overrides.py b/discord_slash/dpy_overrides.py new file mode 100644 index 000000000..58d63c732 --- /dev/null +++ b/discord_slash/dpy_overrides.py @@ -0,0 +1,254 @@ +import discord +from discord.ext import commands +from discord import AllowedMentions, InvalidArgument, File +from discord.http import Route +from discord import http +from discord import abc +from discord import utils + + +class ComponentMessage(discord.Message): + __slots__ = tuple(list(discord.Message.__slots__) + ["components"]) + + def __init__(self, *, state, channel, data): + super().__init__(state=state, channel=channel, data=data) + self.components = data['components'] + + +def new_override(cls, *args, **kwargs): + if cls is discord.Message: + return object.__new__(ComponentMessage) + else: + return object.__new__(cls) + + +discord.message.Message.__new__ = new_override + + +def send_files(self, channel_id, *, files, content=None, tts=False, embed=None, components=None, + nonce=None, allowed_mentions=None, message_reference=None): + r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id) + form = [] + + payload = {'tts': tts} + if content: + payload['content'] = content + if embed: + payload['embed'] = embed + if components: + payload['components'] = components + if nonce: + payload['nonce'] = nonce + if allowed_mentions: + payload['allowed_mentions'] = allowed_mentions + if message_reference: + payload['message_reference'] = message_reference + + form.append({'name': 'payload_json', 'value': utils.to_json(payload)}) + if len(files) == 1: + file = files[0] + form.append({ + 'name': 'file', + 'value': file.fp, + 'filename': file.filename, + 'content_type': 'application/octet-stream' + }) + else: + for index, file in enumerate(files): + form.append({ + 'name': 'file%s' % index, + 'value': file.fp, + 'filename': file.filename, + 'content_type': 'application/octet-stream' + }) + + return self.request(r, form=form, files=files) + + +def send_message(self, channel_id, content, *, tts=False, embed=None, components=None, + nonce=None, allowed_mentions=None, message_reference=None): + r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id) + payload = {} + + if content: + payload['content'] = content + + if tts: + payload['tts'] = True + + if embed: + payload['embed'] = embed + + if components: + payload['components'] = components + + if nonce: + payload['nonce'] = nonce + + if allowed_mentions: + payload['allowed_mentions'] = allowed_mentions + + if message_reference: + payload['message_reference'] = message_reference + + return self.request(r, json=payload) + + +http.HTTPClient.send_files = send_files +http.HTTPClient.send_message = send_message + + +async def send(self, content=None, *, tts=False, embed=None, file=None, components=None, + files=None, delete_after=None, nonce=None, + allowed_mentions=None, reference=None, + mention_author=None): + """|coro| + + Sends a message to the destination with the content given. + + The content must be a type that can convert to a string through ``str(content)``. + If the content is set to ``None`` (the default), then the ``embed`` parameter must + be provided. + + To upload a single file, the ``file`` parameter should be used with a + single :class:`~discord.File` object. To upload multiple files, the ``files`` + parameter should be used with a :class:`list` of :class:`~discord.File` objects. + **Specifying both parameters will lead to an exception**. + + If the ``embed`` parameter is provided, it must be of type :class:`~discord.Embed` and + it must be a rich embed type. + + Parameters + ------------ + content: :class:`str` + The content of the message to send. + tts: :class:`bool` + Indicates if the message should be sent using text-to-speech. + embed: :class:`~discord.Embed` + The rich embed for the content. + file: :class:`~discord.File` + The file to upload. + files: List[:class:`~discord.File`] + A list of files to upload. Must be a maximum of 10. + nonce: :class:`int` + The nonce to use for sending this message. If the message was successfully sent, + then the message will have a nonce with this value. + delete_after: :class:`float` + If provided, the number of seconds to wait in the background + before deleting the message we just sent. If the deletion fails, + then it is silently ignored. + allowed_mentions: :class:`~discord.AllowedMentions` + Controls the mentions being processed in this message. If this is + passed, then the object is merged with :attr:`~discord.Client.allowed_mentions`. + The merging behaviour only overrides attributes that have been explicitly passed + to the object, otherwise it uses the attributes set in :attr:`~discord.Client.allowed_mentions`. + If no object is passed at all then the defaults given by :attr:`~discord.Client.allowed_mentions` + are used instead. + + .. versionadded:: 1.4 + + reference: Union[:class:`~discord.Message`, :class:`~discord.MessageReference`] + A reference to the :class:`~discord.Message` to which you are replying, this can be created using + :meth:`~discord.Message.to_reference` or passed directly as a :class:`~discord.Message`. You can control + whether this mentions the author of the referenced message using the :attr:`~discord.AllowedMentions.replied_user` + attribute of ``allowed_mentions`` or by setting ``mention_author``. + + .. versionadded:: 1.6 + + mention_author: Optional[:class:`bool`] + If set, overrides the :attr:`~discord.AllowedMentions.replied_user` attribute of ``allowed_mentions``. + + .. versionadded:: 1.6 + + Raises + -------- + ~discord.HTTPException + Sending the message failed. + ~discord.Forbidden + You do not have the proper permissions to send the message. + ~discord.InvalidArgument + The ``files`` list is not of the appropriate size, + you specified both ``file`` and ``files``, + or the ``reference`` object is not a :class:`~discord.Message` + or :class:`~discord.MessageReference`. + + Returns + --------- + :class:`~discord.Message` + The message that was sent. + """ + + channel = await self._get_channel() + state = self._state + content = str(content) if content is not None else None + components = components or [] + if embed is not None: + embed = embed.to_dict() + + if allowed_mentions is not None: + if state.allowed_mentions is not None: + allowed_mentions = state.allowed_mentions.merge(allowed_mentions).to_dict() + else: + allowed_mentions = allowed_mentions.to_dict() + else: + allowed_mentions = state.allowed_mentions and state.allowed_mentions.to_dict() + + if mention_author is not None: + allowed_mentions = allowed_mentions or AllowedMentions().to_dict() + allowed_mentions['replied_user'] = bool(mention_author) + + if reference is not None: + try: + reference = reference.to_message_reference_dict() + except AttributeError: + raise InvalidArgument('reference parameter must be Message or MessageReference') from None + + if file is not None and files is not None: + raise InvalidArgument('cannot pass both file and files parameter to send()') + + if file is not None: + if not isinstance(file, File): + raise InvalidArgument('file parameter must be File') + + try: + data = await state.http.send_files(channel.id, files=[file], allowed_mentions=allowed_mentions, + content=content, tts=tts, embed=embed, nonce=nonce, + components=components, + message_reference=reference) + finally: + file.close() + + elif files is not None: + if len(files) > 10: + raise InvalidArgument('files parameter must be a list of up to 10 elements') + elif not all(isinstance(file, File) for file in files): + raise InvalidArgument('files parameter must be a list of File') + + try: + data = await state.http.send_files(channel.id, files=files, content=content, tts=tts, + embed=embed, nonce=nonce, allowed_mentions=allowed_mentions, + components=components, + message_reference=reference) + finally: + for f in files: + f.close() + else: + data = await state.http.send_message(channel.id, content, tts=tts, embed=embed, components=components, + nonce=nonce, allowed_mentions=allowed_mentions, + message_reference=reference) + + ret = state.create_message(channel=channel, data=data) + if delete_after is not None: + await ret.delete(delay=delete_after) + return ret + + +async def send_override(context_or_channel, *args, **kwargs): + if isinstance(context_or_channel, commands.Context): + channel = context_or_channel.channel + else: + channel = context_or_channel + + return await send(channel, *args, **kwargs) + +abc.Messageable.send = send_override diff --git a/discord_slash/model.py b/discord_slash/model.py index 5d4e96194..df9f7021b 100644 --- a/discord_slash/model.py +++ b/discord_slash/model.py @@ -1,10 +1,16 @@ import asyncio +import datetime + import discord from enum import IntEnum from contextlib import suppress from inspect import iscoroutinefunction + +from discord.ext.commands import CooldownMapping, CommandOnCooldown + from . import http from . import error +from . dpy_overrides import ComponentMessage class ChoiceData: @@ -134,38 +140,119 @@ def __init__(self, name, cmd): # Let's reuse old command formatting. if hasattr(self.func, '__commands_checks__'): self.__commands_checks__ = self.func.__commands_checks__ - async def invoke(self, *args): + cooldown = None + if hasattr(self.func, "__commands_cooldown__"): + cooldown = self.func.__commands_cooldown__ + self._buckets = CooldownMapping(cooldown) + + self._max_concurrency = None + if hasattr(self.func, "__commands_max_concurrency__"): + self._max_concurrency = self.func.__commands_max_concurrency__ + + self.on_error = None + + def error(self, coro): + if not asyncio.iscoroutinefunction(coro): + raise TypeError("The error handler must be a coroutine.") + self.on_error = coro + return coro + + def _prepare_cooldowns(self, ctx): + """ + Ref https://github.com/Rapptz/discord.py/blob/master/discord/ext/commands/core.py#L765 + """ + if self._buckets.valid: + dt = ctx.created_at + current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() + bucket = self._buckets.get_bucket(ctx, current) + retry_after = bucket.update_rate_limit(current) + if retry_after: + raise CommandOnCooldown(bucket, retry_after) + + async def _concurrency_checks(self, ctx): + """The checks required for cooldown and max concurrency.""" + # max concurrency checks + if self._max_concurrency is not None: + await self._max_concurrency.acquire(ctx) + try: + # cooldown checks + self._prepare_cooldowns(ctx) + except: + if self._max_concurrency is not None: + await self._max_concurrency.release(ctx) + raise + + async def invoke(self, *args, **kwargs): """ Invokes the command. :param args: Args for the command. :raises: .error.CheckFailure """ - args = list(args) - ctx = args.pop(0) - can_run = await self.can_run(ctx) + can_run = await self.can_run(args[0]) if not can_run: raise error.CheckFailure - coro = None # Get rid of annoying IDE complainings. + await self._concurrency_checks(args[0]) + + # to preventing needing different functions per object, + # this function simply handles cogs + if hasattr(self, "cog"): + return await self.func(self.cog, *args, **kwargs) + return await self.func(*args, **kwargs) + + def is_on_cooldown(self, ctx): + """Checks whether the command is currently on cooldown. + Ref https://github.com/Rapptz/discord.py/blob/master/discord/ext/commands/core.py#L797 + Parameters + ----------- + ctx: :class:`.Context` + The invocation context to use when checking the commands cooldown status. + Returns + -------- + :class:`bool` + A boolean indicating if the command is on cooldown. + """ + if not self._buckets.valid: + return False - not_kwargs = False - if args and isinstance(args[0], dict): - kwargs = args[0] - ctx.kwargs = kwargs - ctx.args = list(kwargs.values()) - try: - coro = self.func(ctx, **kwargs) - except TypeError: - args = list(kwargs.values()) - not_kwargs = True - else: - ctx.args = args - not_kwargs = True - if not_kwargs: - coro = self.func(ctx, *args) + bucket = self._buckets.get_bucket(ctx.message) + dt = ctx.message.edited_at or ctx.message.created_at + current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() + return bucket.get_tokens(current) == 0 + + def reset_cooldown(self, ctx): + """Resets the cooldown on this command. + Ref https://github.com/Rapptz/discord.py/blob/master/discord/ext/commands/core.py#L818 + Parameters + ----------- + ctx: :class:`.Context` + The invocation context to reset the cooldown under. + """ + if self._buckets.valid: + bucket = self._buckets.get_bucket(ctx.message) + bucket.reset() + + def get_cooldown_retry_after(self, ctx): + """Retrieves the amount of seconds before this command can be tried again. + Ref https://github.com/Rapptz/discord.py/blob/master/discord/ext/commands/core.py#L830 + Parameters + ----------- + ctx: :class:`.Context` + The invocation context to retrieve the cooldown from. + Returns + -------- + :class:`float` + The amount of time left on this command's cooldown in seconds. + If this is ``0.0`` then the command isn't on cooldown. + """ + if self._buckets.valid: + bucket = self._buckets.get_bucket(ctx.message) + dt = ctx.message.edited_at or ctx.message.created_at + current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() + return bucket.get_retry_after(current) - return await coro + return 0.0 def add_check(self, func): """ @@ -220,7 +307,6 @@ def __init__(self, name, cmd): # Let's reuse old command formatting. self.default_permission = cmd["default_permission"] self.permissions = cmd["api_permissions"] or {} - class SubcommandObject(CommandObject): """ Subcommand object of this extension. @@ -257,39 +343,6 @@ def __init__(self, *args): super().__init__(*args) self.cog = None # Manually set this later. - async def invoke(self, *args, **kwargs): - """ - Invokes the command. - - :param args: Args for the command. - :raises: .error.CheckFailure - """ - args = list(args) - ctx = args.pop(0) - can_run = await self.can_run(ctx) - if not can_run: - raise error.CheckFailure - - coro = None # Get rid of annoying IDE complainings. - - not_kwargs = False - if args and isinstance(args[0], dict): - kwargs = args[0] - ctx.kwargs = kwargs - ctx.args = list(kwargs.values()) - try: - coro = self.func(self.cog, ctx, **kwargs) - except TypeError: - args = list(kwargs.values()) - not_kwargs = True - else: - ctx.args = args - not_kwargs = True - if not_kwargs: - coro = self.func(self.cog, ctx, *args) - - return await coro - class CogSubcommandObject(SubcommandObject): """ @@ -304,39 +357,6 @@ def __init__(self, base, cmd, sub_group, name, sub): self.base_command_data = cmd self.cog = None # Manually set this later. - async def invoke(self, *args, **kwargs): - """ - Invokes the command. - - :param args: Args for the command. - :raises: .error.CheckFailure - """ - args = list(args) - ctx = args.pop(0) - can_run = await self.can_run(ctx) - if not can_run: - raise error.CheckFailure - - coro = None # Get rid of annoying IDE complainings. - - not_kwargs = False - if args and isinstance(args[0], dict): - kwargs = args[0] - ctx.kwargs = kwargs - ctx.args = list(kwargs.values()) - try: - coro = self.func(self.cog, ctx, **kwargs) - except TypeError: - args = list(kwargs.values()) - not_kwargs = True - else: - ctx.args = args - not_kwargs = True - if not_kwargs: - coro = self.func(self.cog, ctx, *args) - - return await coro - class SlashCommandOptionType(IntEnum): """ @@ -368,7 +388,7 @@ def from_type(cls, t: type): if issubclass(t, discord.abc.Role): return cls.ROLE -class SlashMessage(discord.Message): +class SlashMessage(ComponentMessage): """discord.py's :class:`discord.Message` but overridden ``edit`` and ``delete`` to work for slash command.""" def __init__(self, *, state, channel, data, _http: http.SlashCommandRequest, interaction_token): @@ -391,6 +411,10 @@ async def _slash_edit(self, **fields): embeds = fields.get("embeds") file = fields.get("file") files = fields.get("files") + components = fields.get("components") + + if components: + _resp["components"] = components if embed and embeds: raise error.IncorrectFormat("You can't use both `embed` and `embeds`!") diff --git a/discord_slash/utils/manage_components.py b/discord_slash/utils/manage_components.py new file mode 100644 index 000000000..add1a104e --- /dev/null +++ b/discord_slash/utils/manage_components.py @@ -0,0 +1,201 @@ +import uuid +import enum +import typing +import discord +from ..context import ComponentContext +from ..error import IncorrectFormat + + +class ComponentsType(enum.IntEnum): + actionrow = 1 + button = 2 + select = 3 + + +def create_actionrow(*components: dict) -> dict: + """ + Creates an ActionRow for message components. + + :param components: Components to go within the ActionRow. + :return: dict + """ + if not components or len(components) > 5: + raise IncorrectFormat("Number of components in one row should be between 1 and 5.") + if ComponentsType.select in [component["type"] for component in components] and len(components) > 1: + raise IncorrectFormat("Action row must have only one select component and nothing else") + + return { + "type": ComponentsType.actionrow, + "components": components + } + + +class ButtonStyle(enum.IntEnum): + blue = 1 + blurple = 1 + gray = 2 + grey = 2 + green = 3 + red = 4 + URL = 5 + + primary = 1 + secondary = 2 + success = 3 + danger = 4 + + +def emoji_to_dict(emoji: typing.Union[discord.Emoji, discord.PartialEmoji, str]) -> dict: + """ + Converts a default or custom emoji into a partial emoji dict. + + :param emoji: The emoji to convert. + :type emoji: Union[discord.Emoji, discord.PartialEmoji, str] + """ + if isinstance(emoji, discord.Emoji): + emoji = {"name": emoji.name, "id": emoji.id, "animated": emoji.animated} + elif isinstance(emoji, str): + emoji = {"name": emoji, "id": None} + return emoji if emoji else {} + + +def create_button(style: typing.Union[ButtonStyle, int], + label: str = None, + emoji: typing.Union[discord.Emoji, discord.PartialEmoji, str] = None, + custom_id: str = None, + url: str = None, + disabled: bool = False) -> dict: + """ + Creates a button component for use with the ``components`` field. Must be inside an ActionRow to be used (see :meth:`create_actionrow`). + + .. note:: + At least a label or emoji is required for a button. You can have both, but not neither of them. + + :param style: Style of the button. Refer to :class:`ButtonStyle`. + :type style: Union[ButtonStyle, int] + :param label: The label of the button. + :type label: Optional[str] + :param emoji: The emoji of the button. + :type emoji: Union[discord.Emoji, discord.PartialEmoji, dict] + :param custom_id: The custom_id of the button. Needed for non-link buttons. + :type custom_id: Optional[str] + :param url: The URL of the button. Needed for link buttons. + :type url: Optional[str] + :param disabled: Whether the button is disabled or not. Defaults to `False`. + :type disabled: bool + :returns: :class:`dict` + """ + if style == ButtonStyle.URL: + if custom_id: + raise IncorrectFormat("A link button cannot have a `custom_id`!") + if not url: + raise IncorrectFormat("A link button must have a `url`!") + elif url: + raise IncorrectFormat("You can't have a URL on a non-link button!") + + if not label and not emoji: + raise IncorrectFormat("You must have at least a label or emoji on a button.") + + emoji = emoji_to_dict(emoji) + + data = { + "type": ComponentsType.button, + "style": style, + } + + if label: + data["label"] = label + if emoji: + data["emoji"] = emoji + if disabled: + data["disabled"] = disabled + + if style == ButtonStyle.URL: + data["url"] = url + else: + data["custom_id"] = custom_id or str(uuid.uuid4()) + + return data + + +def create_select_option(label: str, value: str, emoji=None, description: str = None, default: bool = False): + """ + Creates an option for select components. + + :param label: The label of the option. + :param value: The value that the bot will recieve when this option is selected. + :param emoji: The emoji of the option. + :param description: A description of the option. + :param default: Whether or not this is the default option. + """ + emoji = emoji_to_dict(emoji) + + return { + "label": label, + "value": value, + "description": description, + "default": default, + "emoji": emoji + } + + +def create_select(options: typing.List[dict], custom_id=None, placeholder=None, min_values=None, max_values=None): + """ + Creates a select (dropdown) component for use with the ``components`` field. Must be inside an ActionRow to be used (see :meth:`create_actionrow`). + + .. warning:: + Currently, select components are not available for public use, nor have official documentation. The parameters will not be documented at this time. + """ + if not len(options) or len(options) > 25: + raise IncorrectFormat("Options length should be between 1 and 25.") + + return { + "type": ComponentsType.select, + "options": options, + "custom_id": custom_id or str(uuid.uuid4()), + "placeholder": placeholder or "", + "min_values": min_values, + "max_values": max_values, + } + + +async def wait_for_component(client: discord.Client, component: typing.Union[dict, str], check=None, timeout=None) \ + -> ComponentContext: + """ + Waits for a component interaction. Only accepts interactions based on the custom ID of the component, and optionally a check function. + + :param client: The client/bot object. + :type client: :class:`discord.Client` + :param component: The component dict or custom ID. + :type component: Union[dict, str] + :param check: Optional check function. Must take a `ComponentContext` as the first parameter. + :param timeout: The number of seconds to wait before timing out and raising :exc:`asyncio.TimeoutError`. + :raises: :exc:`asyncio.TimeoutError` + """ + def _check(ctx): + if check and not check(ctx): + return False + return (component["custom_id"] if isinstance(component, dict) else component) == ctx.custom_id + + return await client.wait_for("component", check=_check, timeout=timeout) + + +async def wait_for_any_component(client: discord.Client, message: typing.Union[discord.Message, int], + check=None, timeout=None) -> ComponentContext: + """ + Waits for any component interaction. Only accepts interactions based on the message ID given and optionally a check function. + + :param client: The client/bot object. + :type client: :class:`discord.Client` + :param message: The message object to check for, or the message ID. + :type message: Union[discord.Message, int] + :param check: Optional check function. Must take a `ComponentContext` as the first parameter. + :param timeout: The number of seconds to wait before timing out and raising :exc:`asyncio.TimeoutError`. + :raises: :exc:`asyncio.TimeoutError` + """ + def _check(ctx): + if check and not check(ctx): + return False + return (message.id if isinstance(message, discord.Message) else message) == ctx.origin_message_id + + return await client.wait_for("component", check=_check, timeout=timeout) diff --git a/docs/conf.py b/docs/conf.py index 4fd63b95f..8cb9de81e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -68,4 +68,4 @@ intersphinx_mapping = { 'py': ('https://docs.python.org/3', None), 'discord': ("https://discordpy.readthedocs.io/en/latest/", None) -} +} \ No newline at end of file diff --git a/docs/discord_slash.utils.manage_components.rst b/docs/discord_slash.utils.manage_components.rst new file mode 100644 index 000000000..b2103f5f4 --- /dev/null +++ b/docs/discord_slash.utils.manage_components.rst @@ -0,0 +1,7 @@ +discord\_slash.utils.manage\_components module +============================================== + +.. automodule:: discord_slash.utils.manage_components + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/discord_slash.utils.rst b/docs/discord_slash.utils.rst index b9abe14c6..dd460bd73 100644 --- a/docs/discord_slash.utils.rst +++ b/docs/discord_slash.utils.rst @@ -8,6 +8,7 @@ Submodules :maxdepth: 4 discord_slash.utils.manage_commands + discord_slash.utils.manage_components Module contents --------------- diff --git a/docs/events.rst b/docs/events.rst index d8cbf671f..f695fd5e4 100644 --- a/docs/events.rst +++ b/docs/events.rst @@ -20,3 +20,10 @@ These events can be registered to discord.py's listener or :param ex: Exception that raised. :type ex: Exception +.. function:: on_component(ctx) + + Called when a component is triggered. + + :param ctx: ComponentContext of the triggered component. + :type ctx: :class:`.model.ComponentContext` +