diff --git a/netty/src/main/java/io/grpc/netty/NettyServer.java b/netty/src/main/java/io/grpc/netty/NettyServer.java index fe7913870fa..2960604e5b5 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServer.java +++ b/netty/src/main/java/io/grpc/netty/NettyServer.java @@ -99,6 +99,8 @@ class NettyServer implements InternalServer, InternalWithLogId { private final long maxConnectionAgeGraceInNanos; private final boolean permitKeepAliveWithoutCalls; private final long permitKeepAliveTimeInNanos; + private final int maxRstCount; + private final long maxRstPeriodNanos; private final Attributes eagAttributes; private final ReferenceCounted sharedResourceReferenceCounter = new SharedResourceReferenceCounter(); @@ -127,6 +129,7 @@ class NettyServer implements InternalServer, InternalWithLogId { long maxConnectionIdleInNanos, long maxConnectionAgeInNanos, long maxConnectionAgeGraceInNanos, boolean permitKeepAliveWithoutCalls, long permitKeepAliveTimeInNanos, + int maxRstCount, long maxRstPeriodNanos, Attributes eagAttributes, InternalChannelz channelz) { this.addresses = checkNotNull(addresses, "addresses"); this.channelFactory = checkNotNull(channelFactory, "channelFactory"); @@ -156,6 +159,8 @@ class NettyServer implements InternalServer, InternalWithLogId { this.maxConnectionAgeGraceInNanos = maxConnectionAgeGraceInNanos; this.permitKeepAliveWithoutCalls = permitKeepAliveWithoutCalls; this.permitKeepAliveTimeInNanos = permitKeepAliveTimeInNanos; + this.maxRstCount = maxRstCount; + this.maxRstPeriodNanos = maxRstPeriodNanos; this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes"); this.channelz = Preconditions.checkNotNull(channelz); this.logId = InternalLogId.allocate(getClass(), addresses.isEmpty() ? "No address" : @@ -257,6 +262,8 @@ public void initChannel(Channel ch) { maxConnectionAgeGraceInNanos, permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos, + maxRstCount, + maxRstPeriodNanos, eagAttributes); ServerTransportListener transportListener; // This is to order callbacks on the listener, not to guard access to channel. diff --git a/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java b/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java index 9411a979ed4..525f8953e05 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java @@ -75,6 +75,7 @@ public final class NettyServerBuilder extends ForwardingServerBuildergRPC clients send RST_STREAM when they cancel RPCs, so some RST_STREAMs are normal and + * setting this too low can cause errors for legimitate clients. + * + *

By default there is no limit. + * + * @param maxRstStream the positive limit of RST_STREAM frames per connection per period, or + * {@code Integer.MAX_VALUE} for unlimited + * @param secondsPerWindow the positive number of seconds per period + */ + @CanIgnoreReturnValue + public NettyServerBuilder maxRstFramesPerWindow(int maxRstStream, int secondsPerWindow) { + checkArgument(maxRstStream > 0, "maxRstStream must be positive"); + checkArgument(secondsPerWindow > 0, "secondsPerWindow must be positive"); + if (maxRstStream == Integer.MAX_VALUE) { + maxRstStream = MAX_RST_COUNT_DISABLED; + } + this.maxRstCount = maxRstStream; + this.maxRstPeriodNanos = TimeUnit.SECONDS.toNanos(secondsPerWindow); + return this; + } + /** Sets the EAG attributes available to protocol negotiators. Not for general use. */ void eagAttributes(Attributes eagAttributes) { this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes"); @@ -664,7 +694,7 @@ NettyServer buildTransportServers( keepAliveTimeInNanos, keepAliveTimeoutInNanos, maxConnectionIdleInNanos, maxConnectionAgeInNanos, maxConnectionAgeGraceInNanos, permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos, - eagAttributes, this.serverImplBuilder.getChannelz()); + maxRstCount, maxRstPeriodNanos, eagAttributes, this.serverImplBuilder.getChannelz()); } @VisibleForTesting diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index bbf2f17748b..2b1e8126fdc 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -125,10 +125,13 @@ class NettyServerHandler extends AbstractNettyHandler { private final long keepAliveTimeoutInNanos; private final long maxConnectionAgeInNanos; private final long maxConnectionAgeGraceInNanos; + private final int maxRstCount; + private final long maxRstPeriodNanos; private final List streamTracerFactories; private final TransportTracer transportTracer; private final KeepAliveEnforcer keepAliveEnforcer; private final Attributes eagAttributes; + private final Ticker ticker; /** Incomplete attributes produced by negotiator. */ private Attributes negotiationAttributes; private InternalChannelz.Security securityInfo; @@ -146,6 +149,9 @@ class NettyServerHandler extends AbstractNettyHandler { private ScheduledFuture maxConnectionAgeMonitor; @CheckForNull private GracefulShutdown gracefulShutdown; + private int rstCount; + private long lastRstNanoTime; + static NettyServerHandler newHandler( ServerTransportListener transportListener, @@ -164,6 +170,8 @@ static NettyServerHandler newHandler( long maxConnectionAgeGraceInNanos, boolean permitKeepAliveWithoutCalls, long permitKeepAliveTimeInNanos, + int maxRstCount, + long maxRstPeriodNanos, Attributes eagAttributes) { Preconditions.checkArgument(maxHeaderListSize > 0, "maxHeaderListSize must be positive: %s", maxHeaderListSize); @@ -192,6 +200,8 @@ static NettyServerHandler newHandler( maxConnectionAgeGraceInNanos, permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos, + maxRstCount, + maxRstPeriodNanos, eagAttributes, Ticker.systemTicker()); } @@ -215,6 +225,8 @@ static NettyServerHandler newHandler( long maxConnectionAgeGraceInNanos, boolean permitKeepAliveWithoutCalls, long permitKeepAliveTimeInNanos, + int maxRstCount, + long maxRstPeriodNanos, Attributes eagAttributes, Ticker ticker) { Preconditions.checkArgument(maxStreams > 0, "maxStreams must be positive: %s", maxStreams); @@ -266,6 +278,8 @@ static NettyServerHandler newHandler( maxConnectionAgeInNanos, maxConnectionAgeGraceInNanos, keepAliveEnforcer, autoFlowControl, + maxRstCount, + maxRstPeriodNanos, eagAttributes, ticker); } @@ -286,6 +300,8 @@ private NettyServerHandler( long maxConnectionAgeGraceInNanos, final KeepAliveEnforcer keepAliveEnforcer, boolean autoFlowControl, + int maxRstCount, + long maxRstPeriodNanos, Attributes eagAttributes, Ticker ticker) { super(channelUnused, decoder, encoder, settings, new ServerChannelLogger(), @@ -328,8 +344,12 @@ public void onStreamClosed(Http2Stream stream) { this.maxConnectionAgeInNanos = maxConnectionAgeInNanos; this.maxConnectionAgeGraceInNanos = maxConnectionAgeGraceInNanos; this.keepAliveEnforcer = checkNotNull(keepAliveEnforcer, "keepAliveEnforcer"); + this.maxRstCount = maxRstCount; + this.maxRstPeriodNanos = maxRstPeriodNanos; this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes"); + this.ticker = checkNotNull(ticker, "ticker"); + this.lastRstNanoTime = ticker.read(); streamKey = encoder.connection().newKey(); this.transportListener = checkNotNull(transportListener, "transportListener"); this.streamTracerFactories = checkNotNull(streamTracerFactories, "streamTracerFactories"); @@ -527,6 +547,26 @@ private void onDataRead(int streamId, ByteBuf data, int padding, boolean endOfSt } private void onRstStreamRead(int streamId, long errorCode) throws Http2Exception { + if (maxRstCount > 0) { + long now = ticker.read(); + if (now - lastRstNanoTime > maxRstPeriodNanos) { + lastRstNanoTime = now; + rstCount = 1; + } else { + rstCount++; + if (rstCount > maxRstCount) { + throw new Http2Exception(Http2Error.ENHANCE_YOUR_CALM, "too_many_rststreams") { + @SuppressWarnings("UnsynchronizedOverridesSynchronized") // No memory accesses + @Override + public Throwable fillInStackTrace() { + // Avoid the CPU cycles, since the resets may be a CPU consumption attack + return this; + } + }; + } + } + } + try { NettyServerStream.TransportState stream = serverStream(connection().stream(streamId)); if (stream != null) { diff --git a/netty/src/main/java/io/grpc/netty/NettyServerTransport.java b/netty/src/main/java/io/grpc/netty/NettyServerTransport.java index 46ddeb27c93..9511927a09f 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerTransport.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerTransport.java @@ -77,6 +77,8 @@ class NettyServerTransport implements ServerTransport { private final long maxConnectionAgeGraceInNanos; private final boolean permitKeepAliveWithoutCalls; private final long permitKeepAliveTimeInNanos; + private final int maxRstCount; + private final long maxRstPeriodNanos; private final Attributes eagAttributes; private final List streamTracerFactories; private final TransportTracer transportTracer; @@ -99,6 +101,8 @@ class NettyServerTransport implements ServerTransport { long maxConnectionAgeGraceInNanos, boolean permitKeepAliveWithoutCalls, long permitKeepAliveTimeInNanos, + int maxRstCount, + long maxRstPeriodNanos, Attributes eagAttributes) { this.channel = Preconditions.checkNotNull(channel, "channel"); this.channelUnused = channelUnused; @@ -118,6 +122,8 @@ class NettyServerTransport implements ServerTransport { this.maxConnectionAgeGraceInNanos = maxConnectionAgeGraceInNanos; this.permitKeepAliveWithoutCalls = permitKeepAliveWithoutCalls; this.permitKeepAliveTimeInNanos = permitKeepAliveTimeInNanos; + this.maxRstCount = maxRstCount; + this.maxRstPeriodNanos = maxRstPeriodNanos; this.eagAttributes = Preconditions.checkNotNull(eagAttributes, "eagAttributes"); SocketAddress remote = channel.remoteAddress(); this.logId = InternalLogId.allocate(getClass(), remote != null ? remote.toString() : null); @@ -277,6 +283,8 @@ private NettyServerHandler createHandler( maxConnectionAgeGraceInNanos, permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos, + maxRstCount, + maxRstPeriodNanos, eagAttributes); } } diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index eabbbda3180..39e6718a24e 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -27,6 +27,7 @@ import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE; import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_NANOS_DISABLED; import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_IDLE_NANOS_DISABLED; +import static io.grpc.netty.NettyServerBuilder.MAX_RST_COUNT_DISABLED; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -826,7 +827,7 @@ private void startServer(int maxStreamsPerConnection, int maxHeaderListSize) thr DEFAULT_SERVER_KEEPALIVE_TIME_NANOS, DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS, MAX_CONNECTION_IDLE_NANOS_DISABLED, MAX_CONNECTION_AGE_NANOS_DISABLED, MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE, true, 0, - Attributes.EMPTY, + MAX_RST_COUNT_DISABLED, 0, Attributes.EMPTY, channelz); server.start(serverListener); address = TestUtils.testServerAddress((InetSocketAddress) server.getListenSocketAddress()); diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index 368b0600f9e..281ff3b17d6 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -23,6 +23,7 @@ import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE; import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_NANOS_DISABLED; import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_IDLE_NANOS_DISABLED; +import static io.grpc.netty.NettyServerBuilder.MAX_RST_COUNT_DISABLED; import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC; import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER; import static io.grpc.netty.Utils.HTTP_METHOD; @@ -33,6 +34,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; @@ -85,6 +87,7 @@ import io.netty.handler.codec.http2.Http2Stream; import io.netty.util.AsciiString; import java.io.InputStream; +import java.nio.channels.ClosedChannelException; import java.util.Arrays; import java.util.LinkedList; import java.util.List; @@ -143,6 +146,8 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase rapidReset(maxRstCount + 1)); + assertFalse(channel().isOpen()); + } + + private void rapidReset(int burstSize) throws Exception { + Http2Headers headers = new DefaultHttp2Headers() + .method(HTTP_METHOD) + .set(CONTENT_TYPE_HEADER, new AsciiString("application/grpc", UTF_8)) + .set(TE_HEADER, TE_TRAILERS) + .path(new AsciiString("/foo/bar")); + int streamId = 1; + long rpcTimeNanos = maxRstPeriodNanos / 2 / burstSize; + for (int period = 0; period < 3; period++) { + for (int i = 0; i < burstSize; i++) { + channelRead(headersFrame(streamId, headers)); + channelRead(rstStreamFrame(streamId, (int) Http2Error.CANCEL.code())); + streamId += 2; + fakeClock().forwardNanos(rpcTimeNanos); + } + while (channel().readOutbound() != null) {} + fakeClock().forwardNanos(maxRstPeriodNanos - rpcTimeNanos * burstSize + 1); + } + } + private void createStream() throws Exception { Http2Headers headers = new DefaultHttp2Headers() .method(HTTP_METHOD) @@ -1296,6 +1339,8 @@ protected NettyServerHandler newHandler() { maxConnectionAgeGraceInNanos, permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos, + maxRstCount, + maxRstPeriodNanos, Attributes.EMPTY, fakeClock().getTicker()); } diff --git a/netty/src/test/java/io/grpc/netty/NettyServerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerTest.java index 1c212bd42fd..64d31070156 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerTest.java @@ -153,6 +153,7 @@ class NoHandlerProtocolNegotiator implements ProtocolNegotiator { 1, 1, // ignore 1, 1, // ignore true, 0, // ignore + 0, 0, // ignore Attributes.EMPTY, channelz); final SettableFuture serverShutdownCalled = SettableFuture.create(); @@ -203,6 +204,7 @@ public void multiPortStartStopGet() throws Exception { 1, 1, // ignore 1, 1, // ignore true, 0, // ignore + 0, 0, // ignore Attributes.EMPTY, channelz); final SettableFuture shutdownCompleted = SettableFuture.create(); @@ -276,6 +278,7 @@ public void multiPortConnections() throws Exception { 1, 1, // ignore 1, 1, // ignore true, 0, // ignore + 0, 0, // ignore Attributes.EMPTY, channelz); final SettableFuture shutdownCompleted = SettableFuture.create(); @@ -337,6 +340,7 @@ public void getPort_notStarted() { 1, 1, // ignore 1, 1, // ignore true, 0, // ignore + 0, 0, // ignore Attributes.EMPTY, channelz); @@ -411,6 +415,7 @@ class TestProtocolNegotiator implements ProtocolNegotiator { 1, 1, // ignore 1, 1, // ignore true, 0, // ignore + 0, 0, // ignore eagAttributes, channelz); ns.start(new ServerListener() { @@ -458,6 +463,7 @@ public void channelzListenSocket() throws Exception { 1, 1, // ignore 1, 1, // ignore true, 0, // ignore + 0, 0, // ignore Attributes.EMPTY, channelz); final SettableFuture shutdownCompleted = SettableFuture.create(); @@ -600,6 +606,7 @@ private NettyServer getServer(List addr, EventLoopGroup ev) { 1, 1, // ignore 1, 1, // ignore true, 0, // ignore + 0, 0, // ignore Attributes.EMPTY, channelz); }