Skip to content
This repository has been archived by the owner on Jan 1, 2023. It is now read-only.

Commit

Permalink
Fixed the bot still being silent after session resumes (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
fdnt7 authored Oct 3, 2022
1 parent 0d8c757 commit db48b19
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 42 deletions.
36 changes: 15 additions & 21 deletions lyra/src/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .lib import (
EventHandler,
LyraConfig,
NodeRef,
NodeDataRef,
LyraDBClientType,
LyraDBCollectionType,
repeat_emojis,
Expand Down Expand Up @@ -91,44 +91,32 @@ async def prefix_getter(
@_client.with_listener()
async def on_started(
_: hk.StartedEvent,
bot: al.Injected[hk.GatewayBot],
client: al.Injected[tj.Client],
):
emojis = await client.rest.fetch_guild_emojis(lyra_config.emoji_guild)
emoji_cache.update({e.name: e for e in emojis})
repeat_emojis.extend(emoji_cache[f'repeat{n}_b'] for n in range(3))
logger.info("Fetched emojis from Lýra's Emoji Server")

mongo_client = __init_mongo_client__()

prefs_db = mongo_client.get_database('prefs')
guilds_co = prefs_db.get_collection('guilds')

node_ref = NodeRef({})

(
client.set_type_dependency(LyraDBClientType, mongo_client)
.set_type_dependency(LyraDBCollectionType, guilds_co)
.set_type_dependency(EmojiCache, emoji_cache)
.set_type_dependency(NodeRef, node_ref)
)

repeat_emojis.extend(emoji_cache[f'repeat{n}_b'] for n in range(3))


@_client.with_listener()
async def on_shard_ready(
event: hk.ShardReadyEvent,
client: al.Injected[tj.Client],
) -> None:
"""Event that triggers when the hikari gateway is ready."""
node_data_ref = NodeDataRef({})

host = (
os.environ['LAVALINK_HOST']
if os.environ.get('IN_DOCKER', False)
else '127.0.0.1'
)

bot_u = bot.get_me()
assert bot_u

builder = (
lv.LavalinkBuilder(event.my_user.id, lyra_config.token)
lv.LavalinkBuilder(bot_u.id, lyra_config.token)
.set_host(host)
.set_password(os.environ['LAVALINK_PWD'])
.set_port(int(os.environ['LAVALINK_PORT']))
Expand All @@ -137,7 +125,13 @@ async def on_shard_ready(

lvc = await builder.build(EventHandler())

client.set_type_dependency(lv.Lavalink, lvc)
(
client.set_type_dependency(LyraDBClientType, mongo_client)
.set_type_dependency(LyraDBCollectionType, guilds_co)
.set_type_dependency(EmojiCache, emoji_cache)
.set_type_dependency(NodeDataRef, node_data_ref)
.set_type_dependency(lv.Lavalink, lvc)
)


@_client.with_listener()
Expand Down
2 changes: 1 addition & 1 deletion lyra/src/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .connections import cleanup
from .lava import (
EventHandler,
NodeRef,
NodeDataRef,
repeat_emojis,
access_data,
get_data,
Expand Down
26 changes: 13 additions & 13 deletions lyra/src/lib/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
say,
)
from .cmd import CommandIdentifier, get_full_cmd_repr_from_identifier
from .lava import ConnectionCommandsInvokedEvent, NodeRef, get_data, access_data
from .lava import ConnectionCommandsInvokedEvent, NodeDataRef, access_data


logger = logging.getLogger(lgfmt(__name__))
Expand Down Expand Up @@ -54,11 +54,11 @@ async def join(

bot = ctx.client.get_type_dependency(hk.GatewayBot)
cfg = ctx.client.get_type_dependency(LyraDBCollectionType)
nodes = ctx.client.get_type_dependency(NodeRef)
ndt = ctx.client.get_type_dependency(NodeDataRef)
assert (
not isinstance(bot, al.abc.Undefined)
and not isinstance(cfg, al.abc.Undefined)
and not isinstance(nodes, al.abc.Undefined)
and not isinstance(ndt, al.abc.Undefined)
)

if channel is None:
Expand Down Expand Up @@ -136,7 +136,7 @@ async def join(
await lvc.create_session(sess_conn)

async with access_data(ctx, lvc) as d:
nodes.setdefault(ctx.guild_id, d)
ndt.setdefault(ctx.guild_id, d)
d.out_channel_id = ctx.channel_id

is_stage = isinstance(ctx.cache.get_guild_channel(new_ch), hk.GuildStageChannel)
Expand Down Expand Up @@ -165,9 +165,9 @@ async def leave(ctx: tj.abc.Context, lvc: lv.Lavalink, /) -> Fallible[hk.Snowfla
assert ctx.guild_id

bot = ctx.client.get_type_dependency(hk.GatewayBot)
nodes = ctx.client.get_type_dependency(NodeRef)
ndt = ctx.client.get_type_dependency(NodeDataRef)
assert not isinstance(bot, al.abc.Undefined) and not isinstance(
nodes, al.abc.Undefined
ndt, al.abc.Undefined
)

if not (
Expand All @@ -181,7 +181,7 @@ async def leave(ctx: tj.abc.Context, lvc: lv.Lavalink, /) -> Fallible[hk.Snowfla

await others_not_in_vc_check_impl(ctx, conn)

await cleanup(ctx.guild_id, nodes, lvc, bot=bot, also_del_np_msg=False)
await cleanup(ctx.guild_id, ndt, lvc, bot=bot, also_del_np_msg=False)

bot.dispatch(ConnectionCommandsInvokedEvent(bot))
logger.info(f"In guild {ctx.guild_id} left channel {curr_channel} gracefully")
Expand All @@ -191,7 +191,7 @@ async def leave(ctx: tj.abc.Context, lvc: lv.Lavalink, /) -> Fallible[hk.Snowfla
@t.overload
async def cleanup(
guild: hk.Snowflakeish,
nodes: NodeRef,
ndt: NodeDataRef,
lvc: lv.Lavalink,
/,
bot: hk.GatewayBot = ...,
Expand All @@ -205,7 +205,7 @@ async def cleanup(
@t.overload
async def cleanup(
guild: hk.Snowflakeish,
nodes: NodeRef,
ndt: NodeDataRef,
lvc: lv.Lavalink,
/,
bot: hk.GatewayBot = ...,
Expand All @@ -219,7 +219,7 @@ async def cleanup(
@t.overload
async def cleanup(
guild: hk.Snowflakeish,
nodes: NodeRef,
ndt: NodeDataRef,
lvc: lv.Lavalink,
/,
bot: Option[hk.GatewayBot] = None,
Expand All @@ -232,7 +232,7 @@ async def cleanup(

async def cleanup(
guild: hk.Snowflakeish,
nodes: NodeRef,
ndt: NodeDataRef,
lvc: lv.Lavalink,
/,
bot: Option[hk.GatewayBot] = None,
Expand All @@ -246,11 +246,11 @@ async def cleanup(
await bot.update_voice_state(guild, None)
if also_del_np_msg:
assert bot
d = await get_data(guild, lvc)
d = ndt[guild]
if d.out_channel_id and d.nowplaying_msg:
await bot.rest.delete_messages(d.out_channel_id, d.nowplaying_msg)
await lvc.wait_for_connection_info_remove(guild)
nodes.pop(guild)
ndt.pop(guild)
await lvc.remove_guild_node(guild)
await lvc.remove_guild_from_loops(guild)

Expand Down
2 changes: 1 addition & 1 deletion lyra/src/lib/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@

__developers__: t.Final = frozenset((548850193202675713, 626062879531204618))
"""Who the `debug` commands can be used"""
__version__: t.Final = '2.4.2'
__version__: t.Final = '2.4.2-hotfix.1'
2 changes: 1 addition & 1 deletion lyra/src/lib/lava/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pyright: reportUnusedImport=false
from .utils import (
NodeData,
NodeRef,
NodeDataRef,
QueueList,
Bands,
RepeatMode,
Expand Down
2 changes: 1 addition & 1 deletion lyra/src/lib/lava/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ async def edit_now_playing_components(
await rest.edit_message(self.out_channel_id, _np_msg, components=components)


NodeRef = t.NewType('NodeRef', dict[int, NodeData])
NodeDataRef = t.NewType('NodeDataRef', dict[int, NodeData])


class BaseEventHandler(abc.ABC):
Expand Down
8 changes: 4 additions & 4 deletions lyra/src/modules/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
NotConnectedError,
)
from ..lib.cmd import CommandIdentifier as C, with_identifier
from ..lib.lava import ConnectionCommandsInvokedEvent, NodeRef, get_data
from ..lib.lava import ConnectionCommandsInvokedEvent, NodeDataRef, get_data
from ..lib.music import __init_component__
from ..lib.connections import logger, cleanup, join_impl_precaught, leave

Expand Down Expand Up @@ -44,7 +44,7 @@ async def on_voice_state_update(
client: al.Injected[tj.Client],
bot: al.Injected[hk.GatewayBot],
lvc: al.Injected[lv.Lavalink],
nodes: al.Injected[NodeRef],
ndt: al.Injected[NodeDataRef],
):
def conn():
return t.cast(
Expand Down Expand Up @@ -90,7 +90,7 @@ async def get_members_in_vc() -> frozenset[hk.VoiceState]:

if not conn_cmd_invoked and old and old.user_id == bot_u.id:
if not new.channel_id:
await cleanup(event.guild_id, nodes, lvc, bot=bot, also_disconn=False)
await cleanup(event.guild_id, ndt, lvc, bot=bot, also_disconn=False)
await bot.rest.create_message(
out_ch,
f"❕📎 ~~<#{(_vc := old.channel_id)}>~~ `(Bot was forcefully disconnected)`",
Expand Down Expand Up @@ -131,7 +131,7 @@ async def on_everyone_leaves_vc():
__conn = conn()
assert __conn

await cleanup(event.guild_id, nodes, lvc, bot=bot, also_del_np_msg=False)
await cleanup(event.guild_id, ndt, lvc, bot=bot, also_del_np_msg=False)

_vc: int = __conn['channel_id']
logger.info(
Expand Down

0 comments on commit db48b19

Please sign in to comment.