diff --git a/core/src/main/java/io/grpc/ServerBuilder.java b/core/src/main/java/io/grpc/ServerBuilder.java index e646d69cfb5..fda9cb5c1f4 100644 --- a/core/src/main/java/io/grpc/ServerBuilder.java +++ b/core/src/main/java/io/grpc/ServerBuilder.java @@ -18,6 +18,7 @@ import java.io.File; import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; /** @@ -172,6 +173,20 @@ public T addStreamTracerFactory(ServerStreamTracer.Factory factory) { @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1704") public abstract T compressorRegistry(@Nullable CompressorRegistry registry); + /** + * Sets the permitted time for new connections to complete negotiation handshakes before being + * killed. + * + * @return this + * @throws IllegalArgumentException if timeout is negative + * @throws UnsupportedOperationException if unsupported + * @since 1.8.0 + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/3706") + public T handshakeTimeout(long timeout, TimeUnit unit) { + throw new UnsupportedOperationException(); + } + /** * Builds a server using the given parameters. * diff --git a/core/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java b/core/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java index 1e25e0fa71d..a1425349196 100644 --- a/core/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java +++ b/core/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java @@ -23,6 +23,7 @@ import io.grpc.internal.GrpcUtil; import java.io.File; import java.util.List; +import java.util.concurrent.TimeUnit; /** * Builder for a server that services in-process requests. Clients identify the in-process server by @@ -79,6 +80,9 @@ public static InProcessServerBuilder forPort(int port) { private InProcessServerBuilder(String name) { this.name = Preconditions.checkNotNull(name, "name"); + // Disable handshake timeout because it is unnecessary, and can trigger Thread creation that can + // break some environments (like tests). + handshakeTimeout(Long.MAX_VALUE, TimeUnit.SECONDS); } @Override diff --git a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java index 238b0fc5ddd..f463a85c865 100644 --- a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java @@ -16,6 +16,7 @@ package io.grpc.internal; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; @@ -41,6 +42,7 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; /** @@ -73,6 +75,7 @@ public List getServices() { DecompressorRegistry.getDefaultInstance(); private static final CompressorRegistry DEFAULT_COMPRESSOR_REGISTRY = CompressorRegistry.getDefaultInstance(); + private static final long DEFAULT_HANDSHAKE_TIMEOUT_MILLIS = TimeUnit.SECONDS.toMillis(120); final InternalHandlerRegistry.Builder registryBuilder = new InternalHandlerRegistry.Builder(); @@ -96,6 +99,8 @@ public List getServices() { CompressorRegistry compressorRegistry = DEFAULT_COMPRESSOR_REGISTRY; + long handshakeTimeoutMillis = DEFAULT_HANDSHAKE_TIMEOUT_MILLIS; + @Nullable private StatsContextFactory statsFactory; @@ -179,6 +184,13 @@ public final T compressorRegistry(CompressorRegistry registry) { return thisT(); } + @Override + public final T handshakeTimeout(long timeout, TimeUnit unit) { + checkArgument(timeout > 0, "handshake timeout is %s, but must be positive", timeout); + handshakeTimeoutMillis = unit.toMillis(timeout); + return thisT(); + } + /** * Override the default stats implementation. */ diff --git a/core/src/main/java/io/grpc/internal/ServerImpl.java b/core/src/main/java/io/grpc/internal/ServerImpl.java index 5f4c731716a..c9932d293d3 100644 --- a/core/src/main/java/io/grpc/internal/ServerImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerImpl.java @@ -49,6 +49,8 @@ import java.util.HashSet; import java.util.List; import java.util.concurrent.Executor; +import java.util.concurrent.Future; +import java.util.concurrent.FutureTask; import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; @@ -82,6 +84,7 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId { // This is iterated on a per-call basis. Use an array instead of a Collection to avoid iterator // creations. private final ServerInterceptor[] interceptors; + private final long handshakeTimeoutMillis; @GuardedBy("lock") private boolean started; @GuardedBy("lock") private boolean shutdown; /** non-{@code null} if immediate shutdown has been requested. */ @@ -127,6 +130,7 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId { new ArrayList(builder.transportFilters)); this.interceptors = builder.interceptors.toArray(new ServerInterceptor[builder.interceptors.size()]); + this.handshakeTimeoutMillis = builder.handshakeTimeoutMillis; } /** @@ -308,7 +312,9 @@ public ServerTransportListener transportCreated(ServerTransport transport) { synchronized (lock) { transports.add(transport); } - return new ServerTransportListenerImpl(transport); + ServerTransportListenerImpl stli = new ServerTransportListenerImpl(transport); + stli.init(); + return stli; } @Override @@ -338,14 +344,36 @@ public void serverShutdown() { private final class ServerTransportListenerImpl implements ServerTransportListener { private final ServerTransport transport; + private Future handshakeTimeoutFuture; private Attributes attributes; ServerTransportListenerImpl(ServerTransport transport) { this.transport = transport; } + public void init() { + class TransportShutdownNow implements Runnable { + @Override public void run() { + transport.shutdownNow(Status.CANCELLED.withDescription("Handshake timeout exceeded")); + } + } + + if (handshakeTimeoutMillis != Long.MAX_VALUE) { + handshakeTimeoutFuture = transport.getScheduledExecutorService() + .schedule(new TransportShutdownNow(), handshakeTimeoutMillis, TimeUnit.MILLISECONDS); + } else { + // Noop, to avoid triggering Thread creation in InProcessServer + handshakeTimeoutFuture = new FutureTask(new Runnable() { + @Override public void run() {} + }, null); + } + } + @Override public Attributes transportReady(Attributes attributes) { + handshakeTimeoutFuture.cancel(false); + handshakeTimeoutFuture = null; + for (ServerTransportFilter filter : transportFilters) { attributes = Preconditions.checkNotNull(filter.transportReady(attributes), "Filter %s returned null", filter); @@ -356,6 +384,10 @@ public Attributes transportReady(Attributes attributes) { @Override public void transportTerminated() { + if (handshakeTimeoutFuture != null) { + handshakeTimeoutFuture.cancel(false); + handshakeTimeoutFuture = null; + } for (ServerTransportFilter filter : transportFilters) { filter.transportTerminated(attributes); } diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java index d95f0b076fc..4a60e2ea736 100644 --- a/core/src/test/java/io/grpc/internal/ServerImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java @@ -75,6 +75,7 @@ import java.util.concurrent.CyclicBarrier; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -346,11 +347,35 @@ public void start(ServerListener listener) throws IOException { verifyNoMoreInteractions(executorPool); } + @Test + public void transportHandshakeTimeout_expired() throws Exception { + class ShutdownRecordingTransport extends SimpleServerTransport { + Status shutdownNowStatus; + + @Override public void shutdownNow(Status status) { + shutdownNowStatus = status; + super.shutdownNow(status); + } + } + + builder.handshakeTimeout(60, TimeUnit.SECONDS); + createAndStartServer(); + ShutdownRecordingTransport serverTransport = new ShutdownRecordingTransport(); + ServerTransportListener transportListener + = transportServer.registerNewServerTransport(serverTransport); + timer.forwardTime(59, TimeUnit.SECONDS); + assertNull("shutdownNow status", serverTransport.shutdownNowStatus); + // Don't call transportReady() in time + timer.forwardTime(2, TimeUnit.SECONDS); + assertNotNull("shutdownNow status", serverTransport.shutdownNowStatus); + } + @Test public void methodNotFound() throws Exception { createAndStartServer(); ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); Metadata requestHeaders = new Metadata(); StatsTraceContext statsTraceCtx = StatsTraceContext.newServerContext( @@ -377,6 +402,7 @@ public void decompressorNotFound() throws Exception { createAndStartServer(); ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); Metadata requestHeaders = new Metadata(); requestHeaders.put(MESSAGE_ENCODING_KEY, decompressorName); StatsTraceContext statsTraceCtx = @@ -421,6 +447,7 @@ public ServerCall.Listener startCall( }).build()); ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); Metadata requestHeaders = new Metadata(); requestHeaders.put(metadataKey, "value"); @@ -620,6 +647,7 @@ public ServerCall.Listener startCall( ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); Metadata requestHeaders = new Metadata(); StatsTraceContext statsTraceCtx = @@ -664,6 +692,7 @@ public ServerCall.Listener startCall( }).build()); ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); Metadata requestHeaders = new Metadata(); StatsTraceContext statsTraceCtx = @@ -825,6 +854,7 @@ private void checkContext() { }).build()); ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); Metadata requestHeaders = new Metadata(); StatsTraceContext statsTraceCtx = @@ -890,6 +920,7 @@ public ServerCall.Listener startCall( }).build()); ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); Metadata requestHeaders = new Metadata(); StatsTraceContext statsTraceCtx = StatsTraceContext.newServerContext(streamTracerFactories, "Waitier/serve", requestHeaders); @@ -996,6 +1027,7 @@ public void handlerRegistryPriorities() throws Exception { ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); Metadata requestHeaders = new Metadata(); StatsTraceContext statsTraceCtx = StatsTraceContext.newServerContext(streamTracerFactories, "Waiter/serve", requestHeaders);