Skip to content

Commit

Permalink
H/3 Server Cert validation callback exception fix (#55526)
Browse files Browse the repository at this point in the history
* Fix and test.

* MsQuicConnection now call the cert validation callback only once, removed code duplication
  • Loading branch information
ManickaP committed Jul 13, 2021
1 parent cf73943 commit 9a9b105
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public CertificateCallbackMapper(Func<HttpRequestMessage, X509Certificate2?, X50
}
}

public static ValueTask<SslStream> EstablishSslConnectionAsync(SslClientAuthenticationOptions sslOptions, HttpRequestMessage request, bool async, Stream stream, CancellationToken cancellationToken)
private static SslClientAuthenticationOptions SetUpRemoteCertificateValidationCallback(SslClientAuthenticationOptions sslOptions, HttpRequestMessage request)
{
// If there's a cert validation callback, and if it came from HttpClientHandler,
// wrap the original delegate in order to change the sender to be the request message (expected by HttpClientHandler's delegate).
Expand All @@ -52,12 +52,13 @@ public static ValueTask<SslStream> EstablishSslConnectionAsync(SslClientAuthenti
};
}

// Create the SslStream, authenticate, and return it.
return EstablishSslConnectionAsyncCore(async, stream, sslOptions, cancellationToken);
return sslOptions;
}

private static async ValueTask<SslStream> EstablishSslConnectionAsyncCore(bool async, Stream stream, SslClientAuthenticationOptions sslOptions, CancellationToken cancellationToken)
public static async ValueTask<SslStream> EstablishSslConnectionAsync(SslClientAuthenticationOptions sslOptions, HttpRequestMessage request, bool async, Stream stream, CancellationToken cancellationToken)
{
sslOptions = SetUpRemoteCertificateValidationCallback(sslOptions, request);

SslStream sslStream = new SslStream(stream);

try
Expand Down Expand Up @@ -104,8 +105,10 @@ private static async ValueTask<SslStream> EstablishSslConnectionAsyncCore(bool a
[SupportedOSPlatform("windows")]
[SupportedOSPlatform("linux")]
[SupportedOSPlatform("macos")]
public static async ValueTask<QuicConnection> ConnectQuicAsync(QuicImplementationProvider quicImplementationProvider, DnsEndPoint endPoint, SslClientAuthenticationOptions? clientAuthenticationOptions, CancellationToken cancellationToken)
public static async ValueTask<QuicConnection> ConnectQuicAsync(HttpRequestMessage request, QuicImplementationProvider quicImplementationProvider, DnsEndPoint endPoint, SslClientAuthenticationOptions clientAuthenticationOptions, CancellationToken cancellationToken)
{
clientAuthenticationOptions = SetUpRemoteCertificateValidationCallback(clientAuthenticationOptions, request);

QuicConnection con = new QuicConnection(quicImplementationProvider, endPoint, clientAuthenticationOptions);
try
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ private async ValueTask<Http3Connection> GetHttp3ConnectionAsync(HttpRequestMess
QuicConnection quicConnection;
try
{
quicConnection = await ConnectHelper.ConnectQuicAsync(Settings._quicImplementationProvider ?? QuicImplementationProviders.Default, new DnsEndPoint(authority.IdnHost, authority.Port), _sslOptionsHttp3, cancellationToken).ConfigureAwait(false);
quicConnection = await ConnectHelper.ConnectQuicAsync(request, Settings._quicImplementationProvider ?? QuicImplementationProviders.Default, new DnsEndPoint(authority.IdnHost, authority.Port), _sslOptionsHttp3!, cancellationToken).ConfigureAwait(false);
}
catch
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,64 @@ public async Task ReservedFrameType_Throws()
await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000);
}

[Fact]
public async Task ServerCertificateCustomValidationCallback_Succeeds()
{
// Mock doesn't make use of cart validation callback.
if (UseQuicImplementationProvider == QuicImplementationProviders.Mock)
{
return;
}

HttpRequestMessage? callbackRequest = null;
int invocationCount = 0;

var httpClientHandler = CreateHttpClientHandler();
httpClientHandler.ServerCertificateCustomValidationCallback = (request, _, _, _) =>
{
callbackRequest = request;
++invocationCount;
return true;
};

using Http3LoopbackServer server = CreateHttp3LoopbackServer();
using HttpClient client = CreateHttpClient(httpClientHandler);

Task serverTask = Task.Run(async () =>
{
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await stream.HandleRequestAsync();
using Http3LoopbackStream stream2 = await connection.AcceptRequestStreamAsync();
await stream2.HandleRequestAsync();
});

var request = new HttpRequestMessage(HttpMethod.Get, server.Address);
request.Version = HttpVersion.Version30;
request.VersionPolicy = HttpVersionPolicy.RequestVersionExact;

var response = await client.SendAsync(request);

response.EnsureSuccessStatusCode();
Assert.Equal(HttpVersion.Version30, response.Version);
Assert.Same(request, callbackRequest);
Assert.Equal(1, invocationCount);

// Second request, the callback shouldn't be hit at all.
callbackRequest = null;

request = new HttpRequestMessage(HttpMethod.Get, server.Address);
request.Version = HttpVersion.Version30;
request.VersionPolicy = HttpVersionPolicy.RequestVersionExact;

response = await client.SendAsync(request);

response.EnsureSuccessStatusCode();
Assert.Equal(HttpVersion.Version30, response.Version);
Assert.Null(callbackRequest);
Assert.Equal(1, invocationCount);
}

[OuterLoop]
[ConditionalTheory(nameof(IsMsQuicSupported))]
[MemberData(nameof(InteropUris))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,10 @@ private static uint HandleEventPeerCertificateReceived(State state, ref Connecti
if (connection._remoteCertificateValidationCallback != null)
{
bool success = connection._remoteCertificateValidationCallback(connection, certificate, chain, sslPolicyErrors);
// Unset the callback to prevent multiple invocations of the callback per a single connection.
// Return the same value as the custom callback just did.
connection._remoteCertificateValidationCallback = (_, _, _, _) => success;

if (!success && NetEventSource.Log.IsEnabled())
NetEventSource.Error(state, $"{state.TraceId} Remote certificate rejected by verification callback");
return success ? MsQuicStatusCodes.Success : MsQuicStatusCodes.HandshakeFailure;
Expand Down

0 comments on commit 9a9b105

Please sign in to comment.