diff --git a/silhouette/app/com/mohiva/play/silhouette/impl/authenticators/JWTAuthenticator.scala b/silhouette/app/com/mohiva/play/silhouette/impl/authenticators/JWTAuthenticator.scala index c299c19ae..d5eda9854 100644 --- a/silhouette/app/com/mohiva/play/silhouette/impl/authenticators/JWTAuthenticator.scala +++ b/silhouette/app/com/mohiva/play/silhouette/impl/authenticators/JWTAuthenticator.scala @@ -31,9 +31,11 @@ import com.nimbusds.jwt.JWTClaimsSet import org.joda.time.DateTime import play.api.libs.Crypto import play.api.libs.concurrent.Execution.Implicits._ -import play.api.libs.json.Json +import play.api.libs.json._ import play.api.mvc.{ RequestHeader, Result } +import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.concurrent.Future import scala.util.{ Failure, Success, Try } @@ -55,13 +57,15 @@ import scala.util.{ Failure, Success, Try } * @param lastUsedDate The last used timestamp. * @param expirationDate The expiration time. * @param idleTimeout The time in seconds an authenticator can be idle before it timed out. + * @param customClaims Custom claims to embed into the token. */ case class JWTAuthenticator( id: String, loginInfo: LoginInfo, lastUsedDate: DateTime, expirationDate: DateTime, - idleTimeout: Option[Int]) extends StorableAuthenticator { + idleTimeout: Option[Int], + customClaims: Option[JsObject] = None) extends StorableAuthenticator { /** * The Type of the generated value an authenticator will be serialized to. @@ -282,6 +286,16 @@ class JWTAuthenticatorService( .issuedAt(authenticator.lastUsedDate.getMillis / 1000) .expirationTime(authenticator.expirationDate.getMillis / 1000) + authenticator.customClaims.map { data => + serializeCustomClaims(data).foreach { + case (key, value) => + if (ReservedClaims.contains(key)) { + throw new AuthenticationException(OverrideReservedClaim.format(ID, key, ReservedClaims.mkString(", "))) + } + jwtBuilder.claim(key, value) + } + } + new NimbusJwtWriterFactory() .macSigningWriter(SigningAlgorithm.HS256, settings.sharedSecret) .jsonToJwt(jwtBuilder.build()) @@ -305,12 +319,15 @@ class JWTAuthenticatorService( }.flatMap { c => val subject = if (settings.encryptSubject) Crypto.decryptAES(c.getSubject) else Base64.decode(c.getSubject) buildLoginInfo(subject).map { loginInfo => + val filteredClaims = c.getAllClaims.asScala.filterNot { case (k, v) => ReservedClaims.contains(k) || v == null } + val customClaims = unserializeCustomClaims(filteredClaims) JWTAuthenticator( id = c.getJWTID, loginInfo = loginInfo, lastUsedDate = new DateTime(c.getIssueTime), expirationDate = new DateTime(c.getExpirationTime), - idleTimeout = settings.authenticatorIdleTimeout + idleTimeout = settings.authenticatorIdleTimeout, + customClaims = if (customClaims.keys.isEmpty) None else Some(customClaims) ) } }.recover { @@ -318,6 +335,44 @@ class JWTAuthenticatorService( } } + /** + * Serializes recursively the custom claims. + * + * @param claims The custom claims to serialize. + * @return A map containing custom claims. + */ + private def serializeCustomClaims(claims: JsObject): java.util.Map[String, Any] = { + def toJava(value: JsValue): Any = value match { + case v: JsString => v.value + case v: JsNumber => v.value + case v: JsBoolean => v.value + case v: JsObject => serializeCustomClaims(v) + case v: JsArray => v.value.map(toJava).asJava + case v => throw new AuthenticationException(UnexpectedJsonValue.format(ID, v)) + } + + claims.fieldSet.map { case (name, value) => name -> toJava(value) }.toMap.asJava + } + + /** + * Unserializes recursively the custom claims. + * + * @param claims The custom claims to deserialize. + * @return A Json object representing the custom claims. + */ + private def unserializeCustomClaims(claims: java.util.Map[String, Any]): JsObject = { + def toJson(value: Any): JsValue = value match { + case v: java.lang.String => JsString(v) + case v: java.lang.Number => JsNumber(BigDecimal(v.toString)) + case v: java.lang.Boolean => JsBoolean(v) + case v: java.util.Map[_, _] => unserializeCustomClaims(v.asInstanceOf[java.util.Map[String, Any]]) + case v: java.util.List[_] => JsArray(v.map(toJson)) + case v => throw new AuthenticationException(UnexpectedJsonValue.format(ID, v)) + } + + JsObject(claims.map { case (name, value) => name -> toJson(value) }.toSeq) + } + /** * Builds the login info from Json. * @@ -356,6 +411,13 @@ object JWTAuthenticatorService { val InvalidJWTToken = "[Silhouette][%s] Error on parsing JWT token: %s" val JsonParseError = "[Silhouette][%s] Cannot parse Json: %s" val InvalidJsonFormat = "[Silhouette][%s] Invalid Json format: %s" + val UnexpectedJsonValue = "[Silhouette][%s] Unexpected Json value: %s" + val OverrideReservedClaim = "[Silhouette][%s] Try to overriding a reserved claim `%s`; list of reserved claims: %s" + + /** + * The reserved claims used by the authenticator. + */ + val ReservedClaims = Seq("jti", "iss", "sub", "iat", "exp") } /** diff --git a/silhouette/test/com/mohiva/play/silhouette/impl/authenticators/JWTAuthenticatorSpec.scala b/silhouette/test/com/mohiva/play/silhouette/impl/authenticators/JWTAuthenticatorSpec.scala index eb626801a..06f7dd128 100644 --- a/silhouette/test/com/mohiva/play/silhouette/impl/authenticators/JWTAuthenticatorSpec.scala +++ b/silhouette/test/com/mohiva/play/silhouette/impl/authenticators/JWTAuthenticatorSpec.scala @@ -28,7 +28,7 @@ import org.specs2.matcher.JsonMatchers import org.specs2.mock.Mockito import org.specs2.specification.Scope import play.api.libs.Crypto -import play.api.libs.json.Json +import play.api.libs.json.{ JsNull, Json } import play.api.mvc.Results import play.api.test.{ WithApplication, FakeRequest, PlaySpecification } @@ -97,6 +97,42 @@ class JWTAuthenticatorSpec extends PlaySpecification with Mockito with JsonMatch json must /("iat" -> authenticator.lastUsedDate.getMillis / 1000) } + + "throw an AuthenticationException if a reserved claim will be overriden" in new WithApplication with Context { + val claims = Json.obj( + "jti" -> "reserved" + ) + + service(None).serialize(authenticator.copy(customClaims = Some(claims))) must throwA[AuthenticationException].like { + case e => e.getMessage must startWith(OverrideReservedClaim.format(ID, "jti", "")) + } + } + + "throw an AuthenticationException if an unexpected value was found in the arbitrary claims" in new WithApplication with Context { + val claims = Json.obj( + "null" -> JsNull + ) + + service(None).serialize(authenticator.copy(customClaims = Some(claims))) must throwA[AuthenticationException].like { + case e => e.getMessage must startWith(UnexpectedJsonValue.format(ID, "")) + } + } + + "return a JWT with arbitrary claims" in new WithApplication with Context { + val jwt = service(None).serialize(authenticator.copy(customClaims = Some(customClaims))) + val json = Base64.decode(jwt.split('.').apply(1)) + + json must /("boolean" -> true) + json must /("string" -> "string") + json must /("number" -> 1234567890) + json must /("array") /# 0 / 1 + json must /("array") /# 1 / 2 + json must /("object") / "array" /# 0 / "string1" + json must /("object") / "array" /# 1 / "string2" + json must /("object") / "object" / "array" /# 0 / "string" + json must /("object") / "object" / "array" /# 1 / false + json must /("object") / "object" / "array" /# 2 / ("number" -> 1) + } } "The `unserialize` method of the service" should { @@ -129,7 +165,6 @@ class JWTAuthenticatorSpec extends PlaySpecification with Mockito with JsonMatch settings.encryptSubject returns true val jwt = service(None).serialize(authenticator) - val msg = Pattern.quote(InvalidJWTToken.format(ID, jwt)) service(None).unserialize(jwt) must beSuccessfulTry.withValue(authenticator.copy( expirationDate = authenticator.expirationDate.withMillisOfSecond(0), @@ -141,13 +176,23 @@ class JWTAuthenticatorSpec extends PlaySpecification with Mockito with JsonMatch settings.encryptSubject returns false val jwt = service(None).serialize(authenticator) - val msg = Pattern.quote(InvalidJWTToken.format(ID, jwt)) service(None).unserialize(jwt) must beSuccessfulTry.withValue(authenticator.copy( expirationDate = authenticator.expirationDate.withMillisOfSecond(0), lastUsedDate = authenticator.lastUsedDate.withMillisOfSecond(0) )) } + + "unserialize a JWT with arbitrary claims" in new WithApplication with Context { + settings.encryptSubject returns false + + val jwt = service(None).serialize(authenticator.copy(customClaims = Some(customClaims))) + + service(None).unserialize(jwt) must beSuccessfulTry.like { + case a => + a.customClaims must beSome(customClaims) + } + } } "The `create` method of the service" should { @@ -540,5 +585,21 @@ class JWTAuthenticatorSpec extends PlaySpecification with Mockito with JsonMatch expirationDate = DateTime.now.plusMinutes(12 * 60), idleTimeout = settings.authenticatorIdleTimeout ) + + /** + * Some custom claims. + */ + lazy val customClaims = Json.obj( + "boolean" -> true, + "string" -> "string", + "number" -> 1234567890, + "array" -> Json.arr(1, 2), + "object" -> Json.obj( + "array" -> Seq("string1", "string2"), + "object" -> Json.obj( + "array" -> Json.arr("string", false, Json.obj("number" -> 1)) + ) + ) + ) } }