Skip to content

Commit

Permalink
netty: Add option to limit RST_STREAM rate
Browse files Browse the repository at this point in the history
The behavior purposefully mirrors that of Netty's
AbstractHttp2ConnectionHandlerBuilder.decoderEnforceMaxRstFramesPerWindow().
That API is not available to our code as we extend the
Http2ConnectionHandler, but we want our API to be able to delegate to
Netty's in the future if that ever becomes possible.
  • Loading branch information
ejona86 committed Nov 20, 2023
1 parent 4caf106 commit ac94b6e
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 2 deletions.
7 changes: 7 additions & 0 deletions netty/src/main/java/io/grpc/netty/NettyServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class NettyServer implements InternalServer, InternalWithLogId {
private final long maxConnectionAgeGraceInNanos;
private final boolean permitKeepAliveWithoutCalls;
private final long permitKeepAliveTimeInNanos;
private final int maxRstCount;
private final long maxRstPeriodNanos;
private final Attributes eagAttributes;
private final ReferenceCounted sharedResourceReferenceCounter =
new SharedResourceReferenceCounter();
Expand Down Expand Up @@ -127,6 +129,7 @@ class NettyServer implements InternalServer, InternalWithLogId {
long maxConnectionIdleInNanos,
long maxConnectionAgeInNanos, long maxConnectionAgeGraceInNanos,
boolean permitKeepAliveWithoutCalls, long permitKeepAliveTimeInNanos,
int maxRstCount, long maxRstPeriodNanos,
Attributes eagAttributes, InternalChannelz channelz) {
this.addresses = checkNotNull(addresses, "addresses");
this.channelFactory = checkNotNull(channelFactory, "channelFactory");
Expand Down Expand Up @@ -156,6 +159,8 @@ class NettyServer implements InternalServer, InternalWithLogId {
this.maxConnectionAgeGraceInNanos = maxConnectionAgeGraceInNanos;
this.permitKeepAliveWithoutCalls = permitKeepAliveWithoutCalls;
this.permitKeepAliveTimeInNanos = permitKeepAliveTimeInNanos;
this.maxRstCount = maxRstCount;
this.maxRstPeriodNanos = maxRstPeriodNanos;
this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes");
this.channelz = Preconditions.checkNotNull(channelz);
this.logId = InternalLogId.allocate(getClass(), addresses.isEmpty() ? "No address" :
Expand Down Expand Up @@ -257,6 +262,8 @@ public void initChannel(Channel ch) {
maxConnectionAgeGraceInNanos,
permitKeepAliveWithoutCalls,
permitKeepAliveTimeInNanos,
maxRstCount,
maxRstPeriodNanos,
eagAttributes);
ServerTransportListener transportListener;
// This is to order callbacks on the listener, not to guard access to channel.
Expand Down
32 changes: 31 additions & 1 deletion netty/src/main/java/io/grpc/netty/NettyServerBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ public final class NettyServerBuilder extends ForwardingServerBuilder<NettyServe
static final long MAX_CONNECTION_IDLE_NANOS_DISABLED = Long.MAX_VALUE;
static final long MAX_CONNECTION_AGE_NANOS_DISABLED = Long.MAX_VALUE;
static final long MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE = Long.MAX_VALUE;
static final int MAX_RST_COUNT_DISABLED = 0;

private static final long MIN_KEEPALIVE_TIME_NANO = TimeUnit.MILLISECONDS.toNanos(1L);
private static final long MIN_KEEPALIVE_TIMEOUT_NANO = TimeUnit.MICROSECONDS.toNanos(499L);
Expand Down Expand Up @@ -113,6 +114,8 @@ public final class NettyServerBuilder extends ForwardingServerBuilder<NettyServe
private long maxConnectionAgeGraceInNanos = MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE;
private boolean permitKeepAliveWithoutCalls;
private long permitKeepAliveTimeInNanos = TimeUnit.MINUTES.toNanos(5);
private int maxRstCount;
private long maxRstPeriodNanos;
private Attributes eagAttributes = Attributes.EMPTY;

/**
Expand Down Expand Up @@ -644,6 +647,33 @@ public NettyServerBuilder permitKeepAliveWithoutCalls(boolean permit) {
return this;
}

/**
* Limits the rate of incoming RST_STREAM frames per connection to maxRstStream per
* secondsPerWindow. When exceeded on a connection, the connection is closed. This can reduce the
* impact of an attacker continually resetting RPCs before they complete, when combined with TLS
* and {@link #maxConcurrentCallsPerConnection(int)}.
*
* <p>gRPC clients send RST_STREAM when they cancel RPCs, so some RST_STREAMs are normal and
* setting this too low can cause errors for legimitate clients.
*
* <p>By default there is no limit.
*
* @param maxRstStream the positive limit of RST_STREAM frames per connection per period, or
* {@code Integer.MAX_VALUE} for unlimited
* @param secondsPerWindow the positive number of seconds per period
*/
@CanIgnoreReturnValue
public NettyServerBuilder maxRstFramesPerWindow(int maxRstStream, int secondsPerWindow) {
checkArgument(maxRstStream > 0, "maxRstStream must be positive");
checkArgument(secondsPerWindow > 0, "secondsPerWindow must be positive");
if (maxRstStream == Integer.MAX_VALUE) {
maxRstStream = MAX_RST_COUNT_DISABLED;
}
this.maxRstCount = maxRstStream;
this.maxRstPeriodNanos = TimeUnit.SECONDS.toNanos(secondsPerWindow);
return this;
}

/** Sets the EAG attributes available to protocol negotiators. Not for general use. */
void eagAttributes(Attributes eagAttributes) {
this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes");
Expand All @@ -664,7 +694,7 @@ NettyServer buildTransportServers(
keepAliveTimeInNanos, keepAliveTimeoutInNanos,
maxConnectionIdleInNanos, maxConnectionAgeInNanos,
maxConnectionAgeGraceInNanos, permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos,
eagAttributes, this.serverImplBuilder.getChannelz());
maxRstCount, maxRstPeriodNanos, eagAttributes, this.serverImplBuilder.getChannelz());
}

@VisibleForTesting
Expand Down
40 changes: 40 additions & 0 deletions netty/src/main/java/io/grpc/netty/NettyServerHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,13 @@ class NettyServerHandler extends AbstractNettyHandler {
private final long keepAliveTimeoutInNanos;
private final long maxConnectionAgeInNanos;
private final long maxConnectionAgeGraceInNanos;
private final int maxRstCount;
private final long maxRstPeriodNanos;
private final List<? extends ServerStreamTracer.Factory> streamTracerFactories;
private final TransportTracer transportTracer;
private final KeepAliveEnforcer keepAliveEnforcer;
private final Attributes eagAttributes;
private final Ticker ticker;
/** Incomplete attributes produced by negotiator. */
private Attributes negotiationAttributes;
private InternalChannelz.Security securityInfo;
Expand All @@ -146,6 +149,9 @@ class NettyServerHandler extends AbstractNettyHandler {
private ScheduledFuture<?> maxConnectionAgeMonitor;
@CheckForNull
private GracefulShutdown gracefulShutdown;
private int rstCount;
private long lastRstNanoTime;


static NettyServerHandler newHandler(
ServerTransportListener transportListener,
Expand All @@ -164,6 +170,8 @@ static NettyServerHandler newHandler(
long maxConnectionAgeGraceInNanos,
boolean permitKeepAliveWithoutCalls,
long permitKeepAliveTimeInNanos,
int maxRstCount,
long maxRstPeriodNanos,
Attributes eagAttributes) {
Preconditions.checkArgument(maxHeaderListSize > 0, "maxHeaderListSize must be positive: %s",
maxHeaderListSize);
Expand Down Expand Up @@ -192,6 +200,8 @@ static NettyServerHandler newHandler(
maxConnectionAgeGraceInNanos,
permitKeepAliveWithoutCalls,
permitKeepAliveTimeInNanos,
maxRstCount,
maxRstPeriodNanos,
eagAttributes,
Ticker.systemTicker());
}
Expand All @@ -215,6 +225,8 @@ static NettyServerHandler newHandler(
long maxConnectionAgeGraceInNanos,
boolean permitKeepAliveWithoutCalls,
long permitKeepAliveTimeInNanos,
int maxRstCount,
long maxRstPeriodNanos,
Attributes eagAttributes,
Ticker ticker) {
Preconditions.checkArgument(maxStreams > 0, "maxStreams must be positive: %s", maxStreams);
Expand Down Expand Up @@ -266,6 +278,8 @@ static NettyServerHandler newHandler(
maxConnectionAgeInNanos, maxConnectionAgeGraceInNanos,
keepAliveEnforcer,
autoFlowControl,
maxRstCount,
maxRstPeriodNanos,
eagAttributes, ticker);
}

Expand All @@ -286,6 +300,8 @@ private NettyServerHandler(
long maxConnectionAgeGraceInNanos,
final KeepAliveEnforcer keepAliveEnforcer,
boolean autoFlowControl,
int maxRstCount,
long maxRstPeriodNanos,
Attributes eagAttributes,
Ticker ticker) {
super(channelUnused, decoder, encoder, settings, new ServerChannelLogger(),
Expand Down Expand Up @@ -328,8 +344,12 @@ public void onStreamClosed(Http2Stream stream) {
this.maxConnectionAgeInNanos = maxConnectionAgeInNanos;
this.maxConnectionAgeGraceInNanos = maxConnectionAgeGraceInNanos;
this.keepAliveEnforcer = checkNotNull(keepAliveEnforcer, "keepAliveEnforcer");
this.maxRstCount = maxRstCount;
this.maxRstPeriodNanos = maxRstPeriodNanos;
this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes");
this.ticker = checkNotNull(ticker, "ticker");

this.lastRstNanoTime = ticker.read();
streamKey = encoder.connection().newKey();
this.transportListener = checkNotNull(transportListener, "transportListener");
this.streamTracerFactories = checkNotNull(streamTracerFactories, "streamTracerFactories");
Expand Down Expand Up @@ -527,6 +547,26 @@ private void onDataRead(int streamId, ByteBuf data, int padding, boolean endOfSt
}

private void onRstStreamRead(int streamId, long errorCode) throws Http2Exception {
if (maxRstCount > 0) {
long now = ticker.read();
if (now - lastRstNanoTime > maxRstPeriodNanos) {
lastRstNanoTime = now;
rstCount = 1;
} else {
rstCount++;
if (rstCount > maxRstCount) {
throw new Http2Exception(Http2Error.ENHANCE_YOUR_CALM, "too_many_rststreams") {
@SuppressWarnings("UnsynchronizedOverridesSynchronized") // No memory accesses
@Override
public Throwable fillInStackTrace() {
// Avoid the CPU cycles, since the resets may be a CPU consumption attack
return this;
}
};
}
}
}

try {
NettyServerStream.TransportState stream = serverStream(connection().stream(streamId));
if (stream != null) {
Expand Down
8 changes: 8 additions & 0 deletions netty/src/main/java/io/grpc/netty/NettyServerTransport.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class NettyServerTransport implements ServerTransport {
private final long maxConnectionAgeGraceInNanos;
private final boolean permitKeepAliveWithoutCalls;
private final long permitKeepAliveTimeInNanos;
private final int maxRstCount;
private final long maxRstPeriodNanos;
private final Attributes eagAttributes;
private final List<? extends ServerStreamTracer.Factory> streamTracerFactories;
private final TransportTracer transportTracer;
Expand All @@ -99,6 +101,8 @@ class NettyServerTransport implements ServerTransport {
long maxConnectionAgeGraceInNanos,
boolean permitKeepAliveWithoutCalls,
long permitKeepAliveTimeInNanos,
int maxRstCount,
long maxRstPeriodNanos,
Attributes eagAttributes) {
this.channel = Preconditions.checkNotNull(channel, "channel");
this.channelUnused = channelUnused;
Expand All @@ -118,6 +122,8 @@ class NettyServerTransport implements ServerTransport {
this.maxConnectionAgeGraceInNanos = maxConnectionAgeGraceInNanos;
this.permitKeepAliveWithoutCalls = permitKeepAliveWithoutCalls;
this.permitKeepAliveTimeInNanos = permitKeepAliveTimeInNanos;
this.maxRstCount = maxRstCount;
this.maxRstPeriodNanos = maxRstPeriodNanos;
this.eagAttributes = Preconditions.checkNotNull(eagAttributes, "eagAttributes");
SocketAddress remote = channel.remoteAddress();
this.logId = InternalLogId.allocate(getClass(), remote != null ? remote.toString() : null);
Expand Down Expand Up @@ -277,6 +283,8 @@ private NettyServerHandler createHandler(
maxConnectionAgeGraceInNanos,
permitKeepAliveWithoutCalls,
permitKeepAliveTimeInNanos,
maxRstCount,
maxRstPeriodNanos,
eagAttributes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE;
import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_NANOS_DISABLED;
import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_IDLE_NANOS_DISABLED;
import static io.grpc.netty.NettyServerBuilder.MAX_RST_COUNT_DISABLED;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
Expand Down Expand Up @@ -781,7 +782,7 @@ private void startServer(int maxStreamsPerConnection, int maxHeaderListSize) thr
DEFAULT_SERVER_KEEPALIVE_TIME_NANOS, DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS,
MAX_CONNECTION_IDLE_NANOS_DISABLED,
MAX_CONNECTION_AGE_NANOS_DISABLED, MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE, true, 0,
Attributes.EMPTY,
MAX_RST_COUNT_DISABLED, 0, Attributes.EMPTY,
channelz);
server.start(serverListener);
address = TestUtils.testServerAddress((InetSocketAddress) server.getListenSocketAddress());
Expand Down
45 changes: 45 additions & 0 deletions netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE;
import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_NANOS_DISABLED;
import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_IDLE_NANOS_DISABLED;
import static io.grpc.netty.NettyServerBuilder.MAX_RST_COUNT_DISABLED;
import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC;
import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER;
import static io.grpc.netty.Utils.HTTP_METHOD;
Expand All @@ -33,6 +34,7 @@
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.ArgumentMatchers.any;
Expand Down Expand Up @@ -85,6 +87,7 @@
import io.netty.handler.codec.http2.Http2Stream;
import io.netty.util.AsciiString;
import java.io.InputStream;
import java.nio.channels.ClosedChannelException;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
Expand Down Expand Up @@ -143,6 +146,8 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
private long maxConnectionAgeGraceInNanos = MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE;
private long keepAliveTimeInNanos = DEFAULT_SERVER_KEEPALIVE_TIME_NANOS;
private long keepAliveTimeoutInNanos = DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS;
private int maxRstCount = MAX_RST_COUNT_DISABLED;
private long maxRstPeriodNanos;

private class ServerTransportListenerImpl implements ServerTransportListener {

Expand Down Expand Up @@ -1249,6 +1254,44 @@ public void maxConnectionAgeGrace_channelClosedAfterGracePeriod_withPingAck()
assertFalse(channel().isOpen());
}

@Test
public void maxRstCount_withinLimit_succeeds() throws Exception {
maxRstCount = 10;
maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100);
manualSetUp();
rapidReset(maxRstCount);
assertTrue(channel().isOpen());
}

@Test
public void maxRstCount_exceedsLimit_fails() throws Exception {
maxRstCount = 10;
maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100);
manualSetUp();
assertThrows(ClosedChannelException.class, () -> rapidReset(maxRstCount + 1));
assertFalse(channel().isOpen());
}

private void rapidReset(int burstSize) throws Exception {
Http2Headers headers = new DefaultHttp2Headers()
.method(HTTP_METHOD)
.set(CONTENT_TYPE_HEADER, new AsciiString("application/grpc", UTF_8))
.set(TE_HEADER, TE_TRAILERS)
.path(new AsciiString("/foo/bar"));
int streamId = 1;
long rpcTimeNanos = maxRstPeriodNanos / 2 / burstSize;
for (int period = 0; period < 3; period++) {
for (int i = 0; i < burstSize; i++) {
channelRead(headersFrame(streamId, headers));
channelRead(rstStreamFrame(streamId, (int) Http2Error.CANCEL.code()));
streamId += 2;
fakeClock().forwardNanos(rpcTimeNanos);
}
while (channel().readOutbound() != null) {}
fakeClock().forwardNanos(maxRstPeriodNanos - rpcTimeNanos * burstSize + 1);
}
}

private void createStream() throws Exception {
Http2Headers headers = new DefaultHttp2Headers()
.method(HTTP_METHOD)
Expand Down Expand Up @@ -1296,6 +1339,8 @@ protected NettyServerHandler newHandler() {
maxConnectionAgeGraceInNanos,
permitKeepAliveWithoutCalls,
permitKeepAliveTimeInNanos,
maxRstCount,
maxRstPeriodNanos,
Attributes.EMPTY,
fakeClock().getTicker());
}
Expand Down
7 changes: 7 additions & 0 deletions netty/src/test/java/io/grpc/netty/NettyServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class NoHandlerProtocolNegotiator implements ProtocolNegotiator {
1, 1, // ignore
1, 1, // ignore
true, 0, // ignore
0, 0, // ignore
Attributes.EMPTY,
channelz);
final SettableFuture<Void> serverShutdownCalled = SettableFuture.create();
Expand Down Expand Up @@ -203,6 +204,7 @@ public void multiPortStartStopGet() throws Exception {
1, 1, // ignore
1, 1, // ignore
true, 0, // ignore
0, 0, // ignore
Attributes.EMPTY,
channelz);
final SettableFuture<Void> shutdownCompleted = SettableFuture.create();
Expand Down Expand Up @@ -276,6 +278,7 @@ public void multiPortConnections() throws Exception {
1, 1, // ignore
1, 1, // ignore
true, 0, // ignore
0, 0, // ignore
Attributes.EMPTY,
channelz);
final SettableFuture<Void> shutdownCompleted = SettableFuture.create();
Expand Down Expand Up @@ -337,6 +340,7 @@ public void getPort_notStarted() {
1, 1, // ignore
1, 1, // ignore
true, 0, // ignore
0, 0, // ignore
Attributes.EMPTY,
channelz);

Expand Down Expand Up @@ -411,6 +415,7 @@ class TestProtocolNegotiator implements ProtocolNegotiator {
1, 1, // ignore
1, 1, // ignore
true, 0, // ignore
0, 0, // ignore
eagAttributes,
channelz);
ns.start(new ServerListener() {
Expand Down Expand Up @@ -458,6 +463,7 @@ public void channelzListenSocket() throws Exception {
1, 1, // ignore
1, 1, // ignore
true, 0, // ignore
0, 0, // ignore
Attributes.EMPTY,
channelz);
final SettableFuture<Void> shutdownCompleted = SettableFuture.create();
Expand Down Expand Up @@ -600,6 +606,7 @@ private NettyServer getServer(List<SocketAddress> addr, EventLoopGroup ev) {
1, 1, // ignore
1, 1, // ignore
true, 0, // ignore
0, 0, // ignore
Attributes.EMPTY,
channelz);
}
Expand Down

0 comments on commit ac94b6e

Please sign in to comment.