Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add node selection/balancing #18

Merged
merged 8 commits into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mafic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from .node import *
from .player import *
from .pool import *
from .region import *
from .search_type import *
from .strategy import *
from .track import *

del __libraries
Expand Down
6 changes: 6 additions & 0 deletions mafic/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"MaficException",
"MultipleCompatibleLibraries",
"NoCompatibleLibraries",
"NoNodesAvailable",
"PlayerNotConnected",
"TrackLoadException",
)
Expand Down Expand Up @@ -55,3 +56,8 @@ def from_data(cls, data: FriendlyException) -> Self:
class PlayerNotConnected(MaficException):
def __init__(self) -> None:
super().__init__("The player is not connected to a voice channel.")


class NoNodesAvailable(MaficException):
def __init__(self) -> None:
super().__init__("No nodes are available to handle this player.")
103 changes: 102 additions & 1 deletion mafic/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
)
from .playlist import Playlist
from .plugin import Plugin
from .region import VOICE_TO_REGION, Group, Region, VoiceRegion
from .stats import NodeStats
from .track import Track

if TYPE_CHECKING:
from asyncio import Task
from typing import Any
from typing import Any, Sequence

from aiohttp import ClientWebSocketResponse

Expand All @@ -53,6 +54,33 @@
_log = getLogger(__name__)
URL_REGEX = re.compile(r"https?://")

__all__ = ("Node",)


def _wrap_regions(
regions: Sequence[Group | Region | VoiceRegion] | None,
) -> list[Region] | None:
if not regions:
return None

actual_regions: list[Region] = []

for item in regions:
if isinstance(item, Group):
actual_regions.extend(item.value)
elif isinstance(item, Region):
actual_regions.append(item)
elif isinstance(
item, VoiceRegion
): # pyright: ignore[reportUnnecessaryIsInstance]
actual_regions.append(VOICE_TO_REGION[item.value])
else:
raise TypeError(
f"Expected Group, Region, or VoiceRegion, got {type(item)!r}."
)

return actual_regions


class Node:
__slots__ = (
Expand All @@ -73,6 +101,8 @@ class Node:
"_ws_uri",
"_ws_task",
"players",
"regions",
"shard_ids",
)

def __init__(
Expand All @@ -88,6 +118,8 @@ def __init__(
timeout: float = 10,
session: ClientSession | None = None,
resume_key: str | None = None,
regions: Sequence[Group | Region | VoiceRegion] | None = None,
shard_ids: Sequence[int] | None = None,
) -> None:
self._host = host
self._port = port
Expand All @@ -98,6 +130,8 @@ def __init__(
self._timeout = timeout
self._client = client
self.__session = session
self.shard_ids: Sequence[int] | None = shard_ids
self.regions: list[Region] | None = _wrap_regions(regions)

self._rest_uri = f"http{'s' if secure else ''}://{host}:{port}"
self._ws_uri = f"ws{'s' if secure else ''}://{host}:{port}"
Expand Down Expand Up @@ -136,6 +170,72 @@ def secure(self) -> bool:
def stats(self) -> NodeStats | None:
return self._stats

@property
def available(self) -> bool:
return self._available

@property
def weight(self) -> float:
if self._stats is None:
# Stats haven't been set yet, so we'll just return a high value.
# This is so we can properly balance known nodes.
# If stats sending is turned off
# - that's on the user
# - they likely have done it on all if they have multiple, so it is equal
return 6.63e34

stats = self._stats

players = stats.playing_player_count

# These are exponential equations.

# Load is *basically* a percentage (I know it isn't but it is close enough).

# | cores | load | weight |
# |-------|------|--------|
# | 1 | 0.1 | 16 |
# | 1 | 0.5 | 114 |
# | 1 | 0.75 | 388 |
# | 1 | 1 | 1315 |
# | 3 | 0.1 | 12 |
# | 3 | 1 | 51 |
# | 3 | 2 | 259 |
# | 3 | 3 | 1315 |
cpu = 1.05 ** (100 * (stats.cpu.system_load / stats.cpu.cores)) * 10 - 10

# | null frames | weight |
# | ----------- | ------ |
# | 10 | 30 |
# | 20 | 62 |
# | 100 | 382 |
# | 250 | 1456 |

frame_stats = stats.frame_stats
if frame_stats is None:
null = 0
deficit = 0
else:
null = 1.03 ** (frame_stats.nulled / 6) * 600 - 600
deficit = 1.03 ** (frame_stats.deficit / 6) * 600 - 600

# High memory usage isnt bad, but we generally don't want to overload it.
# Especially due to the chance of regular GC pauses.

# | memory usage | weight |
# | ------------ | ------ |
# | 96% | 0 |
# | 97% | 9 |
# | 98% | 99 |
# | 99% | 999 |
# | 99.5% | 3161 |
# | 100% | 9999 |

mem_stats = stats.memory
mem = max(10 ** (100 * (mem_stats.used / mem_stats.reservable) - 96), 1) - 1

return players + cpu + null + deficit + mem

async def connect(self) -> None:
_log.info("Waiting for client to be ready...", extra={"label": self._label})
await self._client.wait_until_ready()
Expand Down Expand Up @@ -256,6 +356,7 @@ async def __send(self, data: OutgoingMessage) -> None:

async def _handle_msg(self, data: IncomingMessage) -> None:
_log.debug("Received event with op %s", data["op"])
_log.debug("Event data: %s", data)

if data["op"] == "playerUpdate":
guild_id = int(data["guildId"])
Expand Down
10 changes: 8 additions & 2 deletions mafic/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@


_log = getLogger(__name__)
__all__ = ("Player",)


class Player(VoiceProtocol):
Expand All @@ -56,7 +57,7 @@ def __init__(
self._connected: bool = False
self._position: int = 0
self._last_update: int = 0
self._ping = 0
self._ping = -1
self._current: Track | None = None
self._filters: OrderedDict[str, Filter] = OrderedDict()

Expand Down Expand Up @@ -96,7 +97,7 @@ def update_state(self, state: PlayerUpdateState) -> None:
self._last_update = state["time"]
self._position = state.get("position", 0)
self._connected = state["connected"]
self._ping = state["ping"]
self._ping = state.get("ping", -1)

# If people are so in love with the VoiceClient interface
def is_connected(self) -> bool:
Expand Down Expand Up @@ -146,6 +147,11 @@ async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
self._node = NodePool.get_node(
guild_id=data["guild_id"], endpoint=data["endpoint"]
)
_log.debug(
"Got best node for player: %s",
self._node.label,
extra={"guild": self._guild_id},
)

self._node.players[self._guild_id] = self

Expand Down
3 changes: 3 additions & 0 deletions mafic/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from .typings import PluginData


__all__ = ("Plugin",)


@dataclass(repr=True)
class Plugin:
name: str
Expand Down
105 changes: 95 additions & 10 deletions mafic/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,28 @@
from typing import TYPE_CHECKING, Generic, TypeVar

from .__libraries import Client
from .errors import NoNodesAvailable
from .node import Node
from .strategy import STRATEGIES, Strategy

if TYPE_CHECKING:
from typing import ClassVar
from typing import ClassVar, Sequence, Union

from aiohttp import ClientSession

from .region import Group, Region, VoiceRegion
from .strategy import StrategyCallable

StrategyList = Union[
Sequence[Strategy],
StrategyCallable,
Sequence[StrategyCallable],
Sequence[Union[Strategy, StrategyCallable]],
]


ClientT = TypeVar("ClientT", bound=Client)
__all__ = ("NodePool",)


_log = getLogger(__name__)
Expand All @@ -24,52 +37,124 @@
class NodePool(Generic[ClientT]):
__slots__ = ()
_nodes: ClassVar[dict[str, Node]] = {}
_node_regions: ClassVar[dict[Region, set[Node]]] = {}
_node_shards: ClassVar[dict[int, set[Node]]] = {}
_client: ClientT | None = None
_default_strategies: StrategyList = [
Strategy.SHARD,
Strategy.LOCATION,
Strategy.USAGE,
]

def __init__(
self,
client: ClientT,
default_strategies: StrategyList | None = None,
) -> None:
NodePool._client = client

if default_strategies is not None:
NodePool._default_strategies = default_strategies

@property
def nodes(self) -> dict[str, Node]:
return self._nodes

@classmethod
async def create_node(
cls,
self,
*,
host: str,
port: int,
label: str,
password: str,
client: ClientT,
secure: bool = False,
heartbeat: int = 30,
timeout: float = 10,
session: ClientSession | None = None,
resume_key: str | None = None,
regions: Sequence[Group | Region | VoiceRegion] | None = None,
shard_ids: Sequence[int] | None = None,
) -> Node:
assert self._client is not None, "NodePool has not been initialized."

node = Node(
host=host,
port=port,
label=label,
password=password,
client=client,
client=self._client,
secure=secure,
heartbeat=heartbeat,
timeout=timeout,
session=session,
resume_key=resume_key,
regions=regions,
shard_ids=shard_ids,
)

# TODO: assign dicts for regions and such
cls._nodes[label] = node
self._nodes[label] = node

# Add to dictionaries, creating a set or extending it if needed.
if node.regions:
for region in node.regions:
self._node_regions[region] = {
node,
*self._node_regions.get(region, set()),
}

if node.shard_ids:
for shard_id in node.shard_ids:
self._node_shards[shard_id] = {
node,
*self._node_shards.get(shard_id, set()),
}

_log.info("Created node, connecting it...", extra={"label": label})
await node.connect()

return node

@classmethod
def get_node(cls, *, guild_id: str | int, endpoint: str | None) -> Node:
# TODO: use guild id, endpoint and other stuff like usage to determine node
def get_node(
cls,
*,
guild_id: str | int,
endpoint: str | None,
strategies: StrategyList | None = None,
) -> Node:
assert cls._client is not None, "NodePool has not been initialized."

return choice(list(cls._nodes.values()))
actual_strategies: Sequence[StrategyCallable | Strategy]

strategies = strategies or cls._default_strategies

if callable(strategies):
actual_strategies = [strategies]
else:
actual_strategies = strategies

nodes = cls._nodes.values()

for strategy in actual_strategies:
if isinstance(strategy, Strategy):
strategy = STRATEGIES[strategy]

nodes = strategy(
list(nodes), int(guild_id), cls._client.shard_count, endpoint
)

_log.debug(
"Strategy %s returned nodes %s.",
strategy.__name__,
", ".join(n.label for n in nodes),
)

if len(nodes) == 1:
return nodes[0]
elif len(nodes) == 0:
raise NoNodesAvailable

return choice(list(nodes))

@classmethod
def get_random_node(cls) -> Node:
Expand Down
Loading