Skip to content

Commit

Permalink
feat: add node selection/balancing (#18)
Browse files Browse the repository at this point in the history
* feat: add region enums

* feat: add strategies for node balancing

* feat(node): add weight for stats comparisons

* fix(player): properly handle ping not existing in 3.5

* feat: add logging for node selection

* fix(plugin): add __all__

* feat(pool): add node selection

* test: add testing for node selection
  • Loading branch information
ooliver1 committed Nov 9, 2022
1 parent f228bed commit f5fc457
Show file tree
Hide file tree
Showing 12 changed files with 527 additions and 26 deletions.
2 changes: 2 additions & 0 deletions mafic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from .node import *
from .player import *
from .pool import *
from .region import *
from .search_type import *
from .strategy import *
from .track import *

del __libraries
Expand Down
6 changes: 6 additions & 0 deletions mafic/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"MaficException",
"MultipleCompatibleLibraries",
"NoCompatibleLibraries",
"NoNodesAvailable",
"PlayerNotConnected",
"TrackLoadException",
)
Expand Down Expand Up @@ -55,3 +56,8 @@ def from_data(cls, data: FriendlyException) -> Self:
class PlayerNotConnected(MaficException):
def __init__(self) -> None:
super().__init__("The player is not connected to a voice channel.")


class NoNodesAvailable(MaficException):
def __init__(self) -> None:
super().__init__("No nodes are available to handle this player.")
103 changes: 102 additions & 1 deletion mafic/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
)
from .playlist import Playlist
from .plugin import Plugin
from .region import VOICE_TO_REGION, Group, Region, VoiceRegion
from .stats import NodeStats
from .track import Track

if TYPE_CHECKING:
from asyncio import Task
from typing import Any
from typing import Any, Sequence

from aiohttp import ClientWebSocketResponse

Expand All @@ -53,6 +54,33 @@
_log = getLogger(__name__)
URL_REGEX = re.compile(r"https?://")

__all__ = ("Node",)


def _wrap_regions(
regions: Sequence[Group | Region | VoiceRegion] | None,
) -> list[Region] | None:
if not regions:
return None

actual_regions: list[Region] = []

for item in regions:
if isinstance(item, Group):
actual_regions.extend(item.value)
elif isinstance(item, Region):
actual_regions.append(item)
elif isinstance(
item, VoiceRegion
): # pyright: ignore[reportUnnecessaryIsInstance]
actual_regions.append(VOICE_TO_REGION[item.value])
else:
raise TypeError(
f"Expected Group, Region, or VoiceRegion, got {type(item)!r}."
)

return actual_regions


class Node:
__slots__ = (
Expand All @@ -73,6 +101,8 @@ class Node:
"_ws_uri",
"_ws_task",
"players",
"regions",
"shard_ids",
)

def __init__(
Expand All @@ -88,6 +118,8 @@ def __init__(
timeout: float = 10,
session: ClientSession | None = None,
resume_key: str | None = None,
regions: Sequence[Group | Region | VoiceRegion] | None = None,
shard_ids: Sequence[int] | None = None,
) -> None:
self._host = host
self._port = port
Expand All @@ -98,6 +130,8 @@ def __init__(
self._timeout = timeout
self._client = client
self.__session = session
self.shard_ids: Sequence[int] | None = shard_ids
self.regions: list[Region] | None = _wrap_regions(regions)

self._rest_uri = f"http{'s' if secure else ''}://{host}:{port}"
self._ws_uri = f"ws{'s' if secure else ''}://{host}:{port}"
Expand Down Expand Up @@ -136,6 +170,72 @@ def secure(self) -> bool:
def stats(self) -> NodeStats | None:
return self._stats

@property
def available(self) -> bool:
return self._available

@property
def weight(self) -> float:
if self._stats is None:
# Stats haven't been set yet, so we'll just return a high value.
# This is so we can properly balance known nodes.
# If stats sending is turned off
# - that's on the user
# - they likely have done it on all if they have multiple, so it is equal
return 6.63e34

stats = self._stats

players = stats.playing_player_count

# These are exponential equations.

# Load is *basically* a percentage (I know it isn't but it is close enough).

# | cores | load | weight |
# |-------|------|--------|
# | 1 | 0.1 | 16 |
# | 1 | 0.5 | 114 |
# | 1 | 0.75 | 388 |
# | 1 | 1 | 1315 |
# | 3 | 0.1 | 12 |
# | 3 | 1 | 51 |
# | 3 | 2 | 259 |
# | 3 | 3 | 1315 |
cpu = 1.05 ** (100 * (stats.cpu.system_load / stats.cpu.cores)) * 10 - 10

# | null frames | weight |
# | ----------- | ------ |
# | 10 | 30 |
# | 20 | 62 |
# | 100 | 382 |
# | 250 | 1456 |

frame_stats = stats.frame_stats
if frame_stats is None:
null = 0
deficit = 0
else:
null = 1.03 ** (frame_stats.nulled / 6) * 600 - 600
deficit = 1.03 ** (frame_stats.deficit / 6) * 600 - 600

# High memory usage isnt bad, but we generally don't want to overload it.
# Especially due to the chance of regular GC pauses.

# | memory usage | weight |
# | ------------ | ------ |
# | 96% | 0 |
# | 97% | 9 |
# | 98% | 99 |
# | 99% | 999 |
# | 99.5% | 3161 |
# | 100% | 9999 |

mem_stats = stats.memory
mem = max(10 ** (100 * (mem_stats.used / mem_stats.reservable) - 96), 1) - 1

return players + cpu + null + deficit + mem

async def connect(self) -> None:
_log.info("Waiting for client to be ready...", extra={"label": self._label})
await self._client.wait_until_ready()
Expand Down Expand Up @@ -256,6 +356,7 @@ async def __send(self, data: OutgoingMessage) -> None:

async def _handle_msg(self, data: IncomingMessage) -> None:
_log.debug("Received event with op %s", data["op"])
_log.debug("Event data: %s", data)

if data["op"] == "playerUpdate":
guild_id = int(data["guildId"])
Expand Down
10 changes: 8 additions & 2 deletions mafic/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@


_log = getLogger(__name__)
__all__ = ("Player",)


class Player(VoiceProtocol):
Expand All @@ -56,7 +57,7 @@ def __init__(
self._connected: bool = False
self._position: int = 0
self._last_update: int = 0
self._ping = 0
self._ping = -1
self._current: Track | None = None
self._filters: OrderedDict[str, Filter] = OrderedDict()

Expand Down Expand Up @@ -96,7 +97,7 @@ def update_state(self, state: PlayerUpdateState) -> None:
self._last_update = state["time"]
self._position = state.get("position", 0)
self._connected = state["connected"]
self._ping = state["ping"]
self._ping = state.get("ping", -1)

# If people are so in love with the VoiceClient interface
def is_connected(self) -> bool:
Expand Down Expand Up @@ -146,6 +147,11 @@ async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
self._node = NodePool.get_node(
guild_id=data["guild_id"], endpoint=data["endpoint"]
)
_log.debug(
"Got best node for player: %s",
self._node.label,
extra={"guild": self._guild_id},
)

self._node.players[self._guild_id] = self

Expand Down
3 changes: 3 additions & 0 deletions mafic/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from .typings import PluginData


__all__ = ("Plugin",)


@dataclass(repr=True)
class Plugin:
name: str
Expand Down
105 changes: 95 additions & 10 deletions mafic/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,28 @@
from typing import TYPE_CHECKING, Generic, TypeVar

from .__libraries import Client
from .errors import NoNodesAvailable
from .node import Node
from .strategy import STRATEGIES, Strategy

if TYPE_CHECKING:
from typing import ClassVar
from typing import ClassVar, Sequence, Union

from aiohttp import ClientSession

from .region import Group, Region, VoiceRegion
from .strategy import StrategyCallable

StrategyList = Union[
Sequence[Strategy],
StrategyCallable,
Sequence[StrategyCallable],
Sequence[Union[Strategy, StrategyCallable]],
]


ClientT = TypeVar("ClientT", bound=Client)
__all__ = ("NodePool",)


_log = getLogger(__name__)
Expand All @@ -24,52 +37,124 @@
class NodePool(Generic[ClientT]):
__slots__ = ()
_nodes: ClassVar[dict[str, Node]] = {}
_node_regions: ClassVar[dict[Region, set[Node]]] = {}
_node_shards: ClassVar[dict[int, set[Node]]] = {}
_client: ClientT | None = None
_default_strategies: StrategyList = [
Strategy.SHARD,
Strategy.LOCATION,
Strategy.USAGE,
]

def __init__(
self,
client: ClientT,
default_strategies: StrategyList | None = None,
) -> None:
NodePool._client = client

if default_strategies is not None:
NodePool._default_strategies = default_strategies

@property
def nodes(self) -> dict[str, Node]:
return self._nodes

@classmethod
async def create_node(
cls,
self,
*,
host: str,
port: int,
label: str,
password: str,
client: ClientT,
secure: bool = False,
heartbeat: int = 30,
timeout: float = 10,
session: ClientSession | None = None,
resume_key: str | None = None,
regions: Sequence[Group | Region | VoiceRegion] | None = None,
shard_ids: Sequence[int] | None = None,
) -> Node:
assert self._client is not None, "NodePool has not been initialized."

node = Node(
host=host,
port=port,
label=label,
password=password,
client=client,
client=self._client,
secure=secure,
heartbeat=heartbeat,
timeout=timeout,
session=session,
resume_key=resume_key,
regions=regions,
shard_ids=shard_ids,
)

# TODO: assign dicts for regions and such
cls._nodes[label] = node
self._nodes[label] = node

# Add to dictionaries, creating a set or extending it if needed.
if node.regions:
for region in node.regions:
self._node_regions[region] = {
node,
*self._node_regions.get(region, set()),
}

if node.shard_ids:
for shard_id in node.shard_ids:
self._node_shards[shard_id] = {
node,
*self._node_shards.get(shard_id, set()),
}

_log.info("Created node, connecting it...", extra={"label": label})
await node.connect()

return node

@classmethod
def get_node(cls, *, guild_id: str | int, endpoint: str | None) -> Node:
# TODO: use guild id, endpoint and other stuff like usage to determine node
def get_node(
cls,
*,
guild_id: str | int,
endpoint: str | None,
strategies: StrategyList | None = None,
) -> Node:
assert cls._client is not None, "NodePool has not been initialized."

return choice(list(cls._nodes.values()))
actual_strategies: Sequence[StrategyCallable | Strategy]

strategies = strategies or cls._default_strategies

if callable(strategies):
actual_strategies = [strategies]
else:
actual_strategies = strategies

nodes = cls._nodes.values()

for strategy in actual_strategies:
if isinstance(strategy, Strategy):
strategy = STRATEGIES[strategy]

nodes = strategy(
list(nodes), int(guild_id), cls._client.shard_count, endpoint
)

_log.debug(
"Strategy %s returned nodes %s.",
strategy.__name__,
", ".join(n.label for n in nodes),
)

if len(nodes) == 1:
return nodes[0]
elif len(nodes) == 0:
raise NoNodesAvailable

return choice(list(nodes))

@classmethod
def get_random_node(cls) -> Node:
Expand Down
Loading

0 comments on commit f5fc457

Please sign in to comment.