Skip to content

Commit

Permalink
feat: Add support for application subscriptions
Browse files Browse the repository at this point in the history
See discord/discord-api-docs#6451, discord/discord-api-docs#6452,
and discord/discord-api-docs#6477.

Signed-off-by: mccoderpy <mccuber04@outlook.de>
  • Loading branch information
mccoderpy committed Nov 4, 2023
1 parent ba9b448 commit bb44f47
Show file tree
Hide file tree
Showing 9 changed files with 638 additions and 55 deletions.
1 change: 1 addition & 0 deletions discord/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from .team import *
from .sticker import Sticker, GuildSticker, StickerPack
from .scheduled_event import GuildScheduledEvent
from .monetization import *


MISSING = utils.MISSING
Expand Down
194 changes: 175 additions & 19 deletions discord/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
TYPE_CHECKING
)

from typing_extensions import Literal

from .auto_updater import AutoUpdateChecker
from .sticker import StickerPack
from .user import ClientUser, User
Expand All @@ -62,6 +64,7 @@
from .channel import _channel_factory, PartialMessageable
from .enums import ChannelType, ApplicationCommandType, Locale
from .mentions import AllowedMentions
from .monetization import Entitlement, SKU
from .errors import *
from .enums import Status, VoiceRegion
from .gateway import *
Expand All @@ -73,7 +76,7 @@
from .object import Object
from .backoff import ExponentialBackoff
from .webhook import Webhook
from .iterators import GuildIterator
from .iterators import GuildIterator, EntitlementIterator
from .appinfo import AppInfo
from .application_commands import *

Expand Down Expand Up @@ -102,7 +105,6 @@
_SubmitCallback = Callable[[ModalSubmitInteraction], Coroutine[Any, Any, Any]]



T = TypeVar('T')
Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])

Expand Down Expand Up @@ -563,12 +565,12 @@ async def _request_sync_commands(self, is_cog_reload: bool = False, *, reload_fa
):
return await self._sync_commands()
state = self._connection # Speedup attribute access
app_id = self.app.id
get_commands = self.http.get_application_commands
if not is_cog_reload:
app_id = self.app.id
log.info('Collecting global application-commands for application %s (%s)', self.app.name, self.app.id)

self._minimal_registered_global_commands_raw = minimal_registered_global_commands_raw = []
get_commands = self.http.get_application_commands
global_registered_raw = await get_commands(app_id)

for raw_command in global_registered_raw:
Expand All @@ -585,7 +587,10 @@ async def _request_sync_commands(self, is_cog_reload: bool = False, *, reload_fa
command._state = state
self._application_commands[command.id] = command

log.info('Done! Cached %s global application-commands', sum([len(cmds) for cmds in self._application_commands_by_type.values()]))
log.info(
'Done! Cached %s global application-commands',
sum([len(cmds) for cmds in self._application_commands_by_type.values()])
)
log.info('Collecting guild-specific application-commands for application %s (%s)', self.app.name, app_id)

self._minimal_registered_guild_commands_raw = minimal_registered_guild_commands_raw = {}
Expand All @@ -605,22 +610,33 @@ async def _request_sync_commands(self, is_cog_reload: bool = False, *, reload_fa
try:
guild_commands = self._guild_specific_application_commands[guild.id]
except KeyError:
self._guild_specific_application_commands[guild.id] = guild_commands = {'chat_input': {}, 'user': {}, 'message': {}}
self._guild_specific_application_commands[guild.id] = guild_commands = {
'chat_input': {}, 'user': {}, 'message': {}
}
for raw_command in registered_guild_commands_raw:
command_type = str(ApplicationCommandType.try_value(raw_command['type']))
minimal_registered_guild_commands.append({'id': int(raw_command['id']), 'type': command_type, 'name': raw_command['name']})
minimal_registered_guild_commands.append(
{'id': int(raw_command['id']), 'type': command_type, 'name': raw_command['name']}
)
try:
command = guild_commands[command_type][raw_command['name']]
except KeyError:
command = ApplicationCommand._from_type(state, data=raw_command)
command.func = None
self._application_commands[command.id] = guild._application_commands[command.id] = guild_commands[command_type][command.name] = command
self._application_commands[command.id] = guild._application_commands[command.id] \
= guild_commands[command_type][command.name] = command
else:
command._fill_data(raw_command)
command._state = state
self._application_commands[command.id] = guild._application_commands[command.id] = command

log.info('Done! Cached %s commands for %s guilds', sum([len(commands) for commands in list(minimal_registered_guild_commands_raw.values())]), len(minimal_registered_guild_commands_raw.keys()))
log.info(
'Done! Cached %s commands for %s guilds',
sum([
len(commands) for commands in list(minimal_registered_guild_commands_raw.values())
]),
len(minimal_registered_guild_commands_raw.keys())
)

else:
# re-assign metadata to the commands (for commands added from cogs)
Expand Down Expand Up @@ -673,11 +689,21 @@ async def _request_sync_commands(self, is_cog_reload: bool = False, *, reload_fa
self._application_commands[command.id] = guild._application_commands[command.id] = command
log.info('Done!')
if no_longer_in_code_global:
log.warning('%s global application-commands where removed from code but are still registered in discord', no_longer_in_code_global)
log.warning(
'%s global application-commands where removed from code but are still registered in discord',
no_longer_in_code_global
)
if no_longer_in_code_guild_specific:
log.warning('In total %s guild-specific application-commands from %s guild(s) where removed from code but are still registered in discord', no_longer_in_code_guild_specific, len(no_longer_in_code_guilds))
log.warning(
'In total %s guild-specific application-commands from %s guild(s) where removed from code '
'but are still registered in discord', no_longer_in_code_guild_specific,
len(no_longer_in_code_guilds)
)
if no_longer_in_code_global or no_longer_in_code_guild_specific:
log.warning('To prevent the above, set `sync_commands_on_cog_reload` of %s to True', self.__class__.__name__)
log.warning(
'To prevent the above, set `sync_commands_on_cog_reload` of %s to True',
self.__class__.__name__
)

@utils.deprecated('Guild.chunk')
async def request_offline_members(self, *guilds):
Expand Down Expand Up @@ -2253,23 +2279,27 @@ async def _sync_commands(self) -> None:

if any_changed is True:
updated = None
if len(to_send) == 1 and has_update and not to_maybe_remove:
if (to_send_count := len(to_send)) == 1 and has_update and not to_maybe_remove:
log.info('Detected changes on global application-command %s, updating.', to_send[0]['name'])
updated = await self.http.edit_application_command(application_id, to_send[0]['id'], to_send[0])
elif len(to_send) == 1 and not has_update and not to_maybe_remove:
elif len == 1 and not has_update and not to_maybe_remove:
log.info('Registering one new global application-command %s.', to_send[0]['name'])
updated = await self.http.create_application_command(application_id, to_send[0])
else:
if len(to_send) > 0:
log.info('Detected %s updated/new global application-commands, bulk overwriting them...', len(to_send))
if to_send_count > 0:
log.info(
f'Detected %s updated/new global application-commands, bulk overwriting them...',
to_send_count
)
if not self.delete_not_existing_commands:
to_send.extend(to_maybe_remove)
else:
if len(to_maybe_remove) > 0:
if (to_maybe_remove_count := len(to_maybe_remove)) > 0:
log.info(
'Removing %s global application-command(s) that isn\'t/arent used in this code anymore.'
' To prevent this set `delete_not_existing_commands` of %s to False',
len(to_maybe_remove), self.__class__.__name__
to_maybe_remove_count,
self.__class__.__name__
)
to_send.extend(to_cep)
global_registered_raw = await self.http.bulk_overwrite_application_commands(application_id, to_send)
Expand Down Expand Up @@ -2948,4 +2978,130 @@ async def fetch_voice_regions(self) -> List[VoiceRegionInfo]:
The voice regions that can be used.
"""
data = await self.http.get_voice_regions()
return [VoiceRegionInfo(data=d) for d in data]
return [VoiceRegionInfo(data=d) for d in data]

async def create_test_entitlement(
self,
sku_id: int,
target: Union[User, Guild, Snowflake],
owner_type: Optional[Literal['guild', 'user']] = MISSING
):
"""|coro|
.. note::
This method is only temporary and probably will be removed with or even before a stable v2 release
as discord is already redesigning the testing system based on developer feedback.
See https://github.com/discord/discord-api-docs/pull/6502 for more information.
Creates a test entitlement to a given :class:`SKU` for a given guild or user.
Discord will act as though that user or guild has entitlement to your premium offering.
After creating a test entitlement, you'll need to reload your Discord client.
After doing so, you'll see that your server or user now has premium access.
Parameters
----------
sku_id: :class:`int`
The ID of the SKU to create a test entitlement for.
target: Union[:class:`User`, :class:`Guild`, :class:`Snowflake`]
The target to create a test entitlement for.
This can be a user, guild or just the ID, if so the owner_type parameter must be set.
owner_type: :class:`str`
The type of the ``target``, could be ``guild`` or ``user``.
Returns
--------
:class:`Entitlement`
The created test entitlement.
"""
target = target.id

if isinstance(target, Guild):
owner_type = 1
elif isinstance(target, User):
owner_type = 2
else:
if owner_type is MISSING:
raise TypeError('owner_type must be set if target is not a Guild or user-like object.')
else:
owner_type = 1 if owner_type == 'guild' else 2

data = await self.http.create_test_entitlement(
self.app.id,
sku_id=sku_id,
owner_id=target,
owner_type=owner_type
)
return Entitlement(data=data, state=self._connection)

async def delete_test_entitlement(self, entitlement_id: int):
"""|coro|
.. note::
This method is only temporary and probably will be removed with or even before a stable v2 release
as discord is already redesigning the testing system based on developer feedback.
See https://github.com/discord/discord-api-docs/pull/6502 for more information.
Deletes a currently-active test entitlement.
Discord will act as though that user or guild no longer has entitlement to your premium offering.
Parameters
----------
entitlement_id: :class:`int`
The ID of the entitlement to delete.
"""
await self.http.delete_test_entitlement(self.app.id, entitlement_id)

async def fetch_entitlements(
self,
*,
limit: int = 100,
user: Optional[User] = None,
guild: Optional[Guild] = None,
sku_ids: Optional[List[int]] = None,
before: Optional[Union[datetime.datetime, Snowflake]] = None,
after: Optional[Union[datetime.datetime, Snowflake]] = None,
exclude_ended: bool = False
):
"""|coro|
Parameters
----------
limit: :class:`int`
The maximum amount of entitlements to fetch.
Defaults to ``100``.
user: Optional[:class:`User`]
The user to fetch entitlements for.
guild: Optional[:class:`Guild`]
The guild to fetch entitlements for.
sku_ids: Optional[List[:class:`int`]]
Optional list of SKU IDs to check entitlements for
before: Optional[Union[:class:`datetime.datetime`, :class:`Snowflake`]]
Retrieve entitlements before this date or object.
If a date is provided it must be a timezone-naive datetime representing UTC time.
after: Optional[Union[:class:`datetime.datetime`, :class:`Snowflake`]]
Retrieve entitlements after this date or object.
If a date is provided it must be a timezone-naive datetime representing UTC time.
exclude_ended: :class:`bool`
Whether ended entitlements should be fetched or not. Defaults to ``False``.
Return
------
:class:`AsyncIterator`
An iterator to fetch all entitlements for the current application.
"""
return EntitlementIterator(
state=self._connection,
limit=limit,
user_id=user.id,
guild_id=guild.id,
sku_ids=sku_ids,
before=before,
after=after,
exclude_ended=exclude_ended
)
24 changes: 18 additions & 6 deletions discord/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
'ForumLayout',
'OnboardingMode',
'OnboardingPromptType',
'SKUType',
)


Expand Down Expand Up @@ -423,13 +424,15 @@ def from_type(cls, t):


class InteractionCallbackType(Enum):
pong = 1
msg_with_source = 4
pong = 1
msg_with_source = 4
deferred_msg_with_source = 5
deferred_update_msg = 6
update_msg = 7
autocomplete_callback = 8
modal = 9
deferred_update_msg = 6
update_msg = 7
autocomplete_callback = 8
modal = 9
premium_required = 10


@classmethod
def from_value(cls, value):
Expand Down Expand Up @@ -1142,6 +1145,15 @@ class OnboardingPromptType(Enum):
dropdown = 1


class SKUType(Enum):
durable_primary = 1
durable = 2
consumable = 3
bundle = 4
subscription = 5
subscription_group = 6


def try_enum(cls: Type[Enum], val: Any):
"""A function that tries to turn the value into enum ``cls``.
Expand Down
Loading

0 comments on commit bb44f47

Please sign in to comment.