Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ public DatabaseClientBuilder withSecurityContext(DatabaseClientFactory.SecurityC
}

/**
*
* @param type must be one of "basic", "digest", "cloud", "kerberos", "certificate", or "saml"
* @return
*/
Expand All @@ -143,17 +142,9 @@ public DatabaseClientBuilder withDigestAuth(String username, String password) {
}

public DatabaseClientBuilder withMarkLogicCloudAuth(String apiKey, String basePath) {
withSecurityContextType(SECURITY_CONTEXT_TYPE_MARKLOGIC_CLOUD)
return withSecurityContextType(SECURITY_CONTEXT_TYPE_MARKLOGIC_CLOUD)
.withCloudApiKey(apiKey)
.withBasePath(basePath);

// Assume sensible defaults for establishing an SSL connection. In the scenario where the user's JVM's
// truststore has a certificate matching that of the MarkLogic Cloud instance, this saves the user from having
// to configure anything except the API key and base path.
if (null == props.get(PREFIX + "sslProtocol") && null == props.get(PREFIX + "sslContext")) {
withSSLProtocol("default");
}
return this;
}

public DatabaseClientBuilder withKerberosAuth(String principal) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,8 @@ public String getCertificatePassword() {
* <li>marklogic.client.database = must be a String</li>
* <li>marklogic.client.connectionType = must be a String or instance of {@code ConnectionType}</li>
* <li>marklogic.client.securityContext = an instance of {@code SecurityContext}; if set, then all other
* properties pertaining to the construction of a {@code SecurityContext} will be ignored</li>
* properties pertaining to the construction of a {@code SecurityContext} will be ignored, including the
* properties pertaing to SSL</li>
* <li>marklogic.client.securityContextType = required if marklogic.client.securityContext is not set;
* must be a String and one of "basic", "digest", "cloud", "kerberos", "certificate", or "saml"</li>
* <li>marklogic.client.username = must be a String; required for basic and digest authentication</li>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,36 @@ public class DatabaseClientPropertySource {

static {
connectionPropertyHandlers = new LinkedHashMap<>();
connectionPropertyHandlers.put(PREFIX + "host", (bean, value) -> bean.setHost((String) value));
connectionPropertyHandlers.put(PREFIX + "host", (bean, value) -> {
if (value instanceof String) {
bean.setHost((String) value);
} else {
throw new IllegalArgumentException("Host must be of type String");
}
});
connectionPropertyHandlers.put(PREFIX + "port", (bean, value) -> {
if (value instanceof String) {
bean.setPort(Integer.parseInt((String) value));
} else {
} else if (value instanceof Integer) {
bean.setPort((int) value);
} else {
throw new IllegalArgumentException("Port must be of type String or Integer");
}
});
connectionPropertyHandlers.put(PREFIX + "database", (bean, value) -> {
if (value instanceof String) {
bean.setDatabase((String) value);
} else {
throw new IllegalArgumentException("Database must be of type String");
}
});
connectionPropertyHandlers.put(PREFIX + "basePath", (bean, value) -> {
if (value instanceof String) {
bean.setBasePath((String) value);
} else {
throw new IllegalArgumentException("Base path must be of type String");
}
});
connectionPropertyHandlers.put(PREFIX + "database", (bean, value) -> bean.setDatabase((String) value));
connectionPropertyHandlers.put(PREFIX + "basePath", (bean, value) -> bean.setBasePath((String) value));
connectionPropertyHandlers.put(PREFIX + "connectionType", (bean, value) -> {
if (value instanceof DatabaseClient.ConnectionType) {
bean.setConnectionType((DatabaseClient.ConnectionType) value);
Expand All @@ -67,15 +87,19 @@ public class DatabaseClientPropertySource {
if (val.trim().length() > 0) {
bean.setConnectionType(DatabaseClient.ConnectionType.valueOf(val.toUpperCase()));
}
} else
} else {
throw new IllegalArgumentException("Connection type must either be a String or an instance of ConnectionType");
}
});
}

public DatabaseClientPropertySource(Function<String, Object> propertySource) {
this.propertySource = propertySource;
}

/**
* @return an instance of {@code DatabaseClient} based on the given property source
*/
public DatabaseClient newClient() {
DatabaseClientFactory.Bean bean = newClientBean();
// For consistency with how clients have been created - i.e. not via a Bean class, but via
Expand All @@ -85,9 +109,12 @@ public DatabaseClient newClient() {
// (and this behavior is expected by some existing tests).
return DatabaseClientFactory.newClient(bean.getHost(), bean.getPort(), bean.getBasePath(), bean.getDatabase(),
bean.getSecurityContext(), bean.getConnectionType());

}

/**
* @return an instance of {@code DatabaseClientFactory.Bean} based on the given property source. This is primarily
* intended for testing purposes so that the Bean can be verified without creating a client.
*/
public DatabaseClientFactory.Bean newClientBean() {
final DatabaseClientFactory.Bean bean = new DatabaseClientFactory.Bean();
connectionPropertyHandlers.forEach((propName, consumer) -> {
Expand All @@ -101,20 +128,28 @@ public DatabaseClientFactory.Bean newClientBean() {
}

private DatabaseClientFactory.SecurityContext newSecurityContext() {
DatabaseClientFactory.SecurityContext securityContext = (DatabaseClientFactory.SecurityContext)
propertySource.apply(PREFIX + "securityContext");
if (securityContext != null) {
return securityContext;
Object securityContextValue = propertySource.apply(PREFIX + "securityContext");
if (securityContextValue != null) {
if (securityContextValue instanceof DatabaseClientFactory.SecurityContext) {
return (DatabaseClientFactory.SecurityContext) securityContextValue;
}
throw new IllegalArgumentException("Security context must be of type " + DatabaseClientFactory.SecurityContext.class.getName());
}

String type = (String) propertySource.apply(PREFIX + "securityContextType");
if (type == null || type.trim().length() == 0) {
throw new IllegalArgumentException("Must define a security context or security context type");
Object typeValue = propertySource.apply(PREFIX + "securityContextType");
if (typeValue == null || !(typeValue instanceof String)) {
throw new IllegalArgumentException("Security context should be set, or security context type must be of type String");
}
securityContext = newSecurityContext(type);
final String securityContextType = (String)typeValue;
final SSLInputs sslInputs = buildSSLInputs(securityContextType);

DatabaseClientFactory.SecurityContext securityContext = newSecurityContext(securityContextType, sslInputs);

X509TrustManager trustManager = determineTrustManager(sslInputs);
SSLContext sslContext = sslInputs.getSslContext() != null ?
sslInputs.getSslContext() :
determineSSLContext(sslInputs, trustManager);

X509TrustManager trustManager = determineTrustManager();
SSLContext sslContext = determineSSLContext(trustManager);
if (sslContext != null) {
securityContext.withSSLContext(sslContext, trustManager);
}
Expand All @@ -123,7 +158,7 @@ private DatabaseClientFactory.SecurityContext newSecurityContext() {
return securityContext;
}

private DatabaseClientFactory.SecurityContext newSecurityContext(String type) {
private DatabaseClientFactory.SecurityContext newSecurityContext(String type, SSLInputs sslInputs) {
switch (type.toLowerCase()) {
case DatabaseClientBuilder.SECURITY_CONTEXT_TYPE_BASIC:
return newBasicAuthContext();
Expand All @@ -134,64 +169,68 @@ private DatabaseClientFactory.SecurityContext newSecurityContext(String type) {
case DatabaseClientBuilder.SECURITY_CONTEXT_TYPE_KERBEROS:
return newKerberosAuthContext();
case DatabaseClientBuilder.SECURITY_CONTEXT_TYPE_CERTIFICATE:
return newCertificateAuthContext();
return newCertificateAuthContext(sslInputs);
case DatabaseClientBuilder.SECURITY_CONTEXT_TYPE_SAML:
return newSAMLAuthContext();
default:
throw new IllegalArgumentException("Unrecognized security context type: " + type);
}
}

private String getRequiredStringValue(String propertyName) {
Object value = propertySource.apply(PREFIX + propertyName);
if (value == null || !(value instanceof String)) {
throw new IllegalArgumentException(propertyName + " must be of type String");
}
return (String) value;
}

private String getNullableStringValue(String propertyName) {
Object value = propertySource.apply(PREFIX + propertyName);
if (value != null && !(value instanceof String)) {
throw new IllegalArgumentException(propertyName + " must be of type String");
}
return (String)value;
}

private DatabaseClientFactory.SecurityContext newBasicAuthContext() {
return new DatabaseClientFactory.BasicAuthContext(
(String) propertySource.apply(PREFIX + "username"),
(String) propertySource.apply(PREFIX + "password")
getRequiredStringValue("username"), getRequiredStringValue("password")
);
}

private DatabaseClientFactory.SecurityContext newDigestAuthContext() {
return new DatabaseClientFactory.DigestAuthContext(
(String) propertySource.apply(PREFIX + "username"),
(String) propertySource.apply(PREFIX + "password")
getRequiredStringValue("username"), getRequiredStringValue("password")
);
}

private DatabaseClientFactory.SecurityContext newCloudAuthContext() {
return new DatabaseClientFactory.MarkLogicCloudAuthContext(
(String) propertySource.apply(PREFIX + "cloud.apiKey")
);
return new DatabaseClientFactory.MarkLogicCloudAuthContext(getRequiredStringValue("cloud.apiKey"));
}

private DatabaseClientFactory.SecurityContext newCertificateAuthContext() {
private DatabaseClientFactory.SecurityContext newCertificateAuthContext(SSLInputs sslInputs) {
try {
return new DatabaseClientFactory.CertificateAuthContext(
(String) propertySource.apply(PREFIX + "certificate.file"),
(String) propertySource.apply(PREFIX + "certificate.password"),
determineTrustManager()
getRequiredStringValue("certificate.file"),
getRequiredStringValue("certificate.password"),
sslInputs.getTrustManager()
);
} catch (Exception e) {
throw new RuntimeException("Unable to create CertificateAuthContext; cause " + e.getMessage(), e);
}
}

private DatabaseClientFactory.SecurityContext newKerberosAuthContext() {
return new DatabaseClientFactory.KerberosAuthContext(
(String) propertySource.apply(PREFIX + "kerberos.principal")
);
return new DatabaseClientFactory.KerberosAuthContext(getRequiredStringValue("kerberos.principal"));
}

private DatabaseClientFactory.SecurityContext newSAMLAuthContext() {
return new DatabaseClientFactory.SAMLAuthContext(
(String) propertySource.apply(PREFIX + "saml.token")
);
return new DatabaseClientFactory.SAMLAuthContext(getRequiredStringValue("saml.token"));
}

private SSLContext determineSSLContext(X509TrustManager trustManager) {
SSLContext sslContext = (SSLContext) propertySource.apply(PREFIX + "sslContext");
if (sslContext != null) {
return sslContext;
}
String protocol = (String) propertySource.apply(PREFIX + "sslProtocol");
private SSLContext determineSSLContext(SSLInputs sslInputs, X509TrustManager trustManager) {
String protocol = sslInputs.getSslProtocol();
if (protocol != null) {
if ("default".equalsIgnoreCase(protocol)) {
try {
Expand All @@ -200,6 +239,8 @@ private SSLContext determineSSLContext(X509TrustManager trustManager) {
throw new RuntimeException("Unable to obtain default SSLContext; cause: " + e.getMessage(), e);
}
}

SSLContext sslContext;
try {
sslContext = SSLContext.getInstance(protocol);
} catch (NoSuchAlgorithmException e) {
Expand All @@ -220,20 +261,15 @@ private SSLContext determineSSLContext(X509TrustManager trustManager) {
return null;
}

private X509TrustManager determineTrustManager() {
Object trustManagerObject = propertySource.apply(PREFIX + "trustManager");
if (trustManagerObject != null) {
if (trustManagerObject instanceof X509TrustManager) {
return (X509TrustManager) trustManagerObject;
}
throw new IllegalArgumentException(
String.format("Trust manager must be an instance of %s", X509TrustManager.class.getName()));
private X509TrustManager determineTrustManager(SSLInputs sslInputs) {
if (sslInputs.getTrustManager() != null) {
return sslInputs.getTrustManager();
}
// If the user chooses the "default" SSLContext, then it's already been initialized - but OkHttp still
// needs a separate X509TrustManager, so use the JVM's default trust manager. The assumption is that the
// default SSLContext was initialized with the JVM's default trust manager. A user can of course always override
// this by simply providing their own trust manager.
if ("default".equalsIgnoreCase((String) propertySource.apply(PREFIX + "sslProtocol"))) {
if ("default".equalsIgnoreCase(sslInputs.getSslProtocol())) {
X509TrustManager defaultTrustManager = SSLUtil.getDefaultTrustManager();
if (logger.isDebugEnabled() && defaultTrustManager != null && defaultTrustManager.getAcceptedIssuers() != null) {
logger.debug("Count of accepted issuers in default trust manager: {}",
Expand Down Expand Up @@ -261,4 +297,69 @@ private DatabaseClientFactory.SSLHostnameVerifier determineHostnameVerifier() {
}
return null;
}

/**
* Uses the given propertySource to construct the inputs pertaining to constructing an SSLContext and an
* X509TrustManager.
*
* @param securityContextType used for applying "default" as the SSL protocol for MarkLogic cloud authentication in
* case the user does not define their own SSLContext or SSL protocol
* @return
*/
private SSLInputs buildSSLInputs(String securityContextType) {
SSLContext sslContext = null;
Object val = propertySource.apply(PREFIX + "sslContext");
if (val != null) {
if (val instanceof SSLContext) {
sslContext = (SSLContext) val;
} else {
throw new IllegalArgumentException("SSL context must be an instanceof " + SSLContext.class.getName());
}
}

String sslProtocol = getNullableStringValue("sslProtocol");
if (sslContext == null &&
(sslProtocol == null || sslProtocol.trim().length() == 0) &&
DatabaseClientBuilder.SECURITY_CONTEXT_TYPE_MARKLOGIC_CLOUD.equalsIgnoreCase(securityContextType)) {
sslProtocol = "default";
}

val = propertySource.apply(PREFIX + "trustManager");
X509TrustManager trustManager = null;
if (val != null) {
if (val instanceof X509TrustManager) {
trustManager = (X509TrustManager) val;
} else {
throw new IllegalArgumentException("Trust manager must be an instanceof " + X509TrustManager.class.getName());
}
}
return new SSLInputs(sslContext, sslProtocol, trustManager);
}

/**
* Captures the inputs provided by the caller that pertain to constructing an SSLContext.
*/
private static class SSLInputs {
private final SSLContext sslContext;
private final String sslProtocol;
private final X509TrustManager trustManager;

public SSLInputs(SSLContext sslContext, String sslProtocol, X509TrustManager trustManager) {
this.sslContext = sslContext;
this.sslProtocol = sslProtocol;
this.trustManager = trustManager;
}

public SSLContext getSslContext() {
return sslContext;
}

public String getSslProtocol() {
return sslProtocol;
}

public X509TrustManager getTrustManager() {
return trustManager;
}
}
}
Loading