Skip to content

Commit

Permalink
Added secure socket(wss://) support.
Browse files Browse the repository at this point in the history
  • Loading branch information
statianzo committed Jun 28, 2011
1 parent f10b7eb commit 93c8f1a
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 31 deletions.
6 changes: 3 additions & 3 deletions src/Fleck.Tests/ClientHandshakeTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public void HostnameShouldMatchOnUri()
clientHandshake.ResourcePath = "BBB";

clientHandshake.Host = "localhost:8181";
Assert.IsTrue(clientHandshake.Validate(null, "ws://localhost:8181/"));
Assert.IsTrue(clientHandshake.Validate(null, "ws://localhost:8181/", "ws"));
}

[Test]
Expand All @@ -28,7 +28,7 @@ public void CorruptHostShouldNotValidate()
clientHandshake.ResourcePath = "BBB";

clientHandshake.Host = "$%%$%NoT^^^A)()(()VALID--==!!URI&&&@@#$#~~~";
Assert.IsFalse(clientHandshake.Validate(null, "ws://localhost:8181/"));
Assert.IsFalse(clientHandshake.Validate(null, "ws://localhost:8181/", "ws"));
}

[Test]
Expand All @@ -41,7 +41,7 @@ public void NullHostShouldNotValidate()
clientHandshake.ResourcePath = "BBB";

clientHandshake.Host = null;
Assert.IsFalse(clientHandshake.Validate(null, "ws://localhost:8181/"));
Assert.IsFalse(clientHandshake.Validate(null, "ws://localhost:8181/", "ws"));
}
}
}
8 changes: 7 additions & 1 deletion src/Fleck.Tests/HandshakeHandlerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Text;
using Moq;
using NUnit.Framework;
using System.Security.Cryptography.X509Certificates;

namespace Fleck.Tests
{
Expand Down Expand Up @@ -44,7 +45,7 @@ public class HandshakeHandlerTests
[SetUp]
public void Setup()
{
_handler = new HandshakeHandler(null, "ws://fleck-test.com");
_handler = new HandshakeHandler(null, "ws://fleck-test.com", "ws");
}

[Test]
Expand Down Expand Up @@ -170,5 +171,10 @@ public void Listen(int backlog)
{
throw new NotImplementedException();
}

public void Authenticate (X509Certificate2 certificate, Action callback, Action<Exception> error)
{
throw new NotImplementedException ();
}
}
}
23 changes: 23 additions & 0 deletions src/Fleck.Tests/WebSocketServerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,28 @@ public void ShouldStart()
socketMock.Verify(s => s.Bind(It.Is<IPEndPoint>(i => i.Port == 8000)));
socketMock.Verify(s => s.BeginAccept(It.IsAny<AsyncCallback>(), It.IsAny<object>()));
}

[Test]
public void ShouldBeSecureWithWssAndCertificate()
{
var server = new WebSocketServer("wss://secureplace.com:8000");
server.Certificate = "MyCert.cer";
Assert.IsTrue(server.IsSecure);
}

[Test]
public void ShouldNotBeSecureWithWssAndNoCertificate()
{
var server = new WebSocketServer("wss://secureplace.com:8000");
Assert.IsFalse(server.IsSecure);
}

[Test]
public void ShouldNotBeSecureWithoutWssAndCertificate()
{
var server = new WebSocketServer("ws://secureplace.com:8000");
server.Certificate = "MyCert.cer";
Assert.IsFalse(server.IsSecure);
}
}
}
8 changes: 5 additions & 3 deletions src/Fleck/ClientHandshake.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,17 @@ public override string ToString()
return stringShake;
}

public bool Validate(string origin, string host)
public bool Validate(string origin, string host, string scheme)
{
bool hasRequiredFields = (Host != null) &&
(Key1 != null) &&
(Key2 != null) &&
(Origin != null) &&
(ResourcePath != null);
var hostUri = "ws://" + Host;


var hostUri = string.Format("{0}://{1}", scheme, Host);

FleckLog.Debug(string.Format("Client host {0}", Host));

return hasRequiredFields &&
Uri.IsWellFormedUriString(hostUri, UriKind.RelativeOrAbsolute) &&
Expand Down
8 changes: 5 additions & 3 deletions src/Fleck/HandshakeHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ namespace Fleck
{
public class HandshakeHandler
{
public HandshakeHandler(string origin, string location)
public HandshakeHandler(string origin, string location, string scheme)
{
Origin = origin;
Location = location;
Scheme = scheme;
}

public string Scheme { get; set; }
public string Origin { get; set; }
public string Location { get; set; }
public ClientHandshake ClientHandshake { get; set; }
Expand All @@ -43,7 +45,7 @@ public void DoShake(HandShakeState state, int receivedByteCount)
ClientHandshake = ParseClientHandshake(new ArraySegment<byte>(state.Buffer, 0, receivedByteCount));


if (ClientHandshake.Validate(Origin, Location))
if (ClientHandshake.Validate(Origin, Location, Scheme))
{
FleckLog.Debug("Client handshake validated");
ServerHandshake serverShake = GenerateResponseHandshake();
Expand Down Expand Up @@ -113,7 +115,7 @@ public ServerHandshake GenerateResponseHandshake()
{
var responseHandshake = new ServerHandshake
{
Location = "ws://" + ClientHandshake.Host + ClientHandshake.ResourcePath,
Location = string.Format("{0}://{1}{2}", Scheme, ClientHandshake.Host, ClientHandshake.ResourcePath),
Origin = ClientHandshake.Origin,
SubProtocol = ClientHandshake.SubProtocol
};
Expand Down
2 changes: 2 additions & 0 deletions src/Fleck/Interfaces/ISocket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Net;
using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;

namespace Fleck
{
Expand All @@ -24,5 +25,6 @@ public interface ISocket
int EndSend(IAsyncResult asyncResult);
void Bind(EndPoint ipLocal);
void Listen(int backlog);
void Authenticate(X509Certificate2 certificate, Action callback, Action<Exception> error);
}
}
33 changes: 29 additions & 4 deletions src/Fleck/SocketWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,39 @@
using System.Collections.Generic;
using System.Net;
using System.Net.Sockets;
using System.IO;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;

namespace Fleck
{
public class SocketWrapper : ISocket
{
private readonly Socket _socket;
private Stream _stream;

public SocketWrapper(Socket socket)
{
_socket = socket;
if(_socket.Connected)
_stream = new NetworkStream(_socket);
}

public void Authenticate(X509Certificate2 certificate, Action callback, Action<Exception> error)
{
var ssl = new SslStream(_stream, false);
_stream = ssl;
Func<AsyncCallback, object,IAsyncResult> begin = (cb, s) => ssl.BeginAuthenticateAsServer(certificate,false,System.Security.Authentication.SslProtocols.Tls,false, cb, s);

var task = Task.Factory.FromAsync(begin, ssl.EndAuthenticateAsServer, null);
task.ContinueWith(t => {
callback();
}, TaskContinuationOptions.NotOnFaulted);
task.ContinueWith(t => error(t.Exception),
TaskContinuationOptions.OnlyOnFaulted);

}

public EndPoint LocalEndPoint
{
Expand All @@ -37,12 +59,12 @@ public bool Connected
public IAsyncResult BeginReceive(IList<ArraySegment<byte>> buffers, SocketFlags socketFlags, AsyncCallback callback,
object state)
{
return _socket.BeginReceive(buffers, socketFlags, callback, state);
return _stream.BeginRead(buffers[0].Array,buffers[0].Offset, buffers[0].Count, callback, state);
}

public int EndReceive(IAsyncResult asyncResult)
{
return _socket.EndReceive(asyncResult);
return _stream.EndRead(asyncResult);
}

public IAsyncResult BeginAccept(AsyncCallback callback, object state)
Expand All @@ -57,23 +79,26 @@ public ISocket EndAccept(IAsyncResult asyncResult)

public void Dispose()
{
_stream.Dispose();
_socket.Dispose();
}

public void Close()
{
_stream.Close();
_socket.Close();
}

public IAsyncResult BeginSend(IList<ArraySegment<byte>> buffers, SocketFlags socketFlags, AsyncCallback callback,
object state)
{
return _socket.BeginSend(buffers, socketFlags, callback, state);
return _stream.BeginWrite(buffers[0].Array,buffers[0].Offset, buffers[0].Count, callback, state);
}

public int EndSend(IAsyncResult asyncResult)
{
return _socket.EndSend(asyncResult);
_stream.EndWrite(asyncResult);
return 0;
}
}
}
48 changes: 31 additions & 17 deletions src/Fleck/WebSocketServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,54 @@
using System.Net.Sockets;
using System.Net;
using System.Threading.Tasks;
using System.Security.Cryptography.X509Certificates;

namespace Fleck
{
public class WebSocketServer : IDisposable
{
private Action<IWebSocketConnection> _config;
private readonly string _scheme;
private X509Certificate2 _x509certificate;

public WebSocketServer(string location) : this(0, location)
public WebSocketServer(string location) : this(8181, location)
{
var uri = new Uri(location);
Port = uri.Port > 0 ? uri.Port : 8181;
}
public WebSocketServer(int port, string location)
{
Port = port;
var uri = new Uri(location);
Port = uri.Port > 0 ? uri.Port : port;
Location = location;
_scheme = uri.Scheme;
var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.IP);
ListenerSocket = new SocketWrapper(socket);
}

public WebSocketServer(int port, string location, string origin)
: this(port, location)
{
Origin = origin;
}

public ISocket ListenerSocket { get; set; }
public string Location { get; private set; }
public int Port { get; private set; }
public string Origin { get; set; }
public string Certificate { get; set; }
public bool IsSecure { get { return _scheme == "wss" && Certificate != null; }}

public void Dispose()
{
((IDisposable)ListenerSocket).Dispose();
}

public void Start(Action<IWebSocketConnection> config)
{
var ipLocal = new IPEndPoint(IPAddress.Any, Port);
ListenerSocket.Bind(ipLocal);
ListenerSocket.Listen(100);
FleckLog.Info("Server stated on " + ipLocal);
FleckLog.Info("Server started at " + Location);
if (_scheme == "wss") {
if (Certificate == null) {
FleckLog.Error("Scheme cannot be 'wss' without a Certificate");
return;
}
_x509certificate = new X509Certificate2(Certificate);
}
ListenForClients();
_config = config;
}
Expand All @@ -60,10 +66,9 @@ private void OnClientConnect(Task<ISocket> task)
{
FleckLog.Debug("Client Connected");
ISocket clientSocket = task.Result;
ListenForClients();

ListenForClients();

var shaker = new HandshakeHandler(Origin, Location)
var shaker = new HandshakeHandler(Origin, Location, _scheme)
{
OnSuccess = handshake =>
{
Expand All @@ -74,8 +79,17 @@ private void OnClientConnect(Task<ISocket> task)
wsc.StartReceiving();
}
};

shaker.Shake(clientSocket);

if (IsSecure) {
FleckLog.Debug("Authenticating Secure Connection");
clientSocket.Authenticate(_x509certificate, () => {
FleckLog.Debug("Authentication Successful");
shaker.Shake(clientSocket);
},e => FleckLog.Warn("Failed to Authenticate", e) );
}
else {
shaker.Shake(clientSocket);
}

}
}
Expand Down

0 comments on commit 93c8f1a

Please sign in to comment.