Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions core/src/main/java/io/grpc/ServerBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.io.File;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;

/**
Expand Down Expand Up @@ -172,6 +173,20 @@ public T addStreamTracerFactory(ServerStreamTracer.Factory factory) {
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/1704")
public abstract T compressorRegistry(@Nullable CompressorRegistry registry);

/**
* Sets the permitted time for new connections to complete negotiation handshakes before being
* killed.
*
* @return this
* @throws IllegalArgumentException if timeout is negative
* @throws UnsupportedOperationException if unsupported
* @since 1.8.0
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/3706")
public T handshakeTimeout(long timeout, TimeUnit unit) {
throw new UnsupportedOperationException();
}

/**
* Builds a server using the given parameters.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.grpc.internal.GrpcUtil;
import java.io.File;
import java.util.List;
import java.util.concurrent.TimeUnit;

/**
* Builder for a server that services in-process requests. Clients identify the in-process server by
Expand Down Expand Up @@ -79,6 +80,9 @@ public static InProcessServerBuilder forPort(int port) {

private InProcessServerBuilder(String name) {
this.name = Preconditions.checkNotNull(name, "name");
// Disable handshake timeout because it is unnecessary, and can trigger Thread creation that can
// break some environments (like tests).
handshakeTimeout(Long.MAX_VALUE, TimeUnit.SECONDS);
}

@Override
Expand Down
12 changes: 12 additions & 0 deletions core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.grpc.internal;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;

import com.google.common.annotations.VisibleForTesting;
Expand All @@ -41,6 +42,7 @@
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;

/**
Expand Down Expand Up @@ -73,6 +75,7 @@ public List<ServerServiceDefinition> getServices() {
DecompressorRegistry.getDefaultInstance();
private static final CompressorRegistry DEFAULT_COMPRESSOR_REGISTRY =
CompressorRegistry.getDefaultInstance();
private static final long DEFAULT_HANDSHAKE_TIMEOUT_MILLIS = TimeUnit.SECONDS.toMillis(120);

final InternalHandlerRegistry.Builder registryBuilder =
new InternalHandlerRegistry.Builder();
Expand All @@ -96,6 +99,8 @@ public List<ServerServiceDefinition> getServices() {

CompressorRegistry compressorRegistry = DEFAULT_COMPRESSOR_REGISTRY;

long handshakeTimeoutMillis = DEFAULT_HANDSHAKE_TIMEOUT_MILLIS;

@Nullable
private StatsContextFactory statsFactory;

Expand Down Expand Up @@ -179,6 +184,13 @@ public final T compressorRegistry(CompressorRegistry registry) {
return thisT();
}

@Override
public final T handshakeTimeout(long timeout, TimeUnit unit) {
checkArgument(timeout > 0, "handshake timeout is %s, but must be positive", timeout);
handshakeTimeoutMillis = unit.toMillis(timeout);
return thisT();
}

/**
* Override the default stats implementation.
*/
Expand Down
34 changes: 33 additions & 1 deletion core/src/main/java/io/grpc/internal/ServerImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
Expand Down Expand Up @@ -82,6 +84,7 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
// This is iterated on a per-call basis. Use an array instead of a Collection to avoid iterator
// creations.
private final ServerInterceptor[] interceptors;
private final long handshakeTimeoutMillis;
@GuardedBy("lock") private boolean started;
@GuardedBy("lock") private boolean shutdown;
/** non-{@code null} if immediate shutdown has been requested. */
Expand Down Expand Up @@ -127,6 +130,7 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
new ArrayList<ServerTransportFilter>(builder.transportFilters));
this.interceptors =
builder.interceptors.toArray(new ServerInterceptor[builder.interceptors.size()]);
this.handshakeTimeoutMillis = builder.handshakeTimeoutMillis;
}

/**
Expand Down Expand Up @@ -308,7 +312,9 @@ public ServerTransportListener transportCreated(ServerTransport transport) {
synchronized (lock) {
transports.add(transport);
}
return new ServerTransportListenerImpl(transport);
ServerTransportListenerImpl stli = new ServerTransportListenerImpl(transport);
stli.init();
return stli;
}

@Override
Expand Down Expand Up @@ -338,14 +344,36 @@ public void serverShutdown() {

private final class ServerTransportListenerImpl implements ServerTransportListener {
private final ServerTransport transport;
private Future<?> handshakeTimeoutFuture;
private Attributes attributes;

ServerTransportListenerImpl(ServerTransport transport) {
this.transport = transport;
}

public void init() {
class TransportShutdownNow implements Runnable {
@Override public void run() {
transport.shutdownNow(Status.CANCELLED.withDescription("Handshake timeout exceeded"));
}
}

if (handshakeTimeoutMillis != Long.MAX_VALUE) {
handshakeTimeoutFuture = transport.getScheduledExecutorService()
.schedule(new TransportShutdownNow(), handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
} else {
// Noop, to avoid triggering Thread creation in InProcessServer
handshakeTimeoutFuture = new FutureTask<Void>(new Runnable() {
@Override public void run() {}
}, null);
}
}

@Override
public Attributes transportReady(Attributes attributes) {
handshakeTimeoutFuture.cancel(false);
handshakeTimeoutFuture = null;

for (ServerTransportFilter filter : transportFilters) {
attributes = Preconditions.checkNotNull(filter.transportReady(attributes),
"Filter %s returned null", filter);
Expand All @@ -356,6 +384,10 @@ public Attributes transportReady(Attributes attributes) {

@Override
public void transportTerminated() {
if (handshakeTimeoutFuture != null) {
handshakeTimeoutFuture.cancel(false);
handshakeTimeoutFuture = null;
}
for (ServerTransportFilter filter : transportFilters) {
filter.transportTerminated(attributes);
}
Expand Down
32 changes: 32 additions & 0 deletions core/src/test/java/io/grpc/internal/ServerImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -346,11 +347,35 @@ public void start(ServerListener listener) throws IOException {
verifyNoMoreInteractions(executorPool);
}

@Test
public void transportHandshakeTimeout_expired() throws Exception {
class ShutdownRecordingTransport extends SimpleServerTransport {
Status shutdownNowStatus;

@Override public void shutdownNow(Status status) {
shutdownNowStatus = status;
super.shutdownNow(status);
}
}

builder.handshakeTimeout(60, TimeUnit.SECONDS);
createAndStartServer();
ShutdownRecordingTransport serverTransport = new ShutdownRecordingTransport();
ServerTransportListener transportListener
= transportServer.registerNewServerTransport(serverTransport);
timer.forwardTime(59, TimeUnit.SECONDS);
assertNull("shutdownNow status", serverTransport.shutdownNowStatus);
// Don't call transportReady() in time
timer.forwardTime(2, TimeUnit.SECONDS);
assertNotNull("shutdownNow status", serverTransport.shutdownNowStatus);
}

@Test
public void methodNotFound() throws Exception {
createAndStartServer();
ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
transportListener.transportReady(Attributes.EMPTY);
Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
StatsTraceContext.newServerContext(
Expand All @@ -377,6 +402,7 @@ public void decompressorNotFound() throws Exception {
createAndStartServer();
ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
transportListener.transportReady(Attributes.EMPTY);
Metadata requestHeaders = new Metadata();
requestHeaders.put(MESSAGE_ENCODING_KEY, decompressorName);
StatsTraceContext statsTraceCtx =
Expand Down Expand Up @@ -421,6 +447,7 @@ public ServerCall.Listener<String> startCall(
}).build());
ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
transportListener.transportReady(Attributes.EMPTY);

Metadata requestHeaders = new Metadata();
requestHeaders.put(metadataKey, "value");
Expand Down Expand Up @@ -620,6 +647,7 @@ public ServerCall.Listener<String> startCall(

ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
transportListener.transportReady(Attributes.EMPTY);

Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
Expand Down Expand Up @@ -664,6 +692,7 @@ public ServerCall.Listener<String> startCall(
}).build());
ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
transportListener.transportReady(Attributes.EMPTY);

Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
Expand Down Expand Up @@ -825,6 +854,7 @@ private void checkContext() {
}).build());
ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
transportListener.transportReady(Attributes.EMPTY);

Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
Expand Down Expand Up @@ -890,6 +920,7 @@ public ServerCall.Listener<String> startCall(
}).build());
ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
transportListener.transportReady(Attributes.EMPTY);
Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
StatsTraceContext.newServerContext(streamTracerFactories, "Waitier/serve", requestHeaders);
Expand Down Expand Up @@ -996,6 +1027,7 @@ public void handlerRegistryPriorities() throws Exception {

ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
transportListener.transportReady(Attributes.EMPTY);
Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
StatsTraceContext.newServerContext(streamTracerFactories, "Waiter/serve", requestHeaders);
Expand Down