Skip to content

Commit

Permalink
feat: transfer players to new nodes (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
ooliver1 committed Jun 1, 2023
1 parent 6baa78b commit bd833ee
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 14 deletions.
61 changes: 61 additions & 0 deletions mafic/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,29 @@ def get_player(self, guild_id: int) -> Player[ClientT] | None:
"""
return self._players.get(guild_id)

async def fetch_player(self, guild_id: int) -> PlayerPayload:
"""Fetch player data from the node.
.. note::
This is an API call. Usually you should use :meth:`get_player` instead.
.. versionadded:: 2.6
Parameters
----------
guild_id:
The guild ID to fetch the player for.
Returns
-------
:class:`dict`
The player data for the guild.
"""
return await self.__request(
"GET", f"sessions/{self._session_id}/players/{guild_id}"
)

def add_player(self, guild_id: int, player: Player[ClientT]) -> None:
"""Add a player to the node.
Expand Down Expand Up @@ -654,6 +677,44 @@ def remove_task(_: Task[None]) -> None:
self._available = True
self._client.dispatch("node_ready", self)

async def close(self) -> None:
"""Close the node.
This will disconnect the websocket and close the session.
.. versionadded:: 2.6
"""
if self._ws is not None:
_log.debug("Closing websocket.", extra={"label": self._label})
await self._ws.close()
self._ws = None
_log.debug("Websocket closed.", extra={"label": self._label})

if self.__session is not None:
_log.debug("Closing session.", extra={"label": self._label})
await self.__session.close()
self.__session = None
_log.debug("Session closed.", extra={"label": self._label})

if self._ws_task is not None:
_log.debug(
"Cancelling websocket listener task.", extra={"label": self._label}
)
self._ws_task.cancel()
self._ws_task = None
_log.debug(
"Websocket listener task cancelled.", extra={"label": self._label}
)

if self._connect_task is not None:
_log.debug("Cancelling connection task.", extra={"label": self._label})
self._connect_task.cancel()
self._connect_task = None
_log.debug("Connection task cancelled.", extra={"label": self._label})

_log.info("Node %s is now closed.", self._label, extra={"label": self._label})
self._available = False

async def _ws_listener(self) -> None:
"""Listen for messages from the websocket."""
backoff = ExponentialBackoff()
Expand Down
50 changes: 50 additions & 0 deletions mafic/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
raise TypeError(msg)

self.guild: Guild = self.channel.guild
self.endpoint: str | None = None

self._node = node

Expand Down Expand Up @@ -384,6 +385,7 @@ async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
self._node.add_player(self._guild_id, self)

self._guild_id = int(data["guild_id"])
self._endpoint = data["endpoint"]
self._server_state = data

await self._dispatch_player_update()
Expand Down Expand Up @@ -478,6 +480,54 @@ def cleanup(self) -> None:

return super().cleanup()

async def transfer_to(self, node: Node[ClientT]) -> None:
"""Transfer the player to a new node.
Parameters
----------
node:
The node to transfer to.
"""
if self._node is None:
raise PlayerNotConnected

if self._node == node:
return

state = await self._node.fetch_player(self.guild.id)

# Remove from the current node, but no need to destroy.
self._node.remove_player(self.guild.id)

old_node = self._node
self._node = node
self._node.add_player(self.guild.id, self)

if self._session_id is None or self._server_state is None:
msg = "Cannot transfer player with session data."
raise RuntimeError(msg)

# We need to update the voice server as the endpoint may have changed.
await self._node.voice_update(
guild_id=self._guild_id,
session_id=self._session_id,
data=self._server_state,
)

# Needed so .update does not fail.
self._connected = True
# Update player with all other state.
# Position, filters, track, etc.
await self.update(
track=self._current,
position=self.position,
volume=state["volume"],
pause=self._paused,
filter=reduce(or_, self._filters.values()) if self._filters else Filter(),
)

await old_node.destroy(guild_id=self.guild.id)

async def destroy(self) -> None:
"""Destroy the player.
Expand Down
78 changes: 77 additions & 1 deletion mafic/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

from __future__ import annotations

import asyncio
from collections.abc import Sequence
from functools import partial
from logging import getLogger
from random import choice
from typing import TYPE_CHECKING, Any, Generic, List, TypeVar, Union, cast

from .errors import NoNodesAvailable
from .errors import NoNodesAvailable, PlayerNotConnected
from .node import Node
from .strategy import Strategy, StrategyCallable, call_strategy
from .type_variables import ClientT
Expand Down Expand Up @@ -207,6 +208,81 @@ async def create_node(
self._nodes[label] = node
return node

async def remove_node(
self, node: Node[ClientT] | str, *, transfer_players: bool = True
) -> None:
"""Remove a node from the pool.
.. versionadded:: 2.6
Parameters
----------
node:
The node to remove.
transfer_players:
Whether to transfer players to other nodes or destroy them.
"""
if isinstance(node, str):
node = self._nodes[node]

if node.regions:
for region in node.regions:
self._node_regions[region].remove(node)

if node.shard_ids:
for shard_id in node.shard_ids:
self._node_shards[shard_id].remove(node)

# Remove prematurely so it is not chosen.
del self._nodes[node.label]

if transfer_players:
if TYPE_CHECKING:
from .player import Player

async def transfer_player(player: Player[ClientT]) -> None:
try:
target = self.get_node(
guild_id=player.guild.id,
endpoint=player.endpoint, # pyright: ignore[reportPrivateUsage]
)
await player.transfer_to(target)
except (RuntimeError, NoNodesAvailable, PlayerNotConnected):
_log.error(
"Failed to transfer player %d, destroying it...",
player.guild.id,
exc_info=True,
extra={"label": node.label},
)
await player.destroy()

tasks = [transfer_player(player) for player in node.players]
await asyncio.gather(*tasks)
else:
if TYPE_CHECKING:
from .player import Player

async def destroy_player(player: Player[ClientT]) -> None:
_log.debug(
"Destroying player %d due to node removal...",
player.guild.id,
extra={"label": node.label},
)
await player.destroy()

tasks = [destroy_player(player) for player in node.players]
await asyncio.gather(*tasks)

await node.close()

async def close(self) -> None:
"""Close all nodes in the pool.
.. versionadded:: 2.6
"""
for node in self._nodes.values():
await node.close()

@classmethod
def get_node(
cls,
Expand Down
1 change: 1 addition & 0 deletions test_bot/.dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
!pyproject.toml
!poetry.lock
!lavalink/
!gateway-proxy/
36 changes: 26 additions & 10 deletions test_bot/bot/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def before_identify_hook(
# gateway-proxy
return

async def add_nodes(self) -> None:
async def add_nodes(self) -> None: # noqa: PLR0912
with open(environ["LAVALINK_FILE"], "rb") as f:
data: list[LavalinkInfo] = orjson.loads(f.read())

Expand All @@ -100,15 +100,25 @@ async def add_nodes(self) -> None:

regions.append(region)

await self.pool.create_node(
host=node["host"],
port=node["port"],
password=node["password"],
regions=regions,
label=node["label"],
shard_ids=node.get("shard_ids"),
resuming_session_id=session_id,
)
if environ["LAVALINK_FILE"] == "lavalink/multi-nodes.json":
await asyncio.sleep(10)

for tries in range(5):
try:
await self.pool.create_node(
host=node["host"],
port=node["port"],
password=node["password"],
regions=regions,
label=node["label"],
shard_ids=node.get("shard_ids"),
resuming_session_id=session_id,
)
except: # noqa: E722
traceback.print_exc()
await asyncio.sleep(tries * 2)
else:
break

async def start(self, token: str, *, reconnect: bool = True) -> None:
await gather(self.add_nodes(), super().start(token, reconnect=reconnect))
Expand Down Expand Up @@ -245,6 +255,12 @@ async def close(inter: Interaction):
await bot.close()


@bot.slash_command()
async def transfer(inter: Interaction):
await bot.pool.remove_node(inter.guild.voice_client.node)
await inter.send("Transferred node.")


@bot.slash_command()
async def boost(inter: Interaction):
if not inter.guild.voice_client:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ version: "3"

services:
bot:
restart: on-failure:5
depends_on:
- lavalink
build: .
Expand All @@ -10,21 +11,33 @@ services:
- ../mafic:/bot/mafic
environment:
TOKEN: $TOKEN
GW_PROXY: ws://127.0.0.1:7878
GW_PROXY: ws://localhost:7878
LAVALINK_FILE: lavalink/multi-nodes.json
network_mode: host
lavalink:
build: lavalink
image: ghcr.io/freyacodes/lavalink:v4
volumes:
- ./logs/lava:/opt/Lavalink/logs
deploy:
replicas: 8
ports:
- "6962-6969:6969"
environment:
JDK_JAVA_OPTIONS: -Xmx2G
SERVER_PORT: 6969
SERVER_ADDRESS: 0.0.0.0
LAVALINK_SERVER_PASSWORD: haha
LOGGING_FILE_PATH: ./logs/
LOGGING_LEVEL_ROOT: INFO
LOGGING_LEVEL_LAVALINK: INFO
LOGGING_REQUEST_ENABLED: true
LOGGING_LOGBACK_ROLLINGPOLICY_MAXFILESIZE: 100MB
LOGGING_LOGBACK_ROLLINGPOLICY_MAXHISTORY: 30
gateway-proxy:
image: gelbpunkt/gateway-proxy:x86-64
volumes:
- ./gateway-proxy/noshard.config.json:/config.json
environment:
TOKEN: $TOKEN
network_mode: host
ports:
- "7878:7878"

0 comments on commit bd833ee

Please sign in to comment.