Skip to content

Commit

Permalink
Implementing SNINetworkStream fixes issue 422 for non-encrypted TCP c…
Browse files Browse the repository at this point in the history
…onnections.
  • Loading branch information
cheenamalhotra committed Nov 17, 2020
1 parent 3eaa757 commit de07449
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIPhysicalHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIProxy.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNITcpHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNISslStream.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIStreams.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNICommon.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SspiClientContextStatus.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using System.Net.Sockets;

namespace Microsoft.Data.SqlClient.SNI
{
Expand All @@ -30,7 +31,49 @@ public override Task<int> ReadAsync(byte[] buffer, int offset, int count, Cancel
try
{
return _readAsyncQueueSemaphore.WaitAsync()
.ContinueWith<int>(_ => base.ReadAsync(buffer, offset, count, cancellationToken).GetAwaiter().GetResult());
.ContinueWith(_ => base.ReadAsync(buffer, offset, count, cancellationToken).GetAwaiter().GetResult());
}
finally
{
_readAsyncQueueSemaphore.Release();
}
}

// Prevent the WriteAsync's collision by running task in Semaphore Slim
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
try
{
return _writeAsyncQueueSemaphore.WaitAsync().ContinueWith(_ => base.WriteAsync(buffer, offset, count, cancellationToken));
}
finally
{
_writeAsyncQueueSemaphore.Release();
}
}
}

/// <summary>
/// This class extends NetworkStream to customize stream behavior for Managed SNI implementation.
/// </summary>
internal class SNINetworkStream : NetworkStream
{
private readonly ConcurrentQueueSemaphore _writeAsyncQueueSemaphore;
private readonly ConcurrentQueueSemaphore _readAsyncQueueSemaphore;

public SNINetworkStream(Socket socket, bool ownsSocket) : base(socket, ownsSocket)
{
_writeAsyncQueueSemaphore = new ConcurrentQueueSemaphore(1);
_readAsyncQueueSemaphore = new ConcurrentQueueSemaphore(1);
}

// Prevent the ReadAsync's collision by running task in Semaphore Slim
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
try
{
return _readAsyncQueueSemaphore.WaitAsync()
.ContinueWith(_ => base.ReadAsync(buffer, offset, count, cancellationToken).GetAwaiter().GetResult());
}
finally
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
}

_socket.NoDelay = true;
_tcpStream = new NetworkStream(_socket, true);
_tcpStream = new SNINetworkStream(_socket, true);

_sslOverTdsStream = new SslOverTdsStream(_tcpStream);
_sslStream = new SNISslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate));
Expand Down

0 comments on commit de07449

Please sign in to comment.