Skip to content

Commit

Permalink
fix(player): Player.connected being False sometimes (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
ooliver1 committed May 9, 2023
1 parent 59a16ae commit a7a9b8d
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
32 changes: 27 additions & 5 deletions mafic/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from __future__ import annotations

import asyncio
from asyncio import Event
from collections import OrderedDict
from functools import reduce
from logging import getLogger
Expand Down Expand Up @@ -102,6 +104,9 @@ def __init__(
self._last_track: Track | None = None
self._paused: bool = False

self._voice_server_update_event: Event = Event()
self._voice_state_update_event: Event = Event()

def set_state(self, state: PlayerPayload) -> None:
"""Set the state of the player.
Expand Down Expand Up @@ -334,7 +339,8 @@ async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
if channel_id is None: # pyright: ignore[reportUnnecessaryComparison]
# This can happen and is on disconnect.
# Not sure why this is typed as always Snowflake.
return self.cleanup()
await self.disconnect(force=True)
return

channel = self.guild.get_channel(int(channel_id))
if not isinstance(channel, (VoiceChannel, StageChannel)):
Expand All @@ -343,8 +349,10 @@ async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:

self.channel = channel

if self._session_id != before_session_id: # noqa: RET503
await self._dispatch_player_update() # noqa: RET503
if self._session_id != before_session_id:
await self._dispatch_player_update()

self._voice_state_update_event.set()

async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
"""Handle a voice server update.
Expand Down Expand Up @@ -380,10 +388,12 @@ async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:

await self._dispatch_player_update()

self._voice_server_update_event.set()

async def connect(
self,
*,
timeout: float, # noqa: ARG002
timeout: float,
reconnect: bool, # noqa: ARG002
self_mute: bool = False,
self_deaf: bool = False,
Expand Down Expand Up @@ -415,7 +425,19 @@ async def connect(
await self.channel.guild.change_voice_state(
channel=self.channel, self_mute=self_mute, self_deaf=self_deaf
)
self._connected = True
futures = [
self._voice_state_update_event.wait(),
self._voice_server_update_event.wait(),
]

ensured = [asyncio.ensure_future(fut) for fut in futures]
_, pending = await asyncio.wait(
ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED
)

if len(pending) != 0:
await self.disconnect(force=True)
raise asyncio.TimeoutError

async def disconnect(self, *, force: bool = False) -> None:
"""Disconnect from the voice channel.
Expand Down
10 changes: 9 additions & 1 deletion test_bot/bot/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import asyncio
import traceback
from asyncio import gather
from logging import DEBUG, getLogger
Expand Down Expand Up @@ -136,8 +137,15 @@ async def join(inter: Interaction):
if not inter.user.voice:
return await inter.response.send_message("You are not in a voice channel.")

await inter.response.defer()

channel = inter.user.voice.channel
await channel.connect(cls=MyPlayer)

try:
await channel.connect(cls=MyPlayer, timeout=5)
except asyncio.TimeoutError:
return await inter.send("Timed out connecting to voice channel.")

await inter.send(f"Joined {channel.mention}.")


Expand Down

0 comments on commit a7a9b8d

Please sign in to comment.