Skip to content

Commit

Permalink
#127 Cache verified keys, to improve performance.
Browse files Browse the repository at this point in the history
  • Loading branch information
dniel committed Sep 21, 2019
1 parent 6f91509 commit 387009e
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 70 deletions.
49 changes: 25 additions & 24 deletions src/main/kotlin/dniel/forwardauth/application/AuthorizeHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,29 @@ class AuthorizeHandler(val properties: AuthProperties,
val method: String
) : Command


/**
* Main Handle Command method.
*/
override fun handle(params: AuthorizeCommand): List<AuthEvent> {
val context = createAuthRuleContext(params)
val rules = listOf<AuthRule>(
VerifyAllowSignInRequest(context),
VerifyRestrictedMethod(context),
VerifyHasPermission(context),
VerifyValidAccessToken(context),
VerifyValidIdToken(context),
VerifySameSubInBothTokens(context))

val events = rules.foldRight(mutableListOf<AuthEvent>()) { rule, acc ->
rule.verify(params)?.let {
acc.add(it)
}
acc
}
return events
}

/**
* This command can produce a set of events as response from the handle method.
*/
Expand Down Expand Up @@ -230,28 +253,6 @@ class AuthorizeHandler(val properties: AuthProperties,
}
}

/**
* Main Handle Command method.
*/
override fun handle(params: AuthorizeCommand): List<AuthEvent> {
val context = createAuthRuleContext(params)
val rules = listOf<AuthRule>(
VerifyAllowSignInRequest(context),
VerifyRestrictedMethod(context),
VerifyHasPermission(context),
VerifyValidAccessToken(context),
VerifyValidIdToken(context),
VerifySameSubInBothTokens(context))

val events = rules.foldRight(mutableListOf<AuthEvent>()) { rule, acc ->
rule.verify(params)?.let {
acc.add(it)
}
acc
}
return events
}

private fun createAuthRuleContext(params: AuthorizeCommand): MutableMap<String, Any> {
val app = properties.findApplicationOrDefault(params.host)
val nonce = nonceService.generate()
Expand All @@ -263,8 +264,8 @@ class AuthorizeHandler(val properties: AuthProperties,
LOGGER.debug("Authorize request=${originUrl} to app=${app.name}")
val context = emptyMap<String, Any>().toMutableMap()

val accessToken = verifyTokenService.verify(params.accessToken, app.audience, AUTH_DOMAIN)
val idToken = verifyTokenService.verify(params.idToken, app.clientId, AUTH_DOMAIN)
val accessToken = verifyTokenService.verify(params.accessToken, app.audience)
val idToken = verifyTokenService.verify(params.idToken, app.clientId)

context.put("access_token", accessToken)
context.put("id_token", idToken)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package dniel.forwardauth.domain.service

import com.auth0.jwt.exceptions.JWTVerificationException
import com.auth0.jwt.interfaces.DecodedJWT
import com.google.common.cache.CacheBuilder
import dniel.forwardauth.domain.InvalidToken
import dniel.forwardauth.domain.JwtToken
import dniel.forwardauth.domain.OpaqueToken
import dniel.forwardauth.domain.Token
import org.slf4j.LoggerFactory
import org.springframework.stereotype.Component
import java.util.*
import java.util.concurrent.Callable
import java.util.concurrent.TimeUnit

/**
* Interface for decoder of Jwt Tokens that is provided to the VerifyTokenService
Expand All @@ -16,14 +19,15 @@ import org.springframework.stereotype.Component
* @see dniel.forwardauth.infrastructure.auth0.Auth0JwtDecoder
*/
interface JwtDecoder {
fun verify(token: String, domain: String): DecodedJWT
fun verify(token: String): DecodedJWT
}

@Component
class VerifyTokenService(val decoder: JwtDecoder) {
private val LOGGER = LoggerFactory.getLogger(this.javaClass)
val cache = CacheBuilder.newBuilder().expireAfterAccess(15, TimeUnit.MINUTES).build<String, DecodedJWT>()

fun verify(token: String?, expectedAudience: String, expectedDomain: String): Token {
fun verify(token: String?, expectedAudience: String): Token {
return when {
// if its a null or empty string just fail fast.
token.isNullOrEmpty() -> InvalidToken("Missing token")
Expand All @@ -33,26 +37,27 @@ class VerifyTokenService(val decoder: JwtDecoder) {

else -> {
try {
val decodedJWT = decodeToken(token, expectedDomain)
if (verifyAudience(decodedJWT, expectedAudience)) {
LOGGER.debug("Verify audience failed, expected ${expectedAudience} but got ${decodedJWT.audience}")
throw IllegalStateException("Verify audience failed, expected ${expectedAudience} but got ${decodedJWT.audience}. " +
"Probably error in application configuration or Auth0 service configuration.")
} else {
JwtToken(decodedJWT)
val decodedJWT = cache.get(token, Callable<DecodedJWT> {
decodeToken(token)
})
when {
hasIllegalAudience(decodedJWT, expectedAudience) -> throw IllegalStateException("Verify audience failed, expected ${expectedAudience} but got ${decodedJWT.audience}")
hasExpired(decodedJWT) -> throw IllegalStateException("Token has expired ${decodedJWT.expiresAt}")
else -> JwtToken(decodedJWT)
}
} catch (e: JWTVerificationException) {
val reason = "Failed to decode the token: ${e.message}"
LOGGER.debug(reason)
InvalidToken(reason)
} catch (e: Exception) {
cache.invalidate(token)
InvalidToken("" + e.message)
}
}
}
}

private fun hasExpired(decodedJWT: DecodedJWT): Boolean = decodedJWT.expiresAt.before(Date())

private fun isOpaqueToken(token: String): Boolean = token.split(".").size == 0

private fun verifyAudience(decodedJWT: DecodedJWT, expectedAudience: String): Boolean = !decodedJWT.audience.contains(expectedAudience)
private fun hasIllegalAudience(decodedJWT: DecodedJWT, expectedAudience: String): Boolean = !decodedJWT.audience.contains(expectedAudience)

private fun decodeToken(token: String, domain: String): DecodedJWT = decoder.verify(token, domain)
private fun decodeToken(token: String): DecodedJWT = decoder.verify(token)
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class Auth0Client(val properties: AuthProperties) {
.asJson();
val status = response.status
val body = response.body
LOGGER.trace("Response status: ${response.status}")
LOGGER.trace("Response body: ${response.body}")
LOGGER.trace("Response status: ${status}")
LOGGER.trace("Response body: ${body}")

if (body.`object`.has("error")) {
val error = body.`object`.getString("error")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,33 +1,39 @@
package dniel.forwardauth.infrastructure.auth0

import com.auth0.jwk.GuavaCachedJwkProvider
import com.auth0.jwk.JwkProvider
import com.auth0.jwk.UrlJwkProvider
import com.auth0.jwt.JWT
import com.auth0.jwt.JWTVerifier
import com.auth0.jwt.algorithms.Algorithm
import com.auth0.jwt.interfaces.DecodedJWT
import com.auth0.jwt.interfaces.RSAKeyProvider
import com.google.common.cache.CacheBuilder
import dniel.forwardauth.AuthProperties
import dniel.forwardauth.domain.service.JwtDecoder
import org.slf4j.LoggerFactory
import org.springframework.stereotype.Component
import java.security.interfaces.RSAPrivateKey
import java.security.interfaces.RSAPublicKey
import java.util.concurrent.Callable
import java.util.concurrent.TimeUnit


@Component
class Auth0JwtDecoder : JwtDecoder {
class Auth0JwtDecoder(val properties: AuthProperties) : JwtDecoder {
val LOGGER = LoggerFactory.getLogger(this.javaClass)
var provider: JwkProvider? = null
val AUTH_DOMAIN = properties.domain
val cache = CacheBuilder.newBuilder().expireAfterWrite(1, TimeUnit.HOURS).build<String, DecodedJWT>()
val provider = GuavaCachedJwkProvider(UrlJwkProvider(AUTH_DOMAIN))

override fun verify(token: String, domain: String): DecodedJWT {
this.provider = GuavaCachedJwkProvider(UrlJwkProvider(domain))
return verifyJWT(token, domain)
override fun verify(token: String): DecodedJWT {
return cache.get(token, Callable<DecodedJWT> {
verifyJWT(token, AUTH_DOMAIN)
})
}

private fun verifyJWT(token: String, domain: String): DecodedJWT {
val decodedJWT = JWT.decode(token)
val jwk = provider!!.get(decodedJWT.keyId)
val jwk = provider.get(decodedJWT.keyId)
val keyProvider = object : RSAKeyProvider {
override fun getPublicKeyById(kid: String): RSAPublicKey = jwk.publicKey as RSAPublicKey
override fun getPrivateKey(): RSAPrivateKey? = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class AuthorizeHandlerTest extends Specification {

and: "a stub VerifyTokenService that return a valid JWT JwtToken"
def verifyTokenService = Stub(VerifyTokenService)
verifyTokenService.verify(validJwtTokenString, _, _) >> new JwtToken(jwtToken)
verifyTokenService.verify(validJwtTokenString, _) >> new JwtToken(jwtToken)

and: "a command handler that is the system under test"
AuthorizeHandler sut = new AuthorizeHandler(
Expand Down Expand Up @@ -66,8 +66,8 @@ class AuthorizeHandlerTest extends Specification {

and: "a stub VerifyTokenService that return a valid JWT JwtToken"
def verifyTokenService = Stub(VerifyTokenService)
verifyTokenService.verify(null, _, _) >> new InvalidToken("missing token return invalid token.")
verifyTokenService.verify("", _, _) >> new InvalidToken("missing token return invalid token.")
verifyTokenService.verify(null, _) >> new InvalidToken("missing token return invalid token.")
verifyTokenService.verify("", _) >> new InvalidToken("missing token return invalid token.")

and: "a command handler that is the system under test"
AuthorizeHandler sut = new AuthorizeHandler(
Expand Down Expand Up @@ -105,7 +105,7 @@ class AuthorizeHandlerTest extends Specification {

and: "a stub VerifyTokenService that return a valid JWT JwtToken"
def verifyTokenService = Stub(VerifyTokenService)
verifyTokenService.verify( _, _, _) >> new InvalidToken("simulating an invalid token resposne from the token service.")
verifyTokenService.verify( _, _) >> new InvalidToken("simulating an invalid token resposne from the token service.")

and: "a command handler that is the system under test"
AuthorizeHandler sut = new AuthorizeHandler(
Expand Down Expand Up @@ -136,9 +136,9 @@ class AuthorizeHandlerTest extends Specification {

and: "a stub VerifyTokenService that return a valid JWT JwtToken"
def verifyTokenService = Stub(VerifyTokenService)
verifyTokenService.verify(validJwtTokenString, _, _) >> new JwtToken(jwtToken)
verifyTokenService.verify(null, _, _) >> new InvalidToken("missing token return invalid token.")
verifyTokenService.verify("", _, _) >> new InvalidToken("missing token return invalid token.")
verifyTokenService.verify(validJwtTokenString, _) >> new JwtToken(jwtToken)
verifyTokenService.verify(null, _) >> new InvalidToken("missing token return invalid token.")
verifyTokenService.verify("", _) >> new InvalidToken("missing token return invalid token.")

and: "a command handler that is the system under test"
AuthorizeHandler sut = new AuthorizeHandler(
Expand Down Expand Up @@ -175,7 +175,7 @@ class AuthorizeHandlerTest extends Specification {

and: "a stub VerifyTokenService that return a valid JWT JwtToken"
def verifyTokenService = Stub(VerifyTokenService)
verifyTokenService.verify(_,_,_) >> new JwtToken(jwtToken)
verifyTokenService.verify(_,_) >> new JwtToken(jwtToken)

and: "a command handler that is the system under test"
AuthorizeHandler sut = new AuthorizeHandler(
Expand Down Expand Up @@ -207,7 +207,7 @@ class AuthorizeHandlerTest extends Specification {

and: "a stub VerifyTokenService that return a valid JWT JwtToken"
def verifyTokenService = Stub(VerifyTokenService)
verifyTokenService.verify( _, _, _) >> new JwtToken(jwtToken)
verifyTokenService.verify( _, _) >> new JwtToken(jwtToken)

and: "a command handler that is the system under test"
AuthorizeHandler sut = new AuthorizeHandler(
Expand Down Expand Up @@ -238,7 +238,7 @@ class AuthorizeHandlerTest extends Specification {

and: "a stub VerifyTokenService that return a valid JWT JwtToken"
def verifyTokenService = Stub(VerifyTokenService)
verifyTokenService.verify(_, _, _) >> new InvalidToken(("Just to get a redirect event to check"))
verifyTokenService.verify(_, _) >> new InvalidToken(("Just to get a redirect event to check"))

and: "a command handler that is the system under test"
AuthorizeHandler sut = new AuthorizeHandler(
Expand All @@ -265,7 +265,7 @@ class AuthorizeHandlerTest extends Specification {

and: "a stub VerifyTokenService that return a valid JWT JwtToken"
def verifyTokenService = Stub(VerifyTokenService)
verifyTokenService.verify(_, _, _) >> new InvalidToken(("Just to get a redirect event to check"))
verifyTokenService.verify(_, _) >> new InvalidToken(("Just to get a redirect event to check"))

and: "a command handler that is the system under test"
AuthorizeHandler sut = new AuthorizeHandler(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ import dniel.forwardauth.domain.InvalidToken
import dniel.forwardauth.domain.JwtToken
import spock.lang.Specification

import static org.hamcrest.Matchers.containsString
import static org.hamcrest.Matchers.instanceOf
import static org.hamcrest.Matchers.is
import static org.hamcrest.Matchers.*
import static spock.util.matcher.HamcrestSupport.that

class VerifyTokenServiceTest extends Specification {
Expand All @@ -22,14 +20,14 @@ class VerifyTokenServiceTest extends Specification {

and: "a stubbed jwt decoder"
def decoder = Stub(JwtDecoder) {
verify(_, _) >> exampleToken
verify(_) >> exampleToken
}

and: "a verification token service which is the system under test"
VerifyTokenService sut = new VerifyTokenService(decoder)

when: "we verify the token"
def verifiedToken = sut.verify(tokenString, exampleAudience, domain)
def verifiedToken = sut.verify(tokenString, exampleAudience)

then:
that(verifiedToken, is(instanceOf(JwtToken)))
Expand All @@ -44,17 +42,17 @@ class VerifyTokenServiceTest extends Specification {

and: "a stubbed jwt decoder"
def decoder = Stub(JwtDecoder) {
verify(_, _) >> exampleToken
verify(_) >> exampleToken
}

and: "a verification token service which is the system under test"
VerifyTokenService sut = new VerifyTokenService(decoder)

when: "we verify the token"
def token = sut.verify(tokenString, exampleAudience, domain)
def token = sut.verify(tokenString, exampleAudience)

then:
thrown(IllegalStateException)
that(token, is(instanceOf(InvalidToken)))
}

def "should return invalid token with reason if token fails to decode"() {
Expand All @@ -66,7 +64,7 @@ class VerifyTokenServiceTest extends Specification {

and: "a stubbed jwt decoder that throws an exception on verify"
def decoder = Stub(JwtDecoder) {
verify(_, _) >> {
verify(_) >> {
throw new JWTVerificationException("something went wrong.")
}
}
Expand All @@ -75,7 +73,7 @@ class VerifyTokenServiceTest extends Specification {
VerifyTokenService sut = new VerifyTokenService(decoder)

when: "we verify the token"
def token = sut.verify(tokenString, exampleAudience, domain)
def token = sut.verify(tokenString, exampleAudience)

then:
that(token, is(instanceOf(InvalidToken)))
Expand Down

0 comments on commit 387009e

Please sign in to comment.