Skip to content

Commit

Permalink
If signature validation fails, reload JWKs and retry if new JWKs are …
Browse files Browse the repository at this point in the history
…found (#88023)

Co-authored-by: Niels Dewulf
  • Loading branch information
justincr-elastic committed Jul 22, 2022
1 parent c3e5daa commit 89e54be
Show file tree
Hide file tree
Showing 13 changed files with 950 additions and 559 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/88023.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 88023
summary: "If signature validation fails, reload JWKs and retry if new JWKs are found"
area: Authentication
type: enhancement
issues: []

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.util.JSONObjectUtils;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;

import org.apache.http.HttpEntity;
import org.apache.http.HttpResponse;
Expand All @@ -33,8 +32,9 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.SpecialPermission;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.hash.MessageDigests;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.SettingsException;
import org.elasticsearch.common.ssl.SslConfiguration;
Expand All @@ -51,6 +51,7 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.AccessController;
import java.security.MessageDigest;
import java.security.PrivilegedAction;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
Expand Down Expand Up @@ -185,16 +186,25 @@ public static URI parseHttpsUri(final String uriString) {
return null;
}

public static byte[] readUriContents(
public static void readUriContents(
final String jwkSetConfigKeyPkc,
final URI jwkSetPathPkcUri,
final CloseableHttpAsyncClient httpClient
) throws SettingsException {
try {
return JwtUtil.readBytes(httpClient, jwkSetPathPkcUri);
} catch (Exception e) {
throw new SettingsException("Can't get contents for setting [" + jwkSetConfigKeyPkc + "] value [" + jwkSetPathPkcUri + "].", e);
}
final CloseableHttpAsyncClient httpClient,
final ActionListener<byte[]> listener
) {
JwtUtil.readBytes(
httpClient,
jwkSetPathPkcUri,
ActionListener.wrap(
listener::onResponse,
ex -> listener.onFailure(
new SettingsException(
"Can't get contents for setting [" + jwkSetConfigKeyPkc + "] value [" + jwkSetPathPkcUri + "].",
ex
)
)
)
);
}

public static byte[] readFileContents(final String jwkSetConfigKeyPkc, final String jwkSetPathPkc, final Environment environment)
Expand All @@ -211,7 +221,7 @@ public static byte[] readFileContents(final String jwkSetConfigKeyPkc, final Str
}

public static String serializeJwkSet(final JWKSet jwkSet, final boolean publicKeysOnly) {
if ((jwkSet == null) || (jwkSet.getKeys().isEmpty())) {
if (jwkSet == null) {
return null;
}
return JSONObjectUtils.toJSONString(jwkSet.toJSONObject(publicKeysOnly));
Expand Down Expand Up @@ -262,13 +272,11 @@ public static CloseableHttpAsyncClient createHttpClient(final RealmConfig realmC
}

/**
* Use the HTTP Client to get URL content bytes up to N max bytes.
* Use the HTTP Client to get URL content bytes.
* @param httpClient Configured HTTP/HTTPS client.
* @param uri URI to download.
* @return Byte array of the URI contents up to N max bytes.
*/
public static byte[] readBytes(final CloseableHttpAsyncClient httpClient, final URI uri) {
final PlainActionFuture<byte[]> plainActionFuture = PlainActionFuture.newFuture();
public static void readBytes(final CloseableHttpAsyncClient httpClient, final URI uri, ActionListener<byte[]> listener) {
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
httpClient.execute(new HttpGet(uri), new FutureCallback<>() {
@Override
Expand All @@ -278,12 +286,12 @@ public void completed(final HttpResponse result) {
if (statusCode == 200) {
final HttpEntity entity = result.getEntity();
try (InputStream inputStream = entity.getContent()) {
plainActionFuture.onResponse(inputStream.readAllBytes());
listener.onResponse(inputStream.readAllBytes());
} catch (Exception e) {
plainActionFuture.onFailure(e);
listener.onFailure(e);
}
} else {
plainActionFuture.onFailure(
listener.onFailure(
new ElasticsearchSecurityException(
"Get [" + uri + "] failed, status [" + statusCode + "], reason [" + statusLine.getReasonPhrase() + "]."
)
Expand All @@ -293,17 +301,16 @@ public void completed(final HttpResponse result) {

@Override
public void failed(Exception e) {
plainActionFuture.onFailure(new ElasticsearchSecurityException("Get [" + uri + "] failed.", e));
listener.onFailure(new ElasticsearchSecurityException("Get [" + uri + "] failed.", e));
}

@Override
public void cancelled() {
plainActionFuture.onFailure(new ElasticsearchSecurityException("Get [" + uri + "] was cancelled."));
listener.onFailure(new ElasticsearchSecurityException("Get [" + uri + "] was cancelled."));
}
});
return null;
});
return plainActionFuture.actionGet();
}

public static Path resolvePath(final Environment environment, final String jwkSetPath) {
Expand Down Expand Up @@ -335,14 +342,10 @@ public static SecureString join(final CharSequence delimiter, final CharSequence
* JWSHeader: Header are not support.
* JWTClaimsSet: Claims are supported. Claim keys are prefixed by "jwt_claim_".
* Base64URL: Signature is not supported.
* @param jwt SignedJWT object.
* @return Map of formatted and filtered values to be used as user metadata.
* @throws Exception Parse error.
*/
//
// Values will be filtered by type using isAllowedTypeForClaim().
public static Map<String, Object> toUserMetadata(final SignedJWT jwt) throws Exception {
final JWTClaimsSet claimsSet = jwt.getJWTClaimsSet();
public static Map<String, Object> toUserMetadata(JWTClaimsSet claimsSet) {
return claimsSet.getClaims()
.entrySet()
.stream()
Expand All @@ -366,4 +369,10 @@ static boolean isAllowedTypeForClaim(final Object value) {
|| (value instanceof Collection
&& ((Collection<?>) value).stream().allMatch(e -> e instanceof String || e instanceof Boolean || e instanceof Number)));
}

public static byte[] sha256(final CharSequence charSequence) {
final MessageDigest messageDigest = MessageDigests.sha256();
messageDigest.update(charSequence.toString().getBytes(StandardCharsets.UTF_8));
return messageDigest.digest();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.settings.SecureString;

import java.util.Date;
Expand All @@ -48,55 +49,6 @@ public class JwtValidateUtil {
null
);

/**
* Validate a SignedJWT. Use iss/aud/alg filters for those claims, JWKSet for signature, and skew seconds for time claims.
* @param jwt Signed JWT to be validated.
* @param allowedIssuer Filter for the "iss" claim.
* @param allowedAudiences Filter for the "aud" claim.
* @param allowedClockSkewSeconds Skew tolerance for the "auth_time", "iat", "nbf", and "exp" claims.
* @param allowedSignatureAlgorithms Filter for the "aud" header.
* @param jwks JWKs of HMAC secret keys or RSA/EC public keys.
* @throws Exception Error for the first validation to fail.
*/
public static void validate(
final SignedJWT jwt,
final String allowedIssuer,
final List<String> allowedAudiences,
final long allowedClockSkewSeconds,
final List<String> allowedSignatureAlgorithms,
final List<JWK> jwks
) throws Exception {
final Date now = new Date();

if (LOGGER.isDebugEnabled()) {
LOGGER.debug(
"Validating JWT, now [{}], alg [{}], issuer [{}], audiences [{}], typ [{}],"
+ " auth_time [{}], iat [{}], nbf [{}], exp [{}], kid [{}], jti [{}]",
now,
jwt.getHeader().getAlgorithm(),
jwt.getJWTClaimsSet().getIssuer(),
jwt.getJWTClaimsSet().getAudience(),
jwt.getHeader().getType(),
jwt.getJWTClaimsSet().getDateClaim("auth_time"),
jwt.getJWTClaimsSet().getIssueTime(),
jwt.getJWTClaimsSet().getNotBeforeTime(),
jwt.getJWTClaimsSet().getExpirationTime(),
jwt.getHeader().getKeyID(),
jwt.getJWTClaimsSet().getJWTID()
);
}
// validate claims before signature, because log messages about rejected claims can be more helpful than rejected signatures
JwtValidateUtil.validateType(jwt);
JwtValidateUtil.validateIssuer(jwt, allowedIssuer);
JwtValidateUtil.validateAudiences(jwt, allowedAudiences);
JwtValidateUtil.validateSignatureAlgorithm(jwt, allowedSignatureAlgorithms);
JwtValidateUtil.validateAuthTime(jwt, now, allowedClockSkewSeconds);
JwtValidateUtil.validateIssuedAtTime(jwt, now, allowedClockSkewSeconds);
JwtValidateUtil.validateNotBeforeTime(jwt, now, allowedClockSkewSeconds);
JwtValidateUtil.validateExpiredTime(jwt, now, allowedClockSkewSeconds);
JwtValidateUtil.validateSignature(jwt, jwks);
}

public static void validateType(final SignedJWT jwt) throws Exception {
final JOSEObjectType jwtHeaderType = jwt.getHeader().getType();
try {
Expand Down Expand Up @@ -277,7 +229,10 @@ static void validateExpiredTime(final Date exp, final Date now, final long allow
* @throws Exception Error if JWKs fail to validate the Signed JWT.
*/
public static void validateSignature(final SignedJWT jwt, final List<JWK> jwks) throws Exception {
assert jwks != null && jwks.isEmpty() == false : "Caller must provide a non-empty JWK list";
assert jwks != null : "Verify requires a non-null JWK list";
if (jwks.isEmpty()) {
throw new ElasticsearchException("Verify requires a non-empty JWK list");
}
final String id = jwt.getHeader().getKeyID();
final JWSAlgorithm alg = jwt.getHeader().getAlgorithm();
LOGGER.trace("JWKs [{}], JWT KID [{}], and JWT Algorithm [{}] before filters.", jwks.size(), id, alg.getName());
Expand Down Expand Up @@ -305,12 +260,35 @@ public static void validateSignature(final SignedJWT jwt, final List<JWK> jwks)
final List<JWK> jwksStrength = jwksAlg.stream().filter(j -> JwkValidateUtil.isMatch(j, alg.getName())).toList();
LOGGER.debug("JWKs [{}] after Algorithm [{}] match filter.", jwksStrength.size(), alg);

// No JWKs passed the kid, alg, and strength checks, so nothing left to use in verifying the JWT signature
if (jwksStrength.isEmpty()) {
throw new ElasticsearchException("Verify failed because all " + jwks.size() + " provided JWKs were filtered.");
}

for (final JWK jwk : jwksStrength) {
if (jwt.verify(JwtValidateUtil.createJwsVerifier(jwk))) {
return; // VERIFY SUCCEEDED
LOGGER.trace(
"JWT signature validation succeeded with JWK kty=[{}], jwtAlg=[{}], jwtKid=[{}], use=[{}], ops=[{}]",
jwk.getKeyType(),
jwk.getAlgorithm(),
jwk.getKeyID(),
jwk.getKeyUse(),
jwk.getKeyOperations()
);
return;
} else {
LOGGER.trace(
"JWT signature validation failed with JWK kty=[{}], jwtAlg=[{}], jwtKid=[{}], use=[{}], ops={}",
jwk.getKeyType(),
jwk.getAlgorithm(),
jwk.getKeyID(),
jwk.getKeyUse(),
jwk.getKeyOperations() == null ? "[null]" : jwk.getKeyOperations()
);
}
}
throw new Exception("Verify failed using " + jwksStrength.size() + " of " + jwks.size() + " provided JWKs.");

throw new ElasticsearchException("Verify failed using " + jwksStrength.size() + " of " + jwks.size() + " provided JWKs.");
}

public static JWSVerifier createJwsVerifier(final JWK jwk) throws JOSEException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jose.util.Base64URL;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;

import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;

import static org.hamcrest.Matchers.anyOf;
Expand All @@ -27,46 +28,27 @@ public class JwkValidateUtilTests extends JwtTestCase {

private static final Logger LOGGER = LogManager.getLogger(JwkValidateUtilTests.class);

// HMAC JWKSet setting can use keys from randomJwkHmac()
// HMAC key setting cannot use randomJwkHmac(), it must use randomJwkHmacString()
public void testConvertHmacJwkToStringToJwk() throws Exception {
final JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(randomFrom(JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC));

// Use HMAC random bytes for OIDC JWKSet setting only. Demonstrate encode/decode fails if used in OIDC HMAC key setting.
final OctetSequenceKey hmacKeyRandomBytes = JwtTestCase.randomJwkHmac(jwsAlgorithm);
assertThat(this.hmacEncodeDecodeAsPasswordTestHelper(hmacKeyRandomBytes), is(false));

// Convert HMAC random bytes to UTF8 bytes. This makes it usable as an OIDC HMAC key setting.
final OctetSequenceKey hmacKeyString1 = JwtTestCase.conditionJwkHmacForOidc(hmacKeyRandomBytes);
assertThat(this.hmacEncodeDecodeAsPasswordTestHelper(hmacKeyString1), is(true));

// Generate HMAC UTF8 bytes. This is usable as an OIDC HMAC key setting.
final OctetSequenceKey hmacKeyString2 = JwtTestCase.randomJwkHmacOidc(jwsAlgorithm);
assertThat(this.hmacEncodeDecodeAsPasswordTestHelper(hmacKeyString2), is(true));
// Test decode bytes as UTF8 to String, encode back to UTF8, and compare to original bytes. If same, it is safe for OIDC JWK encode.
static boolean isJwkHmacOidcSafe(final JWK jwk) {
if (jwk instanceof OctetSequenceKey jwkHmac) {
final byte[] rawKeyBytes = jwkHmac.getKeyValue().decode();
return Arrays.equals(rawKeyBytes, new String(rawKeyBytes, StandardCharsets.UTF_8).getBytes(StandardCharsets.UTF_8));
}
return true;
}

private boolean hmacEncodeDecodeAsPasswordTestHelper(final OctetSequenceKey hmacKey) {
final OctetSequenceKey hmacKeyNoAttributes = JwtTestCase.jwkHmacRemoveAttributes(hmacKey);
// Encode input key as Base64(keyBytes) and Utf8String(keyBytes)
final String keyBytesToBase64 = hmacKey.getKeyValue().toString();
final String keyBytesAsUtf8 = hmacKey.getKeyValue().decodeToString();

// Decode Base64(keyBytes) into new key and compare to original. This always works.
final OctetSequenceKey decodeFromBase64 = new OctetSequenceKey.Builder(new Base64URL(keyBytesToBase64)).build();
LOGGER.info("Base64 enc/dec test:\ngen: [" + hmacKey + "]\nenc: [" + keyBytesToBase64 + "]\ndec: [" + decodeFromBase64 + "]\n");
if (decodeFromBase64.equals(hmacKeyNoAttributes) == false) {
return false;
static boolean areJwkHmacOidcSafe(final Collection<JWK> jwks) {
for (final JWK jwk : jwks) {
if (JwkValidateUtilTests.isJwkHmacOidcSafe(jwk) == false) {
return false;
}
}

// Decode Utf8String(keyBytes) into new key and compare to original. Only works for randomJwkHmacString, fails for randomJwkHmac.
final OctetSequenceKey decodeFromUtf8 = new OctetSequenceKey.Builder(keyBytesAsUtf8.getBytes(StandardCharsets.UTF_8)).build();
LOGGER.info("UTF8 enc/dec test:\ngen: [" + hmacKey + "]\nenc: [" + keyBytesAsUtf8 + "]\ndec: [" + decodeFromUtf8 + "]\n");
return decodeFromUtf8.equals(hmacKeyNoAttributes);
return true;
}

public void testComputeBitLengthRsa() throws Exception {
for (final String signatureAlgorithmRsa : JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_RSA) {
final JWK jwk = JwtTestCase.randomJwk(signatureAlgorithmRsa);
final JWK jwk = JwtTestCase.randomJwkRsa(JWSAlgorithm.parse(signatureAlgorithmRsa));
final int minLength = JwkValidateUtil.computeBitLengthRsa(jwk.toRSAKey().toPublicKey());
assertThat(minLength, is(anyOf(equalTo(2048), equalTo(3072))));
}
Expand All @@ -86,7 +68,7 @@ public void testAlgsJwksAllPkcNotFiltered() throws Exception {

private void filterJwksAndAlgorithmsTestHelper(final List<String> candidateAlgs) throws JOSEException {
final List<String> algsRandom = randomOfMinUnique(2, candidateAlgs); // duplicates allowed
final List<JwtIssuer.AlgJwkPair> algJwkPairsAll = JwtTestCase.randomJwks(algsRandom);
final List<JwtIssuer.AlgJwkPair> algJwkPairsAll = JwtTestCase.randomJwks(algsRandom, randomBoolean());
final List<JWK> jwks = algJwkPairsAll.stream().map(JwtIssuer.AlgJwkPair::jwk).toList();
final List<String> algsAll = algJwkPairsAll.stream().map(JwtIssuer.AlgJwkPair::alg).toList();
final List<JWK> jwksAll = algJwkPairsAll.stream().map(JwtIssuer.AlgJwkPair::jwk).toList();
Expand Down

0 comments on commit 89e54be

Please sign in to comment.