Skip to content

Commit

Permalink
Don't call user callbacks on MsQuic worker thread. (#98361)
Browse files Browse the repository at this point in the history
* Allow switching execution profiles using env vars

* Quick and dirty version to enable benchmarking

* Don't call callbacks from MsQuic threads

* Remove unintentional changes

* Offload parsing to threadpool as well

* Customize TLS ALERT code

* Code review feedback

* Apply suggestions from code review

Co-authored-by: Marie Píchová <11718369+ManickaP@users.noreply.github.com>

* use ConfigureAwaitOptions.ForceYielding

* Version check to work around microsoft/msquic#4132

* Use configure await to yield to threadpool

* Fix functionality on older msquic versions

---------

Co-authored-by: Marie Píchová <11718369+ManickaP@users.noreply.github.com>
  • Loading branch information
rzikm and ManickaP committed Feb 28, 2024
1 parent 9bc96d0 commit 56af107
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ internal static class CertificateValidation
private static readonly IdnMapping s_idnMapping = new IdnMapping();

// WARNING: This function will do the verification using OpenSSL. If the intention is to use OS function, caller should use CertificatePal interface.
internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool _ /*isServer*/, string? hostName, IntPtr certificateBuffer, int bufferLength = 0)
internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool _ /*isServer*/, string? hostName, Span<byte> certificateBuffer)
{
SslPolicyErrors errors = chain.Build(remoteCertificate) ?
SslPolicyErrors.None :
Expand All @@ -31,15 +31,24 @@ internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X
}

SafeX509Handle certHandle;
if (certificateBuffer != IntPtr.Zero && bufferLength > 0)
unsafe
{
certHandle = Interop.Crypto.DecodeX509(certificateBuffer, bufferLength);
}
else
{
// We dont't have DER encoded buffer.
byte[] der = remoteCertificate.Export(X509ContentType.Cert);
certHandle = Interop.Crypto.DecodeX509(Marshal.UnsafeAddrOfPinnedArrayElement(der, 0), der.Length);
if (certificateBuffer.Length > 0)
{
fixed (byte* pCert = certificateBuffer)
{
certHandle = Interop.Crypto.DecodeX509((IntPtr)pCert, certificateBuffer.Length);
}
}
else
{
// We dont't have DER encoded buffer.
byte[] der = remoteCertificate.Export(X509ContentType.Cert);
fixed (byte* pDer = der)
{
certHandle = Interop.Crypto.DecodeX509((IntPtr)pDer, der.Length);
}
}
}

int hostNameMatch;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ internal static class CertificateValidation
private static readonly IdnMapping s_idnMapping = new IdnMapping();

#pragma warning disable IDE0060
internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool isServer, string? hostName, IntPtr certificateBuffer, int bufferLength)
internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool isServer, string? hostName, Span<byte> certificateBuffer)
=> BuildChainAndVerifyProperties(chain, remoteCertificate, checkCertName, isServer, hostName);
#pragma warning restore IDE0060

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace System.Net
internal static partial class CertificateValidation
{
#pragma warning disable IDE0060
internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool isServer, string? hostName, IntPtr certificateBuffer, int bufferLength)
internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool isServer, string? hostName, Span<byte> certificateBuffer)
=> BuildChainAndVerifyProperties(chain, remoteCertificate, checkCertName, isServer, hostName);
#pragma warning restore IDE0060

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,4 +375,55 @@ public int StreamReceiveSetEnabled(MsQuicSafeHandle stream, byte enabled)
}
}
}

public int DatagramSend(MsQuicSafeHandle connection, QUIC_BUFFER* buffers, uint buffersCount, QUIC_SEND_FLAGS flags, void* context)
{
bool success = false;
try
{
connection.DangerousAddRef(ref success);
return ApiTable->DatagramSend(connection.QuicHandle, buffers, buffersCount, flags, context);
}
finally
{
if (success)
{
connection.DangerousRelease();
}
}
}

public int ConnectionResumptionTicketValidationComplete(MsQuicSafeHandle connection, byte result)
{
bool success = false;
try
{
connection.DangerousAddRef(ref success);
return ApiTable->ConnectionResumptionTicketValidationComplete(connection.QuicHandle, result);
}
finally
{
if (success)
{
connection.DangerousRelease();
}
}
}

public int ConnectionCertificateValidationComplete(MsQuicSafeHandle connection, byte result, QUIC_TLS_ALERT_CODES alert)
{
bool success = false;
try
{
connection.DangerousAddRef(ref success);
return ApiTable->ConnectionCertificateValidationComplete(connection.QuicHandle, result, alert);
}
finally
{
if (success)
{
connection.DangerousRelease();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,16 @@ private MsQuicApi(QUIC_API_TABLE* apiTable)
private static readonly Lazy<MsQuicApi> _api = new Lazy<MsQuicApi>(AllocateMsQuicApi);
internal static MsQuicApi Api => _api.Value;

internal static Version? Version { get; private set; }

internal static bool IsQuicSupported { get; }

internal static string MsQuicLibraryVersion { get; } = "unknown";
internal static string? NotSupportedReason { get; }

// workaround for https://github.com/microsoft/msquic/issues/4132
internal static bool SupportsAsyncCertValidation => Version >= new Version(2, 4, 0);

internal static bool UsesSChannelBackend { get; }

internal static bool Tls13ServerMayBeDisabled { get; }
Expand All @@ -69,6 +74,7 @@ static MsQuicApi()
{
bool loaded = false;
IntPtr msQuicHandle;
Version = default;

// MsQuic is using DualMode sockets and that will fail even for IPv4 if AF_INET6 is not available.
if (!Socket.OSSupportsIPv6)
Expand Down Expand Up @@ -135,7 +141,7 @@ static MsQuicApi()
}
return;
}
Version version = new Version((int)libVersion[0], (int)libVersion[1], (int)libVersion[2], (int)libVersion[3]);
Version = new Version((int)libVersion[0], (int)libVersion[1], (int)libVersion[2], (int)libVersion[3]);

paramSize = 64 * sizeof(sbyte);
sbyte* libGitHash = stackalloc sbyte[64];
Expand All @@ -150,11 +156,11 @@ static MsQuicApi()
}
string? gitHash = Marshal.PtrToStringUTF8((IntPtr)libGitHash);

MsQuicLibraryVersion = $"{Interop.Libraries.MsQuic} {version} ({gitHash})";
MsQuicLibraryVersion = $"{Interop.Libraries.MsQuic} {Version} ({gitHash})";

if (version < s_minMsQuicVersion)
if (Version < s_minMsQuicVersion)
{
NotSupportedReason = $"Incompatible MsQuic library version '{version}', expecting higher than '{s_minMsQuicVersion}'.";
NotSupportedReason = $"Incompatible MsQuic library version '{Version}', expecting higher than '{s_minMsQuicVersion}'.";
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(null, NotSupportedReason);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Buffers;
using System.Diagnostics;
using System.Net.Security;
using System.Security.Authentication;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
using Microsoft.Quic;
using static Microsoft.Quic.MsQuic;

Expand Down Expand Up @@ -63,18 +66,122 @@ public partial class QuicConnection
_certificateChainPolicy = certificateChainPolicy;
}

public unsafe int ValidateCertificate(QUIC_BUFFER* certificatePtr, QUIC_BUFFER* chainPtr, out X509Certificate2? certificate)
internal async Task<bool> StartAsyncCertificateValidation(IntPtr certificatePtr, IntPtr chainPtr)
{
//
// The provided data pointers are valid only while still inside this function, so they need to be
// copied to separate buffers which are then handed off to threadpool.
//

X509Certificate2? certificate = null;

byte[]? certDataRented = null;
Memory<byte> certData = default;
byte[]? chainDataRented = null;
Memory<byte> chainData = default;

if (certificatePtr != IntPtr.Zero)
{
if (MsQuicApi.UsesSChannelBackend)
{
// provided data is a pointer to a CERT_CONTEXT
certificate = new X509Certificate2(certificatePtr);
// TODO: what about chainPtr?
}
else
{
unsafe
{
// On non-SChannel backends we specify USE_PORTABLE_CERTIFICATES and the contents are buffers
// with DER encoded cert and chain.
QUIC_BUFFER* certificateBuffer = (QUIC_BUFFER*)certificatePtr;
QUIC_BUFFER* chainBuffer = (QUIC_BUFFER*)chainPtr;

if (certificateBuffer->Length > 0)
{
certDataRented = ArrayPool<byte>.Shared.Rent((int)certificateBuffer->Length);
certData = certDataRented.AsMemory(0, (int)certificateBuffer->Length);
certificateBuffer->Span.CopyTo(certData.Span);
}

if (chainBuffer->Length > 0)
{
chainDataRented = ArrayPool<byte>.Shared.Rent((int)chainBuffer->Length);
chainData = chainDataRented.AsMemory(0, (int)chainBuffer->Length);
chainBuffer->Span.CopyTo(chainData.Span);
}
}
}
}

// We wan't to do the certificate validation asynchronously, but due to a bug in MsQuic, we need to call the callback synchronously on some versions
if (MsQuicApi.SupportsAsyncCertValidation)
{
// force yield to the thread pool to free up MsQuic worker thread.
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
}

// certificatePtr and chainPtr are invalid beyond this point

QUIC_TLS_ALERT_CODES result;
try
{
if (certData.Length > 0)
{
Debug.Assert(certificate == null);
certificate = new X509Certificate2(certData.Span);
}

result = _connection._sslConnectionOptions.ValidateCertificate(certificate, certData.Span, chainData.Span);
_connection._remoteCertificate = certificate;
}
catch (Exception ex)
{
certificate?.Dispose();
_connection._connectedTcs.TrySetException(ex);
result = QUIC_TLS_ALERT_CODES.USER_CANCELED;
}
finally
{
if (certDataRented != null)
{
ArrayPool<byte>.Shared.Return(certDataRented);
}

if (chainDataRented != null)
{
ArrayPool<byte>.Shared.Return(chainDataRented);
}
}

if (MsQuicApi.SupportsAsyncCertValidation)
{
int status = MsQuicApi.Api.ConnectionCertificateValidationComplete(
_connection._handle,
result == QUIC_TLS_ALERT_CODES.SUCCESS ? (byte)1 : (byte)0,
result);

if (MsQuic.StatusFailed(status))
{
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Error(_connection, $"{_connection} ConnectionCertificateValidationComplete failed with {ThrowHelper.GetErrorMessageForStatus(status)}");
}
}
}

return result == QUIC_TLS_ALERT_CODES.SUCCESS;
}

private QUIC_TLS_ALERT_CODES ValidateCertificate(X509Certificate2? certificate, Span<byte> certData, Span<byte> chainData)
{
SslPolicyErrors sslPolicyErrors = SslPolicyErrors.None;
IntPtr certificateBuffer = 0;
int certificateLength = 0;
bool wrapException = false;

X509Chain? chain = null;
X509Certificate2? result = null;
try
{
if (certificatePtr is not null)
if (certificate is not null)
{
chain = new X509Chain();
if (_certificateChainPolicy != null)
Expand All @@ -96,51 +203,34 @@ public unsafe int ValidateCertificate(QUIC_BUFFER* certificatePtr, QUIC_BUFFER*
chain.ChainPolicy.ApplicationPolicy.Add(_isClient ? s_serverAuthOid : s_clientAuthOid);
}

if (MsQuicApi.UsesSChannelBackend)
if (chainData.Length > 0)
{
result = new X509Certificate2((IntPtr)certificatePtr);
X509Certificate2Collection additionalCertificates = new X509Certificate2Collection();
additionalCertificates.Import(chainData);
chain.ChainPolicy.ExtraStore.AddRange(additionalCertificates);
}
else
{
if (certificatePtr->Length > 0)
{
certificateBuffer = (IntPtr)certificatePtr->Buffer;
certificateLength = (int)certificatePtr->Length;
result = new X509Certificate2(certificatePtr->Span);
}

if (chainPtr->Length > 0)
{
X509Certificate2Collection additionalCertificates = new X509Certificate2Collection();
additionalCertificates.Import(chainPtr->Span);
chain.ChainPolicy.ExtraStore.AddRange(additionalCertificates);
}
}
}

if (result is not null)
{
bool checkCertName = !chain!.ChainPolicy!.VerificationFlags.HasFlag(X509VerificationFlags.IgnoreInvalidName);
sslPolicyErrors |= CertificateValidation.BuildChainAndVerifyProperties(chain!, result, checkCertName, !_isClient, TargetHostNameHelper.NormalizeHostName(_targetHost), certificateBuffer, certificateLength);
sslPolicyErrors |= CertificateValidation.BuildChainAndVerifyProperties(chain!, certificate, checkCertName, !_isClient, TargetHostNameHelper.NormalizeHostName(_targetHost), certData);
}
else if (_certificateRequired)
{
sslPolicyErrors |= SslPolicyErrors.RemoteCertificateNotAvailable;
}

int status = QUIC_STATUS_SUCCESS;
QUIC_TLS_ALERT_CODES result = QUIC_TLS_ALERT_CODES.SUCCESS;
if (_validationCallback is not null)
{
wrapException = true;
if (!_validationCallback(_connection, result, chain, sslPolicyErrors))
if (!_validationCallback(_connection, certificate, chain, sslPolicyErrors))
{
wrapException = false;
if (_isClient)
{
throw new AuthenticationException(SR.net_quic_cert_custom_validation);
}

status = QUIC_STATUS_USER_CANCELED;
result = QUIC_TLS_ALERT_CODES.BAD_CERTIFICATE;
}
}
else if (sslPolicyErrors != SslPolicyErrors.None)
Expand All @@ -150,15 +240,13 @@ public unsafe int ValidateCertificate(QUIC_BUFFER* certificatePtr, QUIC_BUFFER*
throw new AuthenticationException(SR.Format(SR.net_quic_cert_chain_validation, sslPolicyErrors));
}

status = QUIC_STATUS_HANDSHAKE_FAILURE;
result = QUIC_TLS_ALERT_CODES.BAD_CERTIFICATE;
}

certificate = result;
return status;
return result;
}
catch (Exception ex)
{
result?.Dispose();
if (wrapException)
{
throw new QuicException(QuicError.CallbackError, null, SR.net_quic_callback_error, ex);
Expand Down
Loading

0 comments on commit 56af107

Please sign in to comment.