From 3b73a4a2244973c547a36f972d0f4832c4a2cd4c Mon Sep 17 00:00:00 2001 From: Rob Rudin Date: Mon, 2 Oct 2023 10:26:14 -0400 Subject: [PATCH] Refactored logic for determining SSL inputs I was looking at this while analyzing support for 2-way SSL and I couldn't really understand what I wrote. So I refactored in the following way: 1. All construction for SSLContext/TrustManager is in buildSSLInputs. 2. There are 4 approaches, each clearly identified and implemented in its own method. 3. I moved the tests I use for verifying SSL support to a "test.ssl" package (no changes to the tests themselves). This may actually fix a bug but I'm not certain. The bug would have been that "newCertificateAuthContext" was called before all the SSL-input logic had occurred, meaning that e.g. "default" as an sslProtocol value would not have impacted certificate authentication. --- .../impl/DatabaseClientPropertySource.java | 178 +++++++++--------- .../com/marklogic/client/impl/SSLUtil.java | 10 +- .../{ => ssl}/CheckSSLConnectionTest.java | 3 +- .../client/test/{ => ssl}/SSLTest.java | 3 +- .../client/test/{ => ssl}/TwoWaySSLTest.java | 3 +- 5 files changed, 105 insertions(+), 92 deletions(-) rename marklogic-client-api/src/test/java/com/marklogic/client/test/{ => ssl}/CheckSSLConnectionTest.java (98%) rename marklogic-client-api/src/test/java/com/marklogic/client/test/{ => ssl}/SSLTest.java (98%) rename marklogic-client-api/src/test/java/com/marklogic/client/test/{ => ssl}/TwoWaySSLTest.java (99%) diff --git a/marklogic-client-api/src/main/java/com/marklogic/client/impl/DatabaseClientPropertySource.java b/marklogic-client-api/src/main/java/com/marklogic/client/impl/DatabaseClientPropertySource.java index 3781897ae..eb5083045 100644 --- a/marklogic-client-api/src/main/java/com/marklogic/client/impl/DatabaseClientPropertySource.java +++ b/marklogic-client-api/src/main/java/com/marklogic/client/impl/DatabaseClientPropertySource.java @@ -19,8 +19,6 @@ import com.marklogic.client.DatabaseClientBuilder; import com.marklogic.client.DatabaseClientFactory; import com.marklogic.client.extra.okhttpclient.RemoveAcceptEncodingConfigurator; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.X509TrustManager; @@ -40,7 +38,6 @@ */ public class DatabaseClientPropertySource { - private static final Logger logger = LoggerFactory.getLogger(DatabaseClientPropertySource.class); private static final String PREFIX = DatabaseClientBuilder.PREFIX; private final Function propertySource; @@ -97,7 +94,7 @@ public class DatabaseClientPropertySource { if (value instanceof Boolean && Boolean.TRUE.equals(value)) { disableGzippedResponses = true; } else if (value instanceof String) { - disableGzippedResponses = Boolean.parseBoolean((String)value); + disableGzippedResponses = Boolean.parseBoolean((String) value); } if (disableGzippedResponses) { DatabaseClientFactory.addConfigurator(new RemoveAcceptEncodingConfigurator()); @@ -152,20 +149,13 @@ private DatabaseClientFactory.SecurityContext newSecurityContext() { if (typeValue == null || !(typeValue instanceof String)) { throw new IllegalArgumentException("Security context should be set, or auth type must be of type String"); } - final String authType = (String)typeValue; - final SSLInputs sslInputs = buildSSLInputs(authType); + final String authType = (String) typeValue; + final SSLInputs sslInputs = buildSSLInputs(authType); DatabaseClientFactory.SecurityContext securityContext = newSecurityContext(authType, sslInputs); - - X509TrustManager trustManager = determineTrustManager(sslInputs); - SSLContext sslContext = sslInputs.getSslContext() != null ? - sslInputs.getSslContext() : - determineSSLContext(sslInputs, trustManager); - - if (sslContext != null) { - securityContext.withSSLContext(sslContext, trustManager); + if (sslInputs.getSslContext() != null) { + securityContext.withSSLContext(sslInputs.getSslContext(), sslInputs.getTrustManager()); } - securityContext.withSSLHostnameVerifier(determineHostnameVerifier()); return securityContext; } @@ -202,7 +192,7 @@ private String getNullableStringValue(String propertyName) { if (value != null && !(value instanceof String)) { throw new IllegalArgumentException(propertyName + " must be of type String"); } - return (String)value; + return (String) value; } private DatabaseClientFactory.SecurityContext newBasicAuthContext() { @@ -255,57 +245,6 @@ private DatabaseClientFactory.SecurityContext newSAMLAuthContext() { return new DatabaseClientFactory.SAMLAuthContext(getRequiredStringValue("saml.token")); } - private SSLContext determineSSLContext(SSLInputs sslInputs, X509TrustManager trustManager) { - String protocol = sslInputs.getSslProtocol(); - if (protocol != null) { - if ("default".equalsIgnoreCase(protocol)) { - try { - return SSLContext.getDefault(); - } catch (NoSuchAlgorithmException e) { - throw new RuntimeException("Unable to obtain default SSLContext; cause: " + e.getMessage(), e); - } - } - - SSLContext sslContext; - try { - sslContext = SSLContext.getInstance(protocol); - } catch (NoSuchAlgorithmException e) { - throw new RuntimeException("Unable to get SSLContext instance with protocol: " + protocol - + "; cause: " + e.getMessage(), e); - } - // Note that if only a protocol is specified, and not a TrustManager, an attempt will later be made - // to use the JVM's default TrustManager - if (trustManager != null) { - try { - sslContext.init(null, new X509TrustManager[]{trustManager}, null); - } catch (KeyManagementException e) { - throw new RuntimeException("Unable to initialize SSLContext; protocol: " + protocol + "; cause: " + e.getMessage(), e); - } - } - return sslContext; - } - return null; - } - - private X509TrustManager determineTrustManager(SSLInputs sslInputs) { - if (sslInputs.getTrustManager() != null) { - return sslInputs.getTrustManager(); - } - // If the user chooses the "default" SSLContext, then it's already been initialized - but OkHttp still - // needs a separate X509TrustManager, so use the JVM's default trust manager. The assumption is that the - // default SSLContext was initialized with the JVM's default trust manager. A user can of course always override - // this by simply providing their own trust manager. - if ("default".equalsIgnoreCase(sslInputs.getSslProtocol())) { - X509TrustManager defaultTrustManager = SSLUtil.getDefaultTrustManager(); - if (logger.isDebugEnabled() && defaultTrustManager != null && defaultTrustManager.getAcceptedIssuers() != null) { - logger.debug("Count of accepted issuers in default trust manager: {}", - defaultTrustManager.getAcceptedIssuers().length); - } - return defaultTrustManager; - } - return null; - } - private DatabaseClientFactory.SSLHostnameVerifier determineHostnameVerifier() { Object verifierObject = propertySource.apply(PREFIX + "sslHostnameVerifier"); if (verifierObject instanceof DatabaseClientFactory.SSLHostnameVerifier) { @@ -329,37 +268,106 @@ private DatabaseClientFactory.SSLHostnameVerifier determineHostnameVerifier() { * X509TrustManager. * * @param authType used for applying "default" as the SSL protocol for MarkLogic cloud authentication in - * case the user does not define their own SSLContext or SSL protocol + * case the user does not define their own SSLContext or SSL protocol * @return */ private SSLInputs buildSSLInputs(String authType) { - SSLContext sslContext = null; + X509TrustManager userTrustManager = getTrustManager(); + + // Approach 1 - user provides an SSLContext object, in which case there's nothing further to check. + SSLContext sslContext = getSSLContext(); + if (sslContext != null) { + return new SSLInputs(sslContext, userTrustManager); + } + + // Approaches 2 and 3 - user defines an SSL protocol. + // Approach 2 - "default" is a convenience for using the JVM's default SSLContext. + // Approach 3 - create a new SSLContext, and initialize it if the user-provided TrustManager is not null. + final String sslProtocol = getSSLProtocol(authType); + if (sslProtocol != null) { + return "default".equalsIgnoreCase(sslProtocol) ? + useDefaultSSLContext(userTrustManager) : + useNewSSLContext(sslProtocol, userTrustManager); + } + + // Approach 4 - no SSL connection is needed. + return new SSLInputs(null, null); + } + + private X509TrustManager getTrustManager() { + Object val = propertySource.apply(PREFIX + "trustManager"); + if (val != null) { + if (val instanceof X509TrustManager) { + return (X509TrustManager) val; + } else { + throw new IllegalArgumentException("Trust manager must be an instanceof " + X509TrustManager.class.getName()); + } + } + return null; + } + + private SSLContext getSSLContext() { Object val = propertySource.apply(PREFIX + "sslContext"); if (val != null) { if (val instanceof SSLContext) { - sslContext = (SSLContext) val; + return (SSLContext) val; } else { throw new IllegalArgumentException("SSL context must be an instanceof " + SSLContext.class.getName()); } } + return null; + } + private String getSSLProtocol(String authType) { String sslProtocol = getNullableStringValue("sslProtocol"); - if (sslContext == null && - (sslProtocol == null || sslProtocol.trim().length() == 0) && - DatabaseClientBuilder.AUTH_TYPE_MARKLOGIC_CLOUD.equalsIgnoreCase(authType)) { + if (sslProtocol != null) { + sslProtocol = sslProtocol.trim(); + } + // For convenience for MarkLogic Cloud users, assume the JVM's default SSLContext should trust the certificate + // used by MarkLogic Cloud. A user can always override this default behavior by providing their own SSLContext. + if ((sslProtocol == null || sslProtocol.length() == 0) && DatabaseClientBuilder.AUTH_TYPE_MARKLOGIC_CLOUD.equalsIgnoreCase(authType)) { sslProtocol = "default"; } + return sslProtocol; + } + + /** + * Uses the JVM's default SSLContext. Because OkHttp requires a separate TrustManager, this approach will either + * user the user-provided TrustManager or it will assume that the JVM's default TrustManager should be used. + */ + private SSLInputs useDefaultSSLContext(X509TrustManager userTrustManager) { + SSLContext sslContext; + try { + sslContext = SSLContext.getDefault(); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException("Unable to obtain default SSLContext; cause: " + e.getMessage(), e); + } + X509TrustManager trustManager = userTrustManager != null ? userTrustManager : SSLUtil.getDefaultTrustManager(); + return new SSLInputs(sslContext, trustManager); + } - val = propertySource.apply(PREFIX + "trustManager"); - X509TrustManager trustManager = null; - if (val != null) { - if (val instanceof X509TrustManager) { - trustManager = (X509TrustManager) val; - } else { - throw new IllegalArgumentException("Trust manager must be an instanceof " + X509TrustManager.class.getName()); + /** + * Constructs a new SSLContext based on the given protocol (e.g. TLSv1.2). The SSLContext will be initialized if + * the user's TrustManager is not null. Otherwise, OkHttpUtil will eventually initialize the SSLContext using the + * JVM's default TrustManager. + */ + private SSLInputs useNewSSLContext(String sslProtocol, X509TrustManager userTrustManager) { + SSLContext sslContext; + try { + sslContext = SSLContext.getInstance(sslProtocol); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException(String.format("Unable to get SSLContext instance with protocol: %s; cause: %s", + sslProtocol, e.getMessage()), e); + } + if (userTrustManager != null) { + try { + sslContext.init(null, new X509TrustManager[]{userTrustManager}, null); + } catch (KeyManagementException e) { + throw new RuntimeException(String.format("Unable to initialize SSLContext; protocol: %s; cause: %s", + sslProtocol, e.getMessage()), e); } } - return new SSLInputs(sslContext, sslProtocol, trustManager); + return new SSLInputs(sslContext, userTrustManager); } /** @@ -367,12 +375,10 @@ private SSLInputs buildSSLInputs(String authType) { */ private static class SSLInputs { private final SSLContext sslContext; - private final String sslProtocol; private final X509TrustManager trustManager; - public SSLInputs(SSLContext sslContext, String sslProtocol, X509TrustManager trustManager) { + public SSLInputs(SSLContext sslContext, X509TrustManager trustManager) { this.sslContext = sslContext; - this.sslProtocol = sslProtocol; this.trustManager = trustManager; } @@ -380,10 +386,6 @@ public SSLContext getSslContext() { return sslContext; } - public String getSslProtocol() { - return sslProtocol; - } - public X509TrustManager getTrustManager() { return trustManager; } diff --git a/marklogic-client-api/src/main/java/com/marklogic/client/impl/SSLUtil.java b/marklogic-client-api/src/main/java/com/marklogic/client/impl/SSLUtil.java index b7e5894a6..217f850dd 100644 --- a/marklogic-client-api/src/main/java/com/marklogic/client/impl/SSLUtil.java +++ b/marklogic-client-api/src/main/java/com/marklogic/client/impl/SSLUtil.java @@ -15,6 +15,9 @@ */ package com.marklogic.client.impl; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; import javax.net.ssl.X509TrustManager; @@ -25,7 +28,12 @@ public interface SSLUtil { static X509TrustManager getDefaultTrustManager() { - return (X509TrustManager) getDefaultTrustManagers()[0]; + X509TrustManager trustManager = (X509TrustManager) getDefaultTrustManagers()[0]; + Logger logger = LoggerFactory.getLogger(SSLUtil.class); + if (logger.isDebugEnabled() && trustManager.getAcceptedIssuers() != null) { + logger.debug("Count of accepted issuers in default trust manager: {}", trustManager.getAcceptedIssuers().length); + } + return trustManager; } /** diff --git a/marklogic-client-api/src/test/java/com/marklogic/client/test/CheckSSLConnectionTest.java b/marklogic-client-api/src/test/java/com/marklogic/client/test/ssl/CheckSSLConnectionTest.java similarity index 98% rename from marklogic-client-api/src/test/java/com/marklogic/client/test/CheckSSLConnectionTest.java rename to marklogic-client-api/src/test/java/com/marklogic/client/test/ssl/CheckSSLConnectionTest.java index b5ec9dca8..2680eaeb8 100644 --- a/marklogic-client-api/src/test/java/com/marklogic/client/test/CheckSSLConnectionTest.java +++ b/marklogic-client-api/src/test/java/com/marklogic/client/test/ssl/CheckSSLConnectionTest.java @@ -1,9 +1,10 @@ -package com.marklogic.client.test; +package com.marklogic.client.test.ssl; import com.marklogic.client.DatabaseClient; import com.marklogic.client.DatabaseClientFactory; import com.marklogic.client.ForbiddenUserException; import com.marklogic.client.MarkLogicIOException; +import com.marklogic.client.test.Common; import com.marklogic.client.test.junit5.RequireSSLExtension; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; diff --git a/marklogic-client-api/src/test/java/com/marklogic/client/test/SSLTest.java b/marklogic-client-api/src/test/java/com/marklogic/client/test/ssl/SSLTest.java similarity index 98% rename from marklogic-client-api/src/test/java/com/marklogic/client/test/SSLTest.java rename to marklogic-client-api/src/test/java/com/marklogic/client/test/ssl/SSLTest.java index e1025fad5..a08adeff8 100644 --- a/marklogic-client-api/src/test/java/com/marklogic/client/test/SSLTest.java +++ b/marklogic-client-api/src/test/java/com/marklogic/client/test/ssl/SSLTest.java @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.marklogic.client.test; +package com.marklogic.client.test.ssl; import com.marklogic.client.DatabaseClient; import com.marklogic.client.DatabaseClientFactory.SSLHostnameVerifier; import com.marklogic.client.MarkLogicIOException; import com.marklogic.client.document.TextDocumentManager; import com.marklogic.client.io.StringHandle; +import com.marklogic.client.test.Common; import org.junit.jupiter.api.Test; import javax.net.ssl.*; diff --git a/marklogic-client-api/src/test/java/com/marklogic/client/test/TwoWaySSLTest.java b/marklogic-client-api/src/test/java/com/marklogic/client/test/ssl/TwoWaySSLTest.java similarity index 99% rename from marklogic-client-api/src/test/java/com/marklogic/client/test/TwoWaySSLTest.java rename to marklogic-client-api/src/test/java/com/marklogic/client/test/ssl/TwoWaySSLTest.java index a8f7a67af..ca1870cc6 100644 --- a/marklogic-client-api/src/test/java/com/marklogic/client/test/TwoWaySSLTest.java +++ b/marklogic-client-api/src/test/java/com/marklogic/client/test/ssl/TwoWaySSLTest.java @@ -1,4 +1,4 @@ -package com.marklogic.client.test; +package com.marklogic.client.test.ssl; import com.fasterxml.jackson.databind.node.ObjectNode; import com.marklogic.client.DatabaseClient; @@ -8,6 +8,7 @@ import com.marklogic.client.document.DocumentDescriptor; import com.marklogic.client.eval.EvalResultIterator; import com.marklogic.client.io.StringHandle; +import com.marklogic.client.test.Common; import com.marklogic.client.test.junit5.RequireSSLExtension; import com.marklogic.mgmt.ManageClient; import com.marklogic.mgmt.resource.appservers.ServerManager;