Skip to content

Commit

Permalink
Fix Kestrel overpooling of HTTP/2 and HTTP/3 request headers (#40087)
Browse files Browse the repository at this point in the history
* Fix Kestrel overpooling of HTTP/2 and HTTP/3 request headers

* Fix build
  • Loading branch information
JamesNK committed Feb 9, 2022
1 parent 6a6be13 commit e371b5d
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,8 @@ private Task DecodeHeadersAsync(bool endHeaders, in ReadOnlySequence<byte> paylo

if (endHeaders)
{
_currentHeadersStream.OnHeadersComplete();

StartStream();
ResetRequestHeaderParsingState();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3;

internal class Http3Connection : IHttp3StreamLifetimeHandler, IRequestProcessor
{
private static readonly object StreamPersistentStateKey = new object();
internal static readonly object StreamPersistentStateKey = new object();

// Internal for unit testing
internal readonly Dictionary<long, IHttp3Stream> _streams = new Dictionary<long, IHttp3Stream>();
Expand Down
2 changes: 2 additions & 0 deletions src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,8 @@ private static Task ProcessUnknownFrameAsync()

InputRemaining = HttpRequestHeaders.ContentLength;

OnHeadersComplete();

// If the stream is complete after receiving the headers then run OnEndStreamReceived.
// If there is a bad content length then this will throw before the request delegate is called.
if (isCompleted)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
using System.Net.Http;
using System.Net.Http.HPack;
using System.Net.Security;
using System.Reflection;
using System.Security.Authentication;
using System.Text;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.AspNetCore.Testing;
Expand Down Expand Up @@ -207,6 +209,73 @@ public async Task RequestHeaderStringReuse_MultipleStreams_KnownHeaderReused()
await StopConnectionAsync(expectedLastStreamId: 3, ignoreNonGoAwayFrames: false);
}

[Fact]
public async Task RequestHeaderStringReuse_MultipleStreams_KnownHeaderClearedIfNotReused()
{
const BindingFlags privateFlags = BindingFlags.NonPublic | BindingFlags.Instance;

IEnumerable<KeyValuePair<string, string>> requestHeaders1 = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
new KeyValuePair<string, string>(HeaderNames.Path, "/hello"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
new KeyValuePair<string, string>(HeaderNames.Authority, "localhost:80"),
new KeyValuePair<string, string>(HeaderNames.ContentType, "application/json")
};

// Note: No content-type
IEnumerable<KeyValuePair<string, string>> requestHeaders2 = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
new KeyValuePair<string, string>(HeaderNames.Path, "/hello"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
new KeyValuePair<string, string>(HeaderNames.Authority, "localhost:80")
};

await InitializeConnectionAsync(_noopApplication);

await StartStreamAsync(1, requestHeaders1, endStream: true);

await ExpectAsync(Http2FrameType.HEADERS,
withLength: 36,
withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
withStreamId: 1);

// TriggerTick will trigger the stream to be returned to the pool so we can assert it
TriggerTick();

// Stream has been returned to the pool
Assert.Equal(1, _connection.StreamPool.Count);
Assert.True(_connection.StreamPool.TryPeek(out var stream1));

// Hacky but required because header references is private.
var headerReferences1 = typeof(HttpRequestHeaders).GetField("_headers", privateFlags).GetValue(stream1.RequestHeaders);
var contentTypeValue1 = (StringValues)headerReferences1.GetType().GetField("_ContentType").GetValue(headerReferences1);

await StartStreamAsync(3, requestHeaders2, endStream: true);

await ExpectAsync(Http2FrameType.HEADERS,
withLength: 6,
withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
withStreamId: 3);

// TriggerTick will trigger the stream to be returned to the pool so we can assert it
TriggerTick();

// Stream has been returned to the pool
Assert.Equal(1, _connection.StreamPool.Count);
Assert.True(_connection.StreamPool.TryPeek(out var stream2));

// Hacky but required because header references is private.
var headerReferences2 = typeof(HttpRequestHeaders).GetField("_headers", privateFlags).GetValue(stream2.RequestHeaders);
var contentTypeValue2 = (StringValues)headerReferences2.GetType().GetField("_ContentType").GetValue(headerReferences2);

Assert.Equal("application/json", contentTypeValue1);
Assert.Equal(StringValues.Empty, contentTypeValue2);

await StopConnectionAsync(expectedLastStreamId: 3, ignoreNonGoAwayFrames: false);
}

private class ResponseTrailersWrapper : IHeaderDictionary
{
readonly IHeaderDictionary _innerHeaders;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
using System.Collections.Generic;
using System.Globalization;
using System.Net.Http;
using System.Reflection;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3;
using Microsoft.AspNetCore.Testing;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -372,6 +375,49 @@ public async Task StreamPool_MultipleStreamsInSequence_KnownHeaderReused()
Assert.Same(authority1, authority2);
}

[Fact]
public async Task RequestHeaderStringReuse_MultipleStreams_KnownHeaderClearedIfNotReused()
{
const BindingFlags privateFlags = BindingFlags.NonPublic | BindingFlags.Instance;

KeyValuePair<string, string>[] requestHeaders1 = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
new KeyValuePair<string, string>(HeaderNames.Path, "/hello"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
new KeyValuePair<string, string>(HeaderNames.Authority, "localhost:80"),
new KeyValuePair<string, string>(HeaderNames.ContentType, "application/json")
};

// Note: No content-type
KeyValuePair<string, string>[] requestHeaders2 = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
new KeyValuePair<string, string>(HeaderNames.Path, "/hello"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
new KeyValuePair<string, string>(HeaderNames.Authority, "localhost:80")
};

await Http3Api.InitializeConnectionAsync(_echoApplication);

var streamContext1 = await MakeRequestAsync(0, requestHeaders1, sendData: true, waitForServerDispose: true);
var http3Stream1 = (Http3Stream)streamContext1.Features.Get<IPersistentStateFeature>().State[Http3Connection.StreamPersistentStateKey];

// Hacky but required because header references is private.
var headerReferences1 = typeof(HttpRequestHeaders).GetField("_headers", privateFlags).GetValue(http3Stream1.RequestHeaders);
var contentTypeValue1 = (StringValues)headerReferences1.GetType().GetField("_ContentType").GetValue(headerReferences1);

var streamContext2 = await MakeRequestAsync(1, requestHeaders2, sendData: true, waitForServerDispose: true);
var http3Stream2 = (Http3Stream)streamContext2.Features.Get<IPersistentStateFeature>().State[Http3Connection.StreamPersistentStateKey];

// Hacky but required because header references is private.
var headerReferences2 = typeof(HttpRequestHeaders).GetField("_headers", privateFlags).GetValue(http3Stream2.RequestHeaders);
var contentTypeValue2 = (StringValues)headerReferences1.GetType().GetField("_ContentType").GetValue(headerReferences2);

Assert.Equal("application/json", contentTypeValue1);
Assert.Equal(StringValues.Empty, contentTypeValue2);
}

[Theory]
[InlineData(10)]
[InlineData(100)]
Expand Down

0 comments on commit e371b5d

Please sign in to comment.