From 936bda9e0fbba88c7189213a4873f87a496f37c4 Mon Sep 17 00:00:00 2001 From: Rob Rudin Date: Wed, 29 Mar 2023 09:05:07 -0400 Subject: [PATCH] DEVEXP-249 Now renewing tokens in ML Cloud auth --- marklogic-client-api/build.gradle | 3 +- ...arkLogicCloudAuthenticationConfigurer.java | 203 ++++++++++++------ ...ogicCloudAuthenticationConfigurerTest.java | 18 +- .../TokenAuthenticationInterceptorTest.java | 156 ++++++++++++++ 4 files changed, 312 insertions(+), 68 deletions(-) create mode 100644 marklogic-client-api/src/test/java/com/marklogic/client/impl/okhttp/TokenAuthenticationInterceptorTest.java diff --git a/marklogic-client-api/build.gradle b/marklogic-client-api/build.gradle index 54adf1346..87098bc36 100644 --- a/marklogic-client-api/build.gradle +++ b/marklogic-client-api/build.gradle @@ -38,11 +38,12 @@ dependencies { // Allows talking to the Manage API. It depends on the Java Client itself, which will usually be a slightly older // version, but that should not have any impact on the tests. - testImplementation "com.marklogic:ml-app-deployer:4.4.0" + testImplementation "com.marklogic:ml-app-deployer:4.5.1" // Starting with mockito 5.x, Java 11 is required, so sticking with 4.x as we have to support Java 8. testImplementation "org.mockito:mockito-core:4.11.0" testImplementation "org.mockito:mockito-inline:4.11.0" + testImplementation "com.squareup.okhttp3:mockwebserver:4.10.0" testImplementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-xml:2.14.1' testImplementation 'ch.qos.logback:logback-classic:1.3.5' diff --git a/marklogic-client-api/src/main/java/com/marklogic/client/impl/okhttp/MarkLogicCloudAuthenticationConfigurer.java b/marklogic-client-api/src/main/java/com/marklogic/client/impl/okhttp/MarkLogicCloudAuthenticationConfigurer.java index 580a5f0ee..0ce123298 100644 --- a/marklogic-client-api/src/main/java/com/marklogic/client/impl/okhttp/MarkLogicCloudAuthenticationConfigurer.java +++ b/marklogic-client-api/src/main/java/com/marklogic/client/impl/okhttp/MarkLogicCloudAuthenticationConfigurer.java @@ -18,7 +18,13 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.marklogic.client.DatabaseClientFactory.MarkLogicCloudAuthContext; -import okhttp3.*; +import okhttp3.Call; +import okhttp3.FormBody; +import okhttp3.HttpUrl; +import okhttp3.Interceptor; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -26,8 +32,6 @@ public class MarkLogicCloudAuthenticationConfigurer implements AuthenticationConfigurer { - private final static Logger logger = LoggerFactory.getLogger(MarkLogicCloudAuthenticationConfigurer.class); - private String host; public MarkLogicCloudAuthenticationConfigurer(String host) { @@ -40,78 +44,155 @@ public void configureAuthentication(OkHttpClient.Builder clientBuilder, MarkLogi if (apiKey == null || apiKey.trim().length() < 1) { throw new IllegalArgumentException("No API key provided"); } + TokenGenerator tokenGenerator = new DefaultTokenGenerator(this.host, securityContext); + clientBuilder.addInterceptor(new TokenAuthenticationInterceptor(tokenGenerator)); + } - final Response response = callTokenEndpoint(securityContext); - final String accessToken = getAccessTokenFromResponse(response); - if (logger.isInfoEnabled()) { - logger.info("Successfully obtained authentication token"); - } - clientBuilder - .addInterceptor(chain -> { - Request authenticatedRequest = chain.request().newBuilder() - .header("Authorization", "Bearer " + accessToken) - .build(); - return chain.proceed(authenticatedRequest); - }); + /** + * Exists solely for mocking in unit tests. + */ + public interface TokenGenerator { + String generateToken(); } - private Response callTokenEndpoint(MarkLogicCloudAuthContext securityContext) { - final HttpUrl tokenUrl = buildTokenUrl(securityContext); - OkHttpClient.Builder clientBuilder = OkHttpUtil.newClientBuilder(); - // Current assumption is that the SSL config provided for connecting to MarkLogic should also be applicable - // for connecting to MarkLogic Cloud's "/token" endpoint. - OkHttpUtil.configureSocketFactory(clientBuilder, securityContext.getSSLContext(), securityContext.getTrustManager()); - OkHttpUtil.configureHostnameVerifier(clientBuilder, securityContext.getSSLHostnameVerifier()); + /** + * Knows how to call the "/token" endpoint in MarkLogic Cloud to generate a new token based on the + * user-provided API key. + */ + static class DefaultTokenGenerator implements TokenGenerator { + + private final static Logger logger = LoggerFactory.getLogger(DefaultTokenGenerator.class); + private String host; + private MarkLogicCloudAuthContext securityContext; + + public DefaultTokenGenerator(String host, MarkLogicCloudAuthContext securityContext) { + this.host = host; + this.securityContext = securityContext; + } - if (logger.isInfoEnabled()) { - logger.info("Calling token endpoint at: " + tokenUrl); + public String generateToken() { + final Response tokenResponse = callTokenEndpoint(); + String token = getAccessTokenFromResponse(tokenResponse); + if (logger.isInfoEnabled()) { + logger.info("Successfully obtained authentication token"); + } + return token; } - final Call call = clientBuilder - .build() - .newCall(new Request.Builder() - .url(tokenUrl) - .post(newFormBody(securityContext)) - .build() + private Response callTokenEndpoint() { + final HttpUrl tokenUrl = buildTokenUrl(); + OkHttpClient.Builder clientBuilder = OkHttpUtil.newClientBuilder(); + // Current assumption is that the SSL config provided for connecting to MarkLogic should also be applicable + // for connecting to MarkLogic Cloud's "/token" endpoint. + OkHttpUtil.configureSocketFactory(clientBuilder, securityContext.getSSLContext(), securityContext.getTrustManager()); + OkHttpUtil.configureHostnameVerifier(clientBuilder, securityContext.getSSLHostnameVerifier()); + + if (logger.isInfoEnabled()) { + logger.info("Calling token endpoint at: " + tokenUrl); + } + + final Call call = clientBuilder.build().newCall( + new Request.Builder() + .url(tokenUrl) + .post(newFormBody()) + .build() ); - try { - return call.execute(); - } catch (IOException e) { - throw new RuntimeException(String.format("Unable to call token endpoint at %s; cause: %s", - tokenUrl, e.getMessage(), e)); + try { + return call.execute(); + } catch (IOException e) { + throw new RuntimeException(String.format("Unable to call token endpoint at %s; cause: %s", + tokenUrl, e.getMessage(), e)); + } } - } - protected HttpUrl buildTokenUrl(MarkLogicCloudAuthContext securityContext) { - // For the near future, it's guaranteed that https and 443 will be required for connecting to MarkLogic Cloud, - // so providing the ability to customize this would be misleading. - return new HttpUrl.Builder() - .scheme("https") - .host(host) - .port(443) - .build() - .resolve(securityContext.getTokenEndpoint()).newBuilder().build(); - } + protected HttpUrl buildTokenUrl() { + // For the near future, it's guaranteed that https and 443 will be required for connecting to MarkLogic Cloud, + // so providing the ability to customize this would be misleading. + return new HttpUrl.Builder() + .scheme("https") + .host(host) + .port(443) + .build() + .resolve(securityContext.getTokenEndpoint()).newBuilder().build(); + } - protected FormBody newFormBody(MarkLogicCloudAuthContext securityContext) { - return new FormBody.Builder() - .add("grant_type", securityContext.getGrantType()) - .add("key", securityContext.getApiKey()).build(); + protected FormBody newFormBody() { + return new FormBody.Builder() + .add("grant_type", securityContext.getGrantType()) + .add("key", securityContext.getApiKey()).build(); + } + + private String getAccessTokenFromResponse(Response response) { + String responseBody = null; + JsonNode payload; + try { + responseBody = response.body().string(); + payload = new ObjectMapper().readTree(responseBody); + } catch (IOException ex) { + throw new RuntimeException("Unable to get access token; response: " + responseBody, ex); + } + if (!payload.has("access_token")) { + throw new RuntimeException("Unable to get access token; unexpected JSON response: " + payload); + } + return payload.get("access_token").asText(); + } } - private String getAccessTokenFromResponse(Response response) { - String responseBody = null; - JsonNode payload; - try { - responseBody = response.body().string(); - payload = new ObjectMapper().readTree(responseBody); - } catch (IOException ex) { - throw new RuntimeException("Unable to get access token; response: " + responseBody, ex); + /** + * OkHttp interceptor that handles adding a token to an HTTP request and renewing it when necessary. + */ + static class TokenAuthenticationInterceptor implements Interceptor { + + private final static Logger logger = LoggerFactory.getLogger(TokenAuthenticationInterceptor.class); + + private TokenGenerator tokenGenerator; + private String token; + + public TokenAuthenticationInterceptor(TokenGenerator tokenGenerator) { + this.tokenGenerator = tokenGenerator; + this.token = tokenGenerator.generateToken(); } - if (!payload.has("access_token")) { - throw new RuntimeException("Unable to get access token; unexpected JSON response: " + payload); + + @Override + public Response intercept(Chain chain) throws IOException { + Response response = chain.proceed(addTokenToRequest(chain)); + if (response.code() == 403) { + logger.info("Received 403; will generate new token if necessary and retry request"); + response.close(); + final String currentToken = this.token; + generateNewTokenIfNecessary(currentToken); + response = chain.proceed(addTokenToRequest(chain)); + } + return response; + } + + /** + * In the case of N threads using the same DatabaseClient - e.g. when using DMSDK - all of them + * may make a request at the same time and get a 403 back. Functionally, it should be fine if all + * make their own requests to renew the token, with the last thread being the one whose token + * value is retained on this class. But to simplify matters, this block is synchronized so only one + * thread can be in here. And if that thread finds that this.token is different from currentToken, + * then some other thread already renewed the token - so this thread doesn't need to do anything and + * can just try again. + * + * @param currentToken the value of this instance's token right before calling this method; in the event that + * another thread using this instance got here first, then this value will differ from the + * instance's token field + */ + private synchronized void generateNewTokenIfNecessary(String currentToken) { + if (currentToken.equals(this.token)) { + logger.info("Generating new token based on receiving 403"); + this.token = tokenGenerator.generateToken(); + } else if (logger.isDebugEnabled()) { + logger.debug("This instance's token has already been updated, presumably by another thread"); + } + } + + private Request addTokenToRequest(Chain chain) { + return chain.request().newBuilder() + .header("Authorization", "Bearer " + token) + .build(); } - return payload.get("access_token").asText(); } } diff --git a/marklogic-client-api/src/test/java/com/marklogic/client/impl/okhttp/MarkLogicCloudAuthenticationConfigurerTest.java b/marklogic-client-api/src/test/java/com/marklogic/client/impl/okhttp/MarkLogicCloudAuthenticationConfigurerTest.java index 924d7e9b3..a590f7590 100644 --- a/marklogic-client-api/src/test/java/com/marklogic/client/impl/okhttp/MarkLogicCloudAuthenticationConfigurerTest.java +++ b/marklogic-client-api/src/test/java/com/marklogic/client/impl/okhttp/MarkLogicCloudAuthenticationConfigurerTest.java @@ -17,10 +17,12 @@ public class MarkLogicCloudAuthenticationConfigurerTest { @Test void buildTokenUrl() throws Exception { - HttpUrl tokenUrl = new MarkLogicCloudAuthenticationConfigurer("somehost").buildTokenUrl( + MarkLogicCloudAuthenticationConfigurer.DefaultTokenGenerator client = new MarkLogicCloudAuthenticationConfigurer.DefaultTokenGenerator("somehost", new DatabaseClientFactory.MarkLogicCloudAuthContext("doesnt-matter") .withSSLContext(SSLContext.getDefault(), null) ); + + HttpUrl tokenUrl = client.buildTokenUrl(); assertEquals("https://somehost/token", tokenUrl.toString()); } @@ -30,17 +32,20 @@ void buildTokenUrl() throws Exception { */ @Test void buildTokenUrlWithCustomTokenPath() throws Exception { - HttpUrl tokenUrl = new MarkLogicCloudAuthenticationConfigurer("otherhost").buildTokenUrl( + MarkLogicCloudAuthenticationConfigurer.DefaultTokenGenerator client = new MarkLogicCloudAuthenticationConfigurer.DefaultTokenGenerator("otherhost", new DatabaseClientFactory.MarkLogicCloudAuthContext("doesnt-matter", "/customToken", "doesnt-matter") .withSSLContext(SSLContext.getDefault(), null) ); + + HttpUrl tokenUrl = client.buildTokenUrl(); assertEquals("https://otherhost/customToken", tokenUrl.toString()); } @Test void newFormBody() { - FormBody body = new MarkLogicCloudAuthenticationConfigurer("doesnt-matter") - .newFormBody(new DatabaseClientFactory.MarkLogicCloudAuthContext("myKey")); + FormBody body = new MarkLogicCloudAuthenticationConfigurer.DefaultTokenGenerator("host-doesnt-matter", + new DatabaseClientFactory.MarkLogicCloudAuthContext("myKey")) + .newFormBody(); assertEquals("grant_type", body.name(0)); assertEquals("apikey", body.value(0)); assertEquals("key", body.name(1)); @@ -53,8 +58,9 @@ void newFormBody() { */ @Test void newFormBodyWithOverrides() { - FormBody body = new MarkLogicCloudAuthenticationConfigurer("doesnt-matter") - .newFormBody(new DatabaseClientFactory.MarkLogicCloudAuthContext("myKey", "doesnt-matter", "custom-grant-type")); + FormBody body = new MarkLogicCloudAuthenticationConfigurer.DefaultTokenGenerator("host-doesnt-matter", + new DatabaseClientFactory.MarkLogicCloudAuthContext("myKey", "doesnt-matter", "custom-grant-type")) + .newFormBody(); assertEquals("grant_type", body.name(0)); assertEquals("custom-grant-type", body.value(0)); assertEquals("key", body.name(1)); diff --git a/marklogic-client-api/src/test/java/com/marklogic/client/impl/okhttp/TokenAuthenticationInterceptorTest.java b/marklogic-client-api/src/test/java/com/marklogic/client/impl/okhttp/TokenAuthenticationInterceptorTest.java new file mode 100644 index 000000000..7c01c7703 --- /dev/null +++ b/marklogic-client-api/src/test/java/com/marklogic/client/impl/okhttp/TokenAuthenticationInterceptorTest.java @@ -0,0 +1,156 @@ +package com.marklogic.client.impl.okhttp; + +import com.marklogic.client.ext.helper.LoggingObject; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * Uses OkHttp's MockWebServer to completely mock a MarkLogic instance so that we can control what response codes are + * returned and processed by TokenAuthenticationInterceptor. + */ +public class TokenAuthenticationInterceptorTest extends LoggingObject { + + private MockWebServer mockWebServer; + private FakeTokenGenerator fakeTokenGenerator; + private OkHttpClient okHttpClient; + + @BeforeEach + void beforeEach() { + mockWebServer = new MockWebServer(); + fakeTokenGenerator = new FakeTokenGenerator(); + + MarkLogicCloudAuthenticationConfigurer.TokenAuthenticationInterceptor interceptor = + new MarkLogicCloudAuthenticationConfigurer.TokenAuthenticationInterceptor(fakeTokenGenerator); + assertEquals(1, fakeTokenGenerator.timesInvoked, + "When the interceptor is created, it should immediately generate a token so that when multiple threads " + + "are using the DatabaseClient, they will all use the same token."); + + okHttpClient = new OkHttpClient.Builder().addInterceptor(interceptor).build(); + } + + @Test + void receive403() { + enqueueResponseCodes(200, 200, 403, 200); + + verifyRequestReturnsResponseCode(200); + verifyRequestReturnsResponseCode(200); + verifyRequestReturnsResponseCode(200, + "If a 403 is received from the server, then the token should be renewed, and then the 200 should be " + + "returned to the caller."); + + assertEquals(2, fakeTokenGenerator.timesInvoked, + "A token should have been generated for the first request and then again when the 403 was received."); + } + + @Test + void receive401() { + enqueueResponseCodes(200, 200, 401); + + verifyRequestReturnsResponseCode(200); + verifyRequestReturnsResponseCode(200); + verifyRequestReturnsResponseCode(401); + + assertEquals(1, fakeTokenGenerator.timesInvoked, + "A token should have been generated for the first request, and the 401 should not have resulted in the " + + "token being renewed; only a 403 should."); + } + + @Test + void multipleThreads() throws Exception { + Runnable threadThatMakesThreeCalls = () -> { + for (int i = 0; i < 3; i++) { + sleep(100); + callMockWebServer(); + } + }; + + // Mock up 4 responses for each of the 2 threads created below. For each thread, the first call succeeds; the + // second receives a 403 and then succeeds; and the third call succeeds. + enqueueResponseCodes(200, 200, 403, 403, 200, 200, 200, 200); + + // Spawn two threads and wait for them to complete. + ExecutorService service = Executors.newFixedThreadPool(2); + Future f1 = service.submit(threadThatMakesThreeCalls); + Future f2 = service.submit(threadThatMakesThreeCalls); + f1.get(); + f2.get(); + + assertEquals(2, fakeTokenGenerator.timesInvoked, + "The fake token generator should have been invoked twice - once when the interceptor was created, and then " + + "only one more time when the two threads received 403's at almost the exact same time. The interceptor " + + "is expected to synchronize the call for generating a token such that only one thread will generate a " + + "new token. The other token is expected to see that the token has changed and uses the new token " + + "instead of generating a new token itself."); + } + + /** + * Uses OkHttp's MockWebServer to enqueue mock responses with the given codes. This allows us to mock a 403 being + * returned to ensure that a new token is generated if necessary. + * + * @param codes + */ + private void enqueueResponseCodes(int... codes) { + for (int code : codes) { + mockWebServer.enqueue(new MockResponse().setResponseCode(code)); + } + } + + private void verifyRequestReturnsResponseCode(int expectedCode) { + verifyRequestReturnsResponseCode(expectedCode, null); + } + + private void verifyRequestReturnsResponseCode(int expectedCode, String optionalMessage) { + int actualCode = callMockWebServer(); + if (optionalMessage != null) { + assertEquals(expectedCode, actualCode, optionalMessage); + } else { + assertEquals(expectedCode, actualCode); + } + } + + private int callMockWebServer() { + Request request = new Request.Builder().url(mockWebServer.url("/url-doesnt-matter")).build(); + try { + return okHttpClient.newCall(request).execute().code(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + /** + * Fake token generator that allows us to assert on how many times it's invoked, which ensures that new tokens are + * or are not being generated when required. + */ + private static class FakeTokenGenerator implements MarkLogicCloudAuthenticationConfigurer.TokenGenerator { + int timesInvoked; + + @Override + public String generateToken() { + // A slight delay is added here for the multipleThread test case to simulate the token generation taking + // some amount of time. This allows us to verify that the synchronization is working properly in the + // interceptor. + sleep(100); + timesInvoked++; + return "fake-token-" + timesInvoked; + } + } + + private static void sleep(long ms) { + try { + Thread.sleep(ms); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } +}