Skip to content

Commit

Permalink
[Feature] Voice reconnection and resuming (#2873)
Browse files Browse the repository at this point in the history
* Voice receive fix (use system-selected port)

* Update SocketGuild.cs

* Reconnect voice after moved, resume voice connection, don't invoke Disconnected event when is going to reconnect

* no more collection primitives

* Disconnected event rallback & dispose audio client after finished

* Update src/Discord.Net.WebSocket/Audio/AudioClient.cs

* Update src/Discord.Net.WebSocket/Audio/AudioClient.cs

---------
  • Loading branch information
Lepterion committed Mar 14, 2024
1 parent d68e06e commit 09680c5
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 15 deletions.
14 changes: 14 additions & 0 deletions src/Discord.Net.WebSocket/API/Voice/ResumeParams.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using Newtonsoft.Json;

namespace Discord.API.Voice
{
public class ResumeParams
{
[JsonProperty("server_id")]
public ulong ServerId { get; set; }
[JsonProperty("session_id")]
public string SessionId { get; set; }
[JsonProperty("token")]
public string Token { get; set; }
}
}
165 changes: 154 additions & 11 deletions src/Discord.Net.WebSocket/Audio/AudioClient.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Discord.API.Voice;
using Discord.Audio.Streams;
using Discord.Logging;
using Discord.Net;
using Discord.Net.Converters;
using Discord.WebSocket;
using Newtonsoft.Json;
Expand All @@ -9,18 +10,23 @@
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Discord.Audio
{
//TODO: Add audio reconnecting
internal partial class AudioClient : IAudioClient
{
private static readonly int ConnectionTimeoutMs = 30000; // 30 seconds
private static readonly int KeepAliveIntervalMs = 5000; // 5 seconds

private static readonly int[] BlacklistedResumeCodes = new int[]
{
4001, 4002, 4003, 4004, 4005, 4006, 4009, 4012, 1014, 4016
};

private struct StreamPair
{
public AudioInStream Reader;
Expand Down Expand Up @@ -49,13 +55,16 @@ public StreamPair(AudioInStream reader, AudioOutStream writer)
private ulong _userId;
private uint _ssrc;
private bool _isSpeaking;
private StopReason _stopReason;
private bool _resuming;

public SocketGuild Guild { get; }
public DiscordVoiceAPIClient ApiClient { get; private set; }
public int Latency { get; private set; }
public int UdpLatency { get; private set; }
public ulong ChannelId { get; internal set; }
internal byte[] SecretKey { get; private set; }
internal bool IsFinished { get; private set; }

private DiscordSocketClient Discord => Guild.Discord;
public ConnectionState ConnectionState => _connection.State;
Expand All @@ -78,7 +87,7 @@ internal AudioClient(SocketGuild guild, int clientId, ulong channelId)
_connection = new ConnectionManager(_stateLock, _audioLogger, ConnectionTimeoutMs,
OnConnectingAsync, OnDisconnectingAsync, x => ApiClient.Disconnected += x);
_connection.Connected += () => _connectedEvent.InvokeAsync();
_connection.Disconnected += (ex, recon) => _disconnectedEvent.InvokeAsync(ex);
_connection.Disconnected += (exception, _) => _disconnectedEvent.InvokeAsync(exception);
_heartbeatTimes = new ConcurrentQueue<long>();
_keepaliveTimes = new ConcurrentQueue<KeyValuePair<ulong, int>>();
_ssrcMap = new ConcurrentDictionary<uint, ulong>();
Expand Down Expand Up @@ -110,15 +119,30 @@ internal Task StartAsync(string url, ulong userId, string sessionId, string toke
}

public Task StopAsync()
=> _connection.StopAsync();
=> StopAsync(StopReason.Normal);

internal Task StopAsync(StopReason stopReason)
{
_stopReason = stopReason;
return _connection.StopAsync();
}

private async Task OnConnectingAsync()
{
await _audioLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false);
await _audioLogger.DebugAsync($"Connecting ApiClient. Voice server: wss://{_url}").ConfigureAwait(false);
await ApiClient.ConnectAsync($"wss://{_url}?v={DiscordConfig.VoiceAPIVersion}").ConfigureAwait(false);
await _audioLogger.DebugAsync($"Listening on port {ApiClient.UdpPort}").ConfigureAwait(false);
await _audioLogger.DebugAsync("Sending Identity").ConfigureAwait(false);
await ApiClient.SendIdentityAsync(_userId, _sessionId, _token).ConfigureAwait(false);

if (!_resuming)
{
await _audioLogger.DebugAsync("Sending Identity").ConfigureAwait(false);
await ApiClient.SendIdentityAsync(_userId, _sessionId, _token).ConfigureAwait(false);
}
else
{
await _audioLogger.DebugAsync("Sending Resume").ConfigureAwait(false);
await ApiClient.SendResume(_token, _sessionId).ConfigureAwait(false);
}

//Wait for READY
await _connection.WaitAsync().ConfigureAwait(false);
Expand All @@ -128,6 +152,63 @@ private async Task OnDisconnectingAsync(Exception ex)
await _audioLogger.DebugAsync("Disconnecting ApiClient").ConfigureAwait(false);
await ApiClient.DisconnectAsync().ConfigureAwait(false);

if (_stopReason == StopReason.Unknown && ex.InnerException is WebSocketException exception)
{
await _audioLogger.WarningAsync(
$"Audio connection terminated with unknown reason. Code: {exception.ErrorCode} - {exception.Message}",
exception);

if (_resuming)
{
await _audioLogger.WarningAsync("Resume failed");

_resuming = false;

await FinishDisconnect(ex, true);
return;
}

if (BlacklistedResumeCodes.Contains(exception.ErrorCode))
{
await FinishDisconnect(ex, true);
return;
}

await ClearHeartBeaters();

_resuming = true;
return;
}

await FinishDisconnect(ex, _stopReason != StopReason.Moved);

if (_stopReason == StopReason.Normal)
{
await _audioLogger.DebugAsync("Sending Voice State").ConfigureAwait(false);
await Discord.ApiClient.SendVoiceStateUpdateAsync(Guild.Id, null, false, false).ConfigureAwait(false);
}

_stopReason = StopReason.Unknown;
}

private async Task FinishDisconnect(Exception ex, bool wontTryReconnect)
{
await _audioLogger.DebugAsync("Finishing audio connection").ConfigureAwait(false);

await ClearHeartBeaters().ConfigureAwait(false);

if (wontTryReconnect)
{
await _connection.StopAsync().ConfigureAwait(false);

await ClearInputStreamsAsync().ConfigureAwait(false);

IsFinished = true;
}
}

private async Task ClearHeartBeaters()
{
//Wait for tasks to complete
await _audioLogger.DebugAsync("Waiting for heartbeater").ConfigureAwait(false);

Expand All @@ -143,12 +224,11 @@ private async Task OnDisconnectingAsync(Exception ex)
{ }
_lastMessageTime = 0;

await ClearInputStreamsAsync().ConfigureAwait(false);

await _audioLogger.DebugAsync("Sending Voice State").ConfigureAwait(false);
await Discord.ApiClient.SendVoiceStateUpdateAsync(Guild.Id, null, false, false).ConfigureAwait(false);
while (_keepaliveTimes.TryDequeue(out _))
{ }
}

#region Streams
public AudioOutStream CreateOpusStream(int bufferMillis)
{
var outputStream = new OutputStream(ApiClient); //Ignores header
Expand Down Expand Up @@ -217,6 +297,7 @@ internal async Task ClearInputStreamsAsync()
_ssrcMap.Clear();
_streams.Clear();
}
#endregion

private async Task ProcessMessageAsync(VoiceOpCode opCode, object payload)
{
Expand Down Expand Up @@ -285,7 +366,7 @@ private async Task ProcessMessageAsync(VoiceOpCode opCode, object payload)
await _audioLogger.DebugAsync("Received Speaking").ConfigureAwait(false);

var data = (payload as JToken).ToObject<SpeakingEvent>(_serializer);
_ssrcMap[data.Ssrc] = data.UserId; //TODO: Memory Leak: SSRCs are never cleaned up
_ssrcMap[data.Ssrc] = data.UserId;

await _speakingUpdatedEvent.InvokeAsync(data.UserId, data.Speaking);
}
Expand All @@ -299,6 +380,17 @@ private async Task ProcessMessageAsync(VoiceOpCode opCode, object payload)
await _clientDisconnectedEvent.InvokeAsync(data.UserId);
}
break;
case VoiceOpCode.Resumed:
{
await _audioLogger.DebugAsync($"Voice connection resumed: wss://{_url}");
_resuming = false;

_heartbeatTask = RunHeartbeatAsync(_heartbeatInterval, _connection.CancelToken);
_keepaliveTask = RunKeepaliveAsync(_connection.CancelToken);

_ = _connection.CompleteAsync();
}
break;
default:
await _audioLogger.WarningAsync($"Unknown OpCode ({opCode})").ConfigureAwait(false);
break;
Expand Down Expand Up @@ -485,6 +577,49 @@ public async Task SetSpeakingAsync(bool value)
}
}

/// <summary>
/// Waits until all post-disconnect actions are done.
/// </summary>
/// <param name="timeout">Maximum time to wait.</param>
/// <returns>
/// A <see cref="Task"/> that represents an asynchronous process of waiting.
/// </returns>
internal async Task WaitForDisconnectAsync(TimeSpan timeout)
{
if (ConnectionState == ConnectionState.Disconnected)
return;

var completion = new TaskCompletionSource<Exception>();

var cts = new CancellationTokenSource();

var _ = Task.Delay(timeout, cts.Token).ContinueWith(_ =>
{
completion.TrySetException(new TimeoutException("Exceeded maximum time to wait"));
cts.Dispose();
}, cts.Token);

_connection.Disconnected += HandleDisconnectSubscription;

await completion.Task.ConfigureAwait(false);

Task HandleDisconnectSubscription(Exception exception, bool reconnect)
{
try
{
cts.Cancel();
completion.TrySetResult(exception);
}
finally
{
_connection.Disconnected -= HandleDisconnectSubscription;
cts.Dispose();
}

return Task.CompletedTask;
}
}

internal void Dispose(bool disposing)
{
if (disposing)
Expand All @@ -496,5 +631,13 @@ internal void Dispose(bool disposing)
}
/// <inheritdoc />
public void Dispose() => Dispose(true);

internal enum StopReason
{
Unknown = 0,
Normal,
Disconnected,
Moved
}
}
}
2 changes: 1 addition & 1 deletion src/Discord.Net.WebSocket/ConnectionManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ private async Task DisconnectAsync(Exception ex, bool isReconnecting)

await _onDisconnecting(ex).ConfigureAwait(false);

await _disconnectedEvent.InvokeAsync(ex, isReconnecting).ConfigureAwait(false);
State = ConnectionState.Disconnected;
await _disconnectedEvent.InvokeAsync(ex, isReconnecting).ConfigureAwait(false);
await _logger.InfoAsync("Disconnected").ConfigureAwait(false);
}

Expand Down
10 changes: 10 additions & 0 deletions src/Discord.Net.WebSocket/DiscordVoiceApiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,16 @@ public Task SendSetSpeaking(bool value)
});
}

public Task SendResume(string token, string sessionId)
{
return SendAsync(VoiceOpCode.Resume, new ResumeParams
{
ServerId = GuildId,
SessionId = sessionId,
Token = token
});
}

public async Task ConnectAsync(string url)
{
await _connectionLock.WaitAsync().ConfigureAwait(false);
Expand Down
27 changes: 24 additions & 3 deletions src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1450,7 +1450,7 @@ public Task<IReadOnlyCollection<RestGuildEvent>> GetEventsAsync(RequestOptions o
/// <returns>
/// A task that represents the asynchronous get operation. The task result contains a read-only collection
/// of the requested audit log entries.
/// </returns>
/// </returns>
public IAsyncEnumerable<IReadOnlyCollection<RestAuditLogEntry>> GetAuditLogsAsync(int limit, RequestOptions options = null, ulong? beforeId = null, ulong? userId = null, ActionType? actionType = null, ulong? afterId = null)
=> GuildHelper.GetAuditLogsAsync(this, Discord, beforeId, limit, options, userId: userId, actionType: actionType, afterId: afterId);

Expand Down Expand Up @@ -1687,7 +1687,7 @@ internal async Task<SocketVoiceState> AddOrUpdateVoiceStateAsync(ClientState sta
if (after.VoiceChannel != null && _audioClient.ChannelId != after.VoiceChannel?.Id)
{
_audioClient.ChannelId = after.VoiceChannel.Id;
await RepopulateAudioStreamsAsync().ConfigureAwait(false);
await _audioClient.StopAsync(Audio.AudioClient.StopReason.Moved);
}
}
else
Expand All @@ -1711,7 +1711,13 @@ internal async Task<SocketVoiceState> AddOrUpdateVoiceStateAsync(ClientState sta
if (_voiceStates.TryRemove(id, out SocketVoiceState voiceState))
{
if (_audioClient != null)
{
await _audioClient.RemoveInputStreamAsync(id).ConfigureAwait(false); //User changed channels, end their stream

if (id == CurrentUser.Id)
await _audioClient.StopAsync(Audio.AudioClient.StopReason.Disconnected);
}

return voiceState;
}
return null;
Expand Down Expand Up @@ -1755,7 +1761,7 @@ internal async Task<IAudioClient> ConnectAudioAsync(ulong channelId, bool selfDe
var audioClient = new AudioClient(this, Discord.GetAudioId(), channelId);
audioClient.Disconnected += async ex =>
{
if (!promise.Task.IsCompleted)
if (promise.Task.IsCompleted && audioClient.IsFinished)
{
try
{ audioClient.Dispose(); }
Expand Down Expand Up @@ -1866,6 +1872,21 @@ internal async Task FinishConnectAudio(string url, string token)
if (_audioClient != null)
{
await RepopulateAudioStreamsAsync().ConfigureAwait(false);

if (_audioClient.ConnectionState != ConnectionState.Disconnected)
{
try
{
await _audioClient.WaitForDisconnectAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(false);
}
catch (TimeoutException)
{
await Discord.LogManager.WarningAsync("Failed to wait for disconnect audio client in time", null).ConfigureAwait(false);
}
}

await Task.Delay(TimeSpan.FromMilliseconds(5)).ConfigureAwait(false);

await _audioClient.StartAsync(url, Discord.CurrentUser.Id, voiceState.VoiceSessionId, token).ConfigureAwait(false);
}
}
Expand Down

0 comments on commit 09680c5

Please sign in to comment.