diff --git a/src/Middleware/HttpOverrides/src/ForwardedHeadersMiddleware.cs b/src/Middleware/HttpOverrides/src/ForwardedHeadersMiddleware.cs index da126640af9e..6bec4d7975ee 100644 --- a/src/Middleware/HttpOverrides/src/ForwardedHeadersMiddleware.cs +++ b/src/Middleware/HttpOverrides/src/ForwardedHeadersMiddleware.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; using System.Linq; using System.Net; using System.Runtime.CompilerServices; @@ -17,50 +18,24 @@ namespace Microsoft.AspNetCore.HttpOverrides; /// public class ForwardedHeadersMiddleware { - private static readonly bool[] HostCharValidity = new bool[127]; - private static readonly bool[] SchemeCharValidity = new bool[123]; - private readonly ForwardedHeadersOptions _options; private readonly RequestDelegate _next; private readonly ILogger _logger; private bool _allowAllHosts; private IList? _allowedHosts; - static ForwardedHeadersMiddleware() - { - // RFC 3986 scheme = ALPHA * (ALPHA / DIGIT / "+" / "-" / ".") - SchemeCharValidity['+'] = true; - SchemeCharValidity['-'] = true; - SchemeCharValidity['.'] = true; - - // Host Matches Http.Sys and Kestrel - // Host Matches RFC 3986 except "*" / "+" / "," / ";" / "=" and "%" HEXDIG HEXDIG which are not allowed by Http.Sys - HostCharValidity['!'] = true; - HostCharValidity['$'] = true; - HostCharValidity['&'] = true; - HostCharValidity['\''] = true; - HostCharValidity['('] = true; - HostCharValidity[')'] = true; - HostCharValidity['-'] = true; - HostCharValidity['.'] = true; - HostCharValidity['_'] = true; - HostCharValidity['~'] = true; - for (var ch = '0'; ch <= '9'; ch++) - { - SchemeCharValidity[ch] = true; - HostCharValidity[ch] = true; - } - for (var ch = 'A'; ch <= 'Z'; ch++) - { - SchemeCharValidity[ch] = true; - HostCharValidity[ch] = true; - } - for (var ch = 'a'; ch <= 'z'; ch++) - { - SchemeCharValidity[ch] = true; - HostCharValidity[ch] = true; - } - } + // RFC 3986 scheme = ALPHA * (ALPHA / DIGIT / "+" / "-" / ".") + private static readonly IndexOfAnyValues SchemeChars = + IndexOfAnyValues.Create("+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); + + // Host Matches Http.Sys and Kestrel + // Host Matches RFC 3986 except "*" / "+" / "," / ";" / "=" and "%" HEXDIG HEXDIG which are not allowed by Http.Sys + private static readonly IndexOfAnyValues HostChars = + IndexOfAnyValues.Create("!$&'()-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz~"); + + // 0-9 / A-F / a-f / ":" / "." + private static readonly IndexOfAnyValues Ipv6HostChars = + IndexOfAnyValues.Create(".0123456789:ABCDEFabcdef"); /// /// Create a new . @@ -264,7 +239,7 @@ public void ApplyForwarders(HttpContext context) if (checkProto) { - if (!string.IsNullOrEmpty(set.Scheme) && TryValidateScheme(set.Scheme)) + if (!string.IsNullOrEmpty(set.Scheme) && set.Scheme.AsSpan().IndexOfAnyExcept(SchemeChars) < 0) { applyChanges = true; currentValues.Scheme = set.Scheme; @@ -383,26 +358,6 @@ private struct SetOfForwarders public string Scheme; } - // Empty was checked for by the caller - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static bool TryValidateScheme(string scheme) - { - for (var i = 0; i < scheme.Length; i++) - { - if (!IsValidSchemeChar(scheme[i])) - { - return false; - } - } - return true; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static bool IsValidSchemeChar(char ch) - { - return ch < SchemeCharValidity.Length && SchemeCharValidity[ch]; - } - // Empty was checked for by the caller [MethodImpl(MethodImplOptions.AggressiveInlining)] private static bool TryValidateHost(string host) @@ -418,87 +373,48 @@ private static bool TryValidateHost(string host) return false; } - var i = 0; - for (; i < host.Length; i++) + var firstNonHostCharIdx = host.AsSpan().IndexOfAnyExcept(HostChars); + if (firstNonHostCharIdx == -1) { - if (!IsValidHostChar(host[i])) - { - break; - } + // no port + return true; + } + else + { + return TryValidateHostPort(host, firstNonHostCharIdx); } - return TryValidateHostPort(host, i); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static bool IsValidHostChar(char ch) - { - return ch < HostCharValidity.Length && HostCharValidity[ch]; } // The lead '[' was already checked [MethodImpl(MethodImplOptions.AggressiveInlining)] private static bool TryValidateIPv6Host(string hostText) { - for (var i = 1; i < hostText.Length; i++) - { - var ch = hostText[i]; - if (ch == ']') - { - // [::1] is the shortest valid IPv6 host - if (i < 4) - { - return false; - } - return TryValidateHostPort(hostText, i + 1); - } + var host = hostText.AsSpan(1); - if (!IsHex(ch) && ch != ':' && ch != '.') - { - return false; - } + var hostEndIdx = host.IndexOfAnyExcept(Ipv6HostChars); + if ((uint)hostEndIdx >= (uint)host.Length || // No ']'. The uint cast is there to eliminate the + // bounds check on the 'host[hostEndIdx]' access below. + host[hostEndIdx] != ']' || // We found an invalid host character + hostEndIdx < 3) // [::1] is the shortest valid IPv6 host + { + return false; } - // Must contain a ']' - return false; + // If there's nothing left, we're good. If there's more, validate it as a port. + // +2 to skip the '[' and ']' (the '[' wasn't included in hostEndIdx because we + // cut it off in the AsSpan above). + return (hostEndIdx + 2 == hostText.Length) || TryValidateHostPort(hostText, hostEndIdx + 2); } [MethodImpl(MethodImplOptions.AggressiveInlining)] private static bool TryValidateHostPort(string hostText, int offset) { - if (offset == hostText.Length) - { - // No port - return true; - } - if (hostText[offset] != ':' || hostText.Length == offset + 1) { // Must have at least one number after the colon if present. return false; } - for (var i = offset + 1; i < hostText.Length; i++) - { - if (!IsNumeric(hostText[i])) - { - return false; - } - } - - return true; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static bool IsNumeric(char ch) - { - return '0' <= ch && ch <= '9'; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static bool IsHex(char ch) - { - return IsNumeric(ch) - || ('a' <= ch && ch <= 'f') - || ('A' <= ch && ch <= 'F'); + return hostText.AsSpan(offset + 1).IndexOfAnyExceptInRange('0', '9') < 0; } }