Skip to content

Commit

Permalink
Merge pull request #7080 from crobibero/ws-token
Browse files Browse the repository at this point in the history
  • Loading branch information
crobibero committed Jan 4, 2022
2 parents fddcaf1 + 0765fd5 commit c6a1dcf
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 43 deletions.
11 changes: 1 addition & 10 deletions Emby.Server.Implementations/HttpServer/WebSocketConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,14 @@ public class WebSocketConnection : IWebSocketConnection, IDisposable
/// <param name="logger">The logger.</param>
/// <param name="socket">The socket.</param>
/// <param name="remoteEndPoint">The remote end point.</param>
/// <param name="query">The query.</param>
public WebSocketConnection(
ILogger<WebSocketConnection> logger,
WebSocket socket,
IPAddress? remoteEndPoint,
IQueryCollection query)
IPAddress? remoteEndPoint)
{
_logger = logger;
_socket = socket;
RemoteEndPoint = remoteEndPoint;
QueryString = query;

_jsonOptions = JsonDefaults.Options;
LastActivityDate = DateTime.Now;
Expand Down Expand Up @@ -81,12 +78,6 @@ public class WebSocketConnection : IWebSocketConnection, IDisposable
/// <inheritdoc />
public DateTime LastKeepAliveDate { get; set; }

/// <summary>
/// Gets the query string.
/// </summary>
/// <value>The query string.</value>
public IQueryCollection QueryString { get; }

/// <summary>
/// Gets the state.
/// </summary>
Expand Down
6 changes: 3 additions & 3 deletions Emby.Server.Implementations/HttpServer/WebSocketManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq;
using System.Net.WebSockets;
using System.Threading.Tasks;
using MediaBrowser.Common.Extensions;
using MediaBrowser.Controller.Net;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -50,16 +51,15 @@ public async Task WebSocketRequestHandler(HttpContext context)
using var connection = new WebSocketConnection(
_loggerFactory.CreateLogger<WebSocketConnection>(),
webSocket,
context.Connection.RemoteIpAddress,
context.Request.Query)
context.GetNormalizedRemoteIp())
{
OnReceive = ProcessWebSocketMessageReceived
};

var tasks = new Task[_webSocketListeners.Length];
for (var i = 0; i < _webSocketListeners.Length; ++i)
{
tasks[i] = _webSocketListeners[i].ProcessWebSocketConnectedAsync(connection);
tasks[i] = _webSocketListeners[i].ProcessWebSocketConnectedAsync(connection, context);
}

await Task.WhenAll(tasks).ConfigureAwait(false);
Expand Down
37 changes: 19 additions & 18 deletions Emby.Server.Implementations/Session/SessionWebSocketListener.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Net.WebSockets;
using System.Threading;
using System.Threading.Tasks;
using MediaBrowser.Common.Extensions;
using MediaBrowser.Controller.Net;
using MediaBrowser.Controller.Session;
using MediaBrowser.Model.Net;
Expand Down Expand Up @@ -50,16 +51,10 @@ public sealed class SessionWebSocketListener : IWebSocketListener, IDisposable
/// </summary>
private readonly object _webSocketsLock = new object();

/// <summary>
/// The _session manager.
/// </summary>
private readonly ISessionManager _sessionManager;

/// <summary>
/// The _logger.
/// </summary>
private readonly ILogger<SessionWebSocketListener> _logger;
private readonly ILoggerFactory _loggerFactory;
private readonly IAuthorizationContext _authorizationContext;

/// <summary>
/// The KeepAlive cancellation token.
Expand All @@ -72,14 +67,17 @@ public sealed class SessionWebSocketListener : IWebSocketListener, IDisposable
/// <param name="logger">The logger.</param>
/// <param name="sessionManager">The session manager.</param>
/// <param name="loggerFactory">The logger factory.</param>
/// <param name="authorizationContext">The authorization context.</param>
public SessionWebSocketListener(
ILogger<SessionWebSocketListener> logger,
ISessionManager sessionManager,
ILoggerFactory loggerFactory)
ILoggerFactory loggerFactory,
IAuthorizationContext authorizationContext)
{
_logger = logger;
_sessionManager = sessionManager;
_loggerFactory = loggerFactory;
_authorizationContext = authorizationContext;
}

/// <inheritdoc />
Expand All @@ -97,35 +95,38 @@ public Task ProcessMessageAsync(WebSocketMessageInfo message)
=> Task.CompletedTask;

/// <inheritdoc />
public async Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection)
public async Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection, HttpContext httpContext)
{
var session = await GetSession(connection.QueryString, connection.RemoteEndPoint.ToString()).ConfigureAwait(false);
var session = await GetSession(httpContext, connection.RemoteEndPoint?.ToString()).ConfigureAwait(false);
if (session != null)
{
EnsureController(session, connection);
await KeepAliveWebSocket(connection).ConfigureAwait(false);
}
else
{
_logger.LogWarning("Unable to determine session based on query string: {0}", connection.QueryString);
_logger.LogWarning("Unable to determine session based on query string: {0}", httpContext.Request.QueryString);
}
}

private Task<SessionInfo> GetSession(IQueryCollection queryString, string remoteEndpoint)
private async Task<SessionInfo> GetSession(HttpContext httpContext, string remoteEndpoint)
{
if (queryString == null)
var authorizationInfo = await _authorizationContext.GetAuthorizationInfo(httpContext)
.ConfigureAwait(false);

if (!authorizationInfo.IsAuthenticated)
{
return null;
}

var token = queryString["api_key"];
if (string.IsNullOrWhiteSpace(token))
var deviceId = authorizationInfo.DeviceId;
if (httpContext.Request.Query.TryGetValue("deviceId", out var queryDeviceId))
{
return null;
deviceId = queryDeviceId;
}

var deviceId = queryString["deviceId"];
return _sessionManager.GetSessionByAuthenticationToken(token, deviceId, remoteEndpoint);
return await _sessionManager.GetSessionByAuthenticationToken(authorizationInfo.Token, deviceId, remoteEndpoint)
.ConfigureAwait(false);
}

private void EnsureController(SessionInfo session, IWebSocketConnection connection)
Expand Down
3 changes: 2 additions & 1 deletion MediaBrowser.Controller/Net/BasePeriodicWebSocketListener.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using System.Threading.Tasks;
using MediaBrowser.Model.Net;
using MediaBrowser.Model.Session;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;

namespace MediaBrowser.Controller.Net
Expand Down Expand Up @@ -95,7 +96,7 @@ public Task ProcessMessageAsync(WebSocketMessageInfo message)
}

/// <inheritdoc />
public Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection) => Task.CompletedTask;
public Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection, HttpContext httpContext) => Task.CompletedTask;

/// <summary>
/// Starts sending messages over a web socket.
Expand Down
6 changes: 0 additions & 6 deletions MediaBrowser.Controller/Net/IWebSocketConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,6 @@ public interface IWebSocketConnection
/// <value>The date of last Keeplive received.</value>
DateTime LastKeepAliveDate { get; set; }

/// <summary>
/// Gets the query string.
/// </summary>
/// <value>The query string.</value>
IQueryCollection QueryString { get; }

/// <summary>
/// Gets or sets the receive action.
/// </summary>
Expand Down
4 changes: 3 additions & 1 deletion MediaBrowser.Controller/Net/IWebSocketListener.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;

namespace MediaBrowser.Controller.Net
{
Expand All @@ -18,7 +19,8 @@ public interface IWebSocketListener
/// Processes a new web socket connection.
/// </summary>
/// <param name="connection">An instance of the <see cref="IWebSocketConnection"/> interface.</param>
/// <param name="httpContext">The current http context.</param>
/// <returns>Task.</returns>
Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection);
Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection, HttpContext httpContext);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public class WebSocketConnectionTests
[Fact]
public void DeserializeWebSocketMessage_SingleSegment_Success()
{
var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!);
var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!);
var bytes = File.ReadAllBytes("Test Data/HttpServer/ForceKeepAlive.json");
con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed);
Assert.Equal(109, bytesConsumed);
Expand All @@ -23,7 +23,7 @@ public void DeserializeWebSocketMessage_SingleSegment_Success()
public void DeserializeWebSocketMessage_MultipleSegments_Success()
{
const int SplitPos = 64;
var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!);
var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!);
var bytes = File.ReadAllBytes("Test Data/HttpServer/ForceKeepAlive.json");
var seg1 = new BufferSegment(new Memory<byte>(bytes, 0, SplitPos));
var seg2 = seg1.Append(new Memory<byte>(bytes, SplitPos, bytes.Length - SplitPos));
Expand All @@ -34,7 +34,7 @@ public void DeserializeWebSocketMessage_MultipleSegments_Success()
[Fact]
public void DeserializeWebSocketMessage_ValidPartial_Success()
{
var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!);
var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!);
var bytes = File.ReadAllBytes("Test Data/HttpServer/ValidPartial.json");
con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed);
Assert.Equal(109, bytesConsumed);
Expand All @@ -43,7 +43,7 @@ public void DeserializeWebSocketMessage_ValidPartial_Success()
[Fact]
public void DeserializeWebSocketMessage_Partial_ThrowJsonException()
{
var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!);
var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!);
var bytes = File.ReadAllBytes("Test Data/HttpServer/Partial.json");
Assert.Throws<JsonException>(() => con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed));
}
Expand Down

0 comments on commit c6a1dcf

Please sign in to comment.