diff --git a/dependencies/pom.xml b/dependencies/pom.xml index ffc9f917d2c..7dd676084df 100644 --- a/dependencies/pom.xml +++ b/dependencies/pom.xml @@ -109,7 +109,7 @@ 4.0 2.0 4.0 - 2.0 + 2.1 4.0 3.0 3.0 diff --git a/microprofile/jwt-auth/src/main/java/io/helidon/microprofile/jwt/auth/JwtAuthProvider.java b/microprofile/jwt-auth/src/main/java/io/helidon/microprofile/jwt/auth/JwtAuthProvider.java index 7ebd31cd674..897ea3630e9 100644 --- a/microprofile/jwt-auth/src/main/java/io/helidon/microprofile/jwt/auth/JwtAuthProvider.java +++ b/microprofile/jwt-auth/src/main/java/io/helidon/microprofile/jwt/auth/JwtAuthProvider.java @@ -32,6 +32,7 @@ import java.security.interfaces.ECPublicKey; import java.security.interfaces.RSAPrivateKey; import java.security.interfaces.RSAPublicKey; +import java.time.Duration; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Base64; @@ -100,9 +101,6 @@ * Provider that provides JWT authentication. */ public class JwtAuthProvider extends SynchronousProvider implements AuthenticationProvider, OutboundSecurityProvider { - private static final System.Logger LOGGER = System.getLogger(JwtAuthProvider.class.getName()); - - private static final JsonReaderFactory JSON = Json.createReaderFactory(Collections.emptyMap()); /** * Configure this for outbound requests to override user to use. @@ -116,6 +114,9 @@ public class JwtAuthProvider extends SynchronousProvider implements Authenticati * Configuration key for expected audiences of incoming tokens. Used for validation of JWT. */ public static final String CONFIG_EXPECTED_AUDIENCES = "mp.jwt.verify.audiences"; + + private static final String CONFIG_EXPECTED_MAX_TOKEN_AGE = "mp.jwt.verify.token.age"; + private static final String CONFIG_CLOCK_SKEW = "mp.jwt.verify.clock.skew"; /** * Configuration of Cookie property name which contains JWT token. * @@ -128,6 +129,8 @@ public class JwtAuthProvider extends SynchronousProvider implements Authenticati * Default value is {@link Http.Header#AUTHORIZATION}. */ private static final String CONFIG_JWT_HEADER = "mp.jwt.token.header"; + private static final System.Logger LOGGER = System.getLogger(JwtAuthProvider.class.getName()); + private static final JsonReaderFactory JSON = Json.createReaderFactory(Collections.emptyMap()); private final boolean optional; private final boolean authenticate; @@ -147,7 +150,10 @@ public class JwtAuthProvider extends SynchronousProvider implements Authenticati private final Map targetToJwtConfig = new IdentityHashMap<>(); private final String expectedIssuer; private final String cookiePrefix; + private final String decryptionKeyAlgorithm; private final boolean useCookie; + private final Duration expectedMaxTokenAge; + private final Duration clockSkew; private JwtAuthProvider(Builder builder) { this.optional = builder.optional; @@ -167,6 +173,9 @@ private JwtAuthProvider(Builder builder) { this.useCookie = builder.useCookie; this.decryptionKeys = builder.decryptionKeys; this.defaultDecryptionJwk = builder.defaultDecryptionJwk; + this.decryptionKeyAlgorithm = builder.decryptionKeyAlgorithm; + this.expectedMaxTokenAge = builder.expectedMaxTokenAge; + this.clockSkew = builder.clockSkew; if (null == atnTokenHandler) { defaultTokenHandler = TokenHandler.builder() @@ -252,7 +261,14 @@ AuthenticationResponse authenticate(ProviderRequest providerRequest, LoginConfig throw new JwtException("Header \"cty\" (content type) must be set to \"JWT\" " + "for encrypted tokens"); } - signedJwt = encryptedJwt.decrypt(decryptionKeys.get(), defaultDecryptionJwk.get()); + List> validators = new LinkedList<>(); + EncryptedJwt.addKekValidator(validators, decryptionKeyAlgorithm, true); + Errors errors = encryptedJwt.validate(validators); + if (errors.isValid()) { + signedJwt = encryptedJwt.decrypt(decryptionKeys.get(), defaultDecryptionJwk.get()); + } else { + return AuthenticationResponse.failed(errors.toString()); + } } else { signedJwt = SignedJwt.parseToken(token); } @@ -278,7 +294,13 @@ AuthenticationResponse authenticate(ProviderRequest providerRequest, LoginConfig } // validate user principal is present Jwt.addUserPrincipalValidator(validators); - validators.add(Jwt.ExpirationValidator.create(true)); + validators.add(Jwt.ExpirationValidator.create(Instant.now(), + (int) clockSkew.getSeconds(), + ChronoUnit.SECONDS, + true)); + if (expectedMaxTokenAge != null) { + Jwt.addMaxTokenAgeValidator(validators, expectedMaxTokenAge, clockSkew, true); + } Errors validate = jwt.validate(validators); @@ -579,6 +601,7 @@ public static class Builder implements io.helidon.common.Builder audiences) { return this; } + /** + * Maximal expected token age. If this value is set, {@code iat} claim needs to be present in the JWT. + * + * @param expectedMaxTokenAge expected maximal token age in seconds + * @return updated builder instance + */ + public Builder expectedMaxTokenAge(int expectedMaxTokenAge) { + this.expectedMaxTokenAge = Duration.ofSeconds(expectedMaxTokenAge); + return this; + } + /** * Private key to decryption of encrypted claims. * @@ -1205,6 +1245,17 @@ public Builder decryptKeyLocation(String decryptKeyLocation) { return this; } + /** + * Expected decryption key algorithm. + * + * @param decryptionKeyAlgorithm expected decryption key algorithm + * @return updated builder instance + */ + public Builder decryptKeyAlgorithm(String decryptionKeyAlgorithm) { + this.decryptionKeyAlgorithm = decryptionKeyAlgorithm; + return this; + } + /** * Whether to load JWK verification keys on server startup * Default value is {@code false}. @@ -1217,6 +1268,17 @@ public Builder loadOnStartup(boolean loadOnStartup) { return this; } + /** + * Clock skew to be accounted for in token expiration and max age validations. + * + * @param clockSkew clock skew + * @return updated builder instance + */ + public Builder clockSkew(int clockSkew) { + this.clockSkew = Duration.ofSeconds(clockSkew); + return this; + } + private void verifyKeys(Config config) { config.get("jwk.resource").as(Resource::create).ifPresent(this::verifyJwk); } diff --git a/security/jwt/src/main/java/io/helidon/security/jwt/EncryptedJwt.java b/security/jwt/src/main/java/io/helidon/security/jwt/EncryptedJwt.java index 7d949399d6f..2806c58f599 100644 --- a/security/jwt/src/main/java/io/helidon/security/jwt/EncryptedJwt.java +++ b/security/jwt/src/main/java/io/helidon/security/jwt/EncryptedJwt.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, 2022 Oracle and/or its affiliates. + * Copyright (c) 2021, 2023 Oracle and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,14 @@ import java.security.PublicKey; import java.security.SecureRandom; import java.security.spec.AlgorithmParameterSpec; +import java.security.spec.MGF1ParameterSpec; import java.util.Arrays; import java.util.Base64; +import java.util.Collection; +import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -36,6 +40,8 @@ import javax.crypto.SecretKey; import javax.crypto.spec.GCMParameterSpec; import javax.crypto.spec.IvParameterSpec; +import javax.crypto.spec.OAEPParameterSpec; +import javax.crypto.spec.PSource; import javax.crypto.spec.SecretKeySpec; import io.helidon.common.Errors; @@ -54,7 +60,6 @@ */ public final class EncryptedJwt { - private static final Map RSA_ALGORITHMS; private static final Map CONTENT_ENCRYPTION; private static final Pattern JWE_PATTERN = Pattern @@ -63,10 +68,6 @@ public final class EncryptedJwt { private static final Base64.Encoder URL_ENCODER = Base64.getUrlEncoder().withoutPadding(); static { - RSA_ALGORITHMS = Map.of(SupportedAlgorithm.RSA_OAEP, "RSA/ECB/OAEPWithSHA-1AndMGF1Padding", - SupportedAlgorithm.RSA_OAEP_256, "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", - SupportedAlgorithm.RSA1_5, "RSA/ECB/PKCS1Padding"); - CONTENT_ENCRYPTION = Map.of(SupportedEncryption.A128GCM, new AesGcmAlgorithm(128), SupportedEncryption.A192GCM, new AesGcmAlgorithm(192), SupportedEncryption.A256GCM, new AesGcmAlgorithm(256), @@ -193,6 +194,31 @@ public static EncryptedJwt parseToken(JwtHeaders header, String token) { } } + /** + * Add validator of kek algorithm to the collection of validators. + * + * @param validators collection of validators + * @param expectedKekAlg audience key encryption key algorithm + * @param mandatory whether the alg field is mandatory in the token + */ + public static void addKekValidator(Collection> validators, String expectedKekAlg, boolean mandatory) { + validators.add((encryptedJwt, collector) -> { + Optional kekAlgorithm = encryptedJwt.header.algorithm(); + if (kekAlgorithm.isPresent()) { + //if null, any kek alg is allowed + if (expectedKekAlg == null || kekAlgorithm.get().equals(expectedKekAlg)) { + return; + } + collector.fatal(encryptedJwt, "Key encryption key algorithm must be equal to " + expectedKekAlg + + ", yet it is: " + kekAlgorithm.get()); + } else { + if (mandatory) { + collector.fatal(encryptedJwt, "Key encryption key algorithm is expected to be present for encrypted JWT"); + } + } + }); + } + private static EncryptedJwt parse(String token, Errors.Collector collector, JwtHeaders header, @@ -211,21 +237,29 @@ private static EncryptedJwt parse(String token, return new EncryptedJwt(token, header, iv, encryptedKey, authTag, encryptedPayload); } - private static byte[] encryptRsa(String algorithm, PublicKey publicKey, byte[] unencryptedKey) { + private static byte[] wrapRsa(SupportedAlgorithm supportedAlgorithm, PublicKey publicKey, byte[] unencryptedKey) { try { - Cipher rsaCipher = Cipher.getInstance(algorithm); - rsaCipher.init(Cipher.ENCRYPT_MODE, publicKey); - return rsaCipher.doFinal(unencryptedKey); + Cipher rsaCipher = Cipher.getInstance(supportedAlgorithm.cipherName()); + if (supportedAlgorithm.parameterSpec() == null) { + rsaCipher.init(Cipher.WRAP_MODE, publicKey); + } else { + rsaCipher.init(Cipher.WRAP_MODE, publicKey, supportedAlgorithm.parameterSpec()); + } + return rsaCipher.wrap(new SecretKeySpec(unencryptedKey, "AES")); } catch (Exception e) { throw new JwtException("Exception during rsa key decryption occurred.", e); } } - private static byte[] decryptRsa(String algorithm, PrivateKey privateKey, byte[] encryptedKey) { + private static byte[] unwrapRsa(SupportedAlgorithm supportedAlgorithm, PrivateKey privateKey, byte[] encryptedKey) { try { - Cipher rsaCipher = Cipher.getInstance(algorithm); - rsaCipher.init(Cipher.DECRYPT_MODE, privateKey); - return rsaCipher.doFinal(encryptedKey); + Cipher rsaCipher = Cipher.getInstance(supportedAlgorithm.cipherName()); + if (supportedAlgorithm.parameterSpec() == null) { + rsaCipher.init(Cipher.UNWRAP_MODE, privateKey); + } else { + rsaCipher.init(Cipher.UNWRAP_MODE, privateKey, supportedAlgorithm.parameterSpec()); + } + return rsaCipher.unwrap(encryptedKey, "AES", Cipher.SECRET_KEY).getEncoded(); } catch (Exception e) { throw new JwtException("Exception during rsa key decryption occurred.", e); } @@ -255,7 +289,7 @@ private static byte[] decodeBytes(String base64, Errors.Collector collector, Str * * Selected {@link Jwk} needs to have private key set. * - * @param jwkKeys jwk keys + * @param jwkKeys jwk keys * @return empty optional if any error has occurred or SignedJwt instance if the decryption and validation was successful */ public SignedJwt decrypt(JwkKeys jwkKeys) { @@ -268,7 +302,7 @@ public SignedJwt decrypt(JwkKeys jwkKeys) { * * Provided {@link Jwk} needs to have private key set. * - * @param jwk jwk keys + * @param jwk jwk keys * @return empty optional if any error has occurred or SignedJwt instance if the decryption and validation was successful */ public SignedJwt decrypt(Jwk jwk) { @@ -313,11 +347,10 @@ public SignedJwt decrypt(JwkKeys jwkKeys, Jwk defaultJwk) { if (enc == null) { errors.fatal("Content encryption algorithm not set."); } - + SupportedAlgorithm supportedAlgorithm = null; if (alg != null) { try { - SupportedAlgorithm supportedAlgorithm = SupportedAlgorithm.getValue(alg); - algorithm = RSA_ALGORITHMS.get(supportedAlgorithm); + supportedAlgorithm = SupportedAlgorithm.getValue(alg); } catch (IllegalArgumentException e) { errors.fatal("Value of the claim alg not supported. alg: " + alg); } @@ -345,7 +378,7 @@ public SignedJwt decrypt(JwkKeys jwkKeys, Jwk defaultJwk) { errors.collect().checkValid(); - byte[] decryptedKey = decryptRsa(algorithm, privateKey, encryptedKey); + byte[] decryptedKey = unwrapRsa(supportedAlgorithm, privateKey, encryptedKey); //Base64 headers are used as an aad. This aad has to be in US_ASCII encoding. EncryptionParts encryptionParts = new EncryptionParts(decryptedKey, iv, @@ -417,6 +450,18 @@ public byte[] encryptedPayload() { return Arrays.copyOf(encryptedPayload, encryptedPayload.length); } + /** + * Validate this Encrypted JWT against provided validators. + * + * @param validators Validators to validate with. + * @return errors instance to check if valid and access error messages + */ + public Errors validate(List> validators) { + Errors.Collector collector = Errors.collector(); + validators.forEach(it -> it.validate(this, collector)); + return collector.collect(); + } + /** * Encrypted JWT builder. */ @@ -511,14 +556,13 @@ public EncryptedJwt build() { JwtHeaders headers = headersBuilder.build(); StringBuilder tokenBuilder = new StringBuilder(); String headersBase64 = encode(headers.headerJson().toString()); - String rsaCipherType = RSA_ALGORITHMS.get(algorithm); AesAlgorithm contentEncryption = CONTENT_ENCRYPTION.get(encryption); //Base64 headers are used as an aad. This aad has to be in US_ASCII encoding. EncryptionParts encryptionParts = contentEncryption.encrypt(jwt.tokenContent().getBytes(StandardCharsets.UTF_8), headersBase64.getBytes(StandardCharsets.US_ASCII)); byte[] aesKey = encryptionParts.key(); - byte[] encryptedAesKey = encryptRsa(rsaCipherType, publicKey, aesKey); + byte[] encryptedAesKey = wrapRsa(algorithm, publicKey, aesKey); String token = tokenBuilder.append(headersBase64).append(".") .append(encode(encryptedAesKey)).append(".") .append(encode(encryptionParts.iv())).append(".") @@ -544,20 +588,28 @@ public enum SupportedAlgorithm { /** * RSA-OAEP declares that RSA/ECB/OAEPWithSHA-1AndMGF1Padding cipher will be used for content key encryption. */ - RSA_OAEP("RSA-OAEP"), + RSA_OAEP("RSA-OAEP", + "RSA/ECB/OAEPWithSHA-1AndMGF1Padding", + new OAEPParameterSpec("SHA-1", "MGF1", MGF1ParameterSpec.SHA1, PSource.PSpecified.DEFAULT)), /** * RSA-OAEP-256 declares that RSA/ECB/OAEPWithSHA-256AndMGF1Padding cipher will be used for content key encryption. */ - RSA_OAEP_256("RSA-OAEP-256"), + RSA_OAEP_256("RSA-OAEP-256", + "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", + new OAEPParameterSpec("SHA-256", "MGF1", MGF1ParameterSpec.SHA256, PSource.PSpecified.DEFAULT)), /** * RSA1_5 declares that RSA/ECB/PKCS1Padding cipher will be used for content key encryption. */ - RSA1_5("RSA1_5"); + RSA1_5("RSA1_5", "RSA/ECB/PKCS1Padding", null); private final String algorithmName; + private final String cipherName; + private final AlgorithmParameterSpec parameterSpec; - SupportedAlgorithm(String algorithmName) { + SupportedAlgorithm(String algorithmName, String cipherName, AlgorithmParameterSpec parameterSpec) { this.algorithmName = algorithmName; + this.cipherName = cipherName; + this.parameterSpec = parameterSpec; } @Override @@ -565,9 +617,21 @@ public String toString() { return algorithmName; } - static SupportedAlgorithm getValue(String value) { + String cipherName() { + return cipherName; + } + + String algorithmName() { + return algorithmName; + } + + AlgorithmParameterSpec parameterSpec() { + return parameterSpec; + } + + static SupportedAlgorithm getValue(String algorithmName) { for (SupportedAlgorithm v : values()) { - if (v.algorithmName.equalsIgnoreCase(value)) { + if (v.algorithmName.equalsIgnoreCase(algorithmName)) { return v; } } diff --git a/security/jwt/src/main/java/io/helidon/security/jwt/Jwt.java b/security/jwt/src/main/java/io/helidon/security/jwt/Jwt.java index 8ef8950f250..881679f3a54 100644 --- a/security/jwt/src/main/java/io/helidon/security/jwt/Jwt.java +++ b/security/jwt/src/main/java/io/helidon/security/jwt/Jwt.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2022 Oracle and/or its affiliates. + * Copyright (c) 2018, 2023 Oracle and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package io.helidon.security.jwt; import java.net.URI; +import java.time.Duration; import java.time.Instant; import java.time.LocalDate; import java.time.ZoneId; @@ -426,6 +427,36 @@ public static void addAudienceValidator(Collection> validators, S }); } + /** + * Add validator of max token age to the collection of validators. + * + * @param validators collection of validators + * @param expectedMaxTokenAge max token age since issue time + * @param clockSkew clock skew + * @param iatRequired whether to fail if iat clam is present + */ + public static void addMaxTokenAgeValidator(Collection> validators, + Duration expectedMaxTokenAge, + Duration clockSkew, + boolean iatRequired) { + validators.add((jwt, collector) -> { + Optional maybeIssueTime = jwt.issueTime(); + if (maybeIssueTime.isPresent()) { + Instant now = Instant.now(); + Instant issueTime = maybeIssueTime.get().minus(clockSkew); + Instant maxValidTime = issueTime.plus(expectedMaxTokenAge).plus(clockSkew); + if (issueTime.isBefore(now) && maxValidTime.isAfter(now)) { + return; + } + collector.fatal(jwt, "Current time need to be between " + issueTime + + " and " + maxValidTime + ", but was " + now); + } else if (iatRequired) { + collector.fatal(jwt, "Claim iat is required to be present in JWT when validating token max allowed age."); + } + + }); + } + /** * Get a builder to create a JWT. *