From 199b1e40887907f12ad02604cea04c97430dd7fa Mon Sep 17 00:00:00 2001 From: Cassandra Coyle Date: Thu, 15 May 2025 17:53:24 -0500 Subject: [PATCH] Add TLS support for gRPC client Signed-off-by: Cassandra Coyle --- CONTRIBUTING.md | 5 + client/build.gradle | 19 +- .../durabletask/DurableTaskGrpcClient.java | 67 +++- .../DurableTaskGrpcClientBuilder.java | 53 +++ .../DurableTaskGrpcClientTlsTest.java | 330 ++++++++++++++++++ 5 files changed, 467 insertions(+), 7 deletions(-) create mode 100644 CONTRIBUTING.md create mode 100644 client/src/test/java/io/dapr/durabletask/DurableTaskGrpcClientTlsTest.java diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..fd898fa7 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,5 @@ +Build & test: + +```shell +./gradlew build +``` \ No newline at end of file diff --git a/client/build.gradle b/client/build.gradle index c4901f47..9ef733da 100644 --- a/client/build.gradle +++ b/client/build.gradle @@ -10,11 +10,11 @@ plugins { } group 'io.dapr' -version = '1.5.3' +version = '1.5.4' archivesBaseName = 'durabletask-client' -def grpcVersion = '1.59.0' -def protocVersion = '3.19.0' +def grpcVersion = '1.69.0' +def protocVersion = '3.25.5' def jacksonVersion = '2.15.3' // When build on local, you need to set this value to your local jdk11 directory. // Java11 is used to compile and run all the tests. @@ -38,6 +38,19 @@ dependencies { testImplementation(platform('org.junit:junit-bom:5.7.2')) testImplementation('org.junit.jupiter:junit-jupiter') + + // Netty dependencies for TLS + implementation "io.grpc:grpc-netty-shaded:${grpcVersion}" + implementation "io.netty:netty-handler:4.1.94.Final" + implementation "io.netty:netty-tcnative-boringssl-static:2.0.59.Final" + + // Add Netty dependencies to test classpath + testImplementation "io.grpc:grpc-netty-shaded:${grpcVersion}" + testImplementation "io.netty:netty-handler:4.1.94.Final" + testImplementation "io.netty:netty-tcnative-boringssl-static:2.0.59.Final" + + testImplementation 'org.bouncycastle:bcprov-jdk15on:1.70' + testImplementation 'org.bouncycastle:bcpkix-jdk15on:1.70' } compileJava { diff --git a/client/src/main/java/io/dapr/durabletask/DurableTaskGrpcClient.java b/client/src/main/java/io/dapr/durabletask/DurableTaskGrpcClient.java index c058563a..0ec0291a 100644 --- a/client/src/main/java/io/dapr/durabletask/DurableTaskGrpcClient.java +++ b/client/src/main/java/io/dapr/durabletask/DurableTaskGrpcClient.java @@ -9,6 +9,11 @@ import io.dapr.durabletask.implementation.protobuf.TaskHubSidecarServiceGrpc.*; import io.grpc.*; +import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; +import io.grpc.netty.shaded.io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import java.io.FileInputStream; +import java.io.InputStream; import javax.annotation.Nullable; import java.time.Duration; @@ -17,6 +22,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.logging.Logger; +import java.io.IOException; /** * Durable Task client implementation that uses gRPC to connect to a remote "sidecar" process. @@ -24,6 +30,10 @@ public final class DurableTaskGrpcClient extends DurableTaskClient { private static final int DEFAULT_PORT = 4001; private static final Logger logger = Logger.getLogger(DurableTaskGrpcClient.class.getPackage().getName()); + private static final String GRPC_TLS_CA_PATH = "DAPR_GRPC_TLS_CA_PATH"; + private static final String GRPC_TLS_CERT_PATH = "DAPR_GRPC_TLS_CERT_PATH"; + private static final String GRPC_TLS_KEY_PATH = "DAPR_GRPC_TLS_KEY_PATH"; + private static final String GRPC_TLS_INSECURE = "DAPR_GRPC_TLS_INSECURE"; private final DataConverter dataConverter; private final ManagedChannel managedSidecarChannel; @@ -44,11 +54,60 @@ public final class DurableTaskGrpcClient extends DurableTaskClient { port = builder.port; } + String endpoint = "localhost:" + port; + ManagedChannelBuilder channelBuilder; + + // Get TLS configuration from builder or environment variables + String tlsCaPath = builder.tlsCaPath != null ? builder.tlsCaPath : System.getenv(GRPC_TLS_CA_PATH); + String tlsCertPath = builder.tlsCertPath != null ? builder.tlsCertPath : System.getenv(GRPC_TLS_CERT_PATH); + String tlsKeyPath = builder.tlsKeyPath != null ? builder.tlsKeyPath : System.getenv(GRPC_TLS_KEY_PATH); + boolean insecure = builder.insecure || Boolean.parseBoolean(System.getenv(GRPC_TLS_INSECURE)); + + if (insecure) { + // Insecure mode - uses TLS but doesn't verify certificates + try { + channelBuilder = NettyChannelBuilder.forTarget(endpoint) + .sslContext(GrpcSslContexts.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .build()); + } catch (Exception e) { + throw new RuntimeException("Failed to create insecure TLS credentials", e); + } + } else if (tlsCertPath != null && tlsKeyPath != null) { + // mTLS case - using client cert and key, with optional CA cert for server authentication + try ( + InputStream clientCertInputStream = new FileInputStream(tlsCertPath); + InputStream clientKeyInputStream = new FileInputStream(tlsKeyPath); + InputStream caCertInputStream = tlsCaPath != null ? new FileInputStream(tlsCaPath) : null + ) { + TlsChannelCredentials.Builder tlsBuilder = TlsChannelCredentials.newBuilder() + .keyManager(clientCertInputStream, clientKeyInputStream); // For client authentication + if (caCertInputStream != null) { + tlsBuilder.trustManager(caCertInputStream); // For server authentication + } + ChannelCredentials credentials = tlsBuilder.build(); + channelBuilder = Grpc.newChannelBuilder(endpoint, credentials); + } catch (IOException e) { + throw new RuntimeException("Failed to create mTLS credentials" + + (tlsCaPath != null ? " with CA cert" : ""), e); + } + } else if (tlsCaPath != null) { + // Simple TLS case - using CA cert only for server authentication + try (InputStream caCertInputStream = new FileInputStream(tlsCaPath)) { + ChannelCredentials credentials = TlsChannelCredentials.newBuilder() + .trustManager(caCertInputStream) + .build(); + channelBuilder = Grpc.newChannelBuilder(endpoint, credentials); + } catch (IOException e) { + throw new RuntimeException("Failed to create TLS credentials with CA cert", e); + } + } else { + // No TLS config provided, use plaintext + channelBuilder = ManagedChannelBuilder.forTarget(endpoint).usePlaintext(); + } + // Need to keep track of this channel so we can dispose it on close() - this.managedSidecarChannel = ManagedChannelBuilder - .forAddress("localhost", port) - .usePlaintext() - .build(); + this.managedSidecarChannel = channelBuilder.build(); sidecarGrpcChannel = this.managedSidecarChannel; } diff --git a/client/src/main/java/io/dapr/durabletask/DurableTaskGrpcClientBuilder.java b/client/src/main/java/io/dapr/durabletask/DurableTaskGrpcClientBuilder.java index 050758fc..bac17849 100644 --- a/client/src/main/java/io/dapr/durabletask/DurableTaskGrpcClientBuilder.java +++ b/client/src/main/java/io/dapr/durabletask/DurableTaskGrpcClientBuilder.java @@ -12,6 +12,10 @@ public final class DurableTaskGrpcClientBuilder { DataConverter dataConverter; int port; Channel channel; + String tlsCaPath; + String tlsCertPath; + String tlsKeyPath; + boolean insecure; /** * Sets the {@link DataConverter} to use for converting serializable data payloads. @@ -53,6 +57,55 @@ public DurableTaskGrpcClientBuilder port(int port) { return this; } + /** + * Sets the path to the TLS CA certificate file for server authentication. + * If not set, the system's default CA certificates will be used. + * + * @param tlsCaPath path to the TLS CA certificate file + * @return this builder object + */ + public DurableTaskGrpcClientBuilder tlsCaPath(String tlsCaPath) { + this.tlsCaPath = tlsCaPath; + return this; + } + + /** + * Sets the path to the TLS client certificate file for client authentication. + * This is used for mTLS (mutual TLS) connections. + * + * @param tlsCertPath path to the TLS client certificate file + * @return this builder object + */ + public DurableTaskGrpcClientBuilder tlsCertPath(String tlsCertPath) { + this.tlsCertPath = tlsCertPath; + return this; + } + + /** + * Sets the path to the TLS client key file for client authentication. + * This is used for mTLS (mutual TLS) connections. + * + * @param tlsKeyPath path to the TLS client key file + * @return this builder object + */ + public DurableTaskGrpcClientBuilder tlsKeyPath(String tlsKeyPath) { + this.tlsKeyPath = tlsKeyPath; + return this; + } + + /** + * Sets whether to use insecure (plaintext) mode for gRPC communication. + * When set to true, TLS will be disabled and communication will be unencrypted. + * This should only be used for development/testing. + * + * @param insecure whether to use insecure mode + * @return this builder object + */ + public DurableTaskGrpcClientBuilder insecure(boolean insecure) { + this.insecure = insecure; + return this; + } + /** * Initializes a new {@link DurableTaskClient} object with the settings specified in the current builder object. * @return a new {@link DurableTaskClient} object diff --git a/client/src/test/java/io/dapr/durabletask/DurableTaskGrpcClientTlsTest.java b/client/src/test/java/io/dapr/durabletask/DurableTaskGrpcClientTlsTest.java new file mode 100644 index 00000000..07a95773 --- /dev/null +++ b/client/src/test/java/io/dapr/durabletask/DurableTaskGrpcClientTlsTest.java @@ -0,0 +1,330 @@ +package io.dapr.durabletask; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.api.condition.EnabledOnOs; +import org.junit.jupiter.api.condition.OS; +import org.junit.jupiter.api.Assumptions; + +import java.io.File; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.cert.X509Certificate; +import java.util.Base64; +import java.util.Date; +import java.math.BigInteger; + +import org.bouncycastle.asn1.x500.X500Name; +import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo; +import org.bouncycastle.cert.X509v3CertificateBuilder; +import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter; +import org.bouncycastle.operator.ContentSigner; +import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder; + +import static org.junit.jupiter.api.Assertions.*; + +public class DurableTaskGrpcClientTlsTest { + private static final int DEFAULT_PORT = 4001; + private static final String DEFAULT_SIDECAR_IP = "127.0.0.1"; + + @TempDir + Path tempDir; + + // Track the client for cleanup + private DurableTaskGrpcClient client; + + @AfterEach + void tearDown() throws Exception { + if (client != null) { + client.close(); + client = null; + } + } + + // Helper method to generate a key pair for testing + private static KeyPair generateKeyPair() throws Exception { + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); + keyPairGenerator.initialize(2048); + return keyPairGenerator.generateKeyPair(); + } + + // Helper method to generate a self-signed certificate + private static X509Certificate generateCertificate(KeyPair keyPair) throws Exception { + X500Name issuer = new X500Name("CN=Test Certificate"); + X500Name subject = new X500Name("CN=Test Certificate"); + Date notBefore = new Date(System.currentTimeMillis() - 24 * 60 * 60 * 1000); + Date notAfter = new Date(System.currentTimeMillis() + 365 * 24 * 60 * 60 * 1000L); + SubjectPublicKeyInfo publicKeyInfo = SubjectPublicKeyInfo.getInstance(keyPair.getPublic().getEncoded()); + X509v3CertificateBuilder certBuilder = new X509v3CertificateBuilder( + issuer, + BigInteger.valueOf(System.currentTimeMillis()), + notBefore, + notAfter, + subject, + publicKeyInfo + ); + ContentSigner signer = new JcaContentSignerBuilder("SHA256withRSA").build(keyPair.getPrivate()); + return new JcaX509CertificateConverter().getCertificate(certBuilder.build(signer)); + } + + private static void writeCertificateToFile(X509Certificate cert, File file) throws Exception { + String certPem = "-----BEGIN CERTIFICATE-----\n" + + Base64.getEncoder().encodeToString(cert.getEncoded()) + + "\n-----END CERTIFICATE-----"; + Files.write(file.toPath(), certPem.getBytes()); + } + + private static void writePrivateKeyToFile(KeyPair keyPair, File file) throws Exception { + String keyPem = "-----BEGIN PRIVATE KEY-----\n" + + Base64.getEncoder().encodeToString(keyPair.getPrivate().getEncoded()) + + "\n-----END PRIVATE KEY-----"; + Files.write(file.toPath(), keyPem.getBytes()); + } + + @Test + public void testBuildGrpcManagedChannelWithTls() throws Exception { + // Generate test certificate and key + KeyPair keyPair = generateKeyPair(); + X509Certificate cert = generateCertificate(keyPair); + + File certFile = File.createTempFile("test-cert", ".pem"); + File keyFile = File.createTempFile("test-key", ".pem"); + try { + writeCertificateToFile(cert, certFile); + writePrivateKeyToFile(keyPair, keyFile); + + client = (DurableTaskGrpcClient) new DurableTaskGrpcClientBuilder() + .tlsCertPath(certFile.getAbsolutePath()) + .tlsKeyPath(keyFile.getAbsolutePath()) + .build(); + + assertNotNull(client); + // Note: We can't easily test the actual TLS configuration without a real server + } finally { + certFile.delete(); + keyFile.delete(); + } + } + + @Test + public void testBuildGrpcManagedChannelWithTlsAndEndpoint() throws Exception { + // Generate test certificate and key + KeyPair keyPair = generateKeyPair(); + X509Certificate cert = generateCertificate(keyPair); + + File certFile = File.createTempFile("test-cert", ".pem"); + File keyFile = File.createTempFile("test-key", ".pem"); + try { + writeCertificateToFile(cert, certFile); + writePrivateKeyToFile(keyPair, keyFile); + + client = (DurableTaskGrpcClient) new DurableTaskGrpcClientBuilder() + .tlsCertPath(certFile.getAbsolutePath()) + .tlsKeyPath(keyFile.getAbsolutePath()) + .port(443) + .build(); + + assertNotNull(client); + } finally { + certFile.delete(); + keyFile.delete(); + } + } + + @Test + public void testBuildGrpcManagedChannelWithInvalidTlsCert() { + assertThrows(RuntimeException.class, () -> { + new DurableTaskGrpcClientBuilder() + .tlsCertPath("/nonexistent/cert.pem") + .tlsKeyPath("/nonexistent/key.pem") + .build(); + }); + } + + @Test + @EnabledOnOs({OS.LINUX, OS.MAC}) + public void testBuildGrpcManagedChannelWithTlsAndUnixSocket() throws Exception { + // Skip this test since Unix socket support is not implemented yet + Assumptions.assumeTrue(false, "Unix socket support not implemented yet"); + } + + @Test + public void testBuildGrpcManagedChannelWithTlsAndDnsAuthority() throws Exception { + // Generate test certificate and key + KeyPair keyPair = generateKeyPair(); + X509Certificate cert = generateCertificate(keyPair); + + File certFile = File.createTempFile("test-cert", ".pem"); + File keyFile = File.createTempFile("test-key", ".pem"); + try { + writeCertificateToFile(cert, certFile); + writePrivateKeyToFile(keyPair, keyFile); + + client = (DurableTaskGrpcClient) new DurableTaskGrpcClientBuilder() + .tlsCertPath(certFile.getAbsolutePath()) + .tlsKeyPath(keyFile.getAbsolutePath()) + .port(443) + .build(); + + assertNotNull(client); + } finally { + certFile.delete(); + keyFile.delete(); + } + } + + @Test + public void testBuildGrpcManagedChannelWithTlsAndCaCert() throws Exception { + // Generate test CA certificate + KeyPair caKeyPair = generateKeyPair(); + X509Certificate caCert = generateCertificate(caKeyPair); + + File caCertFile = File.createTempFile("test-ca-cert", ".pem"); + try { + writeCertificateToFile(caCert, caCertFile); + + client = (DurableTaskGrpcClient) new DurableTaskGrpcClientBuilder() + .tlsCaPath(caCertFile.getAbsolutePath()) + .build(); + + assertNotNull(client); + } finally { + caCertFile.delete(); + } + } + + @Test + public void testBuildGrpcManagedChannelWithTlsAndCaCertAndEndpoint() throws Exception { + // Generate test CA certificate + KeyPair caKeyPair = generateKeyPair(); + X509Certificate caCert = generateCertificate(caKeyPair); + + File caCertFile = File.createTempFile("test-ca-cert", ".pem"); + try { + writeCertificateToFile(caCert, caCertFile); + + client = (DurableTaskGrpcClient) new DurableTaskGrpcClientBuilder() + .tlsCaPath(caCertFile.getAbsolutePath()) + .port(443) + .build(); + + assertNotNull(client); + } finally { + caCertFile.delete(); + } + } + + @Test + public void testBuildGrpcManagedChannelWithInvalidCaCert() { + assertThrows(RuntimeException.class, () -> { + new DurableTaskGrpcClientBuilder() + .tlsCaPath("/nonexistent/ca.pem") + .build(); + }); + } + + @Test + public void testBuildGrpcManagedChannelWithMtlsAndCaCert() throws Exception { + // Generate test certificates + KeyPair caKeyPair = generateKeyPair(); + X509Certificate caCert = generateCertificate(caKeyPair); + KeyPair clientKeyPair = generateKeyPair(); + X509Certificate clientCert = generateCertificate(clientKeyPair); + + File caCertFile = File.createTempFile("test-ca-cert", ".pem"); + File clientCertFile = File.createTempFile("test-client-cert", ".pem"); + File clientKeyFile = File.createTempFile("test-client-key", ".pem"); + try { + writeCertificateToFile(caCert, caCertFile); + writeCertificateToFile(clientCert, clientCertFile); + writePrivateKeyToFile(clientKeyPair, clientKeyFile); + + client = (DurableTaskGrpcClient) new DurableTaskGrpcClientBuilder() + .tlsCaPath(caCertFile.getAbsolutePath()) + .tlsCertPath(clientCertFile.getAbsolutePath()) + .tlsKeyPath(clientKeyFile.getAbsolutePath()) + .build(); + + assertNotNull(client); + } finally { + caCertFile.delete(); + clientCertFile.delete(); + clientKeyFile.delete(); + } + } + + @Test + public void testBuildGrpcManagedChannelWithInsecureTls() throws Exception { + client = (DurableTaskGrpcClient) new DurableTaskGrpcClientBuilder() + .insecure(true) + .port(443) + .build(); + + assertNotNull(client); + } + + @Test + public void testBuildGrpcManagedChannelWithInsecureTlsAndMtls() throws Exception { + // Generate test certificates + KeyPair caKeyPair = generateKeyPair(); + X509Certificate caCert = generateCertificate(caKeyPair); + KeyPair clientKeyPair = generateKeyPair(); + X509Certificate clientCert = generateCertificate(clientKeyPair); + + File caCertFile = File.createTempFile("test-ca-cert", ".pem"); + File clientCertFile = File.createTempFile("test-client-cert", ".pem"); + File clientKeyFile = File.createTempFile("test-client-key", ".pem"); + try { + writeCertificateToFile(caCert, caCertFile); + writeCertificateToFile(clientCert, clientCertFile); + writePrivateKeyToFile(clientKeyPair, clientKeyFile); + + client = (DurableTaskGrpcClient) new DurableTaskGrpcClientBuilder() + .insecure(true) + .tlsCaPath(caCertFile.getAbsolutePath()) + .tlsCertPath(clientCertFile.getAbsolutePath()) + .tlsKeyPath(clientKeyFile.getAbsolutePath()) + .port(443) + .build(); + + assertNotNull(client); + } finally { + caCertFile.delete(); + clientCertFile.delete(); + clientKeyFile.delete(); + } + } + + @Test + public void testBuildGrpcManagedChannelWithInsecureTlsAndCustomEndpoint() throws Exception { + client = (DurableTaskGrpcClient) new DurableTaskGrpcClientBuilder() + .insecure(true) + .port(443) + .build(); + + assertNotNull(client); + } + + @Test + public void testBuildGrpcManagedChannelWithPlaintext() throws Exception { + // No TLS config provided, should use plaintext + client = (DurableTaskGrpcClient) new DurableTaskGrpcClientBuilder() + .port(443) + .build(); + + assertNotNull(client); + } + + @Test + public void testBuildGrpcManagedChannelWithPlaintextAndCustomEndpoint() throws Exception { + // No TLS config provided, should use plaintext + client = (DurableTaskGrpcClient) new DurableTaskGrpcClientBuilder() + .port(50001) // Custom port + .build(); + + assertNotNull(client); + } +} \ No newline at end of file