From 895f0aefdc13ce799485ef03f1e774a7a01c5077 Mon Sep 17 00:00:00 2001 From: EmreTech <50607143+EmreTech@users.noreply.github.com> Date: Thu, 1 Dec 2022 21:35:31 -0800 Subject: [PATCH 01/13] refactor!: rewrite dispatcher & event --- discatcore/utils/dispatcher.py | 298 ++++++++++++++++++++------------- discatcore/utils/event.py | 290 ++++---------------------------- discatcore/utils/json.py | 7 +- 3 files changed, 219 insertions(+), 376 deletions(-) diff --git a/discatcore/utils/dispatcher.py b/discatcore/utils/dispatcher.py index 67d1c1a..08996db 100644 --- a/discatcore/utils/dispatcher.py +++ b/discatcore/utils/dispatcher.py @@ -1,162 +1,228 @@ # SPDX-License-Identifier: MIT +from __future__ import annotations + import asyncio import inspect import logging +import sys import traceback import typing as t -from collections.abc import Callable, Coroutine - -from .event import Event - -_log = logging.getLogger(__name__) - -__all__ = ("Dispatcher",) - -T = t.TypeVar("T") -Func = Callable[..., T] -CoroFunc = Func[Coroutine[t.Any, t.Any, t.Any]] +from collections import defaultdict +import attr +from typing_extensions import Self, TypeGuard -class Dispatcher: - """A class that helps manage events. +from .event import Event, EventT, ExceptionEvent +from .json import JSONObject - Attributes: - events (dict[str, Event]): The callbacks for each event. - """ +if t.TYPE_CHECKING: + from ..gateway import GatewayClient - __slots__ = ("events",) +if sys.version_info >= (3, 10): + from types import UnionType - def __init__(self) -> None: - self.events: dict[str, Event] = {} + _union_types = {t.Union, UnionType} +else: + _union_types = {t.Union} - def get_event(self, name: str) -> t.Optional[Event]: - """Returns an event with the name provided. +_log = logging.getLogger(__name__) - Args: - name (str): The name of the event that will be returned. +__all__ = ( + "Consumer", + "consumer_for", + "Dispatcher", +) - Returns: - The event, none if not found. - """ - return self.events.get(name) +T = t.TypeVar("T") +DispatcherT = t.TypeVar("DispatcherT", bound="Dispatcher") +ListenerCallbackT = t.TypeVar("ListenerCallbackT", bound="ListenerCallback[Event]") +Coro = t.Coroutine[T, t.Any, t.Any] - def new_event(self, name: str) -> Event: - """Creates a new event. Returns this new event after creation. +ListenerCallback = t.Callable[[EventT], Coro[None]] +ConsumerCallback = t.Callable[[DispatcherT, GatewayClient, JSONObject], Coro[None]] - Args: - name (str): The name of the new event. - Returns: - The new event created. - """ - new_event = Event(name, self) - self.events[name] = new_event - return new_event +@attr.define +class Consumer(t.Generic[DispatcherT]): + """Represents a dispatcher consumer. A consumer consumes a raw event and performs actions based on the raw event.""" - def add_event(self, event: Event, *, override: bool = True) -> None: - """Adds a new pre-existing event. + callback: ConsumerCallback[DispatcherT] + events: tuple[type[Event], ...] - Args: - event (Event): The event to add. - """ - if self.has_event(event.name) and not override: - return - self.events[event.name] = event +def consumer_for( + *event_types: type[Event], +) -> t.Callable[[ConsumerCallback[DispatcherT]], Consumer[DispatcherT]]: + event_types = tuple({event for event_type in event_types for event in event_type.dispatches}) - def remove_event(self, name: str) -> None: - """Removes an event. + def wrapper(func: ConsumerCallback[DispatcherT]) -> Consumer[DispatcherT]: + return Consumer(func, event_types) - Args: - name (str): The name of the event to remove. - """ - if name not in self.events: - raise ValueError(f"There is no event with name {name}!") + return wrapper - del self.events[name] - _log.debug("Removed event with name %s", name) - def has_event(self, name: str) -> bool: - """Check if this dispatcher already has a event. +def _is_exception_event(e: EventT) -> TypeGuard[ExceptionEvent[EventT]]: + return isinstance(e, ExceptionEvent) - Args: - name (str): The name of the event to find. - Returns: - A bool correlating to if there is a event with that name or not. - """ - return name in self.events +class Dispatcher: + """A class that helps manage events.""" - def callback_for( - self, event: str, *, one_shot: bool = False, force_parent: bool = False - ) -> Callable[[CoroFunc], Event]: - """A shortcut decorator to add a callback to an event. - If the event does not exist already, then a new one will be created. + __slots__ = ("_listeners", "_consumers") - Args: - event: The name of the event to get or create. - one_shot: Whether or not the callback should be a one shot (which means the callback will be removed after running). Defaults to False. - force_parent: Whether or not this callback contains a self parameter. Defaults to False. + def __init__(self) -> None: + self._listeners: defaultdict[type[Event], list[ListenerCallback[Event]]] = defaultdict( + lambda: [] + ) + self._consumers: dict[str, Consumer[Self]] = {} + + for name, value in inspect.getmembers(self): + if not isinstance(value, Consumer): + continue + + self._consumers[name.lower()] = value + + async def _run_listener(self, event: EventT, listener: ListenerCallback[EventT]) -> None: + try: + await listener(event) + except asyncio.CancelledError: + pass + except Exception as e: + if _is_exception_event(event): + _log.error( + "There was an error while running the listener callback (%s%s) under exception event %s.%s: %s", + listener.__name__, + inspect.signature(listener), + type(event).__module__, + type(event).__qualname__, + traceback.format_exception(type(e), e, e.__traceback__), + ) + else: + exec_event = ExceptionEvent( + exception=e, failed_event=event, failed_listener=listener + ) + + _log.info( + "An exception occured while handling %s.%s.", + type(event).__module__, + type(event).__qualname__, + ) + await self.dispatch(exec_event) + + async def _handle_consumer( + self, consumer: ConsumerCallback[Self], gateway: GatewayClient, payload: JSONObject + ): + try: + await consumer(self, gateway, payload) + except asyncio.CancelledError: + pass + except Exception as e: + asyncio.get_running_loop().call_exception_handler( + { + "message": "An exception occured while consuming a raw event.", + "exception": e, + "task": asyncio.current_task(), + } + ) - Returns: - A wrapper function that acts as the actual decorator. - """ + def subscribe(self, event: type[EventT], func: ListenerCallback[EventT]) -> None: + if not asyncio.iscoroutinefunction(func): + raise TypeError(f"listener callback {func.__name__!r} has to be a coroutine function!") + + _log.debug( + "Subscribing listener callback (%s%s) to event %s.%s", + func.__name__, + inspect.signature(func), + event.__module__, + event.__qualname__, + ) + self._listeners[event].append(func) # pyright: ignore + + def unsubscribe(self, event: type[EventT], func: ListenerCallback[EventT]) -> None: + listeners = self._listeners.get(event) + if not listeners: + return - def wrapper(coro: CoroFunc): - if not self.has_event(event): - event_cls = self.new_event(event) + _log.debug( + "Unsubscribing listener callback (%s%s) from event %s.%s", + func.__name__, + inspect.signature(func), + event.__module__, + event.__qualname__, + ) + listeners.remove(func) # pyright: ignore + + if not listeners: + del self._listeners[event] + + def listen_to(self, *events: type[Event]) -> t.Callable[[ListenerCallbackT], ListenerCallbackT]: + def wrapper(func: ListenerCallbackT) -> ListenerCallbackT: + func_sig = inspect.signature(func) + event_arg = next(iter(func_sig.parameters.values())) + event_arg_anno = event_arg.annotation + + resolved_events: set[type[Event]] + if event_arg_anno is inspect.Parameter.empty: + if events: + resolved_events = set(events) + else: + raise TypeError( + "No event type was provided! Please provide it as an argument or a type hint." + ) else: - event_cls = self.events[event] - event_cls.add_callback(coro, one_shot=one_shot, force_parent=force_parent) - return event_cls + def event_check(arg: t.Any) -> None: + if not isinstance(arg, type) or not issubclass(arg, Event): + raise TypeError(f"Expected an event, got {arg!r}.") - return wrapper + if t.get_origin(event_arg_anno) in _union_types: + union_args = t.get_args(event_arg_anno) - # global error handler + for arg in union_args: + event_check(arg) - async def error_handler(self, exception: Exception) -> None: - """Basic error handler for dispatched events. + resolved_events = t.cast(set[type[Event]], set(union_args)) + else: + event_check(event_arg_anno) + resolved_events = {t.cast(type[Event], event_arg_anno)} - Args: - exception (Exception): The exception from the dispatched event. - """ - traceback.print_exception(type(exception), exception, exception.__traceback__) + for event in resolved_events: + self.subscribe(event, func) - def override_error_handler(self, func: CoroFunc) -> None: - """Overrides a new error handler for dispatched events. + return func - Args: - func (CoroFunc): The new error handler. - """ - if not asyncio.iscoroutinefunction(func): - raise TypeError("Callback provided is not a coroutine.") + return wrapper - orig_handler_sig = inspect.signature(self.error_handler) - new_handler_sig = inspect.signature(func) + def dispatch(self, event: Event) -> asyncio.Future[t.Any]: + _log.debug( + "Dispatching event %s.%s (which dispatches event(s) %r).", + type(event).__module__, + type(event).__qualname__, + [f"{e.__module__}.{e.__qualname__}" for e in event.dispatches], + ) + dispatched: t.List[Coro[None]] = [] - if orig_handler_sig.parameters != new_handler_sig.parameters: - raise TypeError( - "Overloaded error handler does not have the same parameters as original error handler." - ) + for event_type in event.dispatches: + for listener in self._listeners.get(event_type, []): + dispatched.append(self._run_listener(event, listener)) - setattr(self, "error_handler", func) - _log.debug("Registered new error handler") + def _completed_future() -> asyncio.Future[None]: + future = asyncio.get_running_loop().create_future() + future.set_result(None) + return future - # dispatch + return asyncio.gather(*dispatched) if dispatched else _completed_future() - def dispatch(self, name: str, *args: t.Any, **kwargs: t.Any) -> None: - """Dispatches a event. This will trigger the all of the event's - callbacks. + def consume(self, event: str, gateway: GatewayClient, payload: JSONObject): + consumer = self._consumers.get(event) - Args: - name (str): The name of the event to dispatch. - *args (t.Any): Arguments to pass into the event. - **kwargs (t.Any): Keyword arguments to pass into the event. - """ - _log.debug("Dispatching event %s", name) - event = self.events.get(name) + if not consumer: + _log.info("Consumer %s does not exist. Skipping consumption.", event) + return - if event is not None: - event.dispatch(*args, **kwargs) + _log.debug("Consuming raw event %s.", event) + asyncio.create_task( + self._handle_consumer(consumer.callback, gateway, payload), + name=f"DisCatCore Consumer {event}", + ) diff --git a/discatcore/utils/event.py b/discatcore/utils/event.py index d57a45b..bf34347 100644 --- a/discatcore/utils/event.py +++ b/discatcore/utils/event.py @@ -2,282 +2,54 @@ from __future__ import annotations -import asyncio -import inspect -import logging import typing as t -from collections.abc import Callable, Coroutine -from dataclasses import dataclass -if t.TYPE_CHECKING: - from .dispatcher import Dispatcher +import attr +from typing_extensions import Self -_log = logging.getLogger(__name__) +if t.TYPE_CHECKING: + from .dispatcher import ListenerCallback -__all__ = ("Event",) +__all__ = ("Event", "ExceptionEvent") T = t.TypeVar("T") -Func = Callable[..., T] -CoroFunc = Func[Coroutine[t.Any, t.Any, t.Any]] - - -@dataclass -class _EventCallbackMetadata: - one_shot: bool = False - parent: bool = False - - -class Event: - """Represents an event for a dispatcher. - - Args: - name (str): The name of this event. - parent (Dispatcher): The parent dispatcher of this event. - - Attributes: - name (str): The name of this event. - parent (Dispatcher): The parent dispatcher of this event. - callbacks (list[Callable[..., Coroutine[t.Any, t.Any, t.Any]]]): The callbacks for this event. - metadata (dict[Callable[..., Coroutine[t.Any, t.Any, t.Any]], _EventCallbackMetadata]): The metadata for the callbacks for this event. - _proto (t.Optional[inspect.Signature]): The prototype of this event. - This will define what signature all of the callbacks will have. - _error_handler (Callable[..., Coroutine[t.Any, t.Any, t.Any]]): The error handler of this event. - The error handler will be run whenever an event dispatched raises an error. - Defaults to the error handler from the parent dispatcher. - """ - - def __init__(self, name: str, parent: Dispatcher) -> None: - self.name: str = name - self.parent: Dispatcher = parent - self.callbacks: list[CoroFunc] = [] - self.metadata: dict[CoroFunc, _EventCallbackMetadata] = {} - self._proto: t.Optional[inspect.Signature] = None - self._error_handler: CoroFunc = self.parent.error_handler - - # setters/decorators - - def set_proto( - self, proto_func: t.Union[Func[t.Any], staticmethod[t.Any]], *, force_parent: bool = False - ) -> None: - """Sets the prototype for this event. - - Args: - proto_func (Callable[..., t.Any]): The prototype for this event. - force_parent (bool): Whether or not this callback contains a self parameter. Defaults to ``False``. - """ - is_static = isinstance(proto_func, staticmethod) - if is_static: - proto_func = proto_func.__func__ - - if not self._proto: - sig = inspect.signature(proto_func) - if force_parent and not is_static: - new_params = list(sig.parameters.values()) - new_params.pop(0) - sig = sig.replace(parameters=new_params) - self._proto = sig - - _log.debug("Registered new event prototype under event %s", self.name) - else: - raise ValueError(f"Event prototype for event {self.name} has already been set!") - - @t.overload - def proto(self, func: CoroFunc, *, force_parent: bool = ...) -> Event: - pass - - @t.overload - def proto( - self, func: None = ..., *, force_parent: bool = ... - ) -> Callable[[Func[t.Any]], Event]: - pass - - def proto( - self, - func: t.Optional[t.Union[Func[t.Any], staticmethod[t.Any]]] = None, - *, - force_parent: bool = False, - ) -> t.Union[Event, Callable[[Func[t.Any]], Event]]: - """A decorator to set the prototype of this event. - - Args: - func (t.Optional[Callable[..., t.Any]]): The prototype to pass into this decorator. Defaults to ``None``. - force_parent (bool): Whether or not this callback contains a self parameter. Defaults to ``False``. - - Returns: - Either this event object or a wrapper function that acts as the actual decorator. - This depends on if the ``func`` arg was passed in. - """ - - def wrapper(func: t.Union[Func[t.Any], staticmethod[t.Any]]): - self.set_proto(func, force_parent=force_parent) - return self - - if func: - return wrapper(func) - return wrapper +EventT = t.TypeVar("EventT", bound="Event") - def set_error_handler(self, func: CoroFunc) -> None: - """Overrides the error handler of this event. - Args: - func (Callable[..., Coroutine[t.Any, t.Any, t.Any]]): The new error handler for this event. - """ - if not asyncio.iscoroutinefunction(func): - raise TypeError("Callback provided is not a coroutine.") +class _classproperty(t.Generic[T]): + def __init__(self, fget: t.Callable[[t.Any], T], /) -> None: + self.fget: "classmethod[T]" = t.cast("classmethod[T]", fget) - orig_handler_sig = inspect.signature(self._error_handler) - new_handler_sig = inspect.signature(func) + def getter(self, fget: t.Callable[[t.Any], T], /) -> Self: + self.fget = t.cast("classmethod[T]", fget) + return self - if len(orig_handler_sig.parameters) != len(new_handler_sig.parameters): - raise TypeError( - "Overloaded error handler does not have the same parameters as original error handler." - ) + def __get__(self, obj: t.Optional[t.Any], type: t.Optional[type]) -> T: + return self.fget.__func__(type) - self._error_handler = func - _log.debug("Registered new error handler under event %s", self.name) - def error_handler(self) -> Callable[[Func[t.Any]], Event]: - """A decorator to override the error handler of this event. - - Returns: - A wrapper function that acts as the actual decorator. - """ - - def wrapper(func: CoroFunc): - self.set_error_handler(func) - return self - - return wrapper - - def add_callback( - self, func: CoroFunc, *, one_shot: bool = False, force_parent: bool = False - ) -> None: - """Adds a new callback to this event. - - Args: - func (Callable[..., Coroutine[t.Any, t.Any, t.Any]]): The callback to add to this event. - one_shot (bool): Whether or not the callback should be a one shot (which means the callback will be removed after running). Defaults to False. - force_parent (bool): Whether or not this callback contains a self parameter. Defaults to False. - """ - if not self._proto: - self.set_proto(func, force_parent=force_parent) - # this is to prevent static type checkers from inferring that self._proto is - # still None after setting it indirectly via a different function - # (it should never go here tho because exceptions stop the flow of this code - # and it should be set if we don't reach t.Any exceptions) - if not self._proto: - return - - if not asyncio.iscoroutinefunction(func): - raise TypeError("Callback provided is not a coroutine.") - - callback_sig = inspect.signature(func) - if force_parent: - new_params = list(callback_sig.parameters.values()) - new_params.pop(0) - callback_sig = callback_sig.replace(parameters=new_params) - - if len(self._proto.parameters) != len(callback_sig.parameters): - raise TypeError( - "Event callback parameters do not match up with the event prototype parameters." - ) - - metadat = _EventCallbackMetadata(one_shot) - self.metadata[func] = metadat - self.callbacks.append(func) - - _log.debug("Registered new event callback under event %s", self.name) - - def remove_callback(self, index: int) -> None: - """Removes a callback located at a certain index. - - Args: - index (int): The index where the callback is located. - """ - if len(self.callbacks) - 1 < index: - raise ValueError(f"Event {self.name} has less callbacks than the index provided!") - - del self.callbacks[index] - _log.debug("Removed event callback with index %d under event %s", index, self.name) - - @t.overload - def callback(self, func: CoroFunc, *, one_shot: bool = ..., force_parent: bool = ...) -> Event: - pass - - @t.overload - def callback( - self, func: None = ..., *, one_shot: bool = ..., force_parent: bool = ... - ) -> Callable[[Func[t.Any]], Event]: - pass - - def callback( - self, - func: t.Optional[CoroFunc] = None, - *, - one_shot: bool = False, - force_parent: bool = False, - ) -> t.Union[Event, Callable[[Func[t.Any]], Event]]: - """A decorator to add a callback to this event. - - Args: - func (t.Optional[Callable[..., Coroutine[t.Any, t.Any, t.Any]]]): The function to pass into this decorator. Defaults to None. - one_shot (bool): Whether or not the callback should be a one shot (which means the callback will be removed after running). Defaults to False. - force_parent (bool): Whether or not this callback contains a self parameter. Defaults to False. - - Returns: - Either this event object or a wrapper function that acts as the actual decorator. - This depends on if the ``func`` arg was passed in. - """ - - def wrapper(func: CoroFunc): - self.add_callback(func, one_shot=one_shot, force_parent=force_parent) - return self - - if func: - return wrapper(func) - return wrapper +class Event: + """Represents a dispatcher event. An event class contains information about an event for use in listeners.""" - # dispatch + __slots__ = () - async def _run(self, coro: CoroFunc, *args: t.Any, **kwargs: t.Any) -> None: - try: - await coro(*args, **kwargs) - except asyncio.CancelledError: - pass - except Exception as e: - try: - await self._error_handler(e) - except asyncio.CancelledError: - pass + __dispatches: tuple[type[Event], ...] - def _schedule_task( - self, - coro: CoroFunc, - index: t.Optional[int], - *args: t.Any, - **kwargs: t.Any, - ) -> asyncio.Task[t.Any]: - task_name = f"DisCatCore Event:{self.name}" - if index: - task_name += f" Index:{index}" - task_name = task_name.rstrip() + def __init_subclass__(cls) -> None: + super().__init_subclass__() - wrapped = self._run(coro, *args, **kwargs) - return asyncio.create_task(wrapped, name=task_name) + cls.__dispatches = tuple(base for base in cls.__mro__ if issubclass(base, Event)) - def dispatch(self, *args: t.Any, **kwargs: t.Any) -> None: - """Runs all event callbacks with arguments. + @_classproperty + @classmethod + def dispatches(cls): + return cls.__dispatches - Args: - *args (t.Any): Arguments to pass into the event callbacks. - **kwargs (t.Any): Keyword arguments to pass into the event callbacks. - """ - for i, callback in enumerate(self.callbacks): - metadata = self.metadata.get(callback, _EventCallbackMetadata()) - _log.debug("Running event callback under event %s with index %s", self.name, i) - self._schedule_task(callback, i, *args, **kwargs) +@attr.define(kw_only=True) +class ExceptionEvent(Event, t.Generic[EventT]): + """An event that is dispatched whenever a dispatched event raises an exception.""" - if metadata.one_shot: - _log.debug("Removing event callback under event %s with index %s", self.name, i) - self.remove_callback(i) + exception: BaseException + failed_event: EventT + failed_listener: ListenerCallback[EventT] diff --git a/discatcore/utils/json.py b/discatcore/utils/json.py index 0235e67..5ca5dfc 100644 --- a/discatcore/utils/json.py +++ b/discatcore/utils/json.py @@ -10,7 +10,12 @@ except ImportError: import json -__all__ = ("dumps", "loads") +__all__ = ("JSONObject", "dumps", "loads") + + +JSONObject = t.Union[ + str, int, float, bool, None, t.Sequence["JSONObject"], t.Mapping[str, "JSONObject"] +] def dumps(obj: t.Any) -> str: From 77624681e25403336ed3302678162ea2388cf3ca Mon Sep 17 00:00:00 2001 From: EmreTech <50607143+EmreTech@users.noreply.github.com> Date: Sun, 4 Dec 2022 02:03:38 -0800 Subject: [PATCH 02/13] refactor: implement dispatcher changes in GatewayClient --- discatcore/gateway/__init__.py | 2 + discatcore/gateway/client.py | 11 +- discatcore/gateway/events.py | 452 +++++++++++++++++++++++++++++++++ discatcore/utils/dispatcher.py | 8 +- discatcore/utils/event.py | 4 +- 5 files changed, 465 insertions(+), 12 deletions(-) create mode 100644 discatcore/gateway/events.py diff --git a/discatcore/gateway/__init__.py b/discatcore/gateway/__init__.py index 513714b..0cfc1b7 100644 --- a/discatcore/gateway/__init__.py +++ b/discatcore/gateway/__init__.py @@ -6,8 +6,10 @@ """ from .client import * +from .events import * from .ratelimiter import * __all__ = () __all__ += client.__all__ +__all__ += events.__all__ __all__ += ratelimiter.__all__ diff --git a/discatcore/gateway/client.py b/discatcore/gateway/client.py index c6a86a0..1e2e52a 100644 --- a/discatcore/gateway/client.py +++ b/discatcore/gateway/client.py @@ -18,6 +18,7 @@ from ..http import HTTPClient from ..utils.dispatcher import Dispatcher from ..utils.json import dumps, loads +from .events import InvalidSessionEvent, ReconnectEvent, name_to_class from .ratelimiter import Ratelimiter from .types import BaseTypedWSMessage, is_binary, is_text @@ -288,23 +289,21 @@ async def connection_loop(self) -> None: self.session_id = ready_data["session_id"] self.resume_url = ready_data["resume_gateway_url"] - args = (data,) - if data is None: - args = () - self._dispatcher.dispatch(event_name, *args) + event = name_to_class[event_name](data) + await self._dispatcher.dispatch(event) # these should be rare, but it's better to be safe than sorry elif op == HEARTBEAT: await self.heartbeat() elif op == RECONNECT: - self._dispatcher.dispatch("reconnect") + await self._dispatcher.dispatch(ReconnectEvent()) await self.close(code=1012) return elif op == INVALID_SESSION: self.can_resume = bool(self.recent_payload.get("d")) - self._dispatcher.dispatch("invalid_session", self.can_resume) + await self._dispatcher.dispatch(InvalidSessionEvent(self.can_resume)) await self.close(code=1012) return diff --git a/discatcore/gateway/events.py b/discatcore/gateway/events.py new file mode 100644 index 0000000..82d4fd1 --- /dev/null +++ b/discatcore/gateway/events.py @@ -0,0 +1,452 @@ +# SPDX-License-Identifier: MIT +from __future__ import annotations + +import typing as t +from dataclasses import dataclass + +import discord_typings as dt + +from ..utils.event import Event + +__all__ = ( + "GatewayEvent", + "UnknownEvent", + "ApplicationCommandPermissionsUpdateEvent", + "AutoModerationRuleCreateEvent", + "AutoModerationRuleDeleteEvent", + "AutoModerationRuleUpdateEvent", + "ChannelCreateEvent", + "ChannelDeleteEvent", + "ChannelPinsUpdateEvent", + "ChannelUpdateEvent", + "GuildBanAddEvent", + "GuildBanRemoveEvent", + "GuildCreateEvent", + "GuildDeleteEvent", + "GuildEmojisUpdateEvent", + "GuildIntegrationsUpdateEvent", + "GuildMemberAddEvent", + "GuildMemberRemoveEvent", + "GuildMemberUpdateEvent", + "GuildMembersChunkEvent", + "GuildRoleCreateEvent", + "GuildRoleDeleteEvent", + "GuildRoleUpdateEvent", + "GuildScheduledEventCreateEvent", + "GuildScheduledEventDeleteEvent", + "GuildScheduledEventUpdateEvent", + "GuildScheduledEventUserAddEvent", + "GuildScheduledEventUserRemoveEvent", + "GuildStickersUpdateEvent", + "GuildUpdateEvent", + "IntegrationCreateEvent", + "IntegrationDeleteEvent", + "IntegrationUpdateEvent", + "InteractionCreateEvent", + "InvalidSessionEvent", + "InviteCreateEvent", + "InviteDeleteEvent", + "MessageCreateEvent", + "MessageDeleteEvent", + "MessageDeleteBulkEvent", + "MessageReactionAddEvent", + "MessageReactionRemoveEvent", + "MessageReactionRemoveAllEvent", + "MessageReactionRemoveEmojiEvent", + "MessageUpdateEvent", + "PresenceUpdateEvent", + "ReadyEvent", + "ReconnectEvent", + "ResumedEvent", + "StageInstanceCreateEvent", + "StageInstanceDeleteEvent", + "StageInstanceUpdateEvent", + "ThreadCreateEvent", + "ThreadDeleteEvent", + "ThreadListSyncEvent", + "ThreadMemberUpdateEvent", + "ThreadMembersUpdateEvent", + "ThreadUpdateEvent", + "TypingStartEvent", + "UserUpdateEvent", + "VoiceServerUpdateEvent", + "VoiceStateUpdateEvent", + "WebhooksUpdateEvent", +) + + +@dataclass +class GatewayEvent(Event): + pass + + +@dataclass +class UnknownEvent(GatewayEvent): + data: dt.DispatchEvent + + +@dataclass +class ReadyEvent(GatewayEvent): + data: dt.ReadyData + + +@dataclass +class ResumedEvent(GatewayEvent): + pass + + +@dataclass +class ReconnectEvent(GatewayEvent): + pass + + +@dataclass +class InvalidSessionEvent(GatewayEvent): + resumable: bool + + +@dataclass +class ApplicationCommandPermissionsUpdateEvent(GatewayEvent): + data: dt.ApplicationCommandPermissionsUpdateData + + +@dataclass +class AutoModerationRuleCreateEvent(GatewayEvent): + data: dt.AutoModerationRuleData + + +@dataclass +class AutoModerationRuleDeleteEvent(GatewayEvent): + data: dt.AutoModerationRuleData + + +@dataclass +class AutoModerationRuleUpdateEvent(GatewayEvent): + data: dt.AutoModerationRuleData + + +@dataclass +class ChannelCreateEvent(GatewayEvent): + data: dt.ChannelCreateData + + +@dataclass +class ChannelDeleteEvent(GatewayEvent): + data: dt.ChannelDeleteData + + +@dataclass +class ChannelPinsUpdateEvent(GatewayEvent): + data: dt.ChannelPinsUpdateData + + +@dataclass +class ChannelUpdateEvent(GatewayEvent): + data: dt.ChannelUpdateData + + +@dataclass +class GuildBanAddEvent(GatewayEvent): + data: dt.GuildBanAddData + + +@dataclass +class GuildBanRemoveEvent(GatewayEvent): + data: dt.GuildBanRemoveData + + +@dataclass +class GuildCreateEvent(GatewayEvent): + data: dt.GuildCreateData + + +@dataclass +class GuildDeleteEvent(GatewayEvent): + data: dt.GuildDeleteData + + +@dataclass +class GuildEmojisUpdateEvent(GatewayEvent): + data: dt.GuildEmojisUpdateData + + +@dataclass +class GuildIntegrationsUpdateEvent(GatewayEvent): + data: dt.GuildIntergrationsUpdateData + + +@dataclass +class GuildMemberAddEvent(GatewayEvent): + data: dt.GuildMemberAddData + + +@dataclass +class GuildMemberRemoveEvent(GatewayEvent): + data: dt.GuildMemberRemoveData + + +@dataclass +class GuildMemberUpdateEvent(GatewayEvent): + data: dt.GuildMemberUpdateData + + +@dataclass +class GuildMembersChunkEvent(GatewayEvent): + data: dt.GuildMembersChunkData + + +@dataclass +class GuildRoleCreateEvent(GatewayEvent): + data: dt.GuildRoleCreateData + + +@dataclass +class GuildRoleDeleteEvent(GatewayEvent): + data: dt.GuildRoleDeleteData + + +@dataclass +class GuildRoleUpdateEvent(GatewayEvent): + data: dt.GuildRoleUpdateData + + +@dataclass +class GuildScheduledEventCreateEvent(GatewayEvent): + data: dt.GuildScheduledEventCreateData + + +@dataclass +class GuildScheduledEventDeleteEvent(GatewayEvent): + data: dt.GuildScheduledEventDeleteData + + +@dataclass +class GuildScheduledEventUpdateEvent(GatewayEvent): + data: dt.GuildScheduledEventUpdateData + + +@dataclass +class GuildScheduledEventUserAddEvent(GatewayEvent): + data: dt.GuildScheduledEventUserAddData + + +@dataclass +class GuildScheduledEventUserRemoveEvent(GatewayEvent): + data: dt.GuildScheduledEventUserRemoveData + + +@dataclass +class GuildStickersUpdateEvent(GatewayEvent): + data: dt.GuildStickersUpdateData + + +@dataclass +class GuildUpdateEvent(GatewayEvent): + data: dt.GuildUpdateData + + +@dataclass +class IntegrationCreateEvent(GatewayEvent): + data: dt.IntegrationCreateData + + +@dataclass +class IntegrationDeleteEvent(GatewayEvent): + data: dt.IntegrationDeleteData + + +@dataclass +class IntegrationUpdateEvent(GatewayEvent): + data: dt.IntegrationUpdateData + + +@dataclass +class InteractionCreateEvent(GatewayEvent): + data: dt.InteractionCreateData + + +@dataclass +class InviteCreateEvent(GatewayEvent): + data: dt.InviteCreateData + + +@dataclass +class InviteDeleteEvent(GatewayEvent): + data: dt.InviteDeleteData + + +@dataclass +class MessageCreateEvent(GatewayEvent): + data: dt.MessageCreateData + + +@dataclass +class MessageDeleteEvent(GatewayEvent): + data: dt.MessageDeleteData + + +@dataclass +class MessageDeleteBulkEvent(GatewayEvent): + data: dt.MessageDeleteBulkData + + +@dataclass +class MessageReactionAddEvent(GatewayEvent): + data: dt.MessageReactionAddData + + +@dataclass +class MessageReactionRemoveEvent(GatewayEvent): + data: dt.MessageReactionRemoveData + + +@dataclass +class MessageReactionRemoveAllEvent(GatewayEvent): + data: dt.MessageReactionRemoveAllData + + +@dataclass +class MessageReactionRemoveEmojiEvent(GatewayEvent): + data: dt.MessageReactionRemoveEmojiData + + +@dataclass +class MessageUpdateEvent(GatewayEvent): + data: dt.MessageUpdateData + + +@dataclass +class PresenceUpdateEvent(GatewayEvent): + data: dt.PresenceUpdateData + + +@dataclass +class StageInstanceCreateEvent(GatewayEvent): + data: dt.StageInstanceCreateData + + +@dataclass +class StageInstanceDeleteEvent(GatewayEvent): + data: dt.StageInstanceDeleteData + + +@dataclass +class StageInstanceUpdateEvent(GatewayEvent): + data: dt.StageInstanceUpdateData + + +@dataclass +class ThreadCreateEvent(GatewayEvent): + data: dt.ThreadCreateData + + +@dataclass +class ThreadDeleteEvent(GatewayEvent): + data: dt.ThreadDeleteData + + +@dataclass +class ThreadListSyncEvent(GatewayEvent): + data: dt.ThreadListSyncData + + +@dataclass +class ThreadMemberUpdateEvent(GatewayEvent): + data: dt.ThreadMemberUpdateData + + +@dataclass +class ThreadMembersUpdateEvent(GatewayEvent): + data: dt.ThreadMembersUpdateData + + +@dataclass +class ThreadUpdateEvent(GatewayEvent): + data: dt.ThreadUpdateData + + +@dataclass +class TypingStartEvent(GatewayEvent): + data: dt.TypingStartData + + +@dataclass +class UserUpdateEvent(GatewayEvent): + data: dt.UserUpdateData + + +@dataclass +class VoiceServerUpdateEvent(GatewayEvent): + data: dt.VoiceServerUpdateData + + +@dataclass +class VoiceStateUpdateEvent(GatewayEvent): + data: dt.VoiceStateData + + +@dataclass +class WebhooksUpdateEvent(GatewayEvent): + data: dt.WebhooksUpdateData + + +name_to_class: dict[str, t.Any] = { + "application_command_permissions_update": ApplicationCommandPermissionsUpdateEvent, + "auto_moderation_rule_create": AutoModerationRuleCreateEvent, + "auto_moderation_rule_delete": AutoModerationRuleDeleteEvent, + "auto_moderation_rule_update": AutoModerationRuleUpdateEvent, + "channel_create": ChannelCreateEvent, + "channel_delete": ChannelDeleteEvent, + "channel_pins_update": ChannelPinsUpdateEvent, + "channel_update": ChannelUpdateEvent, + "guild_ban_add": GuildBanAddEvent, + "guild_ban_remove": GuildBanRemoveEvent, + "guild_create": GuildCreateEvent, + "guild_delete": GuildDeleteEvent, + "guild_emojis_update": GuildEmojisUpdateEvent, + "guild_integrations_update": GuildIntegrationsUpdateEvent, + "guild_member_add": GuildMemberAddEvent, + "guild_member_remove": GuildMemberRemoveEvent, + "guild_member_update": GuildMemberUpdateEvent, + "guild_members_chunk": GuildMembersChunkEvent, + "guild_role_create": GuildRoleCreateEvent, + "guild_role_delete": GuildRoleDeleteEvent, + "guild_role_update": GuildRoleUpdateEvent, + "guild_scheduled_event_create": GuildScheduledEventCreateEvent, + "guild_scheduled_event_delete": GuildScheduledEventDeleteEvent, + "guild_scheduled_event_update": GuildScheduledEventUpdateEvent, + "guild_scheduled_event_user_add": GuildScheduledEventUserAddEvent, + "guild_scheduled_event_user_remove": GuildScheduledEventUserRemoveEvent, + "guild_stickers_update": GuildStickersUpdateEvent, + "guild_update": GuildUpdateEvent, + "integration_create": IntegrationCreateEvent, + "integration_delete": IntegrationDeleteEvent, + "integration_update": IntegrationUpdateEvent, + "interaction_create": InteractionCreateEvent, + "invite_create": InviteCreateEvent, + "invite_delete": InviteDeleteEvent, + "message_create": MessageCreateEvent, + "message_delete": MessageDeleteEvent, + "message_delete_bulk": MessageDeleteBulkEvent, + "message_reaction_add": MessageReactionAddEvent, + "message_reaction_remove": MessageReactionRemoveEvent, + "message_reaction_remove_all": MessageReactionRemoveAllEvent, + "message_reaction_remove_emoji": MessageReactionRemoveEmojiEvent, + "message_update": MessageUpdateEvent, + "presence_update": PresenceUpdateEvent, + "ready": ReadyEvent, + "stage_instance_create": StageInstanceCreateEvent, + "stage_instance_delete": StageInstanceDeleteEvent, + "stage_instance_update": StageInstanceUpdateEvent, + "thread_create": ThreadCreateEvent, + "thread_delete": ThreadDeleteEvent, + "thread_list_sync": ThreadListSyncEvent, + "thread_member_update": ThreadMemberUpdateEvent, + "thread_members_update": ThreadMembersUpdateEvent, + "thread_update": ThreadUpdateEvent, + "typing_start": TypingStartEvent, + "user_update": UserUpdateEvent, + "voice_server_update": VoiceServerUpdateEvent, + "voice_state_update": VoiceStateUpdateEvent, + "webhooks_update": WebhooksUpdateEvent, +} diff --git a/discatcore/utils/dispatcher.py b/discatcore/utils/dispatcher.py index 08996db..d857994 100644 --- a/discatcore/utils/dispatcher.py +++ b/discatcore/utils/dispatcher.py @@ -9,8 +9,8 @@ import traceback import typing as t from collections import defaultdict +from dataclasses import dataclass -import attr from typing_extensions import Self, TypeGuard from .event import Event, EventT, ExceptionEvent @@ -40,10 +40,10 @@ Coro = t.Coroutine[T, t.Any, t.Any] ListenerCallback = t.Callable[[EventT], Coro[None]] -ConsumerCallback = t.Callable[[DispatcherT, GatewayClient, JSONObject], Coro[None]] +ConsumerCallback = t.Callable[[DispatcherT, "GatewayClient", JSONObject], Coro[None]] -@attr.define +@dataclass class Consumer(t.Generic[DispatcherT]): """Represents a dispatcher consumer. A consumer consumes a raw event and performs actions based on the raw event.""" @@ -173,7 +173,7 @@ def wrapper(func: ListenerCallbackT) -> ListenerCallbackT: else: def event_check(arg: t.Any) -> None: - if not isinstance(arg, type) or not issubclass(arg, Event): + if not isinstance(arg, type) and not issubclass(arg, Event): raise TypeError(f"Expected an event, got {arg!r}.") if t.get_origin(event_arg_anno) in _union_types: diff --git a/discatcore/utils/event.py b/discatcore/utils/event.py index bf34347..7ddce8c 100644 --- a/discatcore/utils/event.py +++ b/discatcore/utils/event.py @@ -3,8 +3,8 @@ from __future__ import annotations import typing as t +from dataclasses import dataclass -import attr from typing_extensions import Self if t.TYPE_CHECKING: @@ -46,7 +46,7 @@ def dispatches(cls): return cls.__dispatches -@attr.define(kw_only=True) +@dataclass(kw_only=True) class ExceptionEvent(Event, t.Generic[EventT]): """An event that is dispatched whenever a dispatched event raises an exception.""" From 17512b2431d830b762b0854ead3618fbc06d7a02 Mon Sep 17 00:00:00 2001 From: EmreTech <50607143+EmreTech@users.noreply.github.com> Date: Sun, 4 Dec 2022 02:11:00 -0800 Subject: [PATCH 03/13] fix: use attrs instead of dataclasses for 3.9 compatibility --- discatcore/utils/dispatcher.py | 4 ++-- discatcore/utils/event.py | 4 ++-- pyproject.toml | 1 + 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/discatcore/utils/dispatcher.py b/discatcore/utils/dispatcher.py index d857994..6a38b23 100644 --- a/discatcore/utils/dispatcher.py +++ b/discatcore/utils/dispatcher.py @@ -9,8 +9,8 @@ import traceback import typing as t from collections import defaultdict -from dataclasses import dataclass +import attr from typing_extensions import Self, TypeGuard from .event import Event, EventT, ExceptionEvent @@ -43,7 +43,7 @@ ConsumerCallback = t.Callable[[DispatcherT, "GatewayClient", JSONObject], Coro[None]] -@dataclass +@attr.define class Consumer(t.Generic[DispatcherT]): """Represents a dispatcher consumer. A consumer consumes a raw event and performs actions based on the raw event.""" diff --git a/discatcore/utils/event.py b/discatcore/utils/event.py index 7ddce8c..bf34347 100644 --- a/discatcore/utils/event.py +++ b/discatcore/utils/event.py @@ -3,8 +3,8 @@ from __future__ import annotations import typing as t -from dataclasses import dataclass +import attr from typing_extensions import Self if t.TYPE_CHECKING: @@ -46,7 +46,7 @@ def dispatches(cls): return cls.__dispatches -@dataclass(kw_only=True) +@attr.define(kw_only=True) class ExceptionEvent(Event, t.Generic[EventT]): """An event that is dispatched whenever a dispatched event raises an exception.""" diff --git a/pyproject.toml b/pyproject.toml index f4a3169..51eea46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ classifiers = [ [tool.poetry.dependencies] python = ">=3.9" aiohttp = ">=3.6.0,<3.9.0" +attrs = "^22.1" discord-typings = {git = "https://github.com/Bluenix2/discord-typings.git"} [tool.poetry.urls] From 796d417d7bcd4b3ffa4107a85a954773eab75dad Mon Sep 17 00:00:00 2001 From: EmreTech <50607143+EmreTech@users.noreply.github.com> Date: Sun, 4 Dec 2022 11:50:32 -0800 Subject: [PATCH 04/13] chore: update README example --- README.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index cff8240..019491a 100644 --- a/README.md +++ b/README.md @@ -36,9 +36,10 @@ dispatcher = discatcore.Dispatcher() intents = 3243773 gateway = discatcore.GatewayClient(http, dispatcher, intents=intents.value) -@dispatcher.new_event("ready").callback -async def ready(event: discord_typings.ReadyData): - print(event) +# alternatively, you can provide the event type in the decorator +@dispatcher.listen_to() +async def ready(event: discatcore.gateway.ReadyEvent): + print(event.data) async def main(): url: str | None = None From 52f253f4ef29e0450f52f0bf49d4d173bd849deb Mon Sep 17 00:00:00 2001 From: EmreTech <50607143+EmreTech@users.noreply.github.com> Date: Sun, 4 Dec 2022 19:00:34 -0800 Subject: [PATCH 05/13] feat: add overloads for listen_to --- discatcore/utils/dispatcher.py | 44 +++++++++++++++++++++++++++++++--- 1 file changed, 41 insertions(+), 3 deletions(-) diff --git a/discatcore/utils/dispatcher.py b/discatcore/utils/dispatcher.py index 6a38b23..23c53b2 100644 --- a/discatcore/utils/dispatcher.py +++ b/discatcore/utils/dispatcher.py @@ -156,8 +156,44 @@ def unsubscribe(self, event: type[EventT], func: ListenerCallback[EventT]) -> No if not listeners: del self._listeners[event] - def listen_to(self, *events: type[Event]) -> t.Callable[[ListenerCallbackT], ListenerCallbackT]: - def wrapper(func: ListenerCallbackT) -> ListenerCallbackT: + @t.overload + def listen_to( + self, func: ListenerCallback[EventT], *, events: None = ... + ) -> ListenerCallback[EventT]: + pass + + @t.overload + def listen_to( + self, func: ListenerCallback[EventT], *, events: list[type[EventT]] + ) -> t.NoReturn: + pass + + @t.overload + def listen_to( + self, func: None = ..., *, events: list[type[EventT]] + ) -> t.Callable[[ListenerCallback[EventT]], ListenerCallback[EventT]]: + pass + + @t.overload + def listen_to( + self, func: None = ..., *, events: None = ... + ) -> t.Callable[[ListenerCallback[EventT]], ListenerCallback[EventT]]: + pass + + def listen_to( + self, + func: t.Optional[ListenerCallback[EventT]] = None, + *, + events: t.Optional[list[type[EventT]]] = None, + ) -> t.Union[ + t.Callable[[ListenerCallback[EventT]], ListenerCallback[EventT]], + ListenerCallback[EventT], + t.NoReturn, + ]: + if func and events is not None: + raise ValueError(f"func and events parameters cannot both be set!") + + def wrapper(func: ListenerCallback[EventT]) -> ListenerCallback[EventT]: func_sig = inspect.signature(func) event_arg = next(iter(func_sig.parameters.values())) event_arg_anno = event_arg.annotation @@ -188,10 +224,12 @@ def event_check(arg: t.Any) -> None: resolved_events = {t.cast(type[Event], event_arg_anno)} for event in resolved_events: - self.subscribe(event, func) + self.subscribe(event, func) # pyright: ignore return func + if func: + return wrapper(func) return wrapper def dispatch(self, event: Event) -> asyncio.Future[t.Any]: From a7fda0dac1ed9e02f17fb3bf911fd1d2b1d5dff1 Mon Sep 17 00:00:00 2001 From: EmreTech <50607143+EmreTech@users.noreply.github.com> Date: Mon, 5 Dec 2022 00:22:22 -0800 Subject: [PATCH 06/13] fix: string annotations were not evalulated --- discatcore/utils/dispatcher.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/discatcore/utils/dispatcher.py b/discatcore/utils/dispatcher.py index 23c53b2..18bff7f 100644 --- a/discatcore/utils/dispatcher.py +++ b/discatcore/utils/dispatcher.py @@ -9,6 +9,7 @@ import traceback import typing as t from collections import defaultdict +from importlib import reload import attr from typing_extensions import Self, TypeGuard @@ -43,6 +44,25 @@ ConsumerCallback = t.Callable[[DispatcherT, "GatewayClient", JSONObject], Coro[None]] +# ported from discatpy +def _get_globals(x: object) -> dict[str, t.Any]: + module = inspect.getmodule(x) + + if module: + try: + t.TYPE_CHECKING = True + reload(module) + except ModuleNotFoundError: + # incomplete __main__ module + # this does mean that anything defined in TYPE_CHECKING will not be extracted + # TODO: find an alternative solution for __main__ module that extracts items from TYPE_CHECKING statements + pass + finally: + t.TYPE_CHECKING = False + + return module.__dict__ + + @attr.define class Consumer(t.Generic[DispatcherT]): """Represents a dispatcher consumer. A consumer consumes a raw event and performs actions based on the raw event.""" @@ -207,6 +227,8 @@ def wrapper(func: ListenerCallback[EventT]) -> ListenerCallback[EventT]: "No event type was provided! Please provide it as an argument or a type hint." ) else: + if isinstance(event_arg_anno, str): + event_arg_anno = eval(event_arg_anno, _get_globals(func)) def event_check(arg: t.Any) -> None: if not isinstance(arg, type) and not issubclass(arg, Event): From 3eb29aeb26b20938596ea1d3a72e871b568262b2 Mon Sep 17 00:00:00 2001 From: EmreTech <50607143+EmreTech@users.noreply.github.com> Date: Sun, 1 Jan 2023 13:13:32 -0400 Subject: [PATCH 07/13] refactor: migrate classproperty to separate file --- discatcore/utils/event.py | 18 +++--------------- discatcore/utils/functools.py | 25 +++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 15 deletions(-) create mode 100644 discatcore/utils/functools.py diff --git a/discatcore/utils/event.py b/discatcore/utils/event.py index bf34347..e5b3d00 100644 --- a/discatcore/utils/event.py +++ b/discatcore/utils/event.py @@ -5,29 +5,17 @@ import typing as t import attr -from typing_extensions import Self + +from .functools import classproperty if t.TYPE_CHECKING: from .dispatcher import ListenerCallback __all__ = ("Event", "ExceptionEvent") -T = t.TypeVar("T") EventT = t.TypeVar("EventT", bound="Event") -class _classproperty(t.Generic[T]): - def __init__(self, fget: t.Callable[[t.Any], T], /) -> None: - self.fget: "classmethod[T]" = t.cast("classmethod[T]", fget) - - def getter(self, fget: t.Callable[[t.Any], T], /) -> Self: - self.fget = t.cast("classmethod[T]", fget) - return self - - def __get__(self, obj: t.Optional[t.Any], type: t.Optional[type]) -> T: - return self.fget.__func__(type) - - class Event: """Represents a dispatcher event. An event class contains information about an event for use in listeners.""" @@ -40,7 +28,7 @@ def __init_subclass__(cls) -> None: cls.__dispatches = tuple(base for base in cls.__mro__ if issubclass(base, Event)) - @_classproperty + @classproperty @classmethod def dispatches(cls): return cls.__dispatches diff --git a/discatcore/utils/functools.py b/discatcore/utils/functools.py new file mode 100644 index 0000000..3361d40 --- /dev/null +++ b/discatcore/utils/functools.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: MIT + +import typing as t + +from typing_extensions import Self + +__all__ = ("classproperty",) + +T = t.TypeVar("T") + + +class classproperty(t.Generic[T]): + def __init__(self, fget: t.Callable[[t.Any], T], /) -> None: + self.fget: "classmethod[T]" + self.getter(fget) + + def getter(self, fget: t.Callable[[t.Any], T], /) -> Self: + if not isinstance(fget, classmethod): + raise ValueError(f"Callable {fget.__name__} is not a classmethod!") + + self.fget = fget + return self + + def __get__(self, obj: t.Optional[t.Any], type: t.Optional[type]) -> T: + return self.fget.__func__(type) From c7e0d65226d1f08c5d9f35775a356977467eafc3 Mon Sep 17 00:00:00 2001 From: EmreTech <50607143+EmreTech@users.noreply.github.com> Date: Fri, 10 Feb 2023 21:35:54 -0800 Subject: [PATCH 08/13] refactor!: remove abstractions for dispatched events --- discatcore/gateway/client.py | 27 ++- discatcore/gateway/events.py | 410 +---------------------------------- 2 files changed, 18 insertions(+), 419 deletions(-) diff --git a/discatcore/gateway/client.py b/discatcore/gateway/client.py index 1e2e52a..8509b17 100644 --- a/discatcore/gateway/client.py +++ b/discatcore/gateway/client.py @@ -16,9 +16,9 @@ from ..errors import GatewayReconnect from ..http import HTTPClient +from ..utils import json from ..utils.dispatcher import Dispatcher -from ..utils.json import dumps, loads -from .events import InvalidSessionEvent, ReconnectEvent, name_to_class +from . import events from .ratelimiter import Ratelimiter from .types import BaseTypedWSMessage, is_binary, is_text @@ -184,7 +184,7 @@ async def send(self, data: Mapping[str, t.Any]) -> None: return await self.ratelimiter.acquire() - await self._ws.send_json(data, dumps=dumps) + await self._ws.send_json(data, dumps=json.dumps) _log.debug("Sent JSON payload %s to the Gateway.", data) async def receive(self) -> t.Optional[bool]: @@ -215,7 +215,7 @@ async def receive(self) -> t.Optional[bool]: else: received_msg = t.cast(str, typed_msg.data) - self.recent_payload = t.cast(dt.GatewayEvent, loads(received_msg)) + self.recent_payload = t.cast(dt.GatewayEvent, json.loads(received_msg)) _log.debug("Received payload from the Gateway: %s", self.recent_payload) self.sequence = self.recent_payload.get("s") return True @@ -280,30 +280,33 @@ async def connection_loop(self) -> None: if res and self.recent_payload is not None: op = int(self.recent_payload["op"]) - if op == DISPATCH and self.recent_payload.get("t") is not None: - event_name = str(self.recent_payload.get("t")).lower() - data = self.recent_payload.get("d") + if op == DISPATCH and (event_name := self.recent_payload.get("t")) is not None: + data = t.cast(json.JSONObject, self.recent_payload.get("d")) - if event_name == "ready": + self._dispatcher.consume(event_name, self, data) + await self._dispatcher.dispatch( + events.DispatchEvent(t.cast(t.Mapping[str, t.Any], data)) + ) + + if event_name == "READY": ready_data = t.cast(dt.ReadyData, data) self.session_id = ready_data["session_id"] self.resume_url = ready_data["resume_gateway_url"] - event = name_to_class[event_name](data) - await self._dispatcher.dispatch(event) + await self._dispatcher.dispatch(events.ReadyEvent(ready_data)) # these should be rare, but it's better to be safe than sorry elif op == HEARTBEAT: await self.heartbeat() elif op == RECONNECT: - await self._dispatcher.dispatch(ReconnectEvent()) + await self._dispatcher.dispatch(events.ReconnectEvent()) await self.close(code=1012) return elif op == INVALID_SESSION: self.can_resume = bool(self.recent_payload.get("d")) - await self._dispatcher.dispatch(InvalidSessionEvent(self.can_resume)) + await self._dispatcher.dispatch(events.InvalidSessionEvent(self.can_resume)) await self.close(code=1012) return diff --git a/discatcore/gateway/events.py b/discatcore/gateway/events.py index 82d4fd1..11c9c6a 100644 --- a/discatcore/gateway/events.py +++ b/discatcore/gateway/events.py @@ -10,68 +10,11 @@ __all__ = ( "GatewayEvent", - "UnknownEvent", - "ApplicationCommandPermissionsUpdateEvent", - "AutoModerationRuleCreateEvent", - "AutoModerationRuleDeleteEvent", - "AutoModerationRuleUpdateEvent", - "ChannelCreateEvent", - "ChannelDeleteEvent", - "ChannelPinsUpdateEvent", - "ChannelUpdateEvent", - "GuildBanAddEvent", - "GuildBanRemoveEvent", - "GuildCreateEvent", - "GuildDeleteEvent", - "GuildEmojisUpdateEvent", - "GuildIntegrationsUpdateEvent", - "GuildMemberAddEvent", - "GuildMemberRemoveEvent", - "GuildMemberUpdateEvent", - "GuildMembersChunkEvent", - "GuildRoleCreateEvent", - "GuildRoleDeleteEvent", - "GuildRoleUpdateEvent", - "GuildScheduledEventCreateEvent", - "GuildScheduledEventDeleteEvent", - "GuildScheduledEventUpdateEvent", - "GuildScheduledEventUserAddEvent", - "GuildScheduledEventUserRemoveEvent", - "GuildStickersUpdateEvent", - "GuildUpdateEvent", - "IntegrationCreateEvent", - "IntegrationDeleteEvent", - "IntegrationUpdateEvent", - "InteractionCreateEvent", + "DispatchEvent", "InvalidSessionEvent", - "InviteCreateEvent", - "InviteDeleteEvent", - "MessageCreateEvent", - "MessageDeleteEvent", - "MessageDeleteBulkEvent", - "MessageReactionAddEvent", - "MessageReactionRemoveEvent", - "MessageReactionRemoveAllEvent", - "MessageReactionRemoveEmojiEvent", - "MessageUpdateEvent", - "PresenceUpdateEvent", "ReadyEvent", "ReconnectEvent", "ResumedEvent", - "StageInstanceCreateEvent", - "StageInstanceDeleteEvent", - "StageInstanceUpdateEvent", - "ThreadCreateEvent", - "ThreadDeleteEvent", - "ThreadListSyncEvent", - "ThreadMemberUpdateEvent", - "ThreadMembersUpdateEvent", - "ThreadUpdateEvent", - "TypingStartEvent", - "UserUpdateEvent", - "VoiceServerUpdateEvent", - "VoiceStateUpdateEvent", - "WebhooksUpdateEvent", ) @@ -81,8 +24,8 @@ class GatewayEvent(Event): @dataclass -class UnknownEvent(GatewayEvent): - data: dt.DispatchEvent +class DispatchEvent(GatewayEvent): + data: t.Mapping[str, t.Any] @dataclass @@ -103,350 +46,3 @@ class ReconnectEvent(GatewayEvent): @dataclass class InvalidSessionEvent(GatewayEvent): resumable: bool - - -@dataclass -class ApplicationCommandPermissionsUpdateEvent(GatewayEvent): - data: dt.ApplicationCommandPermissionsUpdateData - - -@dataclass -class AutoModerationRuleCreateEvent(GatewayEvent): - data: dt.AutoModerationRuleData - - -@dataclass -class AutoModerationRuleDeleteEvent(GatewayEvent): - data: dt.AutoModerationRuleData - - -@dataclass -class AutoModerationRuleUpdateEvent(GatewayEvent): - data: dt.AutoModerationRuleData - - -@dataclass -class ChannelCreateEvent(GatewayEvent): - data: dt.ChannelCreateData - - -@dataclass -class ChannelDeleteEvent(GatewayEvent): - data: dt.ChannelDeleteData - - -@dataclass -class ChannelPinsUpdateEvent(GatewayEvent): - data: dt.ChannelPinsUpdateData - - -@dataclass -class ChannelUpdateEvent(GatewayEvent): - data: dt.ChannelUpdateData - - -@dataclass -class GuildBanAddEvent(GatewayEvent): - data: dt.GuildBanAddData - - -@dataclass -class GuildBanRemoveEvent(GatewayEvent): - data: dt.GuildBanRemoveData - - -@dataclass -class GuildCreateEvent(GatewayEvent): - data: dt.GuildCreateData - - -@dataclass -class GuildDeleteEvent(GatewayEvent): - data: dt.GuildDeleteData - - -@dataclass -class GuildEmojisUpdateEvent(GatewayEvent): - data: dt.GuildEmojisUpdateData - - -@dataclass -class GuildIntegrationsUpdateEvent(GatewayEvent): - data: dt.GuildIntergrationsUpdateData - - -@dataclass -class GuildMemberAddEvent(GatewayEvent): - data: dt.GuildMemberAddData - - -@dataclass -class GuildMemberRemoveEvent(GatewayEvent): - data: dt.GuildMemberRemoveData - - -@dataclass -class GuildMemberUpdateEvent(GatewayEvent): - data: dt.GuildMemberUpdateData - - -@dataclass -class GuildMembersChunkEvent(GatewayEvent): - data: dt.GuildMembersChunkData - - -@dataclass -class GuildRoleCreateEvent(GatewayEvent): - data: dt.GuildRoleCreateData - - -@dataclass -class GuildRoleDeleteEvent(GatewayEvent): - data: dt.GuildRoleDeleteData - - -@dataclass -class GuildRoleUpdateEvent(GatewayEvent): - data: dt.GuildRoleUpdateData - - -@dataclass -class GuildScheduledEventCreateEvent(GatewayEvent): - data: dt.GuildScheduledEventCreateData - - -@dataclass -class GuildScheduledEventDeleteEvent(GatewayEvent): - data: dt.GuildScheduledEventDeleteData - - -@dataclass -class GuildScheduledEventUpdateEvent(GatewayEvent): - data: dt.GuildScheduledEventUpdateData - - -@dataclass -class GuildScheduledEventUserAddEvent(GatewayEvent): - data: dt.GuildScheduledEventUserAddData - - -@dataclass -class GuildScheduledEventUserRemoveEvent(GatewayEvent): - data: dt.GuildScheduledEventUserRemoveData - - -@dataclass -class GuildStickersUpdateEvent(GatewayEvent): - data: dt.GuildStickersUpdateData - - -@dataclass -class GuildUpdateEvent(GatewayEvent): - data: dt.GuildUpdateData - - -@dataclass -class IntegrationCreateEvent(GatewayEvent): - data: dt.IntegrationCreateData - - -@dataclass -class IntegrationDeleteEvent(GatewayEvent): - data: dt.IntegrationDeleteData - - -@dataclass -class IntegrationUpdateEvent(GatewayEvent): - data: dt.IntegrationUpdateData - - -@dataclass -class InteractionCreateEvent(GatewayEvent): - data: dt.InteractionCreateData - - -@dataclass -class InviteCreateEvent(GatewayEvent): - data: dt.InviteCreateData - - -@dataclass -class InviteDeleteEvent(GatewayEvent): - data: dt.InviteDeleteData - - -@dataclass -class MessageCreateEvent(GatewayEvent): - data: dt.MessageCreateData - - -@dataclass -class MessageDeleteEvent(GatewayEvent): - data: dt.MessageDeleteData - - -@dataclass -class MessageDeleteBulkEvent(GatewayEvent): - data: dt.MessageDeleteBulkData - - -@dataclass -class MessageReactionAddEvent(GatewayEvent): - data: dt.MessageReactionAddData - - -@dataclass -class MessageReactionRemoveEvent(GatewayEvent): - data: dt.MessageReactionRemoveData - - -@dataclass -class MessageReactionRemoveAllEvent(GatewayEvent): - data: dt.MessageReactionRemoveAllData - - -@dataclass -class MessageReactionRemoveEmojiEvent(GatewayEvent): - data: dt.MessageReactionRemoveEmojiData - - -@dataclass -class MessageUpdateEvent(GatewayEvent): - data: dt.MessageUpdateData - - -@dataclass -class PresenceUpdateEvent(GatewayEvent): - data: dt.PresenceUpdateData - - -@dataclass -class StageInstanceCreateEvent(GatewayEvent): - data: dt.StageInstanceCreateData - - -@dataclass -class StageInstanceDeleteEvent(GatewayEvent): - data: dt.StageInstanceDeleteData - - -@dataclass -class StageInstanceUpdateEvent(GatewayEvent): - data: dt.StageInstanceUpdateData - - -@dataclass -class ThreadCreateEvent(GatewayEvent): - data: dt.ThreadCreateData - - -@dataclass -class ThreadDeleteEvent(GatewayEvent): - data: dt.ThreadDeleteData - - -@dataclass -class ThreadListSyncEvent(GatewayEvent): - data: dt.ThreadListSyncData - - -@dataclass -class ThreadMemberUpdateEvent(GatewayEvent): - data: dt.ThreadMemberUpdateData - - -@dataclass -class ThreadMembersUpdateEvent(GatewayEvent): - data: dt.ThreadMembersUpdateData - - -@dataclass -class ThreadUpdateEvent(GatewayEvent): - data: dt.ThreadUpdateData - - -@dataclass -class TypingStartEvent(GatewayEvent): - data: dt.TypingStartData - - -@dataclass -class UserUpdateEvent(GatewayEvent): - data: dt.UserUpdateData - - -@dataclass -class VoiceServerUpdateEvent(GatewayEvent): - data: dt.VoiceServerUpdateData - - -@dataclass -class VoiceStateUpdateEvent(GatewayEvent): - data: dt.VoiceStateData - - -@dataclass -class WebhooksUpdateEvent(GatewayEvent): - data: dt.WebhooksUpdateData - - -name_to_class: dict[str, t.Any] = { - "application_command_permissions_update": ApplicationCommandPermissionsUpdateEvent, - "auto_moderation_rule_create": AutoModerationRuleCreateEvent, - "auto_moderation_rule_delete": AutoModerationRuleDeleteEvent, - "auto_moderation_rule_update": AutoModerationRuleUpdateEvent, - "channel_create": ChannelCreateEvent, - "channel_delete": ChannelDeleteEvent, - "channel_pins_update": ChannelPinsUpdateEvent, - "channel_update": ChannelUpdateEvent, - "guild_ban_add": GuildBanAddEvent, - "guild_ban_remove": GuildBanRemoveEvent, - "guild_create": GuildCreateEvent, - "guild_delete": GuildDeleteEvent, - "guild_emojis_update": GuildEmojisUpdateEvent, - "guild_integrations_update": GuildIntegrationsUpdateEvent, - "guild_member_add": GuildMemberAddEvent, - "guild_member_remove": GuildMemberRemoveEvent, - "guild_member_update": GuildMemberUpdateEvent, - "guild_members_chunk": GuildMembersChunkEvent, - "guild_role_create": GuildRoleCreateEvent, - "guild_role_delete": GuildRoleDeleteEvent, - "guild_role_update": GuildRoleUpdateEvent, - "guild_scheduled_event_create": GuildScheduledEventCreateEvent, - "guild_scheduled_event_delete": GuildScheduledEventDeleteEvent, - "guild_scheduled_event_update": GuildScheduledEventUpdateEvent, - "guild_scheduled_event_user_add": GuildScheduledEventUserAddEvent, - "guild_scheduled_event_user_remove": GuildScheduledEventUserRemoveEvent, - "guild_stickers_update": GuildStickersUpdateEvent, - "guild_update": GuildUpdateEvent, - "integration_create": IntegrationCreateEvent, - "integration_delete": IntegrationDeleteEvent, - "integration_update": IntegrationUpdateEvent, - "interaction_create": InteractionCreateEvent, - "invite_create": InviteCreateEvent, - "invite_delete": InviteDeleteEvent, - "message_create": MessageCreateEvent, - "message_delete": MessageDeleteEvent, - "message_delete_bulk": MessageDeleteBulkEvent, - "message_reaction_add": MessageReactionAddEvent, - "message_reaction_remove": MessageReactionRemoveEvent, - "message_reaction_remove_all": MessageReactionRemoveAllEvent, - "message_reaction_remove_emoji": MessageReactionRemoveEmojiEvent, - "message_update": MessageUpdateEvent, - "presence_update": PresenceUpdateEvent, - "ready": ReadyEvent, - "stage_instance_create": StageInstanceCreateEvent, - "stage_instance_delete": StageInstanceDeleteEvent, - "stage_instance_update": StageInstanceUpdateEvent, - "thread_create": ThreadCreateEvent, - "thread_delete": ThreadDeleteEvent, - "thread_list_sync": ThreadListSyncEvent, - "thread_member_update": ThreadMemberUpdateEvent, - "thread_members_update": ThreadMembersUpdateEvent, - "thread_update": ThreadUpdateEvent, - "typing_start": TypingStartEvent, - "user_update": UserUpdateEvent, - "voice_server_update": VoiceServerUpdateEvent, - "voice_state_update": VoiceStateUpdateEvent, - "webhooks_update": WebhooksUpdateEvent, -} From f1a35689d314e21dad189a30467a228bef0c8d33 Mon Sep 17 00:00:00 2001 From: EmreTech <50607143+EmreTech@users.noreply.github.com> Date: Fri, 10 Feb 2023 23:48:45 -0800 Subject: [PATCH 09/13] refactor: swap dataclasses with attrs --- discatcore/gateway/events.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/discatcore/gateway/events.py b/discatcore/gateway/events.py index 11c9c6a..a15e36e 100644 --- a/discatcore/gateway/events.py +++ b/discatcore/gateway/events.py @@ -2,8 +2,8 @@ from __future__ import annotations import typing as t -from dataclasses import dataclass +import attr import discord_typings as dt from ..utils.event import Event @@ -18,31 +18,31 @@ ) -@dataclass +@attr.define class GatewayEvent(Event): pass -@dataclass +@attr.define class DispatchEvent(GatewayEvent): data: t.Mapping[str, t.Any] -@dataclass +@attr.define class ReadyEvent(GatewayEvent): data: dt.ReadyData -@dataclass +@attr.define class ResumedEvent(GatewayEvent): pass -@dataclass +@attr.define class ReconnectEvent(GatewayEvent): pass -@dataclass +@attr.define class InvalidSessionEvent(GatewayEvent): resumable: bool From b4176fe6bb7fcc247b49f636bea154020c11a5b0 Mon Sep 17 00:00:00 2001 From: EmreTech <50607143+EmreTech@users.noreply.github.com> Date: Mon, 20 Feb 2023 23:09:40 -0800 Subject: [PATCH 10/13] refactor: consumers only accept event names --- discatcore/utils/dispatcher.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/discatcore/utils/dispatcher.py b/discatcore/utils/dispatcher.py index 18bff7f..b023947 100644 --- a/discatcore/utils/dispatcher.py +++ b/discatcore/utils/dispatcher.py @@ -68,16 +68,14 @@ class Consumer(t.Generic[DispatcherT]): """Represents a dispatcher consumer. A consumer consumes a raw event and performs actions based on the raw event.""" callback: ConsumerCallback[DispatcherT] - events: tuple[type[Event], ...] + events: tuple[str, ...] def consumer_for( - *event_types: type[Event], + *events: str, ) -> t.Callable[[ConsumerCallback[DispatcherT]], Consumer[DispatcherT]]: - event_types = tuple({event for event_type in event_types for event in event_type.dispatches}) - def wrapper(func: ConsumerCallback[DispatcherT]) -> Consumer[DispatcherT]: - return Consumer(func, event_types) + return Consumer(func, events) return wrapper @@ -103,6 +101,9 @@ def __init__(self) -> None: self._consumers[name.lower()] = value + for event_name in value.events: + self._consumers[event_name.lower()] = value + async def _run_listener(self, event: EventT, listener: ListenerCallback[EventT]) -> None: try: await listener(event) From 541ea5bd2cb77d5ae304c773dce730e683b60a32 Mon Sep 17 00:00:00 2001 From: Emre Terzioglu <50607143+EmreTech@users.noreply.github.com> Date: Sun, 16 Jul 2023 15:08:29 -0700 Subject: [PATCH 11/13] refactor!: listener does not read type hints --- README.md | 5 +- discatcore/utils/dispatcher.py | 103 ++------------------------------- 2 files changed, 7 insertions(+), 101 deletions(-) diff --git a/README.md b/README.md index 019491a..10685d9 100644 --- a/README.md +++ b/README.md @@ -36,9 +36,8 @@ dispatcher = discatcore.Dispatcher() intents = 3243773 gateway = discatcore.GatewayClient(http, dispatcher, intents=intents.value) -# alternatively, you can provide the event type in the decorator -@dispatcher.listen_to() -async def ready(event: discatcore.gateway.ReadyEvent): +@dispatcher.listen_to(discatcore.gateway.ReadyEvent) +async def ready(event): print(event.data) async def main(): diff --git a/discatcore/utils/dispatcher.py b/discatcore/utils/dispatcher.py index b023947..932e83a 100644 --- a/discatcore/utils/dispatcher.py +++ b/discatcore/utils/dispatcher.py @@ -5,11 +5,9 @@ import asyncio import inspect import logging -import sys import traceback import typing as t from collections import defaultdict -from importlib import reload import attr from typing_extensions import Self, TypeGuard @@ -20,13 +18,6 @@ if t.TYPE_CHECKING: from ..gateway import GatewayClient -if sys.version_info >= (3, 10): - from types import UnionType - - _union_types = {t.Union, UnionType} -else: - _union_types = {t.Union} - _log = logging.getLogger(__name__) __all__ = ( @@ -44,25 +35,6 @@ ConsumerCallback = t.Callable[[DispatcherT, "GatewayClient", JSONObject], Coro[None]] -# ported from discatpy -def _get_globals(x: object) -> dict[str, t.Any]: - module = inspect.getmodule(x) - - if module: - try: - t.TYPE_CHECKING = True - reload(module) - except ModuleNotFoundError: - # incomplete __main__ module - # this does mean that anything defined in TYPE_CHECKING will not be extracted - # TODO: find an alternative solution for __main__ module that extracts items from TYPE_CHECKING statements - pass - finally: - t.TYPE_CHECKING = False - - return module.__dict__ - - @attr.define class Consumer(t.Generic[DispatcherT]): """Represents a dispatcher consumer. A consumer consumes a raw event and performs actions based on the raw event.""" @@ -91,7 +63,7 @@ class Dispatcher: def __init__(self) -> None: self._listeners: defaultdict[type[Event], list[ListenerCallback[Event]]] = defaultdict( - lambda: [] + list ) self._consumers: dict[str, Consumer[Self]] = {} @@ -177,82 +149,17 @@ def unsubscribe(self, event: type[EventT], func: ListenerCallback[EventT]) -> No if not listeners: del self._listeners[event] - @t.overload - def listen_to( - self, func: ListenerCallback[EventT], *, events: None = ... - ) -> ListenerCallback[EventT]: - pass - - @t.overload - def listen_to( - self, func: ListenerCallback[EventT], *, events: list[type[EventT]] - ) -> t.NoReturn: - pass - - @t.overload - def listen_to( - self, func: None = ..., *, events: list[type[EventT]] - ) -> t.Callable[[ListenerCallback[EventT]], ListenerCallback[EventT]]: - pass - - @t.overload - def listen_to( - self, func: None = ..., *, events: None = ... - ) -> t.Callable[[ListenerCallback[EventT]], ListenerCallback[EventT]]: - pass - def listen_to( self, - func: t.Optional[ListenerCallback[EventT]] = None, *, - events: t.Optional[list[type[EventT]]] = None, - ) -> t.Union[ - t.Callable[[ListenerCallback[EventT]], ListenerCallback[EventT]], - ListenerCallback[EventT], - t.NoReturn, - ]: - if func and events is not None: - raise ValueError(f"func and events parameters cannot both be set!") - + events: list[type[EventT]] + ) -> t.Callable[[ListenerCallback[EventT]], ListenerCallback[EventT]]: def wrapper(func: ListenerCallback[EventT]) -> ListenerCallback[EventT]: - func_sig = inspect.signature(func) - event_arg = next(iter(func_sig.parameters.values())) - event_arg_anno = event_arg.annotation - - resolved_events: set[type[Event]] - if event_arg_anno is inspect.Parameter.empty: - if events: - resolved_events = set(events) - else: - raise TypeError( - "No event type was provided! Please provide it as an argument or a type hint." - ) - else: - if isinstance(event_arg_anno, str): - event_arg_anno = eval(event_arg_anno, _get_globals(func)) - - def event_check(arg: t.Any) -> None: - if not isinstance(arg, type) and not issubclass(arg, Event): - raise TypeError(f"Expected an event, got {arg!r}.") - - if t.get_origin(event_arg_anno) in _union_types: - union_args = t.get_args(event_arg_anno) - - for arg in union_args: - event_check(arg) - - resolved_events = t.cast(set[type[Event]], set(union_args)) - else: - event_check(event_arg_anno) - resolved_events = {t.cast(type[Event], event_arg_anno)} - - for event in resolved_events: - self.subscribe(event, func) # pyright: ignore + for event in events: + self.subscribe(event, func) return func - if func: - return wrapper(func) return wrapper def dispatch(self, event: Event) -> asyncio.Future[t.Any]: From a95c8f3dd8a12f283a5bb707bb0a464957fe8edd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 16 Jul 2023 22:08:39 +0000 Subject: [PATCH 12/13] style: pre-commit autofix --- discatcore/utils/dispatcher.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/discatcore/utils/dispatcher.py b/discatcore/utils/dispatcher.py index 932e83a..af028c0 100644 --- a/discatcore/utils/dispatcher.py +++ b/discatcore/utils/dispatcher.py @@ -62,9 +62,7 @@ class Dispatcher: __slots__ = ("_listeners", "_consumers") def __init__(self) -> None: - self._listeners: defaultdict[type[Event], list[ListenerCallback[Event]]] = defaultdict( - list - ) + self._listeners: defaultdict[type[Event], list[ListenerCallback[Event]]] = defaultdict(list) self._consumers: dict[str, Consumer[Self]] = {} for name, value in inspect.getmembers(self): @@ -150,9 +148,7 @@ def unsubscribe(self, event: type[EventT], func: ListenerCallback[EventT]) -> No del self._listeners[event] def listen_to( - self, - *, - events: list[type[EventT]] + self, *, events: list[type[EventT]] ) -> t.Callable[[ListenerCallback[EventT]], ListenerCallback[EventT]]: def wrapper(func: ListenerCallback[EventT]) -> ListenerCallback[EventT]: for event in events: From e0ba017859a9a655a4a7d1af5e5752a3f3f82a37 Mon Sep 17 00:00:00 2001 From: Emre Terzioglu <50607143+EmreTech@users.noreply.github.com> Date: Sun, 16 Jul 2023 15:16:32 -0700 Subject: [PATCH 13/13] refactor: update classproperty type hints --- discatcore/utils/functools.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/discatcore/utils/functools.py b/discatcore/utils/functools.py index 3361d40..07a9829 100644 --- a/discatcore/utils/functools.py +++ b/discatcore/utils/functools.py @@ -7,19 +7,20 @@ __all__ = ("classproperty",) T = t.TypeVar("T") +ClsT = t.TypeVar("ClsT") -class classproperty(t.Generic[T]): - def __init__(self, fget: t.Callable[[t.Any], T], /) -> None: - self.fget: "classmethod[T]" +class classproperty(t.Generic[ClsT, T]): + def __init__(self, fget: t.Callable[[ClsT], T], /) -> None: + self.fget: "classmethod[ClsT, ..., T]" self.getter(fget) - def getter(self, fget: t.Callable[[t.Any], T], /) -> Self: + def getter(self, fget: t.Callable[[ClsT], T], /) -> Self: if not isinstance(fget, classmethod): raise ValueError(f"Callable {fget.__name__} is not a classmethod!") self.fget = fget return self - def __get__(self, obj: t.Optional[t.Any], type: t.Optional[type]) -> T: + def __get__(self, obj: t.Optional[t.Any], type: t.Optional[ClsT]) -> T: return self.fget.__func__(type)