Skip to content

Commit

Permalink
Formalize authc context with direct token (#101384)
Browse files Browse the repository at this point in the history
This is a refactoring PR that formalizes the logic
where an authc token is directly provided when constructing
the authenticator context. This is in contrast to the regular
logic where authenticators are responsible to extract the
token(s) from the thread context.

When an authc token is supplied as an argument to the authenticator
context constructor, no credentials extraction from the thread context
is attempted by any authenticator. If the provided token fails to be
authenticated, authentication will fail rather than defaulting to the
anonymous user.
  • Loading branch information
albertzaharovits committed Nov 3, 2023
1 parent 23a8750 commit 4d4d8ce
Show file tree
Hide file tree
Showing 7 changed files with 307 additions and 173 deletions.
Expand Up @@ -215,12 +215,10 @@ public void authenticate(
final Authenticator.Context context = new Authenticator.Context(
threadContext,
new AuditableTransportRequest(auditTrailService.get(), failureHandler, threadContext, action, transportRequest),
null,
true,
realms
realms,
token
);
context.addAuthenticationToken(token);
authenticatorChain.authenticateAsync(context, listener);
authenticatorChain.authenticate(context, listener);
}

public void expire(String principal) {
Expand All @@ -247,18 +245,21 @@ public void onSecurityIndexStateChange(SecurityIndexManager.State previousState,
}
}

Authenticator.Context newContext(final String action, final TransportRequest request, final boolean allowAnonymous) {
/**
* Returns an authenticator context for verifying only the provided {@param authenticationToken} without trying
* to extract any other tokens from the thread context.
*/
Authenticator.Context newContext(final String action, final TransportRequest request, AuthenticationToken authenticationToken) {
return new Authenticator.Context(
threadContext,
new AuditableTransportRequest(auditTrailService.get(), failureHandler, threadContext, action, request),
null,
allowAnonymous,
realms
realms,
authenticationToken
);
}

void authenticate(final Authenticator.Context context, final ActionListener<Authentication> listener) {
authenticatorChain.authenticateAsync(context, listener);
authenticatorChain.authenticate(context, listener);
}

// pkg private method for testing
Expand Down
Expand Up @@ -44,13 +44,6 @@ public interface Authenticator {
@Nullable
AuthenticationToken extractCredentials(Context context);

/**
* Whether authentication with anonymous or fallback user is allowed after this authenticator.
*/
default boolean canBeFollowedByNullTokenHandler() {
return true;
}

/**
* Attempt to authenticate current request encapsulated by the {@link Context} object.
* @param context The context object encapsulating current request and other information relevant for authentication.
Expand Down Expand Up @@ -87,21 +80,56 @@ static SecureString extractBearerTokenFromHeader(ThreadContext threadContext) {
* required for authentication.
* It is instantiated for every incoming request and passed around to {@link AuthenticatorChain} and subsequently all
* {@link Authenticator}.
* {@link Authenticator}s are consulted in order (see {@link AuthenticatorChain}),
* where each is given the chance to first extract some token, and then to verify it.
* If token verification fails in some particular way (i.e. {@code AuthenticationResult.Status.CONTINUE}),
* the next {@link Authenticator} is tried.
* The extracted tokens are all appended with {@link #addAuthenticationToken(AuthenticationToken)}.
*/
class Context implements Closeable {
private final ThreadContext threadContext;
private final AuthenticationService.AuditableRequest request;
private final User fallbackUser;
private final boolean allowAnonymous;
private final boolean extractCredentials;
private final Realms realms;
private final List<AuthenticationToken> authenticationTokens = new ArrayList<>();
private final List<AuthenticationToken> authenticationTokens;
private final List<String> unsuccessfulMessages = new ArrayList<>();
private boolean handleNullToken = true;
private SecureString bearerString = null;
private List<Realm> defaultOrderedRealmList = null;
private List<Realm> unlicensedRealms = null;

public Context(
/**
* Context constructor that provides the authentication token directly as an argument.
* This avoids extracting any tokens from the thread context, which is the regular way that authn works.
* In this case, the authentication process will simply verify the provided token, and will never fall back to the null-token case
* (i.e. in case the token CAN NOT be verified, the user IS NOT authenticated as the anonymous or the fallback user, and
* instead the authentication process fails, see {@link AuthenticatorChain#doAuthenticate}). If a {@code null} token is provided
* the authentication will invariably fail.
*/
Context(
ThreadContext threadContext,
AuthenticationService.AuditableRequest request,
Realms realms,
@Nullable AuthenticationToken token
) {
this.threadContext = threadContext;
this.request = request;
this.realms = realms;
// when a token is directly supplied for authn, don't extract other tokens, and don't handle the null-token case
this.authenticationTokens = token != null ? List.of(token) : List.of(); // no other tokens should be added
this.extractCredentials = false;
this.handleNullToken = false;
// if handleNullToken is false, fallbackUser and allowAnonymous are irrelevant
this.fallbackUser = null;
this.allowAnonymous = false;
}

/**
* Context constructor where authentication looks for credentials in the thread context.
*/
Context(
ThreadContext threadContext,
AuthenticationService.AuditableRequest request,
User fallbackUser,
Expand All @@ -110,6 +138,9 @@ public Context(
) {
this.threadContext = threadContext;
this.request = request;
this.extractCredentials = true;
// the extracted tokens, in order, for each {@code Authenticator}
this.authenticationTokens = new ArrayList<>();
this.fallbackUser = fallbackUser;
this.allowAnonymous = allowAnonymous;
this.realms = realms;
Expand Down Expand Up @@ -139,6 +170,17 @@ public boolean shouldHandleNullToken() {
return handleNullToken;
}

/**
* Returns {@code true}, if {@code Authenticator}s should first be tried in order to extract the credentials token
* from the thread context. The extracted tokens are appended to this authenticator context with
* {@link #addAuthenticationToken(AuthenticationToken)}.
* If {@code false}, the credentials token is directly passed in to this authenticator context, and the authenticators
* themselves are only consulted to authenticate the token, and never to extract any tokens from the thread context.
*/
public boolean shouldExtractCredentials() {
return extractCredentials;
}

public List<String> getUnsuccessfulMessages() {
return unsuccessfulMessages;
}
Expand Down
Expand Up @@ -68,7 +68,7 @@ class AuthenticatorChain {
this.allAuthenticators = List.of(serviceAccountAuthenticator, oAuth2TokenAuthenticator, apiKeyAuthenticator, realmsAuthenticator);
}

void authenticateAsync(Authenticator.Context context, ActionListener<Authentication> originalListener) {
void authenticate(Authenticator.Context context, ActionListener<Authentication> originalListener) {
assert false == context.getDefaultOrderedRealmList().isEmpty() : "realm list must not be empty";
// Check whether authentication is an operator user and mark the threadContext if necessary
// before returning the authentication object
Expand All @@ -79,7 +79,7 @@ void authenticateAsync(Authenticator.Context context, ActionListener<Authenticat
});
// If a token is directly provided in the context, authenticate with it
if (context.getMostRecentAuthenticationToken() != null) {
authenticateAsyncWithExistingAuthenticationToken(context, listener);
doAuthenticate(context, listener);
return;
}
final Authentication authentication;
Expand All @@ -93,23 +93,11 @@ void authenticateAsync(Authenticator.Context context, ActionListener<Authenticat
logger.trace("Found existing authentication [{}] in request [{}]", authentication, context.getRequest());
listener.onResponse(authentication);
} else {
doAuthenticate(context, true, ActionListener.runBefore(listener, context::close));
doAuthenticate(context, ActionListener.runBefore(listener, context::close));
}
}

/**
* Similar to {@link #authenticateAsync} but without extracting credentials. The credentials should
* be prepared by the called and made available in the context before calling this method.
* This method currently uses a shorter chain to match existing behaviour. But there is no reason
* why this could not use the same chain.
*/
private void authenticateAsyncWithExistingAuthenticationToken(Authenticator.Context context, ActionListener<Authentication> listener) {
assert context.getMostRecentAuthenticationToken() != null : "existing authentication token must not be null";
context.setHandleNullToken(false); // already has a token, should not try null token
doAuthenticate(context, false, listener);
}

private void doAuthenticate(Authenticator.Context context, boolean shouldExtractCredentials, ActionListener<Authentication> listener) {
private void doAuthenticate(Authenticator.Context context, ActionListener<Authentication> listener) {
// The iterating listener walks through the list of Authenticators and attempts to authenticate using
// each Authenticator (and optionally asks it to extract the authenticationToken).
// Depending on the authentication result from each Authenticator, the iteration may stop earlier
Expand All @@ -121,14 +109,15 @@ private void doAuthenticate(Authenticator.Context context, boolean shouldExtract
if (result.getStatus() == AuthenticationResult.Status.SUCCESS) {
maybeLookupRunAsUser(context, result.getValue(), l);
} else {
assert result.getStatus() == AuthenticationResult.Status.CONTINUE;
if (context.shouldHandleNullToken()) {
handleNullToken(context, l);
} else {
l.onFailure(Exceptions.authenticationError("failed to authenticate", result.getException()));
}
}
}),
getAuthenticatorConsumer(context, shouldExtractCredentials),
getAuthenticatorConsumer(context),
allAuthenticators,
context.getThreadContext(),
Function.identity(),
Expand All @@ -138,11 +127,10 @@ private void doAuthenticate(Authenticator.Context context, boolean shouldExtract
}

private static BiConsumer<Authenticator, ActionListener<AuthenticationResult<Authentication>>> getAuthenticatorConsumer(
Authenticator.Context context,
boolean shouldExtractCredentials
Authenticator.Context context
) {
return (authenticator, listener) -> {
if (shouldExtractCredentials) {
if (context.shouldExtractCredentials()) {
final AuthenticationToken authenticationToken;
try {
authenticationToken = authenticator.extractCredentials(context);
Expand All @@ -161,7 +149,6 @@ private static BiConsumer<Authenticator, ActionListener<AuthenticationResult<Aut
}
context.addAuthenticationToken(authenticationToken);
}
context.setHandleNullToken(context.shouldHandleNullToken() && authenticator.canBeFollowedByNullTokenHandler());

final Consumer<Exception> onFailure = (e) -> {
assert e != null : "exception cannot be null";
Expand Down
Expand Up @@ -53,16 +53,21 @@ public CrossClusterAccessAuthenticationService(
}

public void authenticate(final String action, final TransportRequest request, final ActionListener<Authentication> listener) {
final Authenticator.Context authcContext = authenticationService.newContext(action, request, false);
final ThreadContext threadContext = authcContext.getThreadContext();

final ThreadContext threadContext = clusterService.threadPool().getThreadContext();
final CrossClusterAccessHeaders crossClusterAccessHeaders;
final Authenticator.Context authcContext;
try {
// parse and add as authentication token as early as possible so that failure events in audit log include API key ID
crossClusterAccessHeaders = CrossClusterAccessHeaders.readFromContext(threadContext);
final ApiKeyService.ApiKeyCredentials apiKeyCredentials = crossClusterAccessHeaders.credentials();
assert ApiKey.Type.CROSS_CLUSTER == apiKeyCredentials.getExpectedType();
authcContext.addAuthenticationToken(apiKeyCredentials);
// authn must verify only the provided api key and not try to extract any other credential from the thread context
authcContext = authenticationService.newContext(action, request, apiKeyCredentials);
} catch (Exception ex) {
withRequestProcessingFailure(authenticationService.newContext(action, request, null), ex, listener);
return;
}
try {
apiKeyService.ensureEnabled();
} catch (Exception ex) {
withRequestProcessingFailure(authcContext, ex, listener);
Expand Down
Expand Up @@ -44,7 +44,6 @@ class RealmsAuthenticator implements Authenticator {

private final AtomicLong numInvalidation;
private final Cache<String, Realm> lastSuccessfulAuthCache;
private boolean authenticationTokenExtracted = false;

RealmsAuthenticator(AtomicLong numInvalidation, Cache<String, Realm> lastSuccessfulAuthCache) {
this.numInvalidation = numInvalidation;
Expand All @@ -60,17 +59,15 @@ public String name() {
public AuthenticationToken extractCredentials(Context context) {
final AuthenticationToken authenticationToken = extractToken(context);
if (authenticationToken != null) {
authenticationTokenExtracted = true;
// Once a token is extracted by realms, from the thread context,
// authentication must not handle the null-token case (in case no realm can verify the extracted token).
// In other words, the handle null-token case (i.e. authenticate as the anonymous user) runs only when no realm can extract
// a token from the thread context (i.e. from the request).
context.setHandleNullToken(false);
}
return authenticationToken;
}

@Override
public boolean canBeFollowedByNullTokenHandler() {
// TODO: once a token is extracted by realms, we should no longer handle null token if no realm can authenticate the token
return false == authenticationTokenExtracted;
}

@Override
public void authenticate(Context context, ActionListener<AuthenticationResult<Authentication>> listener) {
if (context.getMostRecentAuthenticationToken() == null) {
Expand Down

0 comments on commit 4d4d8ce

Please sign in to comment.