Skip to content

Commit

Permalink
Kestrel response header encoding (#33776)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tratcher committed Jul 6, 2021
1 parent e65cec1 commit 051aa95
Show file tree
Hide file tree
Showing 44 changed files with 1,632 additions and 350 deletions.
398 changes: 206 additions & 192 deletions src/Servers/Kestrel/Core/src/Internal/Http/HttpHeaders.Generated.cs

Large diffs are not rendered by default.

13 changes: 9 additions & 4 deletions src/Servers/Kestrel/Core/src/Internal/Http/HttpHeaders.cs
Expand Up @@ -9,6 +9,7 @@
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.Extensions.Primitives;
Expand Down Expand Up @@ -260,21 +261,25 @@ IEnumerator IEnumerable.GetEnumerator()
return TryGetValueFast(key, out value);
}

public static void ValidateHeaderValueCharacters(StringValues headerValues)
public static void ValidateHeaderValueCharacters(string headerName, StringValues headerValues, Func<string, Encoding?> encodingSelector)
{
var requireAscii = ReferenceEquals(encodingSelector, KestrelServerOptions.DefaultHeaderEncodingSelector)
|| encodingSelector(headerName) == null;

var count = headerValues.Count;
for (var i = 0; i < count; i++)

{
ValidateHeaderValueCharacters(headerValues[i]);
ValidateHeaderValueCharacters(headerValues[i], requireAscii);
}
}

public static void ValidateHeaderValueCharacters(string headerCharacters)
public static void ValidateHeaderValueCharacters(string headerCharacters, bool requireAscii)
{
if (headerCharacters != null)
{
var invalid = HttpCharacters.IndexOfInvalidFieldValueChar(headerCharacters);
var invalid = requireAscii ? HttpCharacters.IndexOfInvalidFieldValueChar(headerCharacters)
: HttpCharacters.IndexOfInvalidFieldValueCharExtended(headerCharacters);
if (invalid >= 0)
{
ThrowInvalidHeaderCharacter(headerCharacters[invalid]);
Expand Down
1 change: 1 addition & 0 deletions src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs
Expand Up @@ -374,6 +374,7 @@ public void Reset()
HttpRequestHeaders.EncodingSelector = ServerOptions.RequestHeaderEncodingSelector;
HttpRequestHeaders.ReuseHeaderValues = !ServerOptions.DisableStringReuse;
HttpResponseHeaders.Reset();
HttpResponseHeaders.EncodingSelector = ServerOptions.ResponseHeaderEncodingSelector;
RequestHeaders = HttpRequestHeaders;
ResponseHeaders = HttpResponseHeaders;
RequestTrailers.Clear();
Expand Down
Expand Up @@ -24,7 +24,7 @@ internal sealed partial class HttpRequestHeaders : HttpHeaders
public HttpRequestHeaders(bool reuseHeaderValues = true, Func<string, Encoding?>? encodingSelector = null)
{
ReuseHeaderValues = reuseHeaderValues;
EncodingSelector = encodingSelector ?? KestrelServerOptions.DefaultRequestHeaderEncodingSelector;
EncodingSelector = encodingSelector ?? KestrelServerOptions.DefaultHeaderEncodingSelector;
}

public void OnHeadersComplete()
Expand Down Expand Up @@ -97,7 +97,7 @@ private void AppendContentLength(ReadOnlySpan<byte> value)

[MethodImpl(MethodImplOptions.NoInlining)]
[SkipLocalsInit]
private void AppendContentLengthCustomEncoding(ReadOnlySpan<byte> value, Encoding? customEncoding)
private void AppendContentLengthCustomEncoding(ReadOnlySpan<byte> value, Encoding customEncoding)
{
if (_contentLength.HasValue)
{
Expand All @@ -106,7 +106,7 @@ private void AppendContentLengthCustomEncoding(ReadOnlySpan<byte> value, Encodin

// long.MaxValue = 9223372036854775807 (19 chars)
Span<char> decodedChars = stackalloc char[20];
var numChars = customEncoding!.GetChars(value, decodedChars);
var numChars = customEncoding.GetChars(value, decodedChars);
long parsed = -1;

if (numChars > 19 ||
Expand Down
49 changes: 46 additions & 3 deletions src/Servers/Kestrel/Core/src/Internal/Http/HttpResponseHeaders.cs
Expand Up @@ -4,10 +4,11 @@
using System;
using System.Buffers;
using System.Diagnostics.CodeAnalysis;
using System.IO.Pipelines;
using System.Collections;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Runtime.CompilerServices;
using System.Text;
using Microsoft.Extensions.Primitives;
using Microsoft.Net.Http.Headers;

Expand All @@ -19,6 +20,13 @@ internal sealed partial class HttpResponseHeaders : HttpHeaders
private static ReadOnlySpan<byte> CrLf => new[] { (byte)'\r', (byte)'\n' };
private static ReadOnlySpan<byte> ColonSpace => new[] { (byte)':', (byte)' ' };

public Func<string, Encoding?> EncodingSelector { get; set; }

public HttpResponseHeaders(Func<string, Encoding?>? encodingSelector = null)
{
EncodingSelector = encodingSelector ?? KestrelServerOptions.DefaultHeaderEncodingSelector;
}

public Enumerator GetEnumerator()
{
return new Enumerator(this);
Expand All @@ -34,10 +42,18 @@ internal void CopyTo(ref BufferWriter<PipeWriter> buffer)
CopyToFast(ref buffer);

var extraHeaders = MaybeUnknown;
// Only reserve stack space for the enumerators if there are extra headers
if (extraHeaders != null && extraHeaders.Count > 0)
{
// Only reserve stack space for the enumartors if there are extra headers
CopyExtraHeaders(ref buffer, extraHeaders);
var encodingSelector = EncodingSelector;
if (ReferenceEquals(encodingSelector, KestrelServerOptions.DefaultHeaderEncodingSelector))
{
CopyExtraHeaders(ref buffer, extraHeaders);
}
else
{
CopyExtraHeadersCustomEncoding(ref buffer, extraHeaders, encodingSelector);
}
}

static void CopyExtraHeaders(ref BufferWriter<PipeWriter> buffer, Dictionary<string, StringValues> headers)
Expand All @@ -56,6 +72,33 @@ static void CopyExtraHeaders(ref BufferWriter<PipeWriter> buffer, Dictionary<str
}
}
}

static void CopyExtraHeadersCustomEncoding(ref BufferWriter<PipeWriter> buffer, Dictionary<string, StringValues> headers,
Func<string, Encoding?> encodingSelector)
{
foreach (var kv in headers)
{
var encoding = encodingSelector(kv.Key);
foreach (var value in kv.Value)
{
if (value != null)
{
buffer.Write(CrLf);
buffer.WriteAscii(kv.Key);
buffer.Write(ColonSpace);

if (encoding is null)
{
buffer.WriteAscii(value);
}
else
{
buffer.WriteEncoded(value, encoding);
}
}
}
}
}
}

private static long ParseContentLength(string value)
Expand Down
Expand Up @@ -5,12 +5,20 @@
using System.Collections;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using Microsoft.Extensions.Primitives;

namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
internal partial class HttpResponseTrailers : HttpHeaders
{
public Func<string, Encoding?> EncodingSelector { get; set; }

public HttpResponseTrailers(Func<string, Encoding?>? encodingSelector = null)
{
EncodingSelector = encodingSelector ?? KestrelServerOptions.DefaultHeaderEncodingSelector;
}

public Enumerator GetEnumerator()
{
return new Enumerator(this);
Expand Down
Expand Up @@ -87,7 +87,7 @@ private static bool EncodeStatusHeader(int statusCode, DynamicHPackEncoder hpack
default:
const string name = ":status";
var value = StatusCodes.ToStatusString(statusCode);
return hpackEncoder.EncodeHeader(buffer, H2StaticTable.Status200, HeaderEncodingHint.Index, name, value, out length);
return hpackEncoder.EncodeHeader(buffer, H2StaticTable.Status200, HeaderEncodingHint.Index, name, value, valueEncoding: null, out length);
}
}

Expand All @@ -99,6 +99,9 @@ private static bool EncodeHeadersCore(DynamicHPackEncoder hpackEncoder, Http2Hea
var staticTableId = headersEnumerator.HPackStaticTableId;
var name = headersEnumerator.Current.Key;
var value = headersEnumerator.Current.Value;
var valueEncoding =
ReferenceEquals(headersEnumerator.EncodingSelector, KestrelServerOptions.DefaultHeaderEncodingSelector)
? null : headersEnumerator.EncodingSelector(name);

var hint = ResolveHeaderEncodingHint(staticTableId, name);

Expand All @@ -108,6 +111,7 @@ private static bool EncodeHeadersCore(DynamicHPackEncoder hpackEncoder, Http2Hea
hint,
name,
value,
valueEncoding,
out var headerLength))
{
// If the header wasn't written, and no headers have been written, then the header is too large.
Expand Down
18 changes: 11 additions & 7 deletions src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameWriter.cs
Expand Up @@ -189,11 +189,13 @@ public void WriteResponseHeaders(int streamId, int statusCode, Http2HeadersFrame
var done = HPackHeaderWriter.BeginEncodeHeaders(statusCode, _hpackEncoder, _headersEnumerator, buffer, out var payloadLength);
FinishWritingHeaders(streamId, payloadLength, done);
}
catch (HPackEncodingException hex)
// Any exception from the HPack encoder can leave the dynamic table in a corrupt state.
// Since we allow custom header encoders we don't know what type of exceptions to expect.
catch (Exception ex)
{
_log.HPackEncodingError(_connectionId, streamId, hex);
_http2Connection.Abort(new ConnectionAbortedException(hex.Message, hex));
throw new InvalidOperationException(hex.Message, hex); // Report the error to the user if this was the first write.
_log.HPackEncodingError(_connectionId, streamId, ex);
_http2Connection.Abort(new ConnectionAbortedException(ex.Message, ex));
throw new InvalidOperationException(ex.Message, ex); // Report the error to the user if this was the first write.
}
}
}
Expand All @@ -215,10 +217,12 @@ public ValueTask<FlushResult> WriteResponseTrailersAsync(int streamId, HttpRespo
var done = HPackHeaderWriter.BeginEncodeHeaders(_hpackEncoder, _headersEnumerator, buffer, out var payloadLength);
FinishWritingHeaders(streamId, payloadLength, done);
}
catch (HPackEncodingException hex)
// Any exception from the HPack encoder can leave the dynamic table in a corrupt state.
// Since we allow custom header encoders we don't know what type of exceptions to expect.
catch (Exception ex)
{
_log.HPackEncodingError(_connectionId, streamId, hex);
_http2Connection.Abort(new ConnectionAbortedException(hex.Message, hex));
_log.HPackEncodingError(_connectionId, streamId, ex);
_http2Connection.Abort(new ConnectionAbortedException(ex.Message, ex));
}

return TimeFlushUnsynchronizedAsync();
Expand Down
@@ -1,9 +1,11 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Net.Http.HPack;
using System.Text;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
using Microsoft.Extensions.Primitives;

Expand All @@ -25,23 +27,23 @@ private enum HeadersType : byte
private bool _hasMultipleValues;
private KnownHeaderType _knownHeaderType;

public Func<string, Encoding?> EncodingSelector { get; set; } = KestrelServerOptions.DefaultHeaderEncodingSelector;

public int HPackStaticTableId => GetResponseHeaderStaticTableId(_knownHeaderType);
public KeyValuePair<string, string> Current { get; private set; }
object IEnumerator.Current => Current;

public Http2HeadersEnumerator()
{
}

public void Initialize(HttpResponseHeaders headers)
{
EncodingSelector = headers.EncodingSelector;
_headersEnumerator = headers.GetEnumerator();
_headersType = HeadersType.Headers;
_hasMultipleValues = false;
}

public void Initialize(HttpResponseTrailers headers)
{
EncodingSelector = headers.EncodingSelector;
_trailersEnumerator = headers.GetEnumerator();
_headersType = HeadersType.Trailers;
_hasMultipleValues = false;
Expand Down
Expand Up @@ -25,7 +25,7 @@ IHeaderDictionary IHttpResponseTrailersFeature.Trailers
{
if (ResponseTrailers == null)
{
ResponseTrailers = new HttpResponseTrailers();
ResponseTrailers = new HttpResponseTrailers(ServerOptions.ResponseHeaderEncodingSelector);
if (HasResponseCompleted)
{
ResponseTrailers.SetReadOnly();
Expand Down
32 changes: 14 additions & 18 deletions src/Servers/Kestrel/Core/src/Internal/Http3/Http3FrameWriter.cs
Expand Up @@ -8,6 +8,7 @@
using System.IO.Pipelines;
using System.Net.Http;
using System.Net.Http.QPack;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
Expand Down Expand Up @@ -44,7 +45,7 @@ internal class Http3FrameWriter
// Write headers to a buffer that can grow. Possible performance improvement
// by writing directly to output writer (difficult as frame length is prefixed).
private readonly ArrayBufferWriter<byte> _headerEncodingBuffer;
private IEnumerator<KeyValuePair<string, string>>? _headersEnumerator;
private Http3HeadersEnumerator _headersEnumerator = new();
private int _headersTotalSize;

private long _unflushedBytes;
Expand Down Expand Up @@ -271,7 +272,7 @@ public ValueTask<FlushResult> WriteResponseTrailersAsync(long streamId, HttpResp

try
{
_headersEnumerator = EnumerateHeaders(headers).GetEnumerator();
_headersEnumerator.Initialize(headers);
_headersTotalSize = 0;
_headerEncodingBuffer.Clear();

Expand All @@ -280,9 +281,12 @@ public ValueTask<FlushResult> WriteResponseTrailersAsync(long streamId, HttpResp
var done = QPackHeaderWriter.BeginEncode(_headersEnumerator, buffer, ref _headersTotalSize, out var payloadLength);
FinishWritingHeaders(payloadLength, done);
}
catch (QPackEncodingException ex)
// Any exception from the QPack encoder can leave the dynamic table in a corrupt state.
// Since we allow custom header encoders we don't know what type of exceptions to expect.
catch (Exception ex)
{
_log.QPackEncodingError(_connectionId, streamId, ex);
_connectionContext.Abort(new ConnectionAbortedException(ex.Message, ex));
_http3Stream.Abort(new ConnectionAbortedException(ex.Message, ex), Http3ErrorCode.InternalError);
}

Expand Down Expand Up @@ -314,7 +318,7 @@ public ValueTask<FlushResult> FlushAsync(IHttpOutputAborter? outputAborter, Canc
}
}

internal void WriteResponseHeaders(int statusCode, IHeaderDictionary headers)
internal void WriteResponseHeaders(int statusCode, HttpResponseHeaders headers)
{
lock (_writeLock)
{
Expand All @@ -325,15 +329,19 @@ internal void WriteResponseHeaders(int statusCode, IHeaderDictionary headers)

try
{
_headersEnumerator = EnumerateHeaders(headers).GetEnumerator();
_headersEnumerator.Initialize(headers);

_outgoingFrame.PrepareHeaders();
var buffer = _headerEncodingBuffer.GetSpan(HeaderBufferSize);
var done = QPackHeaderWriter.BeginEncode(statusCode, _headersEnumerator, buffer, ref _headersTotalSize, out var payloadLength);
FinishWritingHeaders(payloadLength, done);
}
catch (QPackEncodingException ex)
// Any exception from the QPack encoder can leave the dynamic table in a corrupt state.
// Since we allow custom header encoders we don't know what type of exceptions to expect.
catch (Exception ex)
{
_log.QPackEncodingError(_connectionId, _http3Stream.StreamId, ex);
_connectionContext.Abort(new ConnectionAbortedException(ex.Message, ex));
_http3Stream.Abort(new ConnectionAbortedException(ex.Message, ex), Http3ErrorCode.InternalError);
throw new InvalidOperationException(ex.Message, ex); // Report the error to the user if this was the first write.
}
Expand All @@ -347,7 +355,6 @@ private void FinishWritingHeaders(int payloadLength, bool done)
while (!done)
{
ValidateHeadersTotalSize();

var buffer = _headerEncodingBuffer.GetSpan(HeaderBufferSize);
done = QPackHeaderWriter.Encode(_headersEnumerator!, buffer, ref _headersTotalSize, out payloadLength);
_headerEncodingBuffer.Advance(payloadLength);
Expand Down Expand Up @@ -404,16 +411,5 @@ public void Abort(ConnectionAbortedException error)
_outputWriter.Complete();
}
}

private static IEnumerable<KeyValuePair<string, string>> EnumerateHeaders(IHeaderDictionary headers)
{
foreach (var header in headers)
{
foreach (var value in header.Value)
{
yield return new KeyValuePair<string, string>(header.Key, value);
}
}
}
}
}

0 comments on commit 051aa95

Please sign in to comment.