diff --git a/src/NUnitEngine/nunit-agent/Program.cs b/src/NUnitEngine/nunit-agent/Program.cs
index d132ba094..09f2189eb 100644
--- a/src/NUnitEngine/nunit-agent/Program.cs
+++ b/src/NUnitEngine/nunit-agent/Program.cs
@@ -22,8 +22,6 @@
// ***********************************************************************
using System;
-using System.Runtime.Remoting.Channels;
-using System.Runtime.Remoting.Channels.Tcp;
using System.Diagnostics;
using NUnit.Engine;
using NUnit.Engine.Agents;
@@ -45,12 +43,6 @@ public class NUnitTestAgent
static ITestAgency Agency;
static RemoteTestAgent Agent;
- ///
- /// Channel used for communications with the agency
- /// and with clients
- ///
- static TcpChannel Channel;
-
private const string LOG_FILE_FORMAT = "nunit-agent_{0}.log";
///
@@ -133,7 +125,8 @@ public static int Main(string[] args)
log.Info("Initializing Services");
engine.Initialize();
- Channel = ServerUtilities.GetTcpChannel();
+ // Owns the channel used for communications with the agency and with clients
+ var testAgencyServer = engine.Services.GetService();
log.Info("Connecting to TestAgency at {0}", AgencyUrl);
try
@@ -145,32 +138,29 @@ public static int Main(string[] args)
log.Error("Unable to connect", ex);
}
- if (Channel != null)
- {
- log.Info("Starting RemoteTestAgent");
- Agent = new RemoteTestAgent(AgentId, Agency, engine.Services);
+ log.Info("Starting RemoteTestAgent");
+ Agent = new RemoteTestAgent(AgentId, Agency, engine.Services);
- try
- {
- if (Agent.Start())
- WaitForStop();
- else
- log.Error("Failed to start RemoteTestAgent");
- }
- catch (Exception ex)
- {
- log.Error("Exception in RemoteTestAgent", ex);
- }
+ try
+ {
+ if (Agent.Start())
+ WaitForStop();
+ else
+ log.Error("Failed to start RemoteTestAgent");
+ }
+ catch (Exception ex)
+ {
+ log.Error("Exception in RemoteTestAgent", ex);
+ }
- //log.Info("Unregistering Channel");
- try
- {
- ChannelServices.UnregisterChannel(Channel);
- }
- catch (Exception ex)
- {
- log.Error("ChannelServices.UnregisterChannel threw an exception", ex);
- }
+ try
+ {
+ // Unregister the channel
+ testAgencyServer.Stop();
+ }
+ catch (Exception ex)
+ {
+ log.Error("Exception in TestAgency.Stop", ex);
}
log.Info("Agent process {0} exiting", Process.GetCurrentProcess().Id);
diff --git a/src/NUnitEngine/nunit.engine/Internal/CurrentMessageCounter.cs b/src/NUnitEngine/nunit.engine/Internal/CurrentMessageCounter.cs
new file mode 100644
index 000000000..3b4fe8717
--- /dev/null
+++ b/src/NUnitEngine/nunit.engine/Internal/CurrentMessageCounter.cs
@@ -0,0 +1,27 @@
+using System.Threading;
+
+namespace NUnit.Engine.Internal
+{
+ public sealed class CurrentMessageCounter
+ {
+ private readonly ManualResetEvent _noMessages = new ManualResetEvent(true);
+ private int _currentMessageCount;
+
+ public void OnMessageStart()
+ {
+ if (Interlocked.Increment(ref _currentMessageCount) == 1)
+ _noMessages.Reset();
+ }
+
+ public void OnMessageEnd()
+ {
+ if (Interlocked.Decrement(ref _currentMessageCount) == 0)
+ _noMessages.Set();
+ }
+
+ public void WaitForAllCurrentMessages()
+ {
+ _noMessages.WaitOne();
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/NUnitEngine/nunit.engine/Internal/ServerBase.cs b/src/NUnitEngine/nunit.engine/Internal/ServerBase.cs
index 79e7dfb81..5f022cea9 100644
--- a/src/NUnitEngine/nunit.engine/Internal/ServerBase.cs
+++ b/src/NUnitEngine/nunit.engine/Internal/ServerBase.cs
@@ -24,7 +24,6 @@
using System;
using System.Threading;
using System.Runtime.Remoting;
-using System.Runtime.Remoting.Services;
using System.Runtime.Remoting.Channels;
using System.Runtime.Remoting.Channels.Tcp;
@@ -39,6 +38,7 @@ public abstract class ServerBase : MarshalByRefObject, IDisposable
protected int port;
private TcpChannel channel;
+ private CurrentMessageCounter currentMessageCounter;
private bool isMarshalled;
private object theLock = new object();
@@ -47,11 +47,6 @@ protected ServerBase()
{
}
- ///
- /// Constructor used to provide
- ///
- ///
- ///
protected ServerBase(string uri, int port)
{
this.uri = uri;
@@ -69,7 +64,8 @@ public virtual void Start()
{
lock (theLock)
{
- this.channel = ServerUtilities.GetTcpChannel(uri + "Channel", port, 100);
+ this.currentMessageCounter = new CurrentMessageCounter();
+ this.channel = ServerUtilities.GetTcpChannel(uri + "Channel", port, 100, currentMessageCounter);
RemotingServices.Marshal(this, uri);
this.isMarshalled = true;
@@ -90,6 +86,8 @@ public virtual void Start()
[System.Runtime.Remoting.Messaging.OneWay]
public virtual void Stop()
{
+ currentMessageCounter.WaitForAllCurrentMessages();
+
lock( theLock )
{
if ( this.isMarshalled )
diff --git a/src/NUnitEngine/nunit.engine/Internal/ServerUtilities.ObservableServerChannelSinkProvider.cs b/src/NUnitEngine/nunit.engine/Internal/ServerUtilities.ObservableServerChannelSinkProvider.cs
new file mode 100644
index 000000000..071f73b27
--- /dev/null
+++ b/src/NUnitEngine/nunit.engine/Internal/ServerUtilities.ObservableServerChannelSinkProvider.cs
@@ -0,0 +1,91 @@
+using System;
+using System.Collections;
+using System.IO;
+using System.Runtime.Remoting.Channels;
+using System.Runtime.Remoting.Messaging;
+
+namespace NUnit.Engine.Internal
+{
+ partial class ServerUtilities
+ {
+ private sealed class ObservableServerChannelSinkProvider : IServerChannelSinkProvider
+ {
+ private readonly CurrentMessageCounter _currentMessageCounter;
+
+ public ObservableServerChannelSinkProvider(CurrentMessageCounter currentMessageCounter)
+ {
+ if (currentMessageCounter == null) throw new ArgumentNullException(nameof(currentMessageCounter));
+ _currentMessageCounter = currentMessageCounter;
+ }
+
+ public void GetChannelData(IChannelDataStore channelData)
+ {
+ }
+
+ public IServerChannelSink CreateSink(IChannelReceiver channel)
+ {
+ if (Next == null)
+ throw new InvalidOperationException("Cannot create a sink without setting the next provider.");
+ return new ObservableServerChannelSink(_currentMessageCounter, Next.CreateSink(channel));
+ }
+
+ public IServerChannelSinkProvider Next { get; set; }
+
+
+ private sealed class ObservableServerChannelSink : IServerChannelSink
+ {
+ private readonly IServerChannelSink _next;
+ private readonly CurrentMessageCounter _currentMessageCounter;
+
+ public ObservableServerChannelSink(CurrentMessageCounter currentMessageCounter, IServerChannelSink next)
+ {
+ if (next == null) throw new ArgumentNullException(nameof(next));
+ _currentMessageCounter = currentMessageCounter;
+ _next = next;
+ }
+
+ public IDictionary Properties => _next.Properties;
+
+ public ServerProcessing ProcessMessage(IServerChannelSinkStack sinkStack, IMessage requestMsg,
+ ITransportHeaders requestHeaders, Stream requestStream, out IMessage responseMsg,
+ out ITransportHeaders responseHeaders, out Stream responseStream)
+ {
+ _currentMessageCounter.OnMessageStart();
+ var isAsync = false;
+ try
+ {
+ var processing = _next.ProcessMessage(sinkStack, requestMsg, requestHeaders, requestStream,
+ out responseMsg, out responseHeaders, out responseStream);
+ isAsync = processing == ServerProcessing.Async;
+ return processing;
+ }
+ finally
+ {
+ if (!isAsync) _currentMessageCounter.OnMessageEnd();
+ }
+ }
+
+ public void AsyncProcessResponse(IServerResponseChannelSinkStack sinkStack, object state, IMessage msg,
+ ITransportHeaders headers, Stream stream)
+ {
+ try
+ {
+ _next.AsyncProcessResponse(sinkStack, state, msg, headers, stream);
+ }
+ finally
+ {
+ _currentMessageCounter.OnMessageEnd();
+ }
+ }
+
+ public Stream GetResponseStream(IServerResponseChannelSinkStack sinkStack, object state, IMessage msg,
+ ITransportHeaders headers)
+ {
+ return _next.GetResponseStream(sinkStack, state, msg, headers);
+ }
+
+ public IServerChannelSink NextChannelSink => _next.NextChannelSink;
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/NUnitEngine/nunit.engine/Internal/ServerUtilities.cs b/src/NUnitEngine/nunit.engine/Internal/ServerUtilities.cs
index 8c6e28b94..c15512e8d 100644
--- a/src/NUnitEngine/nunit.engine/Internal/ServerUtilities.cs
+++ b/src/NUnitEngine/nunit.engine/Internal/ServerUtilities.cs
@@ -34,7 +34,7 @@ namespace NUnit.Engine.Internal
/// A collection of utility methods used to create, retrieve
/// and release s.
///
- public static class ServerUtilities
+ public static partial class ServerUtilities
{
static Logger log = InternalTrace.GetLogger(typeof(ServerUtilities));
@@ -44,8 +44,9 @@ public static class ServerUtilities
/// The name of the channel to create.
/// The port number of the channel to create.
/// The rate limit of the channel to create.
+ /// An optional counter to provide the ability to wait for all current messages
/// A configured with the given name and port.
- private static TcpChannel CreateTcpChannel( string name, int port, int limit )
+ private static TcpChannel CreateTcpChannel( string name, int port, int limit, CurrentMessageCounter currentMessageCounter )
{
Hashtable props = new Hashtable();
props.Add( "port", port );
@@ -69,16 +70,22 @@ private static TcpChannel CreateTcpChannel( string name, int port, int limit )
BinaryClientFormatterSinkProvider clientProvider =
new BinaryClientFormatterSinkProvider();
- return new TcpChannel( props, clientProvider, serverProvider );
+ return new TcpChannel(
+ props,
+ clientProvider,
+ currentMessageCounter != null
+ ? new ObservableServerChannelSinkProvider(currentMessageCounter) { Next = serverProvider }
+ : (IServerChannelSinkProvider)serverProvider);
}
///
/// Get a default channel. If one does not exist, then one is created and registered.
///
+ /// An optional counter to provide the ability to wait for all current messages
/// The specified or null if it cannot be found and created
- public static TcpChannel GetTcpChannel()
+ public static TcpChannel GetTcpChannel( CurrentMessageCounter currentMessageCounter = null )
{
- return GetTcpChannel( "", 0, 2 );
+ return GetTcpChannel( "", 0, 2, currentMessageCounter );
}
///
@@ -88,12 +95,13 @@ public static TcpChannel GetTcpChannel()
///
/// The name of the channel
/// The port to use if the channel must be created
+ /// An optional counter to provide the ability to wait for all current messages
/// The specified or null if it cannot be found and created
- public static TcpChannel GetTcpChannel( string name, int port )
+ public static TcpChannel GetTcpChannel( string name, int port, CurrentMessageCounter currentMessageCounter = null )
{
- return GetTcpChannel( name, port, 2 );
+ return GetTcpChannel( name, port, 2, currentMessageCounter );
}
-
+
///
/// Get a channel by name, casting it to a .
/// Otherwise, create, register and return a with
@@ -102,8 +110,9 @@ public static TcpChannel GetTcpChannel( string name, int port )
/// The name of the channel
/// The port to use if the channel must be created
/// The client connection limit or negative for the default
+ /// An optional counter to provide the ability to wait for all current messages
/// The specified or null if it cannot be found and created
- public static TcpChannel GetTcpChannel(string name, int port, int limit)
+ public static TcpChannel GetTcpChannel(string name, int port, int limit, CurrentMessageCounter currentMessageCounter = null)
{
TcpChannel channel = ChannelServices.GetChannel( name ) as TcpChannel;
@@ -115,7 +124,7 @@ public static TcpChannel GetTcpChannel(string name, int port, int limit)
while( --retries > 0 )
try
{
- channel = CreateTcpChannel( name, port, limit );
+ channel = CreateTcpChannel( name, port, limit, currentMessageCounter );
ChannelServices.RegisterChannel( channel, false );
break;
}
diff --git a/src/NUnitEngine/nunit.engine/nunit.engine.csproj b/src/NUnitEngine/nunit.engine/nunit.engine.csproj
index 467b7b910..22ae02573 100644
--- a/src/NUnitEngine/nunit.engine/nunit.engine.csproj
+++ b/src/NUnitEngine/nunit.engine/nunit.engine.csproj
@@ -77,10 +77,12 @@
+
+