diff --git a/mcs/class/System/System.Net.WebSockets/ClientWebSocket.cs b/mcs/class/System/System.Net.WebSockets/ClientWebSocket.cs index f04cb3c054343..3778c740538ff 100644 --- a/mcs/class/System/System.Net.WebSockets/ClientWebSocket.cs +++ b/mcs/class/System/System.Net.WebSockets/ClientWebSocket.cs @@ -1,10 +1,12 @@ // // ClientWebSocket.cs // -// Author: -// Martin Baulig +// Authors: +// Jérémie Laval // -// Copyright (c) 2013 Xamarin Inc. (http://www.xamarin.com) +// Copyright 2013 Xamarin Inc (http://www.xamarin.com). +// +// Lightly inspired from WebSocket4Net distributed under the Apache License 2.0 // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -27,71 +29,327 @@ #if NET_4_5 using System; +using System.Net; +using System.Net.Sockets; +using System.Security.Principal; +using System.Security.Cryptography.X509Certificates; +using System.Runtime.CompilerServices; +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using System.Globalization; +using System.Text; +using System.Security.Cryptography; namespace System.Net.WebSockets { - [MonoTODO] - public class ClientWebSocket : WebSocket + public class ClientWebSocket : WebSocket, IDisposable { - public ClientWebSocketOptions Options { - get { throw new NotImplementedException (); } + const string Magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + const string VersionTag = "13"; + + ClientWebSocketOptions options; + WebSocketState state; + string subProtocol; + + HttpWebRequest req; + WebConnection connection; + Socket underlyingSocket; + + Random random = new Random (); + + const int HeaderMaxLength = 14; + byte[] headerBuffer; + byte[] sendBuffer; + + public ClientWebSocket () + { + options = new ClientWebSocketOptions (); + state = WebSocketState.None; + headerBuffer = new byte[HeaderMaxLength]; } - public Task ConnectAsync (Uri uri, CancellationToken cancellationToken) + public override void Dispose () { - throw new NotImplementedException (); + if (connection != null) + connection.Close (false); } - #region implemented abstract members of WebSocket + [MonoTODO] public override void Abort () { throw new NotImplementedException (); } - public override Task CloseAsync (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) + + public ClientWebSocketOptions Options { + get { + return options; + } + } + + public override WebSocketState State { + get { + return state; + } + } + + public override WebSocketCloseStatus? CloseStatus { + get { + if (state != WebSocketState.Closed) + return (WebSocketCloseStatus?)null; + return WebSocketCloseStatus.Empty; + } + } + + public override string CloseStatusDescription { + get { + return null; + } + } + + public override string SubProtocol { + get { + return subProtocol; + } + } + + public async Task ConnectAsync (Uri uri, CancellationToken cancellationToken) { - throw new NotImplementedException (); + state = WebSocketState.Connecting; + var httpUri = new UriBuilder (uri); + if (uri.Scheme == "wss") + httpUri.Scheme = "https"; + else + httpUri.Scheme = "http"; + req = (HttpWebRequest)WebRequest.Create (httpUri.Uri); + req.ReuseConnection = true; + if (options.Cookies != null) + req.CookieContainer = options.Cookies; + + if (options.CustomRequestHeaders.Count > 0) { + foreach (var header in options.CustomRequestHeaders) + req.Headers[header.Key] = header.Value; + } + + var secKey = Convert.ToBase64String (Encoding.ASCII.GetBytes (Guid.NewGuid ().ToString ().Substring (0, 16))); + string expectedAccept = Convert.ToBase64String (SHA1.Create ().ComputeHash (Encoding.ASCII.GetBytes (secKey + Magic))); + + req.Headers["Upgrade"] = "WebSocket"; + req.Headers["Sec-WebSocket-Version"] = VersionTag; + req.Headers["Sec-WebSocket-Key"] = secKey; + req.Headers["Sec-WebSocket-Origin"] = uri.Host; + if (options.SubProtocols.Count > 0) + req.Headers["Sec-WebSocket-Protocol"] = string.Join (",", options.SubProtocols); + + if (options.Credentials != null) + req.Credentials = options.Credentials; + if (options.ClientCertificates != null) + req.ClientCertificates = options.ClientCertificates; + if (options.Proxy != null) + req.Proxy = options.Proxy; + req.UseDefaultCredentials = options.UseDefaultCredentials; + req.Connection = "Upgrade"; + + HttpWebResponse resp = null; + try { + resp = (HttpWebResponse)(await req.GetResponseAsync ().ConfigureAwait (false)); + } catch (Exception e) { + throw new WebSocketException (WebSocketError.Success, e); + } + + connection = req.StoredConnection; + underlyingSocket = connection.socket; + + if (resp.StatusCode != HttpStatusCode.SwitchingProtocols) + throw new WebSocketException ("The server returned status code '" + (int)resp.StatusCode + "' when status code '101' was expected"); + if (!string.Equals (resp.Headers["Upgrade"], "WebSocket", StringComparison.OrdinalIgnoreCase) + || !string.Equals (resp.Headers["Connection"], "Upgrade", StringComparison.OrdinalIgnoreCase) + || !string.Equals (resp.Headers["Sec-WebSocket-Accept"], expectedAccept)) + throw new WebSocketException ("HTTP header error during handshake"); + if (resp.Headers["Sec-WebSocket-Protocol"] != null) { + if (!options.SubProtocols.Contains (resp.Headers["Sec-WebSocket-Protocol"])) + throw new WebSocketException (WebSocketError.UnsupportedProtocol); + subProtocol = resp.Headers["Sec-WebSocket-Protocol"]; + } + + state = WebSocketState.Open; } - public override Task CloseOutputAsync (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) + + public override Task SendAsync (ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { - throw new NotImplementedException (); + EnsureWebSocketConnected (); + ValidateArraySegment (buffer); + if (connection == null) + throw new WebSocketException (WebSocketError.Faulted); + var count = Math.Max (options.SendBufferSize, buffer.Count) + HeaderMaxLength; + if (sendBuffer == null || sendBuffer.Length != count) + sendBuffer = new byte[count]; + return Task.Run (() => { + EnsureWebSocketState (WebSocketState.Open, WebSocketState.CloseReceived); + var maskOffset = WriteHeader (messageType, buffer, endOfMessage); + + if (buffer.Count > 0) + MaskData (buffer, maskOffset); + //underlyingSocket.Send (headerBuffer, 0, maskOffset + 4, SocketFlags.None); + var headerLength = maskOffset + 4; + Array.Copy (headerBuffer, sendBuffer, headerLength); + underlyingSocket.Send (sendBuffer, 0, buffer.Count + headerLength, SocketFlags.None); + }); } + public override Task ReceiveAsync (ArraySegment buffer, CancellationToken cancellationToken) { - throw new NotImplementedException (); + EnsureWebSocketConnected (); + ValidateArraySegment (buffer); + return Task.Run (() => { + EnsureWebSocketState (WebSocketState.Open, WebSocketState.CloseSent); + // First read the two first bytes to know what we are doing next + connection.Read (req, headerBuffer, 0, 2); + var isLast = (headerBuffer[0] >> 7) > 0; + var isMasked = (headerBuffer[1] >> 7) > 0; + int mask = 0; + var type = (WebSocketMessageType)(headerBuffer[0] & 0xF); + long length = headerBuffer[1] & 0x7F; + int offset = 0; + if (length == 126) { + offset = 2; + connection.Read (req, headerBuffer, 2, offset); + length = (headerBuffer[2] << 8) | headerBuffer[3]; + } else if (length == 127) { + offset = 8; + connection.Read (req, headerBuffer, 2, offset); + length = 0; + for (int i = 2; i <= 9; i++) + length = (length << 8) | headerBuffer[i]; + } + + if (isMasked) { + connection.Read (req, headerBuffer, 2 + offset, 4); + for (int i = 0; i < 4; i++) { + var pos = i + offset + 2; + mask = (mask << 8) | headerBuffer[pos]; + } + } + + if (type == WebSocketMessageType.Close) { + state = WebSocketState.Closed; + var tmpBuffer = new byte[length]; + connection.Read (req, tmpBuffer, 0, tmpBuffer.Length); + var closeStatus = (WebSocketCloseStatus)(tmpBuffer[0] << 8 | tmpBuffer[1]); + var closeDesc = tmpBuffer.Length > 2 ? Encoding.UTF8.GetString (tmpBuffer, 2, tmpBuffer.Length - 2) : string.Empty; + return new WebSocketReceiveResult ((int)length, type, isLast, closeStatus, closeDesc); + } else { + var readLength = (int)(buffer.Count < length ? buffer.Count : length); + connection.Read (req, buffer.Array, buffer.Offset, readLength); + + return new WebSocketReceiveResult ((int)length, type, isLast); + } + }); } - public override Task SendAsync (ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + + // The damn difference between those two methods is that CloseAsync will wait for server acknowledgement before completing + // while CloseOutputAsync will send the close packet and simply complete. + + public async override Task CloseAsync (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) { - throw new NotImplementedException (); + EnsureWebSocketConnected (); + await SendCloseFrame (closeStatus, statusDescription, cancellationToken).ConfigureAwait (false); + state = WebSocketState.CloseSent; + // TODO: figure what's exceptions are thrown if the server returns something faulty here + await ReceiveAsync (new ArraySegment (new byte[0]), cancellationToken).ConfigureAwait (false); + state = WebSocketState.Closed; } - public override void Dispose () + + public async override Task CloseOutputAsync (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) { - throw new NotImplementedException (); + EnsureWebSocketConnected (); + await SendCloseFrame (closeStatus, statusDescription, cancellationToken).ConfigureAwait (false); + state = WebSocketState.CloseSent; } - public override WebSocketCloseStatus? CloseStatus { - get { - throw new NotImplementedException (); - } + + async Task SendCloseFrame (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) + { + var statusDescBuffer = string.IsNullOrEmpty (statusDescription) ? new byte[2] : new byte[2 + Encoding.UTF8.GetByteCount (statusDescription)]; + statusDescBuffer[0] = (byte)(((ushort)closeStatus) >> 8); + statusDescBuffer[1] = (byte)(((ushort)closeStatus) & 0xFF); + if (!string.IsNullOrEmpty (statusDescription)) + Encoding.UTF8.GetBytes (statusDescription, 0, statusDescription.Length, statusDescBuffer, 2); + await SendAsync (new ArraySegment (statusDescBuffer), WebSocketMessageType.Close, true, cancellationToken).ConfigureAwait (false); } - public override string CloseStatusDescription { - get { - throw new NotImplementedException (); + + int WriteHeader (WebSocketMessageType type, ArraySegment buffer, bool endOfMessage) + { + var opCode = (byte)type; + var length = buffer.Count; + + headerBuffer[0] = (byte)(opCode | (endOfMessage ? 0 : 0x80)); + if (length < 126) { + headerBuffer[1] = (byte)length; + } else if (length <= ushort.MaxValue) { + headerBuffer[1] = (byte)126; + headerBuffer[2] = (byte)(length / 256); + headerBuffer[3] = (byte)(length % 256); + } else { + headerBuffer[1] = (byte)127; + + int left = length; + int unit = 256; + + for (int i = 9; i > 1; i--) { + headerBuffer[i] = (byte)(left % unit); + left = left / unit; + } } + + var l = Math.Max (0, headerBuffer[1] - 125); + var maskOffset = 2 + l * l * 2; + GenerateMask (headerBuffer, maskOffset); + + // Since we are client only, we always mask the payload + headerBuffer[1] |= 0x80; + + return maskOffset; } - public override WebSocketState State { - get { - throw new NotImplementedException (); - } + + void GenerateMask (byte[] mask, int offset) + { + mask[offset + 0] = (byte)random.Next (0, 255); + mask[offset + 1] = (byte)random.Next (0, 255); + mask[offset + 2] = (byte)random.Next (0, 255); + mask[offset + 3] = (byte)random.Next (0, 255); } - public override string SubProtocol { - get { - throw new NotImplementedException (); - } + + void MaskData (ArraySegment buffer, int maskOffset) + { + var sendBufferOffset = maskOffset + 4; + for (var i = 0; i < buffer.Count; i++) + sendBuffer[i + sendBufferOffset] = (byte)(buffer.Array[buffer.Offset + i] ^ headerBuffer[maskOffset + (i % 4)]); + } + + void EnsureWebSocketConnected () + { + if (state < WebSocketState.Open) + throw new InvalidOperationException ("The WebSocket is not connected"); + } + + void EnsureWebSocketState (params WebSocketState[] validStates) + { + foreach (var validState in validStates) + if (state == validState) + return; + throw new WebSocketException ("The WebSocket is in an invalid state ('" + state + "') for this operation. Valid states are: " + string.Join (", ", validStates)); + } + + void ValidateArraySegment (ArraySegment segment) + { + if (segment.Array == null) + throw new ArgumentNullException ("buffer.Array"); + if (segment.Offset < 0) + throw new ArgumentOutOfRangeException ("buffer.Offset"); + if (segment.Offset + segment.Count > segment.Array.Length) + throw new ArgumentOutOfRangeException ("buffer.Count"); } - #endregion } } #endif - diff --git a/mcs/class/System/System.Net.WebSockets/ClientWebSocketOptions.cs b/mcs/class/System/System.Net.WebSockets/ClientWebSocketOptions.cs index a1d617cbadb88..586752d7f6756 100644 --- a/mcs/class/System/System.Net.WebSockets/ClientWebSocketOptions.cs +++ b/mcs/class/System/System.Net.WebSockets/ClientWebSocketOptions.cs @@ -33,11 +33,15 @@ using System.Security.Principal; using System.Security.Cryptography.X509Certificates; using System.Runtime.CompilerServices; +using System.Collections.Generic; namespace System.Net.WebSockets { public sealed class ClientWebSocketOptions { + List subprotocols = new List (); + Dictionary customRequestHeaders = new Dictionary (); + public X509CertificateCollection ClientCertificates { get; set; } public CookieContainer Cookies { get; set; } @@ -50,28 +54,53 @@ public sealed class ClientWebSocketOptions public bool UseDefaultCredentials { get; set; } - [MonoTODO] + internal IList SubProtocols { + get { + return subprotocols.AsReadOnly (); + } + } + + internal Dictionary CustomRequestHeaders { + get { + return customRequestHeaders; + } + } + + internal int ReceiveBufferSize { + get; + private set; + } + + internal ArraySegment CustomReceiveBuffer { + get; + private set; + } + + internal int SendBufferSize { + get; + private set; + } + public void AddSubProtocol (string subProtocol) { - throw new NotImplementedException (); + subprotocols.Add (subProtocol); } - [MonoTODO] public void SetBuffer (int receiveBufferSize, int sendBufferSize) { - throw new NotImplementedException (); + SetBuffer (receiveBufferSize, sendBufferSize, new ArraySegment ()); } - [MonoTODO] public void SetBuffer (int receiveBufferSize, int sendBufferSize, ArraySegment buffer) { - throw new NotImplementedException (); + ReceiveBufferSize = receiveBufferSize; + SendBufferSize = sendBufferSize; + CustomReceiveBuffer = buffer; } - [MonoTODO] public void SetRequestHeader (string headerName, string headerValue) { - throw new NotImplementedException (); + customRequestHeaders[headerName] = headerValue; } } } diff --git a/mcs/class/System/System.Net.WebSockets/WebSocketException.cs b/mcs/class/System/System.Net.WebSockets/WebSocketException.cs index b4980174f6bb7..e617ab38b0bfe 100644 --- a/mcs/class/System/System.Net.WebSockets/WebSocketException.cs +++ b/mcs/class/System/System.Net.WebSockets/WebSocketException.cs @@ -36,72 +36,68 @@ namespace System.Net.WebSockets { public sealed class WebSocketException : Win32Exception { - public WebSocketException () + const string DefaultMessage = "Generic WebSocket exception"; + + public WebSocketException () : this (WebSocketError.Success, -1, DefaultMessage, null) { } - public WebSocketException (int nativeError) : base (nativeError) + public WebSocketException (int nativeError) : this (WebSocketError.Success, nativeError, DefaultMessage, null) { } - public WebSocketException (string message) : base (message) + public WebSocketException (string message) : this (WebSocketError.Success, -1, message, null) { } - public WebSocketException (WebSocketError error) + public WebSocketException (WebSocketError error) : this (error, -1, DefaultMessage, null) { - WebSocketErrorCode = error; } - public WebSocketException (int nativeError, Exception innerException) + public WebSocketException (int nativeError, Exception innerException) : this (WebSocketError.Success, nativeError, DefaultMessage, innerException) { } - public WebSocketException (int nativeError, string message) : base (nativeError, message) + public WebSocketException (int nativeError, string message) : this (WebSocketError.Success, nativeError, message, null) { } - public WebSocketException (string message, Exception innerException) : base (message, innerException) + public WebSocketException (string message, Exception innerException) : this (WebSocketError.Success, -1, message, innerException) { } - public WebSocketException (WebSocketError error, Exception innerException) + public WebSocketException (WebSocketError error, Exception innerException) : this (error, -1, DefaultMessage, innerException) { - WebSocketErrorCode = error; + } - public WebSocketException (WebSocketError error, int nativeError) : base (nativeError) + public WebSocketException (WebSocketError error, int nativeError) : this (error, nativeError, DefaultMessage, null) { - WebSocketErrorCode = error; } - public WebSocketException (WebSocketError error, string message) : base (message) + public WebSocketException (WebSocketError error, string message) : this (error, -1, message, null) { - WebSocketErrorCode = error; } - public WebSocketException (WebSocketError error, int nativeError, Exception innerException) : base (nativeError) + public WebSocketException (WebSocketError error, int nativeError, Exception innerException) : this (error, nativeError, DefaultMessage, innerException) { - WebSocketErrorCode = error; } - public WebSocketException (WebSocketError error, int nativeError, string message) : base (nativeError, message) + public WebSocketException (WebSocketError error, int nativeError, string message) : this (error, nativeError, message, null) { - WebSocketErrorCode = error; } - public WebSocketException (WebSocketError error, string message, Exception innerException) + public WebSocketException (WebSocketError error, string message, Exception innerException) : this (error, -1, message, innerException) { - WebSocketErrorCode = error; } - public WebSocketException (WebSocketError error, int nativeError, string message, Exception innerException) : base (nativeError, message) + public WebSocketException (WebSocketError error, int nativeError, string message, Exception innerException) : base (message, innerException) { WebSocketErrorCode = error; } diff --git a/mcs/class/System/System.Net.WebSockets/WebSocketMessageType.cs b/mcs/class/System/System.Net.WebSockets/WebSocketMessageType.cs index 18e2d9ecbe8d4..50cbc003c0f7c 100644 --- a/mcs/class/System/System.Net.WebSockets/WebSocketMessageType.cs +++ b/mcs/class/System/System.Net.WebSockets/WebSocketMessageType.cs @@ -35,9 +35,9 @@ namespace System.Net.WebSockets { public enum WebSocketMessageType { - Text, - Binary, - Close + Text = 1, + Binary = 2, + Close = 8 } } diff --git a/mcs/class/System/System.Net.WebSockets/WebSocketReceiveResult.cs b/mcs/class/System/System.Net.WebSockets/WebSocketReceiveResult.cs index e237344e46687..af97ebcdca92c 100644 --- a/mcs/class/System/System.Net.WebSockets/WebSocketReceiveResult.cs +++ b/mcs/class/System/System.Net.WebSockets/WebSocketReceiveResult.cs @@ -36,20 +36,22 @@ namespace System.Net.WebSockets { public class WebSocketReceiveResult { - [MonoTODO] public WebSocketReceiveResult (int count, WebSocketMessageType messageType, bool endOfMessage) + : this (count, messageType, endOfMessage, null, null) { - throw new NotImplementedException (); } - [MonoTODO] public WebSocketReceiveResult (int count, WebSocketMessageType messageType, bool endOfMessage, WebSocketCloseStatus? closeStatus, string closeStatusDescription) { - throw new NotImplementedException (); + MessageType = messageType; + CloseStatus = closeStatus; + CloseStatusDescription = closeStatusDescription; + Count = count; + EndOfMessage = endOfMessage; } public WebSocketCloseStatus? CloseStatus { diff --git a/mcs/class/System/System.Net/HttpWebRequest.cs b/mcs/class/System/System.Net/HttpWebRequest.cs index 84d59f6802517..c3c4772b5aa3f 100644 --- a/mcs/class/System/System.Net/HttpWebRequest.cs +++ b/mcs/class/System/System.Net/HttpWebRequest.cs @@ -1163,7 +1163,9 @@ string GetHeaders () bool spoint10 = (proto_version == null || proto_version == HttpVersion.Version10); if (keepAlive && (version == HttpVersion.Version10 || spoint10)) { - webHeaders.RemoveAndAdd (connectionHeader, "keep-alive"); + if (webHeaders[connectionHeader] == null + || webHeaders[connectionHeader].IndexOf ("keep-alive", StringComparison.OrdinalIgnoreCase) == -1) + webHeaders.RemoveAndAdd (connectionHeader, "keep-alive"); } else if (!keepAlive && version == HttpVersion.Version11) { webHeaders.RemoveAndAdd (connectionHeader, "close"); } @@ -1605,6 +1607,13 @@ bool CheckFinalStatus (WebAsyncResult result) throw throwMe; } + + internal bool ReuseConnection { + get; + set; + } + + internal WebConnection StoredConnection; } } diff --git a/mcs/class/System/System.Net/WebConnection.cs b/mcs/class/System/System.Net/WebConnection.cs index 7f09859cb5e6c..cf07f572effc6 100644 --- a/mcs/class/System/System.Net/WebConnection.cs +++ b/mcs/class/System/System.Net/WebConnection.cs @@ -62,7 +62,7 @@ class WebConnection { ServicePoint sPoint; Stream nstream; - Socket socket; + internal Socket socket; object socketLock = new object (); WebExceptionStatus status; WaitCallback initConn; @@ -750,6 +750,8 @@ void InitConnection (object state) { HttpWebRequest request = (HttpWebRequest) state; request.WebConnection = this; + if (request.ReuseConnection) + request.StoredConnection = this; if (request.Aborted) return; @@ -1183,6 +1185,11 @@ internal bool Write (HttpWebRequest request, byte [] buffer, int offset, int siz internal void Close (bool sendNext) { lock (this) { + if (Data != null && Data.request != null && Data.request.ReuseConnection) { + Data.request.ReuseConnection = false; + return; + } + if (nstream != null) { try { nstream.Close (); diff --git a/mcs/class/System/System_test.dll.sources b/mcs/class/System/System_test.dll.sources index a306d52be4078..33f8326b2d6e9 100644 --- a/mcs/class/System/System_test.dll.sources +++ b/mcs/class/System/System_test.dll.sources @@ -499,3 +499,4 @@ System.Collections.Concurrent/BlockingCollectionTests.cs System.Collections.Concurrent/ConcurrentBagTests.cs System.Collections.Concurrent/CollectionStressTestHelper.cs System.Collections.Concurrent/ParallelTestHelper.cs +System.Net.WebSockets/ClientWebSocketTest.cs diff --git a/mcs/class/System/Test/System.Net.WebSockets/ClientWebSocketTest.cs b/mcs/class/System/Test/System.Net.WebSockets/ClientWebSocketTest.cs new file mode 100644 index 0000000000000..212d5db9ca5f3 --- /dev/null +++ b/mcs/class/System/Test/System.Net.WebSockets/ClientWebSocketTest.cs @@ -0,0 +1,242 @@ +using System; +using System.Net; +using System.Threading; +using System.Threading.Tasks; +using System.Collections.Generic; +using System.Net.WebSockets; +using System.Reflection; +using System.Text; + +using NUnit.Framework; + +#if NET_4_5 + +namespace MonoTests.System.Net.WebSockets +{ + [TestFixture] + public class ClientWebSocketTest + { + const string EchoServerUrl = "ws://echo.websocket.org"; + const int Port = 42123; + HttpListener listener; + ClientWebSocket socket; + MethodInfo headerSetMethod; + + [SetUp] + public void Setup () + { + listener = new HttpListener (); + listener.Prefixes.Add ("http://localhost:" + Port + "/"); + listener.Start (); + socket = new ClientWebSocket (); + } + + [TearDown] + public void Teardown () + { + if (listener != null) { + listener.Stop (); + listener = null; + } + if (socket != null) { + if (socket.State == WebSocketState.Open) + socket.CloseAsync (WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).Wait (); + socket.Dispose (); + socket = null; + } + } + + [Test] + public void ServerHandshakeReturnCrapStatusCodeTest () + { + HandleHttpRequestAsync ((req, resp) => resp.StatusCode = 418); + try { + socket.ConnectAsync (new Uri ("ws://localhost:" + Port), CancellationToken.None).Wait (); + } catch (AggregateException e) { + AssertWebSocketException (e, WebSocketError.Success, typeof (WebException)); + return; + } + Assert.Fail ("Should have thrown"); + } + + [Test] + public void ServerHandshakeReturnWrongUpgradeHeader () + { + HandleHttpRequestAsync ((req, resp) => { + resp.StatusCode = 101; + resp.Headers["Upgrade"] = "gtfo"; + }); + try { + socket.ConnectAsync (new Uri ("ws://localhost:" + Port), CancellationToken.None).Wait (); + } catch (AggregateException e) { + AssertWebSocketException (e, WebSocketError.Success); + return; + } + Assert.Fail ("Should have thrown"); + } + + [Test] + public void ServerHandshakeReturnWrongConnectionHeader () + { + HandleHttpRequestAsync ((req, resp) => { + resp.StatusCode = 101; + resp.Headers["Upgrade"] = "websocket"; + // Mono http request doesn't like the forcing, test still valid since the default connection header value is empty + //ForceSetHeader (resp.Headers, "Connection", "Foo"); + }); + try { + socket.ConnectAsync (new Uri ("ws://localhost:" + Port), CancellationToken.None).Wait (); + } catch (AggregateException e) { + AssertWebSocketException (e, WebSocketError.Success); + return; + } + Assert.Fail ("Should have thrown"); + } + + [Test] + public void EchoTest () + { + const string Payload = "This is a websocket test"; + + Assert.AreEqual (WebSocketState.None, socket.State); + socket.ConnectAsync (new Uri (EchoServerUrl), CancellationToken.None).Wait (); + Assert.AreEqual (WebSocketState.Open, socket.State); + + var sendBuffer = Encoding.ASCII.GetBytes (Payload); + socket.SendAsync (new ArraySegment (sendBuffer), WebSocketMessageType.Text, true, CancellationToken.None).Wait (); + + var receiveBuffer = new byte[Payload.Length]; + var resp = socket.ReceiveAsync (new ArraySegment (receiveBuffer), CancellationToken.None).Result; + + Assert.AreEqual (Payload.Length, resp.Count); + Assert.IsTrue (resp.EndOfMessage); + Assert.AreEqual (WebSocketMessageType.Text, resp.MessageType); + Assert.AreEqual (Payload, Encoding.ASCII.GetString (receiveBuffer, 0, resp.Count)); + + socket.CloseAsync (WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).Wait (); + Assert.AreEqual (WebSocketState.Closed, socket.State); + } + + [Test] + public void CloseOutputAsyncTest () + { + socket.ConnectAsync (new Uri (EchoServerUrl), CancellationToken.None).Wait (); + Assert.AreEqual (WebSocketState.Open, socket.State); + + socket.CloseOutputAsync (WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).Wait (); + Assert.AreEqual (WebSocketState.CloseSent, socket.State); + + var resp = socket.ReceiveAsync (new ArraySegment (new byte[0]), CancellationToken.None).Result; + Assert.AreEqual (WebSocketState.Closed, socket.State); + Assert.AreEqual (WebSocketMessageType.Close, resp.MessageType); + Assert.AreEqual (WebSocketCloseStatus.NormalClosure, resp.CloseStatus); + Assert.AreEqual (string.Empty, resp.CloseStatusDescription); + } + + [Test] + public void CloseAsyncTest () + { + socket.ConnectAsync (new Uri (EchoServerUrl), CancellationToken.None).Wait (); + Assert.AreEqual (WebSocketState.Open, socket.State); + + socket.CloseAsync (WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).Wait (); + Assert.AreEqual (WebSocketState.Closed, socket.State); + } + + [Test, ExpectedException (typeof (InvalidOperationException))] + public void SendAsyncArgTest_NotConnected () + { + socket.SendAsync (new ArraySegment (new byte[0]), WebSocketMessageType.Text, true, CancellationToken.None); + } + + [Test, ExpectedException (typeof (ArgumentNullException))] + public void SendAsyncArgTest_NoArray () + { + socket.ConnectAsync (new Uri (EchoServerUrl), CancellationToken.None).Wait (); + socket.SendAsync (new ArraySegment (), WebSocketMessageType.Text, true, CancellationToken.None); + } + + [Test, ExpectedException (typeof (InvalidOperationException))] + public void ReceiveAsyncArgTest_NotConnected () + { + socket.ReceiveAsync (new ArraySegment (new byte[0]), CancellationToken.None); + } + + [Test, ExpectedException (typeof (ArgumentNullException))] + public void ReceiveAsyncArgTest_NoArray () + { + socket.ConnectAsync (new Uri (EchoServerUrl), CancellationToken.None).Wait (); + socket.ReceiveAsync (new ArraySegment (), CancellationToken.None); + } + + [Test] + public void ReceiveAsyncWrongState_Closed () + { + try { + socket.ConnectAsync (new Uri (EchoServerUrl), CancellationToken.None).Wait (); + socket.CloseAsync (WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).Wait (); + socket.ReceiveAsync (new ArraySegment (new byte[0]), CancellationToken.None).Wait (); + } catch (AggregateException e) { + AssertWebSocketException (e, WebSocketError.Success); + return; + } + Assert.Fail ("Should have thrown"); + } + + [Test] + public void SendAsyncWrongState_Closed () + { + try { + socket.ConnectAsync (new Uri (EchoServerUrl), CancellationToken.None).Wait (); + socket.CloseAsync (WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).Wait (); + socket.SendAsync (new ArraySegment (new byte[0]), WebSocketMessageType.Text, true, CancellationToken.None).Wait (); + } catch (AggregateException e) { + AssertWebSocketException (e, WebSocketError.Success); + return; + } + Assert.Fail ("Should have thrown"); + } + + [Test] + public void SendAsyncWrongState_CloseSent () + { + try { + socket.ConnectAsync (new Uri (EchoServerUrl), CancellationToken.None).Wait (); + socket.CloseOutputAsync (WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).Wait (); + socket.SendAsync (new ArraySegment (new byte[0]), WebSocketMessageType.Text, true, CancellationToken.None).Wait (); + } catch (AggregateException e) { + AssertWebSocketException (e, WebSocketError.Success); + return; + } + Assert.Fail ("Should have thrown"); + } + + async Task HandleHttpRequestAsync (Action handler) + { + var ctx = await listener.GetContextAsync (); + handler (ctx.Request, ctx.Response); + ctx.Response.Close (); + } + + void AssertWebSocketException (AggregateException e, WebSocketError error, Type inner = null) + { + var wsEx = e.InnerException as WebSocketException; + Console.WriteLine (e.InnerException.ToString ()); + Assert.IsNotNull (wsEx, "Not a websocketexception"); + Assert.AreEqual (error, wsEx.WebSocketErrorCode); + if (inner != null) { + Assert.IsNotNull (wsEx.InnerException); + Assert.IsInstanceOfType (inner, wsEx.InnerException); + } + } + + void ForceSetHeader (WebHeaderCollection headers, string name, string value) + { + if (headerSetMethod == null) + headerSetMethod = typeof (WebHeaderCollection).GetMethod ("AddValue", BindingFlags.NonPublic); + headerSetMethod.Invoke (headers, new[] { name, value }); + } + } +} + +#endif