Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions src/CHttpServer/CHttpServer/Http2Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ private enum ConnectionState : byte
private readonly Stream _inputStream;
private uint _streamIdIndex;
private readonly HPackDecoder _hpackDecoder;
private readonly Http2Stream _defaultStream;
private readonly LimitedObjectPool<Http2Stream> _streamPool;

private byte[] _buffer;
private FrameWriter? _writer;
Expand All @@ -49,8 +49,6 @@ public Http2Connection(CHttpConnectionContext connectionContext)
_context = connectionContext;
connectionContext.Features.Get<IConnectionHeartbeatFeature>()?.OnHeartbeat(OnHeartbeat, this);
_streams = [];
_defaultStream = new Http2Stream<object>(0, 0, this, _context.Features, null!);
_currentStream = _defaultStream;
_h2Settings = new();
_hpackDecoder = new(maxDynamicTableSize: 0, maxHeadersLength: connectionContext.ServerOptions.MaxRequestHeaderLength);
_buffer = ArrayPool<byte>.Shared.Rent(checked((int)_h2Settings.ReceiveMaxFrameSize));
Expand All @@ -60,6 +58,8 @@ public Http2Connection(CHttpConnectionContext connectionContext)
_readFrame = new();
_serverWindow = new(_context.ServerOptions.ServerConnectionFlowControlSize + CHttpServerOptions.InitialStreamFlowControlSize);
_clientWindow = new(_h2Settings.InitialWindowSize);
_streamPool = new LimitedObjectPool<Http2Stream>();
_currentStream = _streamPool.Get(CreateConnection, (Connection: this, _context.Features));
}

// Setter is atest hook
Expand Down Expand Up @@ -153,7 +153,7 @@ private ValueTask ProcessFrame<TContext>(IHttpApplication<TContext> application)
if (_readFrame.Type == Http2FrameType.RST_STREAM)
return ProcessResetStreamFrame();
if (_readFrame.Type == Http2FrameType.CONTINUATION)
return ProcessContinuationFrame();
return ProcessContinuationFrame(application);
return ValueTask.CompletedTask;
}

Expand Down Expand Up @@ -278,11 +278,10 @@ private async ValueTask ProcessHeaderFrame<TContext>(IHttpApplication<TContext>

if ((_currentStream.StreamId == _readFrame.StreamId && _currentStream.RequestEndHeaders)
|| _streams.ContainsKey(_readFrame.StreamId))
{
throw new Http2ProtocolException();
}

_currentStream = new Http2Stream<TContext>(_readFrame.StreamId, _h2Settings.InitialWindowSize, this, _context.Features, application);
_currentStream = _streamPool.Get(CreateConnection, (Connection: this, _context.Features));
_currentStream.Initialize(_readFrame.StreamId, _h2Settings.InitialWindowSize, _context.ServerOptions.ServerStreamFlowControlSize);
var addResult = _streams.TryAdd(_readFrame.StreamId, _currentStream);
_streamIdIndex = uint.Max(_readFrame.StreamId, _streamIdIndex);
Debug.Assert(addResult);
Expand All @@ -292,14 +291,15 @@ private async ValueTask ProcessHeaderFrame<TContext>(IHttpApplication<TContext>
if (endHeaders)
{
_currentStream.RequestEndHeadersReceived();
StartStream();
StartStream(_currentStream, application);
}
}

// +---------------------------------------------------------------+
// | Header Block Fragment(*) ...
// +---------------------------------------------------------------+
private async ValueTask ProcessContinuationFrame()
private async ValueTask ProcessContinuationFrame<TContext>(
IHttpApplication<TContext> application) where TContext : notnull
{
if (_currentStream.StreamId != _readFrame.StreamId || _currentStream.RequestEndHeaders)
throw new Http2ProtocolException();
Expand All @@ -313,7 +313,7 @@ private async ValueTask ProcessContinuationFrame()
if (endHeaders)
{
_currentStream.RequestEndHeadersReceived();
StartStream();
StartStream(_currentStream, application);
}
}

Expand All @@ -334,12 +334,15 @@ private async ValueTask ProcessResetStreamFrame()
httpStream.Abort();
OnStreamCompleted(httpStream);
if (_currentStream.StreamId == streamId)
_currentStream = _defaultStream;
_currentStream = _streamPool.Get(CreateConnection, (Connection: this, _context.Features));
}

private void StartStream()
private static void StartStream<TContext>(Http2Stream stream, IHttpApplication<TContext> application) where TContext : notnull
{
ThreadPool.UnsafeQueueUserWorkItem(_currentStream, preferLocal: false);
ThreadPool.UnsafeQueueUserWorkItem(
state => state.Stream.Execute(state.App),
(App: application, Stream: stream),
preferLocal: false);
}

private async ValueTask ProcessWindowUpdateFrame()
Expand Down Expand Up @@ -459,6 +462,11 @@ private void ValidateTlsRequirements()
internal void OnStreamCompleted(Http2Stream stream)
{
_streams.TryRemove(stream.StreamId, out _);
if (!stream.IsAborted)
{
stream.Reset();
_streamPool.Return(stream);
}
if (_gracefulShutdownRequested)
TryGracefulShutdown();
}
Expand All @@ -477,4 +485,6 @@ internal bool ReserveClientFlowControlSize(uint requestedSize, out uint reserved
{
return _clientWindow.TryUseAny(requestedSize, out reservedSize);
}

private static Http2Stream CreateConnection((Http2Connection Connection, FeatureCollection Features) state) => new Http2Stream(state.Connection, state.Features);
}
144 changes: 97 additions & 47 deletions src/CHttpServer/CHttpServer/Http2Stream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,7 @@

namespace CHttpServer;

internal class Http2Stream<TContext> : Http2Stream where TContext : notnull
{
private readonly IHttpApplication<TContext> _application;
private readonly FeatureCollection _featureCollection;

public Http2Stream(uint streamId, uint initialWindowSize, Http2Connection connection, FeatureCollection features, IHttpApplication<TContext> application)
: base(streamId, initialWindowSize, connection)
{
_application = application;
_featureCollection = features.Copy();
_featureCollection.Add<IHttpRequestFeature>(this);
_featureCollection.Add<IHttpResponseFeature>(this);
_featureCollection.Add<IHttpResponseBodyFeature>(this);
_featureCollection.Add<IHttpResponseTrailersFeature>(this);
_featureCollection.Add<IHttpRequestBodyDetectionFeature>(this);
_featureCollection.Add<IHttpRequestLifetimeFeature>(this);
}

protected override Task RunApplicationAsync()
{
var context = _application.CreateContext(_featureCollection);
return _application.ProcessRequestAsync(context);
}
}

internal abstract partial class Http2Stream : IThreadPoolWorkItem
internal partial class Http2Stream
{
private enum StreamState : byte
{
Expand All @@ -45,39 +20,95 @@ private enum StreamState : byte

private readonly Http2Connection _connection;
private readonly Http2ResponseWriter _writer;
private readonly FeatureCollection _featureCollection;
private FlowControlSize _serverWindowSize; // Controls Data received
private FlowControlSize _clientWindowSize; // Controls Data sent
private StreamState _state;
private CancellationTokenSource _cts;

public Http2Stream(uint streamId, uint initialWindowSize, Http2Connection connection)
public Http2Stream(Http2Connection connection, FeatureCollection featureCollection)
{
StreamId = streamId;
_clientWindowSize = new(initialWindowSize);
_serverWindowSize = new(connection.ServerOptions.ServerStreamFlowControlSize);
_connection = connection;
_writer = connection.ResponseWriter!;
_state = StreamState.Open;
_state = StreamState.Closed;
RequestEndHeaders = false;
_requestHeaders = new HeaderCollection();
_requestBodyPipe = new(new PipeOptions(MemoryPool<byte>.Shared));
_requestBodyPipeReader = new(_requestBodyPipe.Reader, ReleaseServerFlowControl);
_requestBodyPipeWriter = new(_requestBodyPipe.Writer, flushStartingCallback: ConsumeServerFlowControl, flushedCallback: null);

_responseHeaders = null;
_responseBodyPipe = new(new PipeOptions(MemoryPool<byte>.Shared));
_responseBodyPipeWriter = new(_responseBodyPipe.Writer, flushStartingCallback: size =>
{
if (!_hasStarted)
_ = StartAsync();
});
_cts = new();
StatusCode = 200;
_responseWriterFlushedResponse = new(0);
_clientFlowControlBarrier = new(1, 1);
_responseWritingTask = new TaskCompletionSource<Task>();

_featureCollection = featureCollection.Copy();
_featureCollection.Add<IHttpRequestFeature>(this);
_featureCollection.Add<IHttpResponseFeature>(this);
_featureCollection.Add<IHttpResponseBodyFeature>(this);
_featureCollection.Add<IHttpResponseTrailersFeature>(this);
_featureCollection.Add<IHttpRequestBodyDetectionFeature>(this);
_featureCollection.Add<IHttpRequestLifetimeFeature>(this);
}

public uint StreamId { get; }
public void Initialize(uint streamId, uint initialWindowSize, uint serverStreamFlowControlSize)
{
if (_state != StreamState.Closed)
throw new InvalidOperationException("Stream is in use.");
_state = StreamState.Open;
StreamId = streamId;
_clientWindowSize = new(initialWindowSize);
_serverWindowSize = new(serverStreamFlowControlSize);

public bool RequestEndHeaders { get; private set; }
// HasStarted is reset at initialization to avoid race conditions Complete() and StartAsync()
_hasStarted = false;
}

public void Reset()
{
if (_state != StreamState.Closed)
throw new InvalidOperationException("Stream is in use.");
StreamId = 0;
RequestEndHeaders = false;
_requestHeaders = new();
_requestBodyPipe.Reset();
_requestBodyPipeReader.Reset();
_requestBodyPipeWriter.Reset();

_responseBodyPipe.Reset();
_responseBodyPipeWriter.Reset();

protected abstract Task RunApplicationAsync();
_cts = new();
_responseHeaders = null;
_responseTrailers = null;
StatusCode = 200;
ReasonPhrase = null;
Scheme = string.Empty;
Method = string.Empty;
PathBase = string.Empty;
Path = string.Empty;
QueryString = string.Empty;
_onStartingCallback = null;
_onStartingState = null;
_onCompletedCallback = null;
_onCompletedState = null;
_responseWritingTask = new TaskCompletionSource<Task>();

_clientFlowControlBarrier = new(1, 1);
_responseWriterFlushedResponse = new(0);
}

public uint StreamId { get; private set; }

public bool RequestEndHeaders { get; private set; }

internal void RequestEndHeadersReceived() => RequestEndHeaders = true;

Expand Down Expand Up @@ -135,10 +166,13 @@ private void SetPath(ReadOnlySpan<byte> value)
QueryString = Encoding.Latin1.GetString(value[separatorIndex..]);
}

public async void Execute()
public async void Execute<TContext>(IHttpApplication<TContext> application) where TContext : notnull
{
_requestHeaders.SetReadOnly();
await RunApplicationAsync();
var context = application.CreateContext(_featureCollection.Copy());
await application.ProcessRequestAsync(context);
_requestBodyPipeReader.Complete();
_requestBodyPipeWriter.Complete();
_responseBodyPipeWriter.Complete();
await CompleteAsync();
}
Expand All @@ -147,6 +181,7 @@ public void Abort()
{
_cts.Cancel();
_state = StreamState.Closed;
IsAborted = true;
}

public void CompleteRequestStream()
Expand Down Expand Up @@ -205,6 +240,8 @@ internal partial class Http2Stream : IHttpRequestFeature, IHttpRequestBodyDetect

public bool CanHaveBody => true;

public bool IsAborted { get; private set; }

public PipeWriter RequestPipe => _state <= StreamState.HalfOpenLocal ?
_requestBodyPipeWriter : throw new Http2ConnectionException("STREAM CLOSED");

Expand All @@ -217,20 +254,20 @@ internal partial class Http2Stream : IHttpResponseFeature, IHttpResponseBodyFeat
{
private readonly Pipe _responseBodyPipe;
private readonly Http2StreamPipeWriter _responseBodyPipeWriter;
private readonly SemaphoreSlim _applicationFlushedResponse = new(0);
private readonly SemaphoreSlim _clientFlowControlBarrier = new(1, 1);
private SemaphoreSlim _clientFlowControlBarrier;
private SemaphoreSlim _responseWriterFlushedResponse;

private bool _hasStarted = false;
private Task? _responseWritingTask;
private TaskCompletionSource<Task> _responseWritingTask;
private HeaderCollection? _responseHeaders;
private HeaderCollection? _responseTrailers;

public int StatusCode { get; set; } = 200;
public int StatusCode { get; set; }
public string? ReasonPhrase { get; set; }

public bool HasStarted => _hasStarted;

public Stream Stream => throw new NotSupportedException($"Write with the {nameof(IHttpResponseBodyFeature.Writer)}");
public Stream Stream => _responseBodyPipeWriter.AsStream();

public PipeWriter Writer => _responseBodyPipeWriter;

Expand Down Expand Up @@ -306,16 +343,17 @@ public async Task StartAsync(CancellationToken cancellationToken = default)
_responseHeaders ??= new();
_responseHeaders.SetReadOnly();
cancellationToken.Register(() => _cts.Cancel());
_responseWritingTask = WriteResponseAsync(_cts.Token);
_responseWritingTask.SetResult(WriteResponseAsync(_cts.Token));
}

public async Task CompleteAsync()
{
var task = _responseWritingTask;
if (task == null)
var task = _responseWritingTask.Task;
if (!task.IsCompleted)
await StartAsync();

await _responseWritingTask!;
var responseWriting = await task.WaitAsync(_cts.Token);
await responseWriting;
}

private async Task WriteResponseAsync(CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -358,7 +396,7 @@ private async Task WriteBodyResponseAsync(CancellationToken token = default)

ResponseBodyBuffer = buffer.Slice(0, size);
_writer.ScheduleWriteData(this);
await _applicationFlushedResponse.WaitAsync(token);
await _responseWriterFlushedResponse.WaitAsync(token);
buffer = buffer.Slice(size);
}
_responseBodyPipe.Reader.AdvanceTo(readResult.Buffer.End);
Expand All @@ -376,7 +414,7 @@ private async Task WriteBodyResponseAsync(CancellationToken token = default)
public void OnResponseDataFlushed()
{
// Release semaphore for the next write.
_applicationFlushedResponse.Release(1);
_responseWriterFlushedResponse.Release(1);
}

private bool ReserveClientFlowControlSize(uint requestedSize, out uint reservedSize)
Expand Down Expand Up @@ -405,6 +443,7 @@ internal async Task OnStreamCompletedAsync()
{
await (_onCompletedCallback?.Invoke(_onCompletedState!) ?? Task.CompletedTask);
_state = StreamState.Closed;
_responseBodyPipe.Reader.Complete();
_connection.OnStreamCompleted(this);
}
}
Expand Down Expand Up @@ -453,6 +492,12 @@ public override async ValueTask<FlushResult> FlushAsync(CancellationToken cancel
public override Memory<byte> GetMemory(int sizeHint = 0) => _writer.GetMemory(sizeHint);

public override Span<byte> GetSpan(int sizeHint = 0) => _writer.GetSpan(sizeHint);

public void Reset()
{
_unflushedBytes = 0;
_completed = false;
}
}

internal class Http2StreamPipeReader(PipeReader reader, Action<int> onReadCallback) : PipeReader
Expand Down Expand Up @@ -499,4 +544,9 @@ public override bool TryRead(out ReadResult result)
return hasRead;

}

public void Reset()
{
_lastReadStart = default;
}
}
Loading
Loading