diff --git a/interactions/__init__.py b/interactions/__init__.py index 42bf6933e..efcac1b65 100644 --- a/interactions/__init__.py +++ b/interactions/__init__.py @@ -133,6 +133,8 @@ FlatUIColors, FlatUIColours, get_components_ids, + global_autocomplete, + GlobalAutoComplete, Greedy, Guild, guild_only, @@ -444,7 +446,9 @@ "FlatUIColours", "get_components_ids", "get_logger", + "global_autocomplete", "GLOBAL_SCOPE", + "GlobalAutoComplete", "GlobalScope", "Greedy", "Guild", diff --git a/interactions/client/client.py b/interactions/client/client.py index db51d45db..75cea2268 100644 --- a/interactions/client/client.py +++ b/interactions/client/client.py @@ -26,6 +26,7 @@ import interactions.api.events as events import interactions.client.const as constants +from interactions.models.internal.callback import CallbackObject from interactions.api.events import BaseEvent, RawGatewayEvent, processors from interactions.api.events.internal import CallbackAdded from interactions.api.gateway.gateway import GatewayClient @@ -90,7 +91,7 @@ from interactions.models.discord.file import UPLOADABLE_TYPE from interactions.models.discord.snowflake import Snowflake, to_snowflake_list from interactions.models.internal.active_voice_state import ActiveVoiceState -from interactions.models.internal.application_commands import ContextMenu, ModalCommand +from interactions.models.internal.application_commands import ContextMenu, ModalCommand, GlobalAutoComplete from interactions.models.internal.auto_defer import AutoDefer from interactions.models.internal.command import BaseCommand from interactions.models.internal.context import ( @@ -378,6 +379,7 @@ def __init__( """A dictionary of registered application commands in a tree""" self._component_callbacks: Dict[str, Callable[..., Coroutine]] = {} self._modal_callbacks: Dict[str, Callable[..., Coroutine]] = {} + self._global_autocompletes: Dict[str, GlobalAutoComplete] = {} self.processors: Dict[str, Callable[..., Coroutine]] = {} self.__modules = {} self.ext: Dict[str, Extension] = {} @@ -1256,6 +1258,15 @@ def add_modal_callback(self, command: ModalCommand) -> None: self._modal_callbacks[listener] = command continue + def add_global_autocomplete(self, callback: GlobalAutoComplete) -> None: + """ + Add a global autocomplete to the client. + + Args: + callback: The autocomplete to add + """ + self._global_autocompletes[callback.option_name] = callback + def add_command(self, func: Callable) -> None: """ Add a command to the client. @@ -1271,6 +1282,8 @@ def add_command(self, func: Callable) -> None: self.add_interaction(func) elif isinstance(func, Listener): self.add_listener(func) + elif isinstance(func, GlobalAutoComplete): + self.add_global_autocomplete(func) elif not isinstance(func, BaseCommand): raise TypeError("Invalid command type") @@ -1302,12 +1315,10 @@ def process(callables, location: str) -> None: self.logger.debug(f"{added} callbacks have been loaded from {location}.") main_commands = [ - obj for _, obj in inspect.getmembers(sys.modules["__main__"]) if isinstance(obj, (BaseCommand, Listener)) + obj for _, obj in inspect.getmembers(sys.modules["__main__"]) if isinstance(obj, CallbackObject) ] client_commands = [ - obj.copy_with_binding(self) - for _, obj in inspect.getmembers(self) - if isinstance(obj, (BaseCommand, Listener)) + obj.copy_with_binding(self) for _, obj in inspect.getmembers(self) if isinstance(obj, CallbackObject) ] process(main_commands, "__main__") process(client_commands, self.__class__.__name__) @@ -1597,7 +1608,6 @@ async def _dispatch_interaction(self, event: RawGatewayEvent) -> None: elif autocomplete := self._global_autocompletes.get(str(auto_opt.name)): callback = autocomplete else: - breakpoint() raise ValueError(f"Autocomplete callback for {str(auto_opt.name)} not found") await self.__dispatch_interaction( diff --git a/interactions/models/__init__.py b/interactions/models/__init__.py index 1c7e93878..b7880056d 100644 --- a/interactions/models/__init__.py +++ b/interactions/models/__init__.py @@ -219,6 +219,8 @@ DMConverter, DMGroupConverter, Extension, + global_autocomplete, + GlobalAutoComplete, Greedy, guild_only, GuildCategoryConverter, @@ -382,6 +384,8 @@ "FlatUIColors", "FlatUIColours", "get_components_ids", + "global_autocomplete", + "GlobalAutoComplete", "Greedy", "Guild", "guild_only", diff --git a/interactions/models/internal/__init__.py b/interactions/models/internal/__init__.py index d68e13a98..1c9b0431f 100644 --- a/interactions/models/internal/__init__.py +++ b/interactions/models/internal/__init__.py @@ -22,6 +22,8 @@ ComponentCommand, context_menu, ContextMenu, + global_autocomplete, + GlobalAutoComplete, InteractionCommand, LocalisedDesc, LocalisedName, @@ -126,6 +128,8 @@ "DMConverter", "DMGroupConverter", "Extension", + "global_autocomplete", + "GlobalAutoComplete", "Greedy", "guild_only", "GuildCategoryConverter", diff --git a/interactions/models/internal/application_commands.py b/interactions/models/internal/application_commands.py index 820a34d99..86708cc89 100644 --- a/interactions/models/internal/application_commands.py +++ b/interactions/models/internal/application_commands.py @@ -39,6 +39,7 @@ from interactions.models.discord.snowflake import to_snowflake_list, to_snowflake from interactions.models.discord.user import BaseUser from interactions.models.internal.auto_defer import AutoDefer +from interactions.models.internal.callback import CallbackObject from interactions.models.internal.command import BaseCommand from interactions.models.internal.localisation import LocalisedField @@ -48,28 +49,30 @@ from interactions import Client __all__ = ( - "OptionType", + "application_commands_to_dict", + "auto_defer", "CallbackType", - "InteractionCommand", - "ContextMenu", - "SlashCommandChoice", - "SlashCommandOption", - "SlashCommand", + "component_callback", "ComponentCommand", + "context_menu", + "ContextMenu", + "global_autocomplete", + "GlobalAutoComplete", + "InteractionCommand", + "LocalisedDesc", + "LocalisedName", + "LocalizedDesc", + "LocalizedName", "ModalCommand", + "OptionType", "slash_command", - "subcommand", - "context_menu", - "component_callback", - "slash_option", "slash_default_member_permission", - "auto_defer", - "application_commands_to_dict", + "slash_option", + "SlashCommand", + "SlashCommandChoice", + "SlashCommandOption", + "subcommand", "sync_needed", - "LocalisedName", - "LocalizedName", - "LocalizedDesc", - "LocalisedDesc", ) @@ -674,11 +677,36 @@ def _unpack_helper(iterable: typing.Iterable[str]) -> list[str]: return unpack +class GlobalAutoComplete(CallbackObject): + def __init__(self, option_name: str, callback: Callable) -> None: + self.callback = callback + self.option_name = option_name + + ############## # Decorators # ############## +def global_autocomplete(option_name: str) -> Callable[[AsyncCallable], GlobalAutoComplete]: + """ + Decorator for global autocomplete functions + + Args: + option_name: The name of the option to register the autocomplete function for + + Returns: + The decorator + """ + + def decorator(func: Callable) -> GlobalAutoComplete: + if not asyncio.iscoroutinefunction(func): + raise TypeError("Autocomplete functions must be coroutines") + return GlobalAutoComplete(option_name, func) + + return decorator + + def slash_command( name: str | LocalisedName, *, diff --git a/interactions/models/internal/extension.py b/interactions/models/internal/extension.py index dd987de9d..326484f97 100644 --- a/interactions/models/internal/extension.py +++ b/interactions/models/internal/extension.py @@ -5,6 +5,7 @@ import interactions.models.internal as models import interactions.api.events as events +from interactions.models.internal.callback import CallbackObject from interactions.client.const import MISSING from interactions.client.utils.misc_utils import wrap_partial from interactions.models.internal.tasks import Task @@ -94,7 +95,7 @@ def __new__(cls, bot: "Client", *args, **kwargs) -> "Extension": instance._listeners = [] callables: list[tuple[str, typing.Callable]] = inspect.getmembers( - instance, predicate=lambda x: isinstance(x, (models.BaseCommand, models.Listener, Task)) + instance, predicate=lambda x: isinstance(x, (CallbackObject, Task)) ) for _name, val in callables: @@ -112,6 +113,10 @@ def __new__(cls, bot: "Client", *args, **kwargs) -> "Extension": val = wrap_partial(val, instance) bot.add_listener(val) # type: ignore instance._listeners.append(val) + elif isinstance(val, models.GlobalAutoComplete): + val.extension = instance + val = wrap_partial(val, instance) + bot.add_global_autocomplete(val) bot.dispatch(events.ExtensionCommandParse(extension=instance, callables=callables)) instance.extension_name = inspect.getmodule(instance).__name__ diff --git a/main.py b/main.py index 429782dcc..9b48667fa 100644 --- a/main.py +++ b/main.py @@ -2,8 +2,11 @@ import os import uuid +from thefuzz import process + import interactions -from interactions import Client, listen, slash_command, BrandColours +from interactions import Client, listen, slash_command, BrandColours, FlatUIColours, MaterialColours +from interactions.models.internal.application_commands import global_autocomplete, slash_option logging.basicConfig() logging.getLogger("interactions").setLevel(logging.DEBUG) @@ -104,4 +107,42 @@ async def multi_image_embed_test(ctx: interactions.SlashContext): await ctx.send(embeds=embed) +def get_colour(colour: str): + if colour in interactions.MaterialColors.__members__: + return interactions.MaterialColors[colour] + elif colour in interactions.BrandColors.__members__: + return interactions.BrandColors[colour] + elif colour in interactions.FlatUIColours.__members__: + return interactions.FlatUIColours[colour] + else: + return interactions.BrandColors.BLURPLE + + +@slash_command("test") +@slash_option("colour", "The colour to use", autocomplete=True, opt_type=interactions.OptionType.STRING, required=True) +@slash_option("text", "some text", autocomplete=True, opt_type=interactions.OptionType.STRING, required=True) +async def test(ctx: interactions.SlashContext, colour: str, text: str): + embed = interactions.Embed(f"{text} {colour.title()}", color=get_colour(colour)) + await ctx.send(embeds=embed) + + +@global_autocomplete("colour") +async def colour_autocomplete(ctx: interactions.AutocompleteContext): + colours = list((BrandColours.__members__ | FlatUIColours.__members__ | MaterialColours.__members__).keys()) + + if not ctx.input_text: + colours = colours[:25] + else: + results = process.extract(ctx.input_text, colours, limit=25) + colour_match = sorted([result for result in results if result[1] > 50], key=lambda x: x[1], reverse=True) + colours = [colour[0] for colour in colour_match] + + await ctx.send([{"name": colour.title(), "value": colour} for colour in colours]) + + +@test.autocomplete("text") +async def text_autocomplete(ctx: interactions.AutocompleteContext): + await ctx.send([{"name": c, "value": c} for c in ["colour", "color", "shade", "hue"]]) + + bot.start(os.environ["TOKEN"])