Skip to content

Commit

Permalink
chore: merge unstable into stable (#1117)
Browse files Browse the repository at this point in the history
* feat: Implement helper methods for invites (#1098)

* feat: mention spam trigger type

* fix: Reimplement manual sharding/presence, fix forum tag implementation (#1115)

* fix: Reimplement manual sharding/presence instantiation.

(This was accidentally removed per gateway rework)

* refactor: Reorganise tag creation/updating/deletion to non-deprecated endpoints and make it cache-reflective.

* chore: bump version (#1116)

* fix: properly initialise private attributes in iterators (#1114)

* fix: set `message.member.user` as `message.author` again (#1118)

Co-authored-by: Damego <danyabatueff@gmail.com>
Co-authored-by: i0 <41456914+i0bs@users.noreply.github.com>
Co-authored-by: DeltaX <33706469+DeltaXWizard@users.noreply.github.com>
  • Loading branch information
4 people authored Oct 9, 2022
1 parent 838330d commit 2c902a2
Show file tree
Hide file tree
Showing 11 changed files with 144 additions and 17 deletions.
10 changes: 8 additions & 2 deletions interactions/api/gateway/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def __init__(
intents: Intents,
session_id: Optional[str] = MISSING,
sequence: Optional[int] = MISSING,
shards: Optional[List[Tuple[int]]] = MISSING,
presence: Optional[ClientPresence] = MISSING,
) -> None:
"""
:param token: The token of the application for connecting to the Gateway.
Expand All @@ -132,6 +134,10 @@ def __init__(
:type session_id?: Optional[str]
:param sequence?: The identifier sequence if trying to reconnect. Defaults to ``None``.
:type sequence?: Optional[int]
:param shards?: The list of shards for the application's initial connection, if provided. Defaults to ``None``.
:type shards?: Optional[List[Tuple[int]]]
:param presence?: The presence shown on an application once first connected. Defaults to ``None``.
:type presence?: Optional[ClientPresence]
"""
try:
self._loop = get_event_loop() if version_info < (3, 10) else get_running_loop()
Expand Down Expand Up @@ -161,8 +167,8 @@ def __init__(
}

self._intents: Intents = intents
self.__shard: Optional[List[Tuple[int]]] = None
self.__presence: Optional[ClientPresence] = None
self.__shard: Optional[List[Tuple[int]]] = None if shards is MISSING else shards
self.__presence: Optional[ClientPresence] = None if presence is MISSING else presence

self._task: Optional[Task] = None
self.__heartbeat_event = Event(loop=self._loop) if version_info < (3, 10) else Event()
Expand Down
67 changes: 60 additions & 7 deletions interactions/api/http/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ..error import LibraryException
from ..models.channel import Channel
from ..models.message import Message
from ..models.misc import Snowflake
from .request import _Request
from .route import Route

Expand Down Expand Up @@ -312,8 +313,10 @@ async def create_tag(
self,
channel_id: int,
name: str,
moderated: bool = False,
emoji_id: Optional[int] = None,
emoji_name: Optional[str] = None,
reason: Optional[str] = None,
) -> dict:
"""
Create a new tag.
Expand All @@ -324,25 +327,41 @@ async def create_tag(
:param channel_id: Channel ID snowflake.
:param name: The name of the tag
:param moderated: Whether the tag can only be assigned to moderators or not. Defaults to ``False``
:param emoji_id: The ID of the emoji to use for the tag
:param emoji_name: The name of the emoji to use for the tag
:param reason: The reason for the creating the tag, if any.
:return: A Forum tag.
"""

_dct = {"name": name}
# This *assumes* cache is up-to-date.

_channel = self.cache[Channel].get(Snowflake(channel_id))
_tags = [_._json for _ in _channel.available_tags] # list of tags in dict form

_dct = {"name": name, "moderated": moderated}
if emoji_id:
_dct["emoji_id"] = emoji_id
if emoji_name:
_dct["emoji_name"] = emoji_name

return await self._req.request(Route("POST", f"/channels/{channel_id}/tags"), json=_dct)
_tags.append(_dct)

updated_channel = await self.modify_channel(
channel_id, {"available_tags": _tags}, reason=reason
)
_channel_obj = Channel(**updated_channel, _client=self)
return _channel_obj.available_tags[-1]._json

async def edit_tag(
self,
channel_id: int,
tag_id: int,
name: str,
moderated: Optional[bool] = None,
emoji_id: Optional[int] = None,
emoji_name: Optional[str] = None,
reason: Optional[str] = None,
) -> dict:
"""
Update a tag.
Expand All @@ -351,28 +370,62 @@ async def edit_tag(
Can either have an emoji_id or an emoji_name, but not both.
emoji_id is meant for custom emojis, emoji_name is meant for unicode emojis.
The object returns *will* have a different tag ID.
:param channel_id: Channel ID snowflake.
:param tag_id: The ID of the tag to update.
:param moderated: Whether the tag can only be assigned to moderators or not. Defaults to ``False``
:param name: The new name of the tag
:param emoji_id: The ID of the emoji to use for the tag
:param emoji_name: The name of the emoji to use for the tag
:param reason: The reason for deleting the tag, if any.
:return The updated tag object.
"""

_dct = {"name": name}
# This *assumes* cache is up-to-date.

_channel = self.cache[Channel].get(Snowflake(channel_id))
_tags = [_._json for _ in _channel.available_tags] # list of tags in dict form

_old_tag = [tag for tag in _tags if tag["id"] == tag_id][0]

_tags.remove(_old_tag)

_dct = {"name": name, "tag_id": tag_id}
if moderated:
_dct["moderated"] = moderated
if emoji_id:
_dct["emoji_id"] = emoji_id
if emoji_name:
_dct["emoji_name"] = emoji_name

return await self._req.request(
Route("PUT", f"/channels/{channel_id}/tags/{tag_id}"), json=_dct
_tags.append(_dct)

updated_channel = await self.modify_channel(
channel_id, {"available_tags": _tags}, reason=reason
)
_channel_obj = Channel(**updated_channel, _client=self)

self.cache[Channel].merge(_channel_obj)

return [tag for tag in _channel_obj.available_tags if tag.name == name][0]

async def delete_tag(self, channel_id: int, tag_id: int) -> None: # wha?
async def delete_tag(self, channel_id: int, tag_id: int, reason: Optional[str] = None) -> None:
"""
Delete a forum tag.
:param channel_id: Channel ID snowflake.
:param tag_id: The ID of the tag to delete
:param reason: The reason for deleting the tag, if any.
"""
return await self._req.request(Route("DELETE", f"/channels/{channel_id}/tags/{tag_id}"))
_channel = self.cache[Channel].get(Snowflake(channel_id))
_tags = [_._json for _ in _channel.available_tags]

_old_tag = [tag for tag in _tags if tag["id"] == Snowflake(tag_id)][0]

_tags.remove(_old_tag)

request = await self.modify_channel(channel_id, {"available_tags": _tags}, reason=reason)

self.cache[Channel].merge(Channel(**request, _client=self))
4 changes: 1 addition & 3 deletions interactions/api/http/invite.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@ async def get_invite(
"""
Gets a Discord invite using its code.
.. note:: with_expiration is currently broken, the API will always return expiration_date.
:param invite_code: A string representing the invite code.
:param with_counts: Whether approximate_member_count and approximate_presence_count are returned.
:param with_expiration: Whether the invite's expiration is returned.
:param with_expiration: Whether the invite's expiration date is returned.
:param guild_scheduled_event_id: A guild scheduled event's ID.
"""
params_set = {
Expand Down
4 changes: 2 additions & 2 deletions interactions/api/http/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ async def create_thread(
reason: Optional[str] = None,
) -> dict:
"""
From a given channel, create a Thread with an optional message to start with..
From a given channel, create a Thread with an optional message to start with.
:param channel_id: The ID of the channel to create this thread in
:param name: The name of the thread
Expand Down Expand Up @@ -212,7 +212,7 @@ async def create_thread_in_forum(
:param name: The name of the thread
:param auto_archive_duration: duration in minutes to automatically archive the thread after recent activity,
can be set to: 60, 1440, 4320, 10080
:param message_payload: The payload/dictionary contents of the first message in the forum thread.
:param message: The payload/dictionary contents of the first message in the forum thread.
:param applied_tags: List of tag ids that can be applied to the forum, if any.
:param files: An optional list of files to send attached to the message.
:param rate_limit_per_user: Seconds a user has to wait before sending another message (0 to 21600), if given.
Expand Down
2 changes: 2 additions & 0 deletions interactions/api/models/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def __init__(
):
super().__init__(obj, _client, maximum=maximum, start_at=start_at, check=check)

self.__stop: bool = False

from .message import Message

if reverse and start_at is MISSING:
Expand Down
63 changes: 63 additions & 0 deletions interactions/api/models/guild.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ def __init__(
start_at: Optional[Union[int, str, Snowflake, Member]] = MISSING,
check: Optional[Callable[[Member], bool]] = None,
):

self.__stop: bool = False

super().__init__(obj, _client, maximum=maximum, start_at=start_at, check=check)

self.after = self.start_at
Expand Down Expand Up @@ -2830,6 +2833,66 @@ async def get_full_audit_logs(

return AuditLogs(**_audit_log_dict)

async def get_invite(
self,
invite_code: str,
with_counts: Optional[bool] = MISSING,
with_expiration: Optional[bool] = MISSING,
guild_scheduled_event_id: Optional[int] = MISSING,
) -> "Invite":
"""
Gets the invite using its code.
:param str invite_code: A string representing the invite code.
:param Optional[bool] with_counts: Whether approximate_member_count and approximate_presence_count are returned.
:param Optional[bool] with_expiration: Whether the invite's expiration date is returned.
:param Optional[int] guild_scheduled_event_id: A guild scheduled event's ID.
:return: An invite
:rtype: Invite
"""
if not self._client:
raise LibraryException(code=13)

_with_counts = with_counts if with_counts is not MISSING else None
_with_expiration = with_expiration if with_expiration is not MISSING else None
_guild_scheduled_event_id = (
guild_scheduled_event_id if guild_scheduled_event_id is not MISSING else None
)

res = await self._client.get_invite(
invite_code=invite_code,
with_counts=_with_counts,
with_expiration=_with_expiration,
guild_scheduled_event_id=_guild_scheduled_event_id,
)

return Invite(**res, _client=self._client)

async def delete_invite(self, invite_code: str, reason: Optional[str] = None) -> None:
"""
Deletes the invite using its code.
:param str invite_code: A string representing the invite code.
:param Optional[str] reason: The reason of the deletion
"""
if not self._client:
raise LibraryException(code=13)

await self._client.delete_invite(invite_code=invite_code, reason=reason)

async def get_invites(self) -> List["Invite"]:
"""
Gets invites of the guild.
:return: A list of guild invites
:rtype: List[Invite]
"""
if not self._client:
raise LibraryException(code=13)

res = await self._client.get_guild_invites(guild_id=int(self.id))
return [Invite(**_, _client=self._client) for _ in res]

@property
def icon_url(self) -> Optional[str]:
"""
Expand Down
3 changes: 3 additions & 0 deletions interactions/api/models/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,9 @@ def __attrs_post_init__(self):
if self.guild_id:
self.member._extras["guild_id"] = self.guild_id

if self.author and self.member:
self.member.user = self.author

async def get_channel(self) -> Channel:
"""
Gets the channel where the message was sent.
Expand Down
1 change: 1 addition & 0 deletions interactions/api/models/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ class AutoModTriggerType(IntEnum):
HARMFUL_LINK = 2
SPAM = 3
KEYWORD_PRESET = 4
MENTION_SPAM = 5


class AutoModKeywordPresetTypes(IntEnum):
Expand Down
2 changes: 1 addition & 1 deletion interactions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"__authors__",
)

__version__ = "4.3.2"
__version__ = "4.3.3"

__authors__ = {
"current": [
Expand Down
4 changes: 3 additions & 1 deletion interactions/client/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,13 @@ def __init__(
self._loop: AbstractEventLoop = get_event_loop()
self._http: HTTPClient = token
self._intents: Intents = kwargs.get("intents", Intents.DEFAULT)
self._websocket: WSClient = WSClient(token=token, intents=self._intents)
self._shards: List[Tuple[int]] = kwargs.get("shards", [])
self._commands: List[Command] = []
self._default_scope = kwargs.get("default_scope")
self._presence = kwargs.get("presence")
self._websocket: WSClient = WSClient(
token=token, intents=self._intents, shards=self._shards, presence=self._presence
)
self._token = token
self._extensions = {}
self._scopes = set([])
Expand Down
1 change: 0 additions & 1 deletion interactions/utils/abc/base_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def __init__(
if not hasattr(start_at, "id")
else int(start_at.id)
)
self.__stop: bool = False
self.objects: Optional[List[_O]] = None


Expand Down

0 comments on commit 2c902a2

Please sign in to comment.