Skip to content

Commit

Permalink
Execute SAML authentication on the generic threadpool (#105232)
Browse files Browse the repository at this point in the history
This PR changes SAML transport actions to use `GENERIC` executor 
in order to avoid executing potentially slow and blocking IO/HTTP 
operations on the `transport_worker` threads.

Fixes #104962
  • Loading branch information
slobodanadamovic committed Feb 9, 2024
1 parent 6e3daf7 commit 0dc78a3
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 21 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/105232.yaml
@@ -0,0 +1,6 @@
pr: 105232
summary: Execute SAML authentication on the generic threadpool
area: Authentication
type: bug
issues:
- 104962
Expand Up @@ -7,10 +7,10 @@
package org.elasticsearch.xpack.security.action.saml;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.Task;
Expand All @@ -29,6 +29,7 @@
import org.elasticsearch.xpack.security.authc.saml.SamlToken;

import java.util.Map;
import java.util.concurrent.Executor;

/**
* Transport action responsible for taking saml content and turning it into a token.
Expand All @@ -39,6 +40,7 @@ public final class TransportSamlAuthenticateAction extends HandledTransportActio
private final AuthenticationService authenticationService;
private final TokenService tokenService;
private final SecurityContext securityContext;
private final Executor genericExecutor;

@Inject
public TransportSamlAuthenticateAction(
Expand All @@ -49,21 +51,29 @@ public TransportSamlAuthenticateAction(
TokenService tokenService,
SecurityContext securityContext
) {
// TODO replace SAME when removing workaround for https://github.com/elastic/elasticsearch/issues/97916
super(
SamlAuthenticateAction.NAME,
transportService,
actionFilters,
SamlAuthenticateRequest::new,
EsExecutors.DIRECT_EXECUTOR_SERVICE
threadPool.executor(ThreadPool.Names.SAME)
);
this.threadPool = threadPool;
this.authenticationService = authenticationService;
this.tokenService = tokenService;
this.securityContext = securityContext;
this.genericExecutor = threadPool.generic();
}

@Override
protected void doExecute(Task task, SamlAuthenticateRequest request, ActionListener<SamlAuthenticateResponse> listener) {
// workaround for https://github.com/elastic/elasticsearch/issues/97916 - TODO remove this when we can
genericExecutor.execute(ActionRunnable.wrap(listener, l -> doExecuteForked(task, request, l)));
}

private void doExecuteForked(Task task, SamlAuthenticateRequest request, ActionListener<SamlAuthenticateResponse> listener) {
assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.GENERIC);
final SamlToken saml = new SamlToken(request.getSaml(), request.getValidRequestIds(), request.getRealm());
logger.trace("Attempting to authenticate SamlToken [{}]", saml);
final ThreadContext threadContext = threadPool.getThreadContext();
Expand Down
Expand Up @@ -10,12 +10,13 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
Expand All @@ -34,6 +35,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Executor;
import java.util.function.Predicate;

import static org.elasticsearch.xpack.security.authc.saml.SamlRealm.findSamlRealms;
Expand All @@ -48,6 +50,7 @@ public final class TransportSamlInvalidateSessionAction extends HandledTransport
private static final Logger LOGGER = LogManager.getLogger(TransportSamlInvalidateSessionAction.class);
private final TokenService tokenService;
private final Realms realms;
private final Executor genericExecutor;

@Inject
public TransportSamlInvalidateSessionAction(
Expand All @@ -56,19 +59,27 @@ public TransportSamlInvalidateSessionAction(
TokenService tokenService,
Realms realms
) {
// TODO replace SAME when removing workaround for https://github.com/elastic/elasticsearch/issues/97916
super(
SamlInvalidateSessionAction.NAME,
transportService,
actionFilters,
SamlInvalidateSessionRequest::new,
EsExecutors.DIRECT_EXECUTOR_SERVICE
transportService.getThreadPool().executor(ThreadPool.Names.SAME)
);
this.tokenService = tokenService;
this.realms = realms;
this.genericExecutor = transportService.getThreadPool().generic();
}

@Override
protected void doExecute(Task task, SamlInvalidateSessionRequest request, ActionListener<SamlInvalidateSessionResponse> listener) {
// workaround for https://github.com/elastic/elasticsearch/issues/97916 - TODO remove this when we can
genericExecutor.execute(ActionRunnable.wrap(listener, l -> doExecuteForked(task, request, l)));
}

private void doExecuteForked(Task task, SamlInvalidateSessionRequest request, ActionListener<SamlInvalidateSessionResponse> listener) {
assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.GENERIC);
List<SamlRealm> realms = findSamlRealms(this.realms, request.getRealmName(), request.getAssertionConsumerServiceURL());
if (realms.isEmpty()) {
listener.onFailure(SamlUtils.samlException("Cannot find any matching realm for [{}]", request));
Expand Down
Expand Up @@ -8,12 +8,13 @@

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.security.action.saml.SamlLogoutAction;
import org.elasticsearch.xpack.core.security.action.saml.SamlLogoutRequest;
Expand All @@ -31,6 +32,7 @@
import org.opensaml.saml.saml2.core.LogoutRequest;

import java.util.Map;
import java.util.concurrent.Executor;

/**
* Transport action responsible for generating a SAML {@code &lt;LogoutRequest&gt;} as a redirect binding URL.
Expand All @@ -39,6 +41,7 @@ public final class TransportSamlLogoutAction extends HandledTransportAction<Saml

private final Realms realms;
private final TokenService tokenService;
private final Executor genericExecutor;

@Inject
public TransportSamlLogoutAction(
Expand All @@ -47,13 +50,27 @@ public TransportSamlLogoutAction(
Realms realms,
TokenService tokenService
) {
super(SamlLogoutAction.NAME, transportService, actionFilters, SamlLogoutRequest::new, EsExecutors.DIRECT_EXECUTOR_SERVICE);
// TODO replace SAME when removing workaround for https://github.com/elastic/elasticsearch/issues/97916
super(
SamlLogoutAction.NAME,
transportService,
actionFilters,
SamlLogoutRequest::new,
transportService.getThreadPool().executor(ThreadPool.Names.SAME)
);
this.realms = realms;
this.tokenService = tokenService;
this.genericExecutor = transportService.getThreadPool().generic();
}

@Override
protected void doExecute(Task task, SamlLogoutRequest request, ActionListener<SamlLogoutResponse> listener) {
// workaround for https://github.com/elastic/elasticsearch/issues/97916 - TODO remove this when we can
genericExecutor.execute(ActionRunnable.wrap(listener, l -> doExecuteForked(task, request, l)));
}

private void doExecuteForked(Task task, SamlLogoutRequest request, ActionListener<SamlLogoutResponse> listener) {
assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.GENERIC);
invalidateRefreshToken(request.getRefreshToken(), ActionListener.wrap(ignore -> {
try {
final String token = request.getToken();
Expand Down
Expand Up @@ -8,11 +8,12 @@

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.security.action.saml.SamlPrepareAuthenticationAction;
import org.elasticsearch.xpack.core.security.action.saml.SamlPrepareAuthenticationRequest;
Expand All @@ -24,6 +25,7 @@
import org.opensaml.saml.saml2.core.AuthnRequest;

import java.util.List;
import java.util.concurrent.Executor;

import static org.elasticsearch.xpack.security.authc.saml.SamlRealm.findSamlRealms;

Expand All @@ -35,17 +37,20 @@ public final class TransportSamlPrepareAuthenticationAction extends HandledTrans
SamlPrepareAuthenticationResponse> {

private final Realms realms;
private final Executor genericExecutor;

@Inject
public TransportSamlPrepareAuthenticationAction(TransportService transportService, ActionFilters actionFilters, Realms realms) {
// TODO replace SAME when removing workaround for https://github.com/elastic/elasticsearch/issues/97916
super(
SamlPrepareAuthenticationAction.NAME,
transportService,
actionFilters,
SamlPrepareAuthenticationRequest::new,
EsExecutors.DIRECT_EXECUTOR_SERVICE
transportService.getThreadPool().executor(ThreadPool.Names.SAME)
);
this.realms = realms;
this.genericExecutor = transportService.getThreadPool().generic();
}

@Override
Expand All @@ -54,6 +59,16 @@ protected void doExecute(
SamlPrepareAuthenticationRequest request,
ActionListener<SamlPrepareAuthenticationResponse> listener
) {
// workaround for https://github.com/elastic/elasticsearch/issues/97916 - TODO remove this when we can
genericExecutor.execute(ActionRunnable.wrap(listener, l -> doExecuteForked(task, request, l)));
}

private void doExecuteForked(
Task task,
SamlPrepareAuthenticationRequest request,
ActionListener<SamlPrepareAuthenticationResponse> listener
) {
assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.GENERIC);
List<SamlRealm> realms = findSamlRealms(this.realms, request.getRealmName(), request.getAssertionConsumerServiceURL());
if (realms.isEmpty()) {
listener.onFailure(SamlUtils.samlException("Cannot find any matching realm for [{}]", request));
Expand Down
Expand Up @@ -110,6 +110,7 @@

import static org.elasticsearch.common.Strings.collectionToCommaDelimitedString;
import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.transport.Transports.assertNotTransportThread;
import static org.elasticsearch.xpack.core.security.authc.saml.SamlRealmSettings.CLOCK_SKEW;
import static org.elasticsearch.xpack.core.security.authc.saml.SamlRealmSettings.DN_ATTRIBUTE;
import static org.elasticsearch.xpack.core.security.authc.saml.SamlRealmSettings.ENCRYPTION_KEY_ALIAS;
Expand Down Expand Up @@ -746,6 +747,7 @@ private static final class PrivilegedHTTPMetadataResolver extends HTTPMetadataRe

@Override
protected byte[] fetchMetadata() throws ResolverException {
assert assertNotTransportThread("fetching SAML metadata from a URL");
try {
return AccessController.doPrivileged(
(PrivilegedExceptionAction<byte[]>) () -> PrivilegedHTTPMetadataResolver.super.fetchMetadata()
Expand All @@ -757,6 +759,20 @@ protected byte[] fetchMetadata() throws ResolverException {

}

@SuppressForbidden(reason = "uses java.io.File")
private static final class SamlFilesystemMetadataResolver extends FilesystemMetadataResolver {

SamlFilesystemMetadataResolver(final java.io.File metadata) throws ResolverException {
super(metadata);
}

@Override
protected byte[] fetchMetadata() throws ResolverException {
assert assertNotTransportThread("fetching SAML metadata from a file");
return super.fetchMetadata();
}
}

@SuppressForbidden(reason = "uses toFile")
private static Tuple<AbstractReloadingMetadataResolver, Supplier<EntityDescriptor>> parseFileSystemMetadata(
Logger logger,
Expand All @@ -767,7 +783,7 @@ private static Tuple<AbstractReloadingMetadataResolver, Supplier<EntityDescripto

final String entityId = require(config, IDP_ENTITY_ID);
final Path path = config.env().configFile().resolve(metadataPath);
final FilesystemMetadataResolver resolver = new FilesystemMetadataResolver(path.toFile());
final FilesystemMetadataResolver resolver = new SamlFilesystemMetadataResolver(path.toFile());

for (var httpSetting : List.of(IDP_METADATA_HTTP_REFRESH, IDP_METADATA_HTTP_MIN_REFRESH, IDP_METADATA_HTTP_FAIL_ON_ERROR)) {
if (config.hasSetting(httpSetting)) {
Expand Down
Expand Up @@ -56,6 +56,7 @@
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ClusterServiceUtils;
import org.elasticsearch.test.client.NoOpClient;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportService;
Expand Down Expand Up @@ -130,6 +131,7 @@ public class TransportSamlInvalidateSessionActionTests extends SamlTestCase {
private TransportSamlInvalidateSessionAction action;
private SamlLogoutRequestHandler.Result logoutRequest;
private Function<SearchRequest, SearchHit[]> searchFunction = ignore -> SearchHits.EMPTY;
private ThreadPool threadPool;

@Before
public void setup() throws Exception {
Expand All @@ -147,9 +149,8 @@ public void setup() throws Exception {
.put(getFullSettingKey(realmId, RealmSettings.ORDER_SETTING), 0)
.build();

final ThreadContext threadContext = new ThreadContext(settings);
final ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(threadContext);
this.threadPool = new TestThreadPool("saml test thread pool", settings);
final ThreadContext threadContext = threadPool.getThreadContext();
AuthenticationTestHelper.builder()
.user(new User("kibana"))
.realmRef(new RealmRef("realm", "type", "node"))
Expand Down Expand Up @@ -338,6 +339,7 @@ private SearchHit tokenHit(int idx, BytesReference source) {
@After
public void cleanup() {
samlRealm.close();
threadPool.shutdown();
}

public void testInvalidateCorrectTokensFromLogoutRequest() throws Exception {
Expand Down
Expand Up @@ -37,6 +37,7 @@
import org.elasticsearch.license.MockLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ClusterServiceUtils;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportService;
Expand Down Expand Up @@ -99,6 +100,7 @@ public class TransportSamlLogoutActionTests extends SamlTestCase {
private List<BulkRequest> bulkRequests;
private TransportSamlLogoutAction action;
private Client client;
private ThreadPool threadPool;

@SuppressWarnings("unchecked")
@Before
Expand All @@ -116,9 +118,8 @@ public void setup() throws Exception {
.put(getFullSettingKey(realmIdentifier, RealmSettings.ORDER_SETTING), 0)
.build();

final ThreadContext threadContext = new ThreadContext(settings);
final ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(threadContext);
this.threadPool = new TestThreadPool("saml logout test thread pool", settings);
final ThreadContext threadContext = this.threadPool.getThreadContext();
AuthenticationTestHelper.builder()
.user(new User("kibana"))
.realmRef(new Authentication.RealmRef("realm", "type", "node"))
Expand Down Expand Up @@ -241,6 +242,7 @@ public void setup() throws Exception {
@After
public void cleanup() {
samlRealm.close();
threadPool.shutdown();
}

public void testLogoutInvalidatesToken() throws Exception {
Expand Down

0 comments on commit 0dc78a3

Please sign in to comment.