Skip to content

Commit

Permalink
fix: always rotate refresh tokens for public clients (#2846)
Browse files Browse the repository at this point in the history
When refreshing a token, always rotate for public clients, thus
not requiring rotation to be enabled for all clients and
preventing the possible error condition for public clients.

Change-Id: I6ab80dd8b1928ab55863cea52849ff22f35c2779
  • Loading branch information
mikeroda authored and hsinn0 committed May 17, 2024
1 parent 6249e49 commit 4eb307a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,8 @@ public OAuth2AccessToken refreshAccessToken(String refreshTokenValue, TokenReque
addAuthenticationMethod(claims, additionalRootClaims, authenticationData);

String accessTokenId = generateUniqueTokenId();
refreshTokenValue = refreshTokenCreator.createRefreshTokenValue(jwtToken, claims);
String clientAuth = authenticationData.clientAuth;
refreshTokenValue = refreshTokenCreator.createRefreshTokenValue(jwtToken, claims, clientAuth);
CompositeToken compositeToken =
createCompositeToken(
accessTokenId,
Expand All @@ -313,7 +314,7 @@ refreshTokenValue, new Date(refreshTokenExpireMillis), claims.getJti()
);

String tokenIdToBeDeleted = null;
if (isRevocable && refreshTokenCreator.shouldRotateRefreshTokens()) {
if (isRevocable && refreshTokenCreator.shouldRotateRefreshTokens(clientAuth)) {
tokenIdToBeDeleted = (String) jwtToken.getClaims().get(JTI);
}
return persistRevocableToken(accessTokenId, compositeToken, expiringRefreshToken, claims.getClientId(), user.getId(), isOpaque, isRevocable, tokenIdToBeDeleted);
Expand All @@ -328,7 +329,7 @@ private void addAuthenticationMethod(Claims claims, Map<String, Object> addition
// public refresh flow, allowed if access_token before was also without authentication (claim: client_auth_method=none) and refresh token is one time use (rotate it in refresh)
if (CLIENT_AUTH_NONE.equals(authenticationData.clientAuth) && // current authentication
(!CLIENT_AUTH_NONE.equals(claims.getClientAuth()) || // authentication before
!refreshTokenCreator.shouldRotateRefreshTokens())) {
!refreshTokenCreator.shouldRotateRefreshTokens(authenticationData.clientAuth))) {
throw new TokenRevokedException("Refresh without client authentication not allowed.");
}
addRootClaimEntry(additionalRootClaims, CLIENT_AUTH_METHOD, authenticationData.clientAuth);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,18 @@ public void setTimeService(TimeService timeService) {
this.timeService = timeService;
}

public boolean shouldRotateRefreshTokens() {
return getActiveTokenPolicy().isRefreshTokenRotate();
public boolean shouldRotateRefreshTokens(String clientAuth) {
return getActiveTokenPolicy().isRefreshTokenRotate() || CLIENT_AUTH_NONE.equals(clientAuth);
}

private Map<String, Object> getRefreshedTokenMap(Claims claims) {
claims.setJti(UUID.randomUUID().toString().replace("-", "") + REFRESH_TOKEN_SUFFIX);
return claims.getClaimMap();
}

public String createRefreshTokenValue(JwtTokenSignedByThisUAA jwtToken, Claims claims) {
public String createRefreshTokenValue(JwtTokenSignedByThisUAA jwtToken, Claims claims, String clientAuth) {
String refreshTokenValue;
if (shouldRotateRefreshTokens()) {
if (shouldRotateRefreshTokens(clientAuth)) {
refreshTokenValue = JwtHelper.encode(getRefreshedTokenMap(claims), getActiveKeyInfo()).getEncoded();
} else {
refreshTokenValue = jwtToken.getJwt().getEncoded();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ void testRefreshPublicClientWithRotationAndEmpyAuthentication() {
}

@Test
@DisplayName("Refresh Token with allowpublic but without rotation")
void testRefreshPublicClientWithoutRotation() {
@DisplayName("Refresh Token with allowpublic and implicit rotation")
void testRefreshPublicClientImplicitRotation() {
UaaClientDetails clientDetails = new UaaClientDetails(tokenSupport.defaultClient);
clientDetails.setAutoApproveScopes(singleton("true"));
tokenSupport.clientDetailsService.setClientDetailsStore(IdentityZoneHolder.get().getId(), Collections.singletonMap(CLIENT_ID, clientDetails));
Expand All @@ -196,9 +196,10 @@ void testRefreshPublicClientWithoutRotation() {
assertThat(refreshTokenValue, is(notNullValue()));

setupOAuth2Authentication(oAuth2Request);
RuntimeException exception = assertThrows(TokenRevokedException.class, () ->
tokenServices.refreshAccessToken(refreshTokenValue, new TokenRequest(new HashMap<>(), CLIENT_ID, Lists.newArrayList("openid"), GRANT_TYPE_REFRESH_TOKEN)));
assertEquals("Refresh without client authentication not allowed.", exception.getMessage());
OAuth2AccessToken refreshedToken = tokenServices.refreshAccessToken(refreshTokenValue, new TokenRequest(new HashMap<>(), CLIENT_ID, Lists.newArrayList("openid"), GRANT_TYPE_REFRESH_TOKEN));
assertThat(refreshedToken, is(notNullValue()));
assertNotEquals("New access token should be different from the old one.", refreshTokenValue, refreshedToken.getRefreshToken().getValue());
assertThat((Map<String, Object>) UaaTokenUtils.getClaims(refreshedToken.getValue(), Map.class), hasEntry(CLIENT_AUTH_METHOD, CLIENT_AUTH_NONE));
}

@Test
Expand All @@ -212,13 +213,11 @@ void testRefreshPublicClientButExistingTokenWasEmptyAuthentication() {
Map<String, String> azParameters = new HashMap<>(authorizationRequest.getRequestParameters());
azParameters.put(GRANT_TYPE, GRANT_TYPE_AUTHORIZATION_CODE);
authorizationRequest.setRequestParameters(azParameters);
authorizationRequest.setExtensions(Map.of(CLIENT_AUTH_METHOD, CLIENT_AUTH_EMPTY));
OAuth2Request oAuth2Request = authorizationRequest.createOAuth2Request();
OAuth2Authentication authentication = new OAuth2Authentication(oAuth2Request, tokenSupport.defaultUserAuthentication);
new IdentityZoneManagerImpl().getCurrentIdentityZone().getConfig().getTokenPolicy().setRefreshTokenRotate(true);
CompositeToken accessToken = (CompositeToken) tokenServices.createAccessToken(authentication);

assertThat((Map<String, Object>) UaaTokenUtils.getClaims(accessToken.getValue(), Map.class), hasEntry(CLIENT_AUTH_METHOD, CLIENT_AUTH_NONE));
String refreshTokenValue = accessToken.getRefreshToken().getValue();
assertThat(refreshTokenValue, is(notNullValue()));

Expand Down

0 comments on commit 4eb307a

Please sign in to comment.