Skip to content

Commit

Permalink
Change API of Http3RequestStreamInboundHandler to better handle FIN (#…
Browse files Browse the repository at this point in the history
…240)

Motivation:

5740c52 introduced a workaround for
correctly handle FIN when multiple http3 frames are contained in one
QUIC frame. While this worked it was kind of hacky and we should better
adjust the API.

Modifications:

- Change Http3RequestStreamInboundHandler API for handling frames and
FIN
- Adjust tests.

Result:

Better fix for handling multiple frames
  • Loading branch information
normanmaurer committed Aug 31, 2023
1 parent 815a8cc commit 01a15ab
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 207 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandler;
Expand Down Expand Up @@ -60,6 +59,7 @@ public final class Http3FrameToHttpObjectCodec extends Http3RequestStreamInbound

private final boolean isServer;
private final boolean validateHeaders;
private boolean inboundTranslationInProgress;

public Http3FrameToHttpObjectCodec(final boolean isServer,
final boolean validateHeaders) {
Expand All @@ -72,7 +72,12 @@ public Http3FrameToHttpObjectCodec(final boolean isServer) {
}

@Override
protected void channelRead(ChannelHandlerContext ctx, Http3HeadersFrame frame, boolean isLast) throws Exception {
public boolean isSharable() {
return false;
}

@Override
protected void channelRead(ChannelHandlerContext ctx, Http3HeadersFrame frame) throws Exception {
Http3Headers headers = frame.headers();
long id = ((QuicStreamChannel) ctx.channel()).streamId();

Expand All @@ -86,31 +91,33 @@ protected void channelRead(ChannelHandlerContext ctx, Http3HeadersFrame frame, b
return;
}

if (isLast) {
if (headers.method() == null && status == null) {
LastHttpContent last = new DefaultLastHttpContent(Unpooled.EMPTY_BUFFER, validateHeaders);
HttpConversionUtil.addHttp3ToHttpHeaders(id, headers, last.trailingHeaders(),
HttpVersion.HTTP_1_1, true, true);
ctx.fireChannelRead(last);
} else {
FullHttpMessage full = newFullMessage(id, headers, ctx.alloc());
ctx.fireChannelRead(full);
}
if (headers.method() == null && status == null) {
// Must be trailers!
LastHttpContent last = new DefaultLastHttpContent(Unpooled.EMPTY_BUFFER, validateHeaders);
HttpConversionUtil.addHttp3ToHttpHeaders(id, headers, last.trailingHeaders(),
HttpVersion.HTTP_1_1, true, true);
inboundTranslationInProgress = false;
ctx.fireChannelRead(last);
} else {
HttpMessage req = newMessage(id, headers);
if (!HttpUtil.isContentLengthSet(req)) {
req.headers().add(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED);
}
inboundTranslationInProgress = true;
ctx.fireChannelRead(req);
}
}

@Override
protected void channelRead(ChannelHandlerContext ctx, Http3DataFrame frame, boolean isLast) throws Exception {
if (isLast) {
ctx.fireChannelRead(new DefaultLastHttpContent(frame.content()));
} else {
ctx.fireChannelRead(new DefaultHttpContent(frame.content()));
protected void channelRead(ChannelHandlerContext ctx, Http3DataFrame frame) throws Exception {
inboundTranslationInProgress = true;
ctx.fireChannelRead(new DefaultHttpContent(frame.content()));
}

@Override
protected void channelInputClosed(ChannelHandlerContext ctx) throws Exception {
if (inboundTranslationInProgress) {
ctx.fireChannelRead(LastHttpContent.EMPTY_LAST_CONTENT);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@
*/
package io.netty.incubator.codec.http3;

import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.socket.ChannelInputShutdownEvent;
import io.netty.incubator.codec.quic.QuicException;
import io.netty.incubator.codec.quic.QuicStreamChannel;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;

Expand All @@ -32,72 +30,24 @@
public abstract class Http3RequestStreamInboundHandler extends ChannelInboundHandlerAdapter {
private static final InternalLogger logger =
InternalLoggerFactory.getInstance(Http3RequestStreamInboundHandler.class);
private static final Http3DataFrame EMPTY = new DefaultHttp3DataFrame(Unpooled.EMPTY_BUFFER);
private Object bufferedMessage;
private boolean lastFrameDetected;
private boolean firstFrameReceived;

/**
* Always returns {@code true} as this handler and sub-types are not sharable, due internal state.
*/
@Override
public final boolean isSharable() {
return false;
}

@Override
public final void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
firstFrameReceived = true;
handleBufferedMessage(ctx, false);
bufferedMessage = msg;
}

@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
// Once we receive a channelReadComplete we know that we handled all the data that was contained in a
// stream frame. At this point we should check if the input was closed and so if this received frame
// will be the last on the stream.
handleBufferedMessage(ctx, ((QuicStreamChannel) ctx.channel()).isInputShutdown());

super.channelReadComplete(ctx);
}

@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
if (bufferedMessage != null) {
ReferenceCountUtil.release(bufferedMessage);
bufferedMessage = null;
}
}

private void handleBufferedMessage(ChannelHandlerContext ctx, boolean noMoreInput) throws Exception {
Object msg = bufferedMessage;
if (msg == null) {
return;
}
bufferedMessage = null;
if (msg instanceof Http3UnknownFrame) {
channelRead(ctx, (Http3UnknownFrame) msg);
if (noMoreInput) {
notifyLast(ctx);
}
} else if (msg instanceof Http3HeadersFrame) {
channelRead(ctx, (Http3HeadersFrame) msg);
} else if (msg instanceof Http3DataFrame) {
channelRead(ctx, (Http3DataFrame) msg);
} else {
if (noMoreInput) {
lastFrameDetected = true;
}
if (msg instanceof Http3HeadersFrame) {
channelRead(ctx, (Http3HeadersFrame) msg, noMoreInput);
}
if (msg instanceof Http3DataFrame) {
channelRead(ctx, (Http3DataFrame) msg, noMoreInput);
}
super.channelRead(ctx, msg);
}
}

@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt == ChannelInputShutdownEvent.INSTANCE) {
notifyLast(ctx);
channelInputClosed(ctx);
}
ctx.fireUserEventTriggered(evt);
}
Expand All @@ -113,34 +63,31 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
}
}

private void notifyLast(ChannelHandlerContext ctx) throws Exception {
if (!lastFrameDetected && firstFrameReceived) {
lastFrameDetected = true;
channelRead(ctx, EMPTY, true);
}
}

/**
* Called once a {@link Http3HeadersFrame} is ready for this stream to process.
*
* @param ctx the {@link ChannelHandlerContext} of this handler.
* @param frame the {@link Http3HeadersFrame} that was read
* @param isLast {@code true} if this is the last frame that will be read for this stream.
* @throws Exception thrown if an error happens during processing.
*/
protected abstract void channelRead(ChannelHandlerContext ctx, Http3HeadersFrame frame, boolean isLast)
throws Exception;
protected abstract void channelRead(ChannelHandlerContext ctx, Http3HeadersFrame frame) throws Exception;

/**
* Called once a {@link Http3DataFrame} is ready for this stream to process.
*
* @param ctx the {@link ChannelHandlerContext} of this handler.
* @param frame the {@link Http3DataFrame} that was read
* @param isLast {@code true} if this is the last frame that will be read for this stream.
* @throws Exception thrown if an error happens during processing.
*/
protected abstract void channelRead(ChannelHandlerContext ctx, Http3DataFrame frame, boolean isLast)
throws Exception;
protected abstract void channelRead(ChannelHandlerContext ctx, Http3DataFrame frame) throws Exception;

/**
* Called once the input is closed and so no more inbound data is received on it.
*
* @param ctx the {@link ChannelHandlerContext} of this handler.
* @throws Exception thrown if an error happens during processing.
*/
protected abstract void channelInputClosed(ChannelHandlerContext ctx) throws Exception;

/**
* Called once a {@link Http3UnknownFrame} is ready for this stream to process. By default these frames are just
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,31 +306,6 @@ public void testDowngradeHeadersWithContentLength() {
assertFalse(ch.finish());
}

@Test
public void testDowngradeFullHeaders() {
EmbeddedQuicStreamChannel ch = new EmbeddedQuicStreamChannel(new Http3FrameToHttpObjectCodec(true));
Http3Headers headers = new DefaultHttp3Headers();
headers.path("/");
headers.method("GET");

assertTrue(ch.writeInboundWithFin(new DefaultHttp3HeadersFrame(headers)));

FullHttpRequest request = ch.readInbound();
try {
assertThat(request.uri(), is("/"));
assertThat(request.method(), is(HttpMethod.GET));
assertThat(request.protocolVersion(), is(HttpVersion.HTTP_1_1));
assertThat(request.content().readableBytes(), is(0));
assertTrue(request.trailingHeaders().isEmpty());
assertFalse(HttpUtil.isTransferEncodingChunked(request));
} finally {
request.release();
}

assertThat(ch.readInbound(), is(nullValue()));
assertFalse(ch.finish());
}

@Test
public void testDowngradeTrailers() {
EmbeddedQuicStreamChannel ch = new EmbeddedQuicStreamChannel(new Http3FrameToHttpObjectCodec(true));
Expand Down Expand Up @@ -376,14 +351,21 @@ public void testDowngradeEndData() {
ByteBuf hello = Unpooled.copiedBuffer("hello world", CharsetUtil.UTF_8);
assertTrue(ch.writeInboundWithFin(new DefaultHttp3DataFrame(hello)));

LastHttpContent content = ch.readInbound();
HttpContent content = ch.readInbound();
try {
assertThat(content.content().toString(CharsetUtil.UTF_8), is("hello world"));
assertTrue(content.trailingHeaders().isEmpty());
} finally {
content.release();
}

LastHttpContent last = ch.readInbound();
try {
assertFalse(last.content().isReadable());
assertTrue(last.trailingHeaders().isEmpty());
} finally {
last.release();
}

assertThat(ch.readInbound(), is(nullValue()));
assertFalse(ch.finish());
}
Expand Down Expand Up @@ -882,34 +864,6 @@ public void testDecodeResponseHeadersWithContentLength() {
assertFalse(ch.finish());
}

@Test
public void testDecodeFullResponseHeaders() {
EmbeddedQuicStreamChannel ch = new EmbeddedQuicStreamChannel(new Http3FrameToHttpObjectCodec(false));
Http3Headers headers = new DefaultHttp3Headers();
headers.scheme(HttpScheme.HTTP.name());
headers.status(HttpResponseStatus.OK.codeAsText());

Http3HeadersFrame frame = new DefaultHttp3HeadersFrame(headers);

assertTrue(ch.writeInboundWithFin(frame));

FullHttpResponse response = ch.readInbound();
try {
assertThat(response.status(), is(HttpResponseStatus.OK));
assertThat(response.protocolVersion(), is(HttpVersion.HTTP_1_1));
assertThat(response.content().readableBytes(), is(0));
assertTrue(response.trailingHeaders().isEmpty());
assertFalse(HttpUtil.isTransferEncodingChunked(response));
assertEquals(0,
(int) response.headers().getInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text()));
} finally {
response.release();
}

assertThat(ch.readInbound(), is(nullValue()));
assertFalse(ch.finish());
}

@Test
public void testDecodeResponseTrailersAsClient() {
EmbeddedQuicStreamChannel ch = new EmbeddedQuicStreamChannel(new Http3FrameToHttpObjectCodec(false));
Expand Down Expand Up @@ -954,14 +908,21 @@ public void testDecodeEndDataAsClient() {
ByteBuf hello = Unpooled.copiedBuffer("hello world", CharsetUtil.UTF_8);
assertTrue(ch.writeInboundWithFin(new DefaultHttp3DataFrame(hello)));

LastHttpContent content = ch.readInbound();
HttpContent content = ch.readInbound();
try {
assertThat(content.content().toString(CharsetUtil.UTF_8), is("hello world"));
assertTrue(content.trailingHeaders().isEmpty());
} finally {
content.release();
}

LastHttpContent last = ch.readInbound();
try {
assertFalse(last.content().isReadable());
assertTrue(last.trailingHeaders().isEmpty());
} finally {
last.release();
}

assertThat(ch.readInbound(), is(nullValue()));
assertFalse(ch.finish());
}
Expand Down Expand Up @@ -1074,10 +1035,13 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception

HttpResponse respHeaders = (HttpResponse) received.poll(20, TimeUnit.SECONDS);
assertThat(respHeaders.status(), is(HttpResponseStatus.OK));
assertThat(respHeaders, not(instanceOf(LastHttpContent.class))); // this assertion failed before this PR
LastHttpContent respBody = (LastHttpContent) received.poll(20, TimeUnit.SECONDS);
assertThat(respHeaders, not(instanceOf(LastHttpContent.class)));
HttpContent respBody = (HttpContent) received.poll(20, TimeUnit.SECONDS);
assertThat(respBody.content().toString(CharsetUtil.UTF_8), is("foo"));
respBody.release();

LastHttpContent last = (LastHttpContent) received.poll(20, TimeUnit.SECONDS);
last.release();
} finally {
group.shutdownGracefully();
}
Expand Down

0 comments on commit 01a15ab

Please sign in to comment.