Permalink
Browse files

Re-implement timeout using cancellation tokens.

- Use a cancellation token to trigger timeouts
- Allow host to provide a cancellation token for when it shuts down
- Trigger cancellation token for ASP.NET shutdown.
  • Loading branch information...
1 parent 1dcf5c1 commit 892f7afb3b48a3ba33b87a3a8322f266dbbd22bb @davidfowl davidfowl committed Feb 9, 2012
@@ -3,7 +3,6 @@
using Microsoft.Web.Infrastructure.DynamicModuleHelper;
using SignalR.Hosting.AspNet;
using SignalR.Hubs;
-using SignalR.Infrastructure;
[assembly: PreApplicationStartMethod(typeof(AspNetBootstrapper), "Initialize")]
@@ -59,12 +58,9 @@ private static void RegisterHubModule()
private static void OnAppDomainShutdown()
{
- var connectionManager = AspNetHost.DependencyResolver.Resolve<IConnectionManager>();
-
// Close all connections before the app domain goes down.
- // Only signal all connections on a particular appdomain (if this was cross machine we
- // don't want to end up disconnecting everyone on the farm)
- connectionManager.CloseConnections(ConnectionScope.AppDomain).Wait();
+ // Only signal all connections on a particular appdomain
+ AspNetHost.AppDomainTokenSource.Cancel();
}
}
}
@@ -1,8 +1,8 @@
using System;
using System.Linq;
+using System.Threading;
using System.Threading.Tasks;
using System.Web;
-using SignalR.Hosting;
using SignalR.Infrastructure;
namespace SignalR.Hosting.AspNet
@@ -12,6 +12,9 @@ public class AspNetHost : HttpTaskAsyncHandler
private static readonly IDependencyResolver _defaultResolver = new DefaultDependencyResolver();
private static IDependencyResolver _resolver;
+ // This will fire when the app domain is shutting down
+ internal static readonly CancellationTokenSource AppDomainTokenSource = new CancellationTokenSource();
+
private readonly PersistentConnection _connection;
private static readonly Lazy<bool> _hasAcceptWebSocketRequest =
@@ -42,6 +45,9 @@ public override Task ProcessRequestAsync(HttpContextBase context)
// Set the debugging flag
hostContext.Items[HostConstants.DebugMode] = context.IsDebuggingEnabled;
+ // Set the host shutdown token
+ hostContext.Items[HostConstants.ShutdownToken] = AppDomainTokenSource.Token;
+
// Stick the context in here so transports or other asp.net specific logic can
// grab at it.
hostContext.Items["System.Web.HttpContext"] = context;
@@ -1,11 +1,7 @@
using System;
-using System.Collections.Concurrent;
using System.Collections.Generic;
-using System.Data.SqlClient;
-using System.Linq;
using System.Threading;
using System.Threading.Tasks;
-using SignalR.Infrastructure;
using SignalR.MessageBus;
namespace SignalR.ScaleOut
@@ -195,7 +191,7 @@ public Task Send(string eventKey, object value)
throw new NotImplementedException();
}
- public Task<MessageResult> GetMessages(IEnumerable<string> eventKeys, string id)
+ public Task<MessageResult> GetMessages(IEnumerable<string> eventKeys, string id, CancellationToken timeoutToken)
{
throw new NotImplementedException();
}
@@ -130,7 +130,7 @@ private static void ReceiveLoop(InProcessMessageBus bus, string[] eventKeys, str
{
try
{
- bus.GetMessages(eventKeys, id).ContinueWith(task =>
+ bus.GetMessages(eventKeys, id, CancellationToken.None).ContinueWith(task =>
{
if (task.IsFaulted)
{
@@ -1,6 +1,7 @@
using SignalR.Infrastructure;
using SignalR.MessageBus;
using Xunit;
+using System.Threading;
namespace SignalR.Tests
{
@@ -22,7 +23,7 @@ public void ReturnsAllMessagesWhenLastMessageIdIsLessThanAllMessages()
bus.Send("foo", "3").Wait();
bus.Send("foo", "4").Wait();
- var result = bus.GetMessages(new[] { "foo" }, "1").Result;
+ var result = bus.GetMessages(new[] { "foo" }, "1", CancellationToken.None).Result;
Assert.Equal(2, result.Messages.Count);
}
@@ -76,7 +77,7 @@ public void ReturnsMessagesGreaterThanLastMessageIdWhenLastMessageIdNotInStore()
bus.Send("bar", "5").Wait();
bus.Send("foo", "6").Wait();
- var result = bus.GetMessages(new[] { "foo" }, "3").Result;
+ var result = bus.GetMessages(new[] { "foo" }, "3", CancellationToken.None).Result;
Assert.Equal(2, result.Messages.Count);
}
View
@@ -4,7 +4,6 @@ public enum CommandType
{
AddToGroup,
RemoveFromGroup,
- Disconnect,
- Timeout
+ Disconnect
}
}
View
@@ -1,5 +1,6 @@
using System.Collections.Generic;
using System.Linq;
+using System.Threading;
using System.Threading.Tasks;
using SignalR.Infrastructure;
using SignalR.MessageBus;
@@ -16,7 +17,6 @@ public class Connection : IConnection, IReceivingConnection
private readonly HashSet<string> _groups;
private readonly ITraceManager _trace;
private bool _disconnected;
- private bool _timedOut;
public Connection(IMessageBus messageBus,
IJsonSerializer jsonSerializer,
@@ -58,16 +58,16 @@ public Task Send(object value)
return SendMessage(_connectionId, value);
}
- public Task<PersistentResponse> ReceiveAsync()
+ public Task<PersistentResponse> ReceiveAsync(CancellationToken timeoutToken)
{
- return _messageBus.GetMessages(Signals, null)
- .Then(result => GetResponse(result));
+ return _messageBus.GetMessages(Signals, null, timeoutToken)
+ .Then(result => GetResponse(result));
}
- public Task<PersistentResponse> ReceiveAsync(string messageId)
+ public Task<PersistentResponse> ReceiveAsync(string messageId, CancellationToken timeoutToken)
{
- return _messageBus.GetMessages(Signals, messageId)
- .Then(result => GetResponse(result));
+ return _messageBus.GetMessages(Signals, messageId, timeoutToken)
+ .Then(result => GetResponse(result));
}
public Task SendCommand(SignalCommand command)
@@ -85,7 +85,7 @@ private PersistentResponse GetResponse(MessageResult result)
MessageId = result.LastMessageId,
Messages = messageValues,
Disconnect = _disconnected,
- TimedOut = _timedOut
+ TimedOut = result.TimedOut
};
PopulateResponseState(response);
@@ -126,9 +126,6 @@ private void ProcessCommand(SignalCommand command)
case CommandType.Disconnect:
_disconnected = true;
break;
- case CommandType.Timeout:
- _timedOut = true;
- break;
}
}
@@ -15,15 +15,5 @@ public static Task Close(this IReceivingConnection connection)
return connection.SendCommand(command);
}
-
- public static Task Timeout(this IReceivingConnection connection)
- {
- var command = new SignalCommand
- {
- Type = CommandType.Timeout
- };
-
- return connection.SendCommand(command);
- }
}
}
@@ -1,6 +1,5 @@
using System;
using System.Linq;
-using System.Threading.Tasks;
using SignalR.Hubs;
using SignalR.Infrastructure;
using SignalR.MessageBus;
@@ -10,7 +9,7 @@ namespace SignalR
public class ConnectionManager : IConnectionManager
{
private readonly IDependencyResolver _resolver;
-
+
public ConnectionManager(IDependencyResolver resolver)
{
_resolver = resolver;
@@ -37,24 +36,6 @@ public dynamic GetClients(string hubName)
return new ClientAgent(connection, hubName);
}
- public Task CloseConnections(string scope)
- {
- // Get the connection that represents all clients (even if the type really means nothing
- // since we're just broadcasting to all connected clients
- var connection = GetConnection<PersistentConnection>();
-
- // We're targeting all clients
- string key = SignalCommand.AddCommandSuffix(scope);
-
- // Tell them all to go away
- var command = new SignalCommand
- {
- Type = CommandType.Timeout
- };
-
- return connection.Broadcast(key, command);
- }
-
private IConnection GetConnection(string connectionType)
{
return new Connection(_resolver.Resolve<IMessageBus>(),
View
@@ -1,12 +0,0 @@
-using System;
-
-namespace SignalR
-{
- public static class ConnectionScope
- {
- public static readonly string Global = typeof(PersistentConnection).FullName;
- // TODO: Come up with something here
- // public static readonly string Machine = typeof(PersistentConnection).FullName;
- public static readonly string AppDomain = Guid.NewGuid().ToString();
- }
-}
@@ -16,5 +16,7 @@ public static class HostConstants
/// The host should set this if the web socket url is different
/// </summary>
public static readonly string WebSocketServerUrl = "webSocketServerUrl";
+
+ public static readonly string ShutdownToken = "shutdownToken";
}
}
@@ -1,4 +1,5 @@
-namespace SignalR.Hosting
+using System.Threading;
+namespace SignalR.Hosting
{
public static class HostContextExtensions
{
@@ -26,5 +27,10 @@ public static string WebSocketServerUrl(this HostContext context)
{
return context.GetValue<string>(HostConstants.WebSocketServerUrl);
}
+
+ public static CancellationToken HostShutdownToken(this HostContext context)
+ {
+ return context.GetValue<CancellationToken>(HostConstants.ShutdownToken);
+ }
}
}
@@ -1,4 +1,4 @@
-using System.Threading.Tasks;
+using System;
using SignalR.Hubs;
namespace SignalR
@@ -7,6 +7,5 @@ public interface IConnectionManager
{
dynamic GetClients<T>() where T : IHub;
IConnection GetConnection<T>() where T : PersistentConnection;
- Task CloseConnections(string scope);
}
}
@@ -1,11 +1,12 @@
-using System.Threading.Tasks;
+using System.Threading;
+using System.Threading.Tasks;
namespace SignalR
{
public interface IReceivingConnection
{
- Task<PersistentResponse> ReceiveAsync();
- Task<PersistentResponse> ReceiveAsync(string messageId);
+ Task<PersistentResponse> ReceiveAsync(CancellationToken timeoutToken);
+ Task<PersistentResponse> ReceiveAsync(string messageId, CancellationToken timeoutToken);
Task SendCommand(SignalCommand command);
}
@@ -1,11 +1,12 @@
using System.Collections.Generic;
+using System.Threading;
using System.Threading.Tasks;
namespace SignalR.MessageBus
{
public interface IMessageBus
{
- Task<MessageResult> GetMessages(IEnumerable<string> eventKeys, string id);
+ Task<MessageResult> GetMessages(IEnumerable<string> eventKeys, string id, CancellationToken timeoutToken);
Task Send(string eventKey, object value);
}
}
@@ -49,13 +49,13 @@ public InProcessMessageBus(ITraceManager traceManager, bool garbageCollectMessag
}
}
- public Task<MessageResult> GetMessages(IEnumerable<string> eventKeys, string id)
+ public Task<MessageResult> GetMessages(IEnumerable<string> eventKeys, string id, CancellationToken timeoutToken)
{
if (String.IsNullOrEmpty(id))
{
// Wait for new messages
_trace.Source.TraceInformation("MessageBus: New connection waiting for messages");
- return WaitForMessages(eventKeys);
+ return WaitForMessages(eventKeys, timeoutToken);
}
try
@@ -68,7 +68,7 @@ public Task<MessageResult> GetMessages(IEnumerable<string> eventKeys, string id)
{
// Connection already has the latest message, so start wating
_trace.Source.TraceInformation("MessageBus: Connection waiting for new messages from id {0}", id);
- return WaitForMessages(eventKeys);
+ return WaitForMessages(eventKeys, timeoutToken);
}
var messages = eventKeys.SelectMany(key => GetMessagesSince(key, uid));
@@ -82,7 +82,7 @@ public Task<MessageResult> GetMessages(IEnumerable<string> eventKeys, string id)
// Wait for new messages
_trace.Source.TraceInformation("MessageBus: Connection waiting for new messages from id {0}", id);
- return WaitForMessages(eventKeys);
+ return WaitForMessages(eventKeys, timeoutToken);
}
finally
{
@@ -186,17 +186,26 @@ private IList<InMemoryMessage> GetMessagesSince(string eventKey, ulong id)
return snapshot.GetRange(startIndex, snapshot.Count - startIndex);
}
- private Task<MessageResult> WaitForMessages(IEnumerable<string> eventKeys)
+ private Task<MessageResult> WaitForMessages(IEnumerable<string> eventKeys, CancellationToken timeoutToken)
{
var tcs = new TaskCompletionSource<MessageResult>();
int callbackCalled = 0;
Action<IList<InMemoryMessage>> callback = null;
+ timeoutToken.Register(() =>
+ {
+ if (Interlocked.Exchange(ref callbackCalled, 1) == 0)
+ {
+ string lastMessageId = _lastMessageId.ToString(CultureInfo.InvariantCulture);
+ tcs.TrySetResult(new MessageResult(lastMessageId, timedOut: true));
+ }
+ });
+
callback = messages =>
{
if (Interlocked.Exchange(ref callbackCalled, 1) == 0)
{
- tcs.SetResult(GetMessageResult(messages));
+ tcs.TrySetResult(GetMessageResult(messages));
}
// Remove callback for all keys
@@ -4,8 +4,17 @@ namespace SignalR.MessageBus
{
public struct MessageResult
{
+ private static readonly List<Message> _emptyList = new List<Message>();
+
public IList<Message> Messages { get; private set; }
public string LastMessageId { get; private set; }
+ public bool TimedOut { get; set; }
+
+ public MessageResult(string lastMessageId, bool timedOut)
+ : this(_emptyList, lastMessageId)
+ {
+ TimedOut = timedOut;
+ }
public MessageResult(IList<Message> messages, string lastMessageId)
: this()
Oops, something went wrong.

0 comments on commit 892f7af

Please sign in to comment.