diff --git a/src/NUnitEngine/nunit-agent/Program.cs b/src/NUnitEngine/nunit-agent/Program.cs index d132ba094..09f2189eb 100644 --- a/src/NUnitEngine/nunit-agent/Program.cs +++ b/src/NUnitEngine/nunit-agent/Program.cs @@ -22,8 +22,6 @@ // *********************************************************************** using System; -using System.Runtime.Remoting.Channels; -using System.Runtime.Remoting.Channels.Tcp; using System.Diagnostics; using NUnit.Engine; using NUnit.Engine.Agents; @@ -45,12 +43,6 @@ public class NUnitTestAgent static ITestAgency Agency; static RemoteTestAgent Agent; - /// - /// Channel used for communications with the agency - /// and with clients - /// - static TcpChannel Channel; - private const string LOG_FILE_FORMAT = "nunit-agent_{0}.log"; /// @@ -133,7 +125,8 @@ public static int Main(string[] args) log.Info("Initializing Services"); engine.Initialize(); - Channel = ServerUtilities.GetTcpChannel(); + // Owns the channel used for communications with the agency and with clients + var testAgencyServer = engine.Services.GetService(); log.Info("Connecting to TestAgency at {0}", AgencyUrl); try @@ -145,32 +138,29 @@ public static int Main(string[] args) log.Error("Unable to connect", ex); } - if (Channel != null) - { - log.Info("Starting RemoteTestAgent"); - Agent = new RemoteTestAgent(AgentId, Agency, engine.Services); + log.Info("Starting RemoteTestAgent"); + Agent = new RemoteTestAgent(AgentId, Agency, engine.Services); - try - { - if (Agent.Start()) - WaitForStop(); - else - log.Error("Failed to start RemoteTestAgent"); - } - catch (Exception ex) - { - log.Error("Exception in RemoteTestAgent", ex); - } + try + { + if (Agent.Start()) + WaitForStop(); + else + log.Error("Failed to start RemoteTestAgent"); + } + catch (Exception ex) + { + log.Error("Exception in RemoteTestAgent", ex); + } - //log.Info("Unregistering Channel"); - try - { - ChannelServices.UnregisterChannel(Channel); - } - catch (Exception ex) - { - log.Error("ChannelServices.UnregisterChannel threw an exception", ex); - } + try + { + // Unregister the channel + testAgencyServer.Stop(); + } + catch (Exception ex) + { + log.Error("Exception in TestAgency.Stop", ex); } log.Info("Agent process {0} exiting", Process.GetCurrentProcess().Id); diff --git a/src/NUnitEngine/nunit.engine/Internal/CurrentMessageCounter.cs b/src/NUnitEngine/nunit.engine/Internal/CurrentMessageCounter.cs new file mode 100644 index 000000000..3b4fe8717 --- /dev/null +++ b/src/NUnitEngine/nunit.engine/Internal/CurrentMessageCounter.cs @@ -0,0 +1,27 @@ +using System.Threading; + +namespace NUnit.Engine.Internal +{ + public sealed class CurrentMessageCounter + { + private readonly ManualResetEvent _noMessages = new ManualResetEvent(true); + private int _currentMessageCount; + + public void OnMessageStart() + { + if (Interlocked.Increment(ref _currentMessageCount) == 1) + _noMessages.Reset(); + } + + public void OnMessageEnd() + { + if (Interlocked.Decrement(ref _currentMessageCount) == 0) + _noMessages.Set(); + } + + public void WaitForAllCurrentMessages() + { + _noMessages.WaitOne(); + } + } +} \ No newline at end of file diff --git a/src/NUnitEngine/nunit.engine/Internal/ServerBase.cs b/src/NUnitEngine/nunit.engine/Internal/ServerBase.cs index 79e7dfb81..5f022cea9 100644 --- a/src/NUnitEngine/nunit.engine/Internal/ServerBase.cs +++ b/src/NUnitEngine/nunit.engine/Internal/ServerBase.cs @@ -24,7 +24,6 @@ using System; using System.Threading; using System.Runtime.Remoting; -using System.Runtime.Remoting.Services; using System.Runtime.Remoting.Channels; using System.Runtime.Remoting.Channels.Tcp; @@ -39,6 +38,7 @@ public abstract class ServerBase : MarshalByRefObject, IDisposable protected int port; private TcpChannel channel; + private CurrentMessageCounter currentMessageCounter; private bool isMarshalled; private object theLock = new object(); @@ -47,11 +47,6 @@ protected ServerBase() { } - /// - /// Constructor used to provide - /// - /// - /// protected ServerBase(string uri, int port) { this.uri = uri; @@ -69,7 +64,8 @@ public virtual void Start() { lock (theLock) { - this.channel = ServerUtilities.GetTcpChannel(uri + "Channel", port, 100); + this.currentMessageCounter = new CurrentMessageCounter(); + this.channel = ServerUtilities.GetTcpChannel(uri + "Channel", port, 100, currentMessageCounter); RemotingServices.Marshal(this, uri); this.isMarshalled = true; @@ -90,6 +86,8 @@ public virtual void Start() [System.Runtime.Remoting.Messaging.OneWay] public virtual void Stop() { + currentMessageCounter.WaitForAllCurrentMessages(); + lock( theLock ) { if ( this.isMarshalled ) diff --git a/src/NUnitEngine/nunit.engine/Internal/ServerUtilities.ObservableServerChannelSinkProvider.cs b/src/NUnitEngine/nunit.engine/Internal/ServerUtilities.ObservableServerChannelSinkProvider.cs new file mode 100644 index 000000000..071f73b27 --- /dev/null +++ b/src/NUnitEngine/nunit.engine/Internal/ServerUtilities.ObservableServerChannelSinkProvider.cs @@ -0,0 +1,91 @@ +using System; +using System.Collections; +using System.IO; +using System.Runtime.Remoting.Channels; +using System.Runtime.Remoting.Messaging; + +namespace NUnit.Engine.Internal +{ + partial class ServerUtilities + { + private sealed class ObservableServerChannelSinkProvider : IServerChannelSinkProvider + { + private readonly CurrentMessageCounter _currentMessageCounter; + + public ObservableServerChannelSinkProvider(CurrentMessageCounter currentMessageCounter) + { + if (currentMessageCounter == null) throw new ArgumentNullException(nameof(currentMessageCounter)); + _currentMessageCounter = currentMessageCounter; + } + + public void GetChannelData(IChannelDataStore channelData) + { + } + + public IServerChannelSink CreateSink(IChannelReceiver channel) + { + if (Next == null) + throw new InvalidOperationException("Cannot create a sink without setting the next provider."); + return new ObservableServerChannelSink(_currentMessageCounter, Next.CreateSink(channel)); + } + + public IServerChannelSinkProvider Next { get; set; } + + + private sealed class ObservableServerChannelSink : IServerChannelSink + { + private readonly IServerChannelSink _next; + private readonly CurrentMessageCounter _currentMessageCounter; + + public ObservableServerChannelSink(CurrentMessageCounter currentMessageCounter, IServerChannelSink next) + { + if (next == null) throw new ArgumentNullException(nameof(next)); + _currentMessageCounter = currentMessageCounter; + _next = next; + } + + public IDictionary Properties => _next.Properties; + + public ServerProcessing ProcessMessage(IServerChannelSinkStack sinkStack, IMessage requestMsg, + ITransportHeaders requestHeaders, Stream requestStream, out IMessage responseMsg, + out ITransportHeaders responseHeaders, out Stream responseStream) + { + _currentMessageCounter.OnMessageStart(); + var isAsync = false; + try + { + var processing = _next.ProcessMessage(sinkStack, requestMsg, requestHeaders, requestStream, + out responseMsg, out responseHeaders, out responseStream); + isAsync = processing == ServerProcessing.Async; + return processing; + } + finally + { + if (!isAsync) _currentMessageCounter.OnMessageEnd(); + } + } + + public void AsyncProcessResponse(IServerResponseChannelSinkStack sinkStack, object state, IMessage msg, + ITransportHeaders headers, Stream stream) + { + try + { + _next.AsyncProcessResponse(sinkStack, state, msg, headers, stream); + } + finally + { + _currentMessageCounter.OnMessageEnd(); + } + } + + public Stream GetResponseStream(IServerResponseChannelSinkStack sinkStack, object state, IMessage msg, + ITransportHeaders headers) + { + return _next.GetResponseStream(sinkStack, state, msg, headers); + } + + public IServerChannelSink NextChannelSink => _next.NextChannelSink; + } + } + } +} \ No newline at end of file diff --git a/src/NUnitEngine/nunit.engine/Internal/ServerUtilities.cs b/src/NUnitEngine/nunit.engine/Internal/ServerUtilities.cs index 8c6e28b94..c15512e8d 100644 --- a/src/NUnitEngine/nunit.engine/Internal/ServerUtilities.cs +++ b/src/NUnitEngine/nunit.engine/Internal/ServerUtilities.cs @@ -34,7 +34,7 @@ namespace NUnit.Engine.Internal /// A collection of utility methods used to create, retrieve /// and release s. /// - public static class ServerUtilities + public static partial class ServerUtilities { static Logger log = InternalTrace.GetLogger(typeof(ServerUtilities)); @@ -44,8 +44,9 @@ public static class ServerUtilities /// The name of the channel to create. /// The port number of the channel to create. /// The rate limit of the channel to create. + /// An optional counter to provide the ability to wait for all current messages /// A configured with the given name and port. - private static TcpChannel CreateTcpChannel( string name, int port, int limit ) + private static TcpChannel CreateTcpChannel( string name, int port, int limit, CurrentMessageCounter currentMessageCounter ) { Hashtable props = new Hashtable(); props.Add( "port", port ); @@ -69,16 +70,22 @@ private static TcpChannel CreateTcpChannel( string name, int port, int limit ) BinaryClientFormatterSinkProvider clientProvider = new BinaryClientFormatterSinkProvider(); - return new TcpChannel( props, clientProvider, serverProvider ); + return new TcpChannel( + props, + clientProvider, + currentMessageCounter != null + ? new ObservableServerChannelSinkProvider(currentMessageCounter) { Next = serverProvider } + : (IServerChannelSinkProvider)serverProvider); } /// /// Get a default channel. If one does not exist, then one is created and registered. /// + /// An optional counter to provide the ability to wait for all current messages /// The specified or null if it cannot be found and created - public static TcpChannel GetTcpChannel() + public static TcpChannel GetTcpChannel( CurrentMessageCounter currentMessageCounter = null ) { - return GetTcpChannel( "", 0, 2 ); + return GetTcpChannel( "", 0, 2, currentMessageCounter ); } /// @@ -88,12 +95,13 @@ public static TcpChannel GetTcpChannel() /// /// The name of the channel /// The port to use if the channel must be created + /// An optional counter to provide the ability to wait for all current messages /// The specified or null if it cannot be found and created - public static TcpChannel GetTcpChannel( string name, int port ) + public static TcpChannel GetTcpChannel( string name, int port, CurrentMessageCounter currentMessageCounter = null ) { - return GetTcpChannel( name, port, 2 ); + return GetTcpChannel( name, port, 2, currentMessageCounter ); } - + /// /// Get a channel by name, casting it to a . /// Otherwise, create, register and return a with @@ -102,8 +110,9 @@ public static TcpChannel GetTcpChannel( string name, int port ) /// The name of the channel /// The port to use if the channel must be created /// The client connection limit or negative for the default + /// An optional counter to provide the ability to wait for all current messages /// The specified or null if it cannot be found and created - public static TcpChannel GetTcpChannel(string name, int port, int limit) + public static TcpChannel GetTcpChannel(string name, int port, int limit, CurrentMessageCounter currentMessageCounter = null) { TcpChannel channel = ChannelServices.GetChannel( name ) as TcpChannel; @@ -115,7 +124,7 @@ public static TcpChannel GetTcpChannel(string name, int port, int limit) while( --retries > 0 ) try { - channel = CreateTcpChannel( name, port, limit ); + channel = CreateTcpChannel( name, port, limit, currentMessageCounter ); ChannelServices.RegisterChannel( channel, false ); break; } diff --git a/src/NUnitEngine/nunit.engine/nunit.engine.csproj b/src/NUnitEngine/nunit.engine/nunit.engine.csproj index 467b7b910..22ae02573 100644 --- a/src/NUnitEngine/nunit.engine/nunit.engine.csproj +++ b/src/NUnitEngine/nunit.engine/nunit.engine.csproj @@ -77,10 +77,12 @@ + +