Skip to content
This repository has been archived by the owner on Jan 1, 2023. It is now read-only.

Commit

Permalink
Utilised Tanjun's new features to reduce boilerplate (2) (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
fdnt7 authored Sep 26, 2022
1 parent 44138f4 commit 16c75c9
Show file tree
Hide file tree
Showing 21 changed files with 442 additions and 443 deletions.
2 changes: 1 addition & 1 deletion lyra/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
alluka==0.1.1
attrs
hikari==2.0.0.dev110
hikari-tanjun==2.6.2a1
hikari-tanjun==2.7.0a1
lyricsgenius==3.0.1
PyYAML
requests
Expand Down
12 changes: 6 additions & 6 deletions lyra/src/lib/cmd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
SlashCommandGroupType,
SlashCommandType,
)
from ..utils.types import (
Contextish,
from ..utils import (
ContextishType,
)
from ..extras import RecurserSig, recurse
from ..extras.types import Option


def get_implied_prefix(ctx_cmd: Contextish | GenericAnyCommandType, /) -> str:
def get_implied_prefix(ctx_cmd: ContextishType | GenericAnyCommandType, /) -> str:
if isinstance(ctx_cmd, tj.abc.MessageContext):
return next(iter(ctx_cmd.client.prefixes))
if isinstance(
Expand Down Expand Up @@ -107,7 +107,7 @@ def get_full_cmd_repr(


def get_full_cmd_repr(
_ctx_: Option[Contextish],
_ctx_: Option[ContextishType],
/,
cmd: Option[GenericAnyCommandType] = None,
*,
Expand All @@ -132,7 +132,7 @@ def get_full_cmd_repr(

@t.overload
def get_full_cmd_repr_from_identifier(
identifier: CommandIdentifier, /, ctx_: Contextish, *, pretty: bool = True
identifier: CommandIdentifier, /, ctx_: ContextishType, *, pretty: bool = True
) -> str:
...

Expand All @@ -147,7 +147,7 @@ def get_full_cmd_repr_from_identifier(
def get_full_cmd_repr_from_identifier(
identifier: CommandIdentifier,
/,
_ctx_c: Contextish | tj.abc.Client,
_ctx_c: ContextishType | tj.abc.Client,
*,
pretty: bool = True,
):
Expand Down
26 changes: 14 additions & 12 deletions lyra/src/lib/cmd/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
from ..utils import (
BindSig,
Contextish,
ContextishType,
get_client,
fetch_permissions,
start_confirmation_prompt,
Expand All @@ -46,7 +46,9 @@
from ..lava.utils import get_queue


async def others_not_in_vc_check(ctx_: Contextish, lvc: lv.Lavalink, /) -> Result[bool]:
async def others_not_in_vc_check(
ctx_: ContextishType, lvc: lv.Lavalink, /
) -> Result[bool]:
assert ctx_.guild_id

conn = t.cast(
Expand All @@ -72,7 +74,7 @@ def with_cb_check(
P = t.ParamSpec('P')

async def _check_in_vc(
ctx_: Contextish, conn: ConnectionInfo, /, *, perms: hkperms = DJ_PERMS
ctx_: ContextishType, conn: ConnectionInfo, /, *, perms: hkperms = DJ_PERMS
):
member = ctx_.member
assert ctx_.guild_id
Expand All @@ -94,7 +96,7 @@ async def _check_in_vc(
if not (auth_perms & (perms | hkperms.ADMINISTRATOR)) and not author_in_voice:
raise AlreadyConnected(channel)

async def check_in_vc(ctx_: Contextish, lvc: lv.Lavalink):
async def check_in_vc(ctx_: ContextishType, lvc: lv.Lavalink):
assert ctx_.guild_id

conn = t.cast(
Expand All @@ -103,7 +105,7 @@ async def check_in_vc(ctx_: Contextish, lvc: lv.Lavalink):
assert conn is not None
await _check_in_vc(ctx_, conn)

async def check_np_yours(ctx_: Contextish, lvc: lv.Lavalink):
async def check_np_yours(ctx_: ContextishType, lvc: lv.Lavalink):
assert ctx_.member

auth_perms = await fetch_permissions(ctx_)
Expand All @@ -114,7 +116,7 @@ async def check_np_yours(ctx_: Contextish, lvc: lv.Lavalink):
):
raise PlaybackChangeRefused(q.current)

async def check_can_seek_any(ctx_: Contextish, lvc: lv.Lavalink):
async def check_can_seek_any(ctx_: ContextishType, lvc: lv.Lavalink):
assert ctx_.member

auth_perms = await fetch_permissions(ctx_)
Expand All @@ -124,26 +126,26 @@ async def check_can_seek_any(ctx_: Contextish, lvc: lv.Lavalink):
if ctx_.member.id != np.requester:
raise PlaybackChangeRefused(np)

async def check_stop(ctx_: Contextish, lvc: lv.Lavalink):
async def check_stop(ctx_: ContextishType, lvc: lv.Lavalink):
if (await get_queue(ctx_, lvc)).is_stopped:
raise TrackStopped

async def check_conn(ctx_: Contextish, lvc: lv.Lavalink):
async def check_conn(ctx_: ContextishType, lvc: lv.Lavalink):
assert ctx_.guild_id

conn = lvc.get_guild_gateway_connection_info(ctx_.guild_id)
if not conn:
raise NotConnected

async def check_queue(ctx_: Contextish, lvc: lv.Lavalink):
async def check_queue(ctx_: ContextishType, lvc: lv.Lavalink):
if not await get_queue(ctx_, lvc):
raise QueueEmpty

async def check_playing(ctx_: Contextish, lvc: lv.Lavalink):
async def check_playing(ctx_: ContextishType, lvc: lv.Lavalink):
if not (await get_queue(ctx_, lvc)).current:
raise NotPlaying

async def check_pause(ctx_: Contextish, lvc: lv.Lavalink):
async def check_pause(ctx_: ContextishType, lvc: lv.Lavalink):
if (await get_queue(ctx_, lvc)).is_paused:
raise TrackPaused

Expand All @@ -153,7 +155,7 @@ def callback(
@ft.wraps(func)
async def inner(*args: P.args, **kwargs: P.kwargs) -> None:

ctx_ = next((a for a in args if isinstance(a, Contextish)), NULL)
ctx_ = next((a for a in args if isinstance(a, ContextishType)), NULL)

assert ctx_, "Missing a Contextish object"
assert ctx_.guild_id
Expand Down
4 changes: 2 additions & 2 deletions lyra/src/lib/cmd/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
DJ_PERMS,
TIMEOUT,
BindSig,
Contextish,
ContextishType,
ConnectionInfo,
delete_after,
fetch_permissions,
Expand Down Expand Up @@ -81,7 +81,7 @@ class Binds(AutoDocsFlag):
VOTE = """Binds a voting prompt to be used when needed"""


async def speaker_check(ctx_: Contextish, /) -> Result[bool]:
async def speaker_check(ctx_: ContextishType, /) -> Result[bool]:
assert ctx_.guild_id

client = get_client(ctx_)
Expand Down
4 changes: 2 additions & 2 deletions lyra/src/lib/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
DJ_PERMS,
RESTRICTOR,
ConnectionInfo,
Contextish,
ContextishType,
err_say,
fetch_permissions,
get_client,
Expand Down Expand Up @@ -249,7 +249,7 @@ async def join_impl_precaught(


async def others_not_in_vc_check_impl(
ctx_: Contextish, conn: ConnectionInfo, /, *, perms: hkperms = DJ_PERMS
ctx_: ContextishType, conn: ConnectionInfo, /, *, perms: hkperms = DJ_PERMS
) -> Result[bool]:
auth_perms = await fetch_permissions(ctx_)
member = ctx_.member
Expand Down
8 changes: 4 additions & 4 deletions lyra/src/lib/errors/expects.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from ..cmd import get_full_cmd_repr_from_identifier
from ..cmd.ids import CommandIdentifier
from ..utils import Contextish, dj_perms_fmt, err_say, get_rest, say
from ..utils import ContextishType, dj_perms_fmt, err_say, get_rest, say
from ..extras import Result, format_flags


Expand All @@ -36,7 +36,7 @@

@a.frozen
class BaseErrorExpects(abc.ABC):
context: Contextish
context: ContextishType

@abc.abstractmethod
def match_expect(self, error: Exception, /) -> Result[ExpectSig]:
Expand All @@ -54,7 +54,7 @@ async def expect(self, error: Exception, /) -> bool:

@a.frozen
class CheckErrorExpects(BaseErrorExpects):
context: Contextish
context: ContextishType

async def expect_network_error(self):
await err_say(self.context, content="⁉️ A network error has occurred")
Expand Down Expand Up @@ -192,7 +192,7 @@ async def expect(self, error: Exception, /) -> bool:

@a.frozen
class BindErrorExpects(BaseErrorExpects):
context: Contextish
context: ContextishType

async def expect_command_cancelled(self):
await err_say(self.context, follow_up=False, content="🛑 Cancelled the command")
Expand Down
2 changes: 1 addition & 1 deletion lyra/src/lib/extras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
Panic,
NULL,
RGBTriplet,
MaybeIterable,
IterableOr,
MapSig,
AsyncVoidAnySig,
URLstr,
Expand Down
3 changes: 2 additions & 1 deletion lyra/src/lib/extras/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def __bool__(self) -> t.Literal[False]:
OptionResult = Option[Result[_T]]
Panic = t.Annotated[_T | t.NoReturn, ...]
Require = t.Annotated[_T_co, ...]
MaybeIterable = _T | t.Iterable[_T]
AnyOr = t.Any | _T
IterableOr = _T | t.Iterable[_T]

KeySig = t.Callable[[__E], _KE]
MapSig = t.Callable[[_T], _T]
Expand Down
26 changes: 15 additions & 11 deletions lyra/src/lib/lava/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import hikari as hk
import lavasnek_rs as lv

from ..utils import GuildOrInferable, infer_guild, limit_img_size_by_guild
from ..utils import MaybeGuildIDAware, IntCastable, infer_guild, limit_img_size_by_guild
from ..errors import NotConnected, QueueEmpty
from ..consts import STOP_REFRESH
from ..extras import (
Expand Down Expand Up @@ -340,34 +340,36 @@ async def set_data(


@ctxlib.asynccontextmanager
async def access_queue(g_inf: GuildOrInferable, lvc: lv.Lavalink, /):
data = await get_data(g := infer_guild(g_inf), lvc)
async def access_queue(g_: IntCastable | MaybeGuildIDAware, lvc: lv.Lavalink, /):
data = await get_data(g := infer_guild(g_), lvc)
try:
yield data.queue
finally:
await set_data(g, lvc, data)


@ctxlib.asynccontextmanager
async def access_equalizer(g_inf: GuildOrInferable, lvc: lv.Lavalink, /):
data = await get_data(g := infer_guild(g_inf), lvc)
async def access_equalizer(g_: IntCastable | MaybeGuildIDAware, lvc: lv.Lavalink, /):
data = await get_data(g := infer_guild(g_), lvc)
try:
yield data.equalizer
finally:
await set_data(g, lvc, data)


@ctxlib.asynccontextmanager
async def access_data(g_inf: GuildOrInferable, lvc: lv.Lavalink, /):
data = await get_data(g := infer_guild(g_inf), lvc)
async def access_data(g_: IntCastable | MaybeGuildIDAware, lvc: lv.Lavalink, /):
data = await get_data(g := infer_guild(g_), lvc)
try:
yield data
finally:
await set_data(g, lvc, data)


async def get_queue(g_inf: GuildOrInferable, lvc: lv.Lavalink, /) -> Panic[QueueList]:
return (await get_data(infer_guild(g_inf), lvc)).queue
async def get_queue(
g_: IntCastable | MaybeGuildIDAware, lvc: lv.Lavalink, /
) -> Panic[QueueList]:
return (await get_data(infer_guild(g_), lvc)).queue


def get_repeat_emoji(q: QueueList, /):
Expand Down Expand Up @@ -416,9 +418,11 @@ async def generate_nowplaying_embed(
return embed


async def wait_until_current_track_valid(g_inf: GuildOrInferable, lvc: lv.Lavalink, /):
async def wait_until_current_track_valid(
g_: IntCastable | MaybeGuildIDAware, lvc: lv.Lavalink, /
):
while True:
d = await get_data(infer_guild(g_inf), lvc)
d = await get_data(infer_guild(g_), lvc)
if d.queue.current and d.out_channel_id:
return
await asyncio.sleep(STOP_REFRESH)
18 changes: 13 additions & 5 deletions lyra/src/lib/musicutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import alluka as al
import lavasnek_rs as lv

from .extras.types import MaybeIterable

from .utils import (
Q_CHUNK,
TIMEOUT,
Expand All @@ -17,7 +15,17 @@
err_say,
say,
)
from .extras import Result, Option, MapSig, chunk, chunk_b, map_in_place, to_stamp, wr
from .extras import (
Result,
Option,
MapSig,
IterableOr,
chunk,
chunk_b,
map_in_place,
to_stamp,
wr,
)
from .errors import (
NotConnected,
VotingTimeout,
Expand All @@ -41,8 +49,8 @@ def __init_component__(
*,
guild_check: bool = True,
music_hook: bool = True,
other_checks: MaybeIterable[tj.abc.CheckSig] = (),
other_hooks: MaybeIterable[tj.abc.Hooks[tj.abc.Context]] = (),
other_checks: IterableOr[tj.abc.CheckSig] = (),
other_hooks: IterableOr[tj.abc.Hooks[tj.abc.Context]] = (),
):
comp = tj.Component(name=dunder_name.split('.')[-1].capitalize(), strict=True)
if guild_check:
Expand Down
Loading

0 comments on commit 16c75c9

Please sign in to comment.