Skip to content

Commit

Permalink
Generify AuthenticationResult (#79694)
Browse files Browse the repository at this point in the history
This PR changes AuthenticationResult to use generics for the value it
contains. Previously the value type is fixed to be User. While an User
object is suitable for realm authenticators, it does not fit smoothly for other
authenticators like ApiKey and Token. This is more evident after the
authentication chain is unified (#77293).

Because of the use of generic, the signature of "Realm#authenticate" is
also changed to take a listener that accepts a generified 
AuthenticationResult.

Relates: #75607
  • Loading branch information
ywangd committed Oct 26, 2021
1 parent 3f0ea76 commit 7d21e3c
Show file tree
Hide file tree
Showing 51 changed files with 510 additions and 539 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,39 @@
* <li>Unable to authenticate user, terminate authentication (with an error message)</li>
* </ol>
*/
public final class AuthenticationResult {
private static final AuthenticationResult NOT_HANDLED = new AuthenticationResult(Status.CONTINUE, null, null, null, null);
public final class AuthenticationResult<T> {
private static final AuthenticationResult<?> NOT_HANDLED = new AuthenticationResult<>(Status.CONTINUE, null, null, null, null);

public static String THREAD_CONTEXT_KEY = "_xpack_security_auth_result";

public enum Status {
/**
* The authenticator successfully handled the authentication request
*/
SUCCESS,
/**
* The authenticator either did not handle the authentication request for reasons such as
* it cannot find necessary credentials
* Or the authenticator tried to handle the authentication request but it was unsuccessful.
* Subsequent authenticators (if any) still have chance to attempt authentication.
*/
CONTINUE,
/**
* The authenticator fail to authenticate the request and also requires the whole authentication chain to be stopped
*/
TERMINATE,
}

private final Status status;
private final User user;
private final T value;
private final String message;
private final Exception exception;
private final Map<String, Object> metadata;

private AuthenticationResult(Status status, @Nullable User user, @Nullable String message, @Nullable Exception exception,
private AuthenticationResult(Status status, @Nullable T value, @Nullable String message, @Nullable Exception exception,
@Nullable Map<String, Object> metadata) {
this.status = status;
this.user = user;
this.value = value;
this.message = message;
this.exception = exception;
this.metadata = metadata == null ? Collections.emptyMap() : Collections.unmodifiableMap(metadata);
Expand All @@ -53,8 +65,8 @@ public Status getStatus() {
return status;
}

public User getUser() {
return user;
public T getValue() {
return value;
}

public String getMessage() {
Expand All @@ -77,20 +89,20 @@ public Map<String, Object> getMetadata() {
* </p><p>
* Neither the {@link #getMessage() message} nor {@link #getException() exception} are populated.
* </p>
* @param user The user that was authenticated. Cannot be {@code null}.
* @param value The user that was authenticated. Cannot be {@code null}.
*/
public static AuthenticationResult success(User user) {
return success(user, null);
public static <T> AuthenticationResult<T> success(T value) {
return success(value, null);
}

/**
* Creates a successful result, with optional metadata
*
* @see #success(User)
* @see #success(Object)
*/
public static AuthenticationResult success(User user, @Nullable Map<String, Object> metadata) {
Objects.requireNonNull(user);
return new AuthenticationResult(Status.SUCCESS, user, null, null, metadata);
public static <T> AuthenticationResult<T> success(T value, @Nullable Map<String, Object> metadata) {
Objects.requireNonNull(value);
return new AuthenticationResult<>(Status.SUCCESS, value, null, null, metadata);
}

/**
Expand All @@ -99,11 +111,12 @@ public static AuthenticationResult success(User user, @Nullable Map<String, Obje
* <p>
* The {@link #getStatus() status} is set to {@link Status#CONTINUE}.
* </p><p>
* The {@link #getMessage() message}, {@link #getException() exception}, and {@link #getUser() user} are all set to {@code null}.
* The {@link #getMessage() message}, {@link #getException() exception}, and {@link #getValue() user} are all set to {@code null}.
* </p>
*/
public static AuthenticationResult notHandled() {
return NOT_HANDLED;
@SuppressWarnings("unchecked")
public static <T> AuthenticationResult<T> notHandled() {
return (AuthenticationResult<T>) NOT_HANDLED;
}

/**
Expand All @@ -112,12 +125,12 @@ public static AuthenticationResult notHandled() {
* <p>
* The {@link #getStatus() status} is set to {@link Status#CONTINUE}.
* </p><p>
* The {@link #getUser() user} is not populated.
* The {@link #getValue() value} is not populated.
* </p>
*/
public static AuthenticationResult unsuccessful(String message, @Nullable Exception cause) {
public static <T> AuthenticationResult<T> unsuccessful(String message, @Nullable Exception cause) {
Objects.requireNonNull(message);
return new AuthenticationResult(Status.CONTINUE, null, message, cause, null);
return new AuthenticationResult<>(Status.CONTINUE, null, message, cause, null);
}

/**
Expand All @@ -127,11 +140,11 @@ public static AuthenticationResult unsuccessful(String message, @Nullable Except
* <p>
* The {@link #getStatus() status} is set to {@link Status#TERMINATE}.
* </p><p>
* The {@link #getUser() user} is not populated.
* The {@link #getValue() value} is not populated.
* </p>
*/
public static AuthenticationResult terminate(String message, @Nullable Exception cause) {
return new AuthenticationResult(Status.TERMINATE, null, message, cause, null);
public static <T> AuthenticationResult<T> terminate(String message, @Nullable Exception cause) {
return new AuthenticationResult<>(Status.TERMINATE, null, message, cause, null);
}

/**
Expand All @@ -141,10 +154,10 @@ public static AuthenticationResult terminate(String message, @Nullable Exception
* <p>
* The {@link #getStatus() status} is set to {@link Status#TERMINATE}.
* </p><p>
* The {@link #getUser() user} is not populated.
* The {@link #getValue() value} is not populated.
* </p>
*/
public static AuthenticationResult terminate(String message) {
public static <T> AuthenticationResult<T> terminate(String message) {
return terminate(message, null);
}

Expand All @@ -156,7 +169,7 @@ public boolean isAuthenticated() {
public String toString() {
return "AuthenticationResult{" +
"status=" + status +
", user=" + user +
", value=" + value +
", message=" + message +
", exception=" + exception +
'}';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public int compareTo(Realm other) {
* @param token The authentication token
* @param listener The listener to pass the authentication result to
*/
public abstract void authenticate(AuthenticationToken token, ActionListener<AuthenticationResult> listener);
public abstract void authenticate(AuthenticationToken token, ActionListener<AuthenticationResult<User>> listener);

/**
* Looks up the user identified the String identifier. A successful lookup will call the {@link ActionListener#onResponse}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,9 @@ private void testScenario(Scenario scenario) throws Exception {
Map<String, Map<Realm, User>> users = new HashMap<>();
for (Realm realm : realms) {
for (String username : usernames) {
PlainActionFuture<AuthenticationResult> future = new PlainActionFuture<>();
PlainActionFuture<AuthenticationResult<User>> future = new PlainActionFuture<>();
realm.authenticate(tokens.get(username), future);
User user = future.actionGet().getUser();
User user = future.actionGet().getValue();
assertThat(user, notNullValue());
Map<Realm, User> realmToUser = users.get(username);
if (realmToUser == null) {
Expand All @@ -251,9 +251,9 @@ private void testScenario(Scenario scenario) throws Exception {

for (String username : usernames) {
for (Realm realm : realms) {
PlainActionFuture<AuthenticationResult> future = new PlainActionFuture<>();
PlainActionFuture<AuthenticationResult<User>> future = new PlainActionFuture<>();
realm.authenticate(tokens.get(username), future);
User user = future.actionGet().getUser();
User user = future.actionGet().getValue();
assertThat(user, sameInstance(users.get(username).get(realm)));
}
}
Expand All @@ -264,9 +264,9 @@ private void testScenario(Scenario scenario) throws Exception {
// now, user_a should have been evicted, but user_b should still be cached
for (String username : usernames) {
for (Realm realm : realms) {
PlainActionFuture<AuthenticationResult> future = new PlainActionFuture<>();
PlainActionFuture<AuthenticationResult<User>> future = new PlainActionFuture<>();
realm.authenticate(tokens.get(username), future);
User user = future.actionGet().getUser();
User user = future.actionGet().getValue();
assertThat(user, notNullValue());
scenario.assertEviction(users.get(username).get(realm), user);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.xpack.core.security.action.oidc.OpenIdConnectAuthenticateAction;
import org.elasticsearch.xpack.core.security.authc.Authentication;
import org.elasticsearch.xpack.core.security.authc.AuthenticationResult;
import org.elasticsearch.xpack.core.security.user.User;
import org.elasticsearch.xpack.security.authc.AuthenticationService;
import org.elasticsearch.xpack.security.authc.TokenService;
import org.elasticsearch.xpack.security.authc.oidc.OpenIdConnectRealm;
Expand Down Expand Up @@ -65,9 +66,9 @@ protected void doExecute(Task task, OpenIdConnectAuthenticateRequest request,
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
authenticationService.authenticate(OpenIdConnectAuthenticateAction.NAME, request, token, ActionListener.wrap(
authentication -> {
AuthenticationResult result = threadContext.getTransient(AuthenticationResult.THREAD_CONTEXT_KEY);
AuthenticationResult<User> result = threadContext.getTransient(AuthenticationResult.THREAD_CONTEXT_KEY);
if (result == null) {
listener.onFailure(new IllegalStateException("Cannot find AuthenticationResult on thread context"));
listener.onFailure(new IllegalStateException("Cannot find User AuthenticationResult on thread context"));
return;
}
@SuppressWarnings("unchecked") final Map<String, Object> tokenMetadata = (Map<String, Object>) result.getMetadata()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.xpack.core.security.action.saml.SamlAuthenticateResponse;
import org.elasticsearch.xpack.core.security.authc.Authentication;
import org.elasticsearch.xpack.core.security.authc.AuthenticationResult;
import org.elasticsearch.xpack.core.security.user.User;
import org.elasticsearch.xpack.security.authc.AuthenticationService;
import org.elasticsearch.xpack.security.authc.TokenService;
import org.elasticsearch.xpack.security.authc.saml.SamlRealm;
Expand Down Expand Up @@ -58,9 +59,9 @@ protected void doExecute(Task task, SamlAuthenticateRequest request, ActionListe
Authentication originatingAuthentication = securityContext.getAuthentication();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
authenticationService.authenticate(SamlAuthenticateAction.NAME, request, saml, ActionListener.wrap(authentication -> {
AuthenticationResult result = threadContext.getTransient(AuthenticationResult.THREAD_CONTEXT_KEY);
AuthenticationResult<User> result = threadContext.getTransient(AuthenticationResult.THREAD_CONTEXT_KEY);
if (result == null) {
listener.onFailure(new IllegalStateException("Cannot find AuthenticationResult on thread context"));
listener.onFailure(new IllegalStateException("Cannot find User AuthenticationResult on thread context"));
return;
}
assert authentication != null : "authentication should never be null at this point";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@ public AuthenticationToken extractCredentials(Context context) {
}

@Override
public void authenticate(Context context, ActionListener<Result> listener) {
public void authenticate(Context context, ActionListener<AuthenticationResult<Authentication>> listener) {
final AuthenticationToken authenticationToken = context.getMostRecentAuthenticationToken();
if (false == authenticationToken instanceof ApiKeyCredentials) {
listener.onResponse(Authenticator.Result.notHandled());
listener.onResponse(AuthenticationResult.notHandled());
return;
}
ApiKeyCredentials apiKeyCredentials = (ApiKeyCredentials) authenticationToken;
apiKeyService.tryAuthenticate(context.getThreadContext(), apiKeyCredentials, ActionListener.wrap(authResult -> {
if (authResult.isAuthenticated()) {
final Authentication authentication = apiKeyService.createApiKeyAuthentication(authResult, nodeName);
listener.onResponse(Authenticator.Result.success(authentication));
listener.onResponse(AuthenticationResult.success(authentication));
} else if (authResult.getStatus() == AuthenticationResult.Status.TERMINATE) {
Exception e = (authResult.getException() != null) ?
authResult.getException() :
Expand All @@ -67,7 +67,7 @@ public void authenticate(Context context, ActionListener<Result> listener) {
logger.warn("Authentication using apikey failed - {}", authResult.getMessage());
}
}
listener.onResponse(Authenticator.Result.unsuccessful(
listener.onResponse(AuthenticationResult.unsuccessful(
authResult.getMessage(),
authResult.getException()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ XContentBuilder newDocument(char[] apiKeyHashChars, String name, Authentication
return builder;
}

void tryAuthenticate(ThreadContext ctx, ApiKeyCredentials credentials, ActionListener<AuthenticationResult> listener) {
void tryAuthenticate(ThreadContext ctx, ApiKeyCredentials credentials, ActionListener<AuthenticationResult<User>> listener) {
if (false == isEnabled()) {
listener.onResponse(AuthenticationResult.notHandled());
}
Expand All @@ -394,18 +394,18 @@ void tryAuthenticate(ThreadContext ctx, ApiKeyCredentials credentials, ActionLis
));
}

public Authentication createApiKeyAuthentication(AuthenticationResult authResult, String nodeName) {
public Authentication createApiKeyAuthentication(AuthenticationResult<User> authResult, String nodeName) {
if (false == authResult.isAuthenticated()) {
throw new IllegalArgumentException("API Key authn result must be successful");
}
final User user = authResult.getUser();
final User user = authResult.getValue();
final RealmRef authenticatedBy = new RealmRef(ApiKeyService.API_KEY_REALM_NAME, ApiKeyService.API_KEY_REALM_TYPE, nodeName);
return new Authentication(user, authenticatedBy, null, Version.CURRENT, Authentication.AuthenticationType.API_KEY,
authResult.getMetadata());
}

void loadApiKeyAndValidateCredentials(ThreadContext ctx, ApiKeyCredentials credentials,
ActionListener<AuthenticationResult> listener) {
ActionListener<AuthenticationResult<User>> listener) {
final String docId = credentials.getId();

Consumer<ApiKeyDoc> validator = apiKeyDoc ->
Expand Down Expand Up @@ -596,7 +596,7 @@ public List<RoleDescriptor> parseRoleDescriptors(final String apiKeyId, BytesRef
* @param listener the listener to notify after verification
*/
void validateApiKeyCredentials(String docId, ApiKeyDoc apiKeyDoc, ApiKeyCredentials credentials, Clock clock,
ActionListener<AuthenticationResult> listener) {
ActionListener<AuthenticationResult<User>> listener) {
if ("api_key".equals(apiKeyDoc.docType) == false) {
listener.onResponse(
AuthenticationResult.unsuccessful("document [" + docId + "] is [" + apiKeyDoc.docType + "] not an api key", null));
Expand Down Expand Up @@ -694,7 +694,7 @@ Cache<String, BytesReference> getRoleDescriptorsBytesCache() {

// package-private for testing
void validateApiKeyExpiration(ApiKeyDoc apiKeyDoc, ApiKeyCredentials credentials, Clock clock,
ActionListener<AuthenticationResult> listener) {
ActionListener<AuthenticationResult<User>> listener) {
if (apiKeyDoc.expirationTime == -1 || Instant.ofEpochMilli(apiKeyDoc.expirationTime).isAfter(clock.instant())) {
final String principal = Objects.requireNonNull((String) apiKeyDoc.creator.get("principal"));
final String fullName = (String) apiKeyDoc.creator.get("full_name");
Expand Down

0 comments on commit 7d21e3c

Please sign in to comment.