Skip to content

Commit

Permalink
feat: support getting tracks
Browse files Browse the repository at this point in the history
  • Loading branch information
ooliver1 committed Sep 21, 2022
1 parent 8a12d6e commit d096e4e
Show file tree
Hide file tree
Showing 12 changed files with 249 additions and 31 deletions.
5 changes: 4 additions & 1 deletion mafic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
from .errors import *
from .node import *
from .player import *
from .pool import *
from .search_type import *
from .track import *

# TODO: filters
# TODO: tracks
# TODO: playlists

del __libraries

Expand Down
29 changes: 18 additions & 11 deletions mafic/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,39 @@

from __future__ import annotations

__all__ = ("maficException", "LibraryCompatibilityError", "NoCompatibleLibraries")
__all__ = (
"LibraryCompatibilityError",
"MaficException",
"MultipleCompatibleLibraries",
"NoCompatibleLibraries",
"TrackLoadException",
)


class maficException(Exception):
"""The base exception for mafic errors."""
class MaficException(Exception):
...


class LibraryCompatibilityError(maficException):
"""An error raised when no compatible libraries are found."""
class LibraryCompatibilityError(MaficException):
...


class NoCompatibleLibraries(LibraryCompatibilityError):
"""An error raised when no compatible libraries are found."""

def __init__(self):
def __init__(self) -> None:
super().__init__(
"No compatible libraries were found. Please install one of the following: "
"nextcord, disnake, py-cord, discord.py or discord."
)


class MultipleCompatibleLibraries(LibraryCompatibilityError):
"""An error raised when multiple compatible libraries are found."""

def __init__(self, libraries: list[str]):
def __init__(self, libraries: list[str]) -> None:
super().__init__(
f"Multiple compatible libraries were found: {', '.join(libraries)}. "
"Please remove all but one of the libraries."
)


class TrackLoadException(MaficException):
def __init__(self, *, message: str, severity: str) -> None:
super().__init__(f"The track could not be loaded: {message} ({severity} error)")
74 changes: 72 additions & 2 deletions mafic/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@

from __future__ import annotations

import re
from asyncio import create_task, sleep
from logging import getLogger
from typing import TYPE_CHECKING, cast

from aiohttp import ClientSession, WSMsgType

from .__libraries import ExponentialBackoff, dumps, loads
from .errors import TrackLoadException
from .track import Track

if TYPE_CHECKING:
from asyncio import Task
from typing import Any

from aiohttp import ClientWebSocketResponse

Expand All @@ -20,12 +24,14 @@
from .typings import (
Coro,
EventPayload,
GetTracks,
IncomingMessage,
OutgoingMessage,
PlayPayload,
)

_log = getLogger(__name__)
URL_REGEX = re.compile(r"https?://")


class Node:
Expand Down Expand Up @@ -90,7 +96,7 @@ async def connect(self) -> None:
assert self._client.user is not None

if self.__session is None:
self.__session = ClientSession()
self.__session = await self._create_session()

session = self.__session

Expand All @@ -112,6 +118,8 @@ async def connect(self) -> None:
heartbeat=self._heartbeat,
headers=headers,
)
# TODO: handle exceptions from ws_connect

_log.info("Connected to lavalink.", extra={"label": self._label})
_log.debug(
"Creating task to send configuration to resume with key %s",
Expand Down Expand Up @@ -322,7 +330,69 @@ def play(
# TODO: volume
# TODO: filter
# TODO: API routes:
# TODO: fetch tracks

async def _create_session(self) -> ClientSession:
return ClientSession(json_serialize=dumps)

async def __request(
self,
method: str,
path: str,
json: Any | None = None,
params: dict[str, str] | None = None,
) -> Any:
if self.__session is None:
self.__session = await self._create_session()

session = self.__session
uri = self._rest_uri + path

async with session.request(
method,
uri,
json=json,
params=params,
headers={"Authorization": self.__password},
) as resp:
if not (200 <= resp.status < 300):
# TODO: raise proper error
raise RuntimeError(f"Got status code {resp.status} from lavalink.")

_log.debug(
"Received status %s from lavalink from path %s", resp.status, path
)

json = await resp.json(loads=loads)
_log.debug("Received raw data %s", json)
return json

async def fetch_tracks(
self, query: str, *, search_type: str
) -> list[Track] | None: # TODO: | Playlist
if not URL_REGEX.match(query):
query = f"{search_type}:{query}"

# TODO: handle errors from lavalink
# TODO: handle playlists
# TODO: return actual objects
data: GetTracks = await self.__request(
"GET", "/loadtracks", params={"identifier": query}
)

if data["loadType"] == "NO_MATCHES":
return []
elif data["loadType"] == "TRACK_LOADED":
return [Track(**data["tracks"][0])]
elif data["loadType"] == "PLAYLIST_LOADED":
# TODO: handle playlists
...
elif data["loadType"] == "SEARCH_RESULT":
return [Track(**track) for track in data["tracks"]]
elif data["loadType"] == "LOAD_FAILED":
raise TrackLoadException(**data["exception"])
else:
_log.warning("Unknown load type recieved: %s", data["loadType"])

# TODO: decode track
# TODO: plugins
# TODO: route planner status
Expand Down
32 changes: 26 additions & 6 deletions mafic/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .__libraries import GuildChannel, StageChannel, VoiceChannel, VoiceProtocol
from .pool import NodePool
from .search_type import SearchType

if TYPE_CHECKING:
from .__libraries import (
Expand Down Expand Up @@ -53,9 +54,7 @@ def connected(self) -> bool:
def is_connected(self):
return self._connected

async def _dispatch_player_update(
self, data: GuildVoiceStatePayload | VoiceServerUpdatePayload
) -> None:
async def _dispatch_player_update(self) -> None:
if self._node is None:
_log.debug("Recieved voice update before node was found.")
return
Expand Down Expand Up @@ -85,10 +84,16 @@ async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
self.channel = channel

if self._session_id != before_session_id:
await self._dispatch_player_update(data)
await self._dispatch_player_update()

async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
if self._node is None:
# Fetch the best node as we either don't know the best one yet.
# Or the node we were using was not the best one (endpoint optimisation).
if (
self._node is None
or self._server_state is None
or self._server_state["endpoint"] != data["endpoint"]
):
_log.debug("Getting best node for player", extra={"guild": self._guild_id})
self._node = NodePool.get_node(
guild_id=data["guild_id"], endpoint=data["endpoint"]
Expand All @@ -99,7 +104,7 @@ async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
self._guild_id = int(data["guild_id"])
self._server_state = data

await self._dispatch_player_update(data)
await self._dispatch_player_update()

async def connect(
self,
Expand Down Expand Up @@ -140,3 +145,18 @@ async def destroy(self) -> None:
if self._node is not None:
self._node.players.pop(self.guild.id, None)
await self._node.destroy(guild_id=self.guild.id)

async def fetch_tracks(
self, query: str, search_type: SearchType | str = SearchType.YOUTUBE
): # TODO:-> list[Track] | None:
if self._node is None:
# TODO: raise proper error
raise RuntimeError("No node found.")

raw_type: str
if isinstance(search_type, SearchType):
raw_type = search_type.value
else:
raw_type = search_type

return await self._node.fetch_tracks(query, search_type=raw_type)
4 changes: 2 additions & 2 deletions mafic/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .node import Node

if TYPE_CHECKING:
from typing import ClassVar, Optional
from typing import ClassVar

from aiohttp import ClientSession

Expand Down Expand Up @@ -63,7 +63,7 @@ async def create_node(
return node

@classmethod
def get_node(cls, *, guild_id: str | int, endpoint: Optional[str]) -> Node:
def get_node(cls, *, guild_id: str | int, endpoint: str | None) -> Node:
# TODO: use guild id, endpoint and other stuff like usage to determine node

return choice(list(cls._nodes.values()))
13 changes: 13 additions & 0 deletions mafic/search_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# SPDX-License-Identifier: MIT

from __future__ import annotations

from enum import Enum

__all__ = ("SearchType",)


class SearchType(Enum):
YOUTUBE = "ytsearch"
YOUTUBE_MUSIC = "ytmsearch"
SOUNDCLOUD = "scsearch"
28 changes: 28 additions & 0 deletions mafic/track.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-License-Identifier: MIT

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .typings import TrackInfo

__all__ = ("Track",)


class Track:
def __init__(self, track: str, info: TrackInfo) -> None:
self.id = track

self.title: str = info["title"]
self.author: str = info["author"]

self.identifier: str = info["identifier"]
self.uri: str = info["uri"]
self.source: str = info["sourceName"]

self.stream: bool = info["isStream"]
self.seekable: bool = info["isSeekable"]

self.position: int = info["position"]
self.length: int = info["length"]
1 change: 1 addition & 0 deletions mafic/typings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: MIT

from .http import *
from .incoming import *
from .misc import *
from .outgoing import *
57 changes: 57 additions & 0 deletions mafic/typings/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# SPDX-License-Identifier: MIT

from __future__ import annotations

from typing import TYPE_CHECKING, TypedDict, Union

if TYPE_CHECKING:
from typing import Literal

from .misc import FriendlyException

__all__ = (
"GetTracks",
"PlaylistInfo",
"Tracks",
"TrackInfo",
"TrackWithInfo",
"TracksFailed",
)


class PlaylistInfo(TypedDict):
name: str
selectedTrack: int


class TrackInfo(TypedDict):
identifier: str
isSeekable: bool
author: str
length: int
isStream: bool
position: int
sourceName: str
title: str
uri: str


class TrackWithInfo(TypedDict):
track: str
info: TrackInfo


class Tracks(TypedDict):
loadType: Literal["TRACK_LOADED", "PLAYLIST_LOADED", "SEARCH_RESULT", "NO_MATCHES"]
playlistInfo: PlaylistInfo
tracks: list[TrackWithInfo]


class TracksFailed(TypedDict):
loadType: Literal["LOAD_FAILED"]
playlistInfo: PlaylistInfo
tracks: list[TrackWithInfo]
exception: FriendlyException


GetTracks = Union[Tracks, TracksFailed]
Loading

0 comments on commit d096e4e

Please sign in to comment.