From cbd72f17079a595a73f68eba95bb151a54aeea2c Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Mon, 6 Nov 2017 14:01:34 -0800 Subject: [PATCH 1/2] core: Add negotation timeout for all server transports Negotation is the period between transport creation and ready. --- core/src/main/java/io/grpc/ServerBuilder.java | 15 +++++++++ .../internal/AbstractServerImplBuilder.java | 12 +++++++ .../java/io/grpc/internal/ServerImpl.java | 26 ++++++++++++++- .../java/io/grpc/internal/ServerImplTest.java | 32 +++++++++++++++++++ 4 files changed, 84 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/io/grpc/ServerBuilder.java b/core/src/main/java/io/grpc/ServerBuilder.java index e646d69cfb5..fda9cb5c1f4 100644 --- a/core/src/main/java/io/grpc/ServerBuilder.java +++ b/core/src/main/java/io/grpc/ServerBuilder.java @@ -18,6 +18,7 @@ import java.io.File; import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; /** @@ -172,6 +173,20 @@ public T addStreamTracerFactory(ServerStreamTracer.Factory factory) { @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1704") public abstract T compressorRegistry(@Nullable CompressorRegistry registry); + /** + * Sets the permitted time for new connections to complete negotiation handshakes before being + * killed. + * + * @return this + * @throws IllegalArgumentException if timeout is negative + * @throws UnsupportedOperationException if unsupported + * @since 1.8.0 + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/3706") + public T handshakeTimeout(long timeout, TimeUnit unit) { + throw new UnsupportedOperationException(); + } + /** * Builds a server using the given parameters. * diff --git a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java index 29768d15037..0280d51d553 100644 --- a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java @@ -16,6 +16,7 @@ package io.grpc.internal; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; @@ -39,6 +40,7 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; /** @@ -71,6 +73,7 @@ public List getServices() { DecompressorRegistry.getDefaultInstance(); private static final CompressorRegistry DEFAULT_COMPRESSOR_REGISTRY = CompressorRegistry.getDefaultInstance(); + private static final long DEFAULT_HANDSHAKE_TIMEOUT_MILLIS = TimeUnit.SECONDS.toMillis(20); final InternalHandlerRegistry.Builder registryBuilder = new InternalHandlerRegistry.Builder(); @@ -94,6 +97,8 @@ public List getServices() { CompressorRegistry compressorRegistry = DEFAULT_COMPRESSOR_REGISTRY; + long handshakeTimeoutMillis = DEFAULT_HANDSHAKE_TIMEOUT_MILLIS; + @Nullable private CensusStatsModule censusStatsOverride; @@ -178,6 +183,13 @@ public final T compressorRegistry(CompressorRegistry registry) { return thisT(); } + @Override + public final T handshakeTimeout(long timeout, TimeUnit unit) { + checkArgument(timeout > 0, "handshake timeout is %s, but must be positive", timeout); + handshakeTimeoutMillis = unit.toMillis(timeout); + return thisT(); + } + /** * Override the default stats implementation. */ diff --git a/core/src/main/java/io/grpc/internal/ServerImpl.java b/core/src/main/java/io/grpc/internal/ServerImpl.java index 5f4c731716a..591b77d21d5 100644 --- a/core/src/main/java/io/grpc/internal/ServerImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerImpl.java @@ -49,6 +49,7 @@ import java.util.HashSet; import java.util.List; import java.util.concurrent.Executor; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; @@ -82,6 +83,7 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId { // This is iterated on a per-call basis. Use an array instead of a Collection to avoid iterator // creations. private final ServerInterceptor[] interceptors; + private final long handshakeTimeoutMillis; @GuardedBy("lock") private boolean started; @GuardedBy("lock") private boolean shutdown; /** non-{@code null} if immediate shutdown has been requested. */ @@ -127,6 +129,7 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId { new ArrayList(builder.transportFilters)); this.interceptors = builder.interceptors.toArray(new ServerInterceptor[builder.interceptors.size()]); + this.handshakeTimeoutMillis = builder.handshakeTimeoutMillis; } /** @@ -308,7 +311,9 @@ public ServerTransportListener transportCreated(ServerTransport transport) { synchronized (lock) { transports.add(transport); } - return new ServerTransportListenerImpl(transport); + ServerTransportListenerImpl stli = new ServerTransportListenerImpl(transport); + stli.init(); + return stli; } @Override @@ -338,14 +343,29 @@ public void serverShutdown() { private final class ServerTransportListenerImpl implements ServerTransportListener { private final ServerTransport transport; + private Future handshakeTimeoutFuture; private Attributes attributes; ServerTransportListenerImpl(ServerTransport transport) { this.transport = transport; } + public void init() { + class TransportShutdownNow implements Runnable { + @Override public void run() { + transport.shutdownNow(Status.CANCELLED.withDescription("Handshake timeout exceeded")); + } + } + + handshakeTimeoutFuture = transport.getScheduledExecutorService() + .schedule(new TransportShutdownNow(), handshakeTimeoutMillis, TimeUnit.MILLISECONDS); + } + @Override public Attributes transportReady(Attributes attributes) { + handshakeTimeoutFuture.cancel(false); + handshakeTimeoutFuture = null; + for (ServerTransportFilter filter : transportFilters) { attributes = Preconditions.checkNotNull(filter.transportReady(attributes), "Filter %s returned null", filter); @@ -356,6 +376,10 @@ public Attributes transportReady(Attributes attributes) { @Override public void transportTerminated() { + if (handshakeTimeoutFuture != null) { + handshakeTimeoutFuture.cancel(false); + handshakeTimeoutFuture = null; + } for (ServerTransportFilter filter : transportFilters) { filter.transportTerminated(attributes); } diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java index 6a40924d1df..ef92481936c 100644 --- a/core/src/test/java/io/grpc/internal/ServerImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java @@ -77,6 +77,7 @@ import java.util.concurrent.Executor; import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -348,11 +349,35 @@ public void start(ServerListener listener) throws IOException { verifyNoMoreInteractions(executorPool); } + @Test + public void transportHandshakeTimeout_expired() throws Exception { + class ShutdownRecordingTransport extends SimpleServerTransport { + Status shutdownNowStatus; + + @Override public void shutdownNow(Status status) { + shutdownNowStatus = status; + super.shutdownNow(status); + } + } + + builder.handshakeTimeout(60, TimeUnit.SECONDS); + createAndStartServer(); + ShutdownRecordingTransport serverTransport = new ShutdownRecordingTransport(); + ServerTransportListener transportListener + = transportServer.registerNewServerTransport(serverTransport); + timer.forwardTime(59, TimeUnit.SECONDS); + assertNull("shutdownNow status", serverTransport.shutdownNowStatus); + // Don't call transportReady() in time + timer.forwardTime(2, TimeUnit.SECONDS); + assertNotNull("shutdownNow status", serverTransport.shutdownNowStatus); + } + @Test public void methodNotFound() throws Exception { createAndStartServer(); ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); Metadata requestHeaders = new Metadata(); StatsTraceContext statsTraceCtx = StatsTraceContext.newServerContext( @@ -379,6 +404,7 @@ public void decompressorNotFound() throws Exception { createAndStartServer(); ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); Metadata requestHeaders = new Metadata(); requestHeaders.put(MESSAGE_ENCODING_KEY, decompressorName); StatsTraceContext statsTraceCtx = @@ -423,6 +449,7 @@ public ServerCall.Listener startCall( }).build()); ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); Metadata requestHeaders = new Metadata(); requestHeaders.put(metadataKey, "value"); @@ -622,6 +649,7 @@ public ServerCall.Listener startCall( ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); Metadata requestHeaders = new Metadata(); StatsTraceContext statsTraceCtx = @@ -666,6 +694,7 @@ public ServerCall.Listener startCall( }).build()); ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); Metadata requestHeaders = new Metadata(); StatsTraceContext statsTraceCtx = @@ -827,6 +856,7 @@ private void checkContext() { }).build()); ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); Metadata requestHeaders = new Metadata(); StatsTraceContext statsTraceCtx = @@ -892,6 +922,7 @@ public ServerCall.Listener startCall( }).build()); ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); Metadata requestHeaders = new Metadata(); StatsTraceContext statsTraceCtx = StatsTraceContext.newServerContext(streamTracerFactories, "Waitier/serve", requestHeaders); @@ -998,6 +1029,7 @@ public void handlerRegistryPriorities() throws Exception { ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); Metadata requestHeaders = new Metadata(); StatsTraceContext statsTraceCtx = StatsTraceContext.newServerContext(streamTracerFactories, "Waiter/serve", requestHeaders); From 17105b860f6fee6dbcb1386df3386f99b31f1e73 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Mon, 6 Nov 2017 15:03:42 -0800 Subject: [PATCH 2/2] netty: Move server transportReady after client preface receipt This mirrors the behavior of client-side. --- .../io/grpc/netty/NettyServerHandler.java | 16 +++++++++- .../io/grpc/netty/NettyServerHandlerTest.java | 32 ++++++++++++++----- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index 79b10236227..1180df0912e 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -106,6 +106,9 @@ class NettyServerHandler extends AbstractNettyHandler { private final List streamTracerFactories; private final TransportTracer transportTracer; private final KeepAliveEnforcer keepAliveEnforcer; + /** Incomplete attributes produced by negotiator. */ + private Attributes negotiationAttributes; + /** Completed attributes produced by transportReady. */ private Attributes attributes; private Throwable connectionError; private boolean teWarningLogged; @@ -481,7 +484,7 @@ protected void onStreamError(ChannelHandlerContext ctx, Throwable cause, @Override public void handleProtocolNegotiationCompleted(Attributes attrs) { - attributes = transportListener.transportReady(attrs); + negotiationAttributes = attrs; } @VisibleForTesting @@ -681,6 +684,17 @@ private Http2Exception newStreamException(int streamId, Throwable cause) { } private class FrameListener extends Http2FrameAdapter { + private boolean firstSettings = true; + + @Override + public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) { + if (firstSettings) { + firstSettings = false; + // Delay transportReady until we see the client's HTTP handshake, for coverage with + // handshakeTimeout + attributes = transportListener.transportReady(negotiationAttributes); + } + } @Override public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index 1f056570c5c..3c42959e10a 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -89,13 +89,13 @@ import java.util.List; import java.util.Queue; import java.util.concurrent.TimeUnit; +import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.Timeout; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; -import org.mockito.Matchers; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; @@ -136,7 +136,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBaseany()); + .messagesAvailable(any(StreamListener.MessageProducer.class)); + } + + @Override + protected void manualSetUp() throws Exception { + assertNull("manualSetUp should not run more than once", handler()); initChannel(new GrpcHttp2ServerHeadersDecoder(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE)); @@ -195,6 +198,19 @@ public Void answer(InvocationOnMock invocation) throws Throwable { channelRead(serializedSettings); } + @Test + public void transportReadyDelayedUntilConnectionPreface() throws Exception { + initChannel(new GrpcHttp2ServerHeadersDecoder(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE)); + + handler().handleProtocolNegotiationCompleted(Attributes.EMPTY); + verify(transportListener, never()).transportReady(any(Attributes.class)); + + // Simulate receipt of the connection preface + channelRead(Http2CodecUtil.connectionPrefaceBuf()); + channelRead(serializeSettings(new Http2Settings())); + verify(transportListener).transportReady(any(Attributes.class)); + } + @Test public void sendFrameShouldSucceed() throws Exception { manualSetUp();