Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public abstract class AbstractUaaTokenProvider implements TokenProvider {
private final ConcurrentMap<ConnectionContext, RefreshToken> refreshTokenStreams =
new ConcurrentHashMap<>(1);

private final ConcurrentMap<ConnectionContext, Mono<String>> refreshTokens =
private final ConcurrentMap<ConnectionContext, String> refreshTokens =
new ConcurrentHashMap<>(1);

/**
Expand Down Expand Up @@ -116,7 +116,10 @@ public final Mono<String> getToken(ConnectionContext connectionContext) {

@Override
public void invalidate(ConnectionContext connectionContext) {
this.accessTokens.put(connectionContext, token(connectionContext));
String refreshToken = this.refreshTokens.remove(connectionContext);
if (refreshToken != null) {
this.accessTokens.put(connectionContext, token(connectionContext, refreshToken));
}
}

/**
Expand All @@ -133,6 +136,30 @@ public void invalidate(ConnectionContext connectionContext) {
*/
abstract void tokenRequestTransformer(HttpClientRequest request, HttpClientForm form);

private Mono<String> token(ConnectionContext connectionContext) {
Mono<String> token =
primaryToken(connectionContext)
.doOnSubscribe(s -> LOGGER.debug("Negotiating using token provider"));

return cacheResult(connectionContext, token);
}

private Mono<String> token(ConnectionContext connectionContext, String refreshToken) {
Mono<String> token =
refreshToken(connectionContext, refreshToken)
.doOnSubscribe(s -> LOGGER.debug("Negotiating using refresh token"))
// fall back to primary token in case the refresh_token grant fails
// (expired, revoked, ...)
.switchIfEmpty(
primaryToken(connectionContext)
.doOnSubscribe(
s ->
LOGGER.debug(
"Falling back to token provider")));

return cacheResult(connectionContext, token);
}

private static String extractAccessToken(Map<String, String> payload) {
String accessToken = payload.get(ACCESS_TOKEN);

Expand Down Expand Up @@ -227,8 +254,7 @@ private Consumer<Map<String, String>> extractRefreshToken(ConnectionContext conn
});
}

this.refreshTokens.put(
connectionContext, Mono.just(refreshToken));
this.refreshTokens.put(connectionContext, refreshToken);
getRefreshTokenStream(connectionContext)
.sink
.emitNext(refreshToken, FAIL_FAST);
Expand Down Expand Up @@ -297,30 +323,16 @@ private void setAuthorization(HttpHeaders headers) {
headers.set(AUTHORIZATION, String.format("Basic %s", encoded));
}

private Mono<String> token(ConnectionContext connectionContext) {
Mono<String> cached =
this.refreshTokens
.getOrDefault(connectionContext, Mono.empty())
.flatMap(
refreshToken ->
refreshToken(connectionContext, refreshToken)
.doOnSubscribe(
s ->
LOGGER.debug(
"Negotiating using refresh"
+ " token")))
.switchIfEmpty(
primaryToken(connectionContext)
.doOnSubscribe(
s ->
LOGGER.debug(
"Negotiating using token"
+ " provider")));

/**
* Cache the given mono. If {@link ConnectionContext#getCacheDuration()} is not null, use that
* as the cache TTL. Otherwise, cache indefinitely.
*/
private static Mono<String> cacheResult(
ConnectionContext connectionContext, Mono<String> token) {
return connectionContext
.getCacheDuration()
.map(cached::cache)
.orElseGet(cached::cache)
.map(token::cache)
.orElseGet(token::cache)
.checkpoint();
}

Expand Down