diff --git a/core/src/main/scala/de/lhns/jwt/Jwt.scala b/core/src/main/scala/de/lhns/jwt/Jwt.scala index 1942e68..12466cf 100644 --- a/core/src/main/scala/de/lhns/jwt/Jwt.scala +++ b/core/src/main/scala/de/lhns/jwt/Jwt.scala @@ -1,65 +1,38 @@ package de.lhns.jwt -import cats.syntax.bifunctor._ import cats.syntax.either._ import de.lhns.jwt.Jwt._ -import de.lhns.jwt.Jwt.SignedJwt.VerifyPartiallyApplied import io.circe.syntax._ import io.circe.{Codec, Decoder, Encoder, Json} -import scodec.bits.Bases.Alphabets.Base64Url -import scodec.bits.ByteVector import java.nio.charset.StandardCharsets import java.time.Instant -import java.util.Base64 import scala.collection.immutable.ListMap -case class Jwt( - header: JwtHeader, - payload: JwtPayload, - headerBase64: Option[String], - payloadBase64: Option[String] - ) { - def copy( - header: JwtHeader = header, - payload: JwtPayload = payload, - headerBase64: Option[String] = None, - payloadBase64: Option[String] = None - ): Jwt = new Jwt( - header = header, - payload = payload, - headerBase64 = headerBase64, - payloadBase64 = payloadBase64 - ) - +final case class Jwt( + header: JwtHeader, + payload: JwtPayload + ) { def withHeader(header: JwtHeader): Jwt = copy(header = header) def withPayload(payload: JwtPayload): Jwt = copy(payload = payload) - def changeHeader(f: JwtHeader => JwtHeader): Jwt = withHeader(f(header)) + def modifyHeader(f: JwtHeader => JwtHeader): Jwt = withHeader(f(header)) + + def modifyPayload(f: JwtPayload => JwtPayload): Jwt = withPayload(f(payload)) - def changePayload(f: JwtPayload => JwtPayload): Jwt = withPayload(f(payload)) + def reencode: Jwt = copy( + header = header.reencode, + payload = payload.reencode + ) - def encode: String = List[String]( - headerBase64.getOrElse(header.encode), - payloadBase64.getOrElse(payload.encode) - ).mkString(".") + def encode: String = s"${header.encode}.${payload.encode}" def sign[F[_]]: SignPartiallyApplied[F] = new SignPartiallyApplied[F](this) } object Jwt { - def apply( - header: JwtHeader, - payload: JwtPayload - ): Jwt = new Jwt( - header = header, - payload = payload, - headerBase64 = None, - payloadBase64 = None - ) - class SignPartiallyApplied[F[_]](jwt: Jwt) { def apply[Algorithm <: JwtAlgorithm, Key]( algorithm: Algorithm, @@ -67,54 +40,97 @@ object Jwt { )( implicit signer: JwtSigner[F, Algorithm, Key] ): F[SignedJwt] = - signer.sign(jwt.changeHeader(_.withAlgorithm(Some(algorithm))), algorithm, key) + signer.sign(jwt.modifyHeader(_.withAlgorithm(Some(algorithm))), algorithm, key) } - trait JwtComponent[Self <: JwtComponent[Self]] { + trait JwtComponent { + type Self <: JwtComponent + def claims: ListMap[String, Json] def withClaims(claims: ListMap[String, Json]): Self - def claim[A: Decoder](name: String): Option[A] = + final def modifyClaims(f: ListMap[String, Json] => ListMap[String, Json]): Self = + withClaims(f(claims)) + + final def claim[A: Decoder](name: String): Option[A] = claims.get(name).map(_.as[A].toTry.get) - def withClaim[A: Encoder](name: String, valueOption: Option[A]): Self = + final def withClaim[A: Encoder](name: String, valueOption: Option[A]): Self = withClaims(valueOption.fold( claims.filterNot(_._1 == name) )(value => claims.updated(name, value.asJson) )) - implicit val codec: Codec[Jwt] = Codec.from( - Decoder[String].emapTry(Jwt.decode(_).toTry), - Encoder[String].contramap(_.encode) - ) - } - - case class JwtHeader private(claims: ListMap[String, Json]) extends JwtComponent[JwtHeader] { - def copy(claims: ListMap[String, Json] = claims): JwtHeader = new JwtHeader(JwtHeader.defaultClaims ++ claims) + def reencode: Self - override def withClaims(claims: ListMap[String, Json]): JwtHeader = copy(claims = claims) + def encode: String + } + trait JwtHeaderClaims extends JwtComponent { private[Jwt] lazy val typOption = claim[String]("typ") def typ: String = typOption.get lazy val algorithm: Option[JwtAlgorithm] = claim[String]("alg").flatMap(JwtAlgorithm.fromString) + def withAlgorithm(algorithm: Option[JwtAlgorithm]): Self = withClaim("alg", Some(algorithm.fold("none")(_.name))) + lazy val contentType: Option[String] = claim[String]("cty") + def withContentType(contentType: Option[String]): Self = withClaim("cty", contentType) + lazy val keyId: Option[String] = claim[String]("kid") - def withAlgorithm(algorithm: Option[JwtAlgorithm]): JwtHeader = withClaim("alg", Some(algorithm.fold("none")(_.name))) + def withKeyId(keyId: Option[String]): Self = withClaim("kid", keyId) + + // https://www.rfc-editor.org/rfc/rfc7515 + + lazy val x509Url: Option[String] = claim[String]("x5u") + + def withX509Url(x509Url: Option[String]): Self = withClaim("x5u", x509Url) + + lazy val x509CertificateChain: Option[Seq[Array[Byte]]] = + claim[Seq[String]]("x5c").map(_.map(decodeBase64(_).valueOr(throw _))) + + def withX509CertificateChain(x509CertificateChain: Option[Seq[Array[Byte]]]): Self = + withClaim("x5c", x509CertificateChain.map(_.map(encodeBase64Padded))) + + lazy val x509CertificateSha1Thumbprint: Option[Array[Byte]] = + claim[String]("x5t").map(decodeBase64Url(_).valueOr(throw _)) + + def withX509CertificateSha1Thumbprint(x509CertificateSha1Thumbprint: Option[Array[Byte]]): Self = + withClaim("x5t", x509CertificateSha1Thumbprint.map(encodeBase64Url)) - def withContentType(contentType: Option[String]): JwtHeader = withClaim("cty", contentType) + lazy val x509CertificateSha256Thumbprint: Option[Array[Byte]] = + claim[String]("x5t#S256").map(decodeBase64Url(_).valueOr(throw _)) - def withKeyId(keyId: Option[String]): JwtHeader = withClaim("kid", keyId) + def withX509CertificateSha256Thumbprint(x509CertificateSha256Thumbprint: Option[Array[Byte]]): Self = + withClaim("x5t#S256", x509CertificateSha256Thumbprint.map(encodeBase64Url)) + } + + case class JwtHeader private( + claims: ListMap[String, Json], + encoded: Option[String] + ) extends JwtComponent with JwtHeaderClaims { + override type Self = JwtHeader + + override def withClaims(claims: ListMap[String, Json]): JwtHeader = JwtHeader(claims) + + private[Jwt] def withEncoded(encoded: String): JwtHeader = new JwtHeader(claims, Some(encoded)) + + @deprecated + private def copy(claims: ListMap[String, Json], encoded: Option[String]): JwtHeader = + throw new UnsupportedOperationException() + + def copy(claims: ListMap[String, Json] = claims): JwtHeader = withClaims(claims) - private lazy val normalizedClaims: ListMap[String, Json] = JwtHeader.defaultClaims ++ claims + override def reencode: JwtHeader = JwtHeader(claims) - def encode: String = encodeBase64Url((this: JwtHeader).asJson.noSpaces.getBytes(StandardCharsets.UTF_8)) + override def encode: String = encoded.getOrElse( + encodeBase64Url((this: JwtHeader).asJson.noSpaces.getBytes(StandardCharsets.UTF_8)) + ) } object JwtHeader { @@ -123,156 +139,94 @@ object Jwt { "alg" -> Json.fromString("none") ) - def apply(claims: ListMap[String, Json] = ListMap.empty): JwtHeader = - new JwtHeader(defaultClaims ++ claims) + private def normalizeClaims(claims: ListMap[String, Json]): ListMap[String, Json] = + defaultClaims.foldRight(claims) { case (defaultEntry@(defaultKey, _), claims) => + if (!claims.contains(defaultKey)) ListMap(defaultEntry) ++ claims + else claims + } - def apply(algorithm: Option[JwtAlgorithm]): JwtHeader = - JwtHeader().withAlgorithm(algorithm) + def apply(claims: ListMap[String, Json] = ListMap.empty): JwtHeader = + new JwtHeader(normalizeClaims(claims), None) implicit val codec: Codec[JwtHeader] = Codec.from( - Decoder[ListMap[String, Json]].map(new JwtHeader(_)), - Encoder[ListMap[String, Json]].contramap(_.normalizedClaims) + Decoder[ListMap[String, Json]].map(new JwtHeader(_, None)), + Encoder[ListMap[String, Json]].contramap(_.claims) ) } - case class JwtPayload(claims: ListMap[String, Json] = ListMap.empty) extends JwtComponent[JwtPayload] { - override def withClaims(claims: ListMap[String, Json]): JwtPayload = copy(claims = claims) - + trait JwtPayloadClaims extends JwtComponent { lazy val issuer: Option[String] = claim[String]("iss") + def withIssuer(issuer: Option[String]): Self = withClaim("iss", issuer) + lazy val subject: Option[String] = claim[String]("sub") - lazy val audience: Option[String] = claim[String]("aud") + def withSubject(subject: Option[String]): Self = withClaim("sub", subject) - lazy val expiration: Option[Instant] = claim[Long]("exp").map(Instant.ofEpochSecond) + lazy val audience: Option[String] = claim[String]("aud") - lazy val notBefore: Option[Instant] = claim[Long]("nbf").map(Instant.ofEpochSecond) + def withAudience(audience: Option[String]): Self = withClaim("aud", audience) - lazy val issuedAt: Option[Instant] = claim[Long]("iat").map(Instant.ofEpochSecond) + lazy val expiration: Option[Instant] = claim[Long]("exp").map(Instant.ofEpochSecond) - lazy val jwtId: Option[String] = claim[String]("jti") + def withExpiration(expiration: Option[Instant]): Self = withClaim("exp", expiration.map(_.getEpochSecond)) - def withIssuer(issuer: Option[String]): JwtPayload = withClaim("iss", issuer) + lazy val notBefore: Option[Instant] = claim[Long]("nbf").map(Instant.ofEpochSecond) - def withSubject(subject: Option[String]): JwtPayload = withClaim("sub", subject) + def withNotBefore(notBefore: Option[Instant]): Self = withClaim("nbf", notBefore.map(_.getEpochSecond)) - def withAudience(audience: Option[String]): JwtPayload = withClaim("aud", audience) + lazy val issuedAt: Option[Instant] = claim[Long]("iat").map(Instant.ofEpochSecond) - def withExpiration(expiration: Option[Instant]): JwtPayload = withClaim("exp", expiration.map(_.getEpochSecond)) + def withIssuedAt(issuedAt: Option[Instant]): Self = withClaim("iat", issuedAt.map(_.getEpochSecond)) - def withNotBefore(notBefore: Option[Instant]): JwtPayload = withClaim("nbf", notBefore.map(_.getEpochSecond)) + lazy val jwtId: Option[String] = claim[String]("jti") - def withIssuedAt(issuedAt: Option[Instant]): JwtPayload = withClaim("iat", issuedAt.map(_.getEpochSecond)) + def withJwtId(jwtId: Option[String]): Self = withClaim("jti", jwtId) - def withJwtId(jwtId: Option[String]): JwtPayload = withClaim("jti", jwtId) + // https://www.rfc-editor.org/rfc/rfc8693.html - def encode: String = encodeBase64Url((this: JwtPayload).asJson.noSpaces.getBytes(StandardCharsets.UTF_8)) - } + lazy val actor: Option[JwtPayload] = claim[JwtPayload]("act") - object JwtPayload { - implicit val codec: Codec[JwtPayload] = Codec.from( - Decoder[ListMap[String, Json]].map(JwtPayload(_)), - Encoder[ListMap[String, Json]].contramap(_.claims) - ) + def withActor(actor: Option[JwtPayload]): Self = withClaim("act", actor) } - case class SignedJwt( - jwt: Jwt, - signature: Array[Byte] - ) { - def header: JwtHeader = jwt.header - - def payload: JwtPayload = jwt.payload - - def copy( - header: JwtHeader = jwt.header, - payload: JwtPayload = jwt.payload, - signature: Array[Byte] = signature - ): SignedJwt = SignedJwt( - header = header, - payload = payload, - signature = signature - ) - - def withHeader(header: JwtHeader): SignedJwt = copy(header = header) + case class JwtPayload private( + claims: ListMap[String, Json], + encoded: Option[String] + ) extends JwtComponent with JwtPayloadClaims { + override type Self = JwtPayload - def withPayload(payload: JwtPayload): SignedJwt = copy(payload = payload) + override def withClaims(claims: ListMap[String, Json]): JwtPayload = JwtPayload(claims) - def withSignature(signature: Array[Byte]): SignedJwt = copy(signature = signature) + private[Jwt] def withEncoded(encoded: String): JwtPayload = new JwtPayload(claims, Some(encoded)) - def changeHeader(f: JwtHeader => JwtHeader): SignedJwt = withHeader(f(header)) + @deprecated + private def copy(claims: ListMap[String, Json], encoded: Option[String]): JwtPayload = + throw new UnsupportedOperationException() - def changePayload(f: JwtPayload => JwtPayload): SignedJwt = withPayload(f(payload)) + def copy(claims: ListMap[String, Json] = claims): JwtPayload = withClaims(claims) - def changeSignature(f: Array[Byte] => Array[Byte]): SignedJwt = withSignature(f(signature)) + override def reencode: JwtPayload = JwtPayload(claims) - def encode: String = List[String](jwt.encode, encodeBase64Url(signature)).mkString(".") - - def verify[F[_]]: VerifyPartiallyApplied[F] = - new VerifyPartiallyApplied[F](this) - } - - object SignedJwt { - class VerifyPartiallyApplied[F[_]](jwt: SignedJwt) { - def apply[Algorithm <: JwtAlgorithm, Key]( - algorithm: Algorithm, - key: Key, - options: JwtValidationOptions = JwtValidationOptions.default - )( - implicit verifier: JwtVerifier[F, Algorithm, Key] - ): F[Either[Throwable, Jwt]] = - verifier.verify(jwt.changeHeader(_.withAlgorithm(Some(algorithm))), algorithm, key, options) - } - - def apply( - header: JwtHeader, - payload: JwtPayload, - signature: Array[Byte] - ): SignedJwt = SignedJwt( - jwt = Jwt( - header, - payload - ), - signature = signature - ) - - def decode(string: String): Either[Throwable, SignedJwt] = { - string.split('.').toList match { - case headerBase64 +: payloadBase64 +: signatureBase64 +: Nil => - for { - jwt <- Jwt.decodeComponents(headerBase64, payloadBase64) - signature <- decodeBase64Url(signatureBase64) - } yield SignedJwt( - jwt = jwt, - signature = signature - ) - - case _ => - Left(new IllegalArgumentException("must be of format
..")) - } - } - - implicit val codec: Codec[SignedJwt] = Codec.from( - Decoder[String].emapTry(SignedJwt.decode(_).toTry), - Encoder[String].contramap(_.encode) + def encode: String = encoded.getOrElse( + encodeBase64Url((this: JwtPayload).asJson.noSpaces.getBytes(StandardCharsets.UTF_8)) ) } - def decode(string: String): Either[Throwable, Jwt] = { - string.split('.').toList match { - case headerBase64 +: payloadBase64 +: Nil => - decodeComponents(headerBase64, payloadBase64) + object JwtPayload { + def apply(claims: ListMap[String, Json] = ListMap.empty): JwtPayload = + new JwtPayload(claims, None) - case _ => - Left(new IllegalArgumentException("must be of format
.")) - } + implicit val codec: Codec[JwtPayload] = Codec.from( + Decoder[ListMap[String, Json]].map(new JwtPayload(_, None)), + Encoder[ListMap[String, Json]].contramap(_.claims) + ) } - private def decodeBase64Url(base64: String): Either[IllegalArgumentException, Array[Byte]] = - Either.catchOnly[IllegalArgumentException](Base64.getUrlDecoder.decode(base64)) - - private def encodeBase64Url(bytes: Array[Byte]): String = - Base64.getUrlEncoder.withoutPadding.encodeToString(bytes) + implicit val codec: Codec[Jwt] = Codec.from( + Decoder[String].emapTry(Jwt.decode(_).toTry), + Encoder[String].contramap(_.encode) + ) def decodeComponents(headerBase64: String, payloadBase64: String): Either[Throwable, Jwt] = for { @@ -284,9 +238,17 @@ object Jwt { payload <- io.circe.parser.decode[JwtPayload](payloadString) _ <- header.typOption.filter(_ == "JWT").toRight(new IllegalArgumentException("typ must be `JWT`")) } yield Jwt( - header = header, - payload = payload, - headerBase64 = Some(headerBase64), - payloadBase64 = Some(payloadBase64) + header = header.withEncoded(headerBase64), + payload = payload.withEncoded(payloadBase64) ) + + def decode(string: String): Either[Throwable, Jwt] = { + string.split('.').toList match { + case headerBase64 +: payloadBase64 +: Nil => + decodeComponents(headerBase64, payloadBase64) + + case _ => + Left(new IllegalArgumentException("must be of format
.")) + } + } } diff --git a/core/src/main/scala/de/lhns/jwt/JwtSigner.scala b/core/src/main/scala/de/lhns/jwt/JwtSigner.scala index 2710e0a..a3b7cac 100644 --- a/core/src/main/scala/de/lhns/jwt/JwtSigner.scala +++ b/core/src/main/scala/de/lhns/jwt/JwtSigner.scala @@ -1,7 +1,6 @@ package de.lhns.jwt import cats.effect.kernel.Async -import de.lhns.jwt.Jwt.SignedJwt import de.lhns.jwt.JwtAlgorithm.{JwtAsymmetricAlgorithm, JwtHmacAlgorithm} import pdi.jwt.JwtUtils diff --git a/core/src/main/scala/de/lhns/jwt/JwtVerifier.scala b/core/src/main/scala/de/lhns/jwt/JwtVerifier.scala index ae9004f..8c70f5c 100644 --- a/core/src/main/scala/de/lhns/jwt/JwtVerifier.scala +++ b/core/src/main/scala/de/lhns/jwt/JwtVerifier.scala @@ -1,7 +1,6 @@ package de.lhns.jwt import cats.effect.Async -import de.lhns.jwt.Jwt.SignedJwt import de.lhns.jwt.JwtAlgorithm.{JwtAsymmetricAlgorithm, JwtHmacAlgorithm} import pdi.jwt.JwtUtils diff --git a/core/src/main/scala/de/lhns/jwt/SignedJwt.scala b/core/src/main/scala/de/lhns/jwt/SignedJwt.scala new file mode 100644 index 0000000..3bc7c50 --- /dev/null +++ b/core/src/main/scala/de/lhns/jwt/SignedJwt.scala @@ -0,0 +1,96 @@ +package de.lhns.jwt + +import de.lhns.jwt.Jwt.{JwtHeader, JwtPayload} +import de.lhns.jwt.SignedJwt.VerifyPartiallyApplied +import io.circe.{Codec, Decoder, Encoder} + +final case class SignedJwt( + jwt: Jwt, + signature: Array[Byte] + ) { + def header: JwtHeader = jwt.header + + def payload: JwtPayload = jwt.payload + + @deprecated + private def copy(jwt: Jwt, signature: Array[Byte]): SignedJwt = + throw new UnsupportedOperationException() + + def copy( + header: JwtHeader = jwt.header, + payload: JwtPayload = jwt.payload, + signature: Array[Byte] = signature + ): SignedJwt = SignedJwt( + header = header, + payload = payload, + signature = signature + ) + + def withHeader(header: JwtHeader): SignedJwt = copy(header = header) + + def withPayload(payload: JwtPayload): SignedJwt = copy(payload = payload) + + def withSignature(signature: Array[Byte]): SignedJwt = copy(signature = signature) + + def modifyHeader(f: JwtHeader => JwtHeader): SignedJwt = withHeader(f(header)) + + def modifyPayload(f: JwtPayload => JwtPayload): SignedJwt = withPayload(f(payload)) + + def modifySignature(f: Array[Byte] => Array[Byte]): SignedJwt = withSignature(f(signature)) + + def reencode: SignedJwt = copy( + header = header.reencode, + payload = payload.reencode + ) + + def encode: String = s"${jwt.encode}.${encodeBase64Url(signature)}" + + def verify[F[_]]: VerifyPartiallyApplied[F] = + new VerifyPartiallyApplied[F](this) +} + +object SignedJwt { + class VerifyPartiallyApplied[F[_]](jwt: SignedJwt) { + def apply[Algorithm <: JwtAlgorithm, Key]( + algorithm: Algorithm, + key: Key, + options: JwtValidationOptions = JwtValidationOptions.default + )( + implicit verifier: JwtVerifier[F, Algorithm, Key] + ): F[Either[Throwable, Jwt]] = + verifier.verify(jwt.modifyHeader(_.withAlgorithm(Some(algorithm))), algorithm, key, options) + } + + def apply( + header: JwtHeader, + payload: JwtPayload, + signature: Array[Byte] + ): SignedJwt = SignedJwt( + jwt = Jwt( + header, + payload + ), + signature = signature + ) + + def decode(string: String): Either[Throwable, SignedJwt] = { + string.split('.').toList match { + case headerBase64 +: payloadBase64 +: signatureBase64 +: Nil => + for { + jwt <- Jwt.decodeComponents(headerBase64, payloadBase64) + signature <- decodeBase64Url(signatureBase64) + } yield SignedJwt( + jwt = jwt, + signature = signature + ) + + case _ => + Left(new IllegalArgumentException("must be of format
..")) + } + } + + implicit val codec: Codec[SignedJwt] = Codec.from( + Decoder[String].emapTry(SignedJwt.decode(_).toTry), + Encoder[String].contramap(_.encode) + ) +} \ No newline at end of file diff --git a/core/src/main/scala/de/lhns/jwt/package.scala b/core/src/main/scala/de/lhns/jwt/package.scala new file mode 100644 index 0000000..d0e9ef8 --- /dev/null +++ b/core/src/main/scala/de/lhns/jwt/package.scala @@ -0,0 +1,19 @@ +package de.lhns + +import cats.syntax.all._ + +import java.util.Base64 + +package object jwt { + private[jwt] def decodeBase64(base64: String): Either[IllegalArgumentException, Array[Byte]] = + Either.catchOnly[IllegalArgumentException](Base64.getDecoder.decode(base64)) + + private[jwt] def encodeBase64Padded(bytes: Array[Byte]): String = + Base64.getEncoder.encodeToString(bytes) + + private[jwt] def decodeBase64Url(base64: String): Either[IllegalArgumentException, Array[Byte]] = + Either.catchOnly[IllegalArgumentException](Base64.getUrlDecoder.decode(base64)) + + private[jwt] def encodeBase64Url(bytes: Array[Byte]): String = + Base64.getUrlEncoder.withoutPadding.encodeToString(bytes) +}