Skip to content

Commit

Permalink
Don't send a RST frame when closing the stream in a write future whil… (
Browse files Browse the repository at this point in the history
#13973)

…e processing inbound frames.

Motiviation:

Due a bug in netty we would send a RST frame in some cases even tho we
correctly received the endOfStream already. This is not necessary and
might even confuse the remote peer.

Modifications:

- Keep track of if we received endOfStream and send endOfStream in our
Channel implementation and only send a RST frame if this is not the case
during close
- Add unit tests

Result:

Don't send RST frame if we received endOfStream and send endOfStream.
  • Loading branch information
normanmaurer committed Apr 14, 2024
1 parent e19c91f commit 4d961d0
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ void fireChildRead(Http2Frame frame) {
// otherwise we would have drained it from the queue and processed it during the read cycle.
assert inboundBuffer == null || inboundBuffer.isEmpty();
final RecvByteBufAllocator.Handle allocHandle = unsafe.recvBufAllocHandle();

unsafe.doRead0(frame, allocHandle);
// We currently don't need to check for readEOS because the parent channel and child channel are limited
// to the same EventLoop thread. There are a limited number of frame types that may come after EOS is
Expand Down Expand Up @@ -635,6 +636,9 @@ private final class Http2ChannelUnsafe implements Unsafe {
private boolean closeInitiated;
private boolean readEOS;

private boolean receivedEndOfStream;
private boolean sentEndOfStream;

@Override
public void connect(final SocketAddress remoteAddress,
SocketAddress localAddress, final ChannelPromise promise) {
Expand Down Expand Up @@ -731,7 +735,9 @@ public void operationComplete(ChannelFuture future) {

// Only ever send a reset frame if the connection is still alive and if the stream was created before
// as otherwise we may send a RST on a stream in an invalid state and cause a connection error.
if (parent().isActive() && !readEOS && isStreamIdValid(stream.id())) {
if (parent().isActive() && isStreamIdValid(stream.id()) &&
// Also ensure the stream was never "closed" before.
!readEOS && !(receivedEndOfStream && sentEndOfStream)) {
Http2StreamFrame resetFrame = new DefaultHttp2ResetFrame(error).stream(stream());
write(resetFrame, unsafe().voidPromise());
flush();
Expand Down Expand Up @@ -953,7 +959,6 @@ void doRead0(Http2Frame frame, RecvByteBufAllocator.Handle allocHandle) {
final int bytes;
if (frame instanceof Http2DataFrame) {
bytes = ((Http2DataFrame) frame).initialFlowControlledBytes();

// It is important that we increment the flowControlledBytes before we call fireChannelRead(...)
// as it may cause a read() that will call updateLocalWindowIfNeeded() and we need to ensure
// in this case that we accounted for it.
Expand All @@ -963,6 +968,11 @@ void doRead0(Http2Frame frame, RecvByteBufAllocator.Handle allocHandle) {
} else {
bytes = MIN_HTTP2_FRAME_SIZE;
}

// Let's keep track of what we received as the stream state itself will only be updated once the frame
// was dispatched for reading which might cause problems if we try to close the channel in a write future.
receivedEndOfStream |= isEndOfStream(frame);

// Update before firing event through the pipeline to be consistent with other Channel implementation.
allocHandle.attemptedBytesRead(bytes);
allocHandle.lastBytesRead(bytes);
Expand Down Expand Up @@ -1003,6 +1013,16 @@ public void write(Object msg, final ChannelPromise promise) {
}
}

private boolean isEndOfStream(Http2Frame frame) {
if (frame instanceof Http2HeadersFrame) {
return ((Http2HeadersFrame) frame).isEndStream();
}
if (frame instanceof Http2DataFrame) {
return ((Http2DataFrame) frame).isEndStream();
}
return false;
}

private void writeHttp2StreamFrame(Http2StreamFrame frame, final ChannelPromise promise) {
if (!firstFrameWritten && !isStreamIdValid(stream().id()) && !(frame instanceof Http2HeadersFrame)) {
ReferenceCountUtil.release(frame);
Expand All @@ -1019,6 +1039,9 @@ private void writeHttp2StreamFrame(Http2StreamFrame frame, final ChannelPromise
firstWrite = firstFrameWritten = true;
}

// Let's keep track of what we send as the stream state itself will only be updated once the frame
// was written which might cause problems if we try to close the channel in a write future.
sentEndOfStream |= isEndOfStream(frame);
ChannelFuture f = write0(parentContext(), frame);
if (f.isDone()) {
if (firstWrite) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.socket.ChannelInputShutdownReadComplete;
import io.netty.channel.socket.ChannelOutputShutdownEvent;
import io.netty.handler.codec.UnsupportedMessageTypeException;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpScheme;
Expand All @@ -36,12 +37,15 @@
import io.netty.handler.ssl.SslCloseCompletionEvent;
import io.netty.util.AsciiString;
import io.netty.util.AttributeKey;
import io.netty.util.ReferenceCountUtil;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.function.Executable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentMatcher;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
Expand Down Expand Up @@ -71,6 +75,7 @@
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyInt;
Expand Down Expand Up @@ -229,6 +234,86 @@ public void headerAndDataFramesShouldBeDelivered() {
assertNull(inboundHandler.readInbound());
}

enum RstFrameTestMode {
HEADERS_END_STREAM,
DATA_END_STREAM,
TRAILERS_END_STREAM;
}
@ParameterizedTest
@EnumSource(RstFrameTestMode.class)
void noRstFrameSentOnCloseViaListener(final RstFrameTestMode mode) throws Exception {
LastInboundHandler inboundHandler = new LastInboundHandler() {
private boolean headersReceived;
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
try {
final boolean endStream;
if (msg instanceof Http2HeadersFrame) {
endStream = ((Http2HeadersFrame) msg).isEndStream();
switch (mode) {
case HEADERS_END_STREAM:
assertFalse(headersReceived);
assertTrue(endStream);
break;
case TRAILERS_END_STREAM:
if (headersReceived) {
assertTrue(endStream);
} else {
assertFalse(endStream);
}
break;
case DATA_END_STREAM:
assertFalse(endStream);
break;
default:
fail();
}
headersReceived = true;
} else if (msg instanceof Http2DataFrame) {
endStream = ((Http2DataFrame) msg).isEndStream();
switch (mode) {
case HEADERS_END_STREAM:
fail();
break;
case TRAILERS_END_STREAM:
assertFalse(endStream);
break;
case DATA_END_STREAM:
assertTrue(endStream);
break;
default:
fail();
}
} else {
throw new UnsupportedMessageTypeException(msg);
}
if (endStream) {
ctx.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers(), true, 0))
.addListener(ChannelFutureListener.CLOSE);
}
} finally {
ReferenceCountUtil.release(msg);
}
}
};

Http2StreamChannel channel = newInboundStream(3, mode == RstFrameTestMode.HEADERS_END_STREAM, inboundHandler);
if (mode != RstFrameTestMode.HEADERS_END_STREAM) {
frameInboundWriter.writeInboundData(
channel.stream().id(), bb("something"), 0, mode == RstFrameTestMode.DATA_END_STREAM);
if (mode != RstFrameTestMode.DATA_END_STREAM) {
frameInboundWriter.writeInboundHeaders(channel.stream().id(), new DefaultHttp2Headers(), 0, true);
}
}
channel.closeFuture().syncUninterruptibly();

// We should never produce a RST frame in this case as we received the endOfStream before we write a frame
// with the endOfStream flag.
verify(frameWriter, never()).writeRstStream(eqCodecCtx(),
eqStreamId(channel), anyLong(), anyChannelPromise());
inboundHandler.checkException();
}

@Test
public void headerMultipleContentLengthValidationShouldPropagate() {
headerMultipleContentLengthValidationShouldPropagate(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,24 @@
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.DefaultEventLoop;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.ssl.ApplicationProtocolConfig;
import io.netty.handler.ssl.ApplicationProtocolNames;
import io.netty.handler.ssl.ApplicationProtocolNegotiationHandler;
Expand Down Expand Up @@ -71,7 +85,10 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import static io.netty.handler.codec.http2.Http2FrameCodecBuilder.forClient;
import static io.netty.handler.codec.http2.Http2FrameCodecBuilder.forServer;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand Down Expand Up @@ -743,4 +760,117 @@ public boolean isSharable() {
}
}
}

@Test
public void testServerCloseShouldNotSendResetIfClientSentEOS() throws Exception {
EventLoopGroup group = null;
Channel serverChannel = null;
Channel clientChannel = null;
Channel clientStreamChannel = null;
try {
final CountDownLatch clientReceivedResponseLatch = new CountDownLatch(1);
final CountDownLatch resetFrameLatch = new CountDownLatch(1);
group = new DefaultEventLoop();
LocalAddress serverAddress = new LocalAddress(getClass().getName());
ServerBootstrap sb = new ServerBootstrap()
.channel(LocalServerChannel.class)
.group(group)
.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast(forServer().build());
pipeline.addLast(new Http2FrameIgnore<Http2SettingsFrame>(Http2SettingsFrame.class));
pipeline.addLast(new Http2FrameIgnore<Http2SettingsAckFrame>(Http2SettingsAckFrame.class));
pipeline.addLast(new Http2MultiplexHandler(new ChannelInitializer<Http2StreamChannel>() {
@Override
protected void initChannel(Http2StreamChannel ch) {
ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast(new Http2StreamFrameToHttpObjectCodec(true, true));
pipeline.addLast(new HttpObjectAggregator(16384));
pipeline.addLast(new SimpleChannelInboundHandler<FullHttpRequest>() {
@Override
protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest msg) {
ctx.writeAndFlush(
new DefaultFullHttpResponse(
msg.protocolVersion(), HttpResponseStatus.OK,
Unpooled.copiedBuffer("hello", CharsetUtil.US_ASCII)))
.addListeners(ChannelFutureListener.CLOSE);
}
});
}
}));
}
});
serverChannel = sb.bind(serverAddress).sync().channel();

Bootstrap cb = new Bootstrap()
.channel(LocalChannel.class)
.group(group)
.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast(forClient().build());
pipeline.addLast(new Http2FrameIgnore<Http2SettingsFrame>(Http2SettingsFrame.class));
pipeline.addLast(new Http2FrameIgnore<Http2SettingsAckFrame>(Http2SettingsAckFrame.class));
pipeline.addLast(new Http2MultiplexHandler(new ChannelInitializer<Http2StreamChannel>() {
@Override
protected void initChannel(Http2StreamChannel ch) {
// noop
}
}));
}
});

clientChannel = cb.connect(serverAddress).sync().channel();
clientStreamChannel = new Http2StreamChannelBootstrap(clientChannel)
.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast(new Http2StreamFrameToHttpObjectCodec(false, true));
pipeline.addLast(new HttpObjectAggregator(16384));
pipeline.addLast(new SimpleChannelInboundHandler<FullHttpResponse>() {
@Override
protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) {
clientReceivedResponseLatch.countDown();
}
});
}
})
.open().sync().get();

clientStreamChannel.writeAndFlush(
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/test/")).sync();

assertTrue(clientReceivedResponseLatch.await(3, SECONDS));

// The server should NOT send any RST_STREAM frame.
assertFalse(resetFrameLatch.await(1, SECONDS));
} finally {
if (clientStreamChannel != null) {
clientStreamChannel.close().syncUninterruptibly();
}
if (clientChannel != null) {
clientChannel.close().syncUninterruptibly();
}
if (serverChannel != null) {
serverChannel.close().syncUninterruptibly();
}
if (group != null) {
group.shutdownGracefully(0, 3, SECONDS);
}
}
}

private static final class Http2FrameIgnore<T extends Http2Frame> extends SimpleChannelInboundHandler<T> {
Http2FrameIgnore(Class<? extends T> inboundMessageType) {
super(inboundMessageType);
}

@Override
protected void channelRead0(ChannelHandlerContext ctx, T msg) {
}
}
}

0 comments on commit 4d961d0

Please sign in to comment.