Skip to content

Commit

Permalink
Fix for SocketException caused by --explore
Browse files Browse the repository at this point in the history
  • Loading branch information
jnm2 committed May 20, 2017
1 parent 3750f3d commit b0dacf0
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 50 deletions.
56 changes: 23 additions & 33 deletions src/NUnitEngine/nunit-agent/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
// ***********************************************************************

using System;
using System.Runtime.Remoting.Channels;
using System.Runtime.Remoting.Channels.Tcp;
using System.Diagnostics;
using NUnit.Engine;
using NUnit.Engine.Agents;
Expand All @@ -45,12 +43,6 @@ public class NUnitTestAgent
static ITestAgency Agency;
static RemoteTestAgent Agent;

/// <summary>
/// Channel used for communications with the agency
/// and with clients
/// </summary>
static TcpChannel Channel;

private const string LOG_FILE_FORMAT = "nunit-agent_{0}.log";

/// <summary>
Expand Down Expand Up @@ -133,7 +125,8 @@ public static int Main(string[] args)
log.Info("Initializing Services");
engine.Initialize();

Channel = ServerUtilities.GetTcpChannel();
// Owns the channel used for communications with the agency and with clients
var testAgencyServer = engine.Services.GetService<TestAgency>();

log.Info("Connecting to TestAgency at {0}", AgencyUrl);
try
Expand All @@ -145,32 +138,29 @@ public static int Main(string[] args)
log.Error("Unable to connect", ex);
}

if (Channel != null)
{
log.Info("Starting RemoteTestAgent");
Agent = new RemoteTestAgent(AgentId, Agency, engine.Services);
log.Info("Starting RemoteTestAgent");
Agent = new RemoteTestAgent(AgentId, Agency, engine.Services);

try
{
if (Agent.Start())
WaitForStop();
else
log.Error("Failed to start RemoteTestAgent");
}
catch (Exception ex)
{
log.Error("Exception in RemoteTestAgent", ex);
}
try
{
if (Agent.Start())
WaitForStop();
else
log.Error("Failed to start RemoteTestAgent");
}
catch (Exception ex)
{
log.Error("Exception in RemoteTestAgent", ex);
}

//log.Info("Unregistering Channel");
try
{
ChannelServices.UnregisterChannel(Channel);
}
catch (Exception ex)
{
log.Error("ChannelServices.UnregisterChannel threw an exception", ex);
}
try
{
// Unregister the channel
testAgencyServer.Stop();
}
catch (Exception ex)
{
log.Error("Exception in TestAgency.Stop", ex);
}

log.Info("Agent process {0} exiting", Process.GetCurrentProcess().Id);
Expand Down
27 changes: 27 additions & 0 deletions src/NUnitEngine/nunit.engine/Internal/CurrentMessageCounter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using System.Threading;

namespace NUnit.Engine.Internal
{
public sealed class CurrentMessageCounter
{
private readonly ManualResetEvent _noMessages = new ManualResetEvent(true);
private int _currentMessageCount;

public void OnMessageStart()
{
if (Interlocked.Increment(ref _currentMessageCount) == 1)
_noMessages.Reset();
}

public void OnMessageEnd()
{
if (Interlocked.Decrement(ref _currentMessageCount) == 0)
_noMessages.Set();
}

public void WaitForAllCurrentMessages()
{
_noMessages.WaitOne();
}
}
}
12 changes: 5 additions & 7 deletions src/NUnitEngine/nunit.engine/Internal/ServerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
using System;
using System.Threading;
using System.Runtime.Remoting;
using System.Runtime.Remoting.Services;
using System.Runtime.Remoting.Channels;
using System.Runtime.Remoting.Channels.Tcp;

Expand All @@ -39,6 +38,7 @@ public abstract class ServerBase : MarshalByRefObject, IDisposable
protected int port;

private TcpChannel channel;
private CurrentMessageCounter currentMessageCounter;
private bool isMarshalled;

private object theLock = new object();
Expand All @@ -47,11 +47,6 @@ protected ServerBase()
{
}

/// <summary>
/// Constructor used to provide
/// </summary>
/// <param name="uri"></param>
/// <param name="port"></param>
protected ServerBase(string uri, int port)
{
this.uri = uri;
Expand All @@ -69,7 +64,8 @@ public virtual void Start()
{
lock (theLock)
{
this.channel = ServerUtilities.GetTcpChannel(uri + "Channel", port, 100);
this.currentMessageCounter = new CurrentMessageCounter();
this.channel = ServerUtilities.GetTcpChannel(uri + "Channel", port, 100, currentMessageCounter);

RemotingServices.Marshal(this, uri);
this.isMarshalled = true;
Expand All @@ -90,6 +86,8 @@ public virtual void Start()
[System.Runtime.Remoting.Messaging.OneWay]
public virtual void Stop()
{
currentMessageCounter.WaitForAllCurrentMessages();

lock( theLock )
{
if ( this.isMarshalled )
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
using System;
using System.Collections;
using System.IO;
using System.Runtime.Remoting.Channels;
using System.Runtime.Remoting.Messaging;

namespace NUnit.Engine.Internal
{
partial class ServerUtilities
{
private sealed class ObservableServerChannelSinkProvider : IServerChannelSinkProvider
{
private readonly CurrentMessageCounter _currentMessageCounter;

public ObservableServerChannelSinkProvider(CurrentMessageCounter currentMessageCounter)
{
if (currentMessageCounter == null) throw new ArgumentNullException(nameof(currentMessageCounter));
_currentMessageCounter = currentMessageCounter;
}

public void GetChannelData(IChannelDataStore channelData)
{
}

public IServerChannelSink CreateSink(IChannelReceiver channel)
{
if (Next == null)
throw new InvalidOperationException("Cannot create a sink without setting the next provider.");
return new ObservableServerChannelSink(_currentMessageCounter, Next.CreateSink(channel));
}

public IServerChannelSinkProvider Next { get; set; }


private sealed class ObservableServerChannelSink : IServerChannelSink
{
private readonly IServerChannelSink _next;
private readonly CurrentMessageCounter _currentMessageCounter;

public ObservableServerChannelSink(CurrentMessageCounter currentMessageCounter, IServerChannelSink next)
{
if (next == null) throw new ArgumentNullException(nameof(next));
_currentMessageCounter = currentMessageCounter;
_next = next;
}

public IDictionary Properties => _next.Properties;

public ServerProcessing ProcessMessage(IServerChannelSinkStack sinkStack, IMessage requestMsg,
ITransportHeaders requestHeaders, Stream requestStream, out IMessage responseMsg,
out ITransportHeaders responseHeaders, out Stream responseStream)
{
_currentMessageCounter.OnMessageStart();
var isAsync = false;
try
{
var processing = _next.ProcessMessage(sinkStack, requestMsg, requestHeaders, requestStream,
out responseMsg, out responseHeaders, out responseStream);
isAsync = processing == ServerProcessing.Async;
return processing;
}
finally
{
if (!isAsync) _currentMessageCounter.OnMessageEnd();
}
}

public void AsyncProcessResponse(IServerResponseChannelSinkStack sinkStack, object state, IMessage msg,
ITransportHeaders headers, Stream stream)
{
try
{
_next.AsyncProcessResponse(sinkStack, state, msg, headers, stream);
}
finally
{
_currentMessageCounter.OnMessageEnd();
}
}

public Stream GetResponseStream(IServerResponseChannelSinkStack sinkStack, object state, IMessage msg,
ITransportHeaders headers)
{
return _next.GetResponseStream(sinkStack, state, msg, headers);
}

public IServerChannelSink NextChannelSink => _next.NextChannelSink;
}
}
}
}
29 changes: 19 additions & 10 deletions src/NUnitEngine/nunit.engine/Internal/ServerUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace NUnit.Engine.Internal
/// A collection of utility methods used to create, retrieve
/// and release <see cref="TcpChannel"/>s.
/// </summary>
public static class ServerUtilities
public static partial class ServerUtilities
{
static Logger log = InternalTrace.GetLogger(typeof(ServerUtilities));

Expand All @@ -44,8 +44,9 @@ public static class ServerUtilities
/// <param name="name">The name of the channel to create.</param>
/// <param name="port">The port number of the channel to create.</param>
/// <param name="limit">The rate limit of the channel to create.</param>
/// <param name="currentMessageCounter">An optional counter to provide the ability to wait for all current messages</param>
/// <returns>A <see cref="TcpChannel"/> configured with the given name and port.</returns>
private static TcpChannel CreateTcpChannel( string name, int port, int limit )
private static TcpChannel CreateTcpChannel( string name, int port, int limit, CurrentMessageCounter currentMessageCounter )
{
Hashtable props = new Hashtable();
props.Add( "port", port );
Expand All @@ -69,16 +70,22 @@ private static TcpChannel CreateTcpChannel( string name, int port, int limit )
BinaryClientFormatterSinkProvider clientProvider =
new BinaryClientFormatterSinkProvider();

return new TcpChannel( props, clientProvider, serverProvider );
return new TcpChannel(
props,
clientProvider,
currentMessageCounter != null
? new ObservableServerChannelSinkProvider(currentMessageCounter) { Next = serverProvider }
: (IServerChannelSinkProvider)serverProvider);
}

/// <summary>
/// Get a default channel. If one does not exist, then one is created and registered.
/// </summary>
/// <param name="currentMessageCounter">An optional counter to provide the ability to wait for all current messages</param>
/// <returns>The specified <see cref="TcpChannel"/> or null if it cannot be found and created</returns>
public static TcpChannel GetTcpChannel()
public static TcpChannel GetTcpChannel( CurrentMessageCounter currentMessageCounter = null )
{
return GetTcpChannel( "", 0, 2 );
return GetTcpChannel( "", 0, 2, currentMessageCounter );
}

/// <summary>
Expand All @@ -88,12 +95,13 @@ public static TcpChannel GetTcpChannel()
/// </summary>
/// <param name="name">The name of the channel</param>
/// <param name="port">The port to use if the channel must be created</param>
/// <param name="currentMessageCounter">An optional counter to provide the ability to wait for all current messages</param>
/// <returns>The specified <see cref="TcpChannel"/> or null if it cannot be found and created</returns>
public static TcpChannel GetTcpChannel( string name, int port )
public static TcpChannel GetTcpChannel( string name, int port, CurrentMessageCounter currentMessageCounter = null )
{
return GetTcpChannel( name, port, 2 );
return GetTcpChannel( name, port, 2, currentMessageCounter );
}

/// <summary>
/// Get a channel by name, casting it to a <see cref="TcpChannel"/>.
/// Otherwise, create, register and return a <see cref="TcpChannel"/> with
Expand All @@ -102,8 +110,9 @@ public static TcpChannel GetTcpChannel( string name, int port )
/// <param name="name">The name of the channel</param>
/// <param name="port">The port to use if the channel must be created</param>
/// <param name="limit">The client connection limit or negative for the default</param>
/// <param name="currentMessageCounter">An optional counter to provide the ability to wait for all current messages</param>
/// <returns>The specified <see cref="TcpChannel"/> or null if it cannot be found and created</returns>
public static TcpChannel GetTcpChannel(string name, int port, int limit)
public static TcpChannel GetTcpChannel(string name, int port, int limit, CurrentMessageCounter currentMessageCounter = null)
{
TcpChannel channel = ChannelServices.GetChannel( name ) as TcpChannel;

Expand All @@ -115,7 +124,7 @@ public static TcpChannel GetTcpChannel(string name, int port, int limit)
while( --retries > 0 )
try
{
channel = CreateTcpChannel( name, port, limit );
channel = CreateTcpChannel( name, port, limit, currentMessageCounter );
ChannelServices.RegisterChannel( channel, false );
break;
}
Expand Down
2 changes: 2 additions & 0 deletions src/NUnitEngine/nunit.engine/nunit.engine.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,12 @@
<Compile Include="InternalEnginePackageSettings.cs" />
<Compile Include="Internal\AssemblyHelper.cs" />
<Compile Include="Internal\CecilExtensions.cs" />
<Compile Include="Internal\CurrentMessageCounter.cs" />
<Compile Include="Internal\DirectoryFinder.cs" />
<Compile Include="Internal\Logging\InternalTrace.cs" />
<Compile Include="Internal\Logging\InternalTraceWriter.cs" />
<Compile Include="Internal\Logging\Logger.cs" />
<Compile Include="Internal\ServerUtilities.ObservableServerChannelSinkProvider.cs" />
<Compile Include="Internal\ProvidedPathsAssemblyResolver.cs" />
<Compile Include="ITestAgency.cs" />
<Compile Include="ITestAgent.cs" />
Expand Down

0 comments on commit b0dacf0

Please sign in to comment.