Skip to content

Commit

Permalink
feat: allow for custom player cls when resuming
Browse files Browse the repository at this point in the history
  • Loading branch information
ooliver1 committed Jun 2, 2023
1 parent 325750b commit ed36171
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 13 deletions.
37 changes: 31 additions & 6 deletions mafic/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,14 +574,21 @@ async def _connect_to_websocket(
raise

async def connect(
self, *, backoff: ExponentialBackoff[Literal[False]] | None = None
self,
*,
backoff: ExponentialBackoff[Literal[False]] | None = None,
player_cls: type[Player[ClientT]] | None = None,
) -> None:
"""Connect to the node.
Parameters
----------
backoff:
The backoff to use when reconnecting.
player_cls:
The player class to use for the node when resuming.
.. versionadded:: 2.8
Raises
------
Expand Down Expand Up @@ -676,7 +683,7 @@ def remove_task(_: Task[None]) -> None:
_log.info(
"Node %s is now available.", self._label, extra={"label": self._label}
)
await self.sync_players()
await self.sync_players(player_cls=player_cls)
self._event_queue.set()
self._available = True
self._client.dispatch("node_ready", self)
Expand Down Expand Up @@ -1293,7 +1300,12 @@ async def unmark_all_addresses(self) -> None:
"""Unmark all failed addresses so they can be used again."""
await self.__request("POST", "routeplanner/free/all")

async def _add_unknown_player(self, player_id: int, state: PlayerPayload) -> None:
async def _add_unknown_player(
self,
player_id: int,
state: PlayerPayload,
cls: type[Player[ClientT]] | None = None,
) -> None:
"""Add an unknown player to the node.
Parameters
Expand All @@ -1302,6 +1314,8 @@ async def _add_unknown_player(self, player_id: int, state: PlayerPayload) -> Non
The guild ID of the player.
state:
The state of the player.
cls:
The class of the player to use.
"""
guild = self.client.get_guild(player_id)
if guild is None:
Expand All @@ -1320,7 +1334,7 @@ async def _add_unknown_player(self, player_id: int, state: PlayerPayload) -> Non
# Circular, pool -> node -> player -> pool
from .player import Player

player = Player(self.client, channel)
player = (cls or Player)(self.client, channel)

player.set_state(state)
player._node = self # pyright: ignore[reportPrivateUsage]
Expand All @@ -1340,13 +1354,22 @@ async def _remove_unknown_player(self, player_id: int) -> None:
await self._players[player_id].disconnect(force=True)
self.remove_player(player_id)

async def sync_players(self) -> None:
async def sync_players(
self, player_cls: type[Player[ClientT]] | None = None
) -> None:
"""Sync the players with the node.
.. note::
This method is called automatically when the client is ready.
You should not need to call this method yourself.
Parameters
----------
player_cls:
The class of the player to use.
.. versionadded:: 2.8
"""
players: list[PlayerPayload] = await self.__request(
"GET", f"sessions/{self._session_id}/players"
Expand All @@ -1357,7 +1380,9 @@ async def sync_players(self) -> None:

await gather(
*(
self._add_unknown_player(player_id, actual_players[player_id])
self._add_unknown_player(
player_id, actual_players[player_id], cls=player_cls
)
for player_id in actual_player_ids - expected_player_ids
),
*(
Expand Down
22 changes: 15 additions & 7 deletions mafic/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import aiohttp

from .player import Player
from .region import Group, Region, VoiceRegion


Expand Down Expand Up @@ -110,6 +111,7 @@ async def create_node(
regions: Sequence[Group | Region | VoiceRegion] | None = None,
shard_ids: Sequence[int] | None = None,
resuming_session_id: str | None = None,
player_cls: type[Player[ClientT]] | None = None,
) -> Node[ClientT]:
r"""Create a node and connect it.
Expand Down Expand Up @@ -156,6 +158,10 @@ async def create_node(
connection to us.
.. versionadded:: 2.2
player_cls:
The player class to use for this node when resuming.
.. versionadded:: 2.8
Returns
-------
Expand Down Expand Up @@ -187,10 +193,12 @@ async def create_node(
resuming_session_id=resuming_session_id,
)

await self.add_node(node)
await self.add_node(node, player_cls=player_cls)
return node

async def add_node(self, node: Node[ClientT]) -> None:
async def add_node(
self, node: Node[ClientT], *, player_cls: type[Player[ClientT]] | None = None
) -> None:
"""Add an existing node to this pool.
.. note::
Expand All @@ -205,6 +213,10 @@ async def add_node(self, node: Node[ClientT]) -> None:
----------
node:
The node to add.
player:
The player class to use for this node when resuming.
.. versionadded:: 2.8
"""
# Add to dictionaries, creating a set or extending it if needed.
if node.regions:
Expand All @@ -222,7 +234,7 @@ async def add_node(self, node: Node[ClientT]) -> None:
}

_log.info("Created node, connecting it...", extra={"label": node.label})
await node.connect()
await node.connect(player_cls=player_cls)

self._nodes[node.label] = node

Expand Down Expand Up @@ -255,8 +267,6 @@ async def remove_node(
del self._nodes[node.label]

if transfer_players:
if TYPE_CHECKING:
from .player import Player

async def transfer_player(player: Player[ClientT]) -> None:
try:
Expand All @@ -277,8 +287,6 @@ async def transfer_player(player: Player[ClientT]) -> None:
tasks = [transfer_player(player) for player in node.players]
await asyncio.gather(*tasks)
else:
if TYPE_CHECKING:
from .player import Player

async def destroy_player(player: Player[ClientT]) -> None:
_log.debug(
Expand Down

0 comments on commit ed36171

Please sign in to comment.