Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion marklogic-client-api/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ dependencies {

// Allows talking to the Manage API. It depends on the Java Client itself, which will usually be a slightly older
// version, but that should not have any impact on the tests.
testImplementation "com.marklogic:ml-app-deployer:4.4.0"
testImplementation "com.marklogic:ml-app-deployer:4.5.1"

// Starting with mockito 5.x, Java 11 is required, so sticking with 4.x as we have to support Java 8.
testImplementation "org.mockito:mockito-core:4.11.0"
testImplementation "org.mockito:mockito-inline:4.11.0"
testImplementation "com.squareup.okhttp3:mockwebserver:4.10.0"

testImplementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-xml:2.14.1'
testImplementation 'ch.qos.logback:logback-classic:1.3.5'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.marklogic.client.DatabaseClientFactory.MarkLogicCloudAuthContext;
import okhttp3.*;
import okhttp3.Call;
import okhttp3.FormBody;
import okhttp3.HttpUrl;
import okhttp3.Interceptor;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;

public class MarkLogicCloudAuthenticationConfigurer implements AuthenticationConfigurer<MarkLogicCloudAuthContext> {

private final static Logger logger = LoggerFactory.getLogger(MarkLogicCloudAuthenticationConfigurer.class);

private String host;

public MarkLogicCloudAuthenticationConfigurer(String host) {
Expand All @@ -40,78 +44,155 @@ public void configureAuthentication(OkHttpClient.Builder clientBuilder, MarkLogi
if (apiKey == null || apiKey.trim().length() < 1) {
throw new IllegalArgumentException("No API key provided");
}
TokenGenerator tokenGenerator = new DefaultTokenGenerator(this.host, securityContext);
clientBuilder.addInterceptor(new TokenAuthenticationInterceptor(tokenGenerator));
}

final Response response = callTokenEndpoint(securityContext);
final String accessToken = getAccessTokenFromResponse(response);
if (logger.isInfoEnabled()) {
logger.info("Successfully obtained authentication token");
}
clientBuilder
.addInterceptor(chain -> {
Request authenticatedRequest = chain.request().newBuilder()
.header("Authorization", "Bearer " + accessToken)
.build();
return chain.proceed(authenticatedRequest);
});
/**
* Exists solely for mocking in unit tests.
*/
public interface TokenGenerator {
String generateToken();
}

private Response callTokenEndpoint(MarkLogicCloudAuthContext securityContext) {
final HttpUrl tokenUrl = buildTokenUrl(securityContext);
OkHttpClient.Builder clientBuilder = OkHttpUtil.newClientBuilder();
// Current assumption is that the SSL config provided for connecting to MarkLogic should also be applicable
// for connecting to MarkLogic Cloud's "/token" endpoint.
OkHttpUtil.configureSocketFactory(clientBuilder, securityContext.getSSLContext(), securityContext.getTrustManager());
OkHttpUtil.configureHostnameVerifier(clientBuilder, securityContext.getSSLHostnameVerifier());
/**
* Knows how to call the "/token" endpoint in MarkLogic Cloud to generate a new token based on the
* user-provided API key.
*/
static class DefaultTokenGenerator implements TokenGenerator {

private final static Logger logger = LoggerFactory.getLogger(DefaultTokenGenerator.class);
private String host;
private MarkLogicCloudAuthContext securityContext;

public DefaultTokenGenerator(String host, MarkLogicCloudAuthContext securityContext) {
this.host = host;
this.securityContext = securityContext;
}

if (logger.isInfoEnabled()) {
logger.info("Calling token endpoint at: " + tokenUrl);
public String generateToken() {
final Response tokenResponse = callTokenEndpoint();
String token = getAccessTokenFromResponse(tokenResponse);
if (logger.isInfoEnabled()) {
logger.info("Successfully obtained authentication token");
}
return token;
}

final Call call = clientBuilder
.build()
.newCall(new Request.Builder()
.url(tokenUrl)
.post(newFormBody(securityContext))
.build()
private Response callTokenEndpoint() {
final HttpUrl tokenUrl = buildTokenUrl();
OkHttpClient.Builder clientBuilder = OkHttpUtil.newClientBuilder();
// Current assumption is that the SSL config provided for connecting to MarkLogic should also be applicable
// for connecting to MarkLogic Cloud's "/token" endpoint.
OkHttpUtil.configureSocketFactory(clientBuilder, securityContext.getSSLContext(), securityContext.getTrustManager());
OkHttpUtil.configureHostnameVerifier(clientBuilder, securityContext.getSSLHostnameVerifier());

if (logger.isInfoEnabled()) {
logger.info("Calling token endpoint at: " + tokenUrl);
}

final Call call = clientBuilder.build().newCall(
new Request.Builder()
.url(tokenUrl)
.post(newFormBody())
.build()
);

try {
return call.execute();
} catch (IOException e) {
throw new RuntimeException(String.format("Unable to call token endpoint at %s; cause: %s",
tokenUrl, e.getMessage(), e));
try {
return call.execute();
} catch (IOException e) {
throw new RuntimeException(String.format("Unable to call token endpoint at %s; cause: %s",
tokenUrl, e.getMessage(), e));
}
}
}

protected HttpUrl buildTokenUrl(MarkLogicCloudAuthContext securityContext) {
// For the near future, it's guaranteed that https and 443 will be required for connecting to MarkLogic Cloud,
// so providing the ability to customize this would be misleading.
return new HttpUrl.Builder()
.scheme("https")
.host(host)
.port(443)
.build()
.resolve(securityContext.getTokenEndpoint()).newBuilder().build();
}
protected HttpUrl buildTokenUrl() {
// For the near future, it's guaranteed that https and 443 will be required for connecting to MarkLogic Cloud,
// so providing the ability to customize this would be misleading.
return new HttpUrl.Builder()
.scheme("https")
.host(host)
.port(443)
.build()
.resolve(securityContext.getTokenEndpoint()).newBuilder().build();
}

protected FormBody newFormBody(MarkLogicCloudAuthContext securityContext) {
return new FormBody.Builder()
.add("grant_type", securityContext.getGrantType())
.add("key", securityContext.getApiKey()).build();
protected FormBody newFormBody() {
return new FormBody.Builder()
.add("grant_type", securityContext.getGrantType())
.add("key", securityContext.getApiKey()).build();
}

private String getAccessTokenFromResponse(Response response) {
String responseBody = null;
JsonNode payload;
try {
responseBody = response.body().string();
payload = new ObjectMapper().readTree(responseBody);
} catch (IOException ex) {
throw new RuntimeException("Unable to get access token; response: " + responseBody, ex);
}
if (!payload.has("access_token")) {
throw new RuntimeException("Unable to get access token; unexpected JSON response: " + payload);
}
return payload.get("access_token").asText();
}
}

private String getAccessTokenFromResponse(Response response) {
String responseBody = null;
JsonNode payload;
try {
responseBody = response.body().string();
payload = new ObjectMapper().readTree(responseBody);
} catch (IOException ex) {
throw new RuntimeException("Unable to get access token; response: " + responseBody, ex);
/**
* OkHttp interceptor that handles adding a token to an HTTP request and renewing it when necessary.
*/
static class TokenAuthenticationInterceptor implements Interceptor {

private final static Logger logger = LoggerFactory.getLogger(TokenAuthenticationInterceptor.class);

private TokenGenerator tokenGenerator;
private String token;

public TokenAuthenticationInterceptor(TokenGenerator tokenGenerator) {
this.tokenGenerator = tokenGenerator;
this.token = tokenGenerator.generateToken();
}
if (!payload.has("access_token")) {
throw new RuntimeException("Unable to get access token; unexpected JSON response: " + payload);

@Override
public Response intercept(Chain chain) throws IOException {
Response response = chain.proceed(addTokenToRequest(chain));
if (response.code() == 403) {
logger.info("Received 403; will generate new token if necessary and retry request");
response.close();
final String currentToken = this.token;
generateNewTokenIfNecessary(currentToken);
response = chain.proceed(addTokenToRequest(chain));
}
return response;
}

/**
* In the case of N threads using the same DatabaseClient - e.g. when using DMSDK - all of them
* may make a request at the same time and get a 403 back. Functionally, it should be fine if all
* make their own requests to renew the token, with the last thread being the one whose token
* value is retained on this class. But to simplify matters, this block is synchronized so only one
* thread can be in here. And if that thread finds that this.token is different from currentToken,
* then some other thread already renewed the token - so this thread doesn't need to do anything and
* can just try again.
*
* @param currentToken the value of this instance's token right before calling this method; in the event that
* another thread using this instance got here first, then this value will differ from the
* instance's token field
*/
private synchronized void generateNewTokenIfNecessary(String currentToken) {
if (currentToken.equals(this.token)) {
logger.info("Generating new token based on receiving 403");
this.token = tokenGenerator.generateToken();
} else if (logger.isDebugEnabled()) {
logger.debug("This instance's token has already been updated, presumably by another thread");
}
}

private Request addTokenToRequest(Chain chain) {
return chain.request().newBuilder()
.header("Authorization", "Bearer " + token)
.build();
}
return payload.get("access_token").asText();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ public class MarkLogicCloudAuthenticationConfigurerTest {

@Test
void buildTokenUrl() throws Exception {
HttpUrl tokenUrl = new MarkLogicCloudAuthenticationConfigurer("somehost").buildTokenUrl(
MarkLogicCloudAuthenticationConfigurer.DefaultTokenGenerator client = new MarkLogicCloudAuthenticationConfigurer.DefaultTokenGenerator("somehost",
new DatabaseClientFactory.MarkLogicCloudAuthContext("doesnt-matter")
.withSSLContext(SSLContext.getDefault(), null)
);

HttpUrl tokenUrl = client.buildTokenUrl();
assertEquals("https://somehost/token", tokenUrl.toString());
}

Expand All @@ -30,17 +32,20 @@ void buildTokenUrl() throws Exception {
*/
@Test
void buildTokenUrlWithCustomTokenPath() throws Exception {
HttpUrl tokenUrl = new MarkLogicCloudAuthenticationConfigurer("otherhost").buildTokenUrl(
MarkLogicCloudAuthenticationConfigurer.DefaultTokenGenerator client = new MarkLogicCloudAuthenticationConfigurer.DefaultTokenGenerator("otherhost",
new DatabaseClientFactory.MarkLogicCloudAuthContext("doesnt-matter", "/customToken", "doesnt-matter")
.withSSLContext(SSLContext.getDefault(), null)
);

HttpUrl tokenUrl = client.buildTokenUrl();
assertEquals("https://otherhost/customToken", tokenUrl.toString());
}

@Test
void newFormBody() {
FormBody body = new MarkLogicCloudAuthenticationConfigurer("doesnt-matter")
.newFormBody(new DatabaseClientFactory.MarkLogicCloudAuthContext("myKey"));
FormBody body = new MarkLogicCloudAuthenticationConfigurer.DefaultTokenGenerator("host-doesnt-matter",
new DatabaseClientFactory.MarkLogicCloudAuthContext("myKey"))
.newFormBody();
assertEquals("grant_type", body.name(0));
assertEquals("apikey", body.value(0));
assertEquals("key", body.name(1));
Expand All @@ -53,8 +58,9 @@ void newFormBody() {
*/
@Test
void newFormBodyWithOverrides() {
FormBody body = new MarkLogicCloudAuthenticationConfigurer("doesnt-matter")
.newFormBody(new DatabaseClientFactory.MarkLogicCloudAuthContext("myKey", "doesnt-matter", "custom-grant-type"));
FormBody body = new MarkLogicCloudAuthenticationConfigurer.DefaultTokenGenerator("host-doesnt-matter",
new DatabaseClientFactory.MarkLogicCloudAuthContext("myKey", "doesnt-matter", "custom-grant-type"))
.newFormBody();
assertEquals("grant_type", body.name(0));
assertEquals("custom-grant-type", body.value(0));
assertEquals("key", body.name(1));
Expand Down
Loading