diff --git a/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java b/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java index ac9b3301eea..e4a84986976 100644 --- a/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java +++ b/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java @@ -21,9 +21,11 @@ import io.grpc.ChannelCredentials; import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; +import io.grpc.InsecureServerCredentials; import io.grpc.ManagedChannel; import io.grpc.Server; -import io.grpc.ServerBuilder; +import io.grpc.ServerCredentials; +import io.grpc.TlsServerCredentials; import io.grpc.internal.testing.TestUtils; import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; @@ -69,14 +71,15 @@ public void noNormalNetty() throws Exception { @Test public void serviceLoaderFindsNetty() throws Exception { - assertThat(ServerBuilder.forPort(0)).isInstanceOf(NettyServerBuilder.class); + assertThat(Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create())) + .isInstanceOf(NettyServerBuilder.class); assertThat(Grpc.newChannelBuilder("localhost:1234", InsecureChannelCredentials.create())) .isInstanceOf(NettyChannelBuilder.class); } @Test public void basic() throws Exception { - server = ServerBuilder.forPort(0) + server = Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create()) .addService(new SimpleServiceImpl()) .build().start(); channel = Grpc.newChannelBuilder( @@ -89,8 +92,9 @@ public void basic() throws Exception { @Test public void tcnative() throws Exception { - server = NettyServerBuilder.forPort(0) - .useTransportSecurity(TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key")) + ServerCredentials serverCreds = TlsServerCredentials.create( + TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key")); + server = Grpc.newServerBuilderForPort(0, serverCreds) .addService(new SimpleServiceImpl()) .build().start(); ChannelCredentials creds = NettySslContextChannelCredentials.create( diff --git a/netty/src/main/java/io/grpc/netty/InternalNettyChannelCredentials.java b/netty/src/main/java/io/grpc/netty/InternalNettyChannelCredentials.java index d121c563009..ab962e59d59 100644 --- a/netty/src/main/java/io/grpc/netty/InternalNettyChannelCredentials.java +++ b/netty/src/main/java/io/grpc/netty/InternalNettyChannelCredentials.java @@ -18,8 +18,6 @@ import io.grpc.ChannelCredentials; import io.grpc.Internal; -import io.netty.channel.ChannelHandler; -import io.netty.util.AsciiString; /** * Internal {@link NettyChannelCredentials} accessor. This is intended for usage internal to the @@ -50,27 +48,8 @@ final class ClientFactory implements InternalProtocolNegotiator.ClientFactory { @Override public InternalProtocolNegotiator.ProtocolNegotiator newNegotiator() { - final ProtocolNegotiator pn = result.negotiator.newNegotiator(); - final class LocalProtocolNegotiator - implements InternalProtocolNegotiator.ProtocolNegotiator { - - @Override - public AsciiString scheme() { - return pn.scheme(); - } - - @Override - public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { - return pn.newHandler(grpcHandler); - } - - @Override - public void close() { - pn.close(); - } - } - - return new LocalProtocolNegotiator(); + return new InternalProtocolNegotiator.ProtocolNegotiatorAdapter( + result.negotiator.newNegotiator()); } @Override diff --git a/netty/src/main/java/io/grpc/netty/InternalNettyServerCredentials.java b/netty/src/main/java/io/grpc/netty/InternalNettyServerCredentials.java new file mode 100644 index 00000000000..16e58d94369 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/InternalNettyServerCredentials.java @@ -0,0 +1,69 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import io.grpc.Internal; +import io.grpc.ServerCredentials; +import io.grpc.internal.ObjectPool; +import java.util.concurrent.Executor; + +/** + * Internal {@link NettyServerCredentials} accessor. This is intended for usage internal to the + * gRPC team. If you *really* think you need to use this, contact the gRPC team first. + */ +@Internal +public final class InternalNettyServerCredentials { + private InternalNettyServerCredentials() {} + + /** Creates a {@link ServerCredentials} that will use the provided {@code negotiator}. */ + public static ServerCredentials create(InternalProtocolNegotiator.ProtocolNegotiator negotiator) { + return NettyServerCredentials.create(ProtocolNegotiators.fixedServerFactory(negotiator)); + } + + /** + * Creates a {@link ServerCredentials} that will use the provided {@code negotiator}. Use of + * {@link #create(io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator)} is preferred over + * this method when possible. + */ + public static ServerCredentials create(InternalProtocolNegotiator.ServerFactory negotiator) { + return NettyServerCredentials.create(negotiator); + } + + /** + * Converts a {@link ServerCredentials} to a negotiator, in similar fashion as for a new server. + * + * @throws IllegalArgumentException if unable to convert + */ + public static InternalProtocolNegotiator.ServerFactory toNegotiator( + ServerCredentials channelCredentials) { + final ProtocolNegotiators.FromServerCredentialsResult result = + ProtocolNegotiators.from(channelCredentials); + if (result.error != null) { + throw new IllegalArgumentException(result.error); + } + final class ServerFactory implements InternalProtocolNegotiator.ServerFactory { + @Override + public InternalProtocolNegotiator.ProtocolNegotiator newNegotiator( + ObjectPool offloadExecutorPool) { + return new InternalProtocolNegotiator.ProtocolNegotiatorAdapter( + result.negotiator.newNegotiator(offloadExecutorPool)); + } + } + + return new ServerFactory(); + } +} diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiator.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiator.java index 0efa85eea75..1863ec07ced 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiator.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiator.java @@ -16,7 +16,12 @@ package io.grpc.netty; +import com.google.common.base.Preconditions; import io.grpc.Internal; +import io.grpc.internal.ObjectPool; +import io.netty.channel.ChannelHandler; +import io.netty.util.AsciiString; +import java.util.concurrent.Executor; /** * Internal accessor for {@link ProtocolNegotiator}. @@ -28,7 +33,35 @@ private InternalProtocolNegotiator() {} public interface ProtocolNegotiator extends io.grpc.netty.ProtocolNegotiator {} + static final class ProtocolNegotiatorAdapter + implements InternalProtocolNegotiator.ProtocolNegotiator { + private final io.grpc.netty.ProtocolNegotiator negotiator; + + public ProtocolNegotiatorAdapter(io.grpc.netty.ProtocolNegotiator negotiator) { + this.negotiator = Preconditions.checkNotNull(negotiator, "negotiator"); + } + + @Override + public AsciiString scheme() { + return negotiator.scheme(); + } + + @Override + public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { + return negotiator.newHandler(grpcHandler); + } + + @Override + public void close() { + negotiator.close(); + } + } + public interface ClientFactory extends io.grpc.netty.ProtocolNegotiator.ClientFactory { @Override ProtocolNegotiator newNegotiator(); } + + public interface ServerFactory extends io.grpc.netty.ProtocolNegotiator.ServerFactory { + @Override ProtocolNegotiator newNegotiator(ObjectPool offloadExecutorPool); + } } diff --git a/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java b/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java index 520e3bfbd9a..70b97d8d5c3 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java @@ -29,6 +29,7 @@ import io.grpc.ExperimentalApi; import io.grpc.Internal; import io.grpc.ServerBuilder; +import io.grpc.ServerCredentials; import io.grpc.ServerStreamTracer; import io.grpc.internal.AbstractServerImplBuilder; import io.grpc.internal.FixedObjectPool; @@ -58,7 +59,6 @@ import java.util.Map; import java.util.concurrent.TimeUnit; import javax.annotation.CheckReturnValue; -import javax.annotation.Nullable; import javax.net.ssl.SSLException; /** @@ -98,8 +98,8 @@ public final class NettyServerBuilder extends AbstractServerImplBuilder workerEventLoopGroupPool = DEFAULT_WORKER_EVENT_LOOP_GROUP_POOL; private boolean forceHeapBuffer; - private SslContext sslContext; - private ProtocolNegotiator protocolNegotiator; + private ProtocolNegotiator.ServerFactory protocolNegotiatorFactory; + private final boolean freezeProtocolNegotiatorFactory; private int maxConcurrentCallsPerConnection = Integer.MAX_VALUE; private boolean autoFlowControl = true; private int flowControlWindow = DEFAULT_FLOW_CONTROL_WINDOW; @@ -121,7 +121,18 @@ public final class NettyServerBuilder extends AbstractServerImplBuilder buildClientTransportServers( @@ -144,15 +170,20 @@ public List buildClientTransportServers( } @CheckReturnValue - private NettyServerBuilder(int port) { + private NettyServerBuilder(SocketAddress address) { serverImplBuilder = new ServerImplBuilder(new NettyClientTransportServersBuilder()); - this.listenAddresses.add(new InetSocketAddress(port)); + this.listenAddresses.add(address); + this.protocolNegotiatorFactory = ProtocolNegotiators.serverPlaintextFactory(); + this.freezeProtocolNegotiatorFactory = false; } @CheckReturnValue - private NettyServerBuilder(SocketAddress address) { + NettyServerBuilder( + SocketAddress address, ProtocolNegotiator.ServerFactory negotiatorFactory) { serverImplBuilder = new ServerImplBuilder(new NettyClientTransportServersBuilder()); this.listenAddresses.add(address); + this.protocolNegotiatorFactory = checkNotNull(negotiatorFactory, "negotiatorFactory"); + this.freezeProtocolNegotiatorFactory = true; } @Internal @@ -317,25 +348,28 @@ void setForceHeapBuffer(boolean value) { * have been configured with {@link GrpcSslContexts}, but options could have been overridden. */ public NettyServerBuilder sslContext(SslContext sslContext) { + checkState(!freezeProtocolNegotiatorFactory, + "Cannot change security when using ServerCredentials"); if (sslContext != null) { checkArgument(sslContext.isServer(), "Client SSL context can not be used for server"); GrpcSslContexts.ensureAlpnAndH2Enabled(sslContext.applicationProtocolNegotiator()); + protocolNegotiatorFactory = ProtocolNegotiators.serverTlsFactory(sslContext); + } else { + protocolNegotiatorFactory = ProtocolNegotiators.serverPlaintextFactory(); } - this.sslContext = sslContext; return this; } /** - * Sets the {@link ProtocolNegotiator} to be used. If non-{@code null}, overrides the value - * specified in {@link #sslContext(SslContext)}. - * - *

Default: {@code null}. + * Sets the {@link ProtocolNegotiator} to be used. Overrides the value specified in {@link + * #sslContext(SslContext)}. */ @Internal - public final NettyServerBuilder protocolNegotiator( - @Nullable ProtocolNegotiator protocolNegotiator) { - this.protocolNegotiator = protocolNegotiator; + public final NettyServerBuilder protocolNegotiator(ProtocolNegotiator protocolNegotiator) { + checkState(!freezeProtocolNegotiatorFactory, + "Cannot change security when using ServerCredentials"); + this.protocolNegotiatorFactory = ProtocolNegotiators.fixedServerFactory(protocolNegotiator); return this; } @@ -586,12 +620,8 @@ List buildTransportServers( List streamTracerFactories) { assertEventLoopsAndChannelType(); - ProtocolNegotiator negotiator = protocolNegotiator; - if (negotiator == null) { - negotiator = sslContext != null - ? ProtocolNegotiators.serverTls(sslContext, this.serverImplBuilder.getExecutorPool()) - : ProtocolNegotiators.serverPlaintext(); - } + ProtocolNegotiator negotiator = protocolNegotiatorFactory.newNegotiator( + this.serverImplBuilder.getExecutorPool()); List transportServers = new ArrayList<>(listenAddresses.size()); for (SocketAddress listenAddress : listenAddresses) { @@ -631,23 +661,31 @@ NettyServerBuilder setTransportTracerFactory( @Override public NettyServerBuilder useTransportSecurity(File certChain, File privateKey) { + checkState(!freezeProtocolNegotiatorFactory, + "Cannot change security when using ServerCredentials"); + SslContext sslContext; try { sslContext = GrpcSslContexts.forServer(certChain, privateKey).build(); } catch (SSLException e) { // This should likely be some other, easier to catch exception. throw new RuntimeException(e); } + protocolNegotiatorFactory = ProtocolNegotiators.serverTlsFactory(sslContext); return this; } @Override public NettyServerBuilder useTransportSecurity(InputStream certChain, InputStream privateKey) { + checkState(!freezeProtocolNegotiatorFactory, + "Cannot change security when using ServerCredentials"); + SslContext sslContext; try { sslContext = GrpcSslContexts.forServer(certChain, privateKey).build(); } catch (SSLException e) { // This should likely be some other, easier to catch exception. throw new RuntimeException(e); } + protocolNegotiatorFactory = ProtocolNegotiators.serverTlsFactory(sslContext); return this; } } diff --git a/netty/src/main/java/io/grpc/netty/NettyServerCredentials.java b/netty/src/main/java/io/grpc/netty/NettyServerCredentials.java new file mode 100644 index 00000000000..bc8b4f5f621 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/NettyServerCredentials.java @@ -0,0 +1,37 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import com.google.common.base.Preconditions; +import io.grpc.ServerCredentials; + +/** A credential with full control over the security handshake. */ +final class NettyServerCredentials extends ServerCredentials { + public static ServerCredentials create(ProtocolNegotiator.ServerFactory negotiator) { + return new NettyServerCredentials(negotiator); + } + + private final ProtocolNegotiator.ServerFactory negotiator; + + private NettyServerCredentials(ProtocolNegotiator.ServerFactory negotiator) { + this.negotiator = Preconditions.checkNotNull(negotiator, "negotiator"); + } + + public ProtocolNegotiator.ServerFactory getNegotiator() { + return negotiator; + } +} diff --git a/netty/src/main/java/io/grpc/netty/NettyServerProvider.java b/netty/src/main/java/io/grpc/netty/NettyServerProvider.java index 23b1e66c940..42d075d05cb 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerProvider.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerProvider.java @@ -17,10 +17,11 @@ package io.grpc.netty; import io.grpc.Internal; +import io.grpc.ServerCredentials; import io.grpc.ServerProvider; +import java.net.InetSocketAddress; - -/** Provider for {@link NettyChannelBuilder} instances. */ +/** Provider for {@link NettyServerBuilder} instances. */ @Internal public final class NettyServerProvider extends ServerProvider { @@ -38,5 +39,15 @@ protected int priority() { protected NettyServerBuilder builderForPort(int port) { return NettyServerBuilder.forPort(port); } + + @Override + protected NewServerBuilderResult newServerBuilderForPort(int port, ServerCredentials creds) { + ProtocolNegotiators.FromServerCredentialsResult result = ProtocolNegotiators.from(creds); + if (result.error != null) { + return NewServerBuilderResult.error(result.error); + } + return NewServerBuilderResult.serverBuilder( + new NettyServerBuilder(new InetSocketAddress(port), result.negotiator)); + } } diff --git a/netty/src/main/java/io/grpc/netty/NettySslContextChannelCredentials.java b/netty/src/main/java/io/grpc/netty/NettySslContextChannelCredentials.java index ff8a24669bc..ede511b68f6 100644 --- a/netty/src/main/java/io/grpc/netty/NettySslContextChannelCredentials.java +++ b/netty/src/main/java/io/grpc/netty/NettySslContextChannelCredentials.java @@ -16,6 +16,7 @@ package io.grpc.netty; +import com.google.common.base.Preconditions; import io.grpc.ChannelCredentials; import io.grpc.ExperimentalApi; import io.netty.handler.ssl.SslContext; @@ -30,6 +31,9 @@ private NettySslContextChannelCredentials() {} * with {@link GrpcSslContexts}, but options could have been overridden. */ public static ChannelCredentials create(SslContext sslContext) { + Preconditions.checkArgument(sslContext.isClient(), + "Server SSL context can not be used for client channel"); + GrpcSslContexts.ensureAlpnAndH2Enabled(sslContext.applicationProtocolNegotiator()); return NettyChannelCredentials.create(ProtocolNegotiators.tlsClientFactory(sslContext)); } } diff --git a/netty/src/main/java/io/grpc/netty/NettySslContextServerCredentials.java b/netty/src/main/java/io/grpc/netty/NettySslContextServerCredentials.java new file mode 100644 index 00000000000..9396cabdd23 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/NettySslContextServerCredentials.java @@ -0,0 +1,39 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import com.google.common.base.Preconditions; +import io.grpc.ExperimentalApi; +import io.grpc.ServerCredentials; +import io.netty.handler.ssl.SslContext; + +/** A credential that performs TLS with Netty's SslContext as configuration. */ +@ExperimentalApi("There is no plan to make this API stable, given transport API instability") +public final class NettySslContextServerCredentials { + private NettySslContextServerCredentials() {} + + /** + * Create a credential using Netty's SslContext as configuration. It must have been configured + * with {@link GrpcSslContexts}, but options could have been overridden. + */ + public static ServerCredentials create(SslContext sslContext) { + Preconditions.checkArgument(sslContext.isServer(), + "Client SSL context can not be used for server"); + GrpcSslContexts.ensureAlpnAndH2Enabled(sslContext.applicationProtocolNegotiator()); + return NettyServerCredentials.create(ProtocolNegotiators.serverTlsFactory(sslContext)); + } +} diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiator.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiator.java index 6307b97c8a0..8a2c6f104b2 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiator.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiator.java @@ -16,8 +16,10 @@ package io.grpc.netty; +import io.grpc.internal.ObjectPool; import io.netty.channel.ChannelHandler; import io.netty.util.AsciiString; +import java.util.concurrent.Executor; /** * An class that provides a Netty handler to control protocol negotiation. @@ -52,4 +54,13 @@ interface ClientFactory { /** Returns the implicit port to use if no port was specified explicitly by the user. */ int getDefaultPort(); } + + interface ServerFactory { + /** + * Creates a new negotiator. + * + * @param offloadExecutorPool an executor pool for time-consuming tasks + */ + ProtocolNegotiator newNegotiator(ObjectPool offloadExecutorPool); + } } diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index 7dbbddf9fc3..3779d050095 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -28,15 +28,19 @@ import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ChoiceChannelCredentials; +import io.grpc.ChoiceServerCredentials; import io.grpc.CompositeCallCredentials; import io.grpc.CompositeChannelCredentials; import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; +import io.grpc.InsecureServerCredentials; import io.grpc.InternalChannelz.Security; import io.grpc.InternalChannelz.Tls; import io.grpc.SecurityLevel; +import io.grpc.ServerCredentials; import io.grpc.Status; import io.grpc.TlsChannelCredentials; +import io.grpc.TlsServerCredentials; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; @@ -57,12 +61,14 @@ import io.netty.handler.ssl.OpenSsl; import io.netty.handler.ssl.OpenSslEngine; import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.SslHandshakeCompletionEvent; import io.netty.handler.ssl.SslProvider; import io.netty.util.AsciiString; import io.netty.util.Attribute; import io.netty.util.AttributeMap; +import java.io.ByteArrayInputStream; import java.net.SocketAddress; import java.net.URI; import java.nio.channels.ClosedChannelException; @@ -85,6 +91,8 @@ final class ProtocolNegotiators { private static final Logger log = Logger.getLogger(ProtocolNegotiators.class.getName()); private static final EnumSet understoodTlsFeatures = EnumSet.noneOf(TlsChannelCredentials.Feature.class); + private static final EnumSet understoodServerTlsFeatures = + EnumSet.noneOf(TlsServerCredentials.Feature.class); private ProtocolNegotiators() { @@ -167,6 +175,72 @@ public FromChannelCredentialsResult withCallCredentials(CallCredentials callCred } } + public static FromServerCredentialsResult from(ServerCredentials creds) { + if (creds instanceof TlsServerCredentials) { + TlsServerCredentials tlsCreds = (TlsServerCredentials) creds; + Set incomprehensible = + tlsCreds.incomprehensible(understoodServerTlsFeatures); + if (!incomprehensible.isEmpty()) { + return FromServerCredentialsResult.error( + "TLS features not understood: " + incomprehensible); + } + SslContextBuilder builder = GrpcSslContexts.forServer( + new ByteArrayInputStream(tlsCreds.getCertificateChain()), + new ByteArrayInputStream(tlsCreds.getPrivateKey()), + tlsCreds.getPrivateKeyPassword()); + SslContext sslContext; + try { + sslContext = builder.build(); + } catch (SSLException ex) { + throw new IllegalArgumentException( + "Unexpected error converting ServerCredentials to Netty SslContext", ex); + } + return FromServerCredentialsResult.negotiator(serverTlsFactory(sslContext)); + + } else if (creds instanceof InsecureServerCredentials) { + return FromServerCredentialsResult.negotiator(serverPlaintextFactory()); + + } else if (creds instanceof NettyServerCredentials) { + NettyServerCredentials nettyCreds = (NettyServerCredentials) creds; + return FromServerCredentialsResult.negotiator(nettyCreds.getNegotiator()); + + } else if (creds instanceof ChoiceServerCredentials) { + ChoiceServerCredentials choiceCreds = (ChoiceServerCredentials) creds; + StringBuilder error = new StringBuilder(); + for (ServerCredentials innerCreds : choiceCreds.getCredentialsList()) { + FromServerCredentialsResult result = from(innerCreds); + if (result.error == null) { + return result; + } + error.append(", "); + error.append(result.error); + } + return FromServerCredentialsResult.error(error.substring(2)); + + } else { + return FromServerCredentialsResult.error( + "Unsupported credential type: " + creds.getClass().getName()); + } + } + + public static final class FromServerCredentialsResult { + public final ProtocolNegotiator.ServerFactory negotiator; + public final String error; + + private FromServerCredentialsResult(ProtocolNegotiator.ServerFactory negotiator, String error) { + this.negotiator = negotiator; + this.error = error; + } + + public static FromServerCredentialsResult error(String error) { + return new FromServerCredentialsResult(null, Preconditions.checkNotNull(error, "error")); + } + + public static FromServerCredentialsResult negotiator(ProtocolNegotiator.ServerFactory factory) { + return new FromServerCredentialsResult(Preconditions.checkNotNull(factory, "factory"), null); + } + } + static ChannelLogger negotiationLogger(ChannelHandlerContext ctx) { return negotiationLogger(ctx.channel()); } @@ -190,6 +264,26 @@ public void log(ChannelLogLevel level, String messageFormat, Object... args) {} return new NoopChannelLogger(); } + public static ProtocolNegotiator.ServerFactory fixedServerFactory( + ProtocolNegotiator negotiator) { + return new FixedProtocolNegotiatorServerFactory(negotiator); + } + + private static final class FixedProtocolNegotiatorServerFactory + implements ProtocolNegotiator.ServerFactory { + private final ProtocolNegotiator protocolNegotiator; + + public FixedProtocolNegotiatorServerFactory(ProtocolNegotiator protocolNegotiator) { + this.protocolNegotiator = + Preconditions.checkNotNull(protocolNegotiator, "protocolNegotiator"); + } + + @Override + public ProtocolNegotiator newNegotiator(ObjectPool offloadExecutorPool) { + return protocolNegotiator; + } + } + /** * Create a server plaintext handler for gRPC. */ @@ -197,6 +291,41 @@ public static ProtocolNegotiator serverPlaintext() { return new PlaintextProtocolNegotiator(); } + /** + * Create a server plaintext handler factory for gRPC. + */ + public static ProtocolNegotiator.ServerFactory serverPlaintextFactory() { + return new PlaintextProtocolNegotiatorServerFactory(); + } + + @VisibleForTesting + static final class PlaintextProtocolNegotiatorServerFactory + implements ProtocolNegotiator.ServerFactory { + @Override + public ProtocolNegotiator newNegotiator(ObjectPool offloadExecutorPool) { + return serverPlaintext(); + } + } + + public static ProtocolNegotiator.ServerFactory serverTlsFactory(SslContext sslContext) { + return new TlsProtocolNegotiatorServerFactory(sslContext); + } + + @VisibleForTesting + static final class TlsProtocolNegotiatorServerFactory + implements ProtocolNegotiator.ServerFactory { + private final SslContext sslContext; + + public TlsProtocolNegotiatorServerFactory(SslContext sslContext) { + this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); + } + + @Override + public ProtocolNegotiator newNegotiator(ObjectPool offloadExecutorPool) { + return serverTls(sslContext, offloadExecutorPool); + } + } + /** * Create a server TLS handler for HTTP/2 capable of using ALPN/NPN. * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks diff --git a/netty/src/test/java/io/grpc/netty/NettyServerProviderTest.java b/netty/src/test/java/io/grpc/netty/NettyServerProviderTest.java index 70dbd643d58..6301ca51bef 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerProviderTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerProviderTest.java @@ -16,10 +16,13 @@ package io.grpc.netty; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; +import io.grpc.InsecureServerCredentials; +import io.grpc.ServerCredentials; import io.grpc.ServerProvider; import org.junit.Test; import org.junit.runner.RunWith; @@ -45,4 +48,20 @@ public void basicMethods() { public void builderIsANettyBuilder() { assertSame(NettyServerBuilder.class, provider.builderForPort(443).getClass()); } + + @Test + public void newServerBuilderForPort_success() { + ServerProvider.NewServerBuilderResult result = + provider.newServerBuilderForPort(80, InsecureServerCredentials.create()); + assertThat(result.getServerBuilder()).isInstanceOf(NettyServerBuilder.class); + } + + @Test + public void newServerBuilderForPort_fail() { + ServerProvider.NewServerBuilderResult result = provider.newServerBuilderForPort( + 80, new FakeServerCredentials()); + assertThat(result.getError()).contains("FakeServerCredentials"); + } + + private static final class FakeServerCredentials extends ServerCredentials {} } diff --git a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java index 2e87a089e89..e5fc1b74c90 100644 --- a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java +++ b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java @@ -33,14 +33,18 @@ import io.grpc.CallCredentials; import io.grpc.ChannelCredentials; import io.grpc.ChoiceChannelCredentials; +import io.grpc.ChoiceServerCredentials; import io.grpc.CompositeChannelCredentials; import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; +import io.grpc.InsecureServerCredentials; import io.grpc.InternalChannelz.Security; import io.grpc.SecurityLevel; +import io.grpc.ServerCredentials; import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.grpc.TlsChannelCredentials; +import io.grpc.TlsServerCredentials; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.testing.TestUtils; import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler; @@ -110,6 +114,7 @@ import javax.net.ssl.SSLSession; import org.junit.After; import org.junit.Before; +import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; import org.junit.rules.DisableOnDebug; @@ -128,6 +133,15 @@ public class ProtocolNegotiatorsTest { @Override public void run() {} }; + private static File server1Cert; + private static File server1Key; + + @BeforeClass + public static void loadCerts() throws Exception { + server1Cert = TestUtils.loadCert("server1.pem"); + server1Key = TestUtils.loadCert("server1.key"); + } + private static final int TIMEOUT_SECONDS = 60; @Rule public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(TIMEOUT_SECONDS)); @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 @@ -168,7 +182,7 @@ public void tearDown() { } @Test - public void from_unknown() { + public void fromClient_unknown() { ProtocolNegotiators.FromChannelCredentialsResult result = ProtocolNegotiators.from(new ChannelCredentials() {}); assertThat(result.error).isNotNull(); @@ -177,7 +191,7 @@ public void from_unknown() { } @Test - public void from_tls() { + public void fromClient_tls() { ProtocolNegotiators.FromChannelCredentialsResult result = ProtocolNegotiators.from(TlsChannelCredentials.create()); assertThat(result.error).isNull(); @@ -187,7 +201,7 @@ public void from_tls() { } @Test - public void from_unspportedTls() { + public void fromClient_unsupportedTls() { ProtocolNegotiators.FromChannelCredentialsResult result = ProtocolNegotiators.from(TlsChannelCredentials.newBuilder().requireFakeFeature().build()); assertThat(result.error).contains("FAKE"); @@ -196,7 +210,7 @@ public void from_unspportedTls() { } @Test - public void from_insecure() { + public void fromClient_insecure() { ProtocolNegotiators.FromChannelCredentialsResult result = ProtocolNegotiators.from(InsecureChannelCredentials.create()); assertThat(result.error).isNull(); @@ -206,7 +220,7 @@ public void from_insecure() { } @Test - public void from_composite() { + public void fromClient_composite() { CallCredentials callCredentials = mock(CallCredentials.class); ProtocolNegotiators.FromChannelCredentialsResult result = ProtocolNegotiators.from(CompositeChannelCredentials.create( @@ -225,7 +239,7 @@ public void from_composite() { } @Test - public void from_netty() { + public void fromClient_netty() { ProtocolNegotiator.ClientFactory factory = mock(ProtocolNegotiator.ClientFactory.class); ProtocolNegotiators.FromChannelCredentialsResult result = ProtocolNegotiators.from(NettyChannelCredentials.create(factory)); @@ -235,7 +249,7 @@ public void from_netty() { } @Test - public void from_choice() { + public void fromClient_choice() { ProtocolNegotiators.FromChannelCredentialsResult result = ProtocolNegotiators.from(ChoiceChannelCredentials.create( new ChannelCredentials() {}, @@ -257,7 +271,7 @@ public void from_choice() { } @Test - public void from_choice_unknown() { + public void fromClient_choice_unknown() { ProtocolNegotiators.FromChannelCredentialsResult result = ProtocolNegotiators.from(ChoiceChannelCredentials.create( new ChannelCredentials() {})); @@ -266,6 +280,82 @@ public void from_choice_unknown() { assertThat(result.negotiator).isNull(); } + @Test + public void fromServer_unknown() { + ProtocolNegotiators.FromServerCredentialsResult result = + ProtocolNegotiators.from(new ServerCredentials() {}); + assertThat(result.error).isNotNull(); + assertThat(result.negotiator).isNull(); + } + + @Test + public void fromServer_tls() throws Exception { + ProtocolNegotiators.FromServerCredentialsResult result = + ProtocolNegotiators.from(TlsServerCredentials.create(server1Cert, server1Key)); + assertThat(result.error).isNull(); + assertThat(result.negotiator) + .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorServerFactory.class); + } + + @Test + public void fromServer_unsupportedTls() throws Exception { + ProtocolNegotiators.FromServerCredentialsResult result = ProtocolNegotiators.from( + TlsServerCredentials.newBuilder() + .keyManager(server1Cert, server1Key) + .requireFakeFeature() + .build()); + assertThat(result.error).contains("FAKE"); + assertThat(result.negotiator).isNull(); + } + + @Test + public void fromServer_insecure() { + ProtocolNegotiators.FromServerCredentialsResult result = + ProtocolNegotiators.from(InsecureServerCredentials.create()); + assertThat(result.error).isNull(); + assertThat(result.negotiator) + .isInstanceOf(ProtocolNegotiators.PlaintextProtocolNegotiatorServerFactory.class); + } + + @Test + public void fromServer_netty() { + ProtocolNegotiator.ServerFactory factory = mock(ProtocolNegotiator.ServerFactory.class); + ProtocolNegotiators.FromServerCredentialsResult result = + ProtocolNegotiators.from(NettyServerCredentials.create(factory)); + assertThat(result.error).isNull(); + assertThat(result.negotiator).isSameInstanceAs(factory); + } + + @Test + public void fromServer_choice() throws Exception { + ProtocolNegotiators.FromServerCredentialsResult result = + ProtocolNegotiators.from(ChoiceServerCredentials.create( + new ServerCredentials() {}, + TlsServerCredentials.create(server1Cert, server1Key), + InsecureServerCredentials.create())); + assertThat(result.error).isNull(); + assertThat(result.negotiator) + .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorServerFactory.class); + + result = ProtocolNegotiators.from(ChoiceServerCredentials.create( + InsecureServerCredentials.create(), + new ServerCredentials() {}, + TlsServerCredentials.create(server1Cert, server1Key))); + assertThat(result.error).isNull(); + assertThat(result.negotiator) + .isInstanceOf(ProtocolNegotiators.PlaintextProtocolNegotiatorServerFactory.class); + } + + @Test + public void fromServer_choice_unknown() { + ProtocolNegotiators.FromServerCredentialsResult result = + ProtocolNegotiators.from(ChoiceServerCredentials.create( + new ServerCredentials() {})); + assertThat(result.error).isNotNull(); + assertThat(result.negotiator).isNull(); + } + + @Test public void waitUntilActiveHandler_handlerAdded() throws Exception { final CountDownLatch latch = new CountDownLatch(1);