Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions core/src/main/java/io/grpc/ServerBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.io.File;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;

/**
Expand Down Expand Up @@ -172,6 +173,20 @@ public T addStreamTracerFactory(ServerStreamTracer.Factory factory) {
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/1704")
public abstract T compressorRegistry(@Nullable CompressorRegistry registry);

/**
* Sets the permitted time for new connections to complete negotiation handshakes before being
* killed.
*
* @return this
* @throws IllegalArgumentException if timeout is negative
* @throws UnsupportedOperationException if unsupported
* @since 1.8.0
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/3706")
public T handshakeTimeout(long timeout, TimeUnit unit) {
throw new UnsupportedOperationException();
}

/**
* Builds a server using the given parameters.
*
Expand Down
12 changes: 12 additions & 0 deletions core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.grpc.internal;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;

import com.google.common.annotations.VisibleForTesting;
Expand All @@ -39,6 +40,7 @@
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;

/**
Expand Down Expand Up @@ -71,6 +73,7 @@ public List<ServerServiceDefinition> getServices() {
DecompressorRegistry.getDefaultInstance();
private static final CompressorRegistry DEFAULT_COMPRESSOR_REGISTRY =
CompressorRegistry.getDefaultInstance();
private static final long DEFAULT_HANDSHAKE_TIMEOUT_MILLIS = TimeUnit.SECONDS.toMillis(20);

final InternalHandlerRegistry.Builder registryBuilder =
new InternalHandlerRegistry.Builder();
Expand All @@ -94,6 +97,8 @@ public List<ServerServiceDefinition> getServices() {

CompressorRegistry compressorRegistry = DEFAULT_COMPRESSOR_REGISTRY;

long handshakeTimeoutMillis = DEFAULT_HANDSHAKE_TIMEOUT_MILLIS;

@Nullable
private CensusStatsModule censusStatsOverride;

Expand Down Expand Up @@ -178,6 +183,13 @@ public final T compressorRegistry(CompressorRegistry registry) {
return thisT();
}

@Override
public final T handshakeTimeout(long timeout, TimeUnit unit) {
checkArgument(timeout > 0, "handshake timeout is %s, but must be positive", timeout);
handshakeTimeoutMillis = unit.toMillis(timeout);
return thisT();
}

/**
* Override the default stats implementation.
*/
Expand Down
26 changes: 25 additions & 1 deletion core/src/main/java/io/grpc/internal/ServerImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
Expand Down Expand Up @@ -82,6 +83,7 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
// This is iterated on a per-call basis. Use an array instead of a Collection to avoid iterator
// creations.
private final ServerInterceptor[] interceptors;
private final long handshakeTimeoutMillis;
@GuardedBy("lock") private boolean started;
@GuardedBy("lock") private boolean shutdown;
/** non-{@code null} if immediate shutdown has been requested. */
Expand Down Expand Up @@ -127,6 +129,7 @@ public final class ServerImpl extends io.grpc.Server implements WithLogId {
new ArrayList<ServerTransportFilter>(builder.transportFilters));
this.interceptors =
builder.interceptors.toArray(new ServerInterceptor[builder.interceptors.size()]);
this.handshakeTimeoutMillis = builder.handshakeTimeoutMillis;
}

/**
Expand Down Expand Up @@ -308,7 +311,9 @@ public ServerTransportListener transportCreated(ServerTransport transport) {
synchronized (lock) {
transports.add(transport);
}
return new ServerTransportListenerImpl(transport);
ServerTransportListenerImpl stli = new ServerTransportListenerImpl(transport);
stli.init();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the precise point that protocol negotiation starts, and transportReady the precise point that it ends? InProcessServer also does this? Does it make sense?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point a TCP connection has been established, but nothing else. transportReady is called before any RPCs are triggered.

return stli;
}

@Override
Expand Down Expand Up @@ -338,14 +343,29 @@ public void serverShutdown() {

private final class ServerTransportListenerImpl implements ServerTransportListener {
private final ServerTransport transport;
private Future<?> handshakeTimeoutFuture;
private Attributes attributes;

ServerTransportListenerImpl(ServerTransport transport) {
this.transport = transport;
}

public void init() {
class TransportShutdownNow implements Runnable {
@Override public void run() {
transport.shutdownNow(Status.CANCELLED.withDescription("Handshake timeout exceeded"));
}
}

handshakeTimeoutFuture = transport.getScheduledExecutorService()
.schedule(new TransportShutdownNow(), handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
}

@Override
public Attributes transportReady(Attributes attributes) {
handshakeTimeoutFuture.cancel(false);
handshakeTimeoutFuture = null;

for (ServerTransportFilter filter : transportFilters) {
attributes = Preconditions.checkNotNull(filter.transportReady(attributes),
"Filter %s returned null", filter);
Expand All @@ -356,6 +376,10 @@ public Attributes transportReady(Attributes attributes) {

@Override
public void transportTerminated() {
if (handshakeTimeoutFuture != null) {
handshakeTimeoutFuture.cancel(false);
handshakeTimeoutFuture = null;
}
for (ServerTransportFilter filter : transportFilters) {
filter.transportTerminated(attributes);
}
Expand Down
32 changes: 32 additions & 0 deletions core/src/test/java/io/grpc/internal/ServerImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
import java.util.concurrent.Executor;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -348,11 +349,35 @@ public void start(ServerListener listener) throws IOException {
verifyNoMoreInteractions(executorPool);
}

@Test
public void transportHandshakeTimeout_expired() throws Exception {
class ShutdownRecordingTransport extends SimpleServerTransport {
Status shutdownNowStatus;

@Override public void shutdownNow(Status status) {
shutdownNowStatus = status;
super.shutdownNow(status);
}
}

builder.handshakeTimeout(60, TimeUnit.SECONDS);
createAndStartServer();
ShutdownRecordingTransport serverTransport = new ShutdownRecordingTransport();
ServerTransportListener transportListener
= transportServer.registerNewServerTransport(serverTransport);
timer.forwardTime(59, TimeUnit.SECONDS);
assertNull("shutdownNow status", serverTransport.shutdownNowStatus);
// Don't call transportReady() in time
timer.forwardTime(2, TimeUnit.SECONDS);
assertNotNull("shutdownNow status", serverTransport.shutdownNowStatus);
}

@Test
public void methodNotFound() throws Exception {
createAndStartServer();
ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
transportListener.transportReady(Attributes.EMPTY);
Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
StatsTraceContext.newServerContext(
Expand All @@ -379,6 +404,7 @@ public void decompressorNotFound() throws Exception {
createAndStartServer();
ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
transportListener.transportReady(Attributes.EMPTY);
Metadata requestHeaders = new Metadata();
requestHeaders.put(MESSAGE_ENCODING_KEY, decompressorName);
StatsTraceContext statsTraceCtx =
Expand Down Expand Up @@ -423,6 +449,7 @@ public ServerCall.Listener<String> startCall(
}).build());
ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
transportListener.transportReady(Attributes.EMPTY);

Metadata requestHeaders = new Metadata();
requestHeaders.put(metadataKey, "value");
Expand Down Expand Up @@ -622,6 +649,7 @@ public ServerCall.Listener<String> startCall(

ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
transportListener.transportReady(Attributes.EMPTY);

Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
Expand Down Expand Up @@ -666,6 +694,7 @@ public ServerCall.Listener<String> startCall(
}).build());
ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
transportListener.transportReady(Attributes.EMPTY);

Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
Expand Down Expand Up @@ -827,6 +856,7 @@ private void checkContext() {
}).build());
ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
transportListener.transportReady(Attributes.EMPTY);

Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
Expand Down Expand Up @@ -892,6 +922,7 @@ public ServerCall.Listener<String> startCall(
}).build());
ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
transportListener.transportReady(Attributes.EMPTY);
Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
StatsTraceContext.newServerContext(streamTracerFactories, "Waitier/serve", requestHeaders);
Expand Down Expand Up @@ -998,6 +1029,7 @@ public void handlerRegistryPriorities() throws Exception {

ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
transportListener.transportReady(Attributes.EMPTY);
Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
StatsTraceContext.newServerContext(streamTracerFactories, "Waiter/serve", requestHeaders);
Expand Down
16 changes: 15 additions & 1 deletion netty/src/main/java/io/grpc/netty/NettyServerHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ class NettyServerHandler extends AbstractNettyHandler {
private final List<ServerStreamTracer.Factory> streamTracerFactories;
private final TransportTracer transportTracer;
private final KeepAliveEnforcer keepAliveEnforcer;
/** Incomplete attributes produced by negotiator. */
private Attributes negotiationAttributes;
/** Completed attributes produced by transportReady. */
private Attributes attributes;
private Throwable connectionError;
private boolean teWarningLogged;
Expand Down Expand Up @@ -481,7 +484,7 @@ protected void onStreamError(ChannelHandlerContext ctx, Throwable cause,

@Override
public void handleProtocolNegotiationCompleted(Attributes attrs) {
attributes = transportListener.transportReady(attrs);
negotiationAttributes = attrs;
}

@VisibleForTesting
Expand Down Expand Up @@ -680,6 +683,17 @@ private Http2Exception newStreamException(int streamId, Throwable cause) {
}

private class FrameListener extends Http2FrameAdapter {
private boolean firstSettings = true;

@Override
public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) {
if (firstSettings) {
firstSettings = false;
// Delay transportReady until we see the client's HTTP handshake, for coverage with
// handshakeTimeout
attributes = transportListener.transportReady(negotiationAttributes);
}
}

@Override
public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding,
Expand Down
32 changes: 24 additions & 8 deletions netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,13 @@
import java.util.List;
import java.util.Queue;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Matchers;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
Expand Down Expand Up @@ -136,7 +136,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
private long maxConnectionAgeGraceInNanos = MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE;
private long keepAliveTimeInNanos = DEFAULT_SERVER_KEEPALIVE_TIME_NANOS;
private long keepAliveTimeoutInNanos = DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS;
private TransportTracer transportTracer;
private TransportTracer transportTracer = new TransportTracer();

private class ServerTransportListenerImpl implements ServerTransportListener {

Expand All @@ -155,12 +155,10 @@ public void transportTerminated() {
}
}

@Override
protected void manualSetUp() throws Exception {
assertNull("manualSetUp should not run more than once", handler());

@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
transportTracer = new TransportTracer();

when(streamTracerFactory.newServerStreamTracer(anyString(), any(Metadata.class)))
.thenReturn(streamTracer);

Expand All @@ -178,7 +176,12 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
}
})
.when(streamListener)
.messagesAvailable(Matchers.<StreamListener.MessageProducer>any());
.messagesAvailable(any(StreamListener.MessageProducer.class));
}

@Override
protected void manualSetUp() throws Exception {
assertNull("manualSetUp should not run more than once", handler());

initChannel(new GrpcHttp2ServerHeadersDecoder(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE));

Expand All @@ -195,6 +198,19 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
channelRead(serializedSettings);
}

@Test
public void transportReadyDelayedUntilConnectionPreface() throws Exception {
initChannel(new GrpcHttp2ServerHeadersDecoder(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE));

handler().handleProtocolNegotiationCompleted(Attributes.EMPTY);
verify(transportListener, never()).transportReady(any(Attributes.class));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/any/isA/ ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're equivalent here. But isA would actually be worse, because if I had the wrong class I could miss callbacks that I wanted to trigger failure. I'm not trying to verify it is an Attributes; I just checking the method which happens to have certain arguments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this verify checks never(), I should have meant the second verify.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the second verify would fall into "I'm not trying to verify it is an Attributes".

// Simulate receipt of the connection preface
channelRead(Http2CodecUtil.connectionPrefaceBuf());
channelRead(serializeSettings(new Http2Settings()));
verify(transportListener).transportReady(any(Attributes.class));
}

@Test
public void sendFrameShouldSucceed() throws Exception {
manualSetUp();
Expand Down