Skip to content

Commit

Permalink
Add JWT cache to JWT realm. (#84842)
Browse files Browse the repository at this point in the history
  • Loading branch information
justincr-elastic committed Mar 21, 2022
1 parent 9d94cb5 commit 070dec4
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 52 deletions.
15 changes: 15 additions & 0 deletions docs/reference/settings/security-settings.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -2059,6 +2059,21 @@ If this setting is used, then the JWT realm does not perform role
mapping and instead loads the user from the listed realms.
See <<authorization_realms>>.

`jwt.cache.ttl`::
(<<static-cluster-setting,Static>>)
Specifies the time-to-live for JWT cache entries.
JWT entries will be cached for this period of time.
JWTs can only be cached if client authentication is successful (or disabled).
Use the standard {es} <<time-units,time units>>.
Defaults to `20m`. Zero disables JWT cache.
If clients use a different JWT for every request, set to 0 to disable JWT cache.

`jwt.cache.size`::
(<<static-cluster-setting,Static>>)
Specifies the maximum number of JWT cache entries.
Defaults to `100000`. Zero disables JWT cache.
If clients use a different JWT for every request, set to 0 to disable JWT cache.

// tag::jwt-http-connect-timeout-tag[]
`http.connect_timeout` {ess-icon}::
(<<static-cluster-setting,Static>>)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,9 @@ public static ClientAuthenticationType parse(String value, String settingKey) {
private static final TimeValue DEFAULT_ALLOWED_CLOCK_SKEW = TimeValue.timeValueSeconds(60);
private static final List<String> DEFAULT_ALLOWED_SIGNATURE_ALGORITHMS = Collections.singletonList("RS256");
private static final boolean DEFAULT_POPULATE_USER_METADATA = true;
private static final String DEFAULT_JWT_VALIDATION_CACHE_HASH_ALGO = "ssha256";
private static final TimeValue DEFAULT_JWT_VALIDATION_CACHE_TTL = TimeValue.timeValueMinutes(20);
private static final int DEFAULT_JWT_VALIDATION_CACHE_MAX_USERS = 100_000;
private static final int MIN_JWT_VALIDATION_CACHE_MAX_USERS = 0;
private static final TimeValue DEFAULT_ROLES_LOOKUP_CACHE_TTL = TimeValue.timeValueMinutes(20);
private static final int DEFAULT_ROLES_LOOKUP_CACHE_MAX_USERS = 100_000;
private static final int MIN_ROLES_LOOKUP_CACHE_MAX_USERS = 0;
private static final TimeValue DEFAULT_JWT_CACHE_TTL = TimeValue.timeValueMinutes(20);
private static final int DEFAULT_JWT_CACHE_SIZE = 100_000;
private static final int MIN_JWT_CACHE_SIZE = 0;
private static final TimeValue DEFAULT_HTTP_CONNECT_TIMEOUT = TimeValue.timeValueSeconds(5);
private static final TimeValue DEFAULT_HTTP_CONNECTION_READ_TIMEOUT = TimeValue.timeValueSeconds(5);
private static final TimeValue DEFAULT_HTTP_SOCKET_TIMEOUT = TimeValue.timeValueSeconds(5);
Expand Down Expand Up @@ -140,6 +136,8 @@ private static Set<Setting.AffixSetting<?>> getNonSecureSettings() {
);
// JWT Client settings
set.addAll(List.of(CLIENT_AUTHENTICATION_TYPE));
// JWT Cache settings
set.addAll(List.of(JWT_CACHE_TTL, JWT_CACHE_SIZE));
// Standard HTTP settings for outgoing connections to get JWT issuer jwkset_path
set.addAll(
List.of(
Expand Down Expand Up @@ -238,6 +236,20 @@ private static Set<Setting.AffixSetting<SecureString>> getSecureSettings() {
"client_authentication.shared_secret"
);

// Individual Cache settings

public static final Setting.AffixSetting<TimeValue> JWT_CACHE_TTL = Setting.affixKeySetting(
RealmSettings.realmSettingPrefix(TYPE),
"jwt.cache.ttl",
key -> Setting.timeSetting(key, DEFAULT_JWT_CACHE_TTL, Setting.Property.NodeScope)
);

public static final Setting.AffixSetting<Integer> JWT_CACHE_SIZE = Setting.affixKeySetting(
RealmSettings.realmSettingPrefix(TYPE),
"jwt.cache.size",
key -> Setting.intSetting(key, DEFAULT_JWT_CACHE_SIZE, MIN_JWT_CACHE_SIZE, Setting.Property.NodeScope)
);

// Individual outgoing HTTP settings

public static final Setting.AffixSetting<TimeValue> HTTP_CONNECT_TIMEOUT = Setting.affixKeySetting(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,17 @@ public void removeKeysIf(Predicate<K> removeIf) {
}
}
}

public void removeValuesIf(Predicate<V> removeIf) {
// the cache cannot be modified while doing this operation per the terms of the cache iterator
try (ReleasableLock ignored = this.acquireForIterator()) {
Iterator<V> iterator = cache.values().iterator();
while (iterator.hasNext()) {
V value = iterator.next();
if (removeIf.test(value)) {
iterator.remove();
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.cache.CacheBuilder;
import org.elasticsearch.common.hash.MessageDigests;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.SettingsException;
import org.elasticsearch.common.util.concurrent.ReleasableLock;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.TimeValue;
Expand All @@ -30,16 +34,20 @@
import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
import org.elasticsearch.xpack.core.security.authc.support.CachingRealm;
import org.elasticsearch.xpack.core.security.authc.support.UserRoleMapper;
import org.elasticsearch.xpack.core.security.support.CacheIteratorHelper;
import org.elasticsearch.xpack.core.security.user.User;
import org.elasticsearch.xpack.core.ssl.SSLService;
import org.elasticsearch.xpack.security.authc.BytesKey;
import org.elasticsearch.xpack.security.authc.support.ClaimParser;
import org.elasticsearch.xpack.security.authc.support.DelegatedAuthorizationSupport;

import java.io.IOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Map;

Expand All @@ -50,6 +58,8 @@
public class JwtRealm extends Realm implements CachingRealm, Releasable {
private static final Logger LOGGER = LogManager.getLogger(JwtRealm.class);

record ExpiringUser(User user, Date exp) {}

record JwksAlgs(List<JWK> jwks, List<String> algs) {
boolean isEmpty() {
return jwks.isEmpty() && algs.isEmpty();
Expand Down Expand Up @@ -77,6 +87,8 @@ boolean isEmpty() {
final ClaimParser claimParserName;
final JwtRealmSettings.ClientAuthenticationType clientAuthenticationType;
final SecureString clientAuthenticationSharedSecret;
final Cache<BytesKey, ExpiringUser> jwtCache;
final CacheIteratorHelper<BytesKey, ExpiringUser> jwtCacheHelper;
DelegatedAuthorizationSupport delegatedAuthorizationSupport = null;

public JwtRealm(final RealmConfig realmConfig, final SSLService sslService, final UserRoleMapper userRoleMapper)
Expand All @@ -96,6 +108,8 @@ public JwtRealm(final RealmConfig realmConfig, final SSLService sslService, fina
this.clientAuthenticationType = realmConfig.getSetting(JwtRealmSettings.CLIENT_AUTHENTICATION_TYPE);
final SecureString sharedSecret = realmConfig.getSetting(JwtRealmSettings.CLIENT_AUTHENTICATION_SHARED_SECRET);
this.clientAuthenticationSharedSecret = Strings.hasText(sharedSecret) ? sharedSecret : null; // convert "" to null
this.jwtCache = this.buildJwtCache();
this.jwtCacheHelper = (this.jwtCache == null) ? null : new CacheIteratorHelper<>(this.jwtCache);

// Validate Client Authentication settings. Throw SettingsException there was a problem.
JwtUtil.validateClientAuthenticationSettings(
Expand Down Expand Up @@ -143,6 +157,15 @@ public JwtRealm(final RealmConfig realmConfig, final SSLService sslService, fina
}
}

private Cache<BytesKey, ExpiringUser> buildJwtCache() {
final TimeValue jwtCacheTtl = super.config.getSetting(JwtRealmSettings.JWT_CACHE_TTL);
final int jwtCacheSize = super.config.getSetting(JwtRealmSettings.JWT_CACHE_SIZE);
if ((jwtCacheTtl.getNanos() > 0) && (jwtCacheSize > 0)) {
return CacheBuilder.<BytesKey, ExpiringUser>builder().setExpireAfterWrite(jwtCacheTtl).setMaximumWeight(jwtCacheSize).build();
}
return null;
}

// must call parseAlgsAndJwksHmac() before parseAlgsAndJwksPkc()
private JwtRealm.JwksAlgs parseJwksAlgsHmac() {
final JwtRealm.JwksAlgs jwksAlgsHmac;
Expand Down Expand Up @@ -252,8 +275,19 @@ public void initialize(final Iterable<Realm> allRealms, final XPackLicenseState
this.delegatedAuthorizationSupport = new DelegatedAuthorizationSupport(allRealms, super.config, xpackLicenseState);
}

/**
* Clean up JWT cache (if enabled).
* Clean up HTTPS client cache (if enabled).
*/
@Override
public void close() {
if (this.jwtCache != null) {
try {
this.jwtCache.invalidateAll();
} catch (Exception e) {
LOGGER.warn("Exception invalidating JWT cache for realm [" + super.name() + "]", e);
}
}
if (this.httpClient != null) {
try {
this.httpClient.close();
Expand All @@ -272,11 +306,21 @@ public void lookupUser(final String username, final ActionListener<User> listene
@Override
public void expire(final String username) {
this.ensureInitialized();
LOGGER.trace("Expiring JWT cache entries for realm [" + super.name() + "] principal=[" + username + "]");
if (this.jwtCacheHelper != null) {
this.jwtCacheHelper.removeValuesIf(expiringUser -> expiringUser.user.principal().equals(username));
}
}

@Override
public void expireAll() {
this.ensureInitialized();
if ((this.jwtCache != null) && (this.jwtCacheHelper != null)) {
LOGGER.trace("Invalidating JWT cache for realm [" + super.name() + "]");
try (ReleasableLock ignored = this.jwtCacheHelper.acquireUpdateLock()) {
this.jwtCache.invalidateAll();
}
}
}

@Override
Expand Down Expand Up @@ -321,22 +365,64 @@ public void authenticate(final AuthenticationToken authenticationToken, final Ac
return; // FAILED (secret is missing or mismatched)
}

// Parse JWT: Extract claims for logs and role-mapping.
// JWT cache
final SecureString serializedJwt = jwtAuthenticationToken.getEndUserSignedJwt();
final BytesKey jwtCacheKey = (this.jwtCache == null) ? null : computeBytesKey(serializedJwt);
if (jwtCacheKey != null) {
final ExpiringUser expiringUser = this.jwtCache.get(jwtCacheKey);
if (expiringUser == null) {
LOGGER.trace("Realm [" + super.name() + "] JWT cache miss token=[" + tokenPrincipal + "] key=[" + jwtCacheKey + "].");
} else {
final User user = expiringUser.user;
final Date exp = expiringUser.exp; // claimsSet.getExpirationTime().getTime() + this.allowedClockSkew.getMillis()
final String principal = user.principal();
final Date now = new Date();
if (now.getTime() < exp.getTime()) {
LOGGER.trace(
"Realm ["
+ super.name()
+ "] JWT cache hit token=["
+ tokenPrincipal
+ "] key=["
+ jwtCacheKey
+ "] principal=["
+ principal
+ "] exp=["
+ exp
+ "] now=["
+ now
+ "]."
);
if (this.delegatedAuthorizationSupport.hasDelegation()) {
this.delegatedAuthorizationSupport.resolve(principal, listener);
} else {
listener.onResponse(AuthenticationResult.success(user));
}
return;
}
LOGGER.trace(
"Realm ["
+ super.name()
+ "] JWT cache exp token=["
+ tokenPrincipal
+ "] key=["
+ jwtCacheKey
+ "] principal=["
+ principal
+ "] exp=["
+ exp
+ "] now=["
+ now
+ "]."
);
}
}

// Validate JWT: Extract JWT and claims set, and validate JWT.
final SignedJWT jwt;
final JWTClaimsSet claimsSet;
try {
jwt = SignedJWT.parse(serializedJwt.toString());
claimsSet = jwt.getJWTClaimsSet();
} catch (Exception e) {
final String msg = "Realm [" + super.name() + "] JWT parse failed for token=[" + tokenPrincipal + "].";
LOGGER.debug(msg);
listener.onResponse(AuthenticationResult.unsuccessful(msg, e));
return; // FAILED (JWT parse fail or regex parse fail)
}

// Validate JWT
try {
final String jwtAlg = jwt.getHeader().getAlgorithm().getName();
final boolean isJwtAlgHmac = JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC.contains(jwtAlg);
final JwtRealm.JwksAlgs jwksAndAlgs = isJwtAlgHmac ? this.jwksAlgsHmac : this.jwksAlgsPkc;
Expand All @@ -348,6 +434,7 @@ public void authenticate(final AuthenticationToken authenticationToken, final Ac
jwksAndAlgs.algs,
jwksAndAlgs.jwks
);
claimsSet = jwt.getJWTClaimsSet();
LOGGER.trace("Realm [" + super.name() + "] JWT validation succeeded for token=[" + tokenPrincipal + "].");
} catch (Exception e) {
final String msg = "Realm [" + super.name() + "] JWT validation failed for token=[" + tokenPrincipal + "].";
Expand Down Expand Up @@ -375,23 +462,25 @@ public void authenticate(final AuthenticationToken authenticationToken, final Ac
return;
}

// Delegated role lookup: If enabled, lookup in authz realms. Otherwise, fall through to JWT realm role mapping.
if (this.delegatedAuthorizationSupport.hasDelegation()) {
this.delegatedAuthorizationSupport.resolve(principal, ActionListener.wrap(result -> {
if (result.isAuthenticated()) {
// Intercept the delegated authorization listener response to log roles. Empty roles is OK.
final User user = result.getValue();
final String rolesString = Arrays.toString(user.roles());
LOGGER.debug(
"Realm [" + super.name() + "] delegated roles [" + rolesString + "] for principal=[" + principal + "]."
);
// Roles listener: Log roles from delegated authz lookup or role mapping, and cache User if JWT cache is enabled.
final ActionListener<AuthenticationResult<User>> logAndCacheListener = ActionListener.wrap(result -> {
if (result.isAuthenticated()) {
final User user = result.getValue();
final String rolesString = Arrays.toString(user.roles());
LOGGER.debug("Realm [" + super.name() + "] roles [" + rolesString + "] for principal=[" + principal + "].");
if ((this.jwtCache != null) && (this.jwtCacheHelper != null)) {
try (ReleasableLock ignored = this.jwtCacheHelper.acquireUpdateLock()) {
final long expWallClockMillis = claimsSet.getExpirationTime().getTime() + this.allowedClockSkew.getMillis();
this.jwtCache.put(jwtCacheKey, new ExpiringUser(result.getValue(), new Date(expWallClockMillis)));
}
}
listener.onResponse(result);
}, e -> {
final String msg = "Realm [" + super.name() + "] delegated roles failed for principal=[" + principal + "].";
LOGGER.warn(msg, e);
listener.onResponse(AuthenticationResult.unsuccessful(msg, e));
}));
}
listener.onResponse(result);
}, listener::onFailure);

// Delegated role lookup or Role mapping: Use the above listener to log roles and cache User.
if (this.delegatedAuthorizationSupport.hasDelegation()) {
this.delegatedAuthorizationSupport.resolve(principal, logAndCacheListener);
return;
}

Expand All @@ -415,13 +504,8 @@ public void authenticate(final AuthenticationToken authenticationToken, final Ac
final UserRoleMapper.UserData userData = new UserRoleMapper.UserData(principal, dn, groups, userMetadata, super.config);
this.userRoleMapper.resolveRoles(userData, ActionListener.wrap(rolesSet -> {
final User user = new User(principal, rolesSet.toArray(Strings.EMPTY_ARRAY), name, mail, userData.getMetadata(), true);
LOGGER.debug("Realm [" + super.name() + "] roles " + String.join(",", rolesSet) + " for principal=[" + principal + "].");
listener.onResponse(AuthenticationResult.success(user));
}, e -> {
final String msg = "Realm [" + super.name() + "] roles failed for principal=[" + principal + "].";
LOGGER.warn(msg, e);
listener.onResponse(AuthenticationResult.unsuccessful(msg, e));
}));
logAndCacheListener.onResponse(AuthenticationResult.success(user));
}, logAndCacheListener::onFailure));
} else {
final String className = (authenticationToken == null) ? "null" : authenticationToken.getClass().getCanonicalName();
final String msg = "Realm [" + super.name() + "] does not support AuthenticationToken [" + className + "].";
Expand All @@ -433,6 +517,15 @@ public void authenticate(final AuthenticationToken authenticationToken, final Ac
@Override
public void usageStats(final ActionListener<Map<String, Object>> listener) {
this.ensureInitialized();
super.usageStats(ActionListener.wrap(listener::onResponse, listener::onFailure));
super.usageStats(ActionListener.wrap(stats -> {
stats.put("jwt.cache", Collections.singletonMap("size", this.jwtCache == null ? -1 : this.jwtCache.count()));
listener.onResponse(stats);
}, listener::onFailure));
}

static BytesKey computeBytesKey(final CharSequence charSequence) {
final MessageDigest messageDigest = MessageDigests.sha256();
messageDigest.update(charSequence.toString().getBytes(StandardCharsets.UTF_8));
return new BytesKey(messageDigest.digest());
}
}

0 comments on commit 070dec4

Please sign in to comment.