From 4f89b4a782aee54a1c69460a30903b0d0d6b7565 Mon Sep 17 00:00:00 2001 From: MaciejMierzwa Date: Tue, 24 Oct 2023 23:33:07 +0200 Subject: [PATCH] Switch JWT library implementations from cxf to nimbus (#3421) Switch from org.apache.cxf.rs.security.jose to com.nimbusds.jose.jwk. Signed-off-by: Peter Nied Signed-off-by: Maciej Mierzwa Signed-off-by: Ryan Liang Co-authored-by: Peter Nied Co-authored-by: Ryan Liang --- build.gradle | 4 +- checkstyle/checkstyle.xml | 1 + .../jwt/AbstractHTTPJwtAuthenticator.java | 26 ++- .../auth/http/jwt/keybyoidc/JwtVerifier.java | 88 ++++----- .../auth/http/jwt/keybyoidc/KeyProvider.java | 6 +- .../http/jwt/keybyoidc/KeySetProvider.java | 4 +- .../http/jwt/keybyoidc/KeySetRetriever.java | 11 +- .../jwt/keybyoidc/SelfRefreshingKeySet.java | 46 ++--- .../http/saml/AuthTokenProcessorHandler.java | 108 +++++------ .../auth/http/saml/HTTPSamlAuthenticator.java | 6 +- .../security/authtoken/jwt/JwtVendor.java | 150 ++++++++-------- .../authtoken/jwt/KeyPaddingUtil.java | 33 ++++ .../http/jwt/HTTPJwtAuthenticatorTest.java | 77 +++++++- ...wtKeyByOpenIdConnectAuthenticatorTest.java | 54 +++--- .../http/jwt/keybyoidc/MockIpdServer.java | 12 +- .../keybyoidc/SelfRefreshingKeySetTest.java | 25 +-- ...wtKeyByOpenIdConnectAuthenticatorTest.java | 11 +- .../dlic/auth/http/jwt/keybyoidc/TestJwk.java | 96 +++++----- .../auth/http/jwt/keybyoidc/TestJwts.java | 110 +++++++----- .../http/saml/HTTPSamlAuthenticatorTest.java | 131 ++++++++------ .../security/authtoken/jwt/JwtVendorTest.java | 168 ++++++++---------- .../authtoken/jwt/KeyPaddingUtilTest.java | 42 +++++ 22 files changed, 672 insertions(+), 537 deletions(-) create mode 100644 src/main/java/org/opensearch/security/authtoken/jwt/KeyPaddingUtil.java create mode 100644 src/test/java/org/opensearch/security/authtoken/jwt/KeyPaddingUtilTest.java diff --git a/build.gradle b/build.gradle index cd5d2f301c..2b2358d9cf 100644 --- a/build.gradle +++ b/build.gradle @@ -558,6 +558,7 @@ dependencies { implementation 'commons-cli:commons-cli:1.5.0' implementation "org.bouncycastle:bcprov-jdk15to18:${versions.bouncycastle}" implementation 'org.ldaptive:ldaptive:1.2.3' + implementation 'com.nimbusds:nimbus-jose-jwt:9.31' //JWT implementation "io.jsonwebtoken:jjwt-api:${jjwt_version}" @@ -581,9 +582,6 @@ dependencies { runtimeOnly 'net.minidev:accessors-smart:2.5.0' - implementation("org.apache.cxf:cxf-rt-rs-security-jose:${apache_cxf_version}") { - exclude(group: 'jakarta.activation', module: 'jakarta.activation-api') - } runtimeOnly "org.apache.cxf:cxf-core:${apache_cxf_version}" implementation "org.apache.cxf:cxf-rt-rs-json-basic:${apache_cxf_version}" runtimeOnly "org.apache.cxf:cxf-rt-security:${apache_cxf_version}" diff --git a/checkstyle/checkstyle.xml b/checkstyle/checkstyle.xml index 4484ea4e04..b679ce24ce 100644 --- a/checkstyle/checkstyle.xml +++ b/checkstyle/checkstyle.xml @@ -120,6 +120,7 @@ + diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java index b183593a91..4271e68f1c 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java @@ -16,6 +16,7 @@ import java.nio.file.Path; import java.security.AccessController; import java.security.PrivilegedAction; +import java.text.ParseException; import java.util.Collection; import java.util.Map; import java.util.Optional; @@ -23,8 +24,8 @@ import java.util.regex.Pattern; import com.google.common.annotations.VisibleForTesting; -import org.apache.cxf.rs.security.jose.jwt.JwtClaims; -import org.apache.cxf.rs.security.jose.jwt.JwtToken; +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.SignedJWT; import org.apache.http.HttpStatus; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -112,37 +113,34 @@ private AuthCredentials extractCredentials0(final SecurityRequest request) throw return null; } - JwtToken jwt; + SignedJWT jwt; + JWTClaimsSet claimsSet; try { jwt = jwtVerifier.getVerifiedJwtToken(jwtString); + claimsSet = jwt.getJWTClaimsSet(); } catch (AuthenticatorUnavailableException e) { log.info(e.toString()); throw new OpenSearchSecurityException(e.getMessage(), RestStatus.SERVICE_UNAVAILABLE); - } catch (BadCredentialsException e) { + } catch (BadCredentialsException | ParseException e) { log.info("Extracting JWT token from {} failed", jwtString, e); return null; } - JwtClaims claims = jwt.getClaims(); - - final String subject = extractSubject(claims); - + final String subject = extractSubject(claimsSet); if (subject == null) { log.error("No subject found in JWT token"); return null; } - final String[] roles = extractRoles(claims); - + final String[] roles = extractRoles(claimsSet); final AuthCredentials ac = new AuthCredentials(subject, roles).markComplete(); - for (Entry claim : claims.asMap().entrySet()) { + for (Entry claim : claimsSet.getClaims().entrySet()) { ac.addAttribute("attr.jwt." + claim.getKey(), String.valueOf(claim.getValue())); } return ac; - } protected String getJwtTokenString(SecurityRequest request) { @@ -174,7 +172,7 @@ protected String getJwtTokenString(SecurityRequest request) { } @VisibleForTesting - public String extractSubject(JwtClaims claims) { + public String extractSubject(JWTClaimsSet claims) { String subject = claims.getSubject(); if (subjectKey != null) { @@ -204,7 +202,7 @@ public String extractSubject(JwtClaims claims) { @SuppressWarnings("unchecked") @VisibleForTesting - public String[] extractRoles(JwtClaims claims) { + public String[] extractRoles(JWTClaimsSet claims) { if (rolesKey == null) { return new String[0]; } diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/JwtVerifier.java b/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/JwtVerifier.java index 8cac2a23d1..3716eb7997 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/JwtVerifier.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/JwtVerifier.java @@ -12,21 +12,24 @@ package com.amazon.dlic.auth.http.jwt.keybyoidc; import com.google.common.base.Strings; +import com.nimbusds.jose.Algorithm; +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JWSVerifier; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.OctetSequenceKey; +import com.nimbusds.jose.crypto.factories.DefaultJWSVerifierFactory; +import com.nimbusds.jose.proc.SimpleSecurityContext; +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.SignedJWT; +import com.nimbusds.jwt.proc.BadJWTException; +import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier; import org.apache.commons.lang3.StringEscapeUtils; -import org.apache.cxf.rs.security.jose.jwa.SignatureAlgorithm; -import org.apache.cxf.rs.security.jose.jwk.JsonWebKey; -import org.apache.cxf.rs.security.jose.jwk.KeyType; -import org.apache.cxf.rs.security.jose.jwk.PublicKeyUse; -import org.apache.cxf.rs.security.jose.jws.JwsJwtCompactConsumer; -import org.apache.cxf.rs.security.jose.jws.JwsSignatureVerifier; -import org.apache.cxf.rs.security.jose.jws.JwsUtils; -import org.apache.cxf.rs.security.jose.jwt.JwtClaims; -import org.apache.cxf.rs.security.jose.jwt.JwtException; -import org.apache.cxf.rs.security.jose.jwt.JwtToken; -import org.apache.cxf.rs.security.jose.jwt.JwtUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import java.text.ParseException; +import java.util.Collections; + public class JwtVerifier { private final static Logger log = LogManager.getLogger(JwtVerifier.class); @@ -43,31 +46,24 @@ public JwtVerifier(KeyProvider keyProvider, int clockSkewToleranceSeconds, Strin this.requiredAudience = requiredAudience; } - public JwtToken getVerifiedJwtToken(String encodedJwt) throws BadCredentialsException { + public SignedJWT getVerifiedJwtToken(String encodedJwt) throws BadCredentialsException { try { - JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer(encodedJwt); - JwtToken jwt = jwtConsumer.getJwtToken(); + SignedJWT jwt = SignedJWT.parse(encodedJwt); - String escapedKid = jwt.getJwsHeaders().getKeyId(); + String escapedKid = jwt.getHeader().getKeyID(); String kid = escapedKid; if (!Strings.isNullOrEmpty(kid)) { kid = StringEscapeUtils.unescapeJava(escapedKid); } - JsonWebKey key = keyProvider.getKey(kid); - - // Algorithm is not mandatory for the key material, so we set it to the same as the JWT - if (key.getAlgorithm() == null && key.getPublicKeyUse() == PublicKeyUse.SIGN && key.getKeyType() == KeyType.RSA) { - key.setAlgorithm(jwt.getJwsHeaders().getAlgorithm()); - } - - JwsSignatureVerifier signatureVerifier = getInitializedSignatureVerifier(key, jwt); + JWK key = keyProvider.getKey(kid); - boolean signatureValid = jwtConsumer.verifySignatureWith(signatureVerifier); + JWSVerifier signatureVerifier = getInitializedSignatureVerifier(key, jwt); + boolean signatureValid = jwt.verify(signatureVerifier); if (!signatureValid && Strings.isNullOrEmpty(kid)) { key = keyProvider.getKeyAfterRefresh(null); signatureVerifier = getInitializedSignatureVerifier(key, jwt); - signatureValid = jwtConsumer.verifySignatureWith(signatureVerifier); + signatureValid = jwt.verify(signatureVerifier); } if (!signatureValid) { @@ -77,18 +73,18 @@ public JwtToken getVerifiedJwtToken(String encodedJwt) throws BadCredentialsExce validateClaims(jwt); return jwt; - } catch (JwtException e) { + } catch (JOSEException | ParseException | BadJWTException e) { throw new BadCredentialsException(e.getMessage(), e); } } - private void validateSignatureAlgorithm(JsonWebKey key, JwtToken jwt) throws BadCredentialsException { - if (Strings.isNullOrEmpty(key.getAlgorithm())) { + private void validateSignatureAlgorithm(JWK key, SignedJWT jwt) throws BadCredentialsException { + if (key.getAlgorithm() == null || jwt.getHeader().getAlgorithm() == null) { return; } - SignatureAlgorithm keyAlgorithm = SignatureAlgorithm.getAlgorithm(key.getAlgorithm()); - SignatureAlgorithm tokenAlgorithm = SignatureAlgorithm.getAlgorithm(jwt.getJwsHeaders().getAlgorithm()); + Algorithm keyAlgorithm = key.getAlgorithm(); + Algorithm tokenAlgorithm = jwt.getHeader().getAlgorithm(); if (!keyAlgorithm.equals(tokenAlgorithm)) { throw new BadCredentialsException( @@ -97,11 +93,16 @@ private void validateSignatureAlgorithm(JsonWebKey key, JwtToken jwt) throws Bad } } - private JwsSignatureVerifier getInitializedSignatureVerifier(JsonWebKey key, JwtToken jwt) throws BadCredentialsException, - JwtException { + private JWSVerifier getInitializedSignatureVerifier(JWK key, SignedJWT jwt) throws BadCredentialsException, JOSEException { validateSignatureAlgorithm(key, jwt); - JwsSignatureVerifier result = JwsUtils.getSignatureVerifier(key, jwt.getJwsHeaders().getSignatureAlgorithm()); + final JWSVerifier result; + if (key.getClass() == OctetSequenceKey.class) { + result = new DefaultJWSVerifierFactory().createJWSVerifier(jwt.getHeader(), key.toOctetSequenceKey().toSecretKey()); + } else { + result = new DefaultJWSVerifierFactory().createJWSVerifier(jwt.getHeader(), key.toRSAKey().toRSAPublicKey()); + } + if (result == null) { throw new BadCredentialsException("Cannot verify JWT"); } else { @@ -109,26 +110,31 @@ private JwsSignatureVerifier getInitializedSignatureVerifier(JsonWebKey key, Jwt } } - private void validateClaims(JwtToken jwt) throws JwtException { - JwtClaims claims = jwt.getClaims(); + private void validateClaims(SignedJWT jwt) throws ParseException, BadJWTException { + JWTClaimsSet claims = jwt.getJWTClaimsSet(); if (claims != null) { - JwtUtils.validateJwtExpiry(claims, clockSkewToleranceSeconds, false); - JwtUtils.validateJwtNotBefore(claims, clockSkewToleranceSeconds, false); + DefaultJWTClaimsVerifier claimsVerifier = new DefaultJWTClaimsVerifier<>( + requiredAudience, + null, + Collections.emptySet() + ); + claimsVerifier.setMaxClockSkew(clockSkewToleranceSeconds); + claimsVerifier.verify(claims, null); validateRequiredAudienceAndIssuer(claims); } } - private void validateRequiredAudienceAndIssuer(JwtClaims claims) { - String audience = claims.getAudience(); + private void validateRequiredAudienceAndIssuer(JWTClaimsSet claims) throws BadJWTException { + String audience = claims.getAudience().stream().findFirst().orElse(""); String issuer = claims.getIssuer(); if (!Strings.isNullOrEmpty(requiredAudience) && !requiredAudience.equals(audience)) { - throw new JwtException("Invalid audience"); + throw new BadJWTException("Invalid audience"); } if (!Strings.isNullOrEmpty(requiredIssuer) && !requiredIssuer.equals(issuer)) { - throw new JwtException("Invalid issuer"); + throw new BadJWTException("Invalid issuer"); } } } diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/KeyProvider.java b/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/KeyProvider.java index a0e76c918f..ef539cfa35 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/KeyProvider.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/KeyProvider.java @@ -11,10 +11,10 @@ package com.amazon.dlic.auth.http.jwt.keybyoidc; -import org.apache.cxf.rs.security.jose.jwk.JsonWebKey; +import com.nimbusds.jose.jwk.JWK; public interface KeyProvider { - public JsonWebKey getKey(String kid) throws AuthenticatorUnavailableException, BadCredentialsException; + JWK getKey(String kid) throws AuthenticatorUnavailableException, BadCredentialsException; - public JsonWebKey getKeyAfterRefresh(String kid) throws AuthenticatorUnavailableException, BadCredentialsException; + JWK getKeyAfterRefresh(String kid) throws AuthenticatorUnavailableException, BadCredentialsException; } diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/KeySetProvider.java b/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/KeySetProvider.java index 53ea0237db..c1f5979cde 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/KeySetProvider.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/KeySetProvider.java @@ -11,9 +11,9 @@ package com.amazon.dlic.auth.http.jwt.keybyoidc; -import org.apache.cxf.rs.security.jose.jwk.JsonWebKeys; +import com.nimbusds.jose.jwk.JWKSet; @FunctionalInterface public interface KeySetProvider { - JsonWebKeys get() throws AuthenticatorUnavailableException; + JWKSet get() throws AuthenticatorUnavailableException; } diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/KeySetRetriever.java b/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/KeySetRetriever.java index 9ef50a4404..05f0d88768 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/KeySetRetriever.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/KeySetRetriever.java @@ -12,11 +12,11 @@ package com.amazon.dlic.auth.http.jwt.keybyoidc; import java.io.IOException; +import java.text.ParseException; import java.util.concurrent.TimeUnit; +import com.nimbusds.jose.jwk.JWKSet; import joptsimple.internal.Strings; -import org.apache.cxf.rs.security.jose.jwk.JsonWebKeys; -import org.apache.cxf.rs.security.jose.jwk.JwkUtils; import org.apache.hc.client5.http.cache.HttpCacheContext; import org.apache.hc.client5.http.cache.HttpCacheStorage; import org.apache.hc.client5.http.classic.methods.HttpGet; @@ -70,7 +70,7 @@ public class KeySetRetriever implements KeySetProvider { configureCache(useCacheForOidConnectEndpoint); } - public JsonWebKeys get() throws AuthenticatorUnavailableException { + public JWKSet get() throws AuthenticatorUnavailableException { String uri = getJwksUri(); try (CloseableHttpClient httpClient = createHttpClient(null)) { @@ -94,10 +94,11 @@ public JsonWebKeys get() throws AuthenticatorUnavailableException { if (httpEntity == null) { throw new AuthenticatorUnavailableException("Error while getting " + uri + ": Empty response entity"); } - - JsonWebKeys keySet = JwkUtils.readJwkSet(httpEntity.getContent()); + JWKSet keySet = JWKSet.load(httpEntity.getContent()); return keySet; + } catch (ParseException e) { + throw new RuntimeException(e); } } catch (IOException e) { throw new AuthenticatorUnavailableException("Error while getting " + uri + ": " + e, e); diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SelfRefreshingKeySet.java b/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SelfRefreshingKeySet.java index fe410b171c..d1d823e3a8 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SelfRefreshingKeySet.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SelfRefreshingKeySet.java @@ -19,8 +19,8 @@ import java.util.concurrent.TimeUnit; import com.google.common.base.Strings; -import org.apache.cxf.rs.security.jose.jwk.JsonWebKey; -import org.apache.cxf.rs.security.jose.jwk.JsonWebKeys; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.JWKSet; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -35,7 +35,7 @@ public class SelfRefreshingKeySet implements KeyProvider { TimeUnit.MILLISECONDS, new LinkedBlockingQueue() ); - private volatile JsonWebKeys jsonWebKeys = new JsonWebKeys(); + private volatile JWKSet jsonWebKeys = new JWKSet(); private boolean refreshInProgress = false; private long refreshCount = 0; private long queuedGetCount = 0; @@ -51,7 +51,7 @@ public SelfRefreshingKeySet(KeySetProvider refreshFunction) { this.keySetProvider = refreshFunction; } - public JsonWebKey getKey(String kid) throws AuthenticatorUnavailableException, BadCredentialsException { + public JWK getKey(String kid) throws AuthenticatorUnavailableException, BadCredentialsException { if (Strings.isNullOrEmpty(kid)) { return getKeyWithoutKeyId(); } else { @@ -59,8 +59,8 @@ public JsonWebKey getKey(String kid) throws AuthenticatorUnavailableException, B } } - public synchronized JsonWebKey getKeyAfterRefresh(String kid) throws AuthenticatorUnavailableException, BadCredentialsException { - JsonWebKey result = getKeyAfterRefreshInternal(kid); + public synchronized JWK getKeyAfterRefresh(String kid) throws AuthenticatorUnavailableException, BadCredentialsException { + JWK result = getKeyAfterRefreshInternal(kid); if (result != null) { return result; @@ -71,7 +71,7 @@ public synchronized JsonWebKey getKeyAfterRefresh(String kid) throws Authenticat } } - private synchronized JsonWebKey getKeyAfterRefreshInternal(String kid) throws AuthenticatorUnavailableException { + private synchronized JWK getKeyAfterRefreshInternal(String kid) throws AuthenticatorUnavailableException { if (refreshInProgress) { return waitForRefreshToFinish(kid); } else { @@ -79,11 +79,11 @@ private synchronized JsonWebKey getKeyAfterRefreshInternal(String kid) throws Au } } - private JsonWebKey getKeyWithoutKeyId() throws AuthenticatorUnavailableException, BadCredentialsException { - List keys = jsonWebKeys.getKeys(); + private JWK getKeyWithoutKeyId() throws AuthenticatorUnavailableException, BadCredentialsException { + List keys = jsonWebKeys.getKeys(); if (keys == null || keys.size() == 0) { - JsonWebKey result = getKeyWithRefresh(null); + JWK result = getKeyWithRefresh(null); if (result != null) { return result; @@ -93,7 +93,7 @@ private JsonWebKey getKeyWithoutKeyId() throws AuthenticatorUnavailableException } else if (keys.size() == 1) { return keys.get(0); } else { - JsonWebKey result = getKeyWithRefresh(null); + JWK result = getKeyWithRefresh(null); if (result != null) { return result; @@ -103,8 +103,8 @@ private JsonWebKey getKeyWithoutKeyId() throws AuthenticatorUnavailableException } } - private JsonWebKey getKeyWithKeyId(String kid) throws AuthenticatorUnavailableException, BadCredentialsException { - JsonWebKey result = jsonWebKeys.getKey(kid); + private JWK getKeyWithKeyId(String kid) throws AuthenticatorUnavailableException, BadCredentialsException { + JWK result = jsonWebKeys.getKeyByKeyId(kid); if (result != null) { return result; @@ -119,11 +119,11 @@ private JsonWebKey getKeyWithKeyId(String kid) throws AuthenticatorUnavailableEx return result; } - private synchronized JsonWebKey getKeyWithRefresh(String kid) throws AuthenticatorUnavailableException { + private synchronized JWK getKeyWithRefresh(String kid) throws AuthenticatorUnavailableException { // Always re-check within synchronized to handle any races - JsonWebKey result = getKeySimple(kid); + JWK result = getKeySimple(kid); if (result != null) { return result; @@ -132,9 +132,9 @@ private synchronized JsonWebKey getKeyWithRefresh(String kid) throws Authenticat return getKeyAfterRefreshInternal(kid); } - private JsonWebKey getKeySimple(String kid) { + private JWK getKeySimple(String kid) { if (Strings.isNullOrEmpty(kid)) { - List keys = jsonWebKeys.getKeys(); + List keys = jsonWebKeys.getKeys(); if (keys != null && keys.size() == 1) { return keys.get(0); @@ -143,11 +143,11 @@ private JsonWebKey getKeySimple(String kid) { } } else { - return jsonWebKeys.getKey(kid); + return jsonWebKeys.getKeyByKeyId(kid); } } - private synchronized JsonWebKey waitForRefreshToFinish(String kid) { + private synchronized JWK waitForRefreshToFinish(String kid) { queuedGetCount++; long currentRefreshCount = refreshCount; @@ -160,7 +160,7 @@ private synchronized JsonWebKey waitForRefreshToFinish(String kid) { // Just be optimistic and re-check the key - JsonWebKey result = getKeySimple(kid); + JWK result = getKeySimple(kid); if (result != null) { return result; @@ -177,7 +177,7 @@ private synchronized JsonWebKey waitForRefreshToFinish(String kid) { } } - private synchronized JsonWebKey performRefresh(String kid) { + private synchronized JWK performRefresh(String kid) { if (log.isDebugEnabled()) { log.debug("performRefresh({})", kid); } @@ -209,7 +209,7 @@ private synchronized JsonWebKey performRefresh(String kid) { @Override public void run() { try { - JsonWebKeys newKeys = keySetProvider.get(); + JWKSet newKeys = keySetProvider.get(); if (newKeys == null) { throw new RuntimeException("Refresh function " + keySetProvider + " yielded null"); @@ -247,7 +247,7 @@ public void run() { log.debug(e.toString()); } - JsonWebKey result = getKeySimple(kid); + JWK result = getKeySimple(kid); if (result != null) { return result; diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java index f545c2425b..9f9e654b69 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java @@ -18,6 +18,8 @@ import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; +import java.util.Base64; +import java.util.Date; import java.util.List; import java.util.Optional; import java.util.regex.Matcher; @@ -30,26 +32,26 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ObjectNode; -import com.google.common.base.Strings; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.JWSHeader; +import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.KeyUse; +import com.nimbusds.jose.jwk.OctetSequenceKey; +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.SignedJWT; import com.onelogin.saml2.authn.SamlResponse; import com.onelogin.saml2.exception.ValidationError; import com.onelogin.saml2.settings.Saml2Settings; import com.onelogin.saml2.util.Util; import org.apache.commons.lang3.StringUtils; -import org.apache.cxf.jaxrs.json.basic.JsonMapObjectReaderWriter; -import org.apache.cxf.rs.security.jose.jwk.JsonWebKey; -import org.apache.cxf.rs.security.jose.jwk.KeyType; -import org.apache.cxf.rs.security.jose.jwk.PublicKeyUse; -import org.apache.cxf.rs.security.jose.jws.JwsUtils; -import org.apache.cxf.rs.security.jose.jwt.JoseJwtProducer; -import org.apache.cxf.rs.security.jose.jwt.JwtClaims; -import org.apache.cxf.rs.security.jose.jwt.JwtToken; -import org.apache.cxf.rs.security.jose.jwt.JwtUtils; import org.apache.http.HttpStatus; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.joda.time.DateTime; +import org.opensearch.core.common.Strings; + import org.opensearch.OpenSearchSecurityException; import org.opensearch.SpecialPermission; import org.opensearch.core.common.bytes.BytesReference; @@ -62,13 +64,14 @@ import org.opensearch.security.dlic.rest.api.AuthTokenProcessorAction; import org.opensearch.security.filter.SecurityResponse; +import static org.opensearch.security.authtoken.jwt.KeyPaddingUtil.padSecret; + class AuthTokenProcessorHandler { private static final Logger log = LogManager.getLogger(AuthTokenProcessorHandler.class); private static final Logger token_log = LogManager.getLogger("com.amazon.dlic.auth.http.saml.Token"); private static final Pattern EXPIRY_SETTINGS_PATTERN = Pattern.compile("\\s*(\\w+)\\s*(?:\\+\\s*(\\w+))?\\s*"); private Saml2SettingsProvider saml2SettingsProvider; - private JoseJwtProducer jwtProducer; private String jwtSubjectKey; private String jwtRolesKey; private String samlSubjectKey; @@ -77,8 +80,8 @@ class AuthTokenProcessorHandler { private long expiryOffset = 0; private ExpiryBaseValue expiryBaseValue = ExpiryBaseValue.AUTO; - private JsonWebKey signingKey; - private JsonMapObjectReaderWriter jsonMapReaderWriter = new JsonMapObjectReaderWriter(); + private JWK signingKey; + private JWSHeader jwsHeader; private Pattern samlRolesSeparatorPattern; AuthTokenProcessorHandler(Settings settings, Settings jwtSettings, Saml2SettingsProvider saml2SettingsProvider) throws Exception { @@ -113,10 +116,7 @@ class AuthTokenProcessorHandler { this.initJwtExpirySettings(settings); this.signingKey = this.createJwkFromSettings(settings, jwtSettings); - - this.jwtProducer = new JoseJwtProducer(); - this.jwtProducer.setSignatureProvider(JwsUtils.getSignatureProvider(this.signingKey)); - + this.jwsHeader = this.createJwsHeaderFromSettings(); } @SuppressWarnings("removal") @@ -243,80 +243,68 @@ private Optional handleLowLevel(RestRequest restRequest) throw } } - JsonWebKey createJwkFromSettings(Settings settings, Settings jwtSettings) throws Exception { + private JWSHeader createJwsHeaderFromSettings() { + JWSHeader.Builder jwsHeaderBuilder = new JWSHeader.Builder(JWSAlgorithm.HS512); + return jwsHeaderBuilder.build(); + } + JWK createJwkFromSettings(Settings settings, Settings jwtSettings) throws Exception { String exchangeKey = settings.get("exchange_key"); if (!Strings.isNullOrEmpty(exchangeKey)) { + exchangeKey = padSecret(new String(Base64.getDecoder().decode(exchangeKey), StandardCharsets.UTF_8), JWSAlgorithm.HS512); - JsonWebKey jwk = new JsonWebKey(); - - jwk.setKeyType(KeyType.OCTET); - jwk.setAlgorithm("HS512"); - jwk.setPublicKeyUse(PublicKeyUse.SIGN); - jwk.setProperty("k", exchangeKey); - - return jwk; + return new OctetSequenceKey.Builder(exchangeKey.getBytes(StandardCharsets.UTF_8)).algorithm(JWSAlgorithm.HS512) + .keyUse(KeyUse.SIGNATURE) + .build(); } else { - Settings jwkSettings = jwtSettings.getAsSettings("key"); - if (jwkSettings.isEmpty()) { + if (!jwkSettings.hasValue("k") && !Strings.isNullOrEmpty(jwkSettings.get("k"))) { throw new Exception( "Settings for key exchange missing. Please specify at least the option exchange_key with a shared secret." ); } - JsonWebKey jwk = new JsonWebKey(); - - for (String key : jwkSettings.keySet()) { - jwk.setProperty(key, jwkSettings.get(key)); - } + String k = padSecret(new String(Base64.getDecoder().decode(jwkSettings.get("k")), StandardCharsets.UTF_8), JWSAlgorithm.HS512); - return jwk; + return new OctetSequenceKey.Builder(k.getBytes(StandardCharsets.UTF_8)).algorithm(JWSAlgorithm.HS512) + .keyUse(KeyUse.SIGNATURE) + .build(); } } private String createJwt(SamlResponse samlResponse) throws Exception { - JwtClaims jwtClaims = new JwtClaims(); - JwtToken jwt = new JwtToken(jwtClaims); - - jwtClaims.setNotBefore(System.currentTimeMillis() / 1000); - jwtClaims.setExpiryTime(getJwtExpiration(samlResponse)); - - jwtClaims.setProperty(this.jwtSubjectKey, this.extractSubject(samlResponse)); + JWTClaimsSet.Builder jwtClaimsBuilder = new JWTClaimsSet.Builder().notBeforeTime(new Date()) + .expirationTime(new Date(getJwtExpiration(samlResponse))) + .claim(this.jwtSubjectKey, this.extractSubject(samlResponse)); if (this.samlSubjectKey != null) { - jwtClaims.setProperty("saml_ni", samlResponse.getNameId()); + jwtClaimsBuilder.claim("saml_ni", samlResponse.getNameId()); } - if (samlResponse.getNameIdFormat() != null) { - jwtClaims.setProperty("saml_nif", SamlNameIdFormat.getByUri(samlResponse.getNameIdFormat()).getShortName()); + jwtClaimsBuilder.claim("saml_nif", SamlNameIdFormat.getByUri(samlResponse.getNameIdFormat()).getShortName()); } String sessionIndex = samlResponse.getSessionIndex(); if (sessionIndex != null) { - jwtClaims.setProperty("saml_si", sessionIndex); + jwtClaimsBuilder.claim("saml_si", sessionIndex); } if (this.samlRolesKey != null && this.jwtRolesKey != null) { String[] roles = this.extractRoles(samlResponse); - jwtClaims.setProperty(this.jwtRolesKey, roles); + jwtClaimsBuilder.claim(this.jwtRolesKey, roles); } + JWTClaimsSet jwtClaims = jwtClaimsBuilder.build(); + SignedJWT jwt = new SignedJWT(this.jwsHeader, jwtClaims); + jwt.sign(new DefaultJWSSignerFactory().createJWSSigner(this.signingKey)); - String encodedJwt = this.jwtProducer.processJwt(jwt); + String encodedJwt = jwt.serialize(); if (token_log.isDebugEnabled()) { - token_log.debug( - "Created JWT: " - + encodedJwt - + "\n" - + jsonMapReaderWriter.toJson(jwt.getJwsHeaders()) - + "\n" - + JwtUtils.claimsToJson(jwt.getClaims()) - ); + token_log.debug("Created JWT: " + encodedJwt + "\n" + jwt.getHeader().toString() + "\n" + jwt.getJWTClaimsSet().toString()); } return encodedJwt; @@ -326,10 +314,10 @@ private long getJwtExpiration(SamlResponse samlResponse) throws Exception { DateTime sessionNotOnOrAfter = samlResponse.getSessionNotOnOrAfter(); if (this.expiryBaseValue == ExpiryBaseValue.NOW) { - return System.currentTimeMillis() / 1000 + this.expiryOffset; + return System.currentTimeMillis() + this.expiryOffset * 1000; } else if (this.expiryBaseValue == ExpiryBaseValue.SESSION) { if (sessionNotOnOrAfter != null) { - return sessionNotOnOrAfter.getMillis() / 1000 + this.expiryOffset; + return sessionNotOnOrAfter.getMillis() + this.expiryOffset * 1000; } else { throw new Exception("Error while determining JWT expiration time: SamlResponse did not contain sessionNotOnOrAfter value"); } @@ -337,9 +325,9 @@ private long getJwtExpiration(SamlResponse samlResponse) throws Exception { // AUTO if (sessionNotOnOrAfter != null) { - return sessionNotOnOrAfter.getMillis() / 1000; + return sessionNotOnOrAfter.getMillis(); } else { - return System.currentTimeMillis() / 1000 + (this.expiryOffset > 0 ? this.expiryOffset : 60 * 60); + return System.currentTimeMillis() + (this.expiryOffset > 0 ? this.expiryOffset * 1000 : 60 * 60_000); } } } @@ -440,7 +428,7 @@ private enum ExpiryBaseValue { SESSION } - public JsonWebKey getSigningKey() { + public JWK getSigningKey() { return signingKey; } } diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java index 16866734e8..918e3be5ab 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java @@ -27,6 +27,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Strings; +import com.nimbusds.jose.jwk.JWK; import com.onelogin.saml2.authn.AuthnRequest; import com.onelogin.saml2.logout.LogoutRequest; import com.onelogin.saml2.settings.Saml2Settings; @@ -36,7 +37,6 @@ import net.shibboleth.utilities.java.support.component.DestructableComponent; import net.shibboleth.utilities.java.support.xml.BasicParserPool; import org.apache.commons.lang3.StringEscapeUtils; -import org.apache.cxf.rs.security.jose.jwk.JsonWebKey; import org.apache.http.HttpStatus; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -480,12 +480,12 @@ protected KeyProvider initKeyProvider(Settings settings, Path configPath) throws return new KeyProvider() { @Override - public JsonWebKey getKeyAfterRefresh(String kid) throws AuthenticatorUnavailableException, BadCredentialsException { + public JWK getKeyAfterRefresh(String kid) throws AuthenticatorUnavailableException, BadCredentialsException { return authTokenProcessorHandler.getSigningKey(); } @Override - public JsonWebKey getKey(String kid) throws AuthenticatorUnavailableException, BadCredentialsException { + public JWK getKey(String kid) throws AuthenticatorUnavailableException, BadCredentialsException { return authTokenProcessorHandler.getSigningKey(); } }; diff --git a/src/main/java/org/opensearch/security/authtoken/jwt/JwtVendor.java b/src/main/java/org/opensearch/security/authtoken/jwt/JwtVendor.java index e68a5ef2d7..5500eb5588 100644 --- a/src/main/java/org/opensearch/security/authtoken/jwt/JwtVendor.java +++ b/src/main/java/org/opensearch/security/authtoken/jwt/JwtVendor.java @@ -11,59 +11,58 @@ package org.opensearch.security.authtoken.jwt; -import java.time.Instant; +import java.text.ParseException; +import java.util.Base64; +import java.util.Date; import java.util.List; import java.util.Optional; import java.util.function.LongSupplier; -import org.apache.cxf.jaxrs.json.basic.JsonMapObjectReaderWriter; -import org.apache.cxf.rs.security.jose.jwk.JsonWebKey; -import org.apache.cxf.rs.security.jose.jwk.KeyType; -import org.apache.cxf.rs.security.jose.jwk.PublicKeyUse; -import org.apache.cxf.rs.security.jose.jws.JwsUtils; -import org.apache.cxf.rs.security.jose.jwt.JoseJwtProducer; -import org.apache.cxf.rs.security.jose.jwt.JwtClaims; -import org.apache.cxf.rs.security.jose.jwt.JwtToken; -import org.apache.cxf.rs.security.jose.jwt.JwtUtils; +import com.nimbusds.jose.JOSEException; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.JWSHeader; +import com.nimbusds.jose.JWSSigner; +import com.nimbusds.jose.KeyLengthException; +import com.nimbusds.jose.crypto.MACSigner; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.KeyUse; +import com.nimbusds.jose.jwk.OctetSequenceKey; +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.SignedJWT; + +import org.opensearch.OpenSearchException; +import org.opensearch.common.collect.Tuple; import org.opensearch.common.settings.Settings; -import org.opensearch.security.ssl.util.ExceptionUtils; import static org.opensearch.security.util.AuthTokenUtils.isKeyNull; public class JwtVendor { private static final Logger logger = LogManager.getLogger(JwtVendor.class); - private static JsonMapObjectReaderWriter jsonMapReaderWriter = new JsonMapObjectReaderWriter(); - - private final String claimsEncryptionKey; - private final JsonWebKey signingKey; - private final JoseJwtProducer jwtProducer; + private final JWK signingKey; + private final JWSSigner signer; private final LongSupplier timeProvider; private final EncryptionDecryptionUtil encryptionDecryptionUtil; - private final Integer defaultExpirySeconds = 300; - private final Integer maxExpirySeconds = 600; + private static final Integer DEFAULT_EXPIRY_SECONDS = 300; + private static final Integer MAX_EXPIRY_SECONDS = 600; public JwtVendor(final Settings settings, final Optional timeProvider) { - JoseJwtProducer jwtProducer = new JoseJwtProducer(); - try { - this.signingKey = createJwkFromSettings(settings); - } catch (Exception e) { - throw ExceptionUtils.createJwkCreationException(e); - } - this.jwtProducer = jwtProducer; + final Tuple tuple = createJwkFromSettings(settings); + signingKey = tuple.v1(); + signer = tuple.v2(); + if (isKeyNull(settings, "encryption_key")) { throw new IllegalArgumentException("encryption_key cannot be null"); } else { - this.claimsEncryptionKey = settings.get("encryption_key"); - this.encryptionDecryptionUtil = new EncryptionDecryptionUtil(claimsEncryptionKey); + this.encryptionDecryptionUtil = new EncryptionDecryptionUtil(settings.get("encryption_key")); } if (timeProvider.isPresent()) { this.timeProvider = timeProvider.get(); } else { - this.timeProvider = () -> System.currentTimeMillis() / 1000; + this.timeProvider = () -> System.currentTimeMillis(); } } @@ -73,34 +72,32 @@ public JwtVendor(final Settings settings, final Optional timeProvi * PublicKeyUse: SIGN * Encryption Algorithm: HS512 * */ - static JsonWebKey createJwkFromSettings(Settings settings) throws Exception { + static Tuple createJwkFromSettings(Settings settings) { + final OctetSequenceKey key; if (!isKeyNull(settings, "signing_key")) { String signingKey = settings.get("signing_key"); - - JsonWebKey jwk = new JsonWebKey(); - - jwk.setKeyType(KeyType.OCTET); - jwk.setAlgorithm("HS512"); - jwk.setPublicKeyUse(PublicKeyUse.SIGN); - jwk.setProperty("k", signingKey); - - return jwk; + key = new OctetSequenceKey.Builder(Base64.getDecoder().decode(signingKey)).algorithm(JWSAlgorithm.HS512) + .keyUse(KeyUse.SIGNATURE) + .build(); } else { Settings jwkSettings = settings.getAsSettings("jwt").getAsSettings("key"); if (jwkSettings.isEmpty()) { - throw new Exception( + throw new OpenSearchException( "Settings for signing key is missing. Please specify at least the option signing_key with a shared secret." ); } - JsonWebKey jwk = new JsonWebKey(); - - for (String key : jwkSettings.keySet()) { - jwk.setProperty(key, jwkSettings.get(key)); - } + String signingKey = jwkSettings.get("k"); + key = new OctetSequenceKey.Builder(Base64.getDecoder().decode(signingKey)).algorithm(JWSAlgorithm.HS512) + .keyUse(KeyUse.SIGNATURE) + .build(); + } - return jwk; + try { + return new Tuple<>(key, new MACSigner(key)); + } catch (KeyLengthException kle) { + throw new OpenSearchException(kle); } } @@ -112,60 +109,53 @@ public String createJwt( List roles, List backendRoles, boolean roleSecurityMode - ) throws Exception { - final long nowAsMillis = timeProvider.getAsLong(); - final Instant nowAsInstant = Instant.ofEpochMilli(timeProvider.getAsLong()); - - jwtProducer.setSignatureProvider(JwsUtils.getSignatureProvider(signingKey)); - JwtClaims jwtClaims = new JwtClaims(); - JwtToken jwt = new JwtToken(jwtClaims); - - jwtClaims.setIssuer(issuer); - - jwtClaims.setIssuedAt(nowAsMillis); - - jwtClaims.setSubject(subject); - - jwtClaims.setAudience(audience); - - jwtClaims.setNotBefore(nowAsMillis); - - if (expirySeconds > maxExpirySeconds) { - throw new Exception("The provided expiration time exceeds the maximum allowed duration of " + maxExpirySeconds + " seconds"); + ) throws JOSEException, ParseException { + final Date now = new Date(timeProvider.getAsLong()); + + final JWTClaimsSet.Builder claimsBuilder = new JWTClaimsSet.Builder(); + claimsBuilder.issuer(issuer); + claimsBuilder.issueTime(now); + claimsBuilder.subject(subject); + claimsBuilder.audience(audience); + claimsBuilder.notBeforeTime(now); + + if (expirySeconds > MAX_EXPIRY_SECONDS) { + throw new IllegalArgumentException( + "The provided expiration time exceeds the maximum allowed duration of " + MAX_EXPIRY_SECONDS + " seconds" + ); } - expirySeconds = (expirySeconds == null) ? defaultExpirySeconds : Math.min(expirySeconds, maxExpirySeconds); + expirySeconds = (expirySeconds == null) ? DEFAULT_EXPIRY_SECONDS : Math.min(expirySeconds, MAX_EXPIRY_SECONDS); if (expirySeconds <= 0) { - throw new Exception("The expiration time should be a positive integer"); + throw new IllegalArgumentException("The expiration time should be a positive integer"); } - long expiryTime = timeProvider.getAsLong() + expirySeconds; - jwtClaims.setExpiryTime(expiryTime); + final Date expiryTime = new Date(timeProvider.getAsLong() + expirySeconds * 1000); + claimsBuilder.expirationTime(expiryTime); if (roles != null) { String listOfRoles = String.join(",", roles); - jwtClaims.setProperty("er", encryptionDecryptionUtil.encrypt(listOfRoles)); + claimsBuilder.claim("er", encryptionDecryptionUtil.encrypt(listOfRoles)); } else { - throw new Exception("Roles cannot be null"); + throw new IllegalArgumentException("Roles cannot be null"); } if (!roleSecurityMode && backendRoles != null) { String listOfBackendRoles = String.join(",", backendRoles); - jwtClaims.setProperty("br", listOfBackendRoles); + claimsBuilder.claim("br", listOfBackendRoles); } - String encodedJwt = jwtProducer.processJwt(jwt); + final JWSHeader header = new JWSHeader.Builder(JWSAlgorithm.parse(signingKey.getAlgorithm().getName())).build(); + final SignedJWT signedJwt = new SignedJWT(header, claimsBuilder.build()); + + // Sign the JWT so it can be serialized + signedJwt.sign(signer); if (logger.isDebugEnabled()) { logger.debug( - "Created JWT: " - + encodedJwt - + "\n" - + jsonMapReaderWriter.toJson(jwt.getJwsHeaders()) - + "\n" - + JwtUtils.claimsToJson(jwt.getClaims()) + "Created JWT: " + signedJwt.serialize() + "\n" + signedJwt.getHeader().toJSONObject() + "\n" + signedJwt.getJWTClaimsSet() ); } - return encodedJwt; + return signedJwt.serialize(); } } diff --git a/src/main/java/org/opensearch/security/authtoken/jwt/KeyPaddingUtil.java b/src/main/java/org/opensearch/security/authtoken/jwt/KeyPaddingUtil.java new file mode 100644 index 0000000000..41bf2955f2 --- /dev/null +++ b/src/main/java/org/opensearch/security/authtoken/jwt/KeyPaddingUtil.java @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.authtoken.jwt; + +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.util.ByteUtils; +import org.apache.commons.lang3.StringUtils; + +import static com.nimbusds.jose.crypto.MACSigner.getMinRequiredSecretLength; + +public class KeyPaddingUtil { + public static String padSecret(String signingKey, JWSAlgorithm jwsAlgorithm) { + int requiredSecretLength; + try { + requiredSecretLength = getMinRequiredSecretLength(jwsAlgorithm); + } catch (JOSEException e) { + throw new RuntimeException(e); + } + int requiredByteLength = ByteUtils.byteLength(requiredSecretLength); + // padding the signing key with 0s to meet the minimum required length + return StringUtils.rightPad(signingKey, requiredByteLength, "\0"); + } +} diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticatorTest.java index 4a28c0a752..225b93dbbb 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticatorTest.java @@ -67,7 +67,7 @@ public void testEmptyKey() throws Exception { } @Test - public void testBadKey() throws Exception { + public void testBadKey() { final AuthCredentials credentials = extractCredentialsFromJwtHeader( Settings.builder().put("signing_key", BaseEncoding.base64().encode(new byte[] { 1, 3, 3, 4, 3, 6, 7, 8, 3, 10 })), @@ -78,7 +78,7 @@ public void testBadKey() throws Exception { } @Test - public void testTokenMissing() throws Exception { + public void testTokenMissing() { Settings settings = Settings.builder().put("signing_key", BaseEncoding.base64().encode(secretKeyBytes)).build(); @@ -111,6 +111,79 @@ public void testInvalid() throws Exception { Assert.assertNull(credentials); } + /** Here is the original encoded jwt token generation with cxf library: + * + * String base64EncodedSecret = Base64.getEncoder().encodeToString(someSecret.getBytes(StandardCharsets.UTF_8)); + * JwtClaims claims = new JwtClaims(); + * claims.setNotBefore(854113533); + * claim.setExpiration(4853843133) + * claims.setSubject("horst"); + * claims.setProperty("saml_nif", "u"); + * claims.setProperty("saml_si", "MOCKSAML_3"); + * JwsSignatureProvider jwsSignatureProvider = new HmacJwsSignatureProvider(base64EncodedSecret, SignatureAlgorithm.HS512); + * JweEncryptionProvider jweEncryptionProvider = null; + * JoseJwtProducer producer = new JoseJwtProducer(); + * String encodedCxfJwt = producer.processJwt(jwtToken, jweEncryptionProvider, jwsSignatureProvider); + */ + @Test + public void testParsePrevGeneratedJwt() { + String encodedCxfJwt = + "eyJhbGciOiJIUzUxMiJ9.eyJzdWIiOiJob3JzdCIsIm5iZiI6ODU0MTEzNTMzLCJzYW1sX25pZiI6InUiLCJleHAiOjQ4NTM4NDMxMzMsInNhbWxfc2kiOiJNT0NLU0FNTF8zIn0.MQ9lidZ774EPHjDNB43O4d2Q1SGtG4-lASoLXDPdtE0qJGvZOYDUCN3h2HxBIX5NmwXQQvjJ2PUzN6f6FgY0Iw"; + Settings settings = Settings.builder() + .put( + "signing_key", + BaseEncoding.base64() + .encode( + "thisIsSecretThatIsVeryHardToCrackItsPracticallyImpossibleToDothisIsSecretThatIsVeryHardToCrackItsPracticallyImpossibleToDo" + .getBytes(StandardCharsets.UTF_8) + ) + ) + .build(); + + HTTPJwtAuthenticator jwtAuth = new HTTPJwtAuthenticator(settings, null); + Map headers = new HashMap(); + headers.put("Authorization", "Bearer " + encodedCxfJwt); + + AuthCredentials credentials = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); + + Assert.assertNotNull(credentials); + Assert.assertEquals("horst", credentials.getUsername()); + Assert.assertEquals(0, credentials.getBackendRoles().size()); + Assert.assertEquals(5, credentials.getAttributes().size()); + Assert.assertEquals("854113533", credentials.getAttributes().get("attr.jwt.nbf")); + Assert.assertEquals("4853843133", credentials.getAttributes().get("attr.jwt.exp")); + } + + @Test + public void testFailToParsePrevGeneratedJwt() { + String jwsToken = + "eyJhbGciOiJIUzUxMiJ9.eyJuYmYiOjE2OTgxNTE4ODQsImV4cCI6MTY5ODE1NTQ4NCwic3ViIjoiaG9yc3QiLCJzYW1sX25pZiI6InUiLCJzYW1sX3NpIjoiTU9DS1NBTUxfMyIsInJvbGVzIjpudWxsfQ.E_MP8wVVu1P7_RATtjhnCvPft2gQTFdY5NlmRTCsrjdDXTUfxkswktWCB_k_wXDKCuNukNlSL2FSo3EV2VtUEQ"; + Settings settings = Settings.builder() + .put( + "signing_key", + BaseEncoding.base64() + .encode( + "additionalDatathisIsSecretThatIsVeryHardToCrackItsPracticallyImpossibleToDothisIsSecretThatIsVeryHardToCrackItsPracticallyImpossibleToDo" + .getBytes(StandardCharsets.UTF_8) + ) + ) + .build(); + + HTTPJwtAuthenticator jwtAuth = new HTTPJwtAuthenticator(settings, null); + Map headers = new HashMap(); + headers.put("Authorization", "Bearer " + jwsToken); + + AuthCredentials credentials = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); + + Assert.assertNull(credentials); + } + @Test public void testBearer() throws Exception { diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java index d483a2ec81..a31e30db39 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java @@ -11,6 +11,7 @@ package com.amazon.dlic.auth.http.jwt.keybyoidc; import java.util.HashMap; +import java.util.List; import com.google.common.collect.ImmutableMap; import org.junit.AfterClass; @@ -52,14 +53,13 @@ public void basicTest() { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap()) - .asSecurityRequest(), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap<>()).asSecurityRequest(), null ); Assert.assertNotNull(creds); Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); - Assert.assertEquals(TestJwts.TEST_AUDIENCE, creds.getAttributes().get("attr.jwt.aud")); + Assert.assertEquals(List.of(TestJwts.TEST_AUDIENCE).toString(), creds.getAttributes().get("attr.jwt.aud")); Assert.assertEquals(0, creds.getBackendRoles().size()); Assert.assertEquals(4, creds.getAttributes().size()); } @@ -81,7 +81,7 @@ public void jwksUriTest() { Assert.assertNotNull(creds); Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); - Assert.assertEquals(TestJwts.TEST_AUDIENCE, creds.getAttributes().get("attr.jwt.aud")); + Assert.assertEquals(List.of(TestJwts.TEST_AUDIENCE).toString(), creds.getAttributes().get("attr.jwt.aud")); Assert.assertEquals(0, creds.getBackendRoles().size()); Assert.assertEquals(4, creds.getAttributes().size()); } @@ -187,7 +187,7 @@ public void testEscapeKid() { Assert.assertNotNull(creds); Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); - Assert.assertEquals(TestJwts.TEST_AUDIENCE, creds.getAttributes().get("attr.jwt.aud")); + Assert.assertEquals(List.of(TestJwts.TEST_AUDIENCE).toString(), creds.getAttributes().get("attr.jwt.aud")); Assert.assertEquals(0, creds.getBackendRoles().size()); Assert.assertEquals(4, creds.getAttributes().size()); } @@ -210,13 +210,13 @@ public void bearerTest() { Assert.assertNotNull(creds); Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); - Assert.assertEquals(TestJwts.TEST_AUDIENCE, creds.getAttributes().get("attr.jwt.aud")); + Assert.assertEquals(List.of(TestJwts.TEST_AUDIENCE).toString(), creds.getAttributes().get("attr.jwt.aud")); Assert.assertEquals(0, creds.getBackendRoles().size()); Assert.assertEquals(4, creds.getAttributes().size()); } @Test - public void testRoles() throws Exception { + public void testRoles() { Settings settings = Settings.builder() .put("openid_connect_url", mockIdpServer.getDiscoverUri()) .put("roles_key", TestJwts.ROLES_CLAIM) @@ -238,16 +238,14 @@ public void testRoles() throws Exception { } @Test - public void testExp() throws Exception { + public void testExp() { Settings settings = Settings.builder().put("openid_connect_url", mockIdpServer.getDiscoverUri()).build(); HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest( - ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_EXPIRED_SIGNED_OCT_1), - new HashMap() - ).asSecurityRequest(), + new FakeRestRequest(ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_EXPIRED_SIGNED_OCT_1), new HashMap<>()) + .asSecurityRequest(), null ); @@ -255,7 +253,7 @@ public void testExp() throws Exception { } @Test - public void testExpInSkew() throws Exception { + public void testExpInSkew() { Settings settings = Settings.builder() .put("openid_connect_url", mockIdpServer.getDiscoverUri()) .put("jwt_clock_skew_tolerance_seconds", "10") @@ -271,7 +269,7 @@ public void testExpInSkew() throws Exception { AuthCredentials creds = jwtAuth.extractCredentials( new FakeRestRequest( ImmutableMap.of("Authorization", "Bearer " + TestJwts.createMcCoySignedOct1(notBeforeDate, expiringDate)), - new HashMap() + new HashMap<>() ).asSecurityRequest(), null ); @@ -280,7 +278,7 @@ public void testExpInSkew() throws Exception { } @Test - public void testNbf() throws Exception { + public void testNbf() { Settings settings = Settings.builder() .put("openid_connect_url", mockIdpServer.getDiscoverUri()) .put("jwt_clock_skew_tolerance_seconds", "0") @@ -296,7 +294,7 @@ public void testNbf() throws Exception { AuthCredentials creds = jwtAuth.extractCredentials( new FakeRestRequest( ImmutableMap.of("Authorization", "Bearer " + TestJwts.createMcCoySignedOct1(notBeforeDate, expiringDate)), - new HashMap() + new HashMap<>() ).asSecurityRequest(), null ); @@ -305,7 +303,7 @@ public void testNbf() throws Exception { } @Test - public void testNbfInSkew() throws Exception { + public void testNbfInSkew() { Settings settings = Settings.builder() .put("openid_connect_url", mockIdpServer.getDiscoverUri()) .put("jwt_clock_skew_tolerance_seconds", "10") @@ -321,7 +319,7 @@ public void testNbfInSkew() throws Exception { AuthCredentials creds = jwtAuth.extractCredentials( new FakeRestRequest( ImmutableMap.of("Authorization", "Bearer " + TestJwts.createMcCoySignedOct1(notBeforeDate, expiringDate)), - new HashMap() + new HashMap<>() ).asSecurityRequest(), null ); @@ -330,7 +328,7 @@ public void testNbfInSkew() throws Exception { } @Test - public void testRS256() throws Exception { + public void testRS256() { Settings settings = Settings.builder() .put("openid_connect_url", mockIdpServer.getDiscoverUri()) @@ -341,28 +339,26 @@ public void testRS256() throws Exception { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_1), new HashMap()) - .asSecurityRequest(), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_1), new HashMap<>()).asSecurityRequest(), null ); Assert.assertNotNull(creds); Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); - Assert.assertEquals(TestJwts.TEST_AUDIENCE, creds.getAttributes().get("attr.jwt.aud")); + Assert.assertEquals(List.of(TestJwts.TEST_AUDIENCE).toString(), creds.getAttributes().get("attr.jwt.aud")); Assert.assertEquals(0, creds.getBackendRoles().size()); Assert.assertEquals(4, creds.getAttributes().size()); } @Test - public void testBadSignature() throws Exception { + public void testBadSignature() { Settings settings = Settings.builder().put("openid_connect_url", mockIdpServer.getDiscoverUri()).build(); HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_X), new HashMap()) - .asSecurityRequest(), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_X), new HashMap<>()).asSecurityRequest(), null ); @@ -380,16 +376,14 @@ public void testPeculiarJsonEscaping() { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest( - ImmutableMap.of("Authorization", TestJwts.PeculiarEscaping.MC_COY_SIGNED_RSA_1), - new HashMap() - ).asSecurityRequest(), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.PeculiarEscaping.MC_COY_SIGNED_RSA_1), new HashMap<>()) + .asSecurityRequest(), null ); Assert.assertNotNull(creds); Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); - Assert.assertEquals(TestJwts.TEST_AUDIENCE, creds.getAttributes().get("attr.jwt.aud")); + Assert.assertEquals(List.of(TestJwts.TEST_AUDIENCE).toString(), creds.getAttributes().get("attr.jwt.aud")); Assert.assertEquals(0, creds.getBackendRoles().size()); Assert.assertEquals(4, creds.getAttributes().size()); } diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/MockIpdServer.java b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/MockIpdServer.java index 68f852da5c..20c71b0340 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/MockIpdServer.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/MockIpdServer.java @@ -24,7 +24,7 @@ import javax.net.ssl.SSLParameters; import javax.net.ssl.TrustManagerFactory; -import org.apache.cxf.rs.security.jose.jwk.JsonWebKeys; +import com.nimbusds.jose.jwk.JWKSet; import org.apache.hc.core5.function.Callback; import org.apache.hc.core5.http.ClassicHttpRequest; import org.apache.hc.core5.http.ClassicHttpResponse; @@ -42,8 +42,6 @@ import org.opensearch.security.test.helper.file.FileHelper; import org.opensearch.security.test.helper.network.SocketUtils; -import static com.amazon.dlic.auth.http.jwt.keybyoidc.CxfTestTools.toJson; - class MockIpdServer implements Closeable { final static String CTX_DISCOVER = "/discover"; final static String CTX_KEYS = "/api/oauth/keys"; @@ -52,13 +50,13 @@ class MockIpdServer implements Closeable { private final int port; private final String uri; private final boolean ssl; - private final JsonWebKeys jwks; + private final JWKSet jwks; - MockIpdServer(JsonWebKeys jwks) throws IOException { + MockIpdServer(JWKSet jwks) throws IOException { this(jwks, SocketUtils.findAvailableTcpPort(), false); } - MockIpdServer(JsonWebKeys jwks, int port, boolean ssl) throws IOException { + MockIpdServer(JWKSet jwks, int port, boolean ssl) throws IOException { this.port = port; this.uri = (ssl ? "https" : "http") + "://localhost:" + port; this.ssl = ssl; @@ -143,7 +141,7 @@ protected void handleDiscoverRequest(HttpRequest request, ClassicHttpResponse re protected void handleKeysRequest(HttpRequest request, ClassicHttpResponse response, HttpContext context) throws HttpException, IOException { response.setCode(200); - response.setEntity(new StringEntity(toJson(jwks))); + response.setEntity(new StringEntity(jwks.toString(false))); } private SSLContext createSSLContext() { diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SelfRefreshingKeySetTest.java b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SelfRefreshingKeySetTest.java index 6bbce7d85d..ba7f65b7ee 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SelfRefreshingKeySetTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SelfRefreshingKeySetTest.java @@ -15,8 +15,9 @@ import java.util.concurrent.Executors; import java.util.concurrent.Future; -import org.apache.cxf.rs.security.jose.jwk.JsonWebKey; -import org.apache.cxf.rs.security.jose.jwk.JsonWebKeys; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.OctetSequenceKey; import org.junit.Assert; import org.junit.Test; @@ -26,12 +27,12 @@ public class SelfRefreshingKeySetTest { public void basicTest() throws AuthenticatorUnavailableException, BadCredentialsException { SelfRefreshingKeySet selfRefreshingKeySet = new SelfRefreshingKeySet(new MockKeySetProvider()); - JsonWebKey key1 = selfRefreshingKeySet.getKey("kid/a"); - Assert.assertEquals(TestJwk.OCT_1_K, key1.getProperty("k")); + OctetSequenceKey key1 = (OctetSequenceKey) selfRefreshingKeySet.getKey("kid/a"); + Assert.assertEquals(TestJwk.OCT_1_K, key1.getKeyValue().decodeToString()); Assert.assertEquals(1, selfRefreshingKeySet.getRefreshCount()); - JsonWebKey key2 = selfRefreshingKeySet.getKey("kid/b"); - Assert.assertEquals(TestJwk.OCT_2_K, key2.getProperty("k")); + OctetSequenceKey key2 = (OctetSequenceKey) selfRefreshingKeySet.getKey("kid/b"); + Assert.assertEquals(TestJwk.OCT_2_K, key2.getKeyValue().decodeToString()); Assert.assertEquals(1, selfRefreshingKeySet.getRefreshCount()); try { @@ -51,11 +52,11 @@ public void twoThreadedTest() throws Exception { ExecutorService executorService = Executors.newCachedThreadPool(); - Future f1 = executorService.submit(() -> selfRefreshingKeySet.getKey("kid/a")); + Future f1 = executorService.submit(() -> selfRefreshingKeySet.getKey("kid/a")); provider.waitForCalled(); - Future f2 = executorService.submit(() -> selfRefreshingKeySet.getKey("kid/b")); + Future f2 = executorService.submit(() -> selfRefreshingKeySet.getKey("kid/b")); while (selfRefreshingKeySet.getQueuedGetCount() == 0) { Thread.sleep(10); @@ -63,8 +64,8 @@ public void twoThreadedTest() throws Exception { provider.unblock(); - Assert.assertEquals(TestJwk.OCT_1_K, f1.get().getProperty("k")); - Assert.assertEquals(TestJwk.OCT_2_K, f2.get().getProperty("k")); + Assert.assertEquals(TestJwk.OCT_1_K, ((OctetSequenceKey) f1.get()).getKeyValue().decodeToString()); + Assert.assertEquals(TestJwk.OCT_2_K, ((OctetSequenceKey) f2.get()).getKeyValue().decodeToString()); Assert.assertEquals(1, selfRefreshingKeySet.getRefreshCount()); Assert.assertEquals(1, selfRefreshingKeySet.getQueuedGetCount()); @@ -74,7 +75,7 @@ public void twoThreadedTest() throws Exception { static class MockKeySetProvider implements KeySetProvider { @Override - public JsonWebKeys get() throws AuthenticatorUnavailableException { + public JWKSet get() throws AuthenticatorUnavailableException { return TestJwk.OCT_1_2_3; } @@ -85,7 +86,7 @@ static class BlockingMockKeySetProvider extends MockKeySetProvider { private boolean called = false; @Override - public synchronized JsonWebKeys get() throws AuthenticatorUnavailableException { + public synchronized JWKSet get() throws AuthenticatorUnavailableException { called = true; notifyAll(); diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SingleKeyHTTPJwtKeyByOpenIdConnectAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SingleKeyHTTPJwtKeyByOpenIdConnectAuthenticatorTest.java index 6b5c541981..196e91be21 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SingleKeyHTTPJwtKeyByOpenIdConnectAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SingleKeyHTTPJwtKeyByOpenIdConnectAuthenticatorTest.java @@ -12,6 +12,7 @@ package com.amazon.dlic.auth.http.jwt.keybyoidc; import java.util.HashMap; +import java.util.List; import com.google.common.collect.ImmutableMap; import org.junit.Assert; @@ -39,7 +40,7 @@ public void basicTest() throws Exception { Assert.assertNotNull(creds); Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); - Assert.assertEquals(TestJwts.TEST_AUDIENCE, creds.getAttributes().get("attr.jwt.aud")); + Assert.assertEquals(List.of(TestJwts.TEST_AUDIENCE).toString(), creds.getAttributes().get("attr.jwt.aud")); Assert.assertEquals(0, creds.getBackendRoles().size()); Assert.assertEquals(4, creds.getAttributes().size()); @@ -89,7 +90,7 @@ public void noAlgTest() throws Exception { Assert.assertNotNull(creds); Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); - Assert.assertEquals(TestJwts.TEST_AUDIENCE, creds.getAttributes().get("attr.jwt.aud")); + Assert.assertEquals(List.of(TestJwts.TEST_AUDIENCE).toString(), creds.getAttributes().get("attr.jwt.aud")); Assert.assertEquals(0, creds.getBackendRoles().size()); Assert.assertEquals(4, creds.getAttributes().size()); } finally { @@ -139,7 +140,7 @@ public void keyExchangeTest() throws Exception { Assert.assertNotNull(creds); Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); - Assert.assertEquals(TestJwts.TEST_AUDIENCE, creds.getAttributes().get("attr.jwt.aud")); + Assert.assertEquals(List.of(TestJwts.TEST_AUDIENCE).toString(), creds.getAttributes().get("attr.jwt.aud")); Assert.assertEquals(0, creds.getBackendRoles().size()); Assert.assertEquals(4, creds.getAttributes().size()); @@ -167,7 +168,7 @@ public void keyExchangeTest() throws Exception { Assert.assertNotNull(creds); Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); - Assert.assertEquals(TestJwts.TEST_AUDIENCE, creds.getAttributes().get("attr.jwt.aud")); + Assert.assertEquals(List.of(TestJwts.TEST_AUDIENCE).toString(), creds.getAttributes().get("attr.jwt.aud")); Assert.assertEquals(0, creds.getBackendRoles().size()); Assert.assertEquals(4, creds.getAttributes().size()); @@ -190,7 +191,7 @@ public void keyExchangeTest() throws Exception { Assert.assertNotNull(creds); Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); - Assert.assertEquals(TestJwts.TEST_AUDIENCE, creds.getAttributes().get("attr.jwt.aud")); + Assert.assertEquals(List.of(TestJwts.TEST_AUDIENCE).toString(), creds.getAttributes().get("attr.jwt.aud")); Assert.assertEquals(0, creds.getBackendRoles().size()); Assert.assertEquals(4, creds.getAttributes().size()); diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/TestJwk.java b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/TestJwk.java index 5b0d5738a3..390f5b16f7 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/TestJwk.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/TestJwk.java @@ -11,12 +11,16 @@ package com.amazon.dlic.auth.http.jwt.keybyoidc; +import java.nio.charset.StandardCharsets; import java.util.Arrays; -import org.apache.cxf.rs.security.jose.jwk.JsonWebKey; -import org.apache.cxf.rs.security.jose.jwk.JsonWebKeys; -import org.apache.cxf.rs.security.jose.jwk.KeyType; -import org.apache.cxf.rs.security.jose.jwk.PublicKeyUse; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.KeyUse; +import com.nimbusds.jose.jwk.OctetSequenceKey; +import com.nimbusds.jose.jwk.RSAKey; +import com.nimbusds.jose.util.Base64URL; class TestJwk { @@ -29,13 +33,13 @@ class TestJwk { static final String OCT_3_K = "r3aeW3OK7-B4Hs3hq9BmlT1D3jRiolH9PL82XUz9xAS7dniAdmvMnN5GkOc1vqibOe2T-CC_103UglDm9D0iU9S9zn6wTuQt1L5wfZIoHd9f5IjJ_YFEzZMvsoUY_-ji_0K_ugVvBPwi9JnBQHHS4zrgmP06dGjmcnZDcIf4W_iFas3lDYSXilL1V2QhNaynpSqTarpfBGSphKv4Zg2JhsX8xB0VSaTlEq4lF8pzvpWSxXCW9CtomhB80daSuTizrmSTEPpdN3XzQ2-Tovo1ieMOfDU4csvjEk7Bwc2ThjpnA8ucKQUYpUv9joBxKuCdUltssthWnetrogjYOn_xGA"; - static final JsonWebKey OCT_1 = createOct("kid/a", "HS256", OCT_1_K); - static final JsonWebKey OCT_2 = createOct("kid/b", "HS256", OCT_2_K); - static final JsonWebKey OCT_3 = createOct("kid/c", "HS256", OCT_3_K); - static final JsonWebKey ESCAPED_SLASH_KID_OCT_1 = createOct("kid\\/_a", "HS256", OCT_1_K); - static final JsonWebKey FORWARD_SLASH_KID_OCT_1 = createOct("kid/_a", "HS256", OCT_1_K); + static final JWK OCT_1 = createOct("kid/a", "HS256", OCT_1_K); + static final JWK OCT_2 = createOct("kid/b", "HS256", OCT_2_K); + static final JWK OCT_3 = createOct("kid/c", "HS256", OCT_3_K); + static final JWK ESCAPED_SLASH_KID_OCT_1 = createOct("kid\\/_a", "HS256", OCT_1_K); + static final JWK FORWARD_SLASH_KID_OCT_1 = createOct("kid/_a", "HS256", OCT_1_K); - static final JsonWebKeys OCT_1_2_3 = createJwks(OCT_1, OCT_2, OCT_3, FORWARD_SLASH_KID_OCT_1, ESCAPED_SLASH_KID_OCT_1); + static final JWKSet OCT_1_2_3 = createJwks(OCT_1, OCT_2, OCT_3, FORWARD_SLASH_KID_OCT_1, ESCAPED_SLASH_KID_OCT_1); static final String RSA_1_D = "On8XGMmdM5Fm5hvuhQk-qAkIP2CoK5QMx0OH5m_WDzKXZv8lZ2eg89I4ehBiOKGdw1h_mjmWwTah-evpXV-BF5QpejPQqxkXS-8s5r2AvietQq32jl-gwIwZWTvfzjpT9On0YJZ4q01tMDj3r-YOLUW2xrz3za9tl6pPU_5kP63C-hoj1ybTwcC7ujbCPwhY6yAopMA1v10uVmCxsjsNikEjB6YePgHixez51wO3Z8mXNwefWukFWYJ5T7t4kHMSf5P_8FJZ14u5yvYZnngE_tJCyHFdIDb6UWsrgxomtlQU-SdZYK_NY6gw6mCkjjlqOoYqlsrRJ16kJ81Ds269oQ"; @@ -55,67 +59,53 @@ class TestJwk { "jDDVUMXOXDVcaRVAT5TtuiAsLxk7XAAwyyECfmySZul7D5XVLMtGe6rP2900q3nM4BaCEiuwXjmTCZDAGlFGs2a3eQ1vbBSv9_0KGHL-gZGFPNiv0v8aR7QzZ-abhGnRy5F52PlTWsypGgG_kQpF2t2TBotvYhvVPagAt4ljllDKvY1siOvS3nh4TqcUtWcbgQZEWPmaXuhx0eLmhQJca7UEw99YlGNew48AEzt7ZnfU0Qkz3JwSz7IcPx-NfIh6BN6LwAg_ASdoM3MR8rDOtLYavmJVhutrfOpE-4-fw1mf3eLYu7xrxIplSiOIsHunTUssnTiBkXAaGqGJs604Pw"; static final String RSA_X_E = "AQAB"; - static final JsonWebKey RSA_1 = createRsa("kid/1", "RS256", RSA_1_E, RSA_1_N, RSA_1_D); - static final JsonWebKey RSA_1_PUBLIC = createRsaPublic("kid/1", "RS256", RSA_1_E, RSA_1_N); - static final JsonWebKey RSA_1_PUBLIC_NO_ALG = createRsaPublic("kid/1", null, RSA_1_E, RSA_1_N); - static final JsonWebKey RSA_1_PUBLIC_WRONG_ALG = createRsaPublic("kid/1", "HS256", RSA_1_E, RSA_1_N); + static final JWK RSA_1 = createRsa("kid/1", "RS256", RSA_1_E, RSA_1_N, RSA_1_D); - static final JsonWebKey RSA_2 = createRsa("kid/2", "RS256", RSA_2_E, RSA_2_N, RSA_2_D); - static final JsonWebKey RSA_2_PUBLIC = createRsaPublic("kid/2", "RS256", RSA_2_E, RSA_2_N); + static final JWK RSA_1_PUBLIC = createRsaPublic("kid/1", "RS256", RSA_1_E, RSA_1_N); + static final JWK RSA_1_PUBLIC_NO_ALG = createRsaPublic("kid/1", null, RSA_1_E, RSA_1_N); + static final JWK RSA_1_PUBLIC_WRONG_ALG = createRsaPublic("kid/1", "HS256", RSA_1_E, RSA_1_N); - static final JsonWebKey RSA_X = createRsa("kid/2", "RS256", RSA_X_E, RSA_X_N, RSA_X_D); - static final JsonWebKey RSA_X_PUBLIC = createRsaPublic("kid/2", "RS256", RSA_X_E, RSA_X_N); + static final JWK RSA_2 = createRsa("kid/2", "RS256", RSA_2_E, RSA_2_N, RSA_2_D); + static final JWK RSA_2_PUBLIC = createRsaPublic("kid/2", "RS256", RSA_2_E, RSA_2_N); - static final JsonWebKeys RSA_1_2_PUBLIC = createJwks(RSA_1_PUBLIC, RSA_2_PUBLIC); + static final JWK RSA_X = createRsa("kid/2", "RS256", RSA_X_E, RSA_X_N, RSA_X_D); + static final JWK RSA_X_PUBLIC = createRsaPublic("kid/2", "RS256", RSA_X_E, RSA_X_N); + + static final JWKSet RSA_1_2_PUBLIC = createJwks(RSA_1_PUBLIC, RSA_2_PUBLIC); static class Jwks { - static final JsonWebKeys ALL = createJwks(OCT_1, OCT_2, OCT_3, FORWARD_SLASH_KID_OCT_1, RSA_1_PUBLIC, RSA_2_PUBLIC); - static final JsonWebKeys RSA_1 = createJwks(RSA_1_PUBLIC); - static final JsonWebKeys RSA_2 = createJwks(RSA_2_PUBLIC); - static final JsonWebKeys RSA_1_NO_ALG = createJwks(RSA_1_PUBLIC_NO_ALG); - static final JsonWebKeys RSA_1_WRONG_ALG = createJwks(RSA_1_PUBLIC_WRONG_ALG); + static final JWKSet ALL = createJwks(OCT_1, OCT_2, OCT_3, FORWARD_SLASH_KID_OCT_1, RSA_1_PUBLIC, RSA_2_PUBLIC); + static final JWKSet RSA_1 = createJwks(RSA_1_PUBLIC); + static final JWKSet RSA_2 = createJwks(RSA_2_PUBLIC); + static final JWKSet RSA_1_NO_ALG = createJwks(RSA_1_PUBLIC_NO_ALG); + static final JWKSet RSA_1_WRONG_ALG = createJwks(RSA_1_PUBLIC_WRONG_ALG); } - private static JsonWebKey createOct(String keyId, String algorithm, String k) { - JsonWebKey result = new JsonWebKey(); - - result.setKeyId(keyId); - result.setKeyType(KeyType.OCTET); - result.setAlgorithm(algorithm); - result.setPublicKeyUse(PublicKeyUse.SIGN); - result.setProperty("k", k); - - return result; + private static JWK createOct(String keyId, String algorithm, String k) { + return new OctetSequenceKey.Builder(k.getBytes(StandardCharsets.UTF_8)).keyID(keyId) + .keyUse(KeyUse.SIGNATURE) + .algorithm(JWSAlgorithm.parse(algorithm)) + .build(); } - private static JsonWebKey createRsa(String keyId, String algorithm, String e, String n, String d) { - JsonWebKey result = new JsonWebKey(); - - result.setKeyId(keyId); - result.setKeyType(KeyType.RSA); - result.setAlgorithm(algorithm); - result.setPublicKeyUse(PublicKeyUse.SIGN); + private static JWK createRsa(String keyId, String algorithm, String e, String n, String d) { + RSAKey.Builder builder = new RSAKey.Builder(Base64URL.from(n), Base64URL.from(e)).keyUse(KeyUse.SIGNATURE) + .algorithm(algorithm == null ? null : JWSAlgorithm.parse(algorithm)) + .keyID(keyId); if (d != null) { - result.setProperty("d", d); + builder.privateExponent(Base64URL.from(d)); } - result.setProperty("e", e); - result.setProperty("n", n); - - return result; + return builder.build(); } - private static JsonWebKey createRsaPublic(String keyId, String algorithm, String e, String n) { + private static JWK createRsaPublic(String keyId, String algorithm, String e, String n) { return createRsa(keyId, algorithm, e, n, null); } - private static JsonWebKeys createJwks(JsonWebKey... array) { - JsonWebKeys result = new JsonWebKeys(); - - result.setKeys(Arrays.asList(array)); - - return result; + private static JWKSet createJwks(JWK... array) { + return new JWKSet(Arrays.asList(array)); } } diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/TestJwts.java b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/TestJwts.java index 292af6c014..9d49596e73 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/TestJwts.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/TestJwts.java @@ -14,16 +14,19 @@ import java.util.Set; import com.google.common.collect.ImmutableSet; -import org.apache.cxf.rs.security.jose.jwk.JsonWebKey; -import org.apache.cxf.rs.security.jose.jws.JwsHeaders; -import org.apache.cxf.rs.security.jose.jws.JwsSignatureProvider; -import org.apache.cxf.rs.security.jose.jws.JwsUtils; -import org.apache.cxf.rs.security.jose.jwt.JoseJwtProducer; -import org.apache.cxf.rs.security.jose.jwt.JwtClaims; -import org.apache.cxf.rs.security.jose.jwt.JwtConstants; -import org.apache.cxf.rs.security.jose.jwt.JwtToken; +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.JWSHeader; +import com.nimbusds.jose.JWSSigner; +import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.SignedJWT; import org.apache.logging.log4j.util.Strings; +import static com.nimbusds.jwt.JWTClaimNames.EXPIRATION_TIME; +import static com.nimbusds.jwt.JWTClaimNames.NOT_BEFORE; + class TestJwts { static final String ROLES_CLAIM = "roles"; static final Set TEST_ROLES = ImmutableSet.of("role1", "role2"); @@ -35,21 +38,21 @@ class TestJwts { static final String TEST_ISSUER = "TestIssuer"; - static final JwtToken MC_COY = create(MCCOY_SUBJECT, TEST_AUDIENCE, TEST_ISSUER, ROLES_CLAIM, TEST_ROLES_STRING); + static final JWTClaimsSet MC_COY = create(MCCOY_SUBJECT, TEST_AUDIENCE, TEST_ISSUER, ROLES_CLAIM, TEST_ROLES_STRING); - static final JwtToken MC_COY_2 = create(MCCOY_SUBJECT, TEST_AUDIENCE, TEST_ISSUER, ROLES_CLAIM, TEST_ROLES_STRING); + static final JWTClaimsSet MC_COY_2 = create(MCCOY_SUBJECT, TEST_AUDIENCE, TEST_ISSUER, ROLES_CLAIM, TEST_ROLES_STRING); - static final JwtToken MC_COY_NO_AUDIENCE = create(MCCOY_SUBJECT, null, TEST_ISSUER, ROLES_CLAIM, TEST_ROLES_STRING); + static final JWTClaimsSet MC_COY_NO_AUDIENCE = create(MCCOY_SUBJECT, null, TEST_ISSUER, ROLES_CLAIM, TEST_ROLES_STRING); - static final JwtToken MC_COY_NO_ISSUER = create(MCCOY_SUBJECT, TEST_AUDIENCE, null, ROLES_CLAIM, TEST_ROLES_STRING); + static final JWTClaimsSet MC_COY_NO_ISSUER = create(MCCOY_SUBJECT, TEST_AUDIENCE, null, ROLES_CLAIM, TEST_ROLES_STRING); - static final JwtToken MC_COY_EXPIRED = create( + static final JWTClaimsSet MC_COY_EXPIRED = create( MCCOY_SUBJECT, TEST_AUDIENCE, TEST_ISSUER, ROLES_CLAIM, TEST_ROLES_STRING, - JwtConstants.CLAIM_EXPIRY, + EXPIRATION_TIME, 10 ); @@ -78,73 +81,82 @@ static class PeculiarEscaping { static final String MC_COY_SIGNED_RSA_1 = createSignedWithPeculiarEscaping(MC_COY, TestJwk.RSA_1); } - static JwtToken create(String subject, String audience, String issuer, Object... moreClaims) { - JwtClaims claims = new JwtClaims(); + static JWTClaimsSet create(String subject, String audience, String issuer, Object... moreClaims) { + JWTClaimsSet.Builder claimsBuilder = new JWTClaimsSet.Builder(); - claims.setSubject(subject); + claimsBuilder.subject(subject); if (audience != null) { - claims.setAudience(audience); + claimsBuilder.audience(audience); } if (issuer != null) { - claims.setIssuer(issuer); + claimsBuilder.issuer(issuer); } if (moreClaims != null) { for (int i = 0; i < moreClaims.length; i += 2) { - claims.setClaim(String.valueOf(moreClaims[i]), moreClaims[i + 1]); + claimsBuilder.claim(String.valueOf(moreClaims[i]), moreClaims[i + 1]); } } - JwtToken result = new JwtToken(claims); - - return result; + // JwtToken result = new JwtToken(claimsBuilder); + return claimsBuilder.build(); } - static String createSigned(JwtToken baseJwt, JsonWebKey jwk) { - return createSigned(baseJwt, jwk, JwsUtils.getSignatureProvider(jwk)); - } - - static String createSigned(JwtToken baseJwt, JsonWebKey jwk, JwsSignatureProvider signatureProvider) { - JwsHeaders jwsHeaders = new JwsHeaders(); - JwtToken signedToken = new JwtToken(jwsHeaders, baseJwt.getClaims()); - - jwsHeaders.setKeyId(jwk.getKeyId()); + static String createSigned(JWTClaimsSet jwtClaimsSet, JWK jwk) { + JWSHeader jwsHeader = new JWSHeader.Builder(new JWSAlgorithm(jwk.getAlgorithm().getName())).keyID(jwk.getKeyID()).build(); + SignedJWT signedJWT = new SignedJWT(jwsHeader, jwtClaimsSet); + try { + JWSSigner signer = new DefaultJWSSignerFactory().createJWSSigner(jwk); + signedJWT.sign(signer); + } catch (JOSEException e) { + throw new RuntimeException(e); + } - return new JoseJwtProducer().processJwt(signedToken, null, signatureProvider); + return signedJWT.serialize(); } - static String createSignedWithoutKeyId(JwtToken baseJwt, JsonWebKey jwk) { - JwsHeaders jwsHeaders = new JwsHeaders(); - JwtToken signedToken = new JwtToken(jwsHeaders, baseJwt.getClaims()); + static String createSignedWithoutKeyId(JWTClaimsSet jwtClaimsSet, JWK jwk) { + JWSHeader jwsHeader = new JWSHeader.Builder(new JWSAlgorithm(jwk.getAlgorithm().getName())).keyID(jwk.getKeyID()).build(); + SignedJWT signedJWT = new SignedJWT(jwsHeader, jwtClaimsSet); + try { + JWSSigner signer = new DefaultJWSSignerFactory().createJWSSigner(jwk); + signedJWT.sign(signer); + } catch (JOSEException e) { + throw new RuntimeException(e); + } - return new JoseJwtProducer().processJwt(signedToken, null, JwsUtils.getSignatureProvider(jwk)); + return signedJWT.serialize(); } - static String createSignedWithPeculiarEscaping(JwtToken baseJwt, JsonWebKey jwk) { - JwsSignatureProvider signatureProvider = JwsUtils.getSignatureProvider(jwk); - JwsHeaders jwsHeaders = new JwsHeaders(); - JwtToken signedToken = new JwtToken(jwsHeaders, baseJwt.getClaims()); - - // Depends on CXF not escaping the input string. This may fail for other frameworks or versions. - jwsHeaders.setKeyId(jwk.getKeyId().replace("/", "\\/")); + static String createSignedWithPeculiarEscaping(JWTClaimsSet jwtClaimsSet, JWK jwk) { + JWSHeader jwsHeader = new JWSHeader.Builder(new JWSAlgorithm(jwk.getAlgorithm().getName())).keyID( + jwk.getKeyID().replace("/", "\\/") + ).build(); + SignedJWT signedJWT = new SignedJWT(jwsHeader, jwtClaimsSet); + try { + JWSSigner signer = new DefaultJWSSignerFactory().createJWSSigner(jwk); + signedJWT.sign(signer); + } catch (JOSEException e) { + throw new RuntimeException(e); + } - return new JoseJwtProducer().processJwt(signedToken, null, signatureProvider); + return signedJWT.serialize(); } static String createMcCoySignedOct1(long nbf, long exp) { - JwtToken jwt_token = create( + JWTClaimsSet jwtClaimsSet = create( MCCOY_SUBJECT, TEST_AUDIENCE, TEST_ISSUER, ROLES_CLAIM, TEST_ROLES_STRING, - JwtConstants.CLAIM_NOT_BEFORE, + NOT_BEFORE, nbf, - JwtConstants.CLAIM_EXPIRY, + EXPIRATION_TIME, exp ); - return createSigned(jwt_token, TestJwk.OCT_1); + return createSigned(jwtClaimsSet, TestJwk.OCT_1); } } diff --git a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java index 17a2148fa5..b475a42e2c 100644 --- a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java @@ -33,8 +33,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.google.common.collect.ImmutableMap; -import org.apache.cxf.rs.security.jose.jws.JwsJwtCompactConsumer; -import org.apache.cxf.rs.security.jose.jwt.JwtToken; +import com.nimbusds.jwt.SignedJWT; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Assert; @@ -121,6 +120,46 @@ public void tearDown() { } } + @Test + public void testRawHMACSettings() throws Exception { + mockSamlIdpServer.setSignResponses(true); + mockSamlIdpServer.loadSigningKeys("saml/kirk-keystore.jks", "kirk"); + mockSamlIdpServer.setAuthenticateUser("horst"); + mockSamlIdpServer.setEndpointQueryString(null); + + Settings settings = Settings.builder() + .put(IDP_METADATA_URL, mockSamlIdpServer.getMetadataUri()) + .put("kibana_url", "http://wherever") + .put("idp.entity_id", mockSamlIdpServer.getIdpEntityId()) + .put("roles_key", "roles") + .put("jwt.key.kty", "oct") + .put("jwt.key.k", "abc") + .put("path.home", ".") + .build(); + + HTTPSamlAuthenticator samlAuthenticator = new HTTPSamlAuthenticator(settings, null); + + AuthenticateHeaders authenticateHeaders = getAutenticateHeaders(samlAuthenticator); + + String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); + + RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); + + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); + HashMap response = DefaultObjectMapper.objectMapper.readValue( + responseJson, + new TypeReference>() { + } + ); + String authorization = (String) response.get("authorization"); + + Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); + + SignedJWT jwt = SignedJWT.parse(authorization.replaceAll("\\s*bearer\\s*", "")); + + Assert.assertEquals("horst", jwt.getJWTClaimsSet().getClaim("sub")); + } + @Test public void basicTest() throws Exception { mockSamlIdpServer.setSignResponses(true); @@ -155,10 +194,9 @@ public void basicTest() throws Exception { Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); - JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer(authorization.replaceAll("\\s*bearer\\s*", "")); - JwtToken jwt = jwtConsumer.getJwtToken(); + SignedJWT jwt = SignedJWT.parse(authorization.replaceAll("\\s*bearer\\s*", "")); - Assert.assertEquals("horst", jwt.getClaim("sub")); + Assert.assertEquals("horst", jwt.getJWTClaimsSet().getClaim("sub")); } private Optional sendToAuthenticator(HTTPSamlAuthenticator samlAuthenticator, RestRequest request) { @@ -209,10 +247,9 @@ public void decryptAssertionsTest() throws Exception { Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); - JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer(authorization.replaceAll("\\s*bearer\\s*", "")); - JwtToken jwt = jwtConsumer.getJwtToken(); + SignedJWT jwt = SignedJWT.parse(authorization.replaceAll("\\s*bearer\\s*", "")); - Assert.assertEquals("horst", jwt.getClaim("sub")); + Assert.assertEquals("horst", jwt.getJWTClaimsSet().getClaim("sub")); } @Test @@ -253,13 +290,12 @@ public void shouldUnescapeSamlEntitiesTest() throws Exception { Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); - JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer(authorization.replaceAll("\\s*bearer\\s*", "")); - JwtToken jwt = jwtConsumer.getJwtToken(); + SignedJWT jwt = SignedJWT.parse(authorization.replaceAll("\\s*bearer\\s*", "")); + Assert.assertEquals("ABC\\User1", jwt.getJWTClaimsSet().getClaim("sub")); + Assert.assertEquals("ABC\\User1", samlAuthenticator.httpJwtAuthenticator.extractSubject(jwt.getJWTClaimsSet())); + Assert.assertEquals("[ABC\\Admin]", String.valueOf(jwt.getJWTClaimsSet().getClaim("roles"))); + Assert.assertEquals("ABC\\Admin", samlAuthenticator.httpJwtAuthenticator.extractRoles(jwt.getJWTClaimsSet())[0]); - Assert.assertEquals("ABC\\User1", jwt.getClaim("sub")); - Assert.assertEquals("ABC\\User1", samlAuthenticator.httpJwtAuthenticator.extractSubject(jwt.getClaims())); - Assert.assertEquals("[ABC\\Admin]", String.valueOf(jwt.getClaim("roles"))); - Assert.assertEquals("ABC\\Admin", samlAuthenticator.httpJwtAuthenticator.extractRoles(jwt.getClaims())[0]); } @Test @@ -300,13 +336,11 @@ public void shouldUnescapeSamlEntitiesTest2() throws Exception { Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); - JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer(authorization.replaceAll("\\s*bearer\\s*", "")); - JwtToken jwt = jwtConsumer.getJwtToken(); - - Assert.assertEquals("ABC\"User1", jwt.getClaim("sub")); - Assert.assertEquals("ABC\"User1", samlAuthenticator.httpJwtAuthenticator.extractSubject(jwt.getClaims())); - Assert.assertEquals("[ABC\"Admin]", String.valueOf(jwt.getClaim("roles"))); - Assert.assertEquals("ABC\"Admin", samlAuthenticator.httpJwtAuthenticator.extractRoles(jwt.getClaims())[0]); + SignedJWT jwt = SignedJWT.parse(authorization.replaceAll("\\s*bearer\\s*", "")); + Assert.assertEquals("ABC\"User1", jwt.getJWTClaimsSet().getClaim("sub")); + Assert.assertEquals("ABC\"User1", samlAuthenticator.httpJwtAuthenticator.extractSubject(jwt.getJWTClaimsSet())); + Assert.assertEquals("[ABC\"Admin]", String.valueOf(jwt.getJWTClaimsSet().getClaim("roles"))); + Assert.assertEquals("ABC\"Admin", samlAuthenticator.httpJwtAuthenticator.extractRoles(jwt.getJWTClaimsSet())[0]); } @Test @@ -347,13 +381,11 @@ public void shouldNotEscapeSamlEntities() throws Exception { Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); - JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer(authorization.replaceAll("\\s*bearer\\s*", "")); - JwtToken jwt = jwtConsumer.getJwtToken(); - - Assert.assertEquals("ABC/User1", jwt.getClaim("sub")); - Assert.assertEquals("ABC/User1", samlAuthenticator.httpJwtAuthenticator.extractSubject(jwt.getClaims())); - Assert.assertEquals("[ABC/Admin]", String.valueOf(jwt.getClaim("roles"))); - Assert.assertEquals("ABC/Admin", samlAuthenticator.httpJwtAuthenticator.extractRoles(jwt.getClaims())[0]); + SignedJWT jwt = SignedJWT.parse(authorization.replaceAll("\\s*bearer\\s*", "")); + Assert.assertEquals("ABC/User1", jwt.getJWTClaimsSet().getClaim("sub")); + Assert.assertEquals("ABC/User1", samlAuthenticator.httpJwtAuthenticator.extractSubject(jwt.getJWTClaimsSet())); + Assert.assertEquals("[ABC/Admin]", String.valueOf(jwt.getJWTClaimsSet().getClaim("roles"))); + Assert.assertEquals("ABC/Admin", samlAuthenticator.httpJwtAuthenticator.extractRoles(jwt.getJWTClaimsSet())[0]); } @Test @@ -394,10 +426,9 @@ public void shouldNotTrimWhitespaceInJwtRoles() throws Exception { Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); - JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer(authorization.replaceAll("\\s*bearer\\s*", "")); - JwtToken jwt = jwtConsumer.getJwtToken(); + SignedJWT jwt = SignedJWT.parse(authorization.replaceAll("\\s*bearer\\s*", "")); + Assert.assertEquals("ABC/Admin", samlAuthenticator.httpJwtAuthenticator.extractRoles(jwt.getJWTClaimsSet())[0]); - Assert.assertEquals("ABC/Admin", samlAuthenticator.httpJwtAuthenticator.extractRoles(jwt.getClaims())[0]); } @Test @@ -437,10 +468,9 @@ public void testMetadataBody() throws Exception { Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); - JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer(authorization.replaceAll("\\s*bearer\\s*", "")); - JwtToken jwt = jwtConsumer.getJwtToken(); + SignedJWT jwt = SignedJWT.parse(authorization.replaceAll("\\s*bearer\\s*", "")); - Assert.assertEquals("horst", jwt.getClaim("sub")); + Assert.assertEquals("horst", jwt.getJWTClaimsSet().getClaim("sub")); } @Test(expected = RuntimeException.class) @@ -498,10 +528,9 @@ public void unsolicitedSsoTest() throws Exception { Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); - JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer(authorization.replaceAll("\\s*bearer\\s*", "")); - JwtToken jwt = jwtConsumer.getJwtToken(); + SignedJWT jwt = SignedJWT.parse(authorization.replaceAll("\\s*bearer\\s*", "")); - Assert.assertEquals("horst", jwt.getClaim("sub")); + Assert.assertEquals("horst", jwt.getJWTClaimsSet().getClaim("sub")); } @Test @@ -633,13 +662,12 @@ public void rolesTest() throws Exception { Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); - JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer(authorization.replaceAll("\\s*bearer\\s*", "")); - JwtToken jwt = jwtConsumer.getJwtToken(); + SignedJWT jwt = SignedJWT.parse(authorization.replaceAll("\\s*bearer\\s*", "")); - Assert.assertEquals("horst", jwt.getClaim("sub")); + Assert.assertEquals("horst", jwt.getJWTClaimsSet().getClaim("sub")); Assert.assertArrayEquals( new String[] { "a ", "c", "b ", "d", " e", "f", "g", "h", " ", "i" }, - ((List) jwt.getClaim("roles")).toArray(new String[0]) + ((List) jwt.getJWTClaimsSet().getClaim("roles")).toArray(new String[0]) ); } @@ -676,10 +704,9 @@ public void idpEndpointWithQueryStringTest() throws Exception { Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); - JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer(authorization.replaceAll("\\s*bearer\\s*", "")); - JwtToken jwt = jwtConsumer.getJwtToken(); + SignedJWT jwt = SignedJWT.parse(authorization.replaceAll("\\s*bearer\\s*", "")); - Assert.assertEquals("horst", jwt.getClaim("sub")); + Assert.assertEquals("horst", jwt.getJWTClaimsSet().getClaim("sub")); } @Test @@ -726,11 +753,13 @@ private void commaSeparatedRoles(final String rolesAsString, final Settings.Buil Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); - JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer(authorization.replaceAll("\\s*bearer\\s*", "")); - JwtToken jwt = jwtConsumer.getJwtToken(); + SignedJWT jwt = SignedJWT.parse(authorization.replaceAll("\\s*bearer\\s*", "")); - Assert.assertEquals("horst", jwt.getClaim("sub")); - Assert.assertArrayEquals(new String[] { "a", "b" }, ((List) jwt.getClaim("roles")).toArray(new String[0])); + Assert.assertEquals("horst", jwt.getJWTClaimsSet().getClaim("sub")); + Assert.assertArrayEquals( + new String[] { "a", "b" }, + ((List) jwt.getJWTClaimsSet().getClaim("roles")).toArray(new String[0]) + ); } @Test @@ -844,10 +873,8 @@ public void initialConnectionFailureTest() throws Exception { Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); - JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer(authorization.replaceAll("\\s*bearer\\s*", "")); - JwtToken jwt = jwtConsumer.getJwtToken(); - - Assert.assertEquals("horst", jwt.getClaim("sub")); + SignedJWT jwt = SignedJWT.parse(authorization.replaceAll("\\s*bearer\\s*", "")); + Assert.assertEquals("horst", jwt.getJWTClaimsSet().getClaim("sub")); } } diff --git a/src/test/java/org/opensearch/security/authtoken/jwt/JwtVendorTest.java b/src/test/java/org/opensearch/security/authtoken/jwt/JwtVendorTest.java index 03cbd20b42..e271c7b838 100644 --- a/src/test/java/org/opensearch/security/authtoken/jwt/JwtVendorTest.java +++ b/src/test/java/org/opensearch/security/authtoken/jwt/JwtVendorTest.java @@ -11,27 +11,40 @@ package org.opensearch.security.authtoken.jwt; +import java.nio.charset.StandardCharsets; +import java.util.Date; +import java.util.List; +import java.util.Optional; +import java.util.function.LongSupplier; + +import com.google.common.io.BaseEncoding; +import com.nimbusds.jwt.SignedJWT; import org.apache.commons.lang3.RandomStringUtils; -import org.apache.cxf.rs.security.jose.jwk.JsonWebKey; -import org.apache.cxf.rs.security.jose.jws.JwsJwtCompactConsumer; -import org.apache.cxf.rs.security.jose.jwt.JwtToken; import org.apache.logging.log4j.Level; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.core.Appender; import org.apache.logging.log4j.core.LogEvent; import org.apache.logging.log4j.core.Logger; +import org.junit.Assert; import org.junit.Test; import org.mockito.ArgumentCaptor; +import org.opensearch.OpenSearchException; +import org.opensearch.common.collect.Tuple; import org.opensearch.common.settings.Settings; import org.opensearch.security.support.ConfigConstants; -import java.util.List; -import java.util.Optional; -import java.util.function.LongSupplier; +import com.nimbusds.jose.JWSSigner; +import com.nimbusds.jose.jwk.JWK; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; + +import static org.hamcrest.core.IsNull.notNullValue; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; @@ -43,62 +56,34 @@ public class JwtVendorTest { private Appender mockAppender; private ArgumentCaptor logEventCaptor; - @Test - public void testCreateJwkFromSettingsThrowsException() { - Settings faultySettings = Settings.builder().put("key.someProperty", "badValue").build(); - - Exception thrownException = assertThrows(Exception.class, () -> new JwtVendor(faultySettings, null)); - - String expectedMessagePart = "An error occurred during the creation of Jwk: "; - assertTrue(thrownException.getMessage().contains(expectedMessagePart)); - } - - @Test - public void testJsonWebKeyPropertiesSetFromJwkSettings() throws Exception { - Settings settings = Settings.builder().put("jwt.key.key1", "value1").put("jwt.key.key2", "value2").build(); - - JsonWebKey jwk = JwtVendor.createJwkFromSettings(settings); - - assertEquals("value1", jwk.getProperty("key1")); - assertEquals("value2", jwk.getProperty("key2")); - } + final static String signingKey = + "This is my super safe signing key that no one will ever be able to guess. It's would take billions of years and the world's most powerful quantum computer to crack"; + final static String signingKeyB64Encoded = BaseEncoding.base64().encode(signingKey.getBytes(StandardCharsets.UTF_8)); @Test - public void testJsonWebKeyPropertiesSetFromSettings() { - Settings jwkSettings = Settings.builder().put("key1", "value1").put("key2", "value2").build(); + public void testCreateJwkFromSettings() { + final Settings settings = Settings.builder().put("signing_key", signingKeyB64Encoded).build(); - JsonWebKey jwk = new JsonWebKey(); - for (String key : jwkSettings.keySet()) { - jwk.setProperty(key, jwkSettings.get(key)); - } - - assertEquals("value1", jwk.getProperty("key1")); - assertEquals("value2", jwk.getProperty("key2")); + final Tuple jwk = JwtVendor.createJwkFromSettings(settings); + Assert.assertEquals("HS512", jwk.v1().getAlgorithm().getName()); + Assert.assertEquals("sig", jwk.v1().getKeyUse().toString()); + Assert.assertTrue(jwk.v1().toOctetSequenceKey().getKeyValue().decodeToString().startsWith(signingKey)); } @Test - public void testCreateJwkFromSettings() throws Exception { - Settings settings = Settings.builder().put("signing_key", "abc123").build(); - - JsonWebKey jwk = JwtVendor.createJwkFromSettings(settings); - assertEquals("HS512", jwk.getAlgorithm()); - assertEquals("sig", jwk.getPublicKeyUse().toString()); - assertEquals("abc123", jwk.getProperty("k")); + public void testCreateJwkFromSettingsWithWeakKey() { + Settings settings = Settings.builder().put("signing_key", "abcd1234").build(); + Throwable exception = Assert.assertThrows(OpenSearchException.class, () -> JwtVendor.createJwkFromSettings(settings)); + assertThat(exception.getMessage(), containsString("The secret length must be at least 256 bits")); } @Test public void testCreateJwkFromSettingsWithoutSigningKey() { Settings settings = Settings.builder().put("jwt", "").build(); - Throwable exception = assertThrows(RuntimeException.class, () -> { - try { - JwtVendor.createJwkFromSettings(settings); - } catch (Exception e) { - throw new RuntimeException(e); - } - }); - assertEquals( - "java.lang.Exception: Settings for signing key is missing. Please specify at least the option signing_key with a shared secret.", - exception.getMessage() + Throwable exception = Assert.assertThrows(RuntimeException.class, () -> JwtVendor.createJwkFromSettings(settings)); + assertThat( + exception.getMessage(), + equalTo("Settings for signing key is missing. Please specify at least the option signing_key with a shared secret.") ); } @@ -111,26 +96,26 @@ public void testCreateJwtWithRoles() throws Exception { List backendRoles = List.of("Sales", "Support"); String expectedRoles = "IT,HR"; int expirySeconds = 300; - LongSupplier currentTime = () -> (long) 100; - String claimsEncryptionKey = RandomStringUtils.randomAlphanumeric(16); - Settings settings = Settings.builder().put("signing_key", "abc123").put("encryption_key", claimsEncryptionKey).build(); - Long expectedExp = currentTime.getAsLong() + expirySeconds; + // 2023 oct 4, 10:00:00 AM GMT + LongSupplier currentTime = () -> 1696413600000L; + String claimsEncryptionKey = "1234567890123456"; + Settings settings = Settings.builder().put("signing_key", signingKeyB64Encoded).put("encryption_key", claimsEncryptionKey).build(); JwtVendor jwtVendor = new JwtVendor(settings, Optional.of(currentTime)); - String encodedJwt = jwtVendor.createJwt(issuer, subject, audience, expirySeconds, roles, backendRoles, true); + final String encodedJwt = jwtVendor.createJwt(issuer, subject, audience, expirySeconds, roles, backendRoles, true); - JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer(encodedJwt); - JwtToken jwt = jwtConsumer.getJwtToken(); + SignedJWT signedJWT = SignedJWT.parse(encodedJwt); - assertEquals("cluster_0", jwt.getClaim("iss")); - assertEquals("admin", jwt.getClaim("sub")); - assertEquals("audience_0", jwt.getClaim("aud")); - assertNotNull(jwt.getClaim("iat")); - assertNotNull(jwt.getClaim("exp")); - assertEquals(expectedExp, jwt.getClaim("exp")); + assertThat(signedJWT.getJWTClaimsSet().getClaims().get("iss"), equalTo("cluster_0")); + assertThat(signedJWT.getJWTClaimsSet().getClaims().get("sub"), equalTo("admin")); + assertThat(signedJWT.getJWTClaimsSet().getClaims().get("aud").toString(), equalTo("[audience_0]")); + // 2023 oct 4, 10:00:00 AM GMT + assertThat(((Date) signedJWT.getJWTClaimsSet().getClaims().get("iat")).getTime(), is(1696413600000L)); + // 2023 oct 4, 10:05:00 AM GMT + assertThat(((Date) signedJWT.getJWTClaimsSet().getClaims().get("exp")).getTime(), is(1696413900000L)); EncryptionDecryptionUtil encryptionUtil = new EncryptionDecryptionUtil(claimsEncryptionKey); - assertEquals(expectedRoles, encryptionUtil.decrypt(jwt.getClaim("er").toString())); - assertNull(jwt.getClaim("br")); + assertThat(encryptionUtil.decrypt(signedJWT.getJWTClaimsSet().getClaims().get("er").toString()), equalTo(expectedRoles)); + assertThat(signedJWT.getJWTClaimsSet().getClaims().get("br"), nullValue()); } @Test @@ -145,32 +130,29 @@ public void testCreateJwtWithRoleSecurityMode() throws Exception { int expirySeconds = 300; LongSupplier currentTime = () -> (long) 100; - String claimsEncryptionKey = RandomStringUtils.randomAlphanumeric(16); + String claimsEncryptionKey = "1234567890123456"; Settings settings = Settings.builder() - .put("signing_key", "abc123") + .put("signing_key", signingKeyB64Encoded) .put("encryption_key", claimsEncryptionKey) // CS-SUPPRESS-SINGLE: RegexpSingleline get Extensions Settings - .put(ConfigConstants.EXTENSIONS_BWC_PLUGIN_MODE, "true") + .put(ConfigConstants.EXTENSIONS_BWC_PLUGIN_MODE, true) // CS-ENFORCE-SINGLE .build(); - Long expectedExp = currentTime.getAsLong() + expirySeconds; + final JwtVendor jwtVendor = new JwtVendor(settings, Optional.of(currentTime)); + final String encodedJwt = jwtVendor.createJwt(issuer, subject, audience, expirySeconds, roles, backendRoles, false); - JwtVendor jwtVendor = new JwtVendor(settings, Optional.of(currentTime)); - String encodedJwt = jwtVendor.createJwt(issuer, subject, audience, expirySeconds, roles, backendRoles, false); + SignedJWT signedJWT = SignedJWT.parse(encodedJwt); - JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer(encodedJwt); - JwtToken jwt = jwtConsumer.getJwtToken(); + assertThat(signedJWT.getJWTClaimsSet().getClaims().get("iss"), equalTo("cluster_0")); + assertThat(signedJWT.getJWTClaimsSet().getClaims().get("sub"), equalTo("admin")); + assertThat(signedJWT.getJWTClaimsSet().getClaims().get("aud").toString(), equalTo("[audience_0]")); + assertThat(signedJWT.getJWTClaimsSet().getClaims().get("iat"), is(notNullValue())); + assertThat(signedJWT.getJWTClaimsSet().getClaims().get("exp"), is(notNullValue())); + assertThat(signedJWT.getJWTClaimsSet().getClaims().get("br"), is(notNullValue())); + assertThat(signedJWT.getJWTClaimsSet().getClaims().get("br").toString(), equalTo(expectedBackendRoles)); - assertEquals("cluster_0", jwt.getClaim("iss")); - assertEquals("admin", jwt.getClaim("sub")); - assertEquals("audience_0", jwt.getClaim("aud")); - assertNotNull(jwt.getClaim("iat")); - assertNotNull(jwt.getClaim("exp")); - assertEquals(expectedExp, jwt.getClaim("exp")); EncryptionDecryptionUtil encryptionUtil = new EncryptionDecryptionUtil(claimsEncryptionKey); - assertEquals(expectedRoles, encryptionUtil.decrypt(jwt.getClaim("er").toString())); - assertNotNull(jwt.getClaim("br")); - assertEquals(expectedBackendRoles, jwt.getClaim("br")); + assertThat(encryptionUtil.decrypt(signedJWT.getJWTClaimsSet().getClaims().get("er").toString()), equalTo(expectedRoles)); } @Test @@ -181,7 +163,7 @@ public void testCreateJwtWithNegativeExpiry() { List roles = List.of("admin"); Integer expirySeconds = -300; String claimsEncryptionKey = RandomStringUtils.randomAlphanumeric(16); - Settings settings = Settings.builder().put("signing_key", "abc123").put("encryption_key", claimsEncryptionKey).build(); + Settings settings = Settings.builder().put("signing_key", signingKeyB64Encoded).put("encryption_key", claimsEncryptionKey).build(); JwtVendor jwtVendor = new JwtVendor(settings, Optional.empty()); Throwable exception = assertThrows(RuntimeException.class, () -> { @@ -191,7 +173,7 @@ public void testCreateJwtWithNegativeExpiry() { throw new RuntimeException(e); } }); - assertEquals("java.lang.Exception: The expiration time should be a positive integer", exception.getMessage()); + assertEquals("java.lang.IllegalArgumentException: The expiration time should be a positive integer", exception.getMessage()); } @Test @@ -204,7 +186,7 @@ public void testCreateJwtWithExceededExpiry() throws Exception { int expirySeconds = 900; LongSupplier currentTime = () -> (long) 100; String claimsEncryptionKey = RandomStringUtils.randomAlphanumeric(16); - Settings settings = Settings.builder().put("signing_key", "abc123").put("encryption_key", claimsEncryptionKey).build(); + Settings settings = Settings.builder().put("signing_key", signingKeyB64Encoded).put("encryption_key", claimsEncryptionKey).build(); JwtVendor jwtVendor = new JwtVendor(settings, Optional.of(currentTime)); Throwable exception = assertThrows(RuntimeException.class, () -> { @@ -215,7 +197,7 @@ public void testCreateJwtWithExceededExpiry() throws Exception { } }); assertEquals( - "java.lang.Exception: The provided expiration time exceeds the maximum allowed duration of 600 seconds", + "java.lang.IllegalArgumentException: The provided expiration time exceeds the maximum allowed duration of 600 seconds", exception.getMessage() ); } @@ -228,7 +210,7 @@ public void testCreateJwtWithBadEncryptionKey() { List roles = List.of("admin"); Integer expirySeconds = 300; - Settings settings = Settings.builder().put("signing_key", "abc123").build(); + Settings settings = Settings.builder().put("signing_key", signingKeyB64Encoded).build(); Throwable exception = assertThrows(RuntimeException.class, () -> { try { @@ -248,7 +230,7 @@ public void testCreateJwtWithBadRoles() { List roles = null; Integer expirySeconds = 300; String claimsEncryptionKey = RandomStringUtils.randomAlphanumeric(16); - Settings settings = Settings.builder().put("signing_key", "abc123").put("encryption_key", claimsEncryptionKey).build(); + Settings settings = Settings.builder().put("signing_key", signingKeyB64Encoded).put("encryption_key", claimsEncryptionKey).build(); JwtVendor jwtVendor = new JwtVendor(settings, Optional.empty()); Throwable exception = assertThrows(RuntimeException.class, () -> { @@ -258,7 +240,7 @@ public void testCreateJwtWithBadRoles() { throw new RuntimeException(e); } }); - assertEquals("java.lang.Exception: Roles cannot be null", exception.getMessage()); + assertEquals("java.lang.IllegalArgumentException: Roles cannot be null", exception.getMessage()); } @Test @@ -274,7 +256,7 @@ public void testCreateJwtLogsCorrectly() throws Exception { // Mock settings and other required dependencies LongSupplier currentTime = () -> (long) 100; String claimsEncryptionKey = RandomStringUtils.randomAlphanumeric(16); - Settings settings = Settings.builder().put("signing_key", "abc123").put("encryption_key", claimsEncryptionKey).build(); + Settings settings = Settings.builder().put("signing_key", signingKeyB64Encoded).put("encryption_key", claimsEncryptionKey).build(); String issuer = "cluster_0"; String subject = "admin"; diff --git a/src/test/java/org/opensearch/security/authtoken/jwt/KeyPaddingUtilTest.java b/src/test/java/org/opensearch/security/authtoken/jwt/KeyPaddingUtilTest.java new file mode 100644 index 0000000000..78bd950964 --- /dev/null +++ b/src/test/java/org/opensearch/security/authtoken/jwt/KeyPaddingUtilTest.java @@ -0,0 +1,42 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.authtoken.jwt; + +import com.nimbusds.jose.JWSAlgorithm; +import org.junit.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class KeyPaddingUtilTest { + + private String signingKey = "testKey"; + + @Test + public void testPadSecretForHS256() { + JWSAlgorithm jwsAlgorithm = JWSAlgorithm.HS256; + String paddedKey = KeyPaddingUtil.padSecret(signingKey, jwsAlgorithm); + + // For HS256, HMAC using SHA-256, typical key length is 256 bits or 32 bytes + int expectedLength = 32; + assertEquals(expectedLength, paddedKey.length()); + } + + @Test + public void testPadSecretForHS384() { + JWSAlgorithm jwsAlgorithm = JWSAlgorithm.HS384; + String paddedKey = KeyPaddingUtil.padSecret(signingKey, jwsAlgorithm); + + // For HS384, HMAC using SHA-384, typical key length is 384 bits or 48 bytes + int expectedLength = 48; + assertEquals(expectedLength, paddedKey.length()); + } +}