Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SignalR performance: track groups per connection, remove on disconnect #53486

Merged
merged 10 commits into from
Feb 7, 2024
36 changes: 35 additions & 1 deletion src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.Logging;
using Microsoft.AspNetCore.Http.Features;
alex-jitbit marked this conversation as resolved.
Show resolved Hide resolved

namespace Microsoft.AspNetCore.SignalR;

Expand Down Expand Up @@ -42,6 +43,16 @@ public override Task AddToGroupAsync(string connectionId, string groupName, Canc
return Task.CompletedTask;
}

//track groups in the connection object
alex-jitbit marked this conversation as resolved.
Show resolved Hide resolved
var groupNames = connection.Features.GetRequiredFeature<GroupTrackerFeature>().Groups;
lock (groupNames)
{
if (!groupNames.Add(groupName))
{
return Task.CompletedTask; // Connection already in group
alex-jitbit marked this conversation as resolved.
Show resolved Hide resolved
}
}

_groups.Add(connection, groupName);
// Connection disconnected while adding to group, remove it in case the Add was called after OnDisconnectedAsync removed items from the group
if (connection.ConnectionAborted.IsCancellationRequested)
Expand All @@ -64,6 +75,16 @@ public override Task RemoveFromGroupAsync(string connectionId, string groupName,
return Task.CompletedTask;
}

//remove from previouslyy saved groups
alex-jitbit marked this conversation as resolved.
Show resolved Hide resolved
var groupNames = connection.Features.GetRequiredFeature<GroupTrackerFeature>().Groups;
lock (groupNames)
{
if (!groupNames.Remove(groupName))
{
return Task.CompletedTask; // Connection not in group
alex-jitbit marked this conversation as resolved.
Show resolved Hide resolved
}
}

alex-jitbit marked this conversation as resolved.
Show resolved Hide resolved
_groups.Remove(connectionId, groupName);

return Task.CompletedTask;
Expand Down Expand Up @@ -271,14 +292,22 @@ public override Task SendUserAsync(string userId, string methodName, object?[] a
public override Task OnConnectedAsync(HubConnectionContext connection)
{
_connections.Add(connection);
connection.Features.Set(new GroupTrackerFeature()); //add a group tracker to every concection
alex-jitbit marked this conversation as resolved.
Show resolved Hide resolved
return Task.CompletedTask;
}

/// <inheritdoc />
public override Task OnDisconnectedAsync(HubConnectionContext connection)
{
//now remove from tracked groups one by one
alex-jitbit marked this conversation as resolved.
Show resolved Hide resolved
//this is faster than calling _groups.RemoveDisconnectedConnection
//because that method iteratas through ALL the groups
alex-jitbit marked this conversation as resolved.
Show resolved Hide resolved
foreach (var grpName in connection.Features.GetRequiredFeature<GroupTrackerFeature>().Groups.ToArray()) //copy to array because groups can be modified in other methods, prevent "collection was modified"
alex-jitbit marked this conversation as resolved.
Show resolved Hide resolved
alex-jitbit marked this conversation as resolved.
Show resolved Hide resolved
{
_groups.Remove(connection.ConnectionId, grpName);
}

_connections.Remove(connection);
_groups.RemoveDisconnectedConnection(connection.ConnectionId);

return Task.CompletedTask;
}
Expand Down Expand Up @@ -351,6 +380,11 @@ public override async Task<T> InvokeConnectionAsync<T>(string connectionId, stri
}
}

private class GroupTrackerFeature
alex-jitbit marked this conversation as resolved.
Show resolved Hide resolved
{
public HashSet<string> Groups { get; } = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
}

/// <inheritdoc/>
public override Task SetConnectionResultAsync(string connectionId, CompletionMessage result)
{
Expand Down
10 changes: 0 additions & 10 deletions src/SignalR/server/Core/src/Internal/HubGroupList.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System.Collections;
using System.Collections.Concurrent;
using System.Linq;

namespace Microsoft.AspNetCore.SignalR.Internal;

Expand Down Expand Up @@ -43,15 +42,6 @@ public void Remove(string connectionId, string groupName)
}
}

public void RemoveDisconnectedConnection(string connectionId)
{
var groupNames = _groups.Where(x => x.Value.ContainsKey(connectionId)).Select(x => x.Key);
foreach (var groupName in groupNames)
{
Remove(connectionId, groupName);
}
}

public int Count => _groups.Count;

public IEnumerator<ConcurrentDictionary<string, HubConnectionContext>> GetEnumerator()
Expand Down