Skip to content

Commit

Permalink
Refactor UaaTokenServices + TokenValidation to avoid redundant DB check
Browse files Browse the repository at this point in the history
for opaque tokens

[#117580865] https://www.pivotaltracker.com/story/show/117580865

Signed-off-by: Madhura Bhave <mbhave@pivotal.io>
  • Loading branch information
Jeremy Coffield authored and cf-identity committed Apr 15, 2016
1 parent 2a67a7f commit 78408b0
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 53 deletions.
Expand Up @@ -43,7 +43,6 @@
import org.springframework.beans.factory.InitializingBean;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.GrantedAuthority;
Expand Down Expand Up @@ -189,9 +188,9 @@ public OAuth2AccessToken refreshAccessToken(String refreshTokenValue, TokenReque
+ request.getRequestParameters().get("grant_type"));
}

refreshTokenValue = getJwtTokenValue(refreshTokenValue);

Map<String, Object> claims = getClaimsForToken(refreshTokenValue);
TokenValidation tokenValidation = validateToken(refreshTokenValue);
Map<String, Object> claims = tokenValidation.getClaims();
refreshTokenValue = tokenValidation.getJwt().getEncoded();

// TODO: Should reuse the access token you get after the first
// successful authentication.
Expand Down Expand Up @@ -899,9 +898,9 @@ public OAuth2Authentication loadAuthentication(String accessToken) throws Authen
throw new InvalidTokenException("Invalid access token value, must be at least 30 characters:"+accessToken);
}

accessToken = getJwtTokenValue(accessToken);

Map<String, Object> claims = getClaimsForToken(accessToken);
TokenValidation tokenValidation = validateToken(accessToken);
Map<String, Object> claims = tokenValidation.getClaims();
accessToken = tokenValidation.getJwt().getEncoded();

// Check token expiry
Integer expiration = (Integer) claims.get(EXP);
Expand Down Expand Up @@ -955,25 +954,15 @@ public OAuth2Authentication loadAuthentication(String accessToken) throws Authen
return authentication;
}

protected String getJwtTokenValue(String token) {
if (token.length()<=36) {
try {
token = tokenProvisioning.retrieve(token).getValue();
} catch (EmptyResultDataAccessException x) {
throw new InvalidTokenException("Revocable token with ID:"+ token +" not found.");
}
}
return token;
}

/**
* This method is implemented to support older API calls that assume the
* presence of a token store
*/
@Override
public OAuth2AccessToken readAccessToken(String accessToken) {
accessToken = getJwtTokenValue(accessToken);
Map<String, Object> claims = getClaimsForToken(accessToken);
TokenValidation tokenValidation = validateToken(accessToken);
Map<String, Object> claims = tokenValidation.getClaims();
accessToken = tokenValidation.getJwt().getEncoded();

// Expiry is verified by check_token
CompositeAccessToken token = new CompositeAccessToken(accessToken);
Expand Down Expand Up @@ -1029,8 +1018,8 @@ private Set<String> getAutoApprovedScopes(Object grantType, Collection<String> t
return UaaTokenUtils.retainAutoApprovedScopes(tokenScopes, autoApprovedScopes);
}

protected Map<String, Object> getClaimsForToken(String token) {
TokenValidation tokenValidation = validate(token).throwIfInvalid();
protected TokenValidation validateToken(String token) {
TokenValidation tokenValidation = validate(tokenProvisioning, token).throwIfInvalid();
Jwt tokenJwt = tokenValidation.getJwt();
Map<String, Object> claims = tokenValidation.getClaims();

Expand Down Expand Up @@ -1077,7 +1066,7 @@ protected Map<String, Object> getClaimsForToken(String token) {
String currentRevocationSignature = UaaTokenUtils.getRevocableTokenSignature(client, user);
tokenValidation.checkRevocationSignature(currentRevocationSignature).throwIfInvalid();

return claims;
return tokenValidation;
}

/**
Expand Down
Expand Up @@ -67,12 +67,23 @@ public class TokenValidation {
private final String token;
private final boolean decoded; // this is used to avoid checking claims on tokens that had errors when decoding
private final List<RuntimeException> validationErrors = new ArrayList<>();
private final boolean serverSide;

public static TokenValidation validate(String token) {
return new TokenValidation(token);
}

public static TokenValidation validate(RevocableTokenProvisioning tokenProvisioning, String token) {
Pattern jwtPattern = Pattern.compile("[a-zA-Z0-9_\\-\\\\=]*\\.[a-zA-Z0-9_\\-\\\\=]*\\.[a-zA-Z0-9_\\-\\\\=]*");
if(jwtPattern.matcher(token).matches()) {
return new TokenValidation(token);
} else {
return new TokenValidation(tokenProvisioning, token);
}
}

private TokenValidation(String token) {
this.serverSide = false;
this.token = token;

Jwt tokenJwt;
Expand Down Expand Up @@ -102,6 +113,29 @@ private TokenValidation(String token) {
this.decoded = isValid();
}

private TokenValidation(RevocableTokenProvisioning tokenProvisioning, String tokenId) {
this.serverSide = true;

String token;
try {
token = tokenProvisioning.retrieve(tokenId).getValue();
} catch (EmptyResultDataAccessException x) {
token = null;
addError("Revocable token with ID:" + tokenId + " not found.");
}
this.token = token;

if(token != null) {
tokenJwt = JwtHelper.decode(token);
claims = JsonUtils.readValue(tokenJwt.getClaims(), new TypeReference<Map<String, Object>>() {});
this.decoded = true;
} else {
tokenJwt = null;
claims = null;
this.decoded = false;
}
}

public boolean isValid() {
return validationErrors.size() == 0;
}
Expand All @@ -122,12 +156,17 @@ private TokenValidation(TokenValidation source) {
this.tokenJwt = source.tokenJwt;
this.token = source.token;
this.decoded = source.decoded;

this.serverSide = source.serverSide;
this.scopes = source.scopes;
}


public TokenValidation checkSignature(SignatureVerifier verifier) {
if(serverSide) {
// serverSide tokens are not JWT and we should not validate the JWT signature
return this;
}

if(!decoded) { return this; }
try {
tokenJwt.verifySignature(verifier);
Expand Down Expand Up @@ -380,6 +419,11 @@ else if(audClaim == null) {
}

public TokenValidation checkRevocableTokenStore(RevocableTokenProvisioning revocableTokenProvisioning) {
if(serverSide) {
// serverSide tokens are inherently present in the token store
return this;
}

if(!decoded) {
addError("The token could not be checked for revocation.");
return this;
Expand Down
Expand Up @@ -803,7 +803,7 @@ public void testExpiredToken() throws Exception {
tokenServices.setClientDetailsService(clientDetailsService);
accessToken = tokenServices.createAccessToken(authentication);
Thread.sleep(1000);
Claims result = endpoint.checkToken(accessToken.getValue(), Collections.emptyList());
endpoint.checkToken(accessToken.getValue(), Collections.emptyList());
}

@Test(expected = InvalidTokenException.class)
Expand Down
Expand Up @@ -1504,11 +1504,11 @@ public void testLoad_Opaque_AuthenticationForAUser() {
assertThat("Opaque refresh token must be shorter than 37 characters", accessToken.getRefreshToken().getValue().length(), lessThanOrEqualTo(36));

String accessTokenValue = tokenProvisioning.retrieve(composite.getValue()).getValue();
Map<String,Object> accessTokenClaims = tokenServices.getClaimsForToken(accessTokenValue);
Map<String,Object> accessTokenClaims = tokenServices.validateToken(accessTokenValue).getClaims();
assertEquals(true, accessTokenClaims.get(ClaimConstants.REVOCABLE));

String refreshTokenValue = tokenProvisioning.retrieve(composite.getRefreshToken().getValue()).getValue();
Map<String,Object> refreshTokenClaims = tokenServices.getClaimsForToken(refreshTokenValue);
Map<String,Object> refreshTokenClaims = tokenServices.validateToken(refreshTokenValue).getClaims();
assertEquals(true, refreshTokenClaims.get(ClaimConstants.REVOCABLE));


Expand Down

0 comments on commit 78408b0

Please sign in to comment.