diff --git a/discord/__init__.py b/discord/__init__.py index ecb26038..fc7bf3c3 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -65,6 +65,7 @@ from .team import * from .sticker import Sticker, GuildSticker, StickerPack from .scheduled_event import GuildScheduledEvent +from .monetization import * MISSING = utils.MISSING diff --git a/discord/client.py b/discord/client.py index 2b8c521f..763a68bf 100644 --- a/discord/client.py +++ b/discord/client.py @@ -52,6 +52,8 @@ TYPE_CHECKING ) +from typing_extensions import Literal + from .auto_updater import AutoUpdateChecker from .sticker import StickerPack from .user import ClientUser, User @@ -62,6 +64,7 @@ from .channel import _channel_factory, PartialMessageable from .enums import ChannelType, ApplicationCommandType, Locale from .mentions import AllowedMentions +from .monetization import Entitlement, SKU from .errors import * from .enums import Status, VoiceRegion from .gateway import * @@ -73,7 +76,7 @@ from .object import Object from .backoff import ExponentialBackoff from .webhook import Webhook -from .iterators import GuildIterator +from .iterators import GuildIterator, EntitlementIterator from .appinfo import AppInfo from .application_commands import * @@ -102,7 +105,6 @@ _SubmitCallback = Callable[[ModalSubmitInteraction], Coroutine[Any, Any, Any]] - T = TypeVar('T') Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]]) @@ -563,12 +565,12 @@ async def _request_sync_commands(self, is_cog_reload: bool = False, *, reload_fa ): return await self._sync_commands() state = self._connection # Speedup attribute access - app_id = self.app.id - get_commands = self.http.get_application_commands if not is_cog_reload: + app_id = self.app.id log.info('Collecting global application-commands for application %s (%s)', self.app.name, self.app.id) self._minimal_registered_global_commands_raw = minimal_registered_global_commands_raw = [] + get_commands = self.http.get_application_commands global_registered_raw = await get_commands(app_id) for raw_command in global_registered_raw: @@ -585,7 +587,10 @@ async def _request_sync_commands(self, is_cog_reload: bool = False, *, reload_fa command._state = state self._application_commands[command.id] = command - log.info('Done! Cached %s global application-commands', sum([len(cmds) for cmds in self._application_commands_by_type.values()])) + log.info( + 'Done! Cached %s global application-commands', + sum([len(cmds) for cmds in self._application_commands_by_type.values()]) + ) log.info('Collecting guild-specific application-commands for application %s (%s)', self.app.name, app_id) self._minimal_registered_guild_commands_raw = minimal_registered_guild_commands_raw = {} @@ -605,22 +610,33 @@ async def _request_sync_commands(self, is_cog_reload: bool = False, *, reload_fa try: guild_commands = self._guild_specific_application_commands[guild.id] except KeyError: - self._guild_specific_application_commands[guild.id] = guild_commands = {'chat_input': {}, 'user': {}, 'message': {}} + self._guild_specific_application_commands[guild.id] = guild_commands = { + 'chat_input': {}, 'user': {}, 'message': {} + } for raw_command in registered_guild_commands_raw: command_type = str(ApplicationCommandType.try_value(raw_command['type'])) - minimal_registered_guild_commands.append({'id': int(raw_command['id']), 'type': command_type, 'name': raw_command['name']}) + minimal_registered_guild_commands.append( + {'id': int(raw_command['id']), 'type': command_type, 'name': raw_command['name']} + ) try: command = guild_commands[command_type][raw_command['name']] except KeyError: command = ApplicationCommand._from_type(state, data=raw_command) command.func = None - self._application_commands[command.id] = guild._application_commands[command.id] = guild_commands[command_type][command.name] = command + self._application_commands[command.id] = guild._application_commands[command.id] \ + = guild_commands[command_type][command.name] = command else: command._fill_data(raw_command) command._state = state self._application_commands[command.id] = guild._application_commands[command.id] = command - log.info('Done! Cached %s commands for %s guilds', sum([len(commands) for commands in list(minimal_registered_guild_commands_raw.values())]), len(minimal_registered_guild_commands_raw.keys())) + log.info( + 'Done! Cached %s commands for %s guilds', + sum([ + len(commands) for commands in list(minimal_registered_guild_commands_raw.values()) + ]), + len(minimal_registered_guild_commands_raw.keys()) + ) else: # re-assign metadata to the commands (for commands added from cogs) @@ -673,11 +689,21 @@ async def _request_sync_commands(self, is_cog_reload: bool = False, *, reload_fa self._application_commands[command.id] = guild._application_commands[command.id] = command log.info('Done!') if no_longer_in_code_global: - log.warning('%s global application-commands where removed from code but are still registered in discord', no_longer_in_code_global) + log.warning( + '%s global application-commands where removed from code but are still registered in discord', + no_longer_in_code_global + ) if no_longer_in_code_guild_specific: - log.warning('In total %s guild-specific application-commands from %s guild(s) where removed from code but are still registered in discord', no_longer_in_code_guild_specific, len(no_longer_in_code_guilds)) + log.warning( + 'In total %s guild-specific application-commands from %s guild(s) where removed from code ' + 'but are still registered in discord', no_longer_in_code_guild_specific, + len(no_longer_in_code_guilds) + ) if no_longer_in_code_global or no_longer_in_code_guild_specific: - log.warning('To prevent the above, set `sync_commands_on_cog_reload` of %s to True', self.__class__.__name__) + log.warning( + 'To prevent the above, set `sync_commands_on_cog_reload` of %s to True', + self.__class__.__name__ + ) @utils.deprecated('Guild.chunk') async def request_offline_members(self, *guilds): @@ -2253,23 +2279,27 @@ async def _sync_commands(self) -> None: if any_changed is True: updated = None - if len(to_send) == 1 and has_update and not to_maybe_remove: + if (to_send_count := len(to_send)) == 1 and has_update and not to_maybe_remove: log.info('Detected changes on global application-command %s, updating.', to_send[0]['name']) updated = await self.http.edit_application_command(application_id, to_send[0]['id'], to_send[0]) - elif len(to_send) == 1 and not has_update and not to_maybe_remove: + elif len == 1 and not has_update and not to_maybe_remove: log.info('Registering one new global application-command %s.', to_send[0]['name']) updated = await self.http.create_application_command(application_id, to_send[0]) else: - if len(to_send) > 0: - log.info('Detected %s updated/new global application-commands, bulk overwriting them...', len(to_send)) + if to_send_count > 0: + log.info( + f'Detected %s updated/new global application-commands, bulk overwriting them...', + to_send_count + ) if not self.delete_not_existing_commands: to_send.extend(to_maybe_remove) else: - if len(to_maybe_remove) > 0: + if (to_maybe_remove_count := len(to_maybe_remove)) > 0: log.info( 'Removing %s global application-command(s) that isn\'t/arent used in this code anymore.' ' To prevent this set `delete_not_existing_commands` of %s to False', - len(to_maybe_remove), self.__class__.__name__ + to_maybe_remove_count, + self.__class__.__name__ ) to_send.extend(to_cep) global_registered_raw = await self.http.bulk_overwrite_application_commands(application_id, to_send) @@ -2948,4 +2978,130 @@ async def fetch_voice_regions(self) -> List[VoiceRegionInfo]: The voice regions that can be used. """ data = await self.http.get_voice_regions() - return [VoiceRegionInfo(data=d) for d in data] \ No newline at end of file + return [VoiceRegionInfo(data=d) for d in data] + + async def create_test_entitlement( + self, + sku_id: int, + target: Union[User, Guild, Snowflake], + owner_type: Optional[Literal['guild', 'user']] = MISSING + ): + """|coro| + + .. note:: + + This method is only temporary and probably will be removed with or even before a stable v2 release + as discord is already redesigning the testing system based on developer feedback. + + See https://github.com/discord/discord-api-docs/pull/6502 for more information. + + Creates a test entitlement to a given :class:`SKU` for a given guild or user. + Discord will act as though that user or guild has entitlement to your premium offering. + + After creating a test entitlement, you'll need to reload your Discord client. + After doing so, you'll see that your server or user now has premium access. + + Parameters + ---------- + sku_id: :class:`int` + The ID of the SKU to create a test entitlement for. + target: Union[:class:`User`, :class:`Guild`, :class:`Snowflake`] + The target to create a test entitlement for. + + This can be a user, guild or just the ID, if so the owner_type parameter must be set. + owner_type: :class:`str` + The type of the ``target``, could be ``guild`` or ``user``. + + Returns + -------- + :class:`Entitlement` + The created test entitlement. + """ + target = target.id + + if isinstance(target, Guild): + owner_type = 1 + elif isinstance(target, User): + owner_type = 2 + else: + if owner_type is MISSING: + raise TypeError('owner_type must be set if target is not a Guild or user-like object.') + else: + owner_type = 1 if owner_type == 'guild' else 2 + + data = await self.http.create_test_entitlement( + self.app.id, + sku_id=sku_id, + owner_id=target, + owner_type=owner_type + ) + return Entitlement(data=data, state=self._connection) + + async def delete_test_entitlement(self, entitlement_id: int): + """|coro| + + .. note:: + + This method is only temporary and probably will be removed with or even before a stable v2 release + as discord is already redesigning the testing system based on developer feedback. + + See https://github.com/discord/discord-api-docs/pull/6502 for more information. + + Deletes a currently-active test entitlement. + Discord will act as though that user or guild no longer has entitlement to your premium offering. + + Parameters + ---------- + entitlement_id: :class:`int` + The ID of the entitlement to delete. + """ + await self.http.delete_test_entitlement(self.app.id, entitlement_id) + + async def fetch_entitlements( + self, + *, + limit: int = 100, + user: Optional[User] = None, + guild: Optional[Guild] = None, + sku_ids: Optional[List[int]] = None, + before: Optional[Union[datetime.datetime, Snowflake]] = None, + after: Optional[Union[datetime.datetime, Snowflake]] = None, + exclude_ended: bool = False + ): + """|coro| + + Parameters + ---------- + limit: :class:`int` + The maximum amount of entitlements to fetch. + Defaults to ``100``. + user: Optional[:class:`User`] + The user to fetch entitlements for. + guild: Optional[:class:`Guild`] + The guild to fetch entitlements for. + sku_ids: Optional[List[:class:`int`]] + Optional list of SKU IDs to check entitlements for + before: Optional[Union[:class:`datetime.datetime`, :class:`Snowflake`]] + Retrieve entitlements before this date or object. + If a date is provided it must be a timezone-naive datetime representing UTC time. + after: Optional[Union[:class:`datetime.datetime`, :class:`Snowflake`]] + Retrieve entitlements after this date or object. + If a date is provided it must be a timezone-naive datetime representing UTC time. + exclude_ended: :class:`bool` + Whether ended entitlements should be fetched or not. Defaults to ``False``. + + Return + ------ + :class:`AsyncIterator` + An iterator to fetch all entitlements for the current application. + """ + return EntitlementIterator( + state=self._connection, + limit=limit, + user_id=user.id, + guild_id=guild.id, + sku_ids=sku_ids, + before=before, + after=after, + exclude_ended=exclude_ended + ) diff --git a/discord/enums.py b/discord/enums.py index 9f9dbaed..76c1dff9 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -83,6 +83,7 @@ 'ForumLayout', 'OnboardingMode', 'OnboardingPromptType', + 'SKUType', ) @@ -423,13 +424,15 @@ def from_type(cls, t): class InteractionCallbackType(Enum): - pong = 1 - msg_with_source = 4 + pong = 1 + msg_with_source = 4 deferred_msg_with_source = 5 - deferred_update_msg = 6 - update_msg = 7 - autocomplete_callback = 8 - modal = 9 + deferred_update_msg = 6 + update_msg = 7 + autocomplete_callback = 8 + modal = 9 + premium_required = 10 + @classmethod def from_value(cls, value): @@ -1142,6 +1145,15 @@ class OnboardingPromptType(Enum): dropdown = 1 +class SKUType(Enum): + durable_primary = 1 + durable = 2 + consumable = 3 + bundle = 4 + subscription = 5 + subscription_group = 6 + + def try_enum(cls: Type[Enum], val: Any): """A function that tries to turn the value into enum ``cls``. diff --git a/discord/flags.py b/discord/flags.py index abf75843..b9751f0a 100644 --- a/discord/flags.py +++ b/discord/flags.py @@ -46,7 +46,8 @@ 'GuildMemberFlags', 'Intents', 'MemberCacheFlags', - 'ApplicationFlags' + 'ApplicationFlags', + 'SKUFlags', ) @@ -1398,3 +1399,100 @@ def application_commands_badge(self): def active_application(self): """:class:`bool`: Returns ``True`` is an active application (e.g. has at leas one app command executed in the last 30 days).""" return 1 << 24 + + +@fill_with_flags() +class SKUFlags(BaseFlags): + """Wraps up the flags of a :class:`SKU`. + + .. container:: operations + + .. describe:: x == y + + Checks if two SKUFlags are equal. + .. describe:: x != y + + Checks if two SKUFlags are not equal. + .. describe:: hash(x) + + Return the flag's hash. + .. describe:: iter(x) + + Returns an iterator of ``(name, value)`` pairs. This allows it + to be, for example, constructed as a dict or a list of pairs. + Note that aliases are not shown. + + Attributes + ----------- + value: :class:`int` + The raw value. This value is a bit array field of a 53-bit integer + representing the currently available flags. You should query + flags via the properties rather than using this raw value. + """ + + @flag_value + def premium_purchase(self): + """:class:`bool`: A premium purchase SKU""" + return 1 << 0 + + @flag_value + def has_free_premium_content(self): + """:class:`bool`: An SKU containing free premium content""" + return 1 << 1 + + @flag_value + def available(self): + """:class:`bool`: Whether the SKU is currently available for purchase""" + return 1 << 2 + + # @flag_value + # def premium_and_distribution(self): + # """:class:`bool`: Undocumented""" + # return 1 << 3 + + # @flag_value + # def sticker(self): + # """":class:`bool`: Undocumented""" + # return 1 << 4 + + @flag_value + def guild_role(self): + """:class:`bool`: A role that can be purchased for a guild""" + return 1 << 5 + + @flag_value + def available_for_subscription_gifting(self): + """:class:`bool`: An SKU that can be purchased as a gift for others.""" + return 1 << 6 + + @flag_value + def app_guild_subscription(self): + """:class:`bool`: A recurring SKU that can be purchased by a user and applied to a single server. + Grants access to every user in that server.""" + return 1 << 7 + + @flag_value + def app_user_subscription(self): + """:class:`bool`: Recurring SKU purchased by a user for themselves. + Grants access to the purchasing user in every server.""" + return 1 << 8 + + # @flag_value + # def creator_monetization(self): + # """:class:`bool`: Undocumented""" + # return 1 << 9 + + @flag_value + def guild_product(self): + """:class:`bool`: A product in the guild store""" + return 1 << 10 + + # @flag_value + # def user_update_mask(self): + # """:class:`bool`: Undocumented""" + # return 0 << 0 + # + # @flag_value + # def staff_create_subscription_group_listing_mask(self): + # """:class:`bool`: Undocumented""" + # return 384 diff --git a/discord/http.py b/discord/http.py index ad0af24c..b093f152 100644 --- a/discord/http.py +++ b/discord/http.py @@ -62,7 +62,8 @@ from .embeds import Embed from .message import Attachment, MessageReference from .types import ( - guild + guild, + monetization, ) from .types.snowflake import SnowflakeID from .utils import SnowflakeList @@ -1840,7 +1841,62 @@ def edit_guild_onboarding( # Misc def application_info(self): return self.request(Route('GET', '/oauth2/applications/@me')) - + + def list_entitlements( + self, + application_id: int, + *, + limit: int = 100, + user_id: int = MISSING, + guild_id: int = MISSING, + sku_ids: List[int] = MISSING, + after: int = MISSING, + before: int = MISSING, + exclude_ended: int = False + ) -> Response[List[monetization.Entitlement]]: + params = {'limit': limit} + + if user_id is not MISSING: + params['user_id'] = str(user_id) + if guild_id is not MISSING: # FIXME: Can both be passed at the same time? Consider using elif instead + params['guild_id'] = str(guild_id) + if sku_ids is not MISSING: + params['sku_ids'] = [str(s) for s in sku_ids] + if after is not MISSING: + params['after'] = str(after) + if before is not MISSING: # FIXME: Can both be passed at the same time? Consider using elif instead + params['before'] = str(before) + if exclude_ended is not False: # TODO: what is the api default value? + params['exclude_ended'] = str(exclude_ended) + + r = Route('GET', '/applications/{application_id}/entitlements', application_id=application_id) + return self.request(r, json=params) + + def create_test_entitlement( + self, + application_id: int, + *, + sku_id: int, + owner_id: int, + owner_type: int + ) -> Response[monetization.TestEntitlement]: + payload = { + 'sku_id': sku_id, + 'owner_id': owner_id, + 'owner_type': owner_type + } + r = Route('POST', '/applications/{application_id}/entitlements', application_id=application_id) + return self.request(r, json=payload) + + def delete_test_entitlement(self, application_id: int, entitlement_id: int) -> Response[None]: + r = Route( + 'DELETE', + '/applications/{application_id}/entitlements/{entitlement_id}', + application_id=application_id, + entitlement_id=entitlement_id + ) + return self.request(r) + async def get_gateway(self, *, encoding='json', v=10, zlib=True): try: data = await self.request(Route('GET', '/gateway')) diff --git a/discord/interactions.py b/discord/interactions.py index e3c2ccdf..e40a170d 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -36,6 +36,7 @@ Optional, overload, Sequence, + Set, Tuple, TYPE_CHECKING, Union @@ -44,6 +45,7 @@ from typing_extensions import Literal from . import abc, utils +from .monetization import Entitlement from .object import Object from .channel import _channel_factory, DMChannel, TextChannel, ThreadChannel, VoiceChannel, ForumPost, PartialMessageable from .components import * @@ -468,6 +470,11 @@ class BaseInteraction: The id of the user who triggered the interaction. channel_id: :class:`int` The id of the channel where the interaction was triggered. + entitlements: Optional[List[:class:`~discord.Entitlement`]] + For :ddocs:`monetized apps ` entitlements of the user who triggered the interaction + and optionally, if any. + + This is available for all interaction types. data: :class:`~discord.InteractionData` Some internal needed metadata for the interaction, depending on the type. author_locale: Optional[:class:`~discord.Locale`] @@ -479,7 +486,7 @@ class BaseInteraction: app_permissions: Optional[:class:`~discord.Permissions`] The permissions of the bot in the channel where the interaction was triggered, if it was in a guild. - This is similar to `interaction.channel.permissions_for(innteraction.guild.me)` but calculated on discord side. + This is similar to `interaction.channel.permissions_for(interaction.guild.me)` but calculated on discord side. author_permissions: Optional[:class:`~discord.Permissions`] The author's permissions in the channel where the interaction was triggered, if it was in a guild. @@ -490,12 +497,13 @@ class BaseInteraction: id: int guild_id: Optional[int] channel_id: int - app_permissions: Optional[Permissions] - author_permissions: Optional[Permissions] user_id: int user: User - member: Optional[Member] author_locale: Locale + entitlements: Set[Entitlement] + app_permissions: Optional[Permissions] + author_permissions: Optional[Permissions] + member: Optional[Member] guild_locale: Optional[Locale] data: Optional[InteractionData] message: Optional[Union[Message, EphemeralMessage]] @@ -510,7 +518,7 @@ def __init__(self, state: ConnectionState, data: InteractionPayload) -> None: self.type = InteractionType.try_value(data['type']) self.id = int(data['id']) self.guild_id = guild_id = utils._get_as_snowflake(data, 'guild_id') - + self.entitlements = {Entitlement(data=e, state=state) for e in data.get('entitlements', [])} channel_data = data.get('channel', {}) self.channel_id = channel_id = int(data.get('channel_id', channel_data.get('id', 0))) @@ -651,6 +659,17 @@ async def _defer( msg = self.callback_message = await self.get_original_callback() return msg + async def _premium_required(self) -> None: + """|coro| + Respond with an upgrade button, only available for apps + with :ddocs:`monetized apps ` enabled. + """ + await self._state.http.post_initial_response( + self.id, + self._token, + data={'type': InteractionCallbackType.premium_required.value, 'data': {}} + ) + async def edit( self, *, @@ -1209,6 +1228,18 @@ async def defer(self, hidden: bool = False) -> Union[Message, EphemeralMessage]: data = await super()._defer(InteractionCallbackType.deferred_msg_with_source, hidden) return data + async def respond_with_premium_required(self) -> None: + """|coro| + Respond with an upgrade button, only available for apps + with :ddocs:`monetized apps ` enabled. + + .. note:: + You must respond with this one directly, without using any of + :meth:`~discord.ApplicationCommandInteraction.defer` or + :meth:`~discord.ApplicationCommandInteraction.respond`. + """ + await super()._premium_required() + class ComponentInteraction(BaseInteraction): """ @@ -1302,6 +1333,17 @@ async def defer( """ return await super()._defer(type, hidden) + async def respond_with_premium_required(self) -> None: + """|coro| + Respond with an upgrade button, only available for apps + with :ddocs:`monetized apps ` enabled. + + .. note:: + You must respond with this one directly, without using any of + :meth:`~discord.ApplicationCommandInteraction.defer` or + :meth:`~discord.ApplicationCommandInteraction.respond`. + """ + await super()._premium_required() class AutocompleteInteraction(BaseInteraction): """ @@ -1514,6 +1556,18 @@ async def defer(self, hidden: Optional[bool] = False) -> Optional[Union[Message, async def respond_with_modal(self, modal: Modal) -> NotImplementedError: raise NotImplementedError('You can\'t respond to a modal submit with another modal.') + async def respond_with_premium_required(self) -> None: + """|coro| + Respond with an upgrade button, only available for apps + with :ddocs:`monetized apps ` enabled. + + .. note:: + You must respond with this one directly, without using any of + :meth:`~discord.ApplicationCommandInteraction.defer` or + :meth:`~discord.ApplicationCommandInteraction.respond`. + """ + await super()._premium_required() + class InteractionData: def __init__(self, *, state: ConnectionState, data: InteractionDataPayload, guild: Optional[Guild] = None, **kwargs) -> None: diff --git a/discord/iterators.py b/discord/iterators.py index 80109349..59d07084 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -33,20 +33,25 @@ from typing import ( TYPE_CHECKING, Optional, - Union + Union, + List, ) if TYPE_CHECKING: + from .types import ( + monetization, + ) + from .state import ConnectionState from .guild import Guild from .abc import Snowflake, Messageable from .scheduled_event import GuildScheduledEvent from .channel import ThreadChannel, ForumPost, TextChannel, ForumChannel - + from .monetization import Entitlement +from .utils import MISSING from .errors import NoMoreItems from .utils import time_snowflake, maybe_coroutine from .object import Object -from .audit_logs import AuditLogEntry OLDEST_OBJECT = Object(id=0) BanEntry = namedtuple('BanEntry', 'reason user') @@ -62,6 +67,7 @@ 'MemberIterator', 'ReactionIterator', 'ThreadMemberIterator', + 'EntitlementIterator', ) @@ -445,6 +451,12 @@ def __init__(self, self.action_type = action_type self.after = OLDEST_OBJECT self._users = {} + self._integrations = {} + self._webhooks = {} + self._scheduled_events = {} + self._threads = {} + self._application_commands = {} + self._auto_moderation_rules = {} self._state = guild._state self._filter = None # entry dict -> bool @@ -464,13 +476,12 @@ async def _before_strategy(self, retrieve): before = self.before.id if self.before else None data = await self.request(self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, before=before) - entries = data.get('audit_log_entries', []) if len(data) and entries: if self.limit is not None: self.limit -= retrieve self.before = Object(id=int(entries[-1]['id'])) - return data.get('users', []), entries + return data, entries async def _after_strategy(self, retrieve): after = self.after.id if self.after else None @@ -481,7 +492,7 @@ async def _after_strategy(self, retrieve): if self.limit is not None: self.limit -= retrieve self.after = Object(id=int(entries[0]['id'])) - return data.get('users', []), entries + return data, entries async def next(self): if self.entries.empty(): @@ -503,27 +514,73 @@ def _get_retrieve(self): async def _fill(self): from .user import User + from .integrations import _integration_factory + from .webhook import Webhook + from .scheduled_event import GuildScheduledEvent + from .channel import ForumChannel, ForumPost, ThreadChannel + from .application_commands import ApplicationCommand + from .automod import AutoModRule + from .audit_logs import AuditLogEntry if self._get_retrieve(): - users, data = await self._strategy(self.retrieve) - if len(data) < 100: + data, entries = await self._strategy(self.retrieve) + if len(entries) < 100: self.limit = 0 # terminate the infinite loop if self.reverse: - data = reversed(data) + entries = reversed(entries) if self._filter: data = filter(self._filter, data) - for user in users: - u = User(data=user, state=self._state) + _state = self._state + _guild = self.guild + + for user in data.get('users', []): + u = User(data=user, state=_state) self._users[u.id] = u - for element in data: + for integration in data.get('integrations', []): + i, _ = _integration_factory(integration['type']) + self._integrations[i.id] = i(data=integration, guild=_guild) + + for webhook in data.get('webhooks', []): + w = Webhook.from_state(data=webhook, state=_state) + self._webhooks[w.id] = w + + for scheduled_event in data.get('guild_scheduled_events', []): + e = GuildScheduledEvent(state=_state, guild=_guild, data=scheduled_event) + self._scheduled_events[e.id] = e + + for thread in data.get('threads', []): + parent_id = int(thread.get('parent_id')) + parent_channel = _guild.get_channel(parent_id) + if isinstance(parent_channel, ForumChannel): + t = ForumPost(state=_state, guild=_guild, data=thread) + else: + t = ThreadChannel(state=_state, guild=_guild, data=thread) + self._threads[t.id] = t + + for application_command in data.get('application_commands', []): + c = ApplicationCommand._from_type(state=_state, data=application_command) + self._application_commands[c.id] = c + + for automod_rule in data.get('auto_moderation_rules', []): + r = AutoModRule(state=_state, guild=_guild, **automod_rule) + self._auto_moderation_rules[r.id] = r + + for entry in entries: # TODO: remove this if statement later - if element['action_type'] is None: + if entry['action_type'] is None: continue - await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild)) + await self.entries.put( + AuditLogEntry( + data=entry, + guild=self.guild, + users=self._users, + + ) + ) class GuildIterator(_AsyncIterator): @@ -614,7 +671,7 @@ async def flatten(self): self.limit = 0 if self._filter: - data = filter(self._filter, data) + entries = filter(self._filter, entries) for element in data: result.append(self.create_guild(element)) @@ -978,11 +1035,13 @@ async def _retrieve_users_after_strategy(self, retrieve): class BanIterator(_AsyncIterator): - def __init__(self, - guild: 'Guild', - limit: int = 1000, - before: Optional[Union['Snowflake', datetime.datetime]] = None, - after: Optional[Union['Snowflake', datetime.datetime]] = None,): + def __init__( + self, + guild: Guild, + limit: int = 1000, + before: Optional[Union[Snowflake, datetime.datetime]] = None, + after: Optional[Union[Snowflake, datetime.datetime]] = None + ): self.guild = guild self.guild_id = guild.id self.state = guild._state @@ -1067,3 +1126,117 @@ async def _retrieve_bans_after_strategy(self, retrieve): self.after = Object(id=int(data[0]['user']['id'])) data = reversed(data) return data + + +class EntitlementIterator(_AsyncIterator): + def __init__( + self, + state: ConnectionState, + limit: int = 100, + user_id: int = MISSING, + guild_id: int = MISSING, + sku_ids: List[int] = MISSING, + before: Optional[Union[datetime.datetime, Snowflake]] = None, + after: Optional[Union[datetime.datetime, Snowflake]] = None, + exclude_ended: bool = False + ): + self.application_id = state.application_id + self.guild_id = guild_id + self.user_id = user_id + self.sku_ids = sku_ids + self.state: ConnectionState = state + self.limit: int = limit + self.exclude_ended: bool = exclude_ended + + if isinstance(before, datetime.datetime): + before = Object(id=time_snowflake(before, high=True)) + if isinstance(after, datetime.datetime): + after = Object(id=time_snowflake(after, high=True)) + + self.before: Optional[Object] = before + self.after: Optional[Object] = after + + self.entitlements = asyncio.Queue() + self.getter = state.http.list_entitlements + + self._filter = None + + if self.before and self.after: + self._retrieve_entitlements = self._retrieve_entitlements_before_strategy + self._filter = lambda e: int(e['id']) > self.after.id + elif self.before: + self._retrieve_entitlements = self._retrieve_entitlements_before_strategy + else: + self._retrieve_entitlements = self._retrieve_entitlements_after_strategy + + async def next(self): + if self.entitlements.empty(): + await self.fill_entitlements() + + try: + return self.entitlements.get_nowait() + except asyncio.QueueEmpty: + raise NoMoreItems() + + def _get_retrieve(self): + l = self.limit + r = 100 if l is None or l > 100 else l + self.retrieve = r + return r > 0 + + async def fill_entitlements(self): + # this is a hack because >circular imports< + from .monetization import Entitlement + + state = self.state + + if self._get_retrieve(): + + data = await self._retrieve_entitlements(self.retrieve) + if self.limit is None or len(data) < 100: + self.limit = 0 + + if self._filter: + data = filter(self._filter, data) + + for element in data: + await self.entitlements.put( + Entitlement(data=element, state=state) + ) + + async def _retrieve_entitlements_before_strategy(self, retrieve) -> List[monetization.Entitlement]: + """Retrieve bans using before parameter.""" + before = self.before.id if self.before else MISSING + data = await self.getter( + self.application_id, + limit=retrieve, + before=before, + user_id=self.user_id, + guild_id=self.guild_id, + sku_ids=self.sku_ids, + exclude_ended=self.exclude_ended + ) + if len(data): + if self.limit is not None: + self.limit -= retrieve + self.before = Object(id=int(data[-1]['id'])) + return data + + async def _retrieve_entitlements_after_strategy(self, retrieve) -> List[monetization.Entitlement]: + """Retrieve bans using after parameter.""" + after = self.after.id if self.after else MISSING + data = await self.getter( + self.application_id, + limit=retrieve, + after=after, + user_id=self.user_id, + guild_id=self.guild_id, + sku_ids=self.sku_ids, + exclude_ended=self.exclude_ended, + ) + if len(data): + if self.limit is not None: + self.limit -= retrieve + self.after = Object(id=int(data[0]['id'])) + data = reversed(data) + return data diff --git a/discord/message.py b/discord/message.py index 4d92f6ff..9106fd84 100644 --- a/discord/message.py +++ b/discord/message.py @@ -116,7 +116,7 @@ 'PartialMessage', 'MessageReference', 'DeletedReferencedMessage', - 'RoleSubscriptionInfo' + 'RoleSubscriptionInfo', ) @@ -1508,7 +1508,7 @@ async def publish(self) -> None: await self._state.http.publish_message(self.channel.id, self.id) - async def pin(self, *, reason: Optional[str] = None) -> None: + async def pin(self, *, suppress_system_message: bool = False, reason: Optional[str] = None) -> None: """|coro| Pins the message. @@ -1518,6 +1518,10 @@ async def pin(self, *, reason: Optional[str] = None) -> None: Parameters ----------- + suppress_system_message: :class:`bool` + When set to ``True``, the function will wait 5 seconds for the system message and delete it. Defaults to ``False``. + + .. versionadded:: 2.0 reason: Optional[:class:`str`] The reason for pinning the message. Shows up on the audit log. @@ -1532,11 +1536,26 @@ async def pin(self, *, reason: Optional[str] = None) -> None: HTTPException Pinning the message failed, probably due to the channel having more than 50 pinned messages. + :exc:`~asyncio.TimeoutError` + Waiting for the system message timed out. """ await self._state.http.pin_message(self.channel.id, self.id, reason=reason) self.pinned = True + # TODO: we don't get a message create for this... + if suppress_system_message: + try: + msg = await self._state._get_client().wait_for( + 'message', + check=lambda m: m.type is MessageType.pins_add and m.channel == self.channel, + timeout=5.0 + ) + except asyncio.TimeoutError as exc: + raise asyncio.TimeoutError('Timed out waiting for the system message to suppress.') from exc + else: + await msg.delete(reason='Suppressing system message for pinned message') + async def unpin(self, *, reason: Optional[str] = None) -> None: """|coro| diff --git a/discord/state.py b/discord/state.py index 7d96b02f..43765de6 100644 --- a/discord/state.py +++ b/discord/state.py @@ -61,6 +61,7 @@ from .invite import Invite from .automod import AutoModRule, AutoModActionPayload from .interactions import BaseInteraction, InteractionType +from .monetization import Entitlement if TYPE_CHECKING: @@ -541,6 +542,19 @@ def parse_resumed(self, data): self.call_handlers('resumed') self.dispatch('resumed') + async def parse_entitlement_create(self, data): + entitlement = Entitlement(data, self) + self.dispatch('entitlement_create', entitlement) + + def parse_entitlement_update(self, data): + entitlement = Entitlement(data, self) + # TODO: Add this to advanced cache + self.dispatch('entitlement_update', entitlement) + + def parse_entitlement_delete(self, data): + entitlement = Entitlement(data, self) + self.dispatch('entitlement_delete', entitlement) + def parse_message_create(self, data): channel, _ = self._get_guild_channel(data) message = Message(channel=channel, data=data, state=self)