Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change API of Http3RequestStreamInboundHandler to better handle FIN #240

Merged
merged 1 commit into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandler;
Expand Down Expand Up @@ -60,6 +59,7 @@ public final class Http3FrameToHttpObjectCodec extends Http3RequestStreamInbound

private final boolean isServer;
private final boolean validateHeaders;
private boolean inboundTranslationInProgress;

public Http3FrameToHttpObjectCodec(final boolean isServer,
final boolean validateHeaders) {
Expand All @@ -72,7 +72,12 @@ public Http3FrameToHttpObjectCodec(final boolean isServer) {
}

@Override
protected void channelRead(ChannelHandlerContext ctx, Http3HeadersFrame frame, boolean isLast) throws Exception {
public boolean isSharable() {
return false;
}

@Override
protected void channelRead(ChannelHandlerContext ctx, Http3HeadersFrame frame) throws Exception {
Http3Headers headers = frame.headers();
long id = ((QuicStreamChannel) ctx.channel()).streamId();

Expand All @@ -86,31 +91,33 @@ protected void channelRead(ChannelHandlerContext ctx, Http3HeadersFrame frame, b
return;
}

if (isLast) {
if (headers.method() == null && status == null) {
LastHttpContent last = new DefaultLastHttpContent(Unpooled.EMPTY_BUFFER, validateHeaders);
HttpConversionUtil.addHttp3ToHttpHeaders(id, headers, last.trailingHeaders(),
HttpVersion.HTTP_1_1, true, true);
ctx.fireChannelRead(last);
} else {
FullHttpMessage full = newFullMessage(id, headers, ctx.alloc());
ctx.fireChannelRead(full);
}
if (headers.method() == null && status == null) {
// Must be trailers!
LastHttpContent last = new DefaultLastHttpContent(Unpooled.EMPTY_BUFFER, validateHeaders);
HttpConversionUtil.addHttp3ToHttpHeaders(id, headers, last.trailingHeaders(),
HttpVersion.HTTP_1_1, true, true);
inboundTranslationInProgress = false;
ctx.fireChannelRead(last);
} else {
HttpMessage req = newMessage(id, headers);
if (!HttpUtil.isContentLengthSet(req)) {
req.headers().add(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED);
}
inboundTranslationInProgress = true;
ctx.fireChannelRead(req);
}
}

@Override
protected void channelRead(ChannelHandlerContext ctx, Http3DataFrame frame, boolean isLast) throws Exception {
if (isLast) {
ctx.fireChannelRead(new DefaultLastHttpContent(frame.content()));
} else {
ctx.fireChannelRead(new DefaultHttpContent(frame.content()));
protected void channelRead(ChannelHandlerContext ctx, Http3DataFrame frame) throws Exception {
inboundTranslationInProgress = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be hit without a Http3HeadersFrame that already set this flag?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should never happen as a header frame must be sent first (which is validated by the frame decoder etc)

ctx.fireChannelRead(new DefaultHttpContent(frame.content()));
}

@Override
protected void channelInputClosed(ChannelHandlerContext ctx) throws Exception {
if (inboundTranslationInProgress) {
ctx.fireChannelRead(LastHttpContent.EMPTY_LAST_CONTENT);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@
*/
package io.netty.incubator.codec.http3;

import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.socket.ChannelInputShutdownEvent;
import io.netty.incubator.codec.quic.QuicException;
import io.netty.incubator.codec.quic.QuicStreamChannel;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;

Expand All @@ -32,72 +30,24 @@
public abstract class Http3RequestStreamInboundHandler extends ChannelInboundHandlerAdapter {
private static final InternalLogger logger =
InternalLoggerFactory.getInstance(Http3RequestStreamInboundHandler.class);
private static final Http3DataFrame EMPTY = new DefaultHttp3DataFrame(Unpooled.EMPTY_BUFFER);
private Object bufferedMessage;
private boolean lastFrameDetected;
private boolean firstFrameReceived;

/**
* Always returns {@code true} as this handler and sub-types are not sharable, due internal state.
*/
@Override
public final boolean isSharable() {
return false;
}

@Override
public final void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
firstFrameReceived = true;
handleBufferedMessage(ctx, false);
bufferedMessage = msg;
}

@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
// Once we receive a channelReadComplete we know that we handled all the data that was contained in a
// stream frame. At this point we should check if the input was closed and so if this received frame
// will be the last on the stream.
handleBufferedMessage(ctx, ((QuicStreamChannel) ctx.channel()).isInputShutdown());

super.channelReadComplete(ctx);
}

@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
if (bufferedMessage != null) {
ReferenceCountUtil.release(bufferedMessage);
bufferedMessage = null;
}
}

private void handleBufferedMessage(ChannelHandlerContext ctx, boolean noMoreInput) throws Exception {
Object msg = bufferedMessage;
if (msg == null) {
return;
}
bufferedMessage = null;
if (msg instanceof Http3UnknownFrame) {
channelRead(ctx, (Http3UnknownFrame) msg);
if (noMoreInput) {
notifyLast(ctx);
}
} else if (msg instanceof Http3HeadersFrame) {
channelRead(ctx, (Http3HeadersFrame) msg);
} else if (msg instanceof Http3DataFrame) {
channelRead(ctx, (Http3DataFrame) msg);
} else {
if (noMoreInput) {
lastFrameDetected = true;
}
if (msg instanceof Http3HeadersFrame) {
channelRead(ctx, (Http3HeadersFrame) msg, noMoreInput);
}
if (msg instanceof Http3DataFrame) {
channelRead(ctx, (Http3DataFrame) msg, noMoreInput);
}
super.channelRead(ctx, msg);
}
}

@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt == ChannelInputShutdownEvent.INSTANCE) {
notifyLast(ctx);
channelInputClosed(ctx);
}
ctx.fireUserEventTriggered(evt);
}
Expand All @@ -113,34 +63,31 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
}
}

private void notifyLast(ChannelHandlerContext ctx) throws Exception {
if (!lastFrameDetected && firstFrameReceived) {
lastFrameDetected = true;
channelRead(ctx, EMPTY, true);
}
}

/**
* Called once a {@link Http3HeadersFrame} is ready for this stream to process.
*
* @param ctx the {@link ChannelHandlerContext} of this handler.
* @param frame the {@link Http3HeadersFrame} that was read
* @param isLast {@code true} if this is the last frame that will be read for this stream.
* @throws Exception thrown if an error happens during processing.
*/
protected abstract void channelRead(ChannelHandlerContext ctx, Http3HeadersFrame frame, boolean isLast)
throws Exception;
protected abstract void channelRead(ChannelHandlerContext ctx, Http3HeadersFrame frame) throws Exception;

/**
* Called once a {@link Http3DataFrame} is ready for this stream to process.
*
* @param ctx the {@link ChannelHandlerContext} of this handler.
* @param frame the {@link Http3DataFrame} that was read
* @param isLast {@code true} if this is the last frame that will be read for this stream.
* @throws Exception thrown if an error happens during processing.
*/
protected abstract void channelRead(ChannelHandlerContext ctx, Http3DataFrame frame, boolean isLast)
throws Exception;
protected abstract void channelRead(ChannelHandlerContext ctx, Http3DataFrame frame) throws Exception;

/**
* Called once the input is closed and so no more inbound data is received on it.
*
* @param ctx the {@link ChannelHandlerContext} of this handler.
* @throws Exception thrown if an error happens during processing.
*/
protected abstract void channelInputClosed(ChannelHandlerContext ctx) throws Exception;

/**
* Called once a {@link Http3UnknownFrame} is ready for this stream to process. By default these frames are just
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,31 +306,6 @@ public void testDowngradeHeadersWithContentLength() {
assertFalse(ch.finish());
}

@Test
public void testDowngradeFullHeaders() {
EmbeddedQuicStreamChannel ch = new EmbeddedQuicStreamChannel(new Http3FrameToHttpObjectCodec(true));
Http3Headers headers = new DefaultHttp3Headers();
headers.path("/");
headers.method("GET");

assertTrue(ch.writeInboundWithFin(new DefaultHttp3HeadersFrame(headers)));

FullHttpRequest request = ch.readInbound();
try {
assertThat(request.uri(), is("/"));
assertThat(request.method(), is(HttpMethod.GET));
assertThat(request.protocolVersion(), is(HttpVersion.HTTP_1_1));
assertThat(request.content().readableBytes(), is(0));
assertTrue(request.trailingHeaders().isEmpty());
assertFalse(HttpUtil.isTransferEncodingChunked(request));
} finally {
request.release();
}

assertThat(ch.readInbound(), is(nullValue()));
assertFalse(ch.finish());
}

@Test
public void testDowngradeTrailers() {
EmbeddedQuicStreamChannel ch = new EmbeddedQuicStreamChannel(new Http3FrameToHttpObjectCodec(true));
Expand Down Expand Up @@ -376,14 +351,21 @@ public void testDowngradeEndData() {
ByteBuf hello = Unpooled.copiedBuffer("hello world", CharsetUtil.UTF_8);
assertTrue(ch.writeInboundWithFin(new DefaultHttp3DataFrame(hello)));

LastHttpContent content = ch.readInbound();
HttpContent content = ch.readInbound();
try {
assertThat(content.content().toString(CharsetUtil.UTF_8), is("hello world"));
assertTrue(content.trailingHeaders().isEmpty());
} finally {
content.release();
}

LastHttpContent last = ch.readInbound();
try {
assertFalse(last.content().isReadable());
assertTrue(last.trailingHeaders().isEmpty());
} finally {
last.release();
}

assertThat(ch.readInbound(), is(nullValue()));
assertFalse(ch.finish());
}
Expand Down Expand Up @@ -882,34 +864,6 @@ public void testDecodeResponseHeadersWithContentLength() {
assertFalse(ch.finish());
}

@Test
public void testDecodeFullResponseHeaders() {
EmbeddedQuicStreamChannel ch = new EmbeddedQuicStreamChannel(new Http3FrameToHttpObjectCodec(false));
Http3Headers headers = new DefaultHttp3Headers();
headers.scheme(HttpScheme.HTTP.name());
headers.status(HttpResponseStatus.OK.codeAsText());

Http3HeadersFrame frame = new DefaultHttp3HeadersFrame(headers);

assertTrue(ch.writeInboundWithFin(frame));

FullHttpResponse response = ch.readInbound();
try {
assertThat(response.status(), is(HttpResponseStatus.OK));
assertThat(response.protocolVersion(), is(HttpVersion.HTTP_1_1));
assertThat(response.content().readableBytes(), is(0));
assertTrue(response.trailingHeaders().isEmpty());
assertFalse(HttpUtil.isTransferEncodingChunked(response));
assertEquals(0,
(int) response.headers().getInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text()));
} finally {
response.release();
}

assertThat(ch.readInbound(), is(nullValue()));
assertFalse(ch.finish());
}

@Test
public void testDecodeResponseTrailersAsClient() {
EmbeddedQuicStreamChannel ch = new EmbeddedQuicStreamChannel(new Http3FrameToHttpObjectCodec(false));
Expand Down Expand Up @@ -954,14 +908,21 @@ public void testDecodeEndDataAsClient() {
ByteBuf hello = Unpooled.copiedBuffer("hello world", CharsetUtil.UTF_8);
assertTrue(ch.writeInboundWithFin(new DefaultHttp3DataFrame(hello)));

LastHttpContent content = ch.readInbound();
HttpContent content = ch.readInbound();
try {
assertThat(content.content().toString(CharsetUtil.UTF_8), is("hello world"));
assertTrue(content.trailingHeaders().isEmpty());
} finally {
content.release();
}

LastHttpContent last = ch.readInbound();
try {
assertFalse(last.content().isReadable());
assertTrue(last.trailingHeaders().isEmpty());
} finally {
last.release();
}

assertThat(ch.readInbound(), is(nullValue()));
assertFalse(ch.finish());
}
Expand Down Expand Up @@ -1074,10 +1035,13 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception

HttpResponse respHeaders = (HttpResponse) received.poll(20, TimeUnit.SECONDS);
assertThat(respHeaders.status(), is(HttpResponseStatus.OK));
assertThat(respHeaders, not(instanceOf(LastHttpContent.class))); // this assertion failed before this PR
LastHttpContent respBody = (LastHttpContent) received.poll(20, TimeUnit.SECONDS);
assertThat(respHeaders, not(instanceOf(LastHttpContent.class)));
HttpContent respBody = (HttpContent) received.poll(20, TimeUnit.SECONDS);
assertThat(respBody.content().toString(CharsetUtil.UTF_8), is("foo"));
respBody.release();

LastHttpContent last = (LastHttpContent) received.poll(20, TimeUnit.SECONDS);
last.release();
} finally {
group.shutdownGracefully();
}
Expand Down