Skip to content

Commit

Permalink
refactor: add request to jwt claims validator (#396)
Browse files Browse the repository at this point in the history
Close: #381
  • Loading branch information
sdelamo committed Sep 25, 2020
1 parent 5bf1762 commit 797decf
Show file tree
Hide file tree
Showing 10 changed files with 322 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ protected boolean validate(@NonNull JWTClaimsSet claimsSet) {
return true;
}

@Deprecated
@Override
public boolean validate(JwtClaims claims) {
return validate(JWTClaimsSetUtils.jwtClaimsSetFromClaims(claims));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
*/
package io.micronaut.security.token.jwt.validator;

import edu.umd.cs.findbugs.annotations.NonNull;
import edu.umd.cs.findbugs.annotations.Nullable;
import io.micronaut.http.HttpRequest;
import io.micronaut.security.token.jwt.config.JwtConfigurationProperties;
import io.micronaut.security.token.jwt.generator.claims.JwtClaims;

Expand All @@ -29,9 +32,14 @@ public interface JwtClaimsValidator {
String PREFIX = JwtConfigurationProperties.PREFIX + ".claims-validators";

/**
*
* @deprecated Use {@link JwtClaimsValidator#validate(JwtClaims, HttpRequest)} instead.
* @param claims JWT Claims
* @return whether the JWT claims pass validation.
*/
@Deprecated
boolean validate(JwtClaims claims);

default boolean validate(@NonNull JwtClaims claims, @Nullable HttpRequest<?> request) {
return validate(claims);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import javax.inject.Inject;
import javax.inject.Singleton;
import java.util.*;
import java.util.Collection;

/**
* @see <a href="https://connect2id.com/products/nimbus-jose-jwt/examples/validating-jwt-access-tokens">Validating JWT Access Tokens</a>
Expand All @@ -36,7 +36,7 @@
public class JwtTokenValidator implements TokenValidator {

protected final JwtAuthenticationFactory jwtAuthenticationFactory;
private final JwtValidator validator;
protected final JwtValidator validator;

/**
* Constructor.
Expand Down Expand Up @@ -69,6 +69,7 @@ public JwtTokenValidator(JwtValidator validator,
}

/***
* @deprecated Use {@link JwtTokenValidator#validateToken(String, io.micronaut.http.HttpRequest)} instead.
* @param token The token string.
* @return Publishes {@link Authentication} based on the JWT or empty if the validation fails.
*/
Expand All @@ -80,5 +81,4 @@ public Publisher<Authentication> validateToken(String token) {
.map(Flowable::just)
.orElse(Flowable.empty());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

import com.nimbusds.jose.*;
import com.nimbusds.jwt.*;
import edu.umd.cs.findbugs.annotations.NonNull;
import edu.umd.cs.findbugs.annotations.Nullable;
import io.micronaut.http.HttpRequest;
import io.micronaut.security.token.jwt.encryption.EncryptionConfiguration;
import io.micronaut.security.token.jwt.generator.claims.JwtClaims;
import io.micronaut.security.token.jwt.generator.claims.JwtClaimsSetAdapter;
Expand All @@ -25,7 +28,12 @@
import org.slf4j.LoggerFactory;

import java.text.ParseException;
import java.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;

/**
* A builder style class for validating JWT tokens against any number of provided
Expand All @@ -43,7 +51,9 @@ public final class JwtValidator {
private final List<EncryptionConfiguration> encryptions;
private final List<JwtClaimsValidator> claimsValidators;

private JwtValidator(List<SignatureConfiguration> signatures, List<EncryptionConfiguration> encryptions, List<JwtClaimsValidator> claimsValidators) {
private JwtValidator(List<SignatureConfiguration> signatures,
List<EncryptionConfiguration> encryptions,
List<JwtClaimsValidator> claimsValidators) {
this.signatures = signatures;
this.encryptions = encryptions;
this.claimsValidators = claimsValidators;
Expand All @@ -52,25 +62,38 @@ private JwtValidator(List<SignatureConfiguration> signatures, List<EncryptionCon
/**
* Validates the supplied token with any configurations and claim validators present.
*
* @deprecated Use {@link JwtValidator#validate(String, HttpRequest)} instead.
* @param token The JWT string
* @return An optional JWT token if validation succeeds
*/
@Deprecated
public Optional<JWT> validate(String token) {
try {
if (hasAtLeastTwoDots(token)) {
JWT jwt = JWTParser.parse(token);
return validate(jwt);
} else {
return validate(token, null);
}

/**
* Validates the supplied token with any configurations and claim validators present.
*
* @param token The JWT string
* @param request HTTP Request
* @return An optional JWT token if validation succeeds
*/
public Optional<JWT> validate(String token, @Nullable HttpRequest<?> request) {
try {
if (hasAtLeastTwoDots(token)) {
JWT jwt = JWTParser.parse(token);
return validate(jwt, request);
} else {
if (LOG.isTraceEnabled()) {
LOG.trace("token {} does not contain two dots", token);
}
}
} catch (final ParseException e) {
if (LOG.isTraceEnabled()) {
LOG.trace("token {} does not contain two dots", token);
LOG.trace("Failed to parse JWT: {}", e.getMessage());
}
}
} catch (final ParseException e) {
if (LOG.isTraceEnabled()) {
LOG.trace("Failed to parse JWT: {}", e.getMessage());
}
}
return Optional.empty();
return Optional.empty();
}

/**
Expand All @@ -85,11 +108,23 @@ private boolean hasAtLeastTwoDots(String token) {

/**
* Validates the supplied token with any configurations and claim validators present.
*
* @deprecated Use {@link JwtValidator#validate(JWT, HttpRequest)} instead
* @param token The JWT token
* @return An optional JWT token if validation succeeds
*/
@Deprecated
public Optional<JWT> validate(JWT token) {
return validate(token, null);
}

/**
* Validates the supplied token with any configurations and claim validators present.
*
* @param token The JWT token
* @param request The HTTP Request which contained the JWT token
* @return An optional JWT token if validation succeeds
*/
public Optional<JWT> validate(@NonNull JWT token, @Nullable HttpRequest<?> request) {
Optional<JWT> validationResult;
if (token instanceof PlainJWT) {
validationResult = validate((PlainJWT) token);
Expand All @@ -106,7 +141,7 @@ public Optional<JWT> validate(JWT token) {
return validationResult.filter(jwt -> {
try {
JwtClaims claims = new JwtClaimsSetAdapter(jwt.getJWTClaimsSet());
return claimsValidators.stream().allMatch(validator -> validator.validate(claims));
return claimsValidators.stream().allMatch(validator -> validator.validate(claims, request));
} catch (ParseException e) {
if (LOG.isErrorEnabled()) {
LOG.error("Failed to retrieve the claims set", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.micronaut.security.token.jwt.generator.claims.JwtClaims;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.inject.Singleton;

/**
Expand Down Expand Up @@ -51,8 +52,9 @@ public boolean validate(JWTClaimsSet claimsSet) {
return hasSubject;
}

@Deprecated
@Override
public boolean validate(JwtClaims claims) {
return validate(JWTClaimsSetUtils.jwtClaimsSetFromClaims(claims));
return validate(JWTClaimsSetUtils.jwtClaimsSetFromClaims(claims));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ import io.micronaut.security.token.jwt.validator.JwtTokenValidator
import io.micronaut.security.token.validator.TokenValidator
import io.micronaut.testutils.EmbeddedServerSpecification
import io.reactivex.Flowable
import spock.lang.AutoCleanup
import spock.lang.Shared

class JwtClaimsOverrideSpec extends EmbeddedServerSpecification {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package io.micronaut.security.token.jwt.validator

import edu.umd.cs.findbugs.annotations.NonNull
import edu.umd.cs.findbugs.annotations.Nullable
import io.micronaut.context.annotation.Requires
import io.micronaut.http.HttpHeaders
import io.micronaut.http.HttpRequest
import io.micronaut.http.HttpResponse
import io.micronaut.http.HttpStatus
import io.micronaut.http.MediaType
import io.micronaut.http.annotation.Controller
import io.micronaut.http.annotation.Get
import io.micronaut.http.annotation.Produces
import io.micronaut.security.annotation.Secured
import io.micronaut.security.authentication.AuthenticationException
import io.micronaut.security.authentication.AuthenticationFailed
import io.micronaut.security.authentication.AuthenticationProvider
import io.micronaut.security.authentication.AuthenticationRequest
import io.micronaut.security.authentication.AuthenticationResponse
import io.micronaut.security.authentication.UserDetails
import io.micronaut.security.authentication.UsernamePasswordCredentials
import io.micronaut.security.rules.SecurityRule
import io.micronaut.security.token.jwt.generator.claims.JwtClaims
import io.micronaut.security.token.jwt.render.BearerAccessRefreshToken
import io.micronaut.testutils.EmbeddedServerSpecification
import io.reactivex.BackpressureStrategy
import io.reactivex.Flowable
import org.reactivestreams.Publisher

import javax.inject.Singleton
import java.security.Principal

class JwtClaimsValidatorRequestNotPassedByDefaultSpec extends EmbeddedServerSpecification {
@Override
String getSpecName() {
'JwtClaimsValidatorRequestNotPassedByDefaultSpec'
}

@Override
Map<String, Object> getConfiguration() {
super.configuration + [
'micronaut.security.token.jwt.signatures.secret.generator.secret': 'pleaseChangeThisSecretForANewOne',
'micronaut.security.authentication' : 'bearer',
]
}

def "by default JwtClaimsValidator which expects request is not invoked"() {
when:
UsernamePasswordCredentials creds = new UsernamePasswordCredentials('user', 'password')
HttpResponse rsp = client.exchange(HttpRequest.POST('/login', creds), BearerAccessRefreshToken)

then:
rsp.status() == HttpStatus.OK
rsp.body().accessToken

when:
final String accessToken = rsp.body().accessToken
HttpRequest request = HttpRequest.GET("/echo/user")
.accept(MediaType.TEXT_PLAIN)
.header(HttpHeaders.AUTHORIZATION, "Bearer $accessToken")
client.exchange(request)

then: // no 401 is thrown because GenericJwtClaimsValidator::validate is invoked claims, null and the HttpRequestClaimsValidator returns true
noExceptionThrown()
}

@Requires(property = 'spec.name', value = 'JwtClaimsValidatorRequestNotPassedByDefaultSpec')
@Singleton
static class HttpRequestClaimsValidator implements GenericJwtClaimsValidator {

@Override
boolean validate(JwtClaims claims) {
false
}

@Override
boolean validate(@NonNull JwtClaims claims, @Nullable HttpRequest<?> request) {
request == null
}
}

@Controller("/echo/user")
@Requires(property = 'spec.name', value = 'JwtClaimsValidatorRequestNotPassedByDefaultSpec')
static class EchoController {

@Secured(SecurityRule.IS_AUTHENTICATED)
@Produces(MediaType.TEXT_PLAIN)
@Get
String index(Principal principal) {
principal.name
}
}

@Singleton
@Requires(property = 'spec.name', value = 'JwtClaimsValidatorRequestNotPassedByDefaultSpec')
static class AuthenticationProviderUserPassword implements AuthenticationProvider {

@Override
Publisher<AuthenticationResponse> authenticate(HttpRequest<?> httpRequest, AuthenticationRequest<?, ?> authenticationRequest) {
Flowable.create({ emitter ->
if (authenticationRequest.identity == 'user' && authenticationRequest.secret == 'password') {
emitter.onNext(new UserDetails('user', []))
emitter.onComplete()
} else {
emitter.onError(new AuthenticationException(new AuthenticationFailed()))
}

}, BackpressureStrategy.ERROR)
}
}
}

0 comments on commit 797decf

Please sign in to comment.