From 03c06316ae76b9753f3b724afb5e6e3450b80475 Mon Sep 17 00:00:00 2001 From: Rob Rudin Date: Thu, 26 Jan 2023 15:27:51 -0500 Subject: [PATCH] Improving how cloud auth applies default SSLContext The check is now made in `DatabaseClientPropertySource`, which is at a lower level than `DatabaseClientBuilder`. This ensures that e.g. an ml-gradle user gets the benefit of not having to specify "sslProtocol=default" if their intent is to use their JVM's default truststore. Also improved the type-checking in DatabaseClientPropertySource so that a friendly error message is thrown for any value that is of an incorrect type. --- .../client/DatabaseClientBuilder.java | 11 +- .../client/DatabaseClientFactory.java | 3 +- .../impl/DatabaseClientPropertySource.java | 201 +++++++++++++----- .../DatabaseClientPropertySourceTest.java | 28 +++ .../test/DatabaseClientBuilderTest.java | 10 +- 5 files changed, 187 insertions(+), 66 deletions(-) diff --git a/marklogic-client-api/src/main/java/com/marklogic/client/DatabaseClientBuilder.java b/marklogic-client-api/src/main/java/com/marklogic/client/DatabaseClientBuilder.java index 9498ce2ff..38ed6e8ee 100644 --- a/marklogic-client-api/src/main/java/com/marklogic/client/DatabaseClientBuilder.java +++ b/marklogic-client-api/src/main/java/com/marklogic/client/DatabaseClientBuilder.java @@ -121,7 +121,6 @@ public DatabaseClientBuilder withSecurityContext(DatabaseClientFactory.SecurityC } /** - * * @param type must be one of "basic", "digest", "cloud", "kerberos", "certificate", or "saml" * @return */ @@ -143,17 +142,9 @@ public DatabaseClientBuilder withDigestAuth(String username, String password) { } public DatabaseClientBuilder withMarkLogicCloudAuth(String apiKey, String basePath) { - withSecurityContextType(SECURITY_CONTEXT_TYPE_MARKLOGIC_CLOUD) + return withSecurityContextType(SECURITY_CONTEXT_TYPE_MARKLOGIC_CLOUD) .withCloudApiKey(apiKey) .withBasePath(basePath); - - // Assume sensible defaults for establishing an SSL connection. In the scenario where the user's JVM's - // truststore has a certificate matching that of the MarkLogic Cloud instance, this saves the user from having - // to configure anything except the API key and base path. - if (null == props.get(PREFIX + "sslProtocol") && null == props.get(PREFIX + "sslContext")) { - withSSLProtocol("default"); - } - return this; } public DatabaseClientBuilder withKerberosAuth(String principal) { diff --git a/marklogic-client-api/src/main/java/com/marklogic/client/DatabaseClientFactory.java b/marklogic-client-api/src/main/java/com/marklogic/client/DatabaseClientFactory.java index 02ba2d2fe..c5163f9b8 100644 --- a/marklogic-client-api/src/main/java/com/marklogic/client/DatabaseClientFactory.java +++ b/marklogic-client-api/src/main/java/com/marklogic/client/DatabaseClientFactory.java @@ -1230,7 +1230,8 @@ public String getCertificatePassword() { *
  • marklogic.client.database = must be a String
  • *
  • marklogic.client.connectionType = must be a String or instance of {@code ConnectionType}
  • *
  • marklogic.client.securityContext = an instance of {@code SecurityContext}; if set, then all other - * properties pertaining to the construction of a {@code SecurityContext} will be ignored
  • + * properties pertaining to the construction of a {@code SecurityContext} will be ignored, including the + * properties pertaing to SSL *
  • marklogic.client.securityContextType = required if marklogic.client.securityContext is not set; * must be a String and one of "basic", "digest", "cloud", "kerberos", "certificate", or "saml"
  • *
  • marklogic.client.username = must be a String; required for basic and digest authentication
  • diff --git a/marklogic-client-api/src/main/java/com/marklogic/client/impl/DatabaseClientPropertySource.java b/marklogic-client-api/src/main/java/com/marklogic/client/impl/DatabaseClientPropertySource.java index 5d12d89c5..bea899723 100644 --- a/marklogic-client-api/src/main/java/com/marklogic/client/impl/DatabaseClientPropertySource.java +++ b/marklogic-client-api/src/main/java/com/marklogic/client/impl/DatabaseClientPropertySource.java @@ -49,16 +49,36 @@ public class DatabaseClientPropertySource { static { connectionPropertyHandlers = new LinkedHashMap<>(); - connectionPropertyHandlers.put(PREFIX + "host", (bean, value) -> bean.setHost((String) value)); + connectionPropertyHandlers.put(PREFIX + "host", (bean, value) -> { + if (value instanceof String) { + bean.setHost((String) value); + } else { + throw new IllegalArgumentException("Host must be of type String"); + } + }); connectionPropertyHandlers.put(PREFIX + "port", (bean, value) -> { if (value instanceof String) { bean.setPort(Integer.parseInt((String) value)); - } else { + } else if (value instanceof Integer) { bean.setPort((int) value); + } else { + throw new IllegalArgumentException("Port must be of type String or Integer"); + } + }); + connectionPropertyHandlers.put(PREFIX + "database", (bean, value) -> { + if (value instanceof String) { + bean.setDatabase((String) value); + } else { + throw new IllegalArgumentException("Database must be of type String"); + } + }); + connectionPropertyHandlers.put(PREFIX + "basePath", (bean, value) -> { + if (value instanceof String) { + bean.setBasePath((String) value); + } else { + throw new IllegalArgumentException("Base path must be of type String"); } }); - connectionPropertyHandlers.put(PREFIX + "database", (bean, value) -> bean.setDatabase((String) value)); - connectionPropertyHandlers.put(PREFIX + "basePath", (bean, value) -> bean.setBasePath((String) value)); connectionPropertyHandlers.put(PREFIX + "connectionType", (bean, value) -> { if (value instanceof DatabaseClient.ConnectionType) { bean.setConnectionType((DatabaseClient.ConnectionType) value); @@ -67,8 +87,9 @@ public class DatabaseClientPropertySource { if (val.trim().length() > 0) { bean.setConnectionType(DatabaseClient.ConnectionType.valueOf(val.toUpperCase())); } - } else + } else { throw new IllegalArgumentException("Connection type must either be a String or an instance of ConnectionType"); + } }); } @@ -76,6 +97,9 @@ public DatabaseClientPropertySource(Function propertySource) { this.propertySource = propertySource; } + /** + * @return an instance of {@code DatabaseClient} based on the given property source + */ public DatabaseClient newClient() { DatabaseClientFactory.Bean bean = newClientBean(); // For consistency with how clients have been created - i.e. not via a Bean class, but via @@ -85,9 +109,12 @@ public DatabaseClient newClient() { // (and this behavior is expected by some existing tests). return DatabaseClientFactory.newClient(bean.getHost(), bean.getPort(), bean.getBasePath(), bean.getDatabase(), bean.getSecurityContext(), bean.getConnectionType()); - } + /** + * @return an instance of {@code DatabaseClientFactory.Bean} based on the given property source. This is primarily + * intended for testing purposes so that the Bean can be verified without creating a client. + */ public DatabaseClientFactory.Bean newClientBean() { final DatabaseClientFactory.Bean bean = new DatabaseClientFactory.Bean(); connectionPropertyHandlers.forEach((propName, consumer) -> { @@ -101,20 +128,28 @@ public DatabaseClientFactory.Bean newClientBean() { } private DatabaseClientFactory.SecurityContext newSecurityContext() { - DatabaseClientFactory.SecurityContext securityContext = (DatabaseClientFactory.SecurityContext) - propertySource.apply(PREFIX + "securityContext"); - if (securityContext != null) { - return securityContext; + Object securityContextValue = propertySource.apply(PREFIX + "securityContext"); + if (securityContextValue != null) { + if (securityContextValue instanceof DatabaseClientFactory.SecurityContext) { + return (DatabaseClientFactory.SecurityContext) securityContextValue; + } + throw new IllegalArgumentException("Security context must be of type " + DatabaseClientFactory.SecurityContext.class.getName()); } - String type = (String) propertySource.apply(PREFIX + "securityContextType"); - if (type == null || type.trim().length() == 0) { - throw new IllegalArgumentException("Must define a security context or security context type"); + Object typeValue = propertySource.apply(PREFIX + "securityContextType"); + if (typeValue == null || !(typeValue instanceof String)) { + throw new IllegalArgumentException("Security context should be set, or security context type must be of type String"); } - securityContext = newSecurityContext(type); + final String securityContextType = (String)typeValue; + final SSLInputs sslInputs = buildSSLInputs(securityContextType); + + DatabaseClientFactory.SecurityContext securityContext = newSecurityContext(securityContextType, sslInputs); + + X509TrustManager trustManager = determineTrustManager(sslInputs); + SSLContext sslContext = sslInputs.getSslContext() != null ? + sslInputs.getSslContext() : + determineSSLContext(sslInputs, trustManager); - X509TrustManager trustManager = determineTrustManager(); - SSLContext sslContext = determineSSLContext(trustManager); if (sslContext != null) { securityContext.withSSLContext(sslContext, trustManager); } @@ -123,7 +158,7 @@ private DatabaseClientFactory.SecurityContext newSecurityContext() { return securityContext; } - private DatabaseClientFactory.SecurityContext newSecurityContext(String type) { + private DatabaseClientFactory.SecurityContext newSecurityContext(String type, SSLInputs sslInputs) { switch (type.toLowerCase()) { case DatabaseClientBuilder.SECURITY_CONTEXT_TYPE_BASIC: return newBasicAuthContext(); @@ -134,7 +169,7 @@ private DatabaseClientFactory.SecurityContext newSecurityContext(String type) { case DatabaseClientBuilder.SECURITY_CONTEXT_TYPE_KERBEROS: return newKerberosAuthContext(); case DatabaseClientBuilder.SECURITY_CONTEXT_TYPE_CERTIFICATE: - return newCertificateAuthContext(); + return newCertificateAuthContext(sslInputs); case DatabaseClientBuilder.SECURITY_CONTEXT_TYPE_SAML: return newSAMLAuthContext(); default: @@ -142,32 +177,44 @@ private DatabaseClientFactory.SecurityContext newSecurityContext(String type) { } } + private String getRequiredStringValue(String propertyName) { + Object value = propertySource.apply(PREFIX + propertyName); + if (value == null || !(value instanceof String)) { + throw new IllegalArgumentException(propertyName + " must be of type String"); + } + return (String) value; + } + + private String getNullableStringValue(String propertyName) { + Object value = propertySource.apply(PREFIX + propertyName); + if (value != null && !(value instanceof String)) { + throw new IllegalArgumentException(propertyName + " must be of type String"); + } + return (String)value; + } + private DatabaseClientFactory.SecurityContext newBasicAuthContext() { return new DatabaseClientFactory.BasicAuthContext( - (String) propertySource.apply(PREFIX + "username"), - (String) propertySource.apply(PREFIX + "password") + getRequiredStringValue("username"), getRequiredStringValue("password") ); } private DatabaseClientFactory.SecurityContext newDigestAuthContext() { return new DatabaseClientFactory.DigestAuthContext( - (String) propertySource.apply(PREFIX + "username"), - (String) propertySource.apply(PREFIX + "password") + getRequiredStringValue("username"), getRequiredStringValue("password") ); } private DatabaseClientFactory.SecurityContext newCloudAuthContext() { - return new DatabaseClientFactory.MarkLogicCloudAuthContext( - (String) propertySource.apply(PREFIX + "cloud.apiKey") - ); + return new DatabaseClientFactory.MarkLogicCloudAuthContext(getRequiredStringValue("cloud.apiKey")); } - private DatabaseClientFactory.SecurityContext newCertificateAuthContext() { + private DatabaseClientFactory.SecurityContext newCertificateAuthContext(SSLInputs sslInputs) { try { return new DatabaseClientFactory.CertificateAuthContext( - (String) propertySource.apply(PREFIX + "certificate.file"), - (String) propertySource.apply(PREFIX + "certificate.password"), - determineTrustManager() + getRequiredStringValue("certificate.file"), + getRequiredStringValue("certificate.password"), + sslInputs.getTrustManager() ); } catch (Exception e) { throw new RuntimeException("Unable to create CertificateAuthContext; cause " + e.getMessage(), e); @@ -175,23 +222,15 @@ private DatabaseClientFactory.SecurityContext newCertificateAuthContext() { } private DatabaseClientFactory.SecurityContext newKerberosAuthContext() { - return new DatabaseClientFactory.KerberosAuthContext( - (String) propertySource.apply(PREFIX + "kerberos.principal") - ); + return new DatabaseClientFactory.KerberosAuthContext(getRequiredStringValue("kerberos.principal")); } private DatabaseClientFactory.SecurityContext newSAMLAuthContext() { - return new DatabaseClientFactory.SAMLAuthContext( - (String) propertySource.apply(PREFIX + "saml.token") - ); + return new DatabaseClientFactory.SAMLAuthContext(getRequiredStringValue("saml.token")); } - private SSLContext determineSSLContext(X509TrustManager trustManager) { - SSLContext sslContext = (SSLContext) propertySource.apply(PREFIX + "sslContext"); - if (sslContext != null) { - return sslContext; - } - String protocol = (String) propertySource.apply(PREFIX + "sslProtocol"); + private SSLContext determineSSLContext(SSLInputs sslInputs, X509TrustManager trustManager) { + String protocol = sslInputs.getSslProtocol(); if (protocol != null) { if ("default".equalsIgnoreCase(protocol)) { try { @@ -200,6 +239,8 @@ private SSLContext determineSSLContext(X509TrustManager trustManager) { throw new RuntimeException("Unable to obtain default SSLContext; cause: " + e.getMessage(), e); } } + + SSLContext sslContext; try { sslContext = SSLContext.getInstance(protocol); } catch (NoSuchAlgorithmException e) { @@ -220,20 +261,15 @@ private SSLContext determineSSLContext(X509TrustManager trustManager) { return null; } - private X509TrustManager determineTrustManager() { - Object trustManagerObject = propertySource.apply(PREFIX + "trustManager"); - if (trustManagerObject != null) { - if (trustManagerObject instanceof X509TrustManager) { - return (X509TrustManager) trustManagerObject; - } - throw new IllegalArgumentException( - String.format("Trust manager must be an instance of %s", X509TrustManager.class.getName())); + private X509TrustManager determineTrustManager(SSLInputs sslInputs) { + if (sslInputs.getTrustManager() != null) { + return sslInputs.getTrustManager(); } // If the user chooses the "default" SSLContext, then it's already been initialized - but OkHttp still // needs a separate X509TrustManager, so use the JVM's default trust manager. The assumption is that the // default SSLContext was initialized with the JVM's default trust manager. A user can of course always override // this by simply providing their own trust manager. - if ("default".equalsIgnoreCase((String) propertySource.apply(PREFIX + "sslProtocol"))) { + if ("default".equalsIgnoreCase(sslInputs.getSslProtocol())) { X509TrustManager defaultTrustManager = SSLUtil.getDefaultTrustManager(); if (logger.isDebugEnabled() && defaultTrustManager != null && defaultTrustManager.getAcceptedIssuers() != null) { logger.debug("Count of accepted issuers in default trust manager: {}", @@ -261,4 +297,69 @@ private DatabaseClientFactory.SSLHostnameVerifier determineHostnameVerifier() { } return null; } + + /** + * Uses the given propertySource to construct the inputs pertaining to constructing an SSLContext and an + * X509TrustManager. + * + * @param securityContextType used for applying "default" as the SSL protocol for MarkLogic cloud authentication in + * case the user does not define their own SSLContext or SSL protocol + * @return + */ + private SSLInputs buildSSLInputs(String securityContextType) { + SSLContext sslContext = null; + Object val = propertySource.apply(PREFIX + "sslContext"); + if (val != null) { + if (val instanceof SSLContext) { + sslContext = (SSLContext) val; + } else { + throw new IllegalArgumentException("SSL context must be an instanceof " + SSLContext.class.getName()); + } + } + + String sslProtocol = getNullableStringValue("sslProtocol"); + if (sslContext == null && + (sslProtocol == null || sslProtocol.trim().length() == 0) && + DatabaseClientBuilder.SECURITY_CONTEXT_TYPE_MARKLOGIC_CLOUD.equalsIgnoreCase(securityContextType)) { + sslProtocol = "default"; + } + + val = propertySource.apply(PREFIX + "trustManager"); + X509TrustManager trustManager = null; + if (val != null) { + if (val instanceof X509TrustManager) { + trustManager = (X509TrustManager) val; + } else { + throw new IllegalArgumentException("Trust manager must be an instanceof " + X509TrustManager.class.getName()); + } + } + return new SSLInputs(sslContext, sslProtocol, trustManager); + } + + /** + * Captures the inputs provided by the caller that pertain to constructing an SSLContext. + */ + private static class SSLInputs { + private final SSLContext sslContext; + private final String sslProtocol; + private final X509TrustManager trustManager; + + public SSLInputs(SSLContext sslContext, String sslProtocol, X509TrustManager trustManager) { + this.sslContext = sslContext; + this.sslProtocol = sslProtocol; + this.trustManager = trustManager; + } + + public SSLContext getSslContext() { + return sslContext; + } + + public String getSslProtocol() { + return sslProtocol; + } + + public X509TrustManager getTrustManager() { + return trustManager; + } + } } diff --git a/marklogic-client-api/src/test/java/com/marklogic/client/impl/DatabaseClientPropertySourceTest.java b/marklogic-client-api/src/test/java/com/marklogic/client/impl/DatabaseClientPropertySourceTest.java index b7889c3d2..c1a686ea3 100644 --- a/marklogic-client-api/src/test/java/com/marklogic/client/impl/DatabaseClientPropertySourceTest.java +++ b/marklogic-client-api/src/test/java/com/marklogic/client/impl/DatabaseClientPropertySourceTest.java @@ -10,6 +10,8 @@ import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; /** * Intent of this test is to cover code that cannot be covered by DatabaseClientBuilderTest. @@ -26,6 +28,8 @@ public class DatabaseClientPropertySourceTest { void beforeEach() { props = new HashMap() {{ put(PREFIX + "securityContextType", "digest"); + put(PREFIX + "username", "someuser"); + put(PREFIX + "password", "someword"); }}; } @@ -68,6 +72,30 @@ void stringPort() { assertEquals(8000, bean.getPort()); } + @Test + void cloudAuthWithNoSslInputs() { + props.put(PREFIX + "securityContextType", "cloud"); + props.put(PREFIX + "cloud.apiKey", "abc123"); + props.put(PREFIX + "basePath", "/my/path"); + + bean = buildBean(); + + assertEquals("/my/path", bean.getBasePath()); + assertTrue(bean.getSecurityContext() instanceof DatabaseClientFactory.MarkLogicCloudAuthContext); + + DatabaseClientFactory.MarkLogicCloudAuthContext context = (DatabaseClientFactory.MarkLogicCloudAuthContext) bean.getSecurityContext(); + assertEquals("abc123", context.getKey()); + + assertNotNull(context.getSSLContext(), "If cloud is chosen with no SSL protocol or context, the default JVM " + + "SSLContext should be used"); + + assertNotNull(context.getSSLContext().getSocketFactory(), "The default JVM SSLContext should already be " + + "initialized and thus it should be possible to get a socket factory from it"); + + assertNotNull(context.getTrustManager(), "If cloud is chosen with no SSL protocol or context, the default JVM " + + "trust manager should be used"); + } + private DatabaseClientFactory.Bean buildBean() { DatabaseClientPropertySource source = new DatabaseClientPropertySource(propertyName -> props.get(propertyName)); return source.newClientBean(); diff --git a/marklogic-client-api/src/test/java/com/marklogic/client/test/DatabaseClientBuilderTest.java b/marklogic-client-api/src/test/java/com/marklogic/client/test/DatabaseClientBuilderTest.java index 6b40468bd..0966c706f 100644 --- a/marklogic-client-api/src/test/java/com/marklogic/client/test/DatabaseClientBuilderTest.java +++ b/marklogic-client-api/src/test/java/com/marklogic/client/test/DatabaseClientBuilderTest.java @@ -26,7 +26,7 @@ void minimumConnectionProperties() { bean = new DatabaseClientBuilder() .withHost("myhost") .withPort(8000) - .withSecurityContextType("digest") + .withBasicAuth("someuser", "someword") .buildBean(); assertEquals("myhost", bean.getHost()); @@ -34,7 +34,7 @@ void minimumConnectionProperties() { assertNull(bean.getDatabase()); assertNull(bean.getBasePath()); assertNull(bean.getConnectionType()); - assertTrue(bean.getSecurityContext() instanceof DatabaseClientFactory.DigestAuthContext); + assertTrue(bean.getSecurityContext() instanceof DatabaseClientFactory.BasicAuthContext); } @Test @@ -42,7 +42,7 @@ void allConnectionProperties() { bean = new DatabaseClientBuilder() .withHost("myhost") .withPort(8000) - .withSecurityContextType("digest") + .withDigestAuth("someuser", "someword") .withBasePath("/my/path") .withDatabase("Documents") .withConnectionType(DatabaseClient.ConnectionType.DIRECT) @@ -62,7 +62,7 @@ void noSecurityContextOrType() { .withHost("some-host") .withPort(10) .buildBean()); - assertEquals("Must define a security context or security context type", ex.getMessage()); + assertEquals("Security context should be set, or security context type must be of type String", ex.getMessage()); } @Test @@ -124,7 +124,7 @@ void cloudNoApiKey() { .withSecurityContextType("cloud") .withBasePath("/my/path") .build()); - assertEquals("No API key provided", ex.getMessage()); + assertEquals("cloud.apiKey must be of type String", ex.getMessage()); } @Test