Skip to content

Commit

Permalink
Ensure authentication is wire compatible when setting user (#86741) (#…
Browse files Browse the repository at this point in the history
…86828)

* Ensure authentication is wire compatible when setting user (#86741)

The SecurityServerTransportInterceptor class is responsible for writing
authentication header in a wire compatible format before the request
leaving the local node. However, a bug made it ignore the wire version
when setting user based on the action origin. This PR fixes it and adds
relevant tests.

It is an old bug but never manifested itself previously because (1) the
code path is rare enuough and (2) authentication didn't have any version
difference till 8.2.

Resolves: #86716

* fix compilation
  • Loading branch information
ywangd committed May 17, 2022
1 parent 5a45a88 commit e7d1406
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 11 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/86741.yaml
@@ -0,0 +1,6 @@
pr: 86741
summary: Ensure authentication is wire compatible when setting user
area: Authentication
type: bug
issues:
- 86716
Expand Up @@ -8,6 +8,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
Expand Down Expand Up @@ -105,6 +106,7 @@ operations are blocked on license expiration. All data operations (read and writ
AuthorizationUtils.switchUserBasedOnActionOriginAndExecute(
threadContext,
securityContext,
Version.CURRENT, // current version since this is on the same node
(original) -> { applyInternal(task, chain, action, request, contextPreservingListener); }
);
} else {
Expand Down
Expand Up @@ -105,6 +105,7 @@ public static boolean shouldSetUserBasedOnActionOrigin(ThreadContext context) {
public static void switchUserBasedOnActionOriginAndExecute(
ThreadContext threadContext,
SecurityContext securityContext,
Version version,
Consumer<ThreadContext.StoredContext> consumer
) {
final String actionOrigin = threadContext.getTransient(ClientHelper.ACTION_ORIGIN_TRANSIENT_NAME);
Expand All @@ -115,7 +116,7 @@ public static void switchUserBasedOnActionOriginAndExecute(

switch (actionOrigin) {
case SECURITY_ORIGIN:
securityContext.executeAsInternalUser(XPackSecurityUser.INSTANCE, Version.CURRENT, consumer);
securityContext.executeAsInternalUser(XPackSecurityUser.INSTANCE, version, consumer);
break;
case WATCHER_ORIGIN:
case ML_ORIGIN:
Expand All @@ -133,10 +134,10 @@ public static void switchUserBasedOnActionOriginAndExecute(
case LOGSTASH_MANAGEMENT_ORIGIN:
case FLEET_ORIGIN:
case TASKS_ORIGIN: // TODO use a more limited user for tasks
securityContext.executeAsInternalUser(XPackUser.INSTANCE, Version.CURRENT, consumer);
securityContext.executeAsInternalUser(XPackUser.INSTANCE, version, consumer);
break;
case ASYNC_SEARCH_ORIGIN:
securityContext.executeAsInternalUser(AsyncSearchUser.INSTANCE, Version.CURRENT, consumer);
securityContext.executeAsInternalUser(AsyncSearchUser.INSTANCE, version, consumer);
break;
default:
assert false : "action.origin [" + actionOrigin + "] is unknown!";
Expand Down
Expand Up @@ -106,6 +106,7 @@ public <T extends TransportResponse> void sendRequest(
AuthorizationUtils.switchUserBasedOnActionOriginAndExecute(
threadPool.getThreadContext(),
securityContext,
minVersion,
(original) -> sendWithUser(
connection,
action,
Expand Down
Expand Up @@ -6,11 +6,13 @@
*/
package org.elasticsearch.xpack.security.authz;

import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.persistent.PersistentTasksService;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.VersionUtils;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.security.SecurityContext;
import org.elasticsearch.xpack.core.security.authc.Authentication;
Expand Down Expand Up @@ -98,7 +100,7 @@ public void testShouldSetUser() {
}

public void testSwitchAndExecuteXpackSecurityUser() throws Exception {
assertSwitchBasedOnOriginAndExecute(ClientHelper.SECURITY_ORIGIN, XPackSecurityUser.INSTANCE);
assertSwitchBasedOnOriginAndExecute(ClientHelper.SECURITY_ORIGIN, XPackSecurityUser.INSTANCE, randomVersion());
}

public void testSwitchAndExecuteXpackUser() throws Exception {
Expand All @@ -110,20 +112,20 @@ public void testSwitchAndExecuteXpackUser() throws Exception {
PersistentTasksService.PERSISTENT_TASK_ORIGIN,
ClientHelper.INDEX_LIFECYCLE_ORIGIN
)) {
assertSwitchBasedOnOriginAndExecute(origin, XPackUser.INSTANCE);
assertSwitchBasedOnOriginAndExecute(origin, XPackUser.INSTANCE, randomVersion());
}
}

public void testSwitchAndExecuteAsyncSearchUser() throws Exception {
String origin = ClientHelper.ASYNC_SEARCH_ORIGIN;
assertSwitchBasedOnOriginAndExecute(origin, AsyncSearchUser.INSTANCE);
assertSwitchBasedOnOriginAndExecute(origin, AsyncSearchUser.INSTANCE, randomVersion());
}

public void testSwitchWithTaskOrigin() throws Exception {
assertSwitchBasedOnOriginAndExecute(TASKS_ORIGIN, XPackUser.INSTANCE);
assertSwitchBasedOnOriginAndExecute(TASKS_ORIGIN, XPackUser.INSTANCE, randomVersion());
}

private void assertSwitchBasedOnOriginAndExecute(String origin, User user) throws Exception {
private void assertSwitchBasedOnOriginAndExecute(String origin, User user, Version version) throws Exception {
SecurityContext securityContext = new SecurityContext(Settings.EMPTY, threadContext);
final String headerName = randomAlphaOfLengthBetween(4, 16);
final String headerValue = randomAlphaOfLengthBetween(4, 16);
Expand All @@ -132,23 +134,30 @@ private void assertSwitchBasedOnOriginAndExecute(String origin, User user) throw
final ActionListener<Void> listener = ActionListener.wrap(v -> {
assertNull(threadContext.getTransient(ThreadContext.ACTION_ORIGIN_TRANSIENT_NAME));
assertNull(threadContext.getHeader(headerName));
assertEquals(user, securityContext.getAuthentication().getUser());
final Authentication authentication = securityContext.getAuthentication();
assertEquals(user, authentication.getUser());
assertEquals(version, authentication.getVersion());
latch.countDown();
}, e -> fail(e.getMessage()));

final Consumer<ThreadContext.StoredContext> consumer = original -> {
assertNull(threadContext.getTransient(ThreadContext.ACTION_ORIGIN_TRANSIENT_NAME));
assertNull(threadContext.getHeader(headerName));
assertEquals(user, securityContext.getAuthentication().getUser());
final Authentication authentication = securityContext.getAuthentication();
assertEquals(user, authentication.getUser());
assertEquals(version, authentication.getVersion());
latch.countDown();
listener.onResponse(null);
};

threadContext.putHeader(headerName, headerValue);
try (ThreadContext.StoredContext ignored = threadContext.stashWithOrigin(origin)) {
AuthorizationUtils.switchUserBasedOnActionOriginAndExecute(threadContext, securityContext, consumer);
AuthorizationUtils.switchUserBasedOnActionOriginAndExecute(threadContext, securityContext, version, consumer);
latch.await();
}
}

private Version randomVersion() {
return VersionUtils.randomCompatibleVersion(random(), Version.CURRENT);
}
}
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ClusterServiceUtils;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.VersionUtils;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.Transport;
Expand All @@ -34,8 +35,11 @@
import org.elasticsearch.xpack.core.security.authc.Authentication;
import org.elasticsearch.xpack.core.security.authc.Authentication.RealmRef;
import org.elasticsearch.xpack.core.security.authz.AuthorizationServiceField;
import org.elasticsearch.xpack.core.security.user.AsyncSearchUser;
import org.elasticsearch.xpack.core.security.user.SystemUser;
import org.elasticsearch.xpack.core.security.user.User;
import org.elasticsearch.xpack.core.security.user.XPackSecurityUser;
import org.elasticsearch.xpack.core.security.user.XPackUser;
import org.elasticsearch.xpack.core.ssl.SSLService;
import org.elasticsearch.xpack.security.authc.AuthenticationService;
import org.elasticsearch.xpack.security.authz.AuthorizationService;
Expand All @@ -48,6 +52,12 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;

import static org.elasticsearch.xpack.core.ClientHelper.ASYNC_SEARCH_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.SECURITY_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.TRANSFORM_ORIGIN;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
Expand Down Expand Up @@ -324,6 +334,63 @@ public <T extends TransportResponse> void sendRequest(
assertEquals(Version.CURRENT, authentication.getVersion());
}

public void testSetUserBasedOnActionOrigin() {
final Map<String, User> originToUserMap = Map.of(
SECURITY_ORIGIN,
XPackSecurityUser.INSTANCE,
TRANSFORM_ORIGIN,
XPackUser.INSTANCE,
ASYNC_SEARCH_ORIGIN,
AsyncSearchUser.INSTANCE
);

final String origin = randomFrom(originToUserMap.keySet());

threadContext.putTransient(ThreadContext.ACTION_ORIGIN_TRANSIENT_NAME, origin);
SecurityServerTransportInterceptor interceptor = new SecurityServerTransportInterceptor(
settings,
threadPool,
mock(AuthenticationService.class),
mock(AuthorizationService.class),
mock(SSLService.class),
securityContext,
new DestructiveOperations(
Settings.EMPTY,
new ClusterSettings(Settings.EMPTY, Collections.singleton(DestructiveOperations.REQUIRES_NAME_SETTING))
)
);

final AtomicBoolean calledWrappedSender = new AtomicBoolean(false);
final AtomicReference<Authentication> authenticationRef = new AtomicReference<>();
final AsyncSender intercepted = new AsyncSender() {
@Override
public <T extends TransportResponse> void sendRequest(
Transport.Connection connection,
String action,
TransportRequest request,
TransportRequestOptions options,
TransportResponseHandler<T> handler
) {
if (calledWrappedSender.compareAndSet(false, true) == false) {
fail("sender called more than once!");
}
authenticationRef.set(securityContext.getAuthentication());
}
};
final AsyncSender sender = interceptor.interceptSender(intercepted);

Transport.Connection connection = mock(Transport.Connection.class);
final Version connectionVersion = VersionUtils.randomCompatibleVersion(random(), Version.CURRENT);
when(connection.getVersion()).thenReturn(connectionVersion);

sender.sendRequest(connection, "indices:foo[s]", null, null, null);
assertThat(calledWrappedSender.get(), is(true));
final Authentication authentication = authenticationRef.get();
assertThat(authentication, notNullValue());
assertThat(authentication.getUser(), equalTo(originToUserMap.get(origin)));
assertThat(authentication.getVersion(), equalTo(connectionVersion));
}

public void testContextRestoreResponseHandler() throws Exception {
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);

Expand Down

0 comments on commit e7d1406

Please sign in to comment.