Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions interactions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@
FlatUIColors,
FlatUIColours,
get_components_ids,
global_autocomplete,
GlobalAutoComplete,
Greedy,
Guild,
guild_only,
Expand Down Expand Up @@ -444,7 +446,9 @@
"FlatUIColours",
"get_components_ids",
"get_logger",
"global_autocomplete",
"GLOBAL_SCOPE",
"GlobalAutoComplete",
"GlobalScope",
"Greedy",
"Guild",
Expand Down
22 changes: 16 additions & 6 deletions interactions/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import interactions.api.events as events
import interactions.client.const as constants
from interactions.models.internal.callback import CallbackObject
from interactions.api.events import BaseEvent, RawGatewayEvent, processors
from interactions.api.events.internal import CallbackAdded
from interactions.api.gateway.gateway import GatewayClient
Expand Down Expand Up @@ -90,7 +91,7 @@
from interactions.models.discord.file import UPLOADABLE_TYPE
from interactions.models.discord.snowflake import Snowflake, to_snowflake_list
from interactions.models.internal.active_voice_state import ActiveVoiceState
from interactions.models.internal.application_commands import ContextMenu, ModalCommand
from interactions.models.internal.application_commands import ContextMenu, ModalCommand, GlobalAutoComplete
from interactions.models.internal.auto_defer import AutoDefer
from interactions.models.internal.command import BaseCommand
from interactions.models.internal.context import (
Expand Down Expand Up @@ -378,6 +379,7 @@ def __init__(
"""A dictionary of registered application commands in a tree"""
self._component_callbacks: Dict[str, Callable[..., Coroutine]] = {}
self._modal_callbacks: Dict[str, Callable[..., Coroutine]] = {}
self._global_autocompletes: Dict[str, GlobalAutoComplete] = {}
self.processors: Dict[str, Callable[..., Coroutine]] = {}
self.__modules = {}
self.ext: Dict[str, Extension] = {}
Expand Down Expand Up @@ -1256,6 +1258,15 @@ def add_modal_callback(self, command: ModalCommand) -> None:
self._modal_callbacks[listener] = command
continue

def add_global_autocomplete(self, callback: GlobalAutoComplete) -> None:
"""
Add a global autocomplete to the client.

Args:
callback: The autocomplete to add
"""
self._global_autocompletes[callback.option_name] = callback

def add_command(self, func: Callable) -> None:
"""
Add a command to the client.
Expand All @@ -1271,6 +1282,8 @@ def add_command(self, func: Callable) -> None:
self.add_interaction(func)
elif isinstance(func, Listener):
self.add_listener(func)
elif isinstance(func, GlobalAutoComplete):
self.add_global_autocomplete(func)
elif not isinstance(func, BaseCommand):
raise TypeError("Invalid command type")

Expand Down Expand Up @@ -1302,12 +1315,10 @@ def process(callables, location: str) -> None:
self.logger.debug(f"{added} callbacks have been loaded from {location}.")

main_commands = [
obj for _, obj in inspect.getmembers(sys.modules["__main__"]) if isinstance(obj, (BaseCommand, Listener))
obj for _, obj in inspect.getmembers(sys.modules["__main__"]) if isinstance(obj, CallbackObject)
]
client_commands = [
obj.copy_with_binding(self)
for _, obj in inspect.getmembers(self)
if isinstance(obj, (BaseCommand, Listener))
obj.copy_with_binding(self) for _, obj in inspect.getmembers(self) if isinstance(obj, CallbackObject)
]
process(main_commands, "__main__")
process(client_commands, self.__class__.__name__)
Expand Down Expand Up @@ -1597,7 +1608,6 @@ async def _dispatch_interaction(self, event: RawGatewayEvent) -> None:
elif autocomplete := self._global_autocompletes.get(str(auto_opt.name)):
callback = autocomplete
else:
breakpoint()
raise ValueError(f"Autocomplete callback for {str(auto_opt.name)} not found")

await self.__dispatch_interaction(
Expand Down
4 changes: 4 additions & 0 deletions interactions/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@
DMConverter,
DMGroupConverter,
Extension,
global_autocomplete,
GlobalAutoComplete,
Greedy,
guild_only,
GuildCategoryConverter,
Expand Down Expand Up @@ -382,6 +384,8 @@
"FlatUIColors",
"FlatUIColours",
"get_components_ids",
"global_autocomplete",
"GlobalAutoComplete",
"Greedy",
"Guild",
"guild_only",
Expand Down
4 changes: 4 additions & 0 deletions interactions/models/internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
ComponentCommand,
context_menu,
ContextMenu,
global_autocomplete,
GlobalAutoComplete,
InteractionCommand,
LocalisedDesc,
LocalisedName,
Expand Down Expand Up @@ -126,6 +128,8 @@
"DMConverter",
"DMGroupConverter",
"Extension",
"global_autocomplete",
"GlobalAutoComplete",
"Greedy",
"guild_only",
"GuildCategoryConverter",
Expand Down
60 changes: 44 additions & 16 deletions interactions/models/internal/application_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from interactions.models.discord.snowflake import to_snowflake_list, to_snowflake
from interactions.models.discord.user import BaseUser
from interactions.models.internal.auto_defer import AutoDefer
from interactions.models.internal.callback import CallbackObject
from interactions.models.internal.command import BaseCommand
from interactions.models.internal.localisation import LocalisedField

Expand All @@ -48,28 +49,30 @@
from interactions import Client

__all__ = (
"OptionType",
"application_commands_to_dict",
"auto_defer",
"CallbackType",
"InteractionCommand",
"ContextMenu",
"SlashCommandChoice",
"SlashCommandOption",
"SlashCommand",
"component_callback",
"ComponentCommand",
"context_menu",
"ContextMenu",
"global_autocomplete",
"GlobalAutoComplete",
"InteractionCommand",
"LocalisedDesc",
"LocalisedName",
"LocalizedDesc",
"LocalizedName",
"ModalCommand",
"OptionType",
"slash_command",
"subcommand",
"context_menu",
"component_callback",
"slash_option",
"slash_default_member_permission",
"auto_defer",
"application_commands_to_dict",
"slash_option",
"SlashCommand",
"SlashCommandChoice",
"SlashCommandOption",
"subcommand",
"sync_needed",
"LocalisedName",
"LocalizedName",
"LocalizedDesc",
"LocalisedDesc",
)


Expand Down Expand Up @@ -674,11 +677,36 @@ def _unpack_helper(iterable: typing.Iterable[str]) -> list[str]:
return unpack


class GlobalAutoComplete(CallbackObject):
def __init__(self, option_name: str, callback: Callable) -> None:
self.callback = callback
self.option_name = option_name


##############
# Decorators #
##############


def global_autocomplete(option_name: str) -> Callable[[AsyncCallable], GlobalAutoComplete]:
"""
Decorator for global autocomplete functions

Args:
option_name: The name of the option to register the autocomplete function for

Returns:
The decorator
"""

def decorator(func: Callable) -> GlobalAutoComplete:
if not asyncio.iscoroutinefunction(func):
raise TypeError("Autocomplete functions must be coroutines")
return GlobalAutoComplete(option_name, func)

return decorator


def slash_command(
name: str | LocalisedName,
*,
Expand Down
7 changes: 6 additions & 1 deletion interactions/models/internal/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import interactions.models.internal as models
import interactions.api.events as events
from interactions.models.internal.callback import CallbackObject
from interactions.client.const import MISSING
from interactions.client.utils.misc_utils import wrap_partial
from interactions.models.internal.tasks import Task
Expand Down Expand Up @@ -94,7 +95,7 @@ def __new__(cls, bot: "Client", *args, **kwargs) -> "Extension":
instance._listeners = []

callables: list[tuple[str, typing.Callable]] = inspect.getmembers(
instance, predicate=lambda x: isinstance(x, (models.BaseCommand, models.Listener, Task))
instance, predicate=lambda x: isinstance(x, (CallbackObject, Task))
)

for _name, val in callables:
Expand All @@ -112,6 +113,10 @@ def __new__(cls, bot: "Client", *args, **kwargs) -> "Extension":
val = wrap_partial(val, instance)
bot.add_listener(val) # type: ignore
instance._listeners.append(val)
elif isinstance(val, models.GlobalAutoComplete):
val.extension = instance
val = wrap_partial(val, instance)
bot.add_global_autocomplete(val)
bot.dispatch(events.ExtensionCommandParse(extension=instance, callables=callables))

instance.extension_name = inspect.getmodule(instance).__name__
Expand Down
43 changes: 42 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import os
import uuid

from thefuzz import process

import interactions
from interactions import Client, listen, slash_command, BrandColours
from interactions import Client, listen, slash_command, BrandColours, FlatUIColours, MaterialColours
from interactions.models.internal.application_commands import global_autocomplete, slash_option

logging.basicConfig()
logging.getLogger("interactions").setLevel(logging.DEBUG)
Expand Down Expand Up @@ -104,4 +107,42 @@ async def multi_image_embed_test(ctx: interactions.SlashContext):
await ctx.send(embeds=embed)


def get_colour(colour: str):
if colour in interactions.MaterialColors.__members__:
return interactions.MaterialColors[colour]
elif colour in interactions.BrandColors.__members__:
return interactions.BrandColors[colour]
elif colour in interactions.FlatUIColours.__members__:
return interactions.FlatUIColours[colour]
else:
return interactions.BrandColors.BLURPLE


@slash_command("test")
@slash_option("colour", "The colour to use", autocomplete=True, opt_type=interactions.OptionType.STRING, required=True)
@slash_option("text", "some text", autocomplete=True, opt_type=interactions.OptionType.STRING, required=True)
async def test(ctx: interactions.SlashContext, colour: str, text: str):
embed = interactions.Embed(f"{text} {colour.title()}", color=get_colour(colour))
await ctx.send(embeds=embed)


@global_autocomplete("colour")
async def colour_autocomplete(ctx: interactions.AutocompleteContext):
colours = list((BrandColours.__members__ | FlatUIColours.__members__ | MaterialColours.__members__).keys())

if not ctx.input_text:
colours = colours[:25]
else:
results = process.extract(ctx.input_text, colours, limit=25)
colour_match = sorted([result for result in results if result[1] > 50], key=lambda x: x[1], reverse=True)
colours = [colour[0] for colour in colour_match]

await ctx.send([{"name": colour.title(), "value": colour} for colour in colours])


@test.autocomplete("text")
async def text_autocomplete(ctx: interactions.AutocompleteContext):
await ctx.send([{"name": c, "value": c} for c in ["colour", "color", "shade", "hue"]])


bot.start(os.environ["TOKEN"])