Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TargetHostName to QuicConnection #84976

Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Buffers;
using System.Collections.Generic;
using System.Globalization;
using System.Runtime.InteropServices;

namespace System.Net.Security
{
internal static class TargetHostNameHelper
{
private static readonly IdnMapping s_idnMapping = new IdnMapping();
private static readonly IndexOfAnyValues<char> s_safeDnsChars =
IndexOfAnyValues.Create("-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz");

private static bool IsSafeDnsString(ReadOnlySpan<char> name) =>
name.IndexOfAnyExcept(s_safeDnsChars) < 0;

internal static string NormalizeHostName(string? targetHost)
{
if (string.IsNullOrEmpty(targetHost))
{
return string.Empty;
}

// RFC 6066 section 3 says to exclude trailing dot from fully qualified DNS hostname
targetHost = targetHost.TrimEnd('.');

try
{
return s_idnMapping.GetAscii(targetHost);
}
catch (ArgumentException) when (IsSafeDnsString(targetHost))
{
// Seems like name that does not confrom to IDN but apers somewhat valid according to original DNS rfc.
}

return targetHost;
}

// Simplified version of IPAddressParser.Parse to avoid allocations and dependencies.
// It purposely ignores scopeId as we don't really use so we do not need to map it to actual interface id.
internal static unsafe bool IsValidAddress(string? hostname)
{
if (string.IsNullOrEmpty(hostname))
{
return false;
}

ReadOnlySpan<char> ipSpan = hostname.AsSpan();

int end = ipSpan.Length;

if (ipSpan.Contains(':'))
{
// The address is parsed as IPv6 if and only if it contains a colon. This is valid because
// we don't support/parse a port specification at the end of an IPv4 address.
Span<ushort> numbers = stackalloc ushort[IPAddressParserStatics.IPv6AddressShorts];

fixed (char* ipStringPtr = &MemoryMarshal.GetReference(ipSpan))
{
return IPv6AddressHelper.IsValidStrict(ipStringPtr, 0, ref end);
}
}
else if (char.IsDigit(ipSpan[0]))
{
long tmpAddr;

fixed (char* ipStringPtr = &MemoryMarshal.GetReference(ipSpan))
{
tmpAddr = IPv4AddressHelper.ParseNonCanonical(ipStringPtr, 0, ref end, notImplicitFile: true);
}

if (tmpAddr != IPv4AddressHelper.Invalid && end == ipSpan.Length)
{
return true;
}
}

return false;
}
}
}
2 changes: 2 additions & 0 deletions src/libraries/System.Net.Quic/ref/System.Net.Quic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public sealed partial class QuicConnection : System.IAsyncDisposable
public System.Net.Security.SslApplicationProtocol NegotiatedApplicationProtocol { get { throw null; } }
public System.Security.Cryptography.X509Certificates.X509Certificate? RemoteCertificate { get { throw null; } }
public System.Net.IPEndPoint RemoteEndPoint { get { throw null; } }
public string TargetHostName { get { throw null; } }
public System.Threading.Tasks.ValueTask<System.Net.Quic.QuicStream> AcceptInboundStreamAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public System.Threading.Tasks.ValueTask CloseAsync(long errorCode, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public static System.Threading.Tasks.ValueTask<System.Net.Quic.QuicConnection> ConnectAsync(System.Net.Quic.QuicClientConnectionOptions options, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
Expand Down Expand Up @@ -122,6 +123,7 @@ public sealed partial class QuicStream : System.IO.Stream
public override int ReadByte() { throw null; }
public override long Seek(long offset, System.IO.SeekOrigin origin) { throw null; }
public override void SetLength(long value) { }
public override string ToString() { throw null; }
public override void Write(byte[] buffer, int offset, int count) { }
public override void Write(System.ReadOnlySpan<byte> buffer) { }
public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
Expand Down
4 changes: 4 additions & 0 deletions src/libraries/System.Net.Quic/src/System.Net.Quic.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
<Compile Include="$(CommonPath)System\Net\IPAddressParserStatics.cs" Link="Common\System\Net\IPAddressParserStatics.cs" />
<Compile Include="$(CommonPath)System\Net\Internals\IPEndPointExtensions.cs" Link="Common\System\Net\Internals\IPEndPointExtensions.cs" />
<Compile Include="$(CommonPath)System\Net\Security\TlsAlertMessage.cs" Link="Common\System\Net\Security\TlsAlertMessage.cs" />
<Compile Include="$(CommonPath)System\Net\Security\TargetHostNameHelper.cs" Link="Common\System\Net\Security\TargetHostNameHelper.cs" />
<!-- IP parser -->
<Compile Include="$(CommonPath)System\Net\IPv4AddressHelper.Common.cs" Link="System\Net\IPv4AddressHelper.Common.cs" />
<Compile Include="$(CommonPath)System\Net\IPv6AddressHelper.Common.cs" Link="System\Net\IPv6AddressHelper.Common.cs" />
</ItemGroup>
<!-- Unsupported platforms -->
<ItemGroup Condition="'$(TargetPlatformIdentifier)' == ''">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public partial class QuicConnection
/// <summary>
/// Host name send in SNI, set only for outbound/client connections. Configured via <see cref="SslClientAuthenticationOptions.TargetHost"/>.
/// </summary>
private readonly string? _targetHost;
private readonly string _targetHost;
/// <summary>
/// Always <c>true</c> for outbound/client connections. Configured for inbound/server ones via <see cref="SslServerAuthenticationOptions.ClientCertificateRequired"/>.
/// </summary>
Expand All @@ -47,8 +47,10 @@ public partial class QuicConnection
/// </summary>
private readonly X509ChainPolicy? _certificateChainPolicy;

internal string TargetHost => _targetHost;

public SslConnectionOptions(QuicConnection connection, bool isClient,
string? targetHost, bool certificateRequired, X509RevocationMode
string targetHost, bool certificateRequired, X509RevocationMode
revocationMode, RemoteCertificateValidationCallback? validationCallback,
X509ChainPolicy? certificateChainPolicy)
{
Expand Down Expand Up @@ -118,7 +120,7 @@ public unsafe int ValidateCertificate(QUIC_BUFFER* certificatePtr, QUIC_BUFFER*
if (result is not null)
{
bool checkCertName = !chain!.ChainPolicy!.VerificationFlags.HasFlag(X509VerificationFlags.IgnoreInvalidName);
sslPolicyErrors |= CertificateValidation.BuildChainAndVerifyProperties(chain!, result, checkCertName, !_isClient, _targetHost, certificateBuffer, certificateLength);
sslPolicyErrors |= CertificateValidation.BuildChainAndVerifyProperties(chain!, result, checkCertName, !_isClient, TargetHostNameHelper.NormalizeHostName(_targetHost), certificateBuffer, certificateLength);
}
else if (_certificateRequired)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ public static async ValueTask<QuicConnection> ConnectAsync(QuicClientConnectionO
/// </summary>
public IPEndPoint LocalEndPoint => _localEndPoint;

/// <summary>
/// Gets the name of the server the client is trying to connect to. That name is used for server certificate validation. It can be a DNS name or an IP address.
/// </summary>
/// <returns>The name of the server the client is trying to connect to.</returns>
public string TargetHostName => _sslConnectionOptions.TargetHost ?? string.Empty;

/// <summary>
/// The certificate provided by the peer.
/// For an outbound/client connection will always have the peer's (server) certificate; for an inbound/server one, only if the connection requested and the peer (client) provided one.
Expand Down Expand Up @@ -279,10 +285,16 @@ private async ValueTask FinishConnectAsync(QuicClientConnectionOptions options,
MsQuicHelpers.SetMsQuicParameter(_handle, QUIC_PARAM_CONN_LOCAL_ADDRESS, quicAddress);
}

// RFC 6066 forbids IP literals
// DNI mapping is handled by MsQuic
var hostname = TargetHostNameHelper.IsValidAddress(options.ClientAuthenticationOptions.TargetHost)
? string.Empty
: options.ClientAuthenticationOptions.TargetHost ?? string.Empty;

_sslConnectionOptions = new SslConnectionOptions(
this,
isClient: true,
options.ClientAuthenticationOptions.TargetHost,
hostname,
certificateRequired: true,
options.ClientAuthenticationOptions.CertificateRevocationCheckMode,
options.ClientAuthenticationOptions.RemoteCertificateValidationCallback,
Expand Down Expand Up @@ -312,7 +324,7 @@ private async ValueTask FinishConnectAsync(QuicClientConnectionOptions options,
await valueTask.ConfigureAwait(false);
}

internal ValueTask FinishHandshakeAsync(QuicServerConnectionOptions options, string? targetHost, CancellationToken cancellationToken = default)
internal ValueTask FinishHandshakeAsync(QuicServerConnectionOptions options, string targetHost, CancellationToken cancellationToken = default)
{
ObjectDisposedException.ThrowIf(_disposed == 1, this);

Expand All @@ -322,10 +334,16 @@ internal ValueTask FinishHandshakeAsync(QuicServerConnectionOptions options, str
_defaultStreamErrorCode = options.DefaultStreamErrorCode;
_defaultCloseErrorCode = options.DefaultCloseErrorCode;

// RFC 6066 forbids IP literals, avoid setting IP address here for consistency with SslStream
if (TargetHostNameHelper.IsValidAddress(targetHost))
{
targetHost = string.Empty;
}

_sslConnectionOptions = new SslConnectionOptions(
this,
isClient: false,
targetHost: null,
targetHost,
options.ServerAuthenticationOptions.ClientCertificateRequired,
options.ServerAuthenticationOptions.CertificateRevocationCheckMode,
options.ServerAuthenticationOptions.RemoteCertificateValidationCallback,
Expand Down
53 changes: 53 additions & 0 deletions src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1204,5 +1204,58 @@ await using (serverConnection)
await AssertThrowsQuicExceptionAsync(QuicError.ConnectionIdle, async () => await acceptTask).WaitAsync(TimeSpan.FromSeconds(10));
}
}

private async Task SniTestCore(string hostname, bool shouldSendSni)
{
string expectedHostName = shouldSendSni ? hostname : string.Empty;

using X509Certificate serverCert = Configuration.Certificates.GetSelfSignedServerCertificate();
var listenerOptions = new QuicListenerOptions()
{
ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0),
ApplicationProtocols = new List<SslApplicationProtocol>() { ApplicationProtocol },
ConnectionOptionsCallback = (_, _, _) =>
{
var serverOptions = CreateQuicServerOptions();
serverOptions.ServerAuthenticationOptions.ServerCertificateContext = null;
serverOptions.ServerAuthenticationOptions.ServerCertificate = null;
serverOptions.ServerAuthenticationOptions.ServerCertificateSelectionCallback = (sender, actualHostName) =>
{
Assert.Equal(expectedHostName, actualHostName);
return serverCert;
};
return ValueTask.FromResult(serverOptions);
}
};

// Use whatever endpoint, it'll get overwritten in CreateConnectedQuicConnection.
QuicClientConnectionOptions clientOptions = CreateQuicClientOptions(listenerOptions.ListenEndPoint);
clientOptions.ClientAuthenticationOptions.TargetHost = hostname;
clientOptions.ClientAuthenticationOptions.RemoteCertificateValidationCallback = delegate { return true; };


(QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listenerOptions);
await using (clientConnection)
await using (serverConnection)
{
Assert.Equal(expectedHostName, clientConnection.TargetHostName);
Assert.Equal(expectedHostName, serverConnection.TargetHostName);
}
}

[Theory]
[InlineData("a")]
[InlineData("test")]
[InlineData("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")] // max allowed hostname length is 63
[InlineData("\u017C\u00F3\u0142\u0107 g\u0119\u015Bl\u0105 ja\u017A\u0144. \u7EA2\u70E7. \u7167\u308A\u713C\u304D")]
public Task ClientSendsSniServerReceives_Ok(string hostname) => SniTestCore(hostname, true);

[Theory]
[InlineData("127.0.0.1")]
[InlineData("::1")]
[InlineData("2001:11:22::1")]
[InlineData("fe80::9c3a:b64d:6249:1de8%2")]
[InlineData("fe80::9c3a:b64d:6249:1de8")]
public Task DoesNotSendIPAsSni(string target) => SniTestCore(target, false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ public async Task TestConnect()
{
await using QuicListener listener = await CreateQuicListener();

ValueTask<QuicConnection> connectTask = CreateQuicConnection(listener.LocalEndPoint);
var options = CreateQuicClientOptions(listener.LocalEndPoint);
ValueTask<QuicConnection> connectTask = CreateQuicConnection(options);
ValueTask<QuicConnection> acceptTask = listener.AcceptConnectionAsync();

await new Task[] { connectTask.AsTask(), acceptTask.AsTask() }.WhenAllOrAnyFailed(PassingTestTimeoutMilliseconds);
Expand All @@ -34,6 +35,8 @@ public async Task TestConnect()
Assert.Equal(clientConnection.LocalEndPoint, serverConnection.RemoteEndPoint);
Assert.Equal(ApplicationProtocol.ToString(), clientConnection.NegotiatedApplicationProtocol.ToString());
Assert.Equal(ApplicationProtocol.ToString(), serverConnection.NegotiatedApplicationProtocol.ToString());
Assert.Equal(options.ClientAuthenticationOptions.TargetHost, clientConnection.TargetHostName);
Assert.Equal(options.ClientAuthenticationOptions.TargetHost, serverConnection.TargetHostName);
}

private static async Task<QuicStream> OpenAndUseStreamAsync(QuicConnection c)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
Link="Common\System\NotImplemented.cs" />
<Compile Include="$(CommonPath)System\Net\Security\TlsAlertMessage.cs"
Link="Common\System\Net\Security\TlsAlertMessage.cs" />
<Compile Include="$(CommonPath)System\Net\Security\TargetHostNameHelper.cs"
Link="Common\System\Net\Security\TargetHostNameHelper.cs" />
<Compile Include="$(CommonPath)System\Net\Security\SafeCredentialReference.cs"
Link="Common\System\Net\Security\SafeCredentialReference.cs" />
<Compile Include="$(CommonPath)System\Net\Security\SSPIHandleCache.cs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,58 +4,13 @@
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.Runtime.InteropServices;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;

namespace System.Net.Security
{
internal sealed class SslAuthenticationOptions
{
private static readonly IdnMapping s_idnMapping = new IdnMapping();

// Simplified version of IPAddressParser.Parse to avoid allocations and dependencies.
// It purposely ignores scopeId as we don't really use so we do not need to map it to actual interface id.
private static unsafe bool IsValidAddress(ReadOnlySpan<char> ipSpan)
{
int end = ipSpan.Length;

if (ipSpan.Contains(':'))
{
// The address is parsed as IPv6 if and only if it contains a colon. This is valid because
// we don't support/parse a port specification at the end of an IPv4 address.
Span<ushort> numbers = stackalloc ushort[IPAddressParserStatics.IPv6AddressShorts];

fixed (char* ipStringPtr = &MemoryMarshal.GetReference(ipSpan))
{
return IPv6AddressHelper.IsValidStrict(ipStringPtr, 0, ref end);
}
}
else if (char.IsDigit(ipSpan[0]))
{
long tmpAddr;

fixed (char* ipStringPtr = &MemoryMarshal.GetReference(ipSpan))
{
tmpAddr = IPv4AddressHelper.ParseNonCanonical(ipStringPtr, 0, ref end, notImplicitFile: true);
}

if (tmpAddr != IPv4AddressHelper.Invalid && end == ipSpan.Length)
{
return true;
}
}

return false;
}

private static readonly IndexOfAnyValues<char> s_safeDnsChars =
IndexOfAnyValues.Create("-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz");

private static bool IsSafeDnsString(ReadOnlySpan<char> name) =>
name.IndexOfAnyExcept(s_safeDnsChars) < 0;

internal SslAuthenticationOptions()
{
TargetHost = string.Empty;
Expand Down Expand Up @@ -93,29 +48,11 @@ internal void UpdateOptions(SslClientAuthenticationOptions sslClientAuthenticati
IsServer = false;
RemoteCertRequired = true;
CertificateContext = sslClientAuthenticationOptions.ClientCertificateContext;
if (!string.IsNullOrEmpty(sslClientAuthenticationOptions.TargetHost))
{
// RFC 6066 section 3 says to exclude trailing dot from fully qualified DNS hostname
string targetHost = sslClientAuthenticationOptions.TargetHost.TrimEnd('.');

// RFC 6066 forbids IP literals
if (IsValidAddress(targetHost))
{
TargetHost = string.Empty;
}
else
{
try
{
TargetHost = s_idnMapping.GetAscii(targetHost);
}
catch (ArgumentException) when (IsSafeDnsString(targetHost))
{
// Seems like name that does not confrom to IDN but apers somewhat valid according to orogional DNS rfc.
TargetHost = targetHost;
}
}
}
// RFC 6066 forbids IP literals
TargetHost = TargetHostNameHelper.IsValidAddress(sslClientAuthenticationOptions.TargetHost)
? string.Empty
: sslClientAuthenticationOptions.TargetHost ?? string.Empty;

// Client specific options.
CertificateRevocationCheckMode = sslClientAuthenticationOptions.CertificateRevocationCheckMode;
Expand Down